Skip to main content

stakpak_shared/models/
llm.rs

1//! LLM Provider and Model Configuration
2//!
3//! This module provides the configuration types for LLM providers and models.
4//!
5//! # Provider Configuration
6//!
7//! Providers are configured in a `providers` HashMap where the key becomes the
8//! model prefix for routing requests to the correct provider.
9//!
10//! ## Built-in Providers
11//!
12//! - `openai` - OpenAI API
13//! - `anthropic` - Anthropic API (supports OAuth via `access_token`)
14//! - `gemini` - Google Gemini API
15//!
16//! For built-in providers, you can use the model name directly without a prefix:
17//! - `claude-sonnet-4-5` → auto-detected as Anthropic
18//! - `gpt-4` → auto-detected as OpenAI
19//! - `gemini-2.5-pro` → auto-detected as Gemini
20//!
21//! ## Custom Providers
22//!
23//! Any OpenAI-compatible API can be configured using `type = "custom"`.
24//! The provider key becomes the model prefix.
25//!
26//! # Model Routing
27//!
28//! Models can be specified with or without a provider prefix:
29//!
30//! - `claude-sonnet-4-5` → auto-detected as `anthropic` provider
31//! - `anthropic/claude-sonnet-4-5` → explicit `anthropic` provider
32//! - `offline/llama3` → routes to `offline` custom provider, sends `llama3` to API
33//! - `custom/anthropic/claude-opus` → routes to `custom` provider,
34//!   sends `anthropic/claude-opus` to the API
35//!
36//! # Example Configuration
37//!
38//! ```toml
39//! [profiles.default]
40//! provider = "local"
41//! smart_model = "claude-sonnet-4-5"  # auto-detected as anthropic
42//! eco_model = "offline/llama3"       # custom provider
43//!
44//! [profiles.default.providers.anthropic]
45//! type = "anthropic"
46//! # api_key from auth.toml or ANTHROPIC_API_KEY env var
47//!
48//! [profiles.default.providers.offline]
49//! type = "custom"
50//! api_endpoint = "http://localhost:11434/v1"
51//! ```
52
53use crate::models::{
54    integrations::{anthropic::AnthropicModel, gemini::GeminiModel, openai::OpenAIModel},
55    model_pricing::{ContextAware, ModelContextInfo},
56};
57use serde::{Deserialize, Serialize};
58use std::collections::HashMap;
59use std::fmt::Display;
60
61// =============================================================================
62// Provider Configuration
63// =============================================================================
64
65/// Unified provider configuration enum
66///
67/// All provider configurations are stored in a `HashMap<String, ProviderConfig>`
68/// where the key is the provider name and becomes the model prefix for routing.
69///
70/// # Provider Key = Model Prefix
71///
72/// The key used in the HashMap becomes the prefix used in model names:
73/// - Config key: `providers.offline`
74/// - Model usage: `offline/llama3`
75/// - Routing: finds `offline` provider, sends `llama3` to API
76///
77/// # Example TOML
78/// ```toml
79/// [profiles.myprofile.providers.openai]
80/// type = "openai"
81/// api_key = "sk-..."
82///
83/// [profiles.myprofile.providers.anthropic]
84/// type = "anthropic"
85/// api_key = "sk-ant-..."
86/// access_token = "oauth-token"
87///
88/// [profiles.myprofile.providers.offline]
89/// type = "custom"
90/// api_endpoint = "http://localhost:11434/v1"
91/// ```
92#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
93#[serde(tag = "type", rename_all = "lowercase")]
94pub enum ProviderConfig {
95    /// OpenAI provider configuration
96    OpenAI {
97        #[serde(skip_serializing_if = "Option::is_none")]
98        api_key: Option<String>,
99        #[serde(skip_serializing_if = "Option::is_none")]
100        api_endpoint: Option<String>,
101    },
102    /// Anthropic provider configuration
103    Anthropic {
104        #[serde(skip_serializing_if = "Option::is_none")]
105        api_key: Option<String>,
106        #[serde(skip_serializing_if = "Option::is_none")]
107        api_endpoint: Option<String>,
108        /// OAuth access token (for Claude subscription)
109        #[serde(skip_serializing_if = "Option::is_none")]
110        access_token: Option<String>,
111    },
112    /// Google Gemini provider configuration
113    Gemini {
114        #[serde(skip_serializing_if = "Option::is_none")]
115        api_key: Option<String>,
116        #[serde(skip_serializing_if = "Option::is_none")]
117        api_endpoint: Option<String>,
118    },
119    /// Custom OpenAI-compatible provider (Ollama, vLLM, etc.)
120    ///
121    /// The provider key in the config becomes the model prefix.
122    /// For example, if configured as `providers.offline`, use models as:
123    /// - `offline/llama3` - passes `llama3` to the API
124    /// - `offline/anthropic/claude-opus` - passes `anthropic/claude-opus` to the API
125    ///
126    /// # Example TOML
127    /// ```toml
128    /// [profiles.myprofile.providers.offline]
129    /// type = "custom"
130    /// api_endpoint = "http://localhost:11434/v1"
131    ///
132    /// # Then use models as:
133    /// smart_model = "offline/llama3"
134    /// eco_model = "offline/phi3"
135    /// ```
136    Custom {
137        #[serde(skip_serializing_if = "Option::is_none")]
138        api_key: Option<String>,
139        /// API endpoint URL (required for custom providers)
140        /// Use the base URL as required by your provider (e.g., "http://localhost:11434/v1")
141        api_endpoint: String,
142    },
143    /// Stakpak provider configuration
144    ///
145    /// Routes inference through Stakpak's unified API, which provides:
146    /// - Access to multiple LLM providers via a single endpoint
147    /// - Usage tracking and billing
148    /// - Session management and checkpoints
149    ///
150    /// # Example TOML
151    /// ```toml
152    /// [profiles.myprofile.providers.stakpak]
153    /// type = "stakpak"
154    /// api_key = "your-stakpak-api-key"
155    /// api_endpoint = "https://apiv2.stakpak.dev"  # optional, this is the default
156    ///
157    /// # Then use models as:
158    /// smart_model = "stakpak/anthropic/claude-sonnet-4-5-20250929"
159    /// ```
160    Stakpak {
161        /// Stakpak API key (required)
162        api_key: String,
163        /// API endpoint URL (default: https://apiv2.stakpak.dev)
164        #[serde(skip_serializing_if = "Option::is_none")]
165        api_endpoint: Option<String>,
166    },
167}
168
169impl ProviderConfig {
170    /// Get the provider type name
171    pub fn provider_type(&self) -> &'static str {
172        match self {
173            ProviderConfig::OpenAI { .. } => "openai",
174            ProviderConfig::Anthropic { .. } => "anthropic",
175            ProviderConfig::Gemini { .. } => "gemini",
176            ProviderConfig::Custom { .. } => "custom",
177            ProviderConfig::Stakpak { .. } => "stakpak",
178        }
179    }
180
181    /// Get the API key if set
182    pub fn api_key(&self) -> Option<&str> {
183        match self {
184            ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
185            ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
186            ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
187            ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
188            ProviderConfig::Stakpak { api_key, .. } => Some(api_key.as_str()),
189        }
190    }
191
192    /// Get the API endpoint if set
193    pub fn api_endpoint(&self) -> Option<&str> {
194        match self {
195            ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
196            ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
197            ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
198            ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
199            ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
200        }
201    }
202
203    /// Get the access token (Anthropic only)
204    pub fn access_token(&self) -> Option<&str> {
205        match self {
206            ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
207            _ => None,
208        }
209    }
210
211    /// Create an OpenAI provider config
212    pub fn openai(api_key: Option<String>) -> Self {
213        ProviderConfig::OpenAI {
214            api_key,
215            api_endpoint: None,
216        }
217    }
218
219    /// Create an Anthropic provider config
220    pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
221        ProviderConfig::Anthropic {
222            api_key,
223            api_endpoint: None,
224            access_token,
225        }
226    }
227
228    /// Create a Gemini provider config
229    pub fn gemini(api_key: Option<String>) -> Self {
230        ProviderConfig::Gemini {
231            api_key,
232            api_endpoint: None,
233        }
234    }
235
236    /// Create a custom provider config
237    pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
238        ProviderConfig::Custom {
239            api_key,
240            api_endpoint,
241        }
242    }
243
244    /// Create a Stakpak provider config
245    pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
246        ProviderConfig::Stakpak {
247            api_key,
248            api_endpoint,
249        }
250    }
251}
252
253#[derive(Clone, Debug, PartialEq, Serialize)]
254pub enum LLMModel {
255    Anthropic(AnthropicModel),
256    Gemini(GeminiModel),
257    OpenAI(OpenAIModel),
258    /// Custom provider with explicit provider name and model.
259    ///
260    /// Used for custom OpenAI-compatible providers like LiteLLM, Ollama, etc.
261    /// The provider name matches the key in the `providers` HashMap config.
262    ///
263    /// # Examples
264    /// - `litellm/claude-opus` → `provider: "litellm"`, `model: "claude-opus"`
265    /// - `litellm/anthropic/claude-opus` → `provider: "litellm"`, `model: "anthropic/claude-opus"`
266    /// - `ollama/llama3` → `provider: "ollama"`, `model: "llama3"`
267    Custom {
268        /// Provider name matching the key in providers config (e.g., "litellm", "ollama")
269        provider: String,
270        /// Model name/path to pass to the provider API (can include nested prefixes)
271        model: String,
272        /// Optional display name for the model (shown in UI instead of provider/model)
273        name: Option<String>,
274    },
275}
276
277impl ContextAware for LLMModel {
278    fn context_info(&self) -> ModelContextInfo {
279        match self {
280            LLMModel::Anthropic(model) => model.context_info(),
281            LLMModel::Gemini(model) => model.context_info(),
282            LLMModel::OpenAI(model) => model.context_info(),
283            LLMModel::Custom { .. } => ModelContextInfo::default(),
284        }
285    }
286
287    fn model_name(&self) -> String {
288        match self {
289            LLMModel::Anthropic(model) => model.model_name(),
290            LLMModel::Gemini(model) => model.model_name(),
291            LLMModel::OpenAI(model) => model.model_name(),
292            LLMModel::Custom {
293                provider,
294                model,
295                name,
296            } => name
297                .clone()
298                .unwrap_or_else(|| format!("{}/{}", provider, model)),
299        }
300    }
301}
302
303/// Aggregated provider configuration for LLM operations
304///
305/// This struct holds all configured providers, keyed by provider name.
306#[derive(Debug, Clone, Default)]
307pub struct LLMProviderConfig {
308    /// All provider configurations (key = provider name)
309    pub providers: HashMap<String, ProviderConfig>,
310}
311
312impl LLMProviderConfig {
313    /// Create a new empty provider config
314    pub fn new() -> Self {
315        Self {
316            providers: HashMap::new(),
317        }
318    }
319
320    /// Add a provider configuration
321    pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
322        self.providers.insert(name.into(), config);
323    }
324
325    /// Get a provider configuration by name
326    pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
327        self.providers.get(name)
328    }
329
330    /// Check if any providers are configured
331    pub fn is_empty(&self) -> bool {
332        self.providers.is_empty()
333    }
334}
335
336impl From<String> for LLMModel {
337    /// Parse a model string into an LLMModel.
338    ///
339    /// # Format
340    /// - `provider/model` - Explicit provider prefix
341    /// - `provider/nested/model` - Provider with nested model path (e.g., for LiteLLM)
342    /// - `model-name` - Auto-detect provider from model name
343    ///
344    /// # Examples
345    /// - `"litellm/anthropic/claude-opus"` → Custom { provider: "litellm", model: "anthropic/claude-opus" }
346    /// - `"anthropic/claude-opus-4-5"` → Anthropic(Claude45Opus) (built-in provider)
347    /// - `"claude-opus-4-5"` → Anthropic(Claude45Opus) (auto-detected)
348    /// - `"ollama/llama3"` → Custom { provider: "ollama", model: "llama3" }
349    fn from(value: String) -> Self {
350        // Check for explicit provider/model format (e.g., "litellm/anthropic/claude-opus")
351        // split_once takes only the first segment as provider, rest is the model path
352        if let Some((provider, model)) = value.split_once('/') {
353            // Check if it's a known built-in provider with explicit prefix
354            match provider {
355                "anthropic" => return Self::from_model_name(model),
356                "openai" => return Self::from_model_name(model),
357                "google" | "gemini" => return Self::from_model_name(model),
358                // Unknown provider = custom provider (model can contain additional slashes)
359                _ => {
360                    // Extract display name from the last segment (e.g., "anthropic/claude-opus" -> "claude-opus")
361                    let display_name = model.rsplit('/').next().unwrap_or(model).to_string();
362                    return LLMModel::Custom {
363                        provider: provider.to_string(),
364                        model: model.to_string(), // Preserves nested paths like "anthropic/claude-opus"
365                        name: Some(display_name),
366                    };
367                }
368            }
369        }
370
371        // Fall back to auto-detection by model name prefix
372        Self::from_model_name(&value)
373    }
374}
375
376impl LLMModel {
377    /// Parse model name without provider prefix
378    fn from_model_name(model: &str) -> Self {
379        if model.starts_with("claude-haiku-4-5") {
380            LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
381        } else if model.starts_with("claude-sonnet-4-5") {
382            LLMModel::Anthropic(AnthropicModel::Claude45Sonnet)
383        } else if model.starts_with("claude-opus-4-5") {
384            LLMModel::Anthropic(AnthropicModel::Claude45Opus)
385        } else if model == "gemini-2.5-flash-lite" {
386            LLMModel::Gemini(GeminiModel::Gemini25FlashLite)
387        } else if model.starts_with("gemini-2.5-flash") {
388            LLMModel::Gemini(GeminiModel::Gemini25Flash)
389        } else if model.starts_with("gemini-2.5-pro") {
390            LLMModel::Gemini(GeminiModel::Gemini25Pro)
391        } else if model.starts_with("gemini-3-pro-preview") {
392            LLMModel::Gemini(GeminiModel::Gemini3Pro)
393        } else if model.starts_with("gemini-3-flash-preview") {
394            LLMModel::Gemini(GeminiModel::Gemini3Flash)
395        } else if model.starts_with("gpt-5-mini") {
396            LLMModel::OpenAI(OpenAIModel::GPT5Mini)
397        } else if model.starts_with("gpt-5") {
398            LLMModel::OpenAI(OpenAIModel::GPT5)
399        } else {
400            // Unknown model without provider prefix - treat as custom with "custom" provider
401            LLMModel::Custom {
402                provider: "custom".to_string(),
403                model: model.to_string(),
404                name: Some(model.to_string()), // Use model name as display name
405            }
406        }
407    }
408
409    /// Get the provider name for this model
410    pub fn provider_name(&self) -> &str {
411        match self {
412            LLMModel::Anthropic(_) => "anthropic",
413            LLMModel::Gemini(_) => "google",
414            LLMModel::OpenAI(_) => "openai",
415            LLMModel::Custom { provider, .. } => provider,
416        }
417    }
418
419    /// Get just the model name without provider prefix
420    pub fn model_id(&self) -> String {
421        match self {
422            LLMModel::Anthropic(m) => m.to_string(),
423            LLMModel::Gemini(m) => m.to_string(),
424            LLMModel::OpenAI(m) => m.to_string(),
425            LLMModel::Custom { model, .. } => model.clone(),
426        }
427    }
428
429    /// Set a display name for a custom model
430    pub fn with_name(self, name: impl Into<String>) -> Self {
431        match self {
432            LLMModel::Custom {
433                provider, model, ..
434            } => LLMModel::Custom {
435                provider,
436                model,
437                name: Some(name.into()),
438            },
439            other => other, // Built-in models don't support custom names
440        }
441    }
442
443    /// Get the display name if set (for custom models only)
444    pub fn display_name(&self) -> Option<&str> {
445        match self {
446            LLMModel::Custom { name, .. } => name.as_deref(),
447            _ => None,
448        }
449    }
450}
451
452impl Display for LLMModel {
453    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454        match self {
455            LLMModel::Anthropic(model) => write!(f, "{}", model),
456            LLMModel::Gemini(model) => write!(f, "{}", model),
457            LLMModel::OpenAI(model) => write!(f, "{}", model),
458            LLMModel::Custom {
459                provider,
460                model,
461                name,
462            } => {
463                if let Some(name) = name {
464                    write!(f, "{}", name)
465                } else {
466                    write!(f, "{}/{}", provider, model)
467                }
468            }
469        }
470    }
471}
472
473/// Provider-specific options for LLM requests
474#[derive(Clone, Debug, Serialize, Deserialize, Default)]
475pub struct LLMProviderOptions {
476    /// Anthropic-specific options
477    #[serde(skip_serializing_if = "Option::is_none")]
478    pub anthropic: Option<LLMAnthropicOptions>,
479
480    /// OpenAI-specific options
481    #[serde(skip_serializing_if = "Option::is_none")]
482    pub openai: Option<LLMOpenAIOptions>,
483
484    /// Google/Gemini-specific options
485    #[serde(skip_serializing_if = "Option::is_none")]
486    pub google: Option<LLMGoogleOptions>,
487}
488
489/// Anthropic-specific options
490#[derive(Clone, Debug, Serialize, Deserialize, Default)]
491pub struct LLMAnthropicOptions {
492    /// Extended thinking configuration
493    #[serde(skip_serializing_if = "Option::is_none")]
494    pub thinking: Option<LLMThinkingOptions>,
495}
496
497/// Thinking/reasoning options
498#[derive(Clone, Debug, Serialize, Deserialize)]
499pub struct LLMThinkingOptions {
500    /// Budget tokens for thinking (must be >= 1024)
501    pub budget_tokens: u32,
502}
503
504impl LLMThinkingOptions {
505    pub fn new(budget_tokens: u32) -> Self {
506        Self {
507            budget_tokens: budget_tokens.max(1024),
508        }
509    }
510}
511
512/// OpenAI-specific options
513#[derive(Clone, Debug, Serialize, Deserialize, Default)]
514pub struct LLMOpenAIOptions {
515    /// Reasoning effort for o1/o3/o4 models ("low", "medium", "high")
516    #[serde(skip_serializing_if = "Option::is_none")]
517    pub reasoning_effort: Option<String>,
518}
519
520/// Google/Gemini-specific options
521#[derive(Clone, Debug, Serialize, Deserialize, Default)]
522pub struct LLMGoogleOptions {
523    /// Thinking budget in tokens
524    #[serde(skip_serializing_if = "Option::is_none")]
525    pub thinking_budget: Option<u32>,
526}
527
528#[derive(Clone, Debug, Serialize)]
529pub struct LLMInput {
530    pub model: LLMModel,
531    pub messages: Vec<LLMMessage>,
532    pub max_tokens: u32,
533    pub tools: Option<Vec<LLMTool>>,
534    #[serde(skip_serializing_if = "Option::is_none")]
535    pub provider_options: Option<LLMProviderOptions>,
536    /// Custom headers to pass to the inference provider
537    #[serde(skip_serializing_if = "Option::is_none")]
538    pub headers: Option<std::collections::HashMap<String, String>>,
539}
540
541#[derive(Debug)]
542pub struct LLMStreamInput {
543    pub model: LLMModel,
544    pub messages: Vec<LLMMessage>,
545    pub max_tokens: u32,
546    pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
547    pub tools: Option<Vec<LLMTool>>,
548    pub provider_options: Option<LLMProviderOptions>,
549    /// Custom headers to pass to the inference provider
550    pub headers: Option<std::collections::HashMap<String, String>>,
551}
552
553impl From<&LLMStreamInput> for LLMInput {
554    fn from(value: &LLMStreamInput) -> Self {
555        LLMInput {
556            model: value.model.clone(),
557            messages: value.messages.clone(),
558            max_tokens: value.max_tokens,
559            tools: value.tools.clone(),
560            provider_options: value.provider_options.clone(),
561            headers: value.headers.clone(),
562        }
563    }
564}
565
566#[derive(Serialize, Deserialize, Debug, Clone, Default)]
567pub struct LLMMessage {
568    pub role: String,
569    pub content: LLMMessageContent,
570}
571
572#[derive(Serialize, Deserialize, Debug, Clone)]
573pub struct SimpleLLMMessage {
574    #[serde(rename = "role")]
575    pub role: SimpleLLMRole,
576    pub content: String,
577}
578
579#[derive(Serialize, Deserialize, Debug, Clone)]
580#[serde(rename_all = "lowercase")]
581pub enum SimpleLLMRole {
582    User,
583    Assistant,
584}
585
586impl std::fmt::Display for SimpleLLMRole {
587    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
588        match self {
589            SimpleLLMRole::User => write!(f, "user"),
590            SimpleLLMRole::Assistant => write!(f, "assistant"),
591        }
592    }
593}
594
595#[derive(Serialize, Deserialize, Debug, Clone)]
596#[serde(untagged)]
597pub enum LLMMessageContent {
598    String(String),
599    List(Vec<LLMMessageTypedContent>),
600}
601
602#[allow(clippy::to_string_trait_impl)]
603impl ToString for LLMMessageContent {
604    fn to_string(&self) -> String {
605        match self {
606            LLMMessageContent::String(s) => s.clone(),
607            LLMMessageContent::List(l) => l
608                .iter()
609                .map(|c| match c {
610                    LLMMessageTypedContent::Text { text } => text.clone(),
611                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
612                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
613                    LLMMessageTypedContent::Image { .. } => String::new(),
614                })
615                .collect::<Vec<_>>()
616                .join("\n"),
617        }
618    }
619}
620
621impl From<String> for LLMMessageContent {
622    fn from(value: String) -> Self {
623        LLMMessageContent::String(value)
624    }
625}
626
627impl Default for LLMMessageContent {
628    fn default() -> Self {
629        LLMMessageContent::String(String::new())
630    }
631}
632
633#[derive(Serialize, Deserialize, Debug, Clone)]
634#[serde(tag = "type")]
635pub enum LLMMessageTypedContent {
636    #[serde(rename = "text")]
637    Text { text: String },
638    #[serde(rename = "tool_use")]
639    ToolCall {
640        id: String,
641        name: String,
642        #[serde(alias = "input")]
643        args: serde_json::Value,
644    },
645    #[serde(rename = "tool_result")]
646    ToolResult {
647        tool_use_id: String,
648        content: String,
649    },
650    #[serde(rename = "image")]
651    Image { source: LLMMessageImageSource },
652}
653
654#[derive(Serialize, Deserialize, Debug, Clone)]
655pub struct LLMMessageImageSource {
656    #[serde(rename = "type")]
657    pub r#type: String,
658    pub media_type: String,
659    pub data: String,
660}
661
662impl Default for LLMMessageTypedContent {
663    fn default() -> Self {
664        LLMMessageTypedContent::Text {
665            text: String::new(),
666        }
667    }
668}
669
670#[derive(Serialize, Deserialize, Debug, Clone)]
671pub struct LLMChoice {
672    pub finish_reason: Option<String>,
673    pub index: u32,
674    pub message: LLMMessage,
675}
676
677#[derive(Serialize, Deserialize, Debug, Clone)]
678pub struct LLMCompletionResponse {
679    pub model: String,
680    pub object: String,
681    pub choices: Vec<LLMChoice>,
682    pub created: u64,
683    pub usage: Option<LLMTokenUsage>,
684    pub id: String,
685}
686
687#[derive(Serialize, Deserialize, Debug, Clone)]
688pub struct LLMStreamDelta {
689    #[serde(skip_serializing_if = "Option::is_none")]
690    pub content: Option<String>,
691}
692
693#[derive(Serialize, Deserialize, Debug, Clone)]
694pub struct LLMStreamChoice {
695    pub finish_reason: Option<String>,
696    pub index: u32,
697    pub message: Option<LLMMessage>,
698    pub delta: LLMStreamDelta,
699}
700
701#[derive(Serialize, Deserialize, Debug, Clone)]
702pub struct LLMCompletionStreamResponse {
703    pub model: String,
704    pub object: String,
705    pub choices: Vec<LLMStreamChoice>,
706    pub created: u64,
707    #[serde(skip_serializing_if = "Option::is_none")]
708    pub usage: Option<LLMTokenUsage>,
709    pub id: String,
710    pub citations: Option<Vec<String>>,
711}
712
713#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
714pub struct LLMTool {
715    pub name: String,
716    pub description: String,
717    pub input_schema: serde_json::Value,
718}
719
720#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
721pub struct LLMTokenUsage {
722    pub prompt_tokens: u32,
723    pub completion_tokens: u32,
724    pub total_tokens: u32,
725
726    #[serde(skip_serializing_if = "Option::is_none")]
727    pub prompt_tokens_details: Option<PromptTokensDetails>,
728}
729
730#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
731#[serde(rename_all = "snake_case")]
732pub enum TokenType {
733    InputTokens,
734    OutputTokens,
735    CacheReadInputTokens,
736    CacheWriteInputTokens,
737}
738
739#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
740pub struct PromptTokensDetails {
741    #[serde(skip_serializing_if = "Option::is_none")]
742    pub input_tokens: Option<u32>,
743    #[serde(skip_serializing_if = "Option::is_none")]
744    pub output_tokens: Option<u32>,
745    #[serde(skip_serializing_if = "Option::is_none")]
746    pub cache_read_input_tokens: Option<u32>,
747    #[serde(skip_serializing_if = "Option::is_none")]
748    pub cache_write_input_tokens: Option<u32>,
749}
750
751impl PromptTokensDetails {
752    /// Returns an iterator over the token types and their values
753    pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
754        [
755            (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
756            (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
757            (
758                TokenType::CacheReadInputTokens,
759                self.cache_read_input_tokens.unwrap_or(0),
760            ),
761            (
762                TokenType::CacheWriteInputTokens,
763                self.cache_write_input_tokens.unwrap_or(0),
764            ),
765        ]
766        .into_iter()
767    }
768}
769
770impl std::ops::Add for PromptTokensDetails {
771    type Output = Self;
772
773    fn add(self, rhs: Self) -> Self::Output {
774        Self {
775            input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
776            output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
777            cache_read_input_tokens: Some(
778                self.cache_read_input_tokens.unwrap_or(0)
779                    + rhs.cache_read_input_tokens.unwrap_or(0),
780            ),
781            cache_write_input_tokens: Some(
782                self.cache_write_input_tokens.unwrap_or(0)
783                    + rhs.cache_write_input_tokens.unwrap_or(0),
784            ),
785        }
786    }
787}
788
789impl std::ops::AddAssign for PromptTokensDetails {
790    fn add_assign(&mut self, rhs: Self) {
791        self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
792        self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
793        self.cache_read_input_tokens = Some(
794            self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
795        );
796        self.cache_write_input_tokens = Some(
797            self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
798        );
799    }
800}
801
802#[derive(Serialize, Deserialize, Debug, Clone)]
803#[serde(tag = "type")]
804pub enum GenerationDelta {
805    Content { content: String },
806    Thinking { thinking: String },
807    ToolUse { tool_use: GenerationDeltaToolUse },
808    Usage { usage: LLMTokenUsage },
809    Metadata { metadata: serde_json::Value },
810}
811
812#[derive(Serialize, Deserialize, Debug, Clone)]
813pub struct GenerationDeltaToolUse {
814    pub id: Option<String>,
815    pub name: Option<String>,
816    pub input: Option<String>,
817    pub index: usize,
818}
819
820#[cfg(test)]
821mod tests {
822    use super::*;
823
824    #[test]
825    fn test_llm_model_from_known_anthropic_model() {
826        let model = LLMModel::from("claude-opus-4-5-20251101".to_string());
827        assert!(matches!(
828            model,
829            LLMModel::Anthropic(AnthropicModel::Claude45Opus)
830        ));
831    }
832
833    #[test]
834    fn test_llm_model_from_known_openai_model() {
835        let model = LLMModel::from("gpt-5".to_string());
836        assert!(matches!(model, LLMModel::OpenAI(OpenAIModel::GPT5)));
837    }
838
839    #[test]
840    fn test_llm_model_from_known_gemini_model() {
841        let model = LLMModel::from("gemini-2.5-flash".to_string());
842        assert!(matches!(
843            model,
844            LLMModel::Gemini(GeminiModel::Gemini25Flash)
845        ));
846    }
847
848    #[test]
849    fn test_llm_model_from_custom_provider_with_slash() {
850        let model = LLMModel::from("litellm/claude-opus-4-5".to_string());
851        match model {
852            LLMModel::Custom {
853                provider,
854                model,
855                name,
856            } => {
857                assert_eq!(provider, "litellm");
858                assert_eq!(model, "claude-opus-4-5");
859                // Display name is automatically extracted from last segment
860                assert_eq!(name, Some("claude-opus-4-5".to_string()));
861            }
862            _ => panic!("Expected Custom model"),
863        }
864    }
865
866    #[test]
867    fn test_llm_model_from_ollama_provider() {
868        let model = LLMModel::from("ollama/llama3".to_string());
869        match model {
870            LLMModel::Custom {
871                provider,
872                model,
873                name,
874            } => {
875                assert_eq!(provider, "ollama");
876                assert_eq!(model, "llama3");
877                // Display name is automatically extracted from last segment
878                assert_eq!(name, Some("llama3".to_string()));
879            }
880            _ => panic!("Expected Custom model"),
881        }
882    }
883
884    #[test]
885    fn test_llm_model_from_nested_provider() {
886        // Test nested path like stakpak/anthropic/claude-sonnet-4-5
887        let model = LLMModel::from("stakpak/anthropic/claude-sonnet-4-5".to_string());
888        match model {
889            LLMModel::Custom {
890                provider,
891                model,
892                name,
893            } => {
894                assert_eq!(provider, "stakpak");
895                assert_eq!(model, "anthropic/claude-sonnet-4-5");
896                // Display name is the last segment only
897                assert_eq!(name, Some("claude-sonnet-4-5".to_string()));
898            }
899            _ => panic!("Expected Custom model"),
900        }
901    }
902
903    #[test]
904    fn test_llm_model_explicit_anthropic_prefix() {
905        // Explicit anthropic/ prefix should still parse to Anthropic variant
906        let model = LLMModel::from("anthropic/claude-opus-4-5".to_string());
907        assert!(matches!(
908            model,
909            LLMModel::Anthropic(AnthropicModel::Claude45Opus)
910        ));
911    }
912
913    #[test]
914    fn test_llm_model_explicit_openai_prefix() {
915        let model = LLMModel::from("openai/gpt-5".to_string());
916        assert!(matches!(model, LLMModel::OpenAI(OpenAIModel::GPT5)));
917    }
918
919    #[test]
920    fn test_llm_model_explicit_google_prefix() {
921        let model = LLMModel::from("google/gemini-2.5-flash".to_string());
922        assert!(matches!(
923            model,
924            LLMModel::Gemini(GeminiModel::Gemini25Flash)
925        ));
926    }
927
928    #[test]
929    fn test_llm_model_explicit_gemini_prefix() {
930        // gemini/ alias should also work
931        let model = LLMModel::from("gemini/gemini-2.5-flash".to_string());
932        assert!(matches!(
933            model,
934            LLMModel::Gemini(GeminiModel::Gemini25Flash)
935        ));
936    }
937
938    #[test]
939    fn test_llm_model_unknown_model_becomes_custom() {
940        let model = LLMModel::from("some-random-model".to_string());
941        match model {
942            LLMModel::Custom {
943                provider,
944                model,
945                name,
946            } => {
947                assert_eq!(provider, "custom");
948                assert_eq!(model, "some-random-model");
949                // Display name is the model name itself
950                assert_eq!(name, Some("some-random-model".to_string()));
951            }
952            _ => panic!("Expected Custom model"),
953        }
954    }
955
956    #[test]
957    fn test_llm_model_display_anthropic() {
958        let model = LLMModel::Anthropic(AnthropicModel::Claude45Sonnet);
959        let s = model.to_string();
960        assert!(s.contains("claude"));
961    }
962
963    #[test]
964    fn test_llm_model_display_custom() {
965        let model = LLMModel::Custom {
966            provider: "litellm".to_string(),
967            model: "claude-opus".to_string(),
968            name: None,
969        };
970        assert_eq!(model.to_string(), "litellm/claude-opus");
971    }
972
973    #[test]
974    fn test_llm_model_display_custom_with_name() {
975        let model = LLMModel::Custom {
976            provider: "litellm".to_string(),
977            model: "claude-opus".to_string(),
978            name: Some("My Custom Model".to_string()),
979        };
980        assert_eq!(model.to_string(), "My Custom Model");
981    }
982
983    #[test]
984    fn test_llm_model_with_name() {
985        let model = LLMModel::from("ollama/llama3".to_string()).with_name("Local Llama");
986        assert_eq!(model.to_string(), "Local Llama");
987        assert_eq!(model.display_name(), Some("Local Llama"));
988        // model_id should still return the original model
989        assert_eq!(model.model_id(), "llama3");
990    }
991
992    #[test]
993    fn test_llm_model_provider_name() {
994        assert_eq!(
995            LLMModel::Anthropic(AnthropicModel::Claude45Sonnet).provider_name(),
996            "anthropic"
997        );
998        assert_eq!(
999            LLMModel::OpenAI(OpenAIModel::GPT5).provider_name(),
1000            "openai"
1001        );
1002        assert_eq!(
1003            LLMModel::Gemini(GeminiModel::Gemini25Flash).provider_name(),
1004            "google"
1005        );
1006        assert_eq!(
1007            LLMModel::Custom {
1008                provider: "litellm".to_string(),
1009                model: "test".to_string(),
1010                name: None,
1011            }
1012            .provider_name(),
1013            "litellm"
1014        );
1015    }
1016
1017    #[test]
1018    fn test_llm_model_model_id() {
1019        let model = LLMModel::Custom {
1020            provider: "litellm".to_string(),
1021            model: "claude-opus-4-5".to_string(),
1022            name: None,
1023        };
1024        assert_eq!(model.model_id(), "claude-opus-4-5");
1025    }
1026
1027    // =========================================================================
1028    // ProviderConfig Tests
1029    // =========================================================================
1030
1031    #[test]
1032    fn test_provider_config_openai_serialization() {
1033        let config = ProviderConfig::OpenAI {
1034            api_key: Some("sk-test".to_string()),
1035            api_endpoint: None,
1036        };
1037        let json = serde_json::to_string(&config).unwrap();
1038        assert!(json.contains("\"type\":\"openai\""));
1039        assert!(json.contains("\"api_key\":\"sk-test\""));
1040        assert!(!json.contains("api_endpoint")); // Should be skipped when None
1041    }
1042
1043    #[test]
1044    fn test_provider_config_openai_with_endpoint() {
1045        let config = ProviderConfig::OpenAI {
1046            api_key: Some("sk-test".to_string()),
1047            api_endpoint: Some("https://custom.openai.com/v1".to_string()),
1048        };
1049        let json = serde_json::to_string(&config).unwrap();
1050        assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
1051    }
1052
1053    #[test]
1054    fn test_provider_config_anthropic_serialization() {
1055        let config = ProviderConfig::Anthropic {
1056            api_key: Some("sk-ant-test".to_string()),
1057            api_endpoint: None,
1058            access_token: Some("oauth-token".to_string()),
1059        };
1060        let json = serde_json::to_string(&config).unwrap();
1061        assert!(json.contains("\"type\":\"anthropic\""));
1062        assert!(json.contains("\"api_key\":\"sk-ant-test\""));
1063        assert!(json.contains("\"access_token\":\"oauth-token\""));
1064    }
1065
1066    #[test]
1067    fn test_provider_config_gemini_serialization() {
1068        let config = ProviderConfig::Gemini {
1069            api_key: Some("gemini-key".to_string()),
1070            api_endpoint: None,
1071        };
1072        let json = serde_json::to_string(&config).unwrap();
1073        assert!(json.contains("\"type\":\"gemini\""));
1074        assert!(json.contains("\"api_key\":\"gemini-key\""));
1075    }
1076
1077    #[test]
1078    fn test_provider_config_custom_serialization() {
1079        let config = ProviderConfig::Custom {
1080            api_key: Some("sk-custom".to_string()),
1081            api_endpoint: "http://localhost:4000".to_string(),
1082        };
1083        let json = serde_json::to_string(&config).unwrap();
1084        assert!(json.contains("\"type\":\"custom\""));
1085        assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
1086        assert!(json.contains("\"api_key\":\"sk-custom\""));
1087    }
1088
1089    #[test]
1090    fn test_provider_config_custom_without_key() {
1091        let config = ProviderConfig::Custom {
1092            api_key: None,
1093            api_endpoint: "http://localhost:11434/v1".to_string(),
1094        };
1095        let json = serde_json::to_string(&config).unwrap();
1096        assert!(json.contains("\"type\":\"custom\""));
1097        assert!(json.contains("\"api_endpoint\""));
1098        assert!(!json.contains("api_key")); // Should be skipped when None
1099    }
1100
1101    #[test]
1102    fn test_provider_config_deserialization_openai() {
1103        let json = r#"{"type":"openai","api_key":"sk-test"}"#;
1104        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1105        assert!(matches!(config, ProviderConfig::OpenAI { .. }));
1106        assert_eq!(config.api_key(), Some("sk-test"));
1107    }
1108
1109    #[test]
1110    fn test_provider_config_deserialization_anthropic() {
1111        let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
1112        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1113        assert!(matches!(config, ProviderConfig::Anthropic { .. }));
1114        assert_eq!(config.api_key(), Some("sk-ant"));
1115        assert_eq!(config.access_token(), Some("oauth"));
1116    }
1117
1118    #[test]
1119    fn test_provider_config_deserialization_gemini() {
1120        let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
1121        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1122        assert!(matches!(config, ProviderConfig::Gemini { .. }));
1123        assert_eq!(config.api_key(), Some("gemini-key"));
1124    }
1125
1126    #[test]
1127    fn test_provider_config_deserialization_custom() {
1128        let json =
1129            r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
1130        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1131        assert!(matches!(config, ProviderConfig::Custom { .. }));
1132        assert_eq!(config.api_key(), Some("sk-custom"));
1133        assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
1134    }
1135
1136    #[test]
1137    fn test_provider_config_helper_methods() {
1138        let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
1139        assert_eq!(openai.provider_type(), "openai");
1140        assert_eq!(openai.api_key(), Some("sk-openai"));
1141
1142        let anthropic =
1143            ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
1144        assert_eq!(anthropic.provider_type(), "anthropic");
1145        assert_eq!(anthropic.access_token(), Some("oauth"));
1146
1147        let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
1148        assert_eq!(gemini.provider_type(), "gemini");
1149
1150        let custom = ProviderConfig::custom(
1151            "http://localhost:4000".to_string(),
1152            Some("sk-custom".to_string()),
1153        );
1154        assert_eq!(custom.provider_type(), "custom");
1155        assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
1156    }
1157
1158    #[test]
1159    fn test_llm_provider_config_new() {
1160        let config = LLMProviderConfig::new();
1161        assert!(config.is_empty());
1162    }
1163
1164    #[test]
1165    fn test_llm_provider_config_add_and_get() {
1166        let mut config = LLMProviderConfig::new();
1167        config.add_provider(
1168            "openai",
1169            ProviderConfig::openai(Some("sk-test".to_string())),
1170        );
1171        config.add_provider(
1172            "anthropic",
1173            ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
1174        );
1175
1176        assert!(!config.is_empty());
1177        assert!(config.get_provider("openai").is_some());
1178        assert!(config.get_provider("anthropic").is_some());
1179        assert!(config.get_provider("unknown").is_none());
1180    }
1181
1182    #[test]
1183    fn test_provider_config_toml_parsing() {
1184        // Test parsing a HashMap of providers from TOML-like JSON
1185        let json = r#"{
1186            "openai": {"type": "openai", "api_key": "sk-openai"},
1187            "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
1188            "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
1189        }"#;
1190
1191        let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1192        assert_eq!(providers.len(), 3);
1193
1194        assert!(matches!(
1195            providers.get("openai"),
1196            Some(ProviderConfig::OpenAI { .. })
1197        ));
1198        assert!(matches!(
1199            providers.get("anthropic"),
1200            Some(ProviderConfig::Anthropic { .. })
1201        ));
1202        assert!(matches!(
1203            providers.get("litellm"),
1204            Some(ProviderConfig::Custom { .. })
1205        ));
1206    }
1207}