Skip to main content

stakpak_shared/models/
llm.rs

1//! LLM Provider and Model Configuration
2//!
3//! This module provides the configuration types for LLM providers and models.
4//!
5//! # Provider Configuration
6//!
7//! Providers are configured in a `providers` HashMap where the key becomes the
8//! model prefix for routing requests to the correct provider.
9//!
10//! ## Built-in Providers
11//!
12//! - `openai` - OpenAI API
13//! - `anthropic` - Anthropic API (supports OAuth via `access_token`)
14//! - `gemini` - Google Gemini API
15//! - `bedrock` - AWS Bedrock (uses AWS credential chain, no API key)
16//!
17//! For built-in providers, you can use the model name directly without a prefix:
18//! - `claude-sonnet-4-5` → auto-detected as Anthropic
19//! - `gpt-4` → auto-detected as OpenAI
20//! - `gemini-2.5-pro` → auto-detected as Gemini
21//!
22//! ## Custom Providers
23//!
24//! Any OpenAI-compatible API can be configured using `type = "custom"`.
25//! The provider key becomes the model prefix.
26//!
27//! # Model Routing
28//!
29//! Models can be specified with or without a provider prefix:
30//!
31//! - `claude-sonnet-4-5` → auto-detected as `anthropic` provider
32//! - `anthropic/claude-sonnet-4-5` → explicit `anthropic` provider
33//! - `offline/llama3` → routes to `offline` custom provider, sends `llama3` to API
34//! - `custom/anthropic/claude-opus` → routes to `custom` provider,
35//!   sends `anthropic/claude-opus` to the API
36//!
37//! # Example Configuration
38//!
39//! ```toml
40//! [profiles.default]
41//! provider = "local"
42//! smart_model = "claude-sonnet-4-5"  # auto-detected as anthropic
43//! eco_model = "offline/llama3"       # custom provider
44//!
45//! [profiles.default.providers.anthropic]
46//! type = "anthropic"
47//! # api_key from auth.toml or ANTHROPIC_API_KEY env var
48//!
49//! [profiles.default.providers.offline]
50//! type = "custom"
51//! api_endpoint = "http://localhost:11434/v1"
52//! ```
53
54use serde::{Deserialize, Serialize};
55use stakai::Model;
56use std::collections::HashMap;
57
58use super::auth::ProviderAuth;
59
60// =============================================================================
61// Provider Configuration
62// =============================================================================
63
64/// Unified provider configuration enum
65///
66/// All provider configurations are stored in a `HashMap<String, ProviderConfig>`
67/// where the key is the provider name and becomes the model prefix for routing.
68///
69/// # Provider Key = Model Prefix
70///
71/// The key used in the HashMap becomes the prefix used in model names:
72/// - Config key: `providers.offline`
73/// - Model usage: `offline/llama3`
74/// - Routing: finds `offline` provider, sends `llama3` to API
75///
76/// # Example TOML
77/// ```toml
78/// [profiles.myprofile.providers.openai]
79/// type = "openai"
80///
81/// [profiles.myprofile.providers.openai.auth]
82/// type = "api"
83/// key = "sk-..."
84///
85/// [profiles.myprofile.providers.anthropic]
86/// type = "anthropic"
87///
88/// [profiles.myprofile.providers.anthropic.auth]
89/// type = "oauth"
90/// access = "eyJ..."
91/// refresh = "eyJ..."
92/// expires = 1735600000000
93/// name = "Claude Max"
94///
95/// [profiles.myprofile.providers.offline]
96/// type = "custom"
97/// api_endpoint = "http://localhost:11434/v1"
98/// ```
99#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
100#[serde(tag = "type", rename_all = "lowercase")]
101pub enum ProviderConfig {
102    /// OpenAI provider configuration
103    OpenAI {
104        /// Legacy API key field (prefer `auth` field)
105        #[serde(skip_serializing_if = "Option::is_none")]
106        api_key: Option<String>,
107        #[serde(skip_serializing_if = "Option::is_none")]
108        api_endpoint: Option<String>,
109        /// Authentication credentials (preferred over api_key)
110        #[serde(skip_serializing_if = "Option::is_none")]
111        auth: Option<ProviderAuth>,
112    },
113    /// Anthropic provider configuration
114    Anthropic {
115        /// Legacy API key field (prefer `auth` field)
116        #[serde(skip_serializing_if = "Option::is_none")]
117        api_key: Option<String>,
118        #[serde(skip_serializing_if = "Option::is_none")]
119        api_endpoint: Option<String>,
120        /// Legacy OAuth access token (prefer `auth` field with OAuth type)
121        #[serde(skip_serializing_if = "Option::is_none")]
122        access_token: Option<String>,
123        /// Authentication credentials (preferred over api_key/access_token)
124        #[serde(skip_serializing_if = "Option::is_none")]
125        auth: Option<ProviderAuth>,
126    },
127    /// Google Gemini provider configuration
128    Gemini {
129        /// Legacy API key field (prefer `auth` field)
130        #[serde(skip_serializing_if = "Option::is_none")]
131        api_key: Option<String>,
132        #[serde(skip_serializing_if = "Option::is_none")]
133        api_endpoint: Option<String>,
134        /// Authentication credentials (preferred over api_key)
135        #[serde(skip_serializing_if = "Option::is_none")]
136        auth: Option<ProviderAuth>,
137    },
138    /// Custom OpenAI-compatible provider (Ollama, vLLM, etc.)
139    ///
140    /// The provider key in the config becomes the model prefix.
141    /// For example, if configured as `providers.offline`, use models as:
142    /// - `offline/llama3` - passes `llama3` to the API
143    /// - `offline/anthropic/claude-opus` - passes `anthropic/claude-opus` to the API
144    ///
145    /// # Example TOML
146    /// ```toml
147    /// [profiles.myprofile.providers.offline]
148    /// type = "custom"
149    /// api_endpoint = "http://localhost:11434/v1"
150    ///
151    /// # Then use models as:
152    /// model = "offline/llama3"
153    /// ```
154    Custom {
155        /// Legacy API key field (prefer `auth` field)
156        #[serde(skip_serializing_if = "Option::is_none")]
157        api_key: Option<String>,
158        /// API endpoint URL (required for custom providers)
159        /// Use the base URL as required by your provider (e.g., "http://localhost:11434/v1")
160        api_endpoint: String,
161        /// Authentication credentials (preferred over api_key)
162        #[serde(skip_serializing_if = "Option::is_none")]
163        auth: Option<ProviderAuth>,
164    },
165    /// Stakpak provider configuration
166    ///
167    /// Routes inference through Stakpak's unified API, which provides:
168    /// - Access to multiple LLM providers via a single endpoint
169    /// - Usage tracking and billing
170    /// - Session management and checkpoints
171    ///
172    /// # Example TOML
173    /// ```toml
174    /// [profiles.myprofile.providers.stakpak]
175    /// type = "stakpak"
176    /// api_endpoint = "https://apiv2.stakpak.dev"  # optional, this is the default
177    ///
178    /// [profiles.myprofile.providers.stakpak.auth]
179    /// type = "api"
180    /// key = "your-stakpak-api-key"
181    ///
182    /// # Then use models as:
183    /// model = "stakpak/anthropic/claude-sonnet-4-5-20250929"
184    /// ```
185    Stakpak {
186        /// Legacy API key field (prefer `auth` field)
187        /// Note: This field is optional when using `auth`
188        #[serde(skip_serializing_if = "Option::is_none")]
189        api_key: Option<String>,
190        /// API endpoint URL (default: https://apiv2.stakpak.dev)
191        #[serde(skip_serializing_if = "Option::is_none")]
192        api_endpoint: Option<String>,
193        /// Authentication credentials (preferred over api_key)
194        #[serde(skip_serializing_if = "Option::is_none")]
195        auth: Option<ProviderAuth>,
196    },
197    /// AWS Bedrock provider configuration
198    ///
199    /// Uses AWS credential chain for authentication (no API key needed).
200    /// Supports env vars, shared credentials, SSO, and instance roles.
201    ///
202    /// # Example TOML
203    /// ```toml
204    /// [profiles.myprofile.providers.amazon-bedrock]
205    /// type = "amazon-bedrock"
206    /// region = "us-east-1"
207    /// profile_name = "my-aws-profile"  # optional
208    ///
209    /// # Then use models as (friendly aliases work):
210    /// model = "amazon-bedrock/claude-sonnet-4-5"
211    /// ```
212    #[serde(rename = "amazon-bedrock")]
213    Bedrock {
214        /// AWS region (e.g., "us-east-1")
215        region: String,
216        /// Optional AWS named profile (from ~/.aws/config)
217        #[serde(skip_serializing_if = "Option::is_none")]
218        profile_name: Option<String>,
219    },
220}
221
222impl ProviderConfig {
223    /// Get the provider type name
224    pub fn provider_type(&self) -> &'static str {
225        match self {
226            ProviderConfig::OpenAI { .. } => "openai",
227            ProviderConfig::Anthropic { .. } => "anthropic",
228            ProviderConfig::Gemini { .. } => "gemini",
229            ProviderConfig::Custom { .. } => "custom",
230            ProviderConfig::Stakpak { .. } => "stakpak",
231            ProviderConfig::Bedrock { .. } => "amazon-bedrock",
232        }
233    }
234
235    /// Get the API key if set (checks `auth` field first, then legacy `api_key`)
236    pub fn api_key(&self) -> Option<&str> {
237        // First check auth field
238        if let Some(auth) = self.get_auth_ref()
239            && let Some(key) = auth.api_key_value()
240        {
241            return Some(key);
242        }
243        // Fall back to legacy api_key field
244        match self {
245            ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
246            ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
247            ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
248            ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
249            ProviderConfig::Stakpak { api_key, .. } => api_key.as_deref(),
250            ProviderConfig::Bedrock { .. } => None, // AWS credential chain, no API key
251        }
252    }
253
254    /// Get the auth credentials reference
255    fn get_auth_ref(&self) -> Option<&ProviderAuth> {
256        match self {
257            ProviderConfig::OpenAI { auth, .. } => auth.as_ref(),
258            ProviderConfig::Anthropic { auth, .. } => auth.as_ref(),
259            ProviderConfig::Gemini { auth, .. } => auth.as_ref(),
260            ProviderConfig::Custom { auth, .. } => auth.as_ref(),
261            ProviderConfig::Stakpak { auth, .. } => auth.as_ref(),
262            ProviderConfig::Bedrock { .. } => None,
263        }
264    }
265
266    /// Get resolved authentication credentials.
267    ///
268    /// Resolution order:
269    /// 1. `auth` field (preferred)
270    /// 2. Legacy `api_key` field (converted to ProviderAuth::Api)
271    /// 3. Legacy `access_token` field for Anthropic (converted to ProviderAuth with access token)
272    pub fn get_auth(&self) -> Option<ProviderAuth> {
273        // First check auth field
274        if let Some(auth) = self.get_auth_ref() {
275            return Some(auth.clone());
276        }
277
278        // Fall back to legacy fields
279        match self {
280            ProviderConfig::OpenAI { api_key, .. }
281            | ProviderConfig::Gemini { api_key, .. }
282            | ProviderConfig::Custom { api_key, .. }
283            | ProviderConfig::Stakpak { api_key, .. } => {
284                api_key.as_ref().map(ProviderAuth::api_key)
285            }
286            ProviderConfig::Anthropic {
287                api_key,
288                access_token,
289                ..
290            } => {
291                // Prefer api_key, then access_token (as OAuth bearer token, not API key)
292                if let Some(key) = api_key {
293                    Some(ProviderAuth::api_key(key))
294                } else {
295                    // Legacy access_token is an OAuth bearer token — wrap it as OAuth
296                    // with empty refresh token and zero expiry so it will be treated as
297                    // expired and trigger a re-auth rather than silently failing.
298                    access_token
299                        .as_ref()
300                        .map(|token| ProviderAuth::oauth(token, "", 0))
301                }
302            }
303            ProviderConfig::Bedrock { .. } => None,
304        }
305    }
306
307    /// Set authentication credentials on this provider config.
308    ///
309    /// Also clears any legacy credential fields (`api_key`, `access_token`)
310    /// so they don't shadow the new `auth` field on future reads.
311    pub fn set_auth(&mut self, auth: ProviderAuth) {
312        match self {
313            ProviderConfig::OpenAI {
314                auth: auth_field,
315                api_key,
316                ..
317            }
318            | ProviderConfig::Gemini {
319                auth: auth_field,
320                api_key,
321                ..
322            }
323            | ProviderConfig::Custom {
324                auth: auth_field,
325                api_key,
326                ..
327            }
328            | ProviderConfig::Stakpak {
329                auth: auth_field,
330                api_key,
331                ..
332            } => {
333                *auth_field = Some(auth);
334                *api_key = None;
335            }
336            ProviderConfig::Anthropic {
337                auth: auth_field,
338                api_key,
339                access_token,
340                ..
341            } => {
342                *auth_field = Some(auth);
343                *api_key = None;
344                *access_token = None;
345            }
346            ProviderConfig::Bedrock { .. } => {
347                // Bedrock uses AWS credential chain, no auth field
348            }
349        }
350    }
351
352    /// Clear authentication credentials from this provider config.
353    ///
354    /// Clears both the `auth` field and any legacy credential fields
355    /// (`api_key`, `access_token`) to ensure credentials are fully removed.
356    pub fn clear_auth(&mut self) {
357        match self {
358            ProviderConfig::OpenAI {
359                auth: auth_field,
360                api_key,
361                ..
362            }
363            | ProviderConfig::Gemini {
364                auth: auth_field,
365                api_key,
366                ..
367            }
368            | ProviderConfig::Custom {
369                auth: auth_field,
370                api_key,
371                ..
372            }
373            | ProviderConfig::Stakpak {
374                auth: auth_field,
375                api_key,
376                ..
377            } => {
378                *auth_field = None;
379                *api_key = None;
380            }
381            ProviderConfig::Anthropic {
382                auth: auth_field,
383                api_key,
384                access_token,
385                ..
386            } => {
387                *auth_field = None;
388                *api_key = None;
389                *access_token = None;
390            }
391            ProviderConfig::Bedrock { .. } => {
392                // Bedrock uses AWS credential chain, no auth field
393            }
394        }
395    }
396
397    /// Get the API endpoint if set
398    pub fn api_endpoint(&self) -> Option<&str> {
399        match self {
400            ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
401            ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
402            ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
403            ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
404            ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
405            ProviderConfig::Bedrock { .. } => None, // No custom endpoint in config
406        }
407    }
408
409    /// Set the API endpoint for providers that support it.
410    ///
411    /// For `Custom`, `None` is ignored because custom providers require an endpoint.
412    /// For `Bedrock`, this is a no-op.
413    pub fn set_api_endpoint(&mut self, endpoint: Option<String>) {
414        match self {
415            ProviderConfig::OpenAI { api_endpoint, .. }
416            | ProviderConfig::Anthropic { api_endpoint, .. }
417            | ProviderConfig::Gemini { api_endpoint, .. }
418            | ProviderConfig::Stakpak { api_endpoint, .. } => {
419                *api_endpoint = endpoint;
420            }
421            ProviderConfig::Custom { api_endpoint, .. } => {
422                if let Some(custom_endpoint) = endpoint {
423                    *api_endpoint = custom_endpoint;
424                }
425            }
426            ProviderConfig::Bedrock { .. } => {}
427        }
428    }
429
430    /// Get the access token (Anthropic only)
431    ///
432    /// Checks the `auth` field first for OAuth access token, then falls back
433    /// to the legacy `access_token` field.
434    pub fn access_token(&self) -> Option<&str> {
435        // First check auth field for OAuth access token
436        if let Some(auth) = self.get_auth_ref()
437            && let Some(token) = auth.access_token()
438        {
439            return Some(token);
440        }
441        // Fall back to legacy access_token field
442        match self {
443            ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
444            _ => None,
445        }
446    }
447
448    /// Create an OpenAI provider config (legacy, uses api_key field)
449    pub fn openai(api_key: Option<String>) -> Self {
450        ProviderConfig::OpenAI {
451            api_key,
452            api_endpoint: None,
453            auth: None,
454        }
455    }
456
457    /// Create an OpenAI provider config with auth
458    pub fn openai_with_auth(auth: ProviderAuth) -> Self {
459        ProviderConfig::OpenAI {
460            api_key: None,
461            api_endpoint: None,
462            auth: Some(auth),
463        }
464    }
465
466    /// Create an Anthropic provider config (legacy, uses api_key/access_token fields)
467    pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
468        ProviderConfig::Anthropic {
469            api_key,
470            api_endpoint: None,
471            access_token,
472            auth: None,
473        }
474    }
475
476    /// Create an Anthropic provider config with auth
477    pub fn anthropic_with_auth(auth: ProviderAuth) -> Self {
478        ProviderConfig::Anthropic {
479            api_key: None,
480            api_endpoint: None,
481            access_token: None,
482            auth: Some(auth),
483        }
484    }
485
486    /// Create a Gemini provider config (legacy, uses api_key field)
487    pub fn gemini(api_key: Option<String>) -> Self {
488        ProviderConfig::Gemini {
489            api_key,
490            api_endpoint: None,
491            auth: None,
492        }
493    }
494
495    /// Create a Gemini provider config with auth
496    pub fn gemini_with_auth(auth: ProviderAuth) -> Self {
497        ProviderConfig::Gemini {
498            api_key: None,
499            api_endpoint: None,
500            auth: Some(auth),
501        }
502    }
503
504    /// Create a custom provider config (legacy, uses api_key field)
505    pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
506        ProviderConfig::Custom {
507            api_key,
508            api_endpoint,
509            auth: None,
510        }
511    }
512
513    /// Create a custom provider config with auth
514    pub fn custom_with_auth(api_endpoint: String, auth: ProviderAuth) -> Self {
515        ProviderConfig::Custom {
516            api_key: None,
517            api_endpoint,
518            auth: Some(auth),
519        }
520    }
521
522    /// Create a Stakpak provider config (legacy, uses api_key field)
523    pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
524        ProviderConfig::Stakpak {
525            api_key: Some(api_key),
526            api_endpoint,
527            auth: None,
528        }
529    }
530
531    /// Create a Stakpak provider config with auth
532    pub fn stakpak_with_auth(auth: ProviderAuth, api_endpoint: Option<String>) -> Self {
533        ProviderConfig::Stakpak {
534            api_key: None,
535            api_endpoint,
536            auth: Some(auth),
537        }
538    }
539
540    /// Create a Bedrock provider config
541    pub fn bedrock(region: String, profile_name: Option<String>) -> Self {
542        ProviderConfig::Bedrock {
543            region,
544            profile_name,
545        }
546    }
547
548    /// Get the AWS region (Bedrock only)
549    pub fn region(&self) -> Option<&str> {
550        match self {
551            ProviderConfig::Bedrock { region, .. } => Some(region.as_str()),
552            _ => None,
553        }
554    }
555
556    /// Get the AWS profile name (Bedrock only)
557    pub fn profile_name(&self) -> Option<&str> {
558        match self {
559            ProviderConfig::Bedrock { profile_name, .. } => profile_name.as_deref(),
560            _ => None,
561        }
562    }
563
564    /// Create an empty provider config for a given provider name.
565    ///
566    /// Used during migration when we need to create a provider config
567    /// to attach auth credentials to.
568    pub fn empty_for_provider(provider_name: &str) -> Option<Self> {
569        match provider_name {
570            "openai" => Some(ProviderConfig::OpenAI {
571                api_key: None,
572                api_endpoint: None,
573                auth: None,
574            }),
575            "anthropic" => Some(ProviderConfig::Anthropic {
576                api_key: None,
577                api_endpoint: None,
578                access_token: None,
579                auth: None,
580            }),
581            "gemini" => Some(ProviderConfig::Gemini {
582                api_key: None,
583                api_endpoint: None,
584                auth: None,
585            }),
586            "stakpak" => Some(ProviderConfig::Stakpak {
587                api_key: None,
588                api_endpoint: None,
589                auth: None,
590            }),
591            // Custom providers need an endpoint, Bedrock uses AWS credential chain
592            _ => None,
593        }
594    }
595}
596
597/// Aggregated provider configuration for LLM operations
598///
599/// This struct holds all configured providers, keyed by provider name.
600#[derive(Debug, Clone, Default)]
601pub struct LLMProviderConfig {
602    /// All provider configurations (key = provider name)
603    pub providers: HashMap<String, ProviderConfig>,
604}
605
606impl LLMProviderConfig {
607    /// Create a new empty provider config
608    pub fn new() -> Self {
609        Self {
610            providers: HashMap::new(),
611        }
612    }
613
614    /// Add a provider configuration
615    pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
616        self.providers.insert(name.into(), config);
617    }
618
619    /// Get a provider configuration by name
620    pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
621        self.providers.get(name)
622    }
623
624    /// Check if any providers are configured
625    pub fn is_empty(&self) -> bool {
626        self.providers.is_empty()
627    }
628}
629
630/// Provider-specific options for LLM requests
631#[derive(Clone, Debug, Serialize, Deserialize, Default)]
632pub struct LLMProviderOptions {
633    /// Anthropic-specific options
634    #[serde(skip_serializing_if = "Option::is_none")]
635    pub anthropic: Option<LLMAnthropicOptions>,
636
637    /// OpenAI-specific options
638    #[serde(skip_serializing_if = "Option::is_none")]
639    pub openai: Option<LLMOpenAIOptions>,
640
641    /// Google/Gemini-specific options
642    #[serde(skip_serializing_if = "Option::is_none")]
643    pub google: Option<LLMGoogleOptions>,
644}
645
646/// Anthropic-specific options
647#[derive(Clone, Debug, Serialize, Deserialize, Default)]
648pub struct LLMAnthropicOptions {
649    /// Extended thinking configuration
650    #[serde(skip_serializing_if = "Option::is_none")]
651    pub thinking: Option<LLMThinkingOptions>,
652}
653
654/// Thinking/reasoning options
655#[derive(Clone, Debug, Serialize, Deserialize)]
656pub struct LLMThinkingOptions {
657    /// Budget tokens for thinking (must be >= 1024)
658    pub budget_tokens: u32,
659}
660
661impl LLMThinkingOptions {
662    pub fn new(budget_tokens: u32) -> Self {
663        Self {
664            budget_tokens: budget_tokens.max(1024),
665        }
666    }
667}
668
669/// OpenAI-specific options
670#[derive(Clone, Debug, Serialize, Deserialize, Default)]
671pub struct LLMOpenAIOptions {
672    /// Reasoning effort for o1/o3/o4 models ("low", "medium", "high")
673    #[serde(skip_serializing_if = "Option::is_none")]
674    pub reasoning_effort: Option<String>,
675}
676
677/// Google/Gemini-specific options
678#[derive(Clone, Debug, Serialize, Deserialize, Default)]
679pub struct LLMGoogleOptions {
680    /// Thinking budget in tokens
681    #[serde(skip_serializing_if = "Option::is_none")]
682    pub thinking_budget: Option<u32>,
683}
684
685#[derive(Clone, Debug, Serialize)]
686pub struct LLMInput {
687    pub model: Model,
688    pub messages: Vec<LLMMessage>,
689    pub max_tokens: u32,
690    pub tools: Option<Vec<LLMTool>>,
691    #[serde(skip_serializing_if = "Option::is_none")]
692    pub provider_options: Option<LLMProviderOptions>,
693    /// Custom headers to pass to the inference provider
694    #[serde(skip_serializing_if = "Option::is_none")]
695    pub headers: Option<std::collections::HashMap<String, String>>,
696}
697
698#[derive(Debug)]
699pub struct LLMStreamInput {
700    pub model: Model,
701    pub messages: Vec<LLMMessage>,
702    pub max_tokens: u32,
703    pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
704    pub tools: Option<Vec<LLMTool>>,
705    pub provider_options: Option<LLMProviderOptions>,
706    /// Custom headers to pass to the inference provider
707    pub headers: Option<std::collections::HashMap<String, String>>,
708}
709
710impl From<&LLMStreamInput> for LLMInput {
711    fn from(value: &LLMStreamInput) -> Self {
712        LLMInput {
713            model: value.model.clone(),
714            messages: value.messages.clone(),
715            max_tokens: value.max_tokens,
716            tools: value.tools.clone(),
717            provider_options: value.provider_options.clone(),
718            headers: value.headers.clone(),
719        }
720    }
721}
722
723#[derive(Serialize, Deserialize, Debug, Clone, Default)]
724pub struct LLMMessage {
725    pub role: String,
726    pub content: LLMMessageContent,
727}
728
729#[derive(Serialize, Deserialize, Debug, Clone)]
730pub struct SimpleLLMMessage {
731    #[serde(rename = "role")]
732    pub role: SimpleLLMRole,
733    pub content: String,
734}
735
736#[derive(Serialize, Deserialize, Debug, Clone)]
737#[serde(rename_all = "lowercase")]
738pub enum SimpleLLMRole {
739    User,
740    Assistant,
741}
742
743impl std::fmt::Display for SimpleLLMRole {
744    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
745        match self {
746            SimpleLLMRole::User => write!(f, "user"),
747            SimpleLLMRole::Assistant => write!(f, "assistant"),
748        }
749    }
750}
751
752#[derive(Serialize, Deserialize, Debug, Clone)]
753#[serde(untagged)]
754pub enum LLMMessageContent {
755    String(String),
756    List(Vec<LLMMessageTypedContent>),
757}
758
759#[allow(clippy::to_string_trait_impl)]
760impl ToString for LLMMessageContent {
761    fn to_string(&self) -> String {
762        match self {
763            LLMMessageContent::String(s) => s.clone(),
764            LLMMessageContent::List(l) => l
765                .iter()
766                .map(|c| match c {
767                    LLMMessageTypedContent::Text { text } => text.clone(),
768                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
769                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
770                    LLMMessageTypedContent::Image { .. } => String::new(),
771                })
772                .collect::<Vec<_>>()
773                .join("\n"),
774        }
775    }
776}
777
778impl From<String> for LLMMessageContent {
779    fn from(value: String) -> Self {
780        LLMMessageContent::String(value)
781    }
782}
783
784impl Default for LLMMessageContent {
785    fn default() -> Self {
786        LLMMessageContent::String(String::new())
787    }
788}
789
790impl LLMMessageContent {
791    /// Convert into a Vec of typed content parts.
792    /// A `String` variant is returned as a single `Text` part (empty strings yield an empty vec).
793    pub fn into_parts(self) -> Vec<LLMMessageTypedContent> {
794        match self {
795            LLMMessageContent::List(parts) => parts,
796            LLMMessageContent::String(s) if s.is_empty() => vec![],
797            LLMMessageContent::String(s) => vec![LLMMessageTypedContent::Text { text: s }],
798        }
799    }
800}
801
802#[derive(Serialize, Deserialize, Debug, Clone)]
803#[serde(tag = "type")]
804pub enum LLMMessageTypedContent {
805    #[serde(rename = "text")]
806    Text { text: String },
807    #[serde(rename = "tool_use")]
808    ToolCall {
809        id: String,
810        name: String,
811        #[serde(alias = "input")]
812        args: serde_json::Value,
813        /// Opaque provider-specific metadata (e.g., Gemini thought_signature).
814        #[serde(skip_serializing_if = "Option::is_none")]
815        metadata: Option<serde_json::Value>,
816    },
817    #[serde(rename = "tool_result")]
818    ToolResult {
819        tool_use_id: String,
820        content: String,
821    },
822    #[serde(rename = "image")]
823    Image { source: LLMMessageImageSource },
824}
825
826#[derive(Serialize, Deserialize, Debug, Clone)]
827pub struct LLMMessageImageSource {
828    #[serde(rename = "type")]
829    pub r#type: String,
830    pub media_type: String,
831    pub data: String,
832}
833
834impl Default for LLMMessageTypedContent {
835    fn default() -> Self {
836        LLMMessageTypedContent::Text {
837            text: String::new(),
838        }
839    }
840}
841
842#[derive(Serialize, Deserialize, Debug, Clone)]
843pub struct LLMChoice {
844    pub finish_reason: Option<String>,
845    pub index: u32,
846    pub message: LLMMessage,
847}
848
849#[derive(Serialize, Deserialize, Debug, Clone)]
850pub struct LLMCompletionResponse {
851    pub model: String,
852    pub object: String,
853    pub choices: Vec<LLMChoice>,
854    pub created: u64,
855    pub usage: Option<LLMTokenUsage>,
856    pub id: String,
857}
858
859#[derive(Serialize, Deserialize, Debug, Clone)]
860pub struct LLMStreamDelta {
861    #[serde(skip_serializing_if = "Option::is_none")]
862    pub content: Option<String>,
863}
864
865#[derive(Serialize, Deserialize, Debug, Clone)]
866pub struct LLMStreamChoice {
867    pub finish_reason: Option<String>,
868    pub index: u32,
869    pub message: Option<LLMMessage>,
870    pub delta: LLMStreamDelta,
871}
872
873#[derive(Serialize, Deserialize, Debug, Clone)]
874pub struct LLMCompletionStreamResponse {
875    pub model: String,
876    pub object: String,
877    pub choices: Vec<LLMStreamChoice>,
878    pub created: u64,
879    #[serde(skip_serializing_if = "Option::is_none")]
880    pub usage: Option<LLMTokenUsage>,
881    pub id: String,
882    pub citations: Option<Vec<String>>,
883}
884
885#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
886pub struct LLMTool {
887    pub name: String,
888    pub description: String,
889    pub input_schema: serde_json::Value,
890}
891
892#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
893pub struct LLMTokenUsage {
894    pub prompt_tokens: u32,
895    pub completion_tokens: u32,
896    pub total_tokens: u32,
897
898    #[serde(skip_serializing_if = "Option::is_none")]
899    pub prompt_tokens_details: Option<PromptTokensDetails>,
900}
901
902#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
903#[serde(rename_all = "snake_case")]
904pub enum TokenType {
905    InputTokens,
906    OutputTokens,
907    CacheReadInputTokens,
908    CacheWriteInputTokens,
909}
910
911#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
912pub struct PromptTokensDetails {
913    #[serde(skip_serializing_if = "Option::is_none")]
914    pub input_tokens: Option<u32>,
915    #[serde(skip_serializing_if = "Option::is_none")]
916    pub output_tokens: Option<u32>,
917    #[serde(skip_serializing_if = "Option::is_none")]
918    pub cache_read_input_tokens: Option<u32>,
919    #[serde(skip_serializing_if = "Option::is_none")]
920    pub cache_write_input_tokens: Option<u32>,
921}
922
923impl PromptTokensDetails {
924    /// Returns an iterator over the token types and their values
925    pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
926        [
927            (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
928            (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
929            (
930                TokenType::CacheReadInputTokens,
931                self.cache_read_input_tokens.unwrap_or(0),
932            ),
933            (
934                TokenType::CacheWriteInputTokens,
935                self.cache_write_input_tokens.unwrap_or(0),
936            ),
937        ]
938        .into_iter()
939    }
940}
941
942impl std::ops::Add for PromptTokensDetails {
943    type Output = Self;
944
945    fn add(self, rhs: Self) -> Self::Output {
946        Self {
947            input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
948            output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
949            cache_read_input_tokens: Some(
950                self.cache_read_input_tokens.unwrap_or(0)
951                    + rhs.cache_read_input_tokens.unwrap_or(0),
952            ),
953            cache_write_input_tokens: Some(
954                self.cache_write_input_tokens.unwrap_or(0)
955                    + rhs.cache_write_input_tokens.unwrap_or(0),
956            ),
957        }
958    }
959}
960
961impl std::ops::AddAssign for PromptTokensDetails {
962    fn add_assign(&mut self, rhs: Self) {
963        self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
964        self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
965        self.cache_read_input_tokens = Some(
966            self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
967        );
968        self.cache_write_input_tokens = Some(
969            self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
970        );
971    }
972}
973
974#[derive(Serialize, Deserialize, Debug, Clone)]
975#[serde(tag = "type")]
976pub enum GenerationDelta {
977    Content { content: String },
978    Thinking { thinking: String },
979    ToolUse { tool_use: GenerationDeltaToolUse },
980    Usage { usage: LLMTokenUsage },
981    Metadata { metadata: serde_json::Value },
982}
983
984#[derive(Serialize, Deserialize, Debug, Clone)]
985pub struct GenerationDeltaToolUse {
986    pub id: Option<String>,
987    pub name: Option<String>,
988    pub input: Option<String>,
989    pub index: usize,
990    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
991    #[serde(skip_serializing_if = "Option::is_none")]
992    pub metadata: Option<serde_json::Value>,
993}
994
995#[cfg(test)]
996mod tests {
997    use super::*;
998
999    // =========================================================================
1000    // ProviderConfig Tests
1001    // =========================================================================
1002
1003    #[test]
1004    fn test_provider_config_openai_serialization() {
1005        let config = ProviderConfig::OpenAI {
1006            api_key: Some("sk-test".to_string()),
1007            api_endpoint: None,
1008            auth: None,
1009        };
1010        let json = serde_json::to_string(&config).unwrap();
1011        assert!(json.contains("\"type\":\"openai\""));
1012        assert!(json.contains("\"api_key\":\"sk-test\""));
1013        assert!(!json.contains("api_endpoint")); // Should be skipped when None
1014    }
1015
1016    #[test]
1017    fn test_provider_config_openai_with_endpoint() {
1018        let config = ProviderConfig::OpenAI {
1019            api_key: Some("sk-test".to_string()),
1020            api_endpoint: Some("https://custom.openai.com/v1".to_string()),
1021            auth: None,
1022        };
1023        let json = serde_json::to_string(&config).unwrap();
1024        assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
1025    }
1026
1027    #[test]
1028    fn test_provider_config_anthropic_serialization() {
1029        let config = ProviderConfig::Anthropic {
1030            api_key: Some("sk-ant-test".to_string()),
1031            api_endpoint: None,
1032            access_token: Some("oauth-token".to_string()),
1033            auth: None,
1034        };
1035        let json = serde_json::to_string(&config).unwrap();
1036        assert!(json.contains("\"type\":\"anthropic\""));
1037        assert!(json.contains("\"api_key\":\"sk-ant-test\""));
1038        assert!(json.contains("\"access_token\":\"oauth-token\""));
1039    }
1040
1041    #[test]
1042    fn test_provider_config_gemini_serialization() {
1043        let config = ProviderConfig::Gemini {
1044            api_key: Some("gemini-key".to_string()),
1045            api_endpoint: None,
1046            auth: None,
1047        };
1048        let json = serde_json::to_string(&config).unwrap();
1049        assert!(json.contains("\"type\":\"gemini\""));
1050        assert!(json.contains("\"api_key\":\"gemini-key\""));
1051    }
1052
1053    #[test]
1054    fn test_provider_config_custom_serialization() {
1055        let config = ProviderConfig::Custom {
1056            api_key: Some("sk-custom".to_string()),
1057            api_endpoint: "http://localhost:4000".to_string(),
1058            auth: None,
1059        };
1060        let json = serde_json::to_string(&config).unwrap();
1061        assert!(json.contains("\"type\":\"custom\""));
1062        assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
1063        assert!(json.contains("\"api_key\":\"sk-custom\""));
1064    }
1065
1066    #[test]
1067    fn test_provider_config_custom_without_key() {
1068        let config = ProviderConfig::Custom {
1069            api_key: None,
1070            api_endpoint: "http://localhost:11434/v1".to_string(),
1071            auth: None,
1072        };
1073        let json = serde_json::to_string(&config).unwrap();
1074        assert!(json.contains("\"type\":\"custom\""));
1075        assert!(json.contains("\"api_endpoint\""));
1076        assert!(!json.contains("api_key")); // Should be skipped when None
1077    }
1078
1079    #[test]
1080    fn test_provider_config_deserialization_openai() {
1081        let json = r#"{"type":"openai","api_key":"sk-test"}"#;
1082        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1083        assert!(matches!(config, ProviderConfig::OpenAI { .. }));
1084        assert_eq!(config.api_key(), Some("sk-test"));
1085    }
1086
1087    #[test]
1088    fn test_provider_config_deserialization_anthropic() {
1089        let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
1090        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1091        assert!(matches!(config, ProviderConfig::Anthropic { .. }));
1092        assert_eq!(config.api_key(), Some("sk-ant"));
1093        assert_eq!(config.access_token(), Some("oauth"));
1094    }
1095
1096    #[test]
1097    fn test_provider_config_deserialization_gemini() {
1098        let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
1099        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1100        assert!(matches!(config, ProviderConfig::Gemini { .. }));
1101        assert_eq!(config.api_key(), Some("gemini-key"));
1102    }
1103
1104    #[test]
1105    fn test_provider_config_deserialization_custom() {
1106        let json =
1107            r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
1108        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1109        assert!(matches!(config, ProviderConfig::Custom { .. }));
1110        assert_eq!(config.api_key(), Some("sk-custom"));
1111        assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
1112    }
1113
1114    #[test]
1115    fn test_provider_config_helper_methods() {
1116        let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
1117        assert_eq!(openai.provider_type(), "openai");
1118        assert_eq!(openai.api_key(), Some("sk-openai"));
1119
1120        let anthropic =
1121            ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
1122        assert_eq!(anthropic.provider_type(), "anthropic");
1123        assert_eq!(anthropic.access_token(), Some("oauth"));
1124
1125        let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
1126        assert_eq!(gemini.provider_type(), "gemini");
1127
1128        let custom = ProviderConfig::custom(
1129            "http://localhost:4000".to_string(),
1130            Some("sk-custom".to_string()),
1131        );
1132        assert_eq!(custom.provider_type(), "custom");
1133        assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
1134    }
1135
1136    #[test]
1137    fn test_set_api_endpoint_updates_supported_providers() {
1138        let mut openai = ProviderConfig::openai(Some("sk-openai".to_string()));
1139        openai.set_api_endpoint(Some("https://proxy.example.com/v1".to_string()));
1140        assert_eq!(openai.api_endpoint(), Some("https://proxy.example.com/v1"));
1141
1142        let mut bedrock = ProviderConfig::bedrock("us-east-1".to_string(), None);
1143        bedrock.set_api_endpoint(Some("https://ignored.example.com".to_string()));
1144        assert_eq!(bedrock.api_endpoint(), None);
1145    }
1146
1147    #[test]
1148    fn test_llm_provider_config_new() {
1149        let config = LLMProviderConfig::new();
1150        assert!(config.is_empty());
1151    }
1152
1153    #[test]
1154    fn test_llm_provider_config_add_and_get() {
1155        let mut config = LLMProviderConfig::new();
1156        config.add_provider(
1157            "openai",
1158            ProviderConfig::openai(Some("sk-test".to_string())),
1159        );
1160        config.add_provider(
1161            "anthropic",
1162            ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
1163        );
1164
1165        assert!(!config.is_empty());
1166        assert!(config.get_provider("openai").is_some());
1167        assert!(config.get_provider("anthropic").is_some());
1168        assert!(config.get_provider("unknown").is_none());
1169    }
1170
1171    #[test]
1172    fn test_provider_config_toml_parsing() {
1173        // Test parsing a HashMap of providers from TOML-like JSON
1174        let json = r#"{
1175            "openai": {"type": "openai", "api_key": "sk-openai"},
1176            "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
1177            "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
1178        }"#;
1179
1180        let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1181        assert_eq!(providers.len(), 3);
1182
1183        assert!(matches!(
1184            providers.get("openai"),
1185            Some(ProviderConfig::OpenAI { .. })
1186        ));
1187        assert!(matches!(
1188            providers.get("anthropic"),
1189            Some(ProviderConfig::Anthropic { .. })
1190        ));
1191        assert!(matches!(
1192            providers.get("litellm"),
1193            Some(ProviderConfig::Custom { .. })
1194        ));
1195    }
1196
1197    // =========================================================================
1198    // Bedrock ProviderConfig Tests
1199    // =========================================================================
1200
1201    #[test]
1202    fn test_provider_config_bedrock_serialization() {
1203        let config = ProviderConfig::Bedrock {
1204            region: "us-east-1".to_string(),
1205            profile_name: Some("my-profile".to_string()),
1206        };
1207        let json = serde_json::to_string(&config).unwrap();
1208        assert!(json.contains("\"type\":\"amazon-bedrock\""));
1209        assert!(json.contains("\"region\":\"us-east-1\""));
1210        assert!(json.contains("\"profile_name\":\"my-profile\""));
1211    }
1212
1213    #[test]
1214    fn test_provider_config_bedrock_serialization_without_profile() {
1215        let config = ProviderConfig::Bedrock {
1216            region: "us-west-2".to_string(),
1217            profile_name: None,
1218        };
1219        let json = serde_json::to_string(&config).unwrap();
1220        assert!(json.contains("\"type\":\"amazon-bedrock\""));
1221        assert!(json.contains("\"region\":\"us-west-2\""));
1222        assert!(!json.contains("profile_name")); // Should be skipped when None
1223    }
1224
1225    #[test]
1226    fn test_provider_config_bedrock_deserialization() {
1227        let json = r#"{"type":"amazon-bedrock","region":"us-east-1","profile_name":"prod"}"#;
1228        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1229        assert!(matches!(config, ProviderConfig::Bedrock { .. }));
1230        assert_eq!(config.region(), Some("us-east-1"));
1231        assert_eq!(config.profile_name(), Some("prod"));
1232    }
1233
1234    #[test]
1235    fn test_provider_config_bedrock_deserialization_minimal() {
1236        let json = r#"{"type":"amazon-bedrock","region":"eu-west-1"}"#;
1237        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1238        assert!(matches!(config, ProviderConfig::Bedrock { .. }));
1239        assert_eq!(config.region(), Some("eu-west-1"));
1240        assert_eq!(config.profile_name(), None);
1241    }
1242
1243    #[test]
1244    fn test_provider_config_bedrock_no_api_key() {
1245        let config = ProviderConfig::bedrock("us-east-1".to_string(), None);
1246        assert_eq!(config.api_key(), None); // Bedrock uses AWS credential chain
1247        assert_eq!(config.api_endpoint(), None); // No custom endpoint
1248    }
1249
1250    #[test]
1251    fn test_provider_config_bedrock_helper_methods() {
1252        let bedrock = ProviderConfig::bedrock("us-east-1".to_string(), Some("prod".to_string()));
1253        assert_eq!(bedrock.provider_type(), "amazon-bedrock");
1254        assert_eq!(bedrock.region(), Some("us-east-1"));
1255        assert_eq!(bedrock.profile_name(), Some("prod"));
1256        assert_eq!(bedrock.api_key(), None);
1257        assert_eq!(bedrock.api_endpoint(), None);
1258        assert_eq!(bedrock.access_token(), None);
1259    }
1260
1261    #[test]
1262    fn test_provider_config_bedrock_toml_roundtrip() {
1263        let config = ProviderConfig::Bedrock {
1264            region: "us-east-1".to_string(),
1265            profile_name: Some("my-profile".to_string()),
1266        };
1267        let toml_str = toml::to_string(&config).unwrap();
1268        let parsed: ProviderConfig = toml::from_str(&toml_str).unwrap();
1269        assert_eq!(config, parsed);
1270    }
1271
1272    #[test]
1273    fn test_provider_config_bedrock_toml_parsing() {
1274        let toml_str = r#"
1275            type = "amazon-bedrock"
1276            region = "us-east-1"
1277            profile_name = "production"
1278        "#;
1279        let config: ProviderConfig = toml::from_str(toml_str).unwrap();
1280        assert!(matches!(
1281            config,
1282            ProviderConfig::Bedrock {
1283                ref region,
1284                ref profile_name,
1285            } if region == "us-east-1" && profile_name.as_deref() == Some("production")
1286        ));
1287    }
1288
1289    #[test]
1290    fn test_provider_config_bedrock_missing_region_fails() {
1291        let json = r#"{"type":"amazon-bedrock"}"#;
1292        let result: Result<ProviderConfig, _> = serde_json::from_str(json);
1293        assert!(result.is_err()); // region is required
1294    }
1295
1296    #[test]
1297    fn test_provider_config_bedrock_in_providers_map() {
1298        let json = r#"{
1299            "anthropic": {"type": "anthropic", "api_key": "sk-ant"},
1300            "amazon-bedrock": {"type": "amazon-bedrock", "region": "us-east-1"}
1301        }"#;
1302        let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1303        assert_eq!(providers.len(), 2);
1304        assert!(matches!(
1305            providers.get("amazon-bedrock"),
1306            Some(ProviderConfig::Bedrock { .. })
1307        ));
1308    }
1309
1310    #[test]
1311    fn test_region_returns_none_for_non_bedrock() {
1312        let openai = ProviderConfig::openai(Some("key".to_string()));
1313        assert_eq!(openai.region(), None);
1314
1315        let anthropic = ProviderConfig::anthropic(Some("key".to_string()), None);
1316        assert_eq!(anthropic.region(), None);
1317    }
1318
1319    #[test]
1320    fn test_profile_name_returns_none_for_non_bedrock() {
1321        let openai = ProviderConfig::openai(Some("key".to_string()));
1322        assert_eq!(openai.profile_name(), None);
1323    }
1324}