Skip to main content

phi_core/provider/
model.rs

1//! Model configuration and provider compatibility flags.
2/*
3ARCHITECTURE: model.rs — the "model identity card"
4
5Every LLM provider has a different API shape, auth style, field names, and
6quirks. This module defines the data structures that capture all of that
7variation in a single `ModelConfig` value.
8
9Key types:
10  `ApiProtocol`  — which wire protocol to use (Anthropic vs OpenAI vs Gemini vs ...)
11  `ModelConfig`  — the full model "identity card": base_url, auth, limits, quirks
12  `OpenAiCompat` — per-provider flags for the 15+ OpenAI-compatible providers
13  `CostConfig`   — token pricing (optional, used for cost tracking)
14
15How it flows:
16  1. Caller builds or loads a `ModelConfig` (factory methods: `ModelConfig::anthropic()`,
17     `ModelConfig::openai()`, etc., or deserialize from JSON/YAML)
18  2. Sets it on `StreamConfig::model_config`
19  3. `ProviderRegistry::for_protocol()` picks the right `StreamProvider` impl
20     based on `config.api`
21  4. The provider uses `base_url`, `compat`, `headers` etc. from `ModelConfig`
22     to customise API calls
23
24Why not hard-code provider details in each provider file?
25  ModelConfig externalizes the provider-specific details so users can configure
26  custom endpoints, private deployments, or new providers without changing
27  provider source code.
28*/
29
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::sync::Arc;
33
34use super::traits::ProviderError;
35
36// ---------------------------------------------------------------------------
37// CredentialProvider — pluggable, refreshable API-key source
38// ---------------------------------------------------------------------------
39
40/// Pluggable source of the API key for a [`ModelConfig`].
41///
42/// Long-running agents on short-lived credentials (AWS STS, OAuth, Workload-Identity)
43/// would otherwise hit `ProviderError::Auth` mid-run and stop. Wiring a
44/// `CredentialProvider` lets the agent resolve the current key per-call and refresh
45/// on auth failures — the retry loop in `streaming.rs` calls `invalidate()` once on
46/// `Auth` and retries the stream call before propagating.
47///
48/// The trait is intentionally tiny — implementors are free to cache, validate against
49/// an external metadata service, or block on a key-management API as needed.
50///
51/// # Example
52///
53/// ```no_run
54/// use async_trait::async_trait;
55/// use phi_core::provider::{CredentialProvider, ProviderError};
56/// use std::sync::Mutex;
57///
58/// #[derive(Debug)]
59/// struct StsProvider {
60///     cached: Mutex<Option<String>>,
61/// }
62///
63/// #[async_trait]
64/// impl CredentialProvider for StsProvider {
65///     async fn current(&self) -> Result<String, ProviderError> {
66///         if let Some(k) = self.cached.lock().unwrap().clone() {
67///             return Ok(k);
68///         }
69///         // Hit STS, cache, return... (omitted)
70///         Err(ProviderError::Auth("STS unavailable".into()))
71///     }
72///
73///     async fn invalidate(&self) -> Result<(), ProviderError> {
74///         self.cached.lock().unwrap().take();
75///         Ok(())
76///     }
77/// }
78/// ```
79#[async_trait::async_trait]
80pub trait CredentialProvider: std::fmt::Debug + Send + Sync {
81    /// Return the current API key for this credential. Implementations may cache,
82    /// re-fetch from a metadata service, or compute on the fly. Called once per
83    /// `StreamProvider::stream()` invocation.
84    async fn current(&self) -> Result<String, ProviderError>;
85
86    /// Hint that the current cached credential has been rejected by the upstream
87    /// API and a fresh value should be fetched on the next `current()` call.
88    ///
89    /// Default impl is a no-op for providers that always re-fetch.
90    async fn invalidate(&self) -> Result<(), ProviderError> {
91        Ok(())
92    }
93}
94
95/// Reference implementation of [`CredentialProvider`] that always returns a fixed key.
96///
97/// Useful for tests and for wiring a [`ModelConfig`] uniformly when refresh is not
98/// needed — equivalent to leaving `ModelConfig::credentials = None` and relying on
99/// the static `api_key` field, but lets test harnesses count `invalidate()` calls.
100#[derive(Debug, Clone)]
101pub struct StaticCredentialProvider {
102    key: String,
103}
104
105impl StaticCredentialProvider {
106    pub fn new(key: impl Into<String>) -> Self {
107        Self { key: key.into() }
108    }
109}
110
111#[async_trait::async_trait]
112impl CredentialProvider for StaticCredentialProvider {
113    async fn current(&self) -> Result<String, ProviderError> {
114        Ok(self.key.clone())
115    }
116}
117
118/// Which API protocol a model uses.
119/*
120ARCHITECTURE: ApiProtocol — the dispatch key for the provider registry
121
122`ProviderRegistry::for_protocol(api: ApiProtocol)` maps each variant to
123a concrete `StreamProvider` implementation:
124  AnthropicMessages       → AnthropicProvider
125  OpenAiCompletions       → OpenAiCompatProvider (handles 15+ providers)
126  OpenAiResponses         → OpenAiResponsesProvider
127  AzureOpenAiResponses    → AzureOpenAiProvider
128  GoogleGenerativeAi      → GoogleProvider
129  GoogleVertex            → GoogleVertexProvider
130  BedrockConverseStream   → BedrockProvider
131
132This is the "Strategy via enum dispatch" pattern: the enum variant IS the strategy
133selector. The registry (registry.rs) `match`es on this enum and returns the right
134provider. At runtime, the caller only holds a `Box<dyn StreamProvider>` and never
135needs to know which variant was used.
136
137RUST QUIRK: `Hash` derive — required for use as HashMap keys
138
139`#[derive(Hash)]` enables values of this type to be used as keys in `HashMap<K, V>`.
140`Hash` computes an integer hash of the value. Combined with `PartialEq + Eq`
141(also derived), this is what HashMap needs:
142  - `Hash` to find the bucket
143  - `Eq` to confirm the key matches within the bucket (hash collisions)
144
145Why does `ApiProtocol` need to be a HashMap key?
146  In `ProviderRegistry`, we may store `HashMap<ApiProtocol, Box<dyn StreamProvider>>`.
147  Without `Hash + Eq`, that HashMap would fail to compile.
148
149RUST QUIRK: `Copy` on an enum with no data fields
150  All variants of `ApiProtocol` carry no data — they're just tags.
151  `Copy` lets the compiler bitwise-copy the value instead of moving it.
152  After `let api = model.api;`, `model.api` is STILL valid (Copy semantics).
153  Python analogy: Python enums are always by-reference, so no equivalent concept.
154
155RUST QUIRK: `#[serde(rename_all = "snake_case")]`
156  When serializing to JSON/YAML, variant names are converted to snake_case:
157    `AnthropicMessages` → "anthropic_messages"
158    `BedrockConverseStream` → "bedrock_converse_stream"
159  This makes config files human-readable without matching Rust's PascalCase convention.
160*/
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
162#[serde(rename_all = "snake_case")]
163pub enum ApiProtocol {
164    AnthropicMessages,
165    OpenAiCompletions,
166    OpenAiResponses,
167    AzureOpenAiResponses,
168    GoogleGenerativeAi,
169    GoogleVertex,
170    BedrockConverseStream,
171}
172
173impl std::fmt::Display for ApiProtocol {
174    /*
175    RUST QUIRK: Implementing `Display` manually (vs deriving it)
176
177    `Display` (the `{}` formatter) is NOT derivable — you must write it by hand.
178    `Debug` (the `{:?}` formatter) IS derivable.
179
180    Why? `Debug` is purely for developers (shows the Rust name), so auto-generated
181    is fine. `Display` is for end-users, and you control the string representation.
182
183    Here we return snake_case strings ("anthropic_messages") instead of the
184    Rust PascalCase names ("AnthropicMessages") — consistent with the serde rename.
185
186    `write!(f, "...")` — writes into the formatter buffer `f`.
187    Returns `fmt::Result` (Ok or Err), required by the trait.
188    Python analogy: implementing __str__(self) → return "anthropic_messages"
189    */
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        match self {
192            Self::AnthropicMessages => write!(f, "anthropic_messages"),
193            Self::OpenAiCompletions => write!(f, "openai_completions"),
194            Self::OpenAiResponses => write!(f, "openai_responses"),
195            Self::AzureOpenAiResponses => write!(f, "azure_openai_responses"),
196            Self::GoogleGenerativeAi => write!(f, "google_generative_ai"),
197            Self::GoogleVertex => write!(f, "google_vertex"),
198            Self::BedrockConverseStream => write!(f, "bedrock_converse_stream"),
199        }
200    }
201}
202
203/// Cost per million tokens (input/output).
204/*
205ARCHITECTURE: CostConfig — optional cost tracking
206
207LLM providers charge differently for input vs output tokens, and some offer
208reduced prices for cache reads and cache writes (Anthropic prompt caching).
209
210`CostConfig` is embedded in `ModelConfig` but has `#[serde(default)]` fields,
211meaning callers who don't care about cost tracking don't need to supply them —
212they default to 0.0.
213
214RUST QUIRK: `#[serde(default)]` — per-field default during deserialization
215  When deserializing a `ModelConfig`, if "cache_read_per_million" is absent in
216  the JSON/YAML, serde calls `Default::default()` for that field instead of
217  returning an error. This makes the struct forward-compatible: old config files
218  (without the cache fields) still deserialize correctly.
219  Python analogy: `dataclasses.field(default=0.0)` or `pydantic.Field(default=0.0)`
220*/
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct CostConfig {
223    pub input_per_million: f64,
224    pub output_per_million: f64,
225    #[serde(default)]
226    pub cache_read_per_million: f64,
227    #[serde(default)]
228    pub cache_write_per_million: f64,
229}
230
231impl Default for CostConfig {
232    fn default() -> Self {
233        Self {
234            input_per_million: 0.0,
235            output_per_million: 0.0,
236            cache_read_per_million: 0.0,
237            cache_write_per_million: 0.0,
238        }
239    }
240}
241
242/// How a provider handles the `max_tokens` field.
243/*
244ARCHITECTURE: MaxTokensField — a per-provider API quirk
245
246The OpenAI-compatible API has two field names for the same concept:
247  `max_tokens`           — the original field name, used by most providers
248  `max_completion_tokens`— new name, required by OpenAI o-series reasoning models
249
250Both control the maximum number of tokens in the response, but OpenAI split
251them so reasoning token budgets are counted separately. The provider must use
252the correct field name, or the API returns an error.
253
254`MaxTokensField` is a small enum used as a flag inside `OpenAiCompat`, avoiding
255a raw `bool` (which would be less self-documenting).
256
257RUST QUIRK: `#[derive(Default)]` + `#[default]` on a variant
258  `#[derive(Default)]` auto-generates `Default::default()` for the enum.
259  `#[default]` on a specific variant marks it as the default value:
260    `MaxTokensField::default()` → `MaxTokensField::MaxTokens`
261  Without `#[default]`, the derive macro wouldn't know which variant to pick.
262  Python analogy: no direct equivalent; closest is Enum with a class variable for default.
263*/
264#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
265#[serde(rename_all = "snake_case")]
266pub enum MaxTokensField {
267    #[default]
268    MaxTokens,
269    MaxCompletionTokens,
270}
271
272/// How a provider formats thinking/reasoning output.
273/*
274ARCHITECTURE: ThinkingFormat — per-provider reasoning output format
275
276Extended thinking / chain-of-thought output is formatted differently by each provider:
277  `OpenAi` — reasoning appears in a dedicated `reasoning_content` array
278  `Xai`    — Grok's format (slightly different JSON structure)
279  `Qwen`   — Qwen's format (another variation)
280
281This flag tells `openai_compat.rs` which parsing branch to use when extracting
282thinking deltas from the streaming response. Without this flag, we'd need a
283separate provider file for each thinking-capable service.
284*/
285#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
286#[serde(rename_all = "snake_case")]
287pub enum ThinkingFormat {
288    #[default]
289    OpenAi,
290    Xai,
291    Qwen,
292    /// OpenRouter streaming format: reads thinking text from `delta.reasoning_details`
293    /// array entries where `type == "thinking"`.
294    OpenRouter,
295}
296
297/// Compatibility flags for OpenAI-compatible providers.
298/// Different providers have different quirks even though they share the same base API.
299/*
300ARCHITECTURE: OpenAiCompat — the "quirk matrix" for 15+ OpenAI-compatible providers
301
302The OpenAI Chat Completions API is a de-facto standard that dozens of providers
303implement. But every provider deviates in small ways:
304  - OpenAI o-series uses `max_completion_tokens` not `max_tokens`
305  - xAI (Grok) uses a different thinking output format
306  - Some providers don't include usage data in streaming chunks
307  - Some require a `name` field in tool results
308  - Some need a dummy assistant message inserted after tool results
309
310Instead of writing a separate provider for each quirk combination, we have ONE
311`openai_compat.rs` provider that reads `OpenAiCompat` flags at runtime and
312branches accordingly. New providers = new `OpenAiCompat::new_provider()` factory.
313
314The factory methods (`openai()`, `xai()`, `groq()`, ...) use `..Default::default()`
315struct update syntax to express only the fields that differ from defaults.
316Python analogy: a dataclass with defaults, and factory classmethods that override
317only the fields that need to change.
318*/
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct OpenAiCompat {
321    /// Supports the `store` parameter for conversation persistence.
322    pub supports_store: bool,
323    /// Supports `developer` role (system-level instructions).
324    pub supports_developer_role: bool,
325    /// Supports `reasoning_effort` parameter.
326    pub supports_reasoning_effort: bool,
327    /// Includes usage data in streaming responses.
328    pub supports_usage_in_streaming: bool,
329    /// Which field name to use for max tokens.
330    pub max_tokens_field: MaxTokensField,
331    /// Tool results must include a `name` field.
332    pub requires_tool_result_name: bool,
333    /// Must insert an assistant message after tool results.
334    pub requires_assistant_after_tool_result: bool,
335    /// How thinking/reasoning content is formatted in streaming.
336    pub thinking_format: ThinkingFormat,
337}
338
339impl Default for OpenAiCompat {
340    /*
341    RUST QUIRK: `impl Default` manually (rather than `#[derive(Default)]`)
342
343    `#[derive(Default)]` would work only if every field's type implements `Default`
344    AND the zero-values are the right defaults. Here, `supports_usage_in_streaming`
345    should default to `true`, not `false`. Since `bool` defaults to `false`, we
346    must override it manually.
347
348    A manually written `Default` impl is common when some field defaults are
349    non-trivial (non-zero numbers, non-empty strings, true booleans, etc.).
350    */
351    fn default() -> Self {
352        Self {
353            supports_store: false,
354            supports_developer_role: false,
355            supports_reasoning_effort: false,
356            supports_usage_in_streaming: true, // most OpenAI-compat providers include usage
357            max_tokens_field: MaxTokensField::MaxTokens,
358            requires_tool_result_name: false,
359            requires_assistant_after_tool_result: false,
360            thinking_format: ThinkingFormat::OpenAi,
361        }
362    }
363}
364
365impl OpenAiCompat {
366    /// Compat flags for native OpenAI.
367    /*
368    RUST QUIRK: `..Default::default()` — struct update syntax for overriding defaults
369
370    `Self { supports_store: true, ..Default::default() }` means:
371      "build a Self where supports_store = true (and supports_developer_role = true,
372       supports_reasoning_effort = true, max_tokens_field = MaxCompletionTokens)
373       and all OTHER fields come from Default::default()"
374
375    The `..expr` "spreads" the remaining fields from a base value.
376    It MUST be last in the struct literal.
377    Python analogy: dataclasses.replace(OpenAiCompat(), supports_store=True, ...)
378
379    Why is this better than repeating all fields?
380      - Fewer lines to write (only express differences from defaults)
381      - If a new field is added with a sensible default, existing factory methods
382        automatically get the right value — no manual update needed
383    */
384    pub fn openai() -> Self {
385        Self {
386            supports_store: true,
387            supports_developer_role: true,
388            supports_reasoning_effort: true,
389            supports_usage_in_streaming: true,
390            max_tokens_field: MaxTokensField::MaxCompletionTokens,
391            ..Default::default()
392        }
393    }
394
395    /// Compat flags for xAI (Grok).
396    pub fn xai() -> Self {
397        Self {
398            supports_usage_in_streaming: true,
399            thinking_format: ThinkingFormat::Xai, // Grok uses a different thinking JSON shape
400            ..Default::default()
401        }
402    }
403
404    /// Compat flags for Groq.
405    pub fn groq() -> Self {
406        Self {
407            supports_usage_in_streaming: true,
408            ..Default::default()
409        }
410    }
411
412    /// Compat flags for Cerebras.
413    pub fn cerebras() -> Self {
414        Self::default() // no deviations from defaults
415    }
416
417    /// Compat flags for OpenRouter.
418    pub fn openrouter() -> Self {
419        Self {
420            supports_developer_role: true, // OpenRouter supports "developer" role
421            supports_usage_in_streaming: true,
422            max_tokens_field: MaxTokensField::MaxTokens, // OpenRouter uses max_tokens (not max_completion_tokens)
423            thinking_format: ThinkingFormat::OpenRouter, // reasoning_details array format
424            ..Default::default()
425        }
426    }
427
428    /// Compat flags for Mistral.
429    pub fn mistral() -> Self {
430        Self {
431            supports_usage_in_streaming: true,
432            max_tokens_field: MaxTokensField::MaxTokens,
433            ..Default::default()
434        }
435    }
436
437    /// Compat flags for DeepSeek.
438    pub fn deepseek() -> Self {
439        Self {
440            supports_usage_in_streaming: true,
441            max_tokens_field: MaxTokensField::MaxCompletionTokens,
442            ..Default::default()
443        }
444    }
445}
446
447/// Full model configuration. Knows everything needed to make API calls.
448/*
449ARCHITECTURE: ModelConfig — the single source of truth for a model's identity
450
451`ModelConfig` bundles everything a provider needs to make API calls:
452  - `id` / `name`    — which model to request (sent in the API body)
453  - `api`            — which provider implementation to use (dispatch key)
454  - `provider`       — human label for logging/display
455  - `base_url`       — the HTTP endpoint (can be a private deployment or proxy)
456  - `reasoning`      — whether this model supports extended thinking
457  - `context_window` — max input tokens (used for context compaction decisions)
458  - `max_tokens`     — default output token limit
459  - `cost`           — token pricing for cost tracking
460  - `headers`        — additional HTTP headers (e.g., API-version headers)
461  - `compat`         — OpenAI quirk flags (only for OpenAiCompletions protocol)
462
463Factory methods (`anthropic()`, `openai()`, `local()`, `google()`) cover common
464cases. Custom providers are built by constructing the struct directly.
465
466RUST QUIRK: `HashMap<String, String>` — a key-value dictionary
467  `HashMap<K, V>` from `std::collections` — Rust's standard hash map.
468  Here it stores additional HTTP headers like `{"X-My-Header": "value"}`.
469  Python analogy: `dict[str, str]`.
470  `#[serde(default)]` means it deserializes as an empty HashMap if absent in config.
471
472RUST QUIRK: `Option<OpenAiCompat>` — present only for OpenAI-compat providers
473  Anthropic/Google/Bedrock have their own provider files that don't use `compat`.
474  For them, `compat` is `None`. For OpenAI-compatible providers, `compat` is
475  `Some(OpenAiCompat { ... })`. This models "this field only makes sense for
476  a subset of configurations." The provider accesses it with `compat.as_ref()?` or
477  `compat.unwrap_or_default()`.
478*/
479#[derive(Debug, Clone, Serialize, Deserialize)]
480pub struct ModelConfig {
481    /// Model identifier sent to the API (e.g. "gpt-4o", "claude-sonnet-4-20250514").
482    pub id: String,
483    /// Human-friendly name.
484    pub name: String,
485    /// Which API protocol to use.
486    pub api: ApiProtocol,
487    /// Provider name (e.g. "openai", "anthropic", "xai").
488    pub provider: String,
489    /// Base URL for API requests (without trailing slash).
490    pub base_url: String,
491    /// Authentication credential for this provider (API key, Bearer token, or
492    /// `access_key:secret[:session_token]` for Bedrock).
493    /// Defaults to an empty string so config files can omit it and supply via env instead.
494    #[serde(default)]
495    pub api_key: String,
496    /// Whether this model supports reasoning/thinking.
497    pub reasoning: bool,
498    /// Context window size in tokens.
499    pub context_window: u32,
500    /// Default max output tokens.
501    pub max_tokens: u32,
502    /// Cost configuration.
503    #[serde(default)]
504    pub cost: CostConfig,
505    /// Additional headers to send with requests.
506    #[serde(default)]
507    pub headers: HashMap<String, String>,
508    /// OpenAI-compat quirk flags (only for OpenAiCompletions protocol).
509    #[serde(default)]
510    pub compat: Option<OpenAiCompat>,
511    /// Optional refreshable credential source. When `Some`, every `stream()` call
512    /// resolves the API key via `credentials.current().await` instead of reading
513    /// `api_key` directly; the retry loop calls `credentials.invalidate().await`
514    /// once on `ProviderError::Auth` and retries the call before propagating.
515    /// When `None` (the default), `api_key` is used verbatim, preserving the legacy
516    /// static-key behaviour.
517    #[serde(skip)]
518    pub credentials: Option<Arc<dyn CredentialProvider>>,
519}
520
521impl ModelConfig {
522    /// Create a new Anthropic model config.
523    pub fn anthropic(
524        id: impl Into<String>, // API ID — model identifier sent in the request body (e.g. "claude-sonnet-4-20250514")
525        name: impl Into<String>, // DISPLAY NAME — human-readable label for logging/UI; not sent to the API
526        api_key: impl Into<String>, // AUTH — "sk-ant-..." or OAuth token "sk-ant-oat..."
527    ) -> Self {
528        Self {
529            id: id.into(),
530            name: name.into(),
531            api: ApiProtocol::AnthropicMessages,
532            provider: "anthropic".into(),
533            base_url: "https://api.anthropic.com".into(),
534            api_key: api_key.into(),
535            reasoning: false,
536            context_window: 200_000,
537            max_tokens: 8192,
538            cost: CostConfig::default(),
539            headers: HashMap::new(),
540            compat: None, // Anthropic has its own protocol, no compat flags needed
541            credentials: None,
542        }
543    }
544
545    /// Create a new OpenAI model config.
546    pub fn openai(
547        id: impl Into<String>, // API ID — model identifier sent in the request body (e.g. "gpt-4o")
548        name: impl Into<String>, // DISPLAY NAME — human-readable label for logging/UI; not sent to the API
549        api_key: impl Into<String>, // AUTH — "sk-..."
550    ) -> Self {
551        Self {
552            id: id.into(),
553            name: name.into(),
554            api: ApiProtocol::OpenAiCompletions,
555            provider: "openai".into(),
556            base_url: "https://api.openai.com/v1".into(),
557            api_key: api_key.into(),
558            reasoning: false,
559            context_window: 128_000,
560            max_tokens: 4096,
561            cost: CostConfig::default(),
562            headers: HashMap::new(),
563            compat: Some(OpenAiCompat::openai()), // OpenAI needs compat flags (store, developer role, etc.)
564            credentials: None,
565        }
566    }
567
568    /// Create a config for a local OpenAI-compatible server (LM Studio, Ollama, etc.).
569    /// Pass an empty string for `api_key` — most local servers don't require authentication.
570    pub fn local(
571        base_url: impl Into<String>, // ENDPOINT — full base URL of the local server (e.g. "http://localhost:1234/v1")
572        model_id: impl Into<String>, // API ID — model name expected by the local server (e.g. "llama-3.1-8b")
573        api_key: impl Into<String>,  // AUTH — empty string for unauthenticated local servers
574    ) -> Self {
575        Self {
576            id: model_id.into(),
577            name: "Local Model".into(),
578            api: ApiProtocol::OpenAiCompletions,
579            provider: "local".into(),
580            base_url: base_url.into(), // caller provides e.g. "http://localhost:1234/v1"
581            api_key: api_key.into(),
582            reasoning: false,
583            context_window: 128_000,
584            max_tokens: 4096,
585            cost: CostConfig::default(),
586            headers: HashMap::new(),
587            compat: Some(OpenAiCompat::default()), // most local servers are generic OpenAI-compat
588            credentials: None,
589        }
590    }
591
592    /// Create a new Google Generative AI (Gemini) model config.
593    pub fn google(
594        id: impl Into<String>, // API ID — model identifier sent in the request URL (e.g. "gemini-2.5-pro")
595        name: impl Into<String>, // DISPLAY NAME — human-readable label for logging/UI; not sent to the API
596        api_key: impl Into<String>, // AUTH — Google AI Studio API key
597    ) -> Self {
598        Self {
599            id: id.into(),
600            name: name.into(),
601            api: ApiProtocol::GoogleGenerativeAi,
602            provider: "google".into(),
603            base_url: "https://generativelanguage.googleapis.com".into(),
604            api_key: api_key.into(),
605            reasoning: false,
606            context_window: 1_000_000,
607            max_tokens: 8192,
608            cost: CostConfig::default(),
609            headers: HashMap::new(),
610            compat: None, // Google has its own protocol, no compat flags needed
611            credentials: None,
612        }
613    }
614
615    /// Create a new OpenRouter model config.
616    /// `model_id` uses the `provider/model` format (e.g. `"anthropic/claude-sonnet-4"`).
617    pub fn openrouter(
618        model_id: impl Into<String>, // API ID — "provider/model" format (e.g. "anthropic/claude-sonnet-4")
619        api_key: impl Into<String>,  // AUTH — "sk-or-..."
620    ) -> Self {
621        let id = model_id.into();
622        Self {
623            name: id.clone(),
624            id,
625            api: ApiProtocol::OpenAiCompletions,
626            provider: "openrouter".into(),
627            base_url: "https://openrouter.ai/api/v1".into(),
628            api_key: api_key.into(),
629            reasoning: false,
630            context_window: 200_000, // conservative default; varies by routed model
631            max_tokens: 4096,
632            cost: CostConfig::default(),
633            headers: HashMap::new(),
634            compat: Some(OpenAiCompat::openrouter()),
635            credentials: None,
636        }
637    }
638
639    /// Attach a refreshable credential source. When set, the API key is resolved
640    /// per-call via `credentials.current().await` instead of being read directly
641    /// from `self.api_key`. The retry loop also calls `credentials.invalidate()`
642    /// once on `ProviderError::Auth` and re-tries the stream call.
643    pub fn with_credentials(mut self, creds: Arc<dyn CredentialProvider>) -> Self {
644        self.credentials = Some(creds);
645        self
646    }
647
648    /// Resolve the API key for an outgoing request.
649    ///
650    /// When `credentials` is set, delegate to its `current()` method (which may
651    /// re-fetch from a metadata service or return a cached value). Otherwise fall
652    /// back to the static `api_key` field. Providers should call this once at the
653    /// top of `stream()` instead of reading `api_key` directly.
654    pub async fn resolve_api_key(&self) -> Result<String, ProviderError> {
655        match &self.credentials {
656            Some(c) => c.current().await,
657            None => Ok(self.api_key.clone()),
658        }
659    }
660
661    /// Signal that the currently cached credential has been rejected. If
662    /// `credentials` is set, delegate to its `invalidate()`; otherwise a no-op.
663    /// Invoked by the streaming retry loop on `ProviderError::Auth`.
664    pub async fn invalidate_credentials(&self) -> Result<(), ProviderError> {
665        match &self.credentials {
666            Some(c) => c.invalidate().await,
667            None => Ok(()),
668        }
669    }
670}
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675
676    #[test]
677    fn test_model_config_anthropic() {
678        let config =
679            ModelConfig::anthropic("claude-sonnet-4-20250514", "Claude Sonnet 4", "sk-ant-key");
680        assert_eq!(config.api, ApiProtocol::AnthropicMessages);
681        assert_eq!(config.provider, "anthropic");
682        assert_eq!(config.api_key, "sk-ant-key");
683        assert!(config.compat.is_none());
684    }
685
686    #[test]
687    fn test_model_config_openai() {
688        let config = ModelConfig::openai("gpt-4o", "GPT-4o", "sk-key");
689        assert_eq!(config.api, ApiProtocol::OpenAiCompletions);
690        let compat = config.compat.unwrap();
691        assert!(compat.supports_store);
692        assert!(compat.supports_developer_role);
693        assert_eq!(compat.max_tokens_field, MaxTokensField::MaxCompletionTokens);
694    }
695
696    #[test]
697    fn test_openai_compat_variants() {
698        let xai = OpenAiCompat::xai();
699        assert_eq!(xai.thinking_format, ThinkingFormat::Xai);
700        assert!(!xai.supports_store);
701
702        let groq = OpenAiCompat::groq();
703        assert!(groq.supports_usage_in_streaming);
704        assert!(!groq.supports_store);
705
706        let deepseek = OpenAiCompat::deepseek();
707        assert_eq!(
708            deepseek.max_tokens_field,
709            MaxTokensField::MaxCompletionTokens
710        );
711    }
712
713    #[test]
714    fn test_api_protocol_display() {
715        assert_eq!(
716            ApiProtocol::AnthropicMessages.to_string(),
717            "anthropic_messages"
718        );
719        assert_eq!(
720            ApiProtocol::OpenAiCompletions.to_string(),
721            "openai_completions"
722        );
723        assert_eq!(
724            ApiProtocol::GoogleGenerativeAi.to_string(),
725            "google_generative_ai"
726        );
727    }
728
729    #[test]
730    fn test_cost_config_default() {
731        let cost = CostConfig::default();
732        assert_eq!(cost.input_per_million, 0.0);
733        assert_eq!(cost.output_per_million, 0.0);
734    }
735}