Skip to main content

pi/
provider.rs

1//! LLM provider abstraction layer.
2//!
3//! This module defines the [`Provider`] trait and shared request/response types used by all
4//! backends (Anthropic/OpenAI/Gemini/etc).
5//!
6//! Providers are responsible for:
7//! - Translating [`crate::model::Message`] history into provider-specific HTTP requests.
8//! - Emitting [`StreamEvent`] values as SSE/HTTP chunks arrive.
9//! - Advertising tool schemas to the model (so it can call [`crate::tools`] by name).
10
11pub use crate::model::StreamEvent;
12use crate::model::{Message, ThinkingLevel};
13use async_trait::async_trait;
14use futures::Stream;
15use serde::{Deserialize, Serialize};
16use std::borrow::Cow;
17use std::collections::HashMap;
18use std::pin::Pin;
19
20// ============================================================================
21// Provider Trait
22// ============================================================================
23
24/// An LLM backend capable of streaming assistant output (and tool calls).
25///
26/// A `Provider` is typically configured for a specific API + model and is used by the agent loop
27/// to produce a stream of [`StreamEvent`] updates.
28#[async_trait]
29pub trait Provider: Send + Sync {
30    /// Get the provider name.
31    fn name(&self) -> &str;
32
33    /// Get the API type.
34    fn api(&self) -> &str;
35
36    /// Get the model identifier used by this provider.
37    fn model_id(&self) -> &str;
38
39    /// Start streaming a completion.
40    ///
41    /// Implementations should yield [`StreamEvent`] items as soon as they are decoded, and should
42    /// stop promptly when the request is cancelled.
43    async fn stream(
44        &self,
45        context: &Context<'_>,
46        options: &StreamOptions,
47    ) -> crate::error::Result<Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>>;
48}
49
50// ============================================================================
51// Context
52// ============================================================================
53
54/// Inputs to a single completion request.
55///
56/// The agent loop builds a `Context` from the current session state and tool registry, then hands
57/// it to a [`Provider`] implementation to perform provider-specific request encoding.
58///
59/// Uses [`Cow`] for `messages` and `tools` to avoid deep-cloning the full conversation history on
60/// every turn when no mutation is needed (the common case).
61#[derive(Debug, Clone)]
62pub struct Context<'a> {
63    /// Provider-specific system prompt content.
64    ///
65    /// Uses [`Cow`] to borrow from `AgentConfig.system_prompt` on every turn without
66    /// cloning.  Providers that need an owned `String` can call `.into_owned()`.
67    pub system_prompt: Option<Cow<'a, str>>,
68    /// Conversation history (user/assistant/tool results).
69    pub messages: Cow<'a, [Message]>,
70    /// Tool definitions available to the model for this request.
71    pub tools: Cow<'a, [ToolDef]>,
72}
73
74impl Default for Context<'_> {
75    fn default() -> Self {
76        Self {
77            system_prompt: None,
78            messages: Cow::Owned(Vec::new()),
79            tools: Cow::Owned(Vec::new()),
80        }
81    }
82}
83
84impl Context<'_> {
85    /// Create a `Context` with fully-owned data (no borrowing).
86    ///
87    /// Convenient for tests and one-off callers that already have owned vectors.
88    pub fn owned(
89        system_prompt: Option<String>,
90        messages: Vec<Message>,
91        tools: Vec<ToolDef>,
92    ) -> Context<'static> {
93        Context {
94            system_prompt: system_prompt.map(Cow::Owned),
95            messages: Cow::Owned(messages),
96            tools: Cow::Owned(tools),
97        }
98    }
99}
100
101// ============================================================================
102// Tool Definition
103// ============================================================================
104
105/// A tool definition exposed to the model.
106///
107/// Providers translate this struct into the backend's tool/schema representation (typically JSON
108/// Schema) so the model can emit tool calls that the host executes locally.
109#[derive(Debug, Clone)]
110pub struct ToolDef {
111    pub name: String,
112    pub description: String,
113    pub parameters: serde_json::Value, // JSON Schema
114}
115
116// ============================================================================
117// Stream Options
118// ============================================================================
119
120/// Options that control streaming completion behavior.
121///
122/// Most options are passed through to the provider request (temperature, max tokens, headers).
123/// Some fields are Pi-specific conveniences (e.g. `session_id` for logging/correlation).
124#[derive(Debug, Clone, Default)]
125pub struct StreamOptions {
126    pub temperature: Option<f32>,
127    pub max_tokens: Option<u32>,
128    pub api_key: Option<String>,
129    pub cache_retention: CacheRetention,
130    pub session_id: Option<String>,
131    pub headers: HashMap<String, String>,
132    pub thinking_level: Option<ThinkingLevel>,
133    pub thinking_budgets: Option<ThinkingBudgets>,
134}
135
136/// Cache retention policy.
137#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
138pub enum CacheRetention {
139    #[default]
140    None,
141    /// Provider-managed short-lived caching (provider-specific semantics).
142    Short,
143    /// Provider-managed long-lived caching (e.g. ~1 hour TTL on Anthropic).
144    Long,
145}
146
147/// Custom thinking token budgets per level.
148#[derive(Debug, Clone)]
149pub struct ThinkingBudgets {
150    pub minimal: u32,
151    pub low: u32,
152    pub medium: u32,
153    pub high: u32,
154    pub xhigh: u32,
155}
156
157impl Default for ThinkingBudgets {
158    fn default() -> Self {
159        Self {
160            minimal: 1024,
161            low: 2048,
162            medium: 8192,
163            high: 16384,
164            xhigh: 32768, // Default to double high, or model max? Let's pick a reasonable default.
165        }
166    }
167}
168
169// ============================================================================
170// Model Definition
171// ============================================================================
172
173/// A model definition loaded from the models registry.
174///
175/// This struct is used to drive provider selection, request limits (context window/max tokens),
176/// and cost accounting.
177#[derive(Debug, Clone, Serialize)]
178pub struct Model {
179    pub id: String,
180    pub name: String,
181    pub api: String,
182    pub provider: String,
183    pub base_url: String,
184    pub reasoning: bool,
185    pub input: Vec<InputType>,
186    pub cost: ModelCost,
187    pub context_window: u32,
188    pub max_tokens: u32,
189    pub headers: HashMap<String, String>,
190}
191
192/// Input types supported by a model.
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
194#[serde(rename_all = "lowercase")]
195pub enum InputType {
196    Text,
197    Image,
198}
199
200/// Model pricing per million tokens.
201#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
202#[serde(rename_all = "camelCase")]
203pub struct ModelCost {
204    pub input: f64,
205    pub output: f64,
206    pub cache_read: f64,
207    pub cache_write: f64,
208}
209
210impl Model {
211    /// Calculate cost for usage.
212    #[allow(clippy::cast_precision_loss)] // Token counts within practical range won't lose precision
213    pub fn calculate_cost(
214        &self,
215        input: u64,
216        output: u64,
217        cache_read: u64,
218        cache_write: u64,
219    ) -> f64 {
220        let input_cost = (self.cost.input / 1_000_000.0) * input as f64;
221        let output_cost = (self.cost.output / 1_000_000.0) * output as f64;
222        let cache_read_cost = (self.cost.cache_read / 1_000_000.0) * cache_read as f64;
223        let cache_write_cost = (self.cost.cache_write / 1_000_000.0) * cache_write as f64;
224        input_cost + output_cost + cache_read_cost + cache_write_cost
225    }
226}
227
228// ============================================================================
229// Known APIs and Providers
230// ============================================================================
231
232/// Known API types.
233#[derive(Debug, Clone, PartialEq, Eq)]
234pub enum Api {
235    AnthropicMessages,
236    OpenAICompletions,
237    OpenAIResponses,
238    OpenAICodexResponses,
239    AzureOpenAIResponses,
240    BedrockConverseStream,
241    GoogleGenerativeAI,
242    GoogleGeminiCli,
243    GoogleVertex,
244    Custom(String),
245}
246
247impl std::fmt::Display for Api {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        match self {
250            Self::AnthropicMessages => write!(f, "anthropic-messages"),
251            Self::OpenAICompletions => write!(f, "openai-completions"),
252            Self::OpenAIResponses => write!(f, "openai-responses"),
253            Self::OpenAICodexResponses => write!(f, "openai-codex-responses"),
254            Self::AzureOpenAIResponses => write!(f, "azure-openai-responses"),
255            Self::BedrockConverseStream => write!(f, "bedrock-converse-stream"),
256            Self::GoogleGenerativeAI => write!(f, "google-generative-ai"),
257            Self::GoogleGeminiCli => write!(f, "google-gemini-cli"),
258            Self::GoogleVertex => write!(f, "google-vertex"),
259            Self::Custom(s) => write!(f, "{s}"),
260        }
261    }
262}
263
264impl std::str::FromStr for Api {
265    type Err = String;
266
267    fn from_str(s: &str) -> Result<Self, Self::Err> {
268        match s {
269            "anthropic-messages" => Ok(Self::AnthropicMessages),
270            "openai-completions" => Ok(Self::OpenAICompletions),
271            "openai-responses" => Ok(Self::OpenAIResponses),
272            "openai-codex-responses" => Ok(Self::OpenAICodexResponses),
273            "azure-openai-responses" => Ok(Self::AzureOpenAIResponses),
274            "bedrock-converse-stream" => Ok(Self::BedrockConverseStream),
275            "google-generative-ai" => Ok(Self::GoogleGenerativeAI),
276            "google-gemini-cli" => Ok(Self::GoogleGeminiCli),
277            "google-vertex" => Ok(Self::GoogleVertex),
278            other if !other.is_empty() => Ok(Self::Custom(other.to_string())),
279            _ => Err("API identifier cannot be empty".to_string()),
280        }
281    }
282}
283
284/// Known providers.
285#[derive(Debug, Clone, PartialEq, Eq)]
286#[allow(clippy::upper_case_acronyms)] // These are proper names/brands
287pub enum KnownProvider {
288    Anthropic,
289    OpenAI,
290    Google,
291    GoogleVertex,
292    AmazonBedrock,
293    AzureOpenAI,
294    GithubCopilot,
295    XAI,
296    Groq,
297    Cerebras,
298    OpenRouter,
299    Mistral,
300    Custom(String),
301}
302
303impl std::fmt::Display for KnownProvider {
304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305        match self {
306            Self::Anthropic => write!(f, "anthropic"),
307            Self::OpenAI => write!(f, "openai"),
308            Self::Google => write!(f, "google"),
309            Self::GoogleVertex => write!(f, "google-vertex"),
310            Self::AmazonBedrock => write!(f, "amazon-bedrock"),
311            Self::AzureOpenAI => write!(f, "azure-openai"),
312            Self::GithubCopilot => write!(f, "github-copilot"),
313            Self::XAI => write!(f, "xai"),
314            Self::Groq => write!(f, "groq"),
315            Self::Cerebras => write!(f, "cerebras"),
316            Self::OpenRouter => write!(f, "openrouter"),
317            Self::Mistral => write!(f, "mistral"),
318            Self::Custom(s) => write!(f, "{s}"),
319        }
320    }
321}
322
323impl std::str::FromStr for KnownProvider {
324    type Err = String;
325
326    fn from_str(s: &str) -> Result<Self, Self::Err> {
327        match s {
328            "anthropic" => Ok(Self::Anthropic),
329            "openai" => Ok(Self::OpenAI),
330            "google" => Ok(Self::Google),
331            "google-vertex" => Ok(Self::GoogleVertex),
332            "amazon-bedrock" => Ok(Self::AmazonBedrock),
333            "azure-openai" | "azure" | "azure-cognitive-services" => Ok(Self::AzureOpenAI),
334            "github-copilot" => Ok(Self::GithubCopilot),
335            "xai" => Ok(Self::XAI),
336            "groq" => Ok(Self::Groq),
337            "cerebras" => Ok(Self::Cerebras),
338            "openrouter" => Ok(Self::OpenRouter),
339            "mistral" => Ok(Self::Mistral),
340            other if !other.is_empty() => Ok(Self::Custom(other.to_string())),
341            _ => Err("Provider identifier cannot be empty".to_string()),
342        }
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    // ========================================================================
351    // Api enum: FromStr + Display round-trips
352    // ========================================================================
353
354    #[test]
355    fn api_from_str_known_variants() {
356        let cases = [
357            ("anthropic-messages", Api::AnthropicMessages),
358            ("openai-completions", Api::OpenAICompletions),
359            ("openai-responses", Api::OpenAIResponses),
360            ("openai-codex-responses", Api::OpenAICodexResponses),
361            ("azure-openai-responses", Api::AzureOpenAIResponses),
362            ("bedrock-converse-stream", Api::BedrockConverseStream),
363            ("google-generative-ai", Api::GoogleGenerativeAI),
364            ("google-gemini-cli", Api::GoogleGeminiCli),
365            ("google-vertex", Api::GoogleVertex),
366        ];
367        for (input, expected) in &cases {
368            let parsed: Api = input.parse().unwrap();
369            assert_eq!(&parsed, expected, "from_str({input})");
370        }
371    }
372
373    #[test]
374    fn api_display_known_variants() {
375        let cases = [
376            (Api::AnthropicMessages, "anthropic-messages"),
377            (Api::OpenAICompletions, "openai-completions"),
378            (Api::OpenAIResponses, "openai-responses"),
379            (Api::OpenAICodexResponses, "openai-codex-responses"),
380            (Api::AzureOpenAIResponses, "azure-openai-responses"),
381            (Api::BedrockConverseStream, "bedrock-converse-stream"),
382            (Api::GoogleGenerativeAI, "google-generative-ai"),
383            (Api::GoogleGeminiCli, "google-gemini-cli"),
384            (Api::GoogleVertex, "google-vertex"),
385        ];
386        for (variant, expected) in &cases {
387            assert_eq!(&variant.to_string(), expected, "display for {variant:?}");
388        }
389    }
390
391    #[test]
392    fn api_round_trip_all_known() {
393        let variants = [
394            Api::AnthropicMessages,
395            Api::OpenAICompletions,
396            Api::OpenAIResponses,
397            Api::OpenAICodexResponses,
398            Api::AzureOpenAIResponses,
399            Api::BedrockConverseStream,
400            Api::GoogleGenerativeAI,
401            Api::GoogleGeminiCli,
402            Api::GoogleVertex,
403        ];
404        for variant in &variants {
405            let s = variant.to_string();
406            let parsed: Api = s.parse().unwrap();
407            assert_eq!(&parsed, variant, "round-trip failed for {variant:?} -> {s}");
408        }
409    }
410
411    #[test]
412    fn api_custom_variant() {
413        let parsed: Api = "my-custom-api".parse().unwrap();
414        assert_eq!(parsed, Api::Custom("my-custom-api".to_string()));
415        assert_eq!(parsed.to_string(), "my-custom-api");
416    }
417
418    #[test]
419    fn api_empty_string_rejected() {
420        let result: Result<Api, _> = "".parse();
421        assert!(result.is_err());
422        assert_eq!(result.unwrap_err(), "API identifier cannot be empty");
423    }
424
425    // ========================================================================
426    // KnownProvider enum: FromStr + Display round-trips
427    // ========================================================================
428
429    #[test]
430    fn provider_from_str_known_variants() {
431        let cases = [
432            ("anthropic", KnownProvider::Anthropic),
433            ("openai", KnownProvider::OpenAI),
434            ("google", KnownProvider::Google),
435            ("google-vertex", KnownProvider::GoogleVertex),
436            ("amazon-bedrock", KnownProvider::AmazonBedrock),
437            ("azure-openai", KnownProvider::AzureOpenAI),
438            ("azure", KnownProvider::AzureOpenAI),
439            ("azure-cognitive-services", KnownProvider::AzureOpenAI),
440            ("github-copilot", KnownProvider::GithubCopilot),
441            ("xai", KnownProvider::XAI),
442            ("groq", KnownProvider::Groq),
443            ("cerebras", KnownProvider::Cerebras),
444            ("openrouter", KnownProvider::OpenRouter),
445            ("mistral", KnownProvider::Mistral),
446        ];
447        for (input, expected) in &cases {
448            let parsed: KnownProvider = input.parse().unwrap();
449            assert_eq!(&parsed, expected, "from_str({input})");
450        }
451    }
452
453    #[test]
454    fn provider_display_known_variants() {
455        let cases = [
456            (KnownProvider::Anthropic, "anthropic"),
457            (KnownProvider::OpenAI, "openai"),
458            (KnownProvider::Google, "google"),
459            (KnownProvider::GoogleVertex, "google-vertex"),
460            (KnownProvider::AmazonBedrock, "amazon-bedrock"),
461            (KnownProvider::AzureOpenAI, "azure-openai"),
462            (KnownProvider::GithubCopilot, "github-copilot"),
463            (KnownProvider::XAI, "xai"),
464            (KnownProvider::Groq, "groq"),
465            (KnownProvider::Cerebras, "cerebras"),
466            (KnownProvider::OpenRouter, "openrouter"),
467            (KnownProvider::Mistral, "mistral"),
468        ];
469        for (variant, expected) in &cases {
470            assert_eq!(&variant.to_string(), expected, "display for {variant:?}");
471        }
472    }
473
474    #[test]
475    fn provider_round_trip_all_known() {
476        let variants = [
477            KnownProvider::Anthropic,
478            KnownProvider::OpenAI,
479            KnownProvider::Google,
480            KnownProvider::GoogleVertex,
481            KnownProvider::AmazonBedrock,
482            KnownProvider::AzureOpenAI,
483            KnownProvider::GithubCopilot,
484            KnownProvider::XAI,
485            KnownProvider::Groq,
486            KnownProvider::Cerebras,
487            KnownProvider::OpenRouter,
488            KnownProvider::Mistral,
489        ];
490        for variant in &variants {
491            let s = variant.to_string();
492            let parsed: KnownProvider = s.parse().unwrap();
493            assert_eq!(&parsed, variant, "round-trip failed for {variant:?} -> {s}");
494        }
495    }
496
497    #[test]
498    fn provider_custom_variant() {
499        let parsed: KnownProvider = "my-custom-provider".parse().unwrap();
500        assert_eq!(
501            parsed,
502            KnownProvider::Custom("my-custom-provider".to_string())
503        );
504        assert_eq!(parsed.to_string(), "my-custom-provider");
505    }
506
507    #[test]
508    fn provider_empty_string_rejected() {
509        let result: Result<KnownProvider, _> = "".parse();
510        assert!(result.is_err());
511        assert_eq!(result.unwrap_err(), "Provider identifier cannot be empty");
512    }
513
514    // ========================================================================
515    // Model::calculate_cost
516    // ========================================================================
517
518    fn test_model() -> Model {
519        Model {
520            id: "test-model".to_string(),
521            name: "Test Model".to_string(),
522            api: "anthropic-messages".to_string(),
523            provider: "anthropic".to_string(),
524            base_url: "https://api.anthropic.com".to_string(),
525            reasoning: false,
526            input: vec![InputType::Text],
527            cost: ModelCost {
528                input: 3.0,   // $3 per million input tokens
529                output: 15.0, // $15 per million output tokens
530                cache_read: 0.3,
531                cache_write: 3.75,
532            },
533            context_window: 200_000,
534            max_tokens: 8192,
535            headers: HashMap::new(),
536        }
537    }
538
539    #[test]
540    fn calculate_cost_basic() {
541        let model = test_model();
542        // 1000 input tokens at $3/M = $0.003
543        // 500 output tokens at $15/M = $0.0075
544        let cost = model.calculate_cost(1000, 500, 0, 0);
545        let input_expected = (3.0 / 1_000_000.0) * 1000.0;
546        let output_expected = (15.0 / 1_000_000.0) * 500.0;
547        let expected = input_expected + output_expected;
548        assert!(
549            (cost - expected).abs() < f64::EPSILON,
550            "expected {expected}, got {cost}"
551        );
552    }
553
554    #[test]
555    fn calculate_cost_with_cache() {
556        let model = test_model();
557        let cost = model.calculate_cost(1000, 500, 2000, 1000);
558        let input_expected = (3.0 / 1_000_000.0) * 1000.0;
559        let output_expected = (15.0 / 1_000_000.0) * 500.0;
560        let cache_read_expected = (0.3 / 1_000_000.0) * 2000.0;
561        let cache_write_expected = (3.75 / 1_000_000.0) * 1000.0;
562        let expected =
563            input_expected + output_expected + cache_read_expected + cache_write_expected;
564        assert!(
565            (cost - expected).abs() < 1e-12,
566            "expected {expected}, got {cost}"
567        );
568    }
569
570    #[test]
571    fn calculate_cost_zero_tokens() {
572        let model = test_model();
573        let cost = model.calculate_cost(0, 0, 0, 0);
574        assert!((cost).abs() < f64::EPSILON, "zero tokens should cost $0");
575    }
576
577    #[test]
578    fn calculate_cost_large_token_count() {
579        let model = test_model();
580        // 1 million tokens each
581        let cost = model.calculate_cost(1_000_000, 1_000_000, 0, 0);
582        let expected = 3.0 + 15.0; // $3 input + $15 output
583        assert!(
584            (cost - expected).abs() < 1e-10,
585            "expected {expected}, got {cost}"
586        );
587    }
588
589    // ========================================================================
590    // Default values
591    // ========================================================================
592
593    #[test]
594    fn thinking_budgets_default() {
595        let budgets = ThinkingBudgets::default();
596        assert_eq!(budgets.minimal, 1024);
597        assert_eq!(budgets.low, 2048);
598        assert_eq!(budgets.medium, 8192);
599        assert_eq!(budgets.high, 16384);
600        assert_eq!(budgets.xhigh, 32768);
601    }
602
603    #[test]
604    fn cache_retention_default_is_none() {
605        assert_eq!(CacheRetention::default(), CacheRetention::None);
606    }
607
608    #[test]
609    fn stream_options_default() {
610        let opts = StreamOptions::default();
611        assert!(opts.temperature.is_none());
612        assert!(opts.max_tokens.is_none());
613        assert!(opts.api_key.is_none());
614        assert_eq!(opts.cache_retention, CacheRetention::None);
615        assert!(opts.session_id.is_none());
616        assert!(opts.headers.is_empty());
617        assert!(opts.thinking_level.is_none());
618        assert!(opts.thinking_budgets.is_none());
619    }
620
621    #[test]
622    fn context_default() {
623        let ctx = Context::default();
624        assert!(ctx.system_prompt.is_none());
625        assert!(ctx.messages.is_empty());
626        assert!(ctx.tools.is_empty());
627    }
628
629    // ========================================================================
630    // InputType serde
631    // ========================================================================
632
633    #[test]
634    fn input_type_serialization() {
635        let text_json = serde_json::to_string(&InputType::Text).unwrap();
636        assert_eq!(text_json, "\"text\"");
637
638        let image_json = serde_json::to_string(&InputType::Image).unwrap();
639        assert_eq!(image_json, "\"image\"");
640
641        let text: InputType = serde_json::from_str("\"text\"").unwrap();
642        assert_eq!(text, InputType::Text);
643
644        let image: InputType = serde_json::from_str("\"image\"").unwrap();
645        assert_eq!(image, InputType::Image);
646    }
647
648    // ========================================================================
649    // ModelCost serde
650    // ========================================================================
651
652    #[test]
653    fn model_cost_camel_case_serialization() {
654        let cost = ModelCost {
655            input: 3.0,
656            output: 15.0,
657            cache_read: 0.3,
658            cache_write: 3.75,
659        };
660        let json = serde_json::to_string(&cost).unwrap();
661        assert!(
662            json.contains("\"cacheRead\""),
663            "should use camelCase: {json}"
664        );
665        assert!(
666            json.contains("\"cacheWrite\""),
667            "should use camelCase: {json}"
668        );
669
670        let parsed: ModelCost = serde_json::from_str(&json).unwrap();
671        assert!((parsed.input - 3.0).abs() < f64::EPSILON);
672        assert!((parsed.cache_read - 0.3).abs() < f64::EPSILON);
673    }
674
675    mod proptests {
676        use super::*;
677        use proptest::prelude::*;
678
679        fn arb_model(rate: f64) -> Model {
680            Model {
681                id: "m".to_string(),
682                name: "m".to_string(),
683                api: "anthropic-messages".to_string(),
684                provider: "test".to_string(),
685                base_url: String::new(),
686                reasoning: false,
687                input: vec![InputType::Text],
688                cost: ModelCost {
689                    input: rate,
690                    output: rate,
691                    cache_read: rate,
692                    cache_write: rate,
693                },
694                context_window: 128_000,
695                max_tokens: 8192,
696                headers: HashMap::new(),
697            }
698        }
699
700        // ====================================================================
701        // calculate_cost
702        // ====================================================================
703
704        proptest! {
705            #[test]
706            fn cost_zero_tokens_is_zero(rate in 0.0f64..1000.0) {
707                let m = arb_model(rate);
708                let cost = m.calculate_cost(0, 0, 0, 0);
709                assert!((cost).abs() < f64::EPSILON);
710            }
711
712            #[test]
713            fn cost_non_negative(
714                rate in 0.0f64..100.0,
715                input in 0u64..10_000_000,
716                output in 0u64..10_000_000,
717                cr in 0u64..10_000_000,
718                cw in 0u64..10_000_000,
719            ) {
720                let m = arb_model(rate);
721                assert!(m.calculate_cost(input, output, cr, cw) >= 0.0);
722            }
723
724            #[test]
725            fn cost_linearity(
726                rate in 0.001f64..50.0,
727                tokens in 1u64..1_000_000,
728            ) {
729                let m = arb_model(rate);
730                let single = m.calculate_cost(tokens, 0, 0, 0);
731                let double = m.calculate_cost(tokens * 2, 0, 0, 0);
732                assert!((double - single * 2.0).abs() < 1e-6,
733                    "doubling tokens should double cost: single={single}, double={double}");
734            }
735
736            #[test]
737            fn cost_additivity(
738                rate in 0.001f64..50.0,
739                input in 0u64..1_000_000,
740                output in 0u64..1_000_000,
741            ) {
742                let m = arb_model(rate);
743                let combined = m.calculate_cost(input, output, 0, 0);
744                let separate = m.calculate_cost(input, 0, 0, 0)
745                    + m.calculate_cost(0, output, 0, 0);
746                assert!((combined - separate).abs() < 1e-10,
747                    "cost should be additive: combined={combined}, separate={separate}");
748            }
749        }
750
751        // ====================================================================
752        // Api FromStr + Display round-trip
753        // ====================================================================
754
755        proptest! {
756            #[test]
757            fn api_custom_round_trip(s in "[a-z][a-z0-9-]{0,20}") {
758                let known = [
759                    "anthropic-messages", "openai-completions", "openai-responses", "openai-codex-responses",
760                    "azure-openai-responses", "bedrock-converse-stream",
761                    "google-generative-ai", "google-gemini-cli", "google-vertex",
762                ];
763                if !known.contains(&s.as_str()) {
764                    let parsed: Api = s.parse().unwrap();
765                    assert_eq!(parsed.to_string(), s);
766                }
767            }
768
769            #[test]
770            fn api_never_panics(s in ".*") {
771                let _ = s.parse::<Api>(); // must not panic
772            }
773
774            #[test]
775            fn api_empty_always_rejected(ws in "[ \t]*") {
776                if ws.is_empty() {
777                    assert!(ws.parse::<Api>().is_err());
778                }
779            }
780        }
781
782        // ====================================================================
783        // KnownProvider FromStr + Display round-trip
784        // ====================================================================
785
786        proptest! {
787            #[test]
788            fn provider_custom_round_trip(s in "[a-z][a-z0-9-]{0,20}") {
789                let known = [
790                    "anthropic", "openai", "google", "google-vertex",
791                    "amazon-bedrock", "azure-openai", "azure",
792                    "azure-cognitive-services", "github-copilot", "xai",
793                    "groq", "cerebras", "openrouter", "mistral",
794                ];
795                if !known.contains(&s.as_str()) {
796                    let parsed: KnownProvider = s.parse().unwrap();
797                    assert_eq!(parsed.to_string(), s);
798                }
799            }
800
801            #[test]
802            fn provider_never_panics(s in ".*") {
803                let _ = s.parse::<KnownProvider>(); // must not panic
804            }
805        }
806
807        // ====================================================================
808        // ModelCost serde round-trip
809        // ====================================================================
810
811        proptest! {
812            #[test]
813            fn model_cost_serde_round_trip(
814                input in 0.0f64..1000.0,
815                output in 0.0f64..1000.0,
816                cr in 0.0f64..1000.0,
817                cw in 0.0f64..1000.0,
818            ) {
819                let cost = ModelCost { input, output, cache_read: cr, cache_write: cw };
820                let json = serde_json::to_string(&cost).unwrap();
821                let parsed: ModelCost = serde_json::from_str(&json).unwrap();
822                assert!((parsed.input - input).abs() < 1e-10);
823                assert!((parsed.output - output).abs() < 1e-10);
824                assert!((parsed.cache_read - cr).abs() < 1e-10);
825                assert!((parsed.cache_write - cw).abs() < 1e-10);
826            }
827        }
828    }
829}