Skip to main content

wraith_api/providers/
mod.rs

1use std::future::Future;
2use std::pin::Pin;
3
4use crate::error::ApiError;
5use crate::types::{MessageRequest, MessageResponse};
6
7pub mod anthropic;
8pub mod openai_compat;
9
10pub type ProviderFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, ApiError>> + Send + 'a>>;
11
12pub trait Provider {
13    type Stream;
14
15    fn send_message<'a>(
16        &'a self,
17        request: &'a MessageRequest,
18    ) -> ProviderFuture<'a, MessageResponse>;
19
20    fn stream_message<'a>(
21        &'a self,
22        request: &'a MessageRequest,
23    ) -> ProviderFuture<'a, Self::Stream>;
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum ProviderKind {
28    Anthropic,
29    Xai,
30    OpenAi,
31    Gemini,
32    OpenRouter,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub struct ProviderMetadata {
37    pub provider: ProviderKind,
38    pub auth_env: &'static str,
39    pub base_url_env: &'static str,
40    pub default_base_url: &'static str,
41}
42
43const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
44    (
45        "opus",
46        ProviderMetadata {
47            provider: ProviderKind::Anthropic,
48            auth_env: "ANTHROPIC_API_KEY",
49            base_url_env: "ANTHROPIC_BASE_URL",
50            default_base_url: anthropic::DEFAULT_BASE_URL,
51        },
52    ),
53    (
54        "sonnet",
55        ProviderMetadata {
56            provider: ProviderKind::Anthropic,
57            auth_env: "ANTHROPIC_API_KEY",
58            base_url_env: "ANTHROPIC_BASE_URL",
59            default_base_url: anthropic::DEFAULT_BASE_URL,
60        },
61    ),
62    (
63        "haiku",
64        ProviderMetadata {
65            provider: ProviderKind::Anthropic,
66            auth_env: "ANTHROPIC_API_KEY",
67            base_url_env: "ANTHROPIC_BASE_URL",
68            default_base_url: anthropic::DEFAULT_BASE_URL,
69        },
70    ),
71    (
72        "claude-opus-4-6",
73        ProviderMetadata {
74            provider: ProviderKind::Anthropic,
75            auth_env: "ANTHROPIC_API_KEY",
76            base_url_env: "ANTHROPIC_BASE_URL",
77            default_base_url: anthropic::DEFAULT_BASE_URL,
78        },
79    ),
80    (
81        "claude-sonnet-4-6",
82        ProviderMetadata {
83            provider: ProviderKind::Anthropic,
84            auth_env: "ANTHROPIC_API_KEY",
85            base_url_env: "ANTHROPIC_BASE_URL",
86            default_base_url: anthropic::DEFAULT_BASE_URL,
87        },
88    ),
89    (
90        "claude-haiku-4-5-20251213",
91        ProviderMetadata {
92            provider: ProviderKind::Anthropic,
93            auth_env: "ANTHROPIC_API_KEY",
94            base_url_env: "ANTHROPIC_BASE_URL",
95            default_base_url: anthropic::DEFAULT_BASE_URL,
96        },
97    ),
98    (
99        "grok",
100        ProviderMetadata {
101            provider: ProviderKind::Xai,
102            auth_env: "XAI_API_KEY",
103            base_url_env: "XAI_BASE_URL",
104            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
105        },
106    ),
107    (
108        "grok-3",
109        ProviderMetadata {
110            provider: ProviderKind::Xai,
111            auth_env: "XAI_API_KEY",
112            base_url_env: "XAI_BASE_URL",
113            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
114        },
115    ),
116    (
117        "grok-mini",
118        ProviderMetadata {
119            provider: ProviderKind::Xai,
120            auth_env: "XAI_API_KEY",
121            base_url_env: "XAI_BASE_URL",
122            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
123        },
124    ),
125    (
126        "grok-3-mini",
127        ProviderMetadata {
128            provider: ProviderKind::Xai,
129            auth_env: "XAI_API_KEY",
130            base_url_env: "XAI_BASE_URL",
131            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
132        },
133    ),
134    (
135        "grok-2",
136        ProviderMetadata {
137            provider: ProviderKind::Xai,
138            auth_env: "XAI_API_KEY",
139            base_url_env: "XAI_BASE_URL",
140            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
141        },
142    ),
143    (
144        "gemini",
145        ProviderMetadata {
146            provider: ProviderKind::Gemini,
147            auth_env: "GEMINI_API_KEY",
148            base_url_env: "GEMINI_BASE_URL",
149            default_base_url: openai_compat::DEFAULT_GEMINI_BASE_URL,
150        },
151    ),
152    (
153        "gemini-2.0-flash",
154        ProviderMetadata {
155            provider: ProviderKind::Gemini,
156            auth_env: "GEMINI_API_KEY",
157            base_url_env: "GEMINI_BASE_URL",
158            default_base_url: openai_compat::DEFAULT_GEMINI_BASE_URL,
159        },
160    ),
161    (
162        "gemini-2.5-pro",
163        ProviderMetadata {
164            provider: ProviderKind::Gemini,
165            auth_env: "GEMINI_API_KEY",
166            base_url_env: "GEMINI_BASE_URL",
167            default_base_url: openai_compat::DEFAULT_GEMINI_BASE_URL,
168        },
169    ),
170    (
171        "gemini-2.5-flash",
172        ProviderMetadata {
173            provider: ProviderKind::Gemini,
174            auth_env: "GEMINI_API_KEY",
175            base_url_env: "GEMINI_BASE_URL",
176            default_base_url: openai_compat::DEFAULT_GEMINI_BASE_URL,
177        },
178    ),
179    (
180        "gemini-1.5-pro",
181        ProviderMetadata {
182            provider: ProviderKind::Gemini,
183            auth_env: "GEMINI_API_KEY",
184            base_url_env: "GEMINI_BASE_URL",
185            default_base_url: openai_compat::DEFAULT_GEMINI_BASE_URL,
186        },
187    ),
188];
189
190#[must_use]
191pub fn resolve_model_alias(model: &str) -> String {
192    let trimmed = model.trim();
193    let lower = trimmed.to_ascii_lowercase();
194    MODEL_REGISTRY
195        .iter()
196        .find_map(|(alias, metadata)| {
197            (*alias == lower).then_some(match metadata.provider {
198                ProviderKind::Anthropic => match *alias {
199                    "opus" => "claude-opus-4-6",
200                    "sonnet" => "claude-sonnet-4-6",
201                    "haiku" => "claude-haiku-4-5-20251213",
202                    _ => trimmed,
203                },
204                ProviderKind::Xai => match *alias {
205                    "grok" | "grok-3" => "grok-3",
206                    "grok-mini" | "grok-3-mini" => "grok-3-mini",
207                    "grok-2" => "grok-2",
208                    _ => trimmed,
209                },
210                ProviderKind::OpenAi | ProviderKind::Gemini | ProviderKind::OpenRouter => trimmed,
211            })
212        })
213        .map_or_else(|| trimmed.to_string(), ToOwned::to_owned)
214}
215
216#[must_use]
217pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
218    let canonical = resolve_model_alias(model);
219    let lower = canonical.to_ascii_lowercase();
220    if let Some((_, metadata)) = MODEL_REGISTRY.iter().find(|(alias, _)| *alias == lower) {
221        return Some(*metadata);
222    }
223    if lower.starts_with("grok") {
224        return Some(ProviderMetadata {
225            provider: ProviderKind::Xai,
226            auth_env: "XAI_API_KEY",
227            base_url_env: "XAI_BASE_URL",
228            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
229        });
230    }
231    if lower.starts_with("gemini") {
232        return Some(ProviderMetadata {
233            provider: ProviderKind::Gemini,
234            auth_env: "GEMINI_API_KEY",
235            base_url_env: "GEMINI_BASE_URL",
236            default_base_url: openai_compat::DEFAULT_GEMINI_BASE_URL,
237        });
238    }
239    if lower.contains('/') {
240        return Some(ProviderMetadata {
241            provider: ProviderKind::OpenRouter,
242            auth_env: "OPENROUTER_API_KEY",
243            base_url_env: "OPENROUTER_BASE_URL",
244            default_base_url: openai_compat::DEFAULT_OPENROUTER_BASE_URL,
245        });
246    }
247    None
248}
249
250#[must_use]
251pub fn detect_provider_kind(model: &str) -> ProviderKind {
252    if let Some(metadata) = metadata_for_model(model) {
253        return metadata.provider;
254    }
255    if anthropic::has_auth_from_env_or_saved().unwrap_or(false) {
256        return ProviderKind::Anthropic;
257    }
258    if openai_compat::has_api_key("OPENAI_API_KEY") {
259        return ProviderKind::OpenAi;
260    }
261    if openai_compat::has_api_key("XAI_API_KEY") {
262        return ProviderKind::Xai;
263    }
264    if openai_compat::has_api_key("GEMINI_API_KEY") {
265        return ProviderKind::Gemini;
266    }
267    if openai_compat::has_api_key("OPENROUTER_API_KEY") {
268        return ProviderKind::OpenRouter;
269    }
270    ProviderKind::Anthropic
271}
272
273#[must_use]
274pub fn max_tokens_for_model(model: &str) -> u32 {
275    let canonical = resolve_model_alias(model);
276    if canonical.contains("opus") {
277        32_000
278    } else {
279        64_000
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::{detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind};
286
287    #[test]
288    fn resolves_grok_aliases() {
289        assert_eq!(resolve_model_alias("grok"), "grok-3");
290        assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini");
291        assert_eq!(resolve_model_alias("grok-2"), "grok-2");
292    }
293
294    #[test]
295    fn detects_provider_from_model_name_first() {
296        assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai);
297        assert_eq!(
298            detect_provider_kind("claude-sonnet-4-6"),
299            ProviderKind::Anthropic
300        );
301    }
302
303    #[test]
304    fn keeps_existing_max_token_heuristic() {
305        assert_eq!(max_tokens_for_model("opus"), 32_000);
306        assert_eq!(max_tokens_for_model("grok-3"), 64_000);
307    }
308}