phi_core/provider/model.rs
1//! Model configuration and provider compatibility flags.
2/*
3ARCHITECTURE: model.rs — the "model identity card"
4
5Every LLM provider has a different API shape, auth style, field names, and
6quirks. This module defines the data structures that capture all of that
7variation in a single `ModelConfig` value.
8
9Key types:
10 `ApiProtocol` — which wire protocol to use (Anthropic vs OpenAI vs Gemini vs ...)
11 `ModelConfig` — the full model "identity card": base_url, auth, limits, quirks
12 `OpenAiCompat` — per-provider flags for the 15+ OpenAI-compatible providers
13 `CostConfig` — token pricing (optional, used for cost tracking)
14
15How it flows:
16 1. Caller builds or loads a `ModelConfig` (factory methods: `ModelConfig::anthropic()`,
17 `ModelConfig::openai()`, etc., or deserialize from JSON/YAML)
18 2. Sets it on `StreamConfig::model_config`
19 3. `ProviderRegistry::for_protocol()` picks the right `StreamProvider` impl
20 based on `config.api`
21 4. The provider uses `base_url`, `compat`, `headers` etc. from `ModelConfig`
22 to customise API calls
23
24Why not hard-code provider details in each provider file?
25 ModelConfig externalizes the provider-specific details so users can configure
26 custom endpoints, private deployments, or new providers without changing
27 provider source code.
28*/
29
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::sync::Arc;
33
34use super::traits::ProviderError;
35
36// ---------------------------------------------------------------------------
37// CredentialProvider — pluggable, refreshable API-key source
38// ---------------------------------------------------------------------------
39
40/// Pluggable source of the API key for a [`ModelConfig`].
41///
42/// Long-running agents on short-lived credentials (AWS STS, OAuth, Workload-Identity)
43/// would otherwise hit `ProviderError::Auth` mid-run and stop. Wiring a
44/// `CredentialProvider` lets the agent resolve the current key per-call and refresh
45/// on auth failures — the retry loop in `streaming.rs` calls `invalidate()` once on
46/// `Auth` and retries the stream call before propagating.
47///
48/// The trait is intentionally tiny — implementors are free to cache, validate against
49/// an external metadata service, or block on a key-management API as needed.
50///
51/// # Example
52///
53/// ```no_run
54/// use async_trait::async_trait;
55/// use phi_core::provider::{CredentialProvider, ProviderError};
56/// use std::sync::Mutex;
57///
58/// #[derive(Debug)]
59/// struct StsProvider {
60/// cached: Mutex<Option<String>>,
61/// }
62///
63/// #[async_trait]
64/// impl CredentialProvider for StsProvider {
65/// async fn current(&self) -> Result<String, ProviderError> {
66/// if let Some(k) = self.cached.lock().unwrap().clone() {
67/// return Ok(k);
68/// }
69/// // Hit STS, cache, return... (omitted)
70/// Err(ProviderError::Auth("STS unavailable".into()))
71/// }
72///
73/// async fn invalidate(&self) -> Result<(), ProviderError> {
74/// self.cached.lock().unwrap().take();
75/// Ok(())
76/// }
77/// }
78/// ```
79#[async_trait::async_trait]
80pub trait CredentialProvider: std::fmt::Debug + Send + Sync {
81 /// Return the current API key for this credential. Implementations may cache,
82 /// re-fetch from a metadata service, or compute on the fly. Called once per
83 /// `StreamProvider::stream()` invocation.
84 async fn current(&self) -> Result<String, ProviderError>;
85
86 /// Hint that the current cached credential has been rejected by the upstream
87 /// API and a fresh value should be fetched on the next `current()` call.
88 ///
89 /// Default impl is a no-op for providers that always re-fetch.
90 async fn invalidate(&self) -> Result<(), ProviderError> {
91 Ok(())
92 }
93}
94
95/// Reference implementation of [`CredentialProvider`] that always returns a fixed key.
96///
97/// Useful for tests and for wiring a [`ModelConfig`] uniformly when refresh is not
98/// needed — equivalent to leaving `ModelConfig::credentials = None` and relying on
99/// the static `api_key` field, but lets test harnesses count `invalidate()` calls.
100#[derive(Debug, Clone)]
101pub struct StaticCredentialProvider {
102 key: String,
103}
104
105impl StaticCredentialProvider {
106 pub fn new(key: impl Into<String>) -> Self {
107 Self { key: key.into() }
108 }
109}
110
111#[async_trait::async_trait]
112impl CredentialProvider for StaticCredentialProvider {
113 async fn current(&self) -> Result<String, ProviderError> {
114 Ok(self.key.clone())
115 }
116}
117
118/// Which API protocol a model uses.
119/*
120ARCHITECTURE: ApiProtocol — the dispatch key for the provider registry
121
122`ProviderRegistry::for_protocol(api: ApiProtocol)` maps each variant to
123a concrete `StreamProvider` implementation:
124 AnthropicMessages → AnthropicProvider
125 OpenAiCompletions → OpenAiCompatProvider (handles 15+ providers)
126 OpenAiResponses → OpenAiResponsesProvider
127 AzureOpenAiResponses → AzureOpenAiProvider
128 GoogleGenerativeAi → GoogleProvider
129 GoogleVertex → GoogleVertexProvider
130 BedrockConverseStream → BedrockProvider
131
132This is the "Strategy via enum dispatch" pattern: the enum variant IS the strategy
133selector. The registry (registry.rs) `match`es on this enum and returns the right
134provider. At runtime, the caller only holds a `Box<dyn StreamProvider>` and never
135needs to know which variant was used.
136
137RUST QUIRK: `Hash` derive — required for use as HashMap keys
138
139`#[derive(Hash)]` enables values of this type to be used as keys in `HashMap<K, V>`.
140`Hash` computes an integer hash of the value. Combined with `PartialEq + Eq`
141(also derived), this is what HashMap needs:
142 - `Hash` to find the bucket
143 - `Eq` to confirm the key matches within the bucket (hash collisions)
144
145Why does `ApiProtocol` need to be a HashMap key?
146 In `ProviderRegistry`, we may store `HashMap<ApiProtocol, Box<dyn StreamProvider>>`.
147 Without `Hash + Eq`, that HashMap would fail to compile.
148
149RUST QUIRK: `Copy` on an enum with no data fields
150 All variants of `ApiProtocol` carry no data — they're just tags.
151 `Copy` lets the compiler bitwise-copy the value instead of moving it.
152 After `let api = model.api;`, `model.api` is STILL valid (Copy semantics).
153 Python analogy: Python enums are always by-reference, so no equivalent concept.
154
155RUST QUIRK: `#[serde(rename_all = "snake_case")]`
156 When serializing to JSON/YAML, variant names are converted to snake_case:
157 `AnthropicMessages` → "anthropic_messages"
158 `BedrockConverseStream` → "bedrock_converse_stream"
159 This makes config files human-readable without matching Rust's PascalCase convention.
160*/
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
162#[serde(rename_all = "snake_case")]
163pub enum ApiProtocol {
164 AnthropicMessages,
165 OpenAiCompletions,
166 OpenAiResponses,
167 AzureOpenAiResponses,
168 GoogleGenerativeAi,
169 GoogleVertex,
170 BedrockConverseStream,
171}
172
173impl std::fmt::Display for ApiProtocol {
174 /*
175 RUST QUIRK: Implementing `Display` manually (vs deriving it)
176
177 `Display` (the `{}` formatter) is NOT derivable — you must write it by hand.
178 `Debug` (the `{:?}` formatter) IS derivable.
179
180 Why? `Debug` is purely for developers (shows the Rust name), so auto-generated
181 is fine. `Display` is for end-users, and you control the string representation.
182
183 Here we return snake_case strings ("anthropic_messages") instead of the
184 Rust PascalCase names ("AnthropicMessages") — consistent with the serde rename.
185
186 `write!(f, "...")` — writes into the formatter buffer `f`.
187 Returns `fmt::Result` (Ok or Err), required by the trait.
188 Python analogy: implementing __str__(self) → return "anthropic_messages"
189 */
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 match self {
192 Self::AnthropicMessages => write!(f, "anthropic_messages"),
193 Self::OpenAiCompletions => write!(f, "openai_completions"),
194 Self::OpenAiResponses => write!(f, "openai_responses"),
195 Self::AzureOpenAiResponses => write!(f, "azure_openai_responses"),
196 Self::GoogleGenerativeAi => write!(f, "google_generative_ai"),
197 Self::GoogleVertex => write!(f, "google_vertex"),
198 Self::BedrockConverseStream => write!(f, "bedrock_converse_stream"),
199 }
200 }
201}
202
203/// Cost per million tokens (input/output).
204/*
205ARCHITECTURE: CostConfig — optional cost tracking
206
207LLM providers charge differently for input vs output tokens, and some offer
208reduced prices for cache reads and cache writes (Anthropic prompt caching).
209
210`CostConfig` is embedded in `ModelConfig` but has `#[serde(default)]` fields,
211meaning callers who don't care about cost tracking don't need to supply them —
212they default to 0.0.
213
214RUST QUIRK: `#[serde(default)]` — per-field default during deserialization
215 When deserializing a `ModelConfig`, if "cache_read_per_million" is absent in
216 the JSON/YAML, serde calls `Default::default()` for that field instead of
217 returning an error. This makes the struct forward-compatible: old config files
218 (without the cache fields) still deserialize correctly.
219 Python analogy: `dataclasses.field(default=0.0)` or `pydantic.Field(default=0.0)`
220*/
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct CostConfig {
223 pub input_per_million: f64,
224 pub output_per_million: f64,
225 #[serde(default)]
226 pub cache_read_per_million: f64,
227 #[serde(default)]
228 pub cache_write_per_million: f64,
229}
230
231impl Default for CostConfig {
232 fn default() -> Self {
233 Self {
234 input_per_million: 0.0,
235 output_per_million: 0.0,
236 cache_read_per_million: 0.0,
237 cache_write_per_million: 0.0,
238 }
239 }
240}
241
242/// How a provider handles the `max_tokens` field.
243/*
244ARCHITECTURE: MaxTokensField — a per-provider API quirk
245
246The OpenAI-compatible API has two field names for the same concept:
247 `max_tokens` — the original field name, used by most providers
248 `max_completion_tokens`— new name, required by OpenAI o-series reasoning models
249
250Both control the maximum number of tokens in the response, but OpenAI split
251them so reasoning token budgets are counted separately. The provider must use
252the correct field name, or the API returns an error.
253
254`MaxTokensField` is a small enum used as a flag inside `OpenAiCompat`, avoiding
255a raw `bool` (which would be less self-documenting).
256
257RUST QUIRK: `#[derive(Default)]` + `#[default]` on a variant
258 `#[derive(Default)]` auto-generates `Default::default()` for the enum.
259 `#[default]` on a specific variant marks it as the default value:
260 `MaxTokensField::default()` → `MaxTokensField::MaxTokens`
261 Without `#[default]`, the derive macro wouldn't know which variant to pick.
262 Python analogy: no direct equivalent; closest is Enum with a class variable for default.
263*/
264#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
265#[serde(rename_all = "snake_case")]
266pub enum MaxTokensField {
267 #[default]
268 MaxTokens,
269 MaxCompletionTokens,
270}
271
272/// How a provider formats thinking/reasoning output.
273/*
274ARCHITECTURE: ThinkingFormat — per-provider reasoning output format
275
276Extended thinking / chain-of-thought output is formatted differently by each provider:
277 `OpenAi` — reasoning appears in a dedicated `reasoning_content` array
278 `Xai` — Grok's format (slightly different JSON structure)
279 `Qwen` — Qwen's format (another variation)
280
281This flag tells `openai_compat.rs` which parsing branch to use when extracting
282thinking deltas from the streaming response. Without this flag, we'd need a
283separate provider file for each thinking-capable service.
284*/
285#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
286#[serde(rename_all = "snake_case")]
287pub enum ThinkingFormat {
288 #[default]
289 OpenAi,
290 Xai,
291 Qwen,
292 /// OpenRouter streaming format: reads thinking text from `delta.reasoning_details`
293 /// array entries where `type == "thinking"`.
294 OpenRouter,
295}
296
297/// Compatibility flags for OpenAI-compatible providers.
298/// Different providers have different quirks even though they share the same base API.
299/*
300ARCHITECTURE: OpenAiCompat — the "quirk matrix" for 15+ OpenAI-compatible providers
301
302The OpenAI Chat Completions API is a de-facto standard that dozens of providers
303implement. But every provider deviates in small ways:
304 - OpenAI o-series uses `max_completion_tokens` not `max_tokens`
305 - xAI (Grok) uses a different thinking output format
306 - Some providers don't include usage data in streaming chunks
307 - Some require a `name` field in tool results
308 - Some need a dummy assistant message inserted after tool results
309
310Instead of writing a separate provider for each quirk combination, we have ONE
311`openai_compat.rs` provider that reads `OpenAiCompat` flags at runtime and
312branches accordingly. New providers = new `OpenAiCompat::new_provider()` factory.
313
314The factory methods (`openai()`, `xai()`, `groq()`, ...) use `..Default::default()`
315struct update syntax to express only the fields that differ from defaults.
316Python analogy: a dataclass with defaults, and factory classmethods that override
317only the fields that need to change.
318*/
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct OpenAiCompat {
321 /// Supports the `store` parameter for conversation persistence.
322 pub supports_store: bool,
323 /// Supports `developer` role (system-level instructions).
324 pub supports_developer_role: bool,
325 /// Supports `reasoning_effort` parameter.
326 pub supports_reasoning_effort: bool,
327 /// Includes usage data in streaming responses.
328 pub supports_usage_in_streaming: bool,
329 /// Which field name to use for max tokens.
330 pub max_tokens_field: MaxTokensField,
331 /// Tool results must include a `name` field.
332 pub requires_tool_result_name: bool,
333 /// Must insert an assistant message after tool results.
334 pub requires_assistant_after_tool_result: bool,
335 /// How thinking/reasoning content is formatted in streaming.
336 pub thinking_format: ThinkingFormat,
337}
338
339impl Default for OpenAiCompat {
340 /*
341 RUST QUIRK: `impl Default` manually (rather than `#[derive(Default)]`)
342
343 `#[derive(Default)]` would work only if every field's type implements `Default`
344 AND the zero-values are the right defaults. Here, `supports_usage_in_streaming`
345 should default to `true`, not `false`. Since `bool` defaults to `false`, we
346 must override it manually.
347
348 A manually written `Default` impl is common when some field defaults are
349 non-trivial (non-zero numbers, non-empty strings, true booleans, etc.).
350 */
351 fn default() -> Self {
352 Self {
353 supports_store: false,
354 supports_developer_role: false,
355 supports_reasoning_effort: false,
356 supports_usage_in_streaming: true, // most OpenAI-compat providers include usage
357 max_tokens_field: MaxTokensField::MaxTokens,
358 requires_tool_result_name: false,
359 requires_assistant_after_tool_result: false,
360 thinking_format: ThinkingFormat::OpenAi,
361 }
362 }
363}
364
365impl OpenAiCompat {
366 /// Compat flags for native OpenAI.
367 /*
368 RUST QUIRK: `..Default::default()` — struct update syntax for overriding defaults
369
370 `Self { supports_store: true, ..Default::default() }` means:
371 "build a Self where supports_store = true (and supports_developer_role = true,
372 supports_reasoning_effort = true, max_tokens_field = MaxCompletionTokens)
373 and all OTHER fields come from Default::default()"
374
375 The `..expr` "spreads" the remaining fields from a base value.
376 It MUST be last in the struct literal.
377 Python analogy: dataclasses.replace(OpenAiCompat(), supports_store=True, ...)
378
379 Why is this better than repeating all fields?
380 - Fewer lines to write (only express differences from defaults)
381 - If a new field is added with a sensible default, existing factory methods
382 automatically get the right value — no manual update needed
383 */
384 pub fn openai() -> Self {
385 Self {
386 supports_store: true,
387 supports_developer_role: true,
388 supports_reasoning_effort: true,
389 supports_usage_in_streaming: true,
390 max_tokens_field: MaxTokensField::MaxCompletionTokens,
391 ..Default::default()
392 }
393 }
394
395 /// Compat flags for xAI (Grok).
396 pub fn xai() -> Self {
397 Self {
398 supports_usage_in_streaming: true,
399 thinking_format: ThinkingFormat::Xai, // Grok uses a different thinking JSON shape
400 ..Default::default()
401 }
402 }
403
404 /// Compat flags for Groq.
405 pub fn groq() -> Self {
406 Self {
407 supports_usage_in_streaming: true,
408 ..Default::default()
409 }
410 }
411
412 /// Compat flags for Cerebras.
413 pub fn cerebras() -> Self {
414 Self::default() // no deviations from defaults
415 }
416
417 /// Compat flags for OpenRouter.
418 pub fn openrouter() -> Self {
419 Self {
420 supports_developer_role: true, // OpenRouter supports "developer" role
421 supports_usage_in_streaming: true,
422 max_tokens_field: MaxTokensField::MaxTokens, // OpenRouter uses max_tokens (not max_completion_tokens)
423 thinking_format: ThinkingFormat::OpenRouter, // reasoning_details array format
424 ..Default::default()
425 }
426 }
427
428 /// Compat flags for Mistral.
429 pub fn mistral() -> Self {
430 Self {
431 supports_usage_in_streaming: true,
432 max_tokens_field: MaxTokensField::MaxTokens,
433 ..Default::default()
434 }
435 }
436
437 /// Compat flags for DeepSeek.
438 pub fn deepseek() -> Self {
439 Self {
440 supports_usage_in_streaming: true,
441 max_tokens_field: MaxTokensField::MaxCompletionTokens,
442 ..Default::default()
443 }
444 }
445}
446
447/// Full model configuration. Knows everything needed to make API calls.
448/*
449ARCHITECTURE: ModelConfig — the single source of truth for a model's identity
450
451`ModelConfig` bundles everything a provider needs to make API calls:
452 - `id` / `name` — which model to request (sent in the API body)
453 - `api` — which provider implementation to use (dispatch key)
454 - `provider` — human label for logging/display
455 - `base_url` — the HTTP endpoint (can be a private deployment or proxy)
456 - `reasoning` — whether this model supports extended thinking
457 - `context_window` — max input tokens (used for context compaction decisions)
458 - `max_tokens` — default output token limit
459 - `cost` — token pricing for cost tracking
460 - `headers` — additional HTTP headers (e.g., API-version headers)
461 - `compat` — OpenAI quirk flags (only for OpenAiCompletions protocol)
462
463Factory methods (`anthropic()`, `openai()`, `local()`, `google()`) cover common
464cases. Custom providers are built by constructing the struct directly.
465
466RUST QUIRK: `HashMap<String, String>` — a key-value dictionary
467 `HashMap<K, V>` from `std::collections` — Rust's standard hash map.
468 Here it stores additional HTTP headers like `{"X-My-Header": "value"}`.
469 Python analogy: `dict[str, str]`.
470 `#[serde(default)]` means it deserializes as an empty HashMap if absent in config.
471
472RUST QUIRK: `Option<OpenAiCompat>` — present only for OpenAI-compat providers
473 Anthropic/Google/Bedrock have their own provider files that don't use `compat`.
474 For them, `compat` is `None`. For OpenAI-compatible providers, `compat` is
475 `Some(OpenAiCompat { ... })`. This models "this field only makes sense for
476 a subset of configurations." The provider accesses it with `compat.as_ref()?` or
477 `compat.unwrap_or_default()`.
478*/
479#[derive(Debug, Clone, Serialize, Deserialize)]
480pub struct ModelConfig {
481 /// Model identifier sent to the API (e.g. "gpt-4o", "claude-sonnet-4-20250514").
482 pub id: String,
483 /// Human-friendly name.
484 pub name: String,
485 /// Which API protocol to use.
486 pub api: ApiProtocol,
487 /// Provider name (e.g. "openai", "anthropic", "xai").
488 pub provider: String,
489 /// Base URL for API requests (without trailing slash).
490 pub base_url: String,
491 /// Authentication credential for this provider (API key, Bearer token, or
492 /// `access_key:secret[:session_token]` for Bedrock).
493 /// Defaults to an empty string so config files can omit it and supply via env instead.
494 #[serde(default)]
495 pub api_key: String,
496 /// Whether this model supports reasoning/thinking.
497 pub reasoning: bool,
498 /// Context window size in tokens.
499 pub context_window: u32,
500 /// Default max output tokens.
501 pub max_tokens: u32,
502 /// Cost configuration.
503 #[serde(default)]
504 pub cost: CostConfig,
505 /// Additional headers to send with requests.
506 #[serde(default)]
507 pub headers: HashMap<String, String>,
508 /// OpenAI-compat quirk flags (only for OpenAiCompletions protocol).
509 #[serde(default)]
510 pub compat: Option<OpenAiCompat>,
511 /// Optional refreshable credential source. When `Some`, every `stream()` call
512 /// resolves the API key via `credentials.current().await` instead of reading
513 /// `api_key` directly; the retry loop calls `credentials.invalidate().await`
514 /// once on `ProviderError::Auth` and retries the call before propagating.
515 /// When `None` (the default), `api_key` is used verbatim, preserving the legacy
516 /// static-key behaviour.
517 #[serde(skip)]
518 pub credentials: Option<Arc<dyn CredentialProvider>>,
519}
520
521impl ModelConfig {
522 /// Create a new Anthropic model config.
523 pub fn anthropic(
524 id: impl Into<String>, // API ID — model identifier sent in the request body (e.g. "claude-sonnet-4-20250514")
525 name: impl Into<String>, // DISPLAY NAME — human-readable label for logging/UI; not sent to the API
526 api_key: impl Into<String>, // AUTH — "sk-ant-..." or OAuth token "sk-ant-oat..."
527 ) -> Self {
528 Self {
529 id: id.into(),
530 name: name.into(),
531 api: ApiProtocol::AnthropicMessages,
532 provider: "anthropic".into(),
533 base_url: "https://api.anthropic.com".into(),
534 api_key: api_key.into(),
535 reasoning: false,
536 context_window: 200_000,
537 max_tokens: 8192,
538 cost: CostConfig::default(),
539 headers: HashMap::new(),
540 compat: None, // Anthropic has its own protocol, no compat flags needed
541 credentials: None,
542 }
543 }
544
545 /// Create a new OpenAI model config.
546 pub fn openai(
547 id: impl Into<String>, // API ID — model identifier sent in the request body (e.g. "gpt-4o")
548 name: impl Into<String>, // DISPLAY NAME — human-readable label for logging/UI; not sent to the API
549 api_key: impl Into<String>, // AUTH — "sk-..."
550 ) -> Self {
551 Self {
552 id: id.into(),
553 name: name.into(),
554 api: ApiProtocol::OpenAiCompletions,
555 provider: "openai".into(),
556 base_url: "https://api.openai.com/v1".into(),
557 api_key: api_key.into(),
558 reasoning: false,
559 context_window: 128_000,
560 max_tokens: 4096,
561 cost: CostConfig::default(),
562 headers: HashMap::new(),
563 compat: Some(OpenAiCompat::openai()), // OpenAI needs compat flags (store, developer role, etc.)
564 credentials: None,
565 }
566 }
567
568 /// Create a config for a local OpenAI-compatible server (LM Studio, Ollama, etc.).
569 /// Pass an empty string for `api_key` — most local servers don't require authentication.
570 pub fn local(
571 base_url: impl Into<String>, // ENDPOINT — full base URL of the local server (e.g. "http://localhost:1234/v1")
572 model_id: impl Into<String>, // API ID — model name expected by the local server (e.g. "llama-3.1-8b")
573 api_key: impl Into<String>, // AUTH — empty string for unauthenticated local servers
574 ) -> Self {
575 Self {
576 id: model_id.into(),
577 name: "Local Model".into(),
578 api: ApiProtocol::OpenAiCompletions,
579 provider: "local".into(),
580 base_url: base_url.into(), // caller provides e.g. "http://localhost:1234/v1"
581 api_key: api_key.into(),
582 reasoning: false,
583 context_window: 128_000,
584 max_tokens: 4096,
585 cost: CostConfig::default(),
586 headers: HashMap::new(),
587 compat: Some(OpenAiCompat::default()), // most local servers are generic OpenAI-compat
588 credentials: None,
589 }
590 }
591
592 /// Create a new Google Generative AI (Gemini) model config.
593 pub fn google(
594 id: impl Into<String>, // API ID — model identifier sent in the request URL (e.g. "gemini-2.5-pro")
595 name: impl Into<String>, // DISPLAY NAME — human-readable label for logging/UI; not sent to the API
596 api_key: impl Into<String>, // AUTH — Google AI Studio API key
597 ) -> Self {
598 Self {
599 id: id.into(),
600 name: name.into(),
601 api: ApiProtocol::GoogleGenerativeAi,
602 provider: "google".into(),
603 base_url: "https://generativelanguage.googleapis.com".into(),
604 api_key: api_key.into(),
605 reasoning: false,
606 context_window: 1_000_000,
607 max_tokens: 8192,
608 cost: CostConfig::default(),
609 headers: HashMap::new(),
610 compat: None, // Google has its own protocol, no compat flags needed
611 credentials: None,
612 }
613 }
614
615 /// Create a new OpenRouter model config.
616 /// `model_id` uses the `provider/model` format (e.g. `"anthropic/claude-sonnet-4"`).
617 pub fn openrouter(
618 model_id: impl Into<String>, // API ID — "provider/model" format (e.g. "anthropic/claude-sonnet-4")
619 api_key: impl Into<String>, // AUTH — "sk-or-..."
620 ) -> Self {
621 let id = model_id.into();
622 Self {
623 name: id.clone(),
624 id,
625 api: ApiProtocol::OpenAiCompletions,
626 provider: "openrouter".into(),
627 base_url: "https://openrouter.ai/api/v1".into(),
628 api_key: api_key.into(),
629 reasoning: false,
630 context_window: 200_000, // conservative default; varies by routed model
631 max_tokens: 4096,
632 cost: CostConfig::default(),
633 headers: HashMap::new(),
634 compat: Some(OpenAiCompat::openrouter()),
635 credentials: None,
636 }
637 }
638
639 /// Attach a refreshable credential source. When set, the API key is resolved
640 /// per-call via `credentials.current().await` instead of being read directly
641 /// from `self.api_key`. The retry loop also calls `credentials.invalidate()`
642 /// once on `ProviderError::Auth` and re-tries the stream call.
643 pub fn with_credentials(mut self, creds: Arc<dyn CredentialProvider>) -> Self {
644 self.credentials = Some(creds);
645 self
646 }
647
648 /// Resolve the API key for an outgoing request.
649 ///
650 /// When `credentials` is set, delegate to its `current()` method (which may
651 /// re-fetch from a metadata service or return a cached value). Otherwise fall
652 /// back to the static `api_key` field. Providers should call this once at the
653 /// top of `stream()` instead of reading `api_key` directly.
654 pub async fn resolve_api_key(&self) -> Result<String, ProviderError> {
655 match &self.credentials {
656 Some(c) => c.current().await,
657 None => Ok(self.api_key.clone()),
658 }
659 }
660
661 /// Signal that the currently cached credential has been rejected. If
662 /// `credentials` is set, delegate to its `invalidate()`; otherwise a no-op.
663 /// Invoked by the streaming retry loop on `ProviderError::Auth`.
664 pub async fn invalidate_credentials(&self) -> Result<(), ProviderError> {
665 match &self.credentials {
666 Some(c) => c.invalidate().await,
667 None => Ok(()),
668 }
669 }
670}
671
672#[cfg(test)]
673mod tests {
674 use super::*;
675
676 #[test]
677 fn test_model_config_anthropic() {
678 let config =
679 ModelConfig::anthropic("claude-sonnet-4-20250514", "Claude Sonnet 4", "sk-ant-key");
680 assert_eq!(config.api, ApiProtocol::AnthropicMessages);
681 assert_eq!(config.provider, "anthropic");
682 assert_eq!(config.api_key, "sk-ant-key");
683 assert!(config.compat.is_none());
684 }
685
686 #[test]
687 fn test_model_config_openai() {
688 let config = ModelConfig::openai("gpt-4o", "GPT-4o", "sk-key");
689 assert_eq!(config.api, ApiProtocol::OpenAiCompletions);
690 let compat = config.compat.unwrap();
691 assert!(compat.supports_store);
692 assert!(compat.supports_developer_role);
693 assert_eq!(compat.max_tokens_field, MaxTokensField::MaxCompletionTokens);
694 }
695
696 #[test]
697 fn test_openai_compat_variants() {
698 let xai = OpenAiCompat::xai();
699 assert_eq!(xai.thinking_format, ThinkingFormat::Xai);
700 assert!(!xai.supports_store);
701
702 let groq = OpenAiCompat::groq();
703 assert!(groq.supports_usage_in_streaming);
704 assert!(!groq.supports_store);
705
706 let deepseek = OpenAiCompat::deepseek();
707 assert_eq!(
708 deepseek.max_tokens_field,
709 MaxTokensField::MaxCompletionTokens
710 );
711 }
712
713 #[test]
714 fn test_api_protocol_display() {
715 assert_eq!(
716 ApiProtocol::AnthropicMessages.to_string(),
717 "anthropic_messages"
718 );
719 assert_eq!(
720 ApiProtocol::OpenAiCompletions.to_string(),
721 "openai_completions"
722 );
723 assert_eq!(
724 ApiProtocol::GoogleGenerativeAi.to_string(),
725 "google_generative_ai"
726 );
727 }
728
729 #[test]
730 fn test_cost_config_default() {
731 let cost = CostConfig::default();
732 assert_eq!(cost.input_per_million, 0.0);
733 assert_eq!(cost.output_per_million, 0.0);
734 }
735}