Skip to main content

stakpak_shared/models/
llm.rs

1//! LLM Provider and Model Configuration
2//!
3//! This module provides the configuration types for LLM providers and models.
4//!
5//! # Provider Configuration
6//!
7//! Providers are configured in a `providers` HashMap where the key becomes the
8//! model prefix for routing requests to the correct provider.
9//!
10//! ## Built-in Providers
11//!
12//! - `openai` - OpenAI API
13//! - `anthropic` - Anthropic API (supports OAuth via `access_token`)
14//! - `gemini` - Google Gemini API
15//!
16//! For built-in providers, you can use the model name directly without a prefix:
17//! - `claude-sonnet-4-5` → auto-detected as Anthropic
18//! - `gpt-4` → auto-detected as OpenAI
19//! - `gemini-2.5-pro` → auto-detected as Gemini
20//!
21//! ## Custom Providers
22//!
23//! Any OpenAI-compatible API can be configured using `type = "custom"`.
24//! The provider key becomes the model prefix.
25//!
26//! # Model Routing
27//!
28//! Models can be specified with or without a provider prefix:
29//!
30//! - `claude-sonnet-4-5` → auto-detected as `anthropic` provider
31//! - `anthropic/claude-sonnet-4-5` → explicit `anthropic` provider
32//! - `offline/llama3` → routes to `offline` custom provider, sends `llama3` to API
33//! - `custom/anthropic/claude-opus` → routes to `custom` provider,
34//!   sends `anthropic/claude-opus` to the API
35//!
36//! # Example Configuration
37//!
38//! ```toml
39//! [profiles.default]
40//! provider = "local"
41//! smart_model = "claude-sonnet-4-5"  # auto-detected as anthropic
42//! eco_model = "offline/llama3"       # custom provider
43//!
44//! [profiles.default.providers.anthropic]
45//! type = "anthropic"
46//! # api_key from auth.toml or ANTHROPIC_API_KEY env var
47//!
48//! [profiles.default.providers.offline]
49//! type = "custom"
50//! api_endpoint = "http://localhost:11434/v1"
51//! ```
52
53use serde::{Deserialize, Serialize};
54use stakai::Model;
55use std::collections::HashMap;
56
57// =============================================================================
58// Provider Configuration
59// =============================================================================
60
61/// Unified provider configuration enum
62///
63/// All provider configurations are stored in a `HashMap<String, ProviderConfig>`
64/// where the key is the provider name and becomes the model prefix for routing.
65///
66/// # Provider Key = Model Prefix
67///
68/// The key used in the HashMap becomes the prefix used in model names:
69/// - Config key: `providers.offline`
70/// - Model usage: `offline/llama3`
71/// - Routing: finds `offline` provider, sends `llama3` to API
72///
73/// # Example TOML
74/// ```toml
75/// [profiles.myprofile.providers.openai]
76/// type = "openai"
77/// api_key = "sk-..."
78///
79/// [profiles.myprofile.providers.anthropic]
80/// type = "anthropic"
81/// api_key = "sk-ant-..."
82/// access_token = "oauth-token"
83///
84/// [profiles.myprofile.providers.offline]
85/// type = "custom"
86/// api_endpoint = "http://localhost:11434/v1"
87/// ```
88#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
89#[serde(tag = "type", rename_all = "lowercase")]
90pub enum ProviderConfig {
91    /// OpenAI provider configuration
92    OpenAI {
93        #[serde(skip_serializing_if = "Option::is_none")]
94        api_key: Option<String>,
95        #[serde(skip_serializing_if = "Option::is_none")]
96        api_endpoint: Option<String>,
97    },
98    /// Anthropic provider configuration
99    Anthropic {
100        #[serde(skip_serializing_if = "Option::is_none")]
101        api_key: Option<String>,
102        #[serde(skip_serializing_if = "Option::is_none")]
103        api_endpoint: Option<String>,
104        /// OAuth access token (for Claude subscription)
105        #[serde(skip_serializing_if = "Option::is_none")]
106        access_token: Option<String>,
107    },
108    /// Google Gemini provider configuration
109    Gemini {
110        #[serde(skip_serializing_if = "Option::is_none")]
111        api_key: Option<String>,
112        #[serde(skip_serializing_if = "Option::is_none")]
113        api_endpoint: Option<String>,
114    },
115    /// Custom OpenAI-compatible provider (Ollama, vLLM, etc.)
116    ///
117    /// The provider key in the config becomes the model prefix.
118    /// For example, if configured as `providers.offline`, use models as:
119    /// - `offline/llama3` - passes `llama3` to the API
120    /// - `offline/anthropic/claude-opus` - passes `anthropic/claude-opus` to the API
121    ///
122    /// # Example TOML
123    /// ```toml
124    /// [profiles.myprofile.providers.offline]
125    /// type = "custom"
126    /// api_endpoint = "http://localhost:11434/v1"
127    ///
128    /// # Then use models as:
129    /// smart_model = "offline/llama3"
130    /// eco_model = "offline/phi3"
131    /// ```
132    Custom {
133        #[serde(skip_serializing_if = "Option::is_none")]
134        api_key: Option<String>,
135        /// API endpoint URL (required for custom providers)
136        /// Use the base URL as required by your provider (e.g., "http://localhost:11434/v1")
137        api_endpoint: String,
138    },
139    /// Stakpak provider configuration
140    ///
141    /// Routes inference through Stakpak's unified API, which provides:
142    /// - Access to multiple LLM providers via a single endpoint
143    /// - Usage tracking and billing
144    /// - Session management and checkpoints
145    ///
146    /// # Example TOML
147    /// ```toml
148    /// [profiles.myprofile.providers.stakpak]
149    /// type = "stakpak"
150    /// api_key = "your-stakpak-api-key"
151    /// api_endpoint = "https://apiv2.stakpak.dev"  # optional, this is the default
152    ///
153    /// # Then use models as:
154    /// smart_model = "stakpak/anthropic/claude-sonnet-4-5-20250929"
155    /// ```
156    Stakpak {
157        /// Stakpak API key (required)
158        api_key: String,
159        /// API endpoint URL (default: https://apiv2.stakpak.dev)
160        #[serde(skip_serializing_if = "Option::is_none")]
161        api_endpoint: Option<String>,
162    },
163}
164
165impl ProviderConfig {
166    /// Get the provider type name
167    pub fn provider_type(&self) -> &'static str {
168        match self {
169            ProviderConfig::OpenAI { .. } => "openai",
170            ProviderConfig::Anthropic { .. } => "anthropic",
171            ProviderConfig::Gemini { .. } => "gemini",
172            ProviderConfig::Custom { .. } => "custom",
173            ProviderConfig::Stakpak { .. } => "stakpak",
174        }
175    }
176
177    /// Get the API key if set
178    pub fn api_key(&self) -> Option<&str> {
179        match self {
180            ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
181            ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
182            ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
183            ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
184            ProviderConfig::Stakpak { api_key, .. } => Some(api_key.as_str()),
185        }
186    }
187
188    /// Get the API endpoint if set
189    pub fn api_endpoint(&self) -> Option<&str> {
190        match self {
191            ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
192            ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
193            ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
194            ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
195            ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
196        }
197    }
198
199    /// Get the access token (Anthropic only)
200    pub fn access_token(&self) -> Option<&str> {
201        match self {
202            ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
203            _ => None,
204        }
205    }
206
207    /// Create an OpenAI provider config
208    pub fn openai(api_key: Option<String>) -> Self {
209        ProviderConfig::OpenAI {
210            api_key,
211            api_endpoint: None,
212        }
213    }
214
215    /// Create an Anthropic provider config
216    pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
217        ProviderConfig::Anthropic {
218            api_key,
219            api_endpoint: None,
220            access_token,
221        }
222    }
223
224    /// Create a Gemini provider config
225    pub fn gemini(api_key: Option<String>) -> Self {
226        ProviderConfig::Gemini {
227            api_key,
228            api_endpoint: None,
229        }
230    }
231
232    /// Create a custom provider config
233    pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
234        ProviderConfig::Custom {
235            api_key,
236            api_endpoint,
237        }
238    }
239
240    /// Create a Stakpak provider config
241    pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
242        ProviderConfig::Stakpak {
243            api_key,
244            api_endpoint,
245        }
246    }
247}
248
249/// Aggregated provider configuration for LLM operations
250///
251/// This struct holds all configured providers, keyed by provider name.
252#[derive(Debug, Clone, Default)]
253pub struct LLMProviderConfig {
254    /// All provider configurations (key = provider name)
255    pub providers: HashMap<String, ProviderConfig>,
256}
257
258impl LLMProviderConfig {
259    /// Create a new empty provider config
260    pub fn new() -> Self {
261        Self {
262            providers: HashMap::new(),
263        }
264    }
265
266    /// Add a provider configuration
267    pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
268        self.providers.insert(name.into(), config);
269    }
270
271    /// Get a provider configuration by name
272    pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
273        self.providers.get(name)
274    }
275
276    /// Check if any providers are configured
277    pub fn is_empty(&self) -> bool {
278        self.providers.is_empty()
279    }
280}
281
282/// Provider-specific options for LLM requests
283#[derive(Clone, Debug, Serialize, Deserialize, Default)]
284pub struct LLMProviderOptions {
285    /// Anthropic-specific options
286    #[serde(skip_serializing_if = "Option::is_none")]
287    pub anthropic: Option<LLMAnthropicOptions>,
288
289    /// OpenAI-specific options
290    #[serde(skip_serializing_if = "Option::is_none")]
291    pub openai: Option<LLMOpenAIOptions>,
292
293    /// Google/Gemini-specific options
294    #[serde(skip_serializing_if = "Option::is_none")]
295    pub google: Option<LLMGoogleOptions>,
296}
297
298/// Anthropic-specific options
299#[derive(Clone, Debug, Serialize, Deserialize, Default)]
300pub struct LLMAnthropicOptions {
301    /// Extended thinking configuration
302    #[serde(skip_serializing_if = "Option::is_none")]
303    pub thinking: Option<LLMThinkingOptions>,
304}
305
306/// Thinking/reasoning options
307#[derive(Clone, Debug, Serialize, Deserialize)]
308pub struct LLMThinkingOptions {
309    /// Budget tokens for thinking (must be >= 1024)
310    pub budget_tokens: u32,
311}
312
313impl LLMThinkingOptions {
314    pub fn new(budget_tokens: u32) -> Self {
315        Self {
316            budget_tokens: budget_tokens.max(1024),
317        }
318    }
319}
320
321/// OpenAI-specific options
322#[derive(Clone, Debug, Serialize, Deserialize, Default)]
323pub struct LLMOpenAIOptions {
324    /// Reasoning effort for o1/o3/o4 models ("low", "medium", "high")
325    #[serde(skip_serializing_if = "Option::is_none")]
326    pub reasoning_effort: Option<String>,
327}
328
329/// Google/Gemini-specific options
330#[derive(Clone, Debug, Serialize, Deserialize, Default)]
331pub struct LLMGoogleOptions {
332    /// Thinking budget in tokens
333    #[serde(skip_serializing_if = "Option::is_none")]
334    pub thinking_budget: Option<u32>,
335}
336
337#[derive(Clone, Debug, Serialize)]
338pub struct LLMInput {
339    pub model: Model,
340    pub messages: Vec<LLMMessage>,
341    pub max_tokens: u32,
342    pub tools: Option<Vec<LLMTool>>,
343    #[serde(skip_serializing_if = "Option::is_none")]
344    pub provider_options: Option<LLMProviderOptions>,
345    /// Custom headers to pass to the inference provider
346    #[serde(skip_serializing_if = "Option::is_none")]
347    pub headers: Option<std::collections::HashMap<String, String>>,
348}
349
350#[derive(Debug)]
351pub struct LLMStreamInput {
352    pub model: Model,
353    pub messages: Vec<LLMMessage>,
354    pub max_tokens: u32,
355    pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
356    pub tools: Option<Vec<LLMTool>>,
357    pub provider_options: Option<LLMProviderOptions>,
358    /// Custom headers to pass to the inference provider
359    pub headers: Option<std::collections::HashMap<String, String>>,
360}
361
362impl From<&LLMStreamInput> for LLMInput {
363    fn from(value: &LLMStreamInput) -> Self {
364        LLMInput {
365            model: value.model.clone(),
366            messages: value.messages.clone(),
367            max_tokens: value.max_tokens,
368            tools: value.tools.clone(),
369            provider_options: value.provider_options.clone(),
370            headers: value.headers.clone(),
371        }
372    }
373}
374
375#[derive(Serialize, Deserialize, Debug, Clone, Default)]
376pub struct LLMMessage {
377    pub role: String,
378    pub content: LLMMessageContent,
379}
380
381#[derive(Serialize, Deserialize, Debug, Clone)]
382pub struct SimpleLLMMessage {
383    #[serde(rename = "role")]
384    pub role: SimpleLLMRole,
385    pub content: String,
386}
387
388#[derive(Serialize, Deserialize, Debug, Clone)]
389#[serde(rename_all = "lowercase")]
390pub enum SimpleLLMRole {
391    User,
392    Assistant,
393}
394
395impl std::fmt::Display for SimpleLLMRole {
396    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397        match self {
398            SimpleLLMRole::User => write!(f, "user"),
399            SimpleLLMRole::Assistant => write!(f, "assistant"),
400        }
401    }
402}
403
404#[derive(Serialize, Deserialize, Debug, Clone)]
405#[serde(untagged)]
406pub enum LLMMessageContent {
407    String(String),
408    List(Vec<LLMMessageTypedContent>),
409}
410
411#[allow(clippy::to_string_trait_impl)]
412impl ToString for LLMMessageContent {
413    fn to_string(&self) -> String {
414        match self {
415            LLMMessageContent::String(s) => s.clone(),
416            LLMMessageContent::List(l) => l
417                .iter()
418                .map(|c| match c {
419                    LLMMessageTypedContent::Text { text } => text.clone(),
420                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
421                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
422                    LLMMessageTypedContent::Image { .. } => String::new(),
423                })
424                .collect::<Vec<_>>()
425                .join("\n"),
426        }
427    }
428}
429
430impl From<String> for LLMMessageContent {
431    fn from(value: String) -> Self {
432        LLMMessageContent::String(value)
433    }
434}
435
436impl Default for LLMMessageContent {
437    fn default() -> Self {
438        LLMMessageContent::String(String::new())
439    }
440}
441
442impl LLMMessageContent {
443    /// Convert into a Vec of typed content parts.
444    /// A `String` variant is returned as a single `Text` part (empty strings yield an empty vec).
445    pub fn into_parts(self) -> Vec<LLMMessageTypedContent> {
446        match self {
447            LLMMessageContent::List(parts) => parts,
448            LLMMessageContent::String(s) if s.is_empty() => vec![],
449            LLMMessageContent::String(s) => vec![LLMMessageTypedContent::Text { text: s }],
450        }
451    }
452}
453
454#[derive(Serialize, Deserialize, Debug, Clone)]
455#[serde(tag = "type")]
456pub enum LLMMessageTypedContent {
457    #[serde(rename = "text")]
458    Text { text: String },
459    #[serde(rename = "tool_use")]
460    ToolCall {
461        id: String,
462        name: String,
463        #[serde(alias = "input")]
464        args: serde_json::Value,
465        /// Opaque provider-specific metadata (e.g., Gemini thought_signature).
466        #[serde(skip_serializing_if = "Option::is_none")]
467        metadata: Option<serde_json::Value>,
468    },
469    #[serde(rename = "tool_result")]
470    ToolResult {
471        tool_use_id: String,
472        content: String,
473    },
474    #[serde(rename = "image")]
475    Image { source: LLMMessageImageSource },
476}
477
478#[derive(Serialize, Deserialize, Debug, Clone)]
479pub struct LLMMessageImageSource {
480    #[serde(rename = "type")]
481    pub r#type: String,
482    pub media_type: String,
483    pub data: String,
484}
485
486impl Default for LLMMessageTypedContent {
487    fn default() -> Self {
488        LLMMessageTypedContent::Text {
489            text: String::new(),
490        }
491    }
492}
493
494#[derive(Serialize, Deserialize, Debug, Clone)]
495pub struct LLMChoice {
496    pub finish_reason: Option<String>,
497    pub index: u32,
498    pub message: LLMMessage,
499}
500
501#[derive(Serialize, Deserialize, Debug, Clone)]
502pub struct LLMCompletionResponse {
503    pub model: String,
504    pub object: String,
505    pub choices: Vec<LLMChoice>,
506    pub created: u64,
507    pub usage: Option<LLMTokenUsage>,
508    pub id: String,
509}
510
511#[derive(Serialize, Deserialize, Debug, Clone)]
512pub struct LLMStreamDelta {
513    #[serde(skip_serializing_if = "Option::is_none")]
514    pub content: Option<String>,
515}
516
517#[derive(Serialize, Deserialize, Debug, Clone)]
518pub struct LLMStreamChoice {
519    pub finish_reason: Option<String>,
520    pub index: u32,
521    pub message: Option<LLMMessage>,
522    pub delta: LLMStreamDelta,
523}
524
525#[derive(Serialize, Deserialize, Debug, Clone)]
526pub struct LLMCompletionStreamResponse {
527    pub model: String,
528    pub object: String,
529    pub choices: Vec<LLMStreamChoice>,
530    pub created: u64,
531    #[serde(skip_serializing_if = "Option::is_none")]
532    pub usage: Option<LLMTokenUsage>,
533    pub id: String,
534    pub citations: Option<Vec<String>>,
535}
536
537#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
538pub struct LLMTool {
539    pub name: String,
540    pub description: String,
541    pub input_schema: serde_json::Value,
542}
543
544#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
545pub struct LLMTokenUsage {
546    pub prompt_tokens: u32,
547    pub completion_tokens: u32,
548    pub total_tokens: u32,
549
550    #[serde(skip_serializing_if = "Option::is_none")]
551    pub prompt_tokens_details: Option<PromptTokensDetails>,
552}
553
554#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
555#[serde(rename_all = "snake_case")]
556pub enum TokenType {
557    InputTokens,
558    OutputTokens,
559    CacheReadInputTokens,
560    CacheWriteInputTokens,
561}
562
563#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
564pub struct PromptTokensDetails {
565    #[serde(skip_serializing_if = "Option::is_none")]
566    pub input_tokens: Option<u32>,
567    #[serde(skip_serializing_if = "Option::is_none")]
568    pub output_tokens: Option<u32>,
569    #[serde(skip_serializing_if = "Option::is_none")]
570    pub cache_read_input_tokens: Option<u32>,
571    #[serde(skip_serializing_if = "Option::is_none")]
572    pub cache_write_input_tokens: Option<u32>,
573}
574
575impl PromptTokensDetails {
576    /// Returns an iterator over the token types and their values
577    pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
578        [
579            (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
580            (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
581            (
582                TokenType::CacheReadInputTokens,
583                self.cache_read_input_tokens.unwrap_or(0),
584            ),
585            (
586                TokenType::CacheWriteInputTokens,
587                self.cache_write_input_tokens.unwrap_or(0),
588            ),
589        ]
590        .into_iter()
591    }
592}
593
594impl std::ops::Add for PromptTokensDetails {
595    type Output = Self;
596
597    fn add(self, rhs: Self) -> Self::Output {
598        Self {
599            input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
600            output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
601            cache_read_input_tokens: Some(
602                self.cache_read_input_tokens.unwrap_or(0)
603                    + rhs.cache_read_input_tokens.unwrap_or(0),
604            ),
605            cache_write_input_tokens: Some(
606                self.cache_write_input_tokens.unwrap_or(0)
607                    + rhs.cache_write_input_tokens.unwrap_or(0),
608            ),
609        }
610    }
611}
612
613impl std::ops::AddAssign for PromptTokensDetails {
614    fn add_assign(&mut self, rhs: Self) {
615        self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
616        self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
617        self.cache_read_input_tokens = Some(
618            self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
619        );
620        self.cache_write_input_tokens = Some(
621            self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
622        );
623    }
624}
625
626#[derive(Serialize, Deserialize, Debug, Clone)]
627#[serde(tag = "type")]
628pub enum GenerationDelta {
629    Content { content: String },
630    Thinking { thinking: String },
631    ToolUse { tool_use: GenerationDeltaToolUse },
632    Usage { usage: LLMTokenUsage },
633    Metadata { metadata: serde_json::Value },
634}
635
636#[derive(Serialize, Deserialize, Debug, Clone)]
637pub struct GenerationDeltaToolUse {
638    pub id: Option<String>,
639    pub name: Option<String>,
640    pub input: Option<String>,
641    pub index: usize,
642    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
643    #[serde(skip_serializing_if = "Option::is_none")]
644    pub metadata: Option<serde_json::Value>,
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650
651    // =========================================================================
652    // ProviderConfig Tests
653    // =========================================================================
654
655    #[test]
656    fn test_provider_config_openai_serialization() {
657        let config = ProviderConfig::OpenAI {
658            api_key: Some("sk-test".to_string()),
659            api_endpoint: None,
660        };
661        let json = serde_json::to_string(&config).unwrap();
662        assert!(json.contains("\"type\":\"openai\""));
663        assert!(json.contains("\"api_key\":\"sk-test\""));
664        assert!(!json.contains("api_endpoint")); // Should be skipped when None
665    }
666
667    #[test]
668    fn test_provider_config_openai_with_endpoint() {
669        let config = ProviderConfig::OpenAI {
670            api_key: Some("sk-test".to_string()),
671            api_endpoint: Some("https://custom.openai.com/v1".to_string()),
672        };
673        let json = serde_json::to_string(&config).unwrap();
674        assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
675    }
676
677    #[test]
678    fn test_provider_config_anthropic_serialization() {
679        let config = ProviderConfig::Anthropic {
680            api_key: Some("sk-ant-test".to_string()),
681            api_endpoint: None,
682            access_token: Some("oauth-token".to_string()),
683        };
684        let json = serde_json::to_string(&config).unwrap();
685        assert!(json.contains("\"type\":\"anthropic\""));
686        assert!(json.contains("\"api_key\":\"sk-ant-test\""));
687        assert!(json.contains("\"access_token\":\"oauth-token\""));
688    }
689
690    #[test]
691    fn test_provider_config_gemini_serialization() {
692        let config = ProviderConfig::Gemini {
693            api_key: Some("gemini-key".to_string()),
694            api_endpoint: None,
695        };
696        let json = serde_json::to_string(&config).unwrap();
697        assert!(json.contains("\"type\":\"gemini\""));
698        assert!(json.contains("\"api_key\":\"gemini-key\""));
699    }
700
701    #[test]
702    fn test_provider_config_custom_serialization() {
703        let config = ProviderConfig::Custom {
704            api_key: Some("sk-custom".to_string()),
705            api_endpoint: "http://localhost:4000".to_string(),
706        };
707        let json = serde_json::to_string(&config).unwrap();
708        assert!(json.contains("\"type\":\"custom\""));
709        assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
710        assert!(json.contains("\"api_key\":\"sk-custom\""));
711    }
712
713    #[test]
714    fn test_provider_config_custom_without_key() {
715        let config = ProviderConfig::Custom {
716            api_key: None,
717            api_endpoint: "http://localhost:11434/v1".to_string(),
718        };
719        let json = serde_json::to_string(&config).unwrap();
720        assert!(json.contains("\"type\":\"custom\""));
721        assert!(json.contains("\"api_endpoint\""));
722        assert!(!json.contains("api_key")); // Should be skipped when None
723    }
724
725    #[test]
726    fn test_provider_config_deserialization_openai() {
727        let json = r#"{"type":"openai","api_key":"sk-test"}"#;
728        let config: ProviderConfig = serde_json::from_str(json).unwrap();
729        assert!(matches!(config, ProviderConfig::OpenAI { .. }));
730        assert_eq!(config.api_key(), Some("sk-test"));
731    }
732
733    #[test]
734    fn test_provider_config_deserialization_anthropic() {
735        let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
736        let config: ProviderConfig = serde_json::from_str(json).unwrap();
737        assert!(matches!(config, ProviderConfig::Anthropic { .. }));
738        assert_eq!(config.api_key(), Some("sk-ant"));
739        assert_eq!(config.access_token(), Some("oauth"));
740    }
741
742    #[test]
743    fn test_provider_config_deserialization_gemini() {
744        let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
745        let config: ProviderConfig = serde_json::from_str(json).unwrap();
746        assert!(matches!(config, ProviderConfig::Gemini { .. }));
747        assert_eq!(config.api_key(), Some("gemini-key"));
748    }
749
750    #[test]
751    fn test_provider_config_deserialization_custom() {
752        let json =
753            r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
754        let config: ProviderConfig = serde_json::from_str(json).unwrap();
755        assert!(matches!(config, ProviderConfig::Custom { .. }));
756        assert_eq!(config.api_key(), Some("sk-custom"));
757        assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
758    }
759
760    #[test]
761    fn test_provider_config_helper_methods() {
762        let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
763        assert_eq!(openai.provider_type(), "openai");
764        assert_eq!(openai.api_key(), Some("sk-openai"));
765
766        let anthropic =
767            ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
768        assert_eq!(anthropic.provider_type(), "anthropic");
769        assert_eq!(anthropic.access_token(), Some("oauth"));
770
771        let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
772        assert_eq!(gemini.provider_type(), "gemini");
773
774        let custom = ProviderConfig::custom(
775            "http://localhost:4000".to_string(),
776            Some("sk-custom".to_string()),
777        );
778        assert_eq!(custom.provider_type(), "custom");
779        assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
780    }
781
782    #[test]
783    fn test_llm_provider_config_new() {
784        let config = LLMProviderConfig::new();
785        assert!(config.is_empty());
786    }
787
788    #[test]
789    fn test_llm_provider_config_add_and_get() {
790        let mut config = LLMProviderConfig::new();
791        config.add_provider(
792            "openai",
793            ProviderConfig::openai(Some("sk-test".to_string())),
794        );
795        config.add_provider(
796            "anthropic",
797            ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
798        );
799
800        assert!(!config.is_empty());
801        assert!(config.get_provider("openai").is_some());
802        assert!(config.get_provider("anthropic").is_some());
803        assert!(config.get_provider("unknown").is_none());
804    }
805
806    #[test]
807    fn test_provider_config_toml_parsing() {
808        // Test parsing a HashMap of providers from TOML-like JSON
809        let json = r#"{
810            "openai": {"type": "openai", "api_key": "sk-openai"},
811            "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
812            "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
813        }"#;
814
815        let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
816        assert_eq!(providers.len(), 3);
817
818        assert!(matches!(
819            providers.get("openai"),
820            Some(ProviderConfig::OpenAI { .. })
821        ));
822        assert!(matches!(
823            providers.get("anthropic"),
824            Some(ProviderConfig::Anthropic { .. })
825        ));
826        assert!(matches!(
827            providers.get("litellm"),
828            Some(ProviderConfig::Custom { .. })
829        ));
830    }
831}