Skip to main content

stakpak_shared/models/
llm.rs

1//! LLM Provider and Model Configuration
2//!
3//! This module provides the configuration types for LLM providers and models.
4//!
5//! # Provider Configuration
6//!
7//! Providers are configured in a `providers` HashMap where the key becomes the
8//! model prefix for routing requests to the correct provider.
9//!
10//! ## Built-in Providers
11//!
12//! - `openai` - OpenAI API
13//! - `anthropic` - Anthropic API (supports OAuth via `access_token`)
14//! - `gemini` - Google Gemini API
15//! - `bedrock` - AWS Bedrock (uses AWS credential chain, no API key)
16//!
17//! For built-in providers, you can use the model name directly without a prefix:
18//! - `claude-sonnet-4-5` → auto-detected as Anthropic
19//! - `gpt-4` → auto-detected as OpenAI
20//! - `gemini-2.5-pro` → auto-detected as Gemini
21//!
22//! ## Custom Providers
23//!
24//! Any OpenAI-compatible API can be configured using `type = "custom"`.
25//! The provider key becomes the model prefix.
26//!
27//! # Model Routing
28//!
29//! Models can be specified with or without a provider prefix:
30//!
31//! - `claude-sonnet-4-5` → auto-detected as `anthropic` provider
32//! - `anthropic/claude-sonnet-4-5` → explicit `anthropic` provider
33//! - `offline/llama3` → routes to `offline` custom provider, sends `llama3` to API
34//! - `custom/anthropic/claude-opus` → routes to `custom` provider,
35//!   sends `anthropic/claude-opus` to the API
36//!
37//! # Example Configuration
38//!
39//! ```toml
40//! [profiles.default]
41//! provider = "local"
42//! smart_model = "claude-sonnet-4-5"  # auto-detected as anthropic
43//! eco_model = "offline/llama3"       # custom provider
44//!
45//! [profiles.default.providers.anthropic]
46//! type = "anthropic"
47//! # api_key from auth.toml or ANTHROPIC_API_KEY env var
48//!
49//! [profiles.default.providers.offline]
50//! type = "custom"
51//! api_endpoint = "http://localhost:11434/v1"
52//! ```
53
54use serde::{Deserialize, Serialize};
55use stakai::Model;
56use std::collections::HashMap;
57
58use super::auth::ProviderAuth;
59
60// =============================================================================
61// Provider Configuration
62// =============================================================================
63
64/// Unified provider configuration enum
65///
66/// All provider configurations are stored in a `HashMap<String, ProviderConfig>`
67/// where the key is the provider name and becomes the model prefix for routing.
68///
69/// # Provider Key = Model Prefix
70///
71/// The key used in the HashMap becomes the prefix used in model names:
72/// - Config key: `providers.offline`
73/// - Model usage: `offline/llama3`
74/// - Routing: finds `offline` provider, sends `llama3` to API
75///
76/// # Example TOML
77/// ```toml
78/// [profiles.myprofile.providers.openai]
79/// type = "openai"
80///
81/// [profiles.myprofile.providers.openai.auth]
82/// type = "api"
83/// key = "sk-..."
84///
85/// [profiles.myprofile.providers.anthropic]
86/// type = "anthropic"
87///
88/// [profiles.myprofile.providers.anthropic.auth]
89/// type = "oauth"
90/// access = "eyJ..."
91/// refresh = "eyJ..."
92/// expires = 1735600000000
93/// name = "Claude Max"
94///
95/// [profiles.myprofile.providers.offline]
96/// type = "custom"
97/// api_endpoint = "http://localhost:11434/v1"
98/// ```
99#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
100#[serde(tag = "type", rename_all = "lowercase")]
101pub enum ProviderConfig {
102    /// OpenAI provider configuration
103    OpenAI {
104        /// Legacy API key field (prefer `auth` field)
105        #[serde(skip_serializing_if = "Option::is_none")]
106        api_key: Option<String>,
107        #[serde(skip_serializing_if = "Option::is_none")]
108        api_endpoint: Option<String>,
109        /// Authentication credentials (preferred over api_key)
110        #[serde(skip_serializing_if = "Option::is_none")]
111        auth: Option<ProviderAuth>,
112    },
113    /// Anthropic provider configuration
114    Anthropic {
115        /// Legacy API key field (prefer `auth` field)
116        #[serde(skip_serializing_if = "Option::is_none")]
117        api_key: Option<String>,
118        #[serde(skip_serializing_if = "Option::is_none")]
119        api_endpoint: Option<String>,
120        /// Legacy OAuth access token (prefer `auth` field with OAuth type)
121        #[serde(skip_serializing_if = "Option::is_none")]
122        access_token: Option<String>,
123        /// Authentication credentials (preferred over api_key/access_token)
124        #[serde(skip_serializing_if = "Option::is_none")]
125        auth: Option<ProviderAuth>,
126    },
127    /// Google Gemini provider configuration
128    Gemini {
129        /// Legacy API key field (prefer `auth` field)
130        #[serde(skip_serializing_if = "Option::is_none")]
131        api_key: Option<String>,
132        #[serde(skip_serializing_if = "Option::is_none")]
133        api_endpoint: Option<String>,
134        /// Authentication credentials (preferred over api_key)
135        #[serde(skip_serializing_if = "Option::is_none")]
136        auth: Option<ProviderAuth>,
137    },
138    /// Custom OpenAI-compatible provider (Ollama, vLLM, etc.)
139    ///
140    /// The provider key in the config becomes the model prefix.
141    /// For example, if configured as `providers.offline`, use models as:
142    /// - `offline/llama3` - passes `llama3` to the API
143    /// - `offline/anthropic/claude-opus` - passes `anthropic/claude-opus` to the API
144    ///
145    /// # Example TOML
146    /// ```toml
147    /// [profiles.myprofile.providers.offline]
148    /// type = "custom"
149    /// api_endpoint = "http://localhost:11434/v1"
150    ///
151    /// # Then use models as:
152    /// model = "offline/llama3"
153    /// ```
154    Custom {
155        /// Legacy API key field (prefer `auth` field)
156        #[serde(skip_serializing_if = "Option::is_none")]
157        api_key: Option<String>,
158        /// API endpoint URL (required for custom providers)
159        /// Use the base URL as required by your provider (e.g., "http://localhost:11434/v1")
160        api_endpoint: String,
161        /// Authentication credentials (preferred over api_key)
162        #[serde(skip_serializing_if = "Option::is_none")]
163        auth: Option<ProviderAuth>,
164    },
165    /// Stakpak provider configuration
166    ///
167    /// Routes inference through Stakpak's unified API, which provides:
168    /// - Access to multiple LLM providers via a single endpoint
169    /// - Usage tracking and billing
170    /// - Session management and checkpoints
171    ///
172    /// # Example TOML
173    /// ```toml
174    /// [profiles.myprofile.providers.stakpak]
175    /// type = "stakpak"
176    /// api_endpoint = "https://apiv2.stakpak.dev"  # optional, this is the default
177    ///
178    /// [profiles.myprofile.providers.stakpak.auth]
179    /// type = "api"
180    /// key = "your-stakpak-api-key"
181    ///
182    /// # Then use models as:
183    /// model = "stakpak/anthropic/claude-sonnet-4-5-20250929"
184    /// ```
185    Stakpak {
186        /// Legacy API key field (prefer `auth` field)
187        /// Note: This field is optional when using `auth`
188        #[serde(skip_serializing_if = "Option::is_none")]
189        api_key: Option<String>,
190        /// API endpoint URL (default: https://apiv2.stakpak.dev)
191        #[serde(skip_serializing_if = "Option::is_none")]
192        api_endpoint: Option<String>,
193        /// Authentication credentials (preferred over api_key)
194        #[serde(skip_serializing_if = "Option::is_none")]
195        auth: Option<ProviderAuth>,
196    },
197    /// AWS Bedrock provider configuration
198    ///
199    /// Uses AWS credential chain for authentication (no API key needed).
200    /// Supports env vars, shared credentials, SSO, and instance roles.
201    ///
202    /// # Example TOML
203    /// ```toml
204    /// [profiles.myprofile.providers.amazon-bedrock]
205    /// type = "amazon-bedrock"
206    /// region = "us-east-1"
207    /// profile_name = "my-aws-profile"  # optional
208    ///
209    /// # Then use models as (friendly aliases work):
210    /// model = "amazon-bedrock/claude-sonnet-4-5"
211    /// ```
212    #[serde(rename = "amazon-bedrock")]
213    Bedrock {
214        /// AWS region (e.g., "us-east-1")
215        region: String,
216        /// Optional AWS named profile (from ~/.aws/config)
217        #[serde(skip_serializing_if = "Option::is_none")]
218        profile_name: Option<String>,
219    },
220    /// GitHub Copilot provider configuration
221    ///
222    /// Uses the GitHub Device Authorization Grant to obtain an OAuth token, then
223    /// calls the OpenAI-compatible Copilot API endpoint.
224    ///
225    /// # Example TOML
226    /// ```toml
227    /// [profiles.myprofile.providers.github-copilot]
228    /// type = "github-copilot"
229    ///
230    /// [profiles.myprofile.providers.github-copilot.auth]
231    /// type = "oauth"
232    /// access = "ghu_..."
233    /// refresh = ""
234    /// expires = 9223372036854775807
235    /// name = "GitHub Copilot"
236    ///
237    /// # Then use models as:
238    /// model = "github-copilot/gpt-4o"
239    /// ```
240    #[serde(rename = "github-copilot")]
241    GitHubCopilot {
242        /// Optional custom API endpoint (defaults to https://api.githubcopilot.com)
243        #[serde(skip_serializing_if = "Option::is_none")]
244        api_endpoint: Option<String>,
245        /// Authentication credentials (OAuth access token from device flow)
246        #[serde(skip_serializing_if = "Option::is_none")]
247        auth: Option<ProviderAuth>,
248    },
249
250    #[serde(rename = "openrouter")]
251    OpenRouter {
252        /// Legacy API key field (prefer `auth` field)
253        #[serde(skip_serializing_if = "Option::is_none")]
254        api_key: Option<String>,
255        /// Optional custom API endpoint (defaults to https://openrouter.ai/api/v1)
256        #[serde(skip_serializing_if = "Option::is_none")]
257        api_endpoint: Option<String>,
258        /// Authentication credentials (preferred over api_key)
259        #[serde(skip_serializing_if = "Option::is_none")]
260        auth: Option<ProviderAuth>,
261    },
262}
263
264impl ProviderConfig {
265    /// Get the provider type name
266    pub fn provider_type(&self) -> &'static str {
267        match self {
268            ProviderConfig::OpenAI { .. } => "openai",
269            ProviderConfig::Anthropic { .. } => "anthropic",
270            ProviderConfig::Gemini { .. } => "gemini",
271            ProviderConfig::Custom { .. } => "custom",
272            ProviderConfig::Stakpak { .. } => "stakpak",
273            ProviderConfig::Bedrock { .. } => "amazon-bedrock",
274            ProviderConfig::GitHubCopilot { .. } => "github-copilot",
275            ProviderConfig::OpenRouter { .. } => "openrouter",
276        }
277    }
278
279    /// Get the API key if set (checks `auth` field first, then legacy `api_key`)
280    pub fn api_key(&self) -> Option<&str> {
281        // First check auth field
282        if let Some(auth) = self.get_auth_ref()
283            && let Some(key) = auth.api_key_value()
284        {
285            return Some(key);
286        }
287        // Fall back to legacy api_key field
288        match self {
289            ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
290            ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
291            ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
292            ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
293            ProviderConfig::Stakpak { api_key, .. } => api_key.as_deref(),
294            ProviderConfig::OpenRouter { api_key, .. } => api_key.as_deref(),
295            ProviderConfig::Bedrock { .. } => None, // AWS credential chain, no API key
296            ProviderConfig::GitHubCopilot { .. } => None, // OAuth only, no API key
297        }
298    }
299
300    /// Get the auth credentials reference
301    fn get_auth_ref(&self) -> Option<&ProviderAuth> {
302        match self {
303            ProviderConfig::OpenAI { auth, .. } => auth.as_ref(),
304            ProviderConfig::Anthropic { auth, .. } => auth.as_ref(),
305            ProviderConfig::Gemini { auth, .. } => auth.as_ref(),
306            ProviderConfig::Custom { auth, .. } => auth.as_ref(),
307            ProviderConfig::Stakpak { auth, .. } => auth.as_ref(),
308            ProviderConfig::Bedrock { .. } => None,
309            ProviderConfig::GitHubCopilot { auth, .. } => auth.as_ref(),
310            ProviderConfig::OpenRouter { auth, .. } => auth.as_ref(),
311        }
312    }
313
314    /// Get resolved authentication credentials.
315    ///
316    /// Resolution order:
317    /// 1. `auth` field (preferred)
318    /// 2. Legacy `api_key` field (converted to ProviderAuth::Api)
319    /// 3. Legacy `access_token` field for Anthropic (converted to ProviderAuth with access token)
320    pub fn get_auth(&self) -> Option<ProviderAuth> {
321        // First check auth field
322        if let Some(auth) = self.get_auth_ref() {
323            return Some(auth.clone());
324        }
325
326        // Fall back to legacy fields
327        match self {
328            ProviderConfig::OpenAI { api_key, .. }
329            | ProviderConfig::Gemini { api_key, .. }
330            | ProviderConfig::Custom { api_key, .. }
331            | ProviderConfig::Stakpak { api_key, .. }
332            | ProviderConfig::OpenRouter { api_key, .. } => {
333                api_key.as_ref().map(ProviderAuth::api_key)
334            }
335            ProviderConfig::Anthropic {
336                api_key,
337                access_token,
338                ..
339            } => {
340                // Prefer api_key, then access_token (as OAuth bearer token, not API key)
341                if let Some(key) = api_key {
342                    Some(ProviderAuth::api_key(key))
343                } else {
344                    // Legacy access_token is an OAuth bearer token — wrap it as OAuth
345                    // with empty refresh token and zero expiry so it will be treated as
346                    // expired and trigger a re-auth rather than silently failing.
347                    access_token
348                        .as_ref()
349                        .map(|token| ProviderAuth::oauth(token, "", 0))
350                }
351            }
352            ProviderConfig::Bedrock { .. } => None,
353            // GitHubCopilot has no legacy fields; auth is always in the `auth` field
354            ProviderConfig::GitHubCopilot { .. } => None,
355        }
356    }
357
358    /// Set authentication credentials on this provider config.
359    ///
360    /// Also clears any legacy credential fields (`api_key`, `access_token`)
361    /// so they don't shadow the new `auth` field on future reads.
362    pub fn set_auth(&mut self, auth: ProviderAuth) {
363        match self {
364            ProviderConfig::OpenAI {
365                auth: auth_field,
366                api_key,
367                ..
368            }
369            | ProviderConfig::Gemini {
370                auth: auth_field,
371                api_key,
372                ..
373            }
374            | ProviderConfig::Custom {
375                auth: auth_field,
376                api_key,
377                ..
378            }
379            | ProviderConfig::Stakpak {
380                auth: auth_field,
381                api_key,
382                ..
383            }
384            | ProviderConfig::OpenRouter {
385                auth: auth_field,
386                api_key,
387                ..
388            } => {
389                *auth_field = Some(auth);
390                *api_key = None;
391            }
392            ProviderConfig::Anthropic {
393                auth: auth_field,
394                api_key,
395                access_token,
396                ..
397            } => {
398                *auth_field = Some(auth);
399                *api_key = None;
400                *access_token = None;
401            }
402            ProviderConfig::GitHubCopilot {
403                auth: auth_field, ..
404            } => {
405                *auth_field = Some(auth);
406            }
407            ProviderConfig::Bedrock { .. } => {
408                // Bedrock uses AWS credential chain, no auth field
409            }
410        }
411    }
412
413    /// Clear authentication credentials from this provider config.
414    ///
415    /// Clears both the `auth` field and any legacy credential fields
416    /// (`api_key`, `access_token`) to ensure credentials are fully removed.
417    pub fn clear_auth(&mut self) {
418        match self {
419            ProviderConfig::OpenAI {
420                auth: auth_field,
421                api_key,
422                ..
423            }
424            | ProviderConfig::Gemini {
425                auth: auth_field,
426                api_key,
427                ..
428            }
429            | ProviderConfig::Custom {
430                auth: auth_field,
431                api_key,
432                ..
433            }
434            | ProviderConfig::Stakpak {
435                auth: auth_field,
436                api_key,
437                ..
438            }
439            | ProviderConfig::OpenRouter {
440                auth: auth_field,
441                api_key,
442                ..
443            } => {
444                *auth_field = None;
445                *api_key = None;
446            }
447            ProviderConfig::Anthropic {
448                auth: auth_field,
449                api_key,
450                access_token,
451                ..
452            } => {
453                *auth_field = None;
454                *api_key = None;
455                *access_token = None;
456            }
457            ProviderConfig::GitHubCopilot {
458                auth: auth_field, ..
459            } => {
460                *auth_field = None;
461            }
462            ProviderConfig::Bedrock { .. } => {
463                // Bedrock uses AWS credential chain, no auth field
464            }
465        }
466    }
467
468    /// Get the API endpoint if set
469    pub fn api_endpoint(&self) -> Option<&str> {
470        match self {
471            ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
472            ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
473            ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
474            ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
475            ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
476            ProviderConfig::OpenRouter { api_endpoint, .. } => api_endpoint.as_deref(),
477            ProviderConfig::Bedrock { .. } => None, // No custom endpoint in config
478            ProviderConfig::GitHubCopilot { api_endpoint, .. } => api_endpoint.as_deref(),
479        }
480    }
481
482    /// Set the API endpoint for providers that support it.
483    ///
484    /// For `Custom`, `None` is ignored because custom providers require an endpoint.
485    /// For `Bedrock`, this is a no-op.
486    pub fn set_api_endpoint(&mut self, endpoint: Option<String>) {
487        match self {
488            ProviderConfig::OpenAI { api_endpoint, .. }
489            | ProviderConfig::Anthropic { api_endpoint, .. }
490            | ProviderConfig::Gemini { api_endpoint, .. }
491            | ProviderConfig::Stakpak { api_endpoint, .. }
492            | ProviderConfig::GitHubCopilot { api_endpoint, .. }
493            | ProviderConfig::OpenRouter { api_endpoint, .. } => {
494                *api_endpoint = endpoint;
495            }
496            ProviderConfig::Custom { api_endpoint, .. } => {
497                if let Some(custom_endpoint) = endpoint {
498                    *api_endpoint = custom_endpoint;
499                }
500            }
501            ProviderConfig::Bedrock { .. } => {}
502        }
503    }
504
505    /// Get the access token (for OAuth-based providers such as Anthropic and GitHub Copilot)
506    ///
507    /// Checks the `auth` field first for OAuth access token, then falls back
508    /// to the legacy `access_token` field (Anthropic only).
509    pub fn access_token(&self) -> Option<&str> {
510        // First check auth field for OAuth access token
511        if let Some(auth) = self.get_auth_ref()
512            && let Some(token) = auth.access_token()
513        {
514            return Some(token);
515        }
516        // Fall back to legacy access_token field (Anthropic only)
517        match self {
518            ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
519            _ => None,
520        }
521    }
522
523    /// Create an OpenAI provider config (legacy, uses api_key field)
524    pub fn openai(api_key: Option<String>) -> Self {
525        ProviderConfig::OpenAI {
526            api_key,
527            api_endpoint: None,
528            auth: None,
529        }
530    }
531
532    /// Create an OpenAI provider config with auth
533    pub fn openai_with_auth(auth: ProviderAuth) -> Self {
534        ProviderConfig::OpenAI {
535            api_key: None,
536            api_endpoint: None,
537            auth: Some(auth),
538        }
539    }
540
541    /// Create an Anthropic provider config (legacy, uses api_key/access_token fields)
542    pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
543        ProviderConfig::Anthropic {
544            api_key,
545            api_endpoint: None,
546            access_token,
547            auth: None,
548        }
549    }
550
551    /// Create an Anthropic provider config with auth
552    pub fn anthropic_with_auth(auth: ProviderAuth) -> Self {
553        ProviderConfig::Anthropic {
554            api_key: None,
555            api_endpoint: None,
556            access_token: None,
557            auth: Some(auth),
558        }
559    }
560
561    /// Create a Gemini provider config (legacy, uses api_key field)
562    pub fn gemini(api_key: Option<String>) -> Self {
563        ProviderConfig::Gemini {
564            api_key,
565            api_endpoint: None,
566            auth: None,
567        }
568    }
569
570    /// Create a Gemini provider config with auth
571    pub fn gemini_with_auth(auth: ProviderAuth) -> Self {
572        ProviderConfig::Gemini {
573            api_key: None,
574            api_endpoint: None,
575            auth: Some(auth),
576        }
577    }
578
579    /// Create a custom provider config (legacy, uses api_key field)
580    pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
581        ProviderConfig::Custom {
582            api_key,
583            api_endpoint,
584            auth: None,
585        }
586    }
587
588    /// Create a custom provider config with auth
589    pub fn custom_with_auth(api_endpoint: String, auth: ProviderAuth) -> Self {
590        ProviderConfig::Custom {
591            api_key: None,
592            api_endpoint,
593            auth: Some(auth),
594        }
595    }
596
597    /// Create a Stakpak provider config (legacy, uses api_key field)
598    pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
599        ProviderConfig::Stakpak {
600            api_key: Some(api_key),
601            api_endpoint,
602            auth: None,
603        }
604    }
605
606    /// Create a Stakpak provider config with auth
607    pub fn stakpak_with_auth(auth: ProviderAuth, api_endpoint: Option<String>) -> Self {
608        ProviderConfig::Stakpak {
609            api_key: None,
610            api_endpoint,
611            auth: Some(auth),
612        }
613    }
614
615    /// Create an OpenRouter provider config (legacy, uses api_key field)
616    pub fn openrouter(api_key: Option<String>, api_endpoint: Option<String>) -> Self {
617        ProviderConfig::OpenRouter {
618            api_key,
619            api_endpoint,
620            auth: None,
621        }
622    }
623
624    /// Create an OpenRouter provider config with auth
625    pub fn openrouter_with_auth(auth: ProviderAuth, api_endpoint: Option<String>) -> Self {
626        ProviderConfig::OpenRouter {
627            api_key: None,
628            api_endpoint,
629            auth: Some(auth),
630        }
631    }
632
633    /// Create a GitHub Copilot provider config with auth (OAuth token from device flow)
634    pub fn github_copilot_with_auth(auth: ProviderAuth) -> Self {
635        ProviderConfig::GitHubCopilot {
636            api_endpoint: None,
637            auth: Some(auth),
638        }
639    }
640
641    /// Create a Bedrock provider config
642    pub fn bedrock(region: String, profile_name: Option<String>) -> Self {
643        ProviderConfig::Bedrock {
644            region,
645            profile_name,
646        }
647    }
648
649    /// Get the AWS region (Bedrock only)
650    pub fn region(&self) -> Option<&str> {
651        match self {
652            ProviderConfig::Bedrock { region, .. } => Some(region.as_str()),
653            _ => None,
654        }
655    }
656
657    /// Get the AWS profile name (Bedrock only)
658    pub fn profile_name(&self) -> Option<&str> {
659        match self {
660            ProviderConfig::Bedrock { profile_name, .. } => profile_name.as_deref(),
661            _ => None,
662        }
663    }
664
665    /// Create an empty provider config for a given provider name.
666    ///
667    /// Used during migration when we need to create a provider config
668    /// to attach auth credentials to.
669    pub fn empty_for_provider(provider_name: &str) -> Option<Self> {
670        match provider_name {
671            "openai" => Some(ProviderConfig::OpenAI {
672                api_key: None,
673                api_endpoint: None,
674                auth: None,
675            }),
676            "anthropic" => Some(ProviderConfig::Anthropic {
677                api_key: None,
678                api_endpoint: None,
679                access_token: None,
680                auth: None,
681            }),
682            "gemini" => Some(ProviderConfig::Gemini {
683                api_key: None,
684                api_endpoint: None,
685                auth: None,
686            }),
687            "stakpak" => Some(ProviderConfig::Stakpak {
688                api_key: None,
689                api_endpoint: None,
690                auth: None,
691            }),
692            "github-copilot" => Some(ProviderConfig::GitHubCopilot {
693                api_endpoint: None,
694                auth: None,
695            }),
696            "openrouter" => Some(ProviderConfig::OpenRouter {
697                api_key: None,
698                api_endpoint: None,
699                auth: None,
700            }),
701            // Custom providers need an endpoint, Bedrock uses AWS credential chain
702            _ => None,
703        }
704    }
705}
706
707/// Aggregated provider configuration for LLM operations
708///
709/// This struct holds all configured providers, keyed by provider name.
710#[derive(Debug, Clone, Default)]
711pub struct LLMProviderConfig {
712    /// All provider configurations (key = provider name)
713    pub providers: HashMap<String, ProviderConfig>,
714}
715
716impl LLMProviderConfig {
717    /// Create a new empty provider config
718    pub fn new() -> Self {
719        Self {
720            providers: HashMap::new(),
721        }
722    }
723
724    /// Add a provider configuration
725    pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
726        self.providers.insert(name.into(), config);
727    }
728
729    /// Get a provider configuration by name
730    pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
731        self.providers.get(name)
732    }
733
734    /// Check if any providers are configured
735    pub fn is_empty(&self) -> bool {
736        self.providers.is_empty()
737    }
738}
739
740/// Provider-specific options for LLM requests
741#[derive(Clone, Debug, Serialize, Deserialize, Default)]
742pub struct LLMProviderOptions {
743    /// Anthropic-specific options
744    #[serde(skip_serializing_if = "Option::is_none")]
745    pub anthropic: Option<LLMAnthropicOptions>,
746
747    /// OpenAI-specific options
748    #[serde(skip_serializing_if = "Option::is_none")]
749    pub openai: Option<LLMOpenAIOptions>,
750
751    /// Google/Gemini-specific options
752    #[serde(skip_serializing_if = "Option::is_none")]
753    pub google: Option<LLMGoogleOptions>,
754}
755
756/// Anthropic-specific options
757#[derive(Clone, Debug, Serialize, Deserialize, Default)]
758pub struct LLMAnthropicOptions {
759    /// Extended thinking configuration
760    #[serde(skip_serializing_if = "Option::is_none")]
761    pub thinking: Option<LLMThinkingOptions>,
762}
763
764/// Thinking/reasoning options
765#[derive(Clone, Debug, Serialize, Deserialize)]
766pub struct LLMThinkingOptions {
767    /// Budget tokens for thinking (must be >= 1024)
768    pub budget_tokens: u32,
769}
770
771impl LLMThinkingOptions {
772    pub fn new(budget_tokens: u32) -> Self {
773        Self {
774            budget_tokens: budget_tokens.max(1024),
775        }
776    }
777}
778
779/// OpenAI-specific options
780#[derive(Clone, Debug, Serialize, Deserialize, Default)]
781pub struct LLMOpenAIOptions {
782    /// Reasoning effort for o1/o3/o4 models ("low", "medium", "high")
783    #[serde(skip_serializing_if = "Option::is_none")]
784    pub reasoning_effort: Option<String>,
785}
786
787/// Google/Gemini-specific options
788#[derive(Clone, Debug, Serialize, Deserialize, Default)]
789pub struct LLMGoogleOptions {
790    /// Thinking budget in tokens
791    #[serde(skip_serializing_if = "Option::is_none")]
792    pub thinking_budget: Option<u32>,
793}
794
795#[derive(Clone, Debug, Serialize)]
796pub struct LLMInput {
797    pub model: Model,
798    pub messages: Vec<LLMMessage>,
799    pub max_tokens: u32,
800    pub tools: Option<Vec<LLMTool>>,
801    #[serde(skip_serializing_if = "Option::is_none")]
802    pub provider_options: Option<LLMProviderOptions>,
803    /// Custom headers to pass to the inference provider
804    #[serde(skip_serializing_if = "Option::is_none")]
805    pub headers: Option<std::collections::HashMap<String, String>>,
806}
807
808#[derive(Debug)]
809pub struct LLMStreamInput {
810    pub model: Model,
811    pub messages: Vec<LLMMessage>,
812    pub max_tokens: u32,
813    pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
814    pub tools: Option<Vec<LLMTool>>,
815    pub provider_options: Option<LLMProviderOptions>,
816    /// Custom headers to pass to the inference provider
817    pub headers: Option<std::collections::HashMap<String, String>>,
818}
819
820impl From<&LLMStreamInput> for LLMInput {
821    fn from(value: &LLMStreamInput) -> Self {
822        LLMInput {
823            model: value.model.clone(),
824            messages: value.messages.clone(),
825            max_tokens: value.max_tokens,
826            tools: value.tools.clone(),
827            provider_options: value.provider_options.clone(),
828            headers: value.headers.clone(),
829        }
830    }
831}
832
833#[derive(Serialize, Deserialize, Debug, Clone, Default)]
834pub struct LLMMessage {
835    pub role: String,
836    pub content: LLMMessageContent,
837}
838
839#[derive(Serialize, Deserialize, Debug, Clone)]
840pub struct SimpleLLMMessage {
841    #[serde(rename = "role")]
842    pub role: SimpleLLMRole,
843    pub content: String,
844}
845
846#[derive(Serialize, Deserialize, Debug, Clone)]
847#[serde(rename_all = "lowercase")]
848pub enum SimpleLLMRole {
849    User,
850    Assistant,
851}
852
853impl std::fmt::Display for SimpleLLMRole {
854    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
855        match self {
856            SimpleLLMRole::User => write!(f, "user"),
857            SimpleLLMRole::Assistant => write!(f, "assistant"),
858        }
859    }
860}
861
862#[derive(Serialize, Deserialize, Debug, Clone)]
863#[serde(untagged)]
864pub enum LLMMessageContent {
865    String(String),
866    List(Vec<LLMMessageTypedContent>),
867}
868
869#[allow(clippy::to_string_trait_impl)]
870impl ToString for LLMMessageContent {
871    fn to_string(&self) -> String {
872        match self {
873            LLMMessageContent::String(s) => s.clone(),
874            LLMMessageContent::List(l) => l
875                .iter()
876                .map(|c| match c {
877                    LLMMessageTypedContent::Text { text } => text.clone(),
878                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
879                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
880                    LLMMessageTypedContent::Image { .. } => String::new(),
881                })
882                .collect::<Vec<_>>()
883                .join("\n"),
884        }
885    }
886}
887
888impl From<String> for LLMMessageContent {
889    fn from(value: String) -> Self {
890        LLMMessageContent::String(value)
891    }
892}
893
894impl Default for LLMMessageContent {
895    fn default() -> Self {
896        LLMMessageContent::String(String::new())
897    }
898}
899
900impl LLMMessageContent {
901    /// Convert into a Vec of typed content parts.
902    /// A `String` variant is returned as a single `Text` part (empty strings yield an empty vec).
903    pub fn into_parts(self) -> Vec<LLMMessageTypedContent> {
904        match self {
905            LLMMessageContent::List(parts) => parts,
906            LLMMessageContent::String(s) if s.is_empty() => vec![],
907            LLMMessageContent::String(s) => vec![LLMMessageTypedContent::Text { text: s }],
908        }
909    }
910}
911
912#[derive(Serialize, Deserialize, Debug, Clone)]
913#[serde(tag = "type")]
914pub enum LLMMessageTypedContent {
915    #[serde(rename = "text")]
916    Text { text: String },
917    #[serde(rename = "tool_use")]
918    ToolCall {
919        id: String,
920        name: String,
921        #[serde(alias = "input")]
922        args: serde_json::Value,
923        /// Opaque provider-specific metadata (e.g., Gemini thought_signature).
924        #[serde(skip_serializing_if = "Option::is_none")]
925        metadata: Option<serde_json::Value>,
926    },
927    #[serde(rename = "tool_result")]
928    ToolResult {
929        tool_use_id: String,
930        content: String,
931    },
932    #[serde(rename = "image")]
933    Image { source: LLMMessageImageSource },
934}
935
936#[derive(Serialize, Deserialize, Debug, Clone)]
937pub struct LLMMessageImageSource {
938    #[serde(rename = "type")]
939    pub r#type: String,
940    pub media_type: String,
941    pub data: String,
942}
943
944impl Default for LLMMessageTypedContent {
945    fn default() -> Self {
946        LLMMessageTypedContent::Text {
947            text: String::new(),
948        }
949    }
950}
951
952#[derive(Serialize, Deserialize, Debug, Clone)]
953pub struct LLMChoice {
954    pub finish_reason: Option<String>,
955    pub index: u32,
956    pub message: LLMMessage,
957}
958
959#[derive(Serialize, Deserialize, Debug, Clone)]
960pub struct LLMCompletionResponse {
961    pub model: String,
962    pub object: String,
963    pub choices: Vec<LLMChoice>,
964    pub created: u64,
965    pub usage: Option<LLMTokenUsage>,
966    pub id: String,
967}
968
969#[derive(Serialize, Deserialize, Debug, Clone)]
970pub struct LLMStreamDelta {
971    #[serde(skip_serializing_if = "Option::is_none")]
972    pub content: Option<String>,
973}
974
975#[derive(Serialize, Deserialize, Debug, Clone)]
976pub struct LLMStreamChoice {
977    pub finish_reason: Option<String>,
978    pub index: u32,
979    pub message: Option<LLMMessage>,
980    pub delta: LLMStreamDelta,
981}
982
983#[derive(Serialize, Deserialize, Debug, Clone)]
984pub struct LLMCompletionStreamResponse {
985    pub model: String,
986    pub object: String,
987    pub choices: Vec<LLMStreamChoice>,
988    pub created: u64,
989    #[serde(skip_serializing_if = "Option::is_none")]
990    pub usage: Option<LLMTokenUsage>,
991    pub id: String,
992    pub citations: Option<Vec<String>>,
993}
994
995#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
996pub struct LLMTool {
997    pub name: String,
998    pub description: String,
999    pub input_schema: serde_json::Value,
1000}
1001
1002#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
1003pub struct LLMTokenUsage {
1004    pub prompt_tokens: u32,
1005    pub completion_tokens: u32,
1006    pub total_tokens: u32,
1007
1008    #[serde(skip_serializing_if = "Option::is_none")]
1009    pub prompt_tokens_details: Option<PromptTokensDetails>,
1010}
1011
1012#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
1013#[serde(rename_all = "snake_case")]
1014pub enum TokenType {
1015    InputTokens,
1016    OutputTokens,
1017    CacheReadInputTokens,
1018    CacheWriteInputTokens,
1019}
1020
1021#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
1022pub struct PromptTokensDetails {
1023    #[serde(skip_serializing_if = "Option::is_none")]
1024    pub input_tokens: Option<u32>,
1025    #[serde(skip_serializing_if = "Option::is_none")]
1026    pub output_tokens: Option<u32>,
1027    #[serde(skip_serializing_if = "Option::is_none")]
1028    pub cache_read_input_tokens: Option<u32>,
1029    #[serde(skip_serializing_if = "Option::is_none")]
1030    pub cache_write_input_tokens: Option<u32>,
1031}
1032
1033impl PromptTokensDetails {
1034    /// Returns an iterator over the token types and their values
1035    pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
1036        [
1037            (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
1038            (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
1039            (
1040                TokenType::CacheReadInputTokens,
1041                self.cache_read_input_tokens.unwrap_or(0),
1042            ),
1043            (
1044                TokenType::CacheWriteInputTokens,
1045                self.cache_write_input_tokens.unwrap_or(0),
1046            ),
1047        ]
1048        .into_iter()
1049    }
1050}
1051
1052impl std::ops::Add for PromptTokensDetails {
1053    type Output = Self;
1054
1055    fn add(self, rhs: Self) -> Self::Output {
1056        Self {
1057            input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
1058            output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
1059            cache_read_input_tokens: Some(
1060                self.cache_read_input_tokens.unwrap_or(0)
1061                    + rhs.cache_read_input_tokens.unwrap_or(0),
1062            ),
1063            cache_write_input_tokens: Some(
1064                self.cache_write_input_tokens.unwrap_or(0)
1065                    + rhs.cache_write_input_tokens.unwrap_or(0),
1066            ),
1067        }
1068    }
1069}
1070
1071impl std::ops::AddAssign for PromptTokensDetails {
1072    fn add_assign(&mut self, rhs: Self) {
1073        self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
1074        self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
1075        self.cache_read_input_tokens = Some(
1076            self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
1077        );
1078        self.cache_write_input_tokens = Some(
1079            self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
1080        );
1081    }
1082}
1083
1084#[derive(Serialize, Deserialize, Debug, Clone)]
1085#[serde(tag = "type")]
1086pub enum GenerationDelta {
1087    Content { content: String },
1088    Thinking { thinking: String },
1089    ToolUse { tool_use: GenerationDeltaToolUse },
1090    Usage { usage: LLMTokenUsage },
1091    Metadata { metadata: serde_json::Value },
1092}
1093
1094#[derive(Serialize, Deserialize, Debug, Clone)]
1095pub struct GenerationDeltaToolUse {
1096    pub id: Option<String>,
1097    pub name: Option<String>,
1098    pub input: Option<String>,
1099    pub index: usize,
1100    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
1101    #[serde(skip_serializing_if = "Option::is_none")]
1102    pub metadata: Option<serde_json::Value>,
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107    use super::*;
1108
1109    // =========================================================================
1110    // ProviderConfig Tests
1111    // =========================================================================
1112
1113    #[test]
1114    fn test_provider_config_openai_serialization() {
1115        let config = ProviderConfig::OpenAI {
1116            api_key: Some("sk-test".to_string()),
1117            api_endpoint: None,
1118            auth: None,
1119        };
1120        let json = serde_json::to_string(&config).unwrap();
1121        assert!(json.contains("\"type\":\"openai\""));
1122        assert!(json.contains("\"api_key\":\"sk-test\""));
1123        assert!(!json.contains("api_endpoint")); // Should be skipped when None
1124    }
1125
1126    #[test]
1127    fn test_provider_config_openai_with_endpoint() {
1128        let config = ProviderConfig::OpenAI {
1129            api_key: Some("sk-test".to_string()),
1130            api_endpoint: Some("https://custom.openai.com/v1".to_string()),
1131            auth: None,
1132        };
1133        let json = serde_json::to_string(&config).unwrap();
1134        assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
1135    }
1136
1137    #[test]
1138    fn test_provider_config_anthropic_serialization() {
1139        let config = ProviderConfig::Anthropic {
1140            api_key: Some("sk-ant-test".to_string()),
1141            api_endpoint: None,
1142            access_token: Some("oauth-token".to_string()),
1143            auth: None,
1144        };
1145        let json = serde_json::to_string(&config).unwrap();
1146        assert!(json.contains("\"type\":\"anthropic\""));
1147        assert!(json.contains("\"api_key\":\"sk-ant-test\""));
1148        assert!(json.contains("\"access_token\":\"oauth-token\""));
1149    }
1150
1151    #[test]
1152    fn test_provider_config_gemini_serialization() {
1153        let config = ProviderConfig::Gemini {
1154            api_key: Some("gemini-key".to_string()),
1155            api_endpoint: None,
1156            auth: None,
1157        };
1158        let json = serde_json::to_string(&config).unwrap();
1159        assert!(json.contains("\"type\":\"gemini\""));
1160        assert!(json.contains("\"api_key\":\"gemini-key\""));
1161    }
1162
1163    #[test]
1164    fn test_provider_config_custom_serialization() {
1165        let config = ProviderConfig::Custom {
1166            api_key: Some("sk-custom".to_string()),
1167            api_endpoint: "http://localhost:4000".to_string(),
1168            auth: None,
1169        };
1170        let json = serde_json::to_string(&config).unwrap();
1171        assert!(json.contains("\"type\":\"custom\""));
1172        assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
1173        assert!(json.contains("\"api_key\":\"sk-custom\""));
1174    }
1175
1176    #[test]
1177    fn test_provider_config_custom_without_key() {
1178        let config = ProviderConfig::Custom {
1179            api_key: None,
1180            api_endpoint: "http://localhost:11434/v1".to_string(),
1181            auth: None,
1182        };
1183        let json = serde_json::to_string(&config).unwrap();
1184        assert!(json.contains("\"type\":\"custom\""));
1185        assert!(json.contains("\"api_endpoint\""));
1186        assert!(!json.contains("api_key")); // Should be skipped when None
1187    }
1188
1189    #[test]
1190    fn test_provider_config_deserialization_openai() {
1191        let json = r#"{"type":"openai","api_key":"sk-test"}"#;
1192        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1193        assert!(matches!(config, ProviderConfig::OpenAI { .. }));
1194        assert_eq!(config.api_key(), Some("sk-test"));
1195    }
1196
1197    #[test]
1198    fn test_provider_config_deserialization_anthropic() {
1199        let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
1200        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1201        assert!(matches!(config, ProviderConfig::Anthropic { .. }));
1202        assert_eq!(config.api_key(), Some("sk-ant"));
1203        assert_eq!(config.access_token(), Some("oauth"));
1204    }
1205
1206    #[test]
1207    fn test_provider_config_deserialization_gemini() {
1208        let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
1209        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1210        assert!(matches!(config, ProviderConfig::Gemini { .. }));
1211        assert_eq!(config.api_key(), Some("gemini-key"));
1212    }
1213
1214    #[test]
1215    fn test_provider_config_deserialization_custom() {
1216        let json =
1217            r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
1218        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1219        assert!(matches!(config, ProviderConfig::Custom { .. }));
1220        assert_eq!(config.api_key(), Some("sk-custom"));
1221        assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
1222    }
1223
1224    #[test]
1225    fn test_provider_config_helper_methods() {
1226        let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
1227        assert_eq!(openai.provider_type(), "openai");
1228        assert_eq!(openai.api_key(), Some("sk-openai"));
1229
1230        let anthropic =
1231            ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
1232        assert_eq!(anthropic.provider_type(), "anthropic");
1233        assert_eq!(anthropic.access_token(), Some("oauth"));
1234
1235        let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
1236        assert_eq!(gemini.provider_type(), "gemini");
1237
1238        let custom = ProviderConfig::custom(
1239            "http://localhost:4000".to_string(),
1240            Some("sk-custom".to_string()),
1241        );
1242        assert_eq!(custom.provider_type(), "custom");
1243        assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
1244    }
1245
1246    #[test]
1247    fn test_set_api_endpoint_updates_supported_providers() {
1248        let mut openai = ProviderConfig::openai(Some("sk-openai".to_string()));
1249        openai.set_api_endpoint(Some("https://proxy.example.com/v1".to_string()));
1250        assert_eq!(openai.api_endpoint(), Some("https://proxy.example.com/v1"));
1251
1252        let mut bedrock = ProviderConfig::bedrock("us-east-1".to_string(), None);
1253        bedrock.set_api_endpoint(Some("https://ignored.example.com".to_string()));
1254        assert_eq!(bedrock.api_endpoint(), None);
1255    }
1256
1257    #[test]
1258    fn test_llm_provider_config_new() {
1259        let config = LLMProviderConfig::new();
1260        assert!(config.is_empty());
1261    }
1262
1263    #[test]
1264    fn test_llm_provider_config_add_and_get() {
1265        let mut config = LLMProviderConfig::new();
1266        config.add_provider(
1267            "openai",
1268            ProviderConfig::openai(Some("sk-test".to_string())),
1269        );
1270        config.add_provider(
1271            "anthropic",
1272            ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
1273        );
1274
1275        assert!(!config.is_empty());
1276        assert!(config.get_provider("openai").is_some());
1277        assert!(config.get_provider("anthropic").is_some());
1278        assert!(config.get_provider("unknown").is_none());
1279    }
1280
1281    #[test]
1282    fn test_provider_config_toml_parsing() {
1283        // Test parsing a HashMap of providers from TOML-like JSON
1284        let json = r#"{
1285            "openai": {"type": "openai", "api_key": "sk-openai"},
1286            "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
1287            "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
1288        }"#;
1289
1290        let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1291        assert_eq!(providers.len(), 3);
1292
1293        assert!(matches!(
1294            providers.get("openai"),
1295            Some(ProviderConfig::OpenAI { .. })
1296        ));
1297        assert!(matches!(
1298            providers.get("anthropic"),
1299            Some(ProviderConfig::Anthropic { .. })
1300        ));
1301        assert!(matches!(
1302            providers.get("litellm"),
1303            Some(ProviderConfig::Custom { .. })
1304        ));
1305    }
1306
1307    // =========================================================================
1308    // Bedrock ProviderConfig Tests
1309    // =========================================================================
1310
1311    #[test]
1312    fn test_provider_config_bedrock_serialization() {
1313        let config = ProviderConfig::Bedrock {
1314            region: "us-east-1".to_string(),
1315            profile_name: Some("my-profile".to_string()),
1316        };
1317        let json = serde_json::to_string(&config).unwrap();
1318        assert!(json.contains("\"type\":\"amazon-bedrock\""));
1319        assert!(json.contains("\"region\":\"us-east-1\""));
1320        assert!(json.contains("\"profile_name\":\"my-profile\""));
1321    }
1322
1323    #[test]
1324    fn test_provider_config_bedrock_serialization_without_profile() {
1325        let config = ProviderConfig::Bedrock {
1326            region: "us-west-2".to_string(),
1327            profile_name: None,
1328        };
1329        let json = serde_json::to_string(&config).unwrap();
1330        assert!(json.contains("\"type\":\"amazon-bedrock\""));
1331        assert!(json.contains("\"region\":\"us-west-2\""));
1332        assert!(!json.contains("profile_name")); // Should be skipped when None
1333    }
1334
1335    #[test]
1336    fn test_provider_config_bedrock_deserialization() {
1337        let json = r#"{"type":"amazon-bedrock","region":"us-east-1","profile_name":"prod"}"#;
1338        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1339        assert!(matches!(config, ProviderConfig::Bedrock { .. }));
1340        assert_eq!(config.region(), Some("us-east-1"));
1341        assert_eq!(config.profile_name(), Some("prod"));
1342    }
1343
1344    #[test]
1345    fn test_provider_config_bedrock_deserialization_minimal() {
1346        let json = r#"{"type":"amazon-bedrock","region":"eu-west-1"}"#;
1347        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1348        assert!(matches!(config, ProviderConfig::Bedrock { .. }));
1349        assert_eq!(config.region(), Some("eu-west-1"));
1350        assert_eq!(config.profile_name(), None);
1351    }
1352
1353    #[test]
1354    fn test_provider_config_bedrock_no_api_key() {
1355        let config = ProviderConfig::bedrock("us-east-1".to_string(), None);
1356        assert_eq!(config.api_key(), None); // Bedrock uses AWS credential chain
1357        assert_eq!(config.api_endpoint(), None); // No custom endpoint
1358    }
1359
1360    #[test]
1361    fn test_provider_config_bedrock_helper_methods() {
1362        let bedrock = ProviderConfig::bedrock("us-east-1".to_string(), Some("prod".to_string()));
1363        assert_eq!(bedrock.provider_type(), "amazon-bedrock");
1364        assert_eq!(bedrock.region(), Some("us-east-1"));
1365        assert_eq!(bedrock.profile_name(), Some("prod"));
1366        assert_eq!(bedrock.api_key(), None);
1367        assert_eq!(bedrock.api_endpoint(), None);
1368        assert_eq!(bedrock.access_token(), None);
1369    }
1370
1371    #[test]
1372    fn test_provider_config_bedrock_toml_roundtrip() {
1373        let config = ProviderConfig::Bedrock {
1374            region: "us-east-1".to_string(),
1375            profile_name: Some("my-profile".to_string()),
1376        };
1377        let toml_str = toml::to_string(&config).unwrap();
1378        let parsed: ProviderConfig = toml::from_str(&toml_str).unwrap();
1379        assert_eq!(config, parsed);
1380    }
1381
1382    #[test]
1383    fn test_provider_config_bedrock_toml_parsing() {
1384        let toml_str = r#"
1385            type = "amazon-bedrock"
1386            region = "us-east-1"
1387            profile_name = "production"
1388        "#;
1389        let config: ProviderConfig = toml::from_str(toml_str).unwrap();
1390        assert!(matches!(
1391            config,
1392            ProviderConfig::Bedrock {
1393                ref region,
1394                ref profile_name,
1395            } if region == "us-east-1" && profile_name.as_deref() == Some("production")
1396        ));
1397    }
1398
1399    #[test]
1400    fn test_provider_config_bedrock_missing_region_fails() {
1401        let json = r#"{"type":"amazon-bedrock"}"#;
1402        let result: Result<ProviderConfig, _> = serde_json::from_str(json);
1403        assert!(result.is_err()); // region is required
1404    }
1405
1406    #[test]
1407    fn test_provider_config_bedrock_in_providers_map() {
1408        let json = r#"{
1409            "anthropic": {"type": "anthropic", "api_key": "sk-ant"},
1410            "amazon-bedrock": {"type": "amazon-bedrock", "region": "us-east-1"}
1411        }"#;
1412        let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1413        assert_eq!(providers.len(), 2);
1414        assert!(matches!(
1415            providers.get("amazon-bedrock"),
1416            Some(ProviderConfig::Bedrock { .. })
1417        ));
1418    }
1419
1420    #[test]
1421    fn test_region_returns_none_for_non_bedrock() {
1422        let openai = ProviderConfig::openai(Some("key".to_string()));
1423        assert_eq!(openai.region(), None);
1424
1425        let anthropic = ProviderConfig::anthropic(Some("key".to_string()), None);
1426        assert_eq!(anthropic.region(), None);
1427    }
1428
1429    #[test]
1430    fn test_profile_name_returns_none_for_non_bedrock() {
1431        let openai = ProviderConfig::openai(Some("key".to_string()));
1432        assert_eq!(openai.profile_name(), None);
1433    }
1434}