1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::fmt;
8use std::str::FromStr;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
12#[serde(rename_all = "lowercase")]
13pub enum Provider {
14 #[default]
15 OpenAI,
16 Anthropic,
17 Google,
18}
19
20impl Provider {
21 pub const ALL: &'static [Provider] = &[Provider::OpenAI, Provider::Anthropic, Provider::Google];
23
24 pub const fn name(&self) -> &'static str {
26 match self {
27 Self::OpenAI => "openai",
28 Self::Anthropic => "anthropic",
29 Self::Google => "google",
30 }
31 }
32
33 pub const fn default_model(&self) -> &'static str {
35 match self {
36 Self::OpenAI => "gpt-5.1",
37 Self::Anthropic => "claude-sonnet-4-5-20250929",
38 Self::Google => "gemini-3-pro-preview",
39 }
40 }
41
42 pub const fn default_fast_model(&self) -> &'static str {
44 match self {
45 Self::OpenAI => "gpt-5.1-mini",
46 Self::Anthropic => "claude-haiku-4-5-20251001",
47 Self::Google => "gemini-2.5-flash",
48 }
49 }
50
51 pub const fn context_window(&self) -> usize {
53 match self {
54 Self::OpenAI => 128_000,
55 Self::Anthropic => 200_000,
56 Self::Google => 1_000_000,
57 }
58 }
59
60 pub const fn api_key_env(&self) -> &'static str {
62 match self {
63 Self::OpenAI => "OPENAI_API_KEY",
64 Self::Anthropic => "ANTHROPIC_API_KEY",
65 Self::Google => "GOOGLE_API_KEY",
66 }
67 }
68
69 pub fn all_names() -> Vec<&'static str> {
71 Self::ALL.iter().map(Self::name).collect()
72 }
73}
74
75impl FromStr for Provider {
76 type Err = ProviderError;
77
78 fn from_str(s: &str) -> Result<Self, Self::Err> {
79 let lower = s.to_lowercase();
80 let normalized = if lower == "claude" {
82 "anthropic"
83 } else {
84 &lower
85 };
86
87 Self::ALL
88 .iter()
89 .find(|p| p.name() == normalized)
90 .copied()
91 .ok_or_else(|| ProviderError::Unknown(s.to_string()))
92 }
93}
94
95impl fmt::Display for Provider {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 write!(f, "{}", self.name())
98 }
99}
100
101#[derive(Debug, thiserror::Error)]
103pub enum ProviderError {
104 #[error("Unknown provider: {0}. Supported: openai, anthropic, google")]
105 Unknown(String),
106 #[error("API key required for provider: {0}")]
107 MissingApiKey(String),
108}
109
110#[derive(Debug, Clone, Default, Serialize, Deserialize)]
112pub struct ProviderConfig {
113 #[serde(default, skip_serializing_if = "String::is_empty")]
115 pub api_key: String,
116 #[serde(default, skip_serializing_if = "String::is_empty")]
118 pub model: String,
119 #[serde(default, skip_serializing_if = "Option::is_none")]
121 pub fast_model: Option<String>,
122 #[serde(default, skip_serializing_if = "Option::is_none")]
124 pub token_limit: Option<usize>,
125 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
127 pub additional_params: HashMap<String, String>,
128}
129
130impl ProviderConfig {
131 pub fn with_defaults(provider: Provider) -> Self {
133 Self {
134 api_key: String::new(),
135 model: provider.default_model().to_string(),
136 fast_model: Some(provider.default_fast_model().to_string()),
137 token_limit: None,
138 additional_params: HashMap::new(),
139 }
140 }
141
142 pub fn effective_model(&self, provider: Provider) -> &str {
144 if self.model.is_empty() {
145 provider.default_model()
146 } else {
147 &self.model
148 }
149 }
150
151 pub fn effective_fast_model(&self, provider: Provider) -> &str {
153 self.fast_model
154 .as_deref()
155 .unwrap_or_else(|| provider.default_fast_model())
156 }
157
158 pub fn effective_token_limit(&self, provider: Provider) -> usize {
160 self.token_limit
161 .unwrap_or_else(|| provider.context_window())
162 }
163
164 pub fn has_api_key(&self) -> bool {
166 !self.api_key.is_empty()
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_provider_from_str() {
176 assert_eq!("openai".parse::<Provider>().ok(), Some(Provider::OpenAI));
177 assert_eq!(
178 "ANTHROPIC".parse::<Provider>().ok(),
179 Some(Provider::Anthropic)
180 );
181 assert_eq!("claude".parse::<Provider>().ok(), Some(Provider::Anthropic)); assert!("invalid".parse::<Provider>().is_err());
183 }
184
185 #[test]
186 fn test_provider_defaults() {
187 assert_eq!(Provider::OpenAI.default_model(), "gpt-5.1");
188 assert_eq!(Provider::Anthropic.context_window(), 200_000);
189 assert_eq!(Provider::Google.api_key_env(), "GOOGLE_API_KEY");
190 }
191
192 #[test]
193 fn test_provider_config_defaults() {
194 let config = ProviderConfig::with_defaults(Provider::Anthropic);
195 assert_eq!(config.model, "claude-sonnet-4-5-20250929");
196 assert_eq!(
197 config.fast_model.as_deref(),
198 Some("claude-haiku-4-5-20251001")
199 );
200 }
201}