Skip to main content

codineer_api/providers/
mod.rs

1use std::time::Duration;
2
3pub mod codineer_provider;
4pub mod openai_compat;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct RetryPolicy {
8    pub max_retries: u32,
9    pub initial_backoff: Duration,
10    pub max_backoff: Duration,
11}
12
13impl Default for RetryPolicy {
14    fn default() -> Self {
15        Self {
16            max_retries: 2,
17            initial_backoff: Duration::from_millis(200),
18            max_backoff: Duration::from_secs(2),
19        }
20    }
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum ProviderKind {
25    CodineerApi,
26    Xai,
27    OpenAi,
28    Custom,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub struct ProviderMetadata {
33    pub provider: ProviderKind,
34    pub auth_env: &'static str,
35    pub base_url_env: &'static str,
36    pub default_base_url: &'static str,
37}
38
39const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
40    (
41        "claude-opus-4-6",
42        ProviderMetadata {
43            provider: ProviderKind::CodineerApi,
44            auth_env: "ANTHROPIC_API_KEY",
45            base_url_env: "ANTHROPIC_BASE_URL",
46            default_base_url: codineer_provider::DEFAULT_BASE_URL,
47        },
48    ),
49    (
50        "claude-sonnet-4-6",
51        ProviderMetadata {
52            provider: ProviderKind::CodineerApi,
53            auth_env: "ANTHROPIC_API_KEY",
54            base_url_env: "ANTHROPIC_BASE_URL",
55            default_base_url: codineer_provider::DEFAULT_BASE_URL,
56        },
57    ),
58    (
59        "claude-haiku-4-5-20251213",
60        ProviderMetadata {
61            provider: ProviderKind::CodineerApi,
62            auth_env: "ANTHROPIC_API_KEY",
63            base_url_env: "ANTHROPIC_BASE_URL",
64            default_base_url: codineer_provider::DEFAULT_BASE_URL,
65        },
66    ),
67    (
68        "grok-3",
69        ProviderMetadata {
70            provider: ProviderKind::Xai,
71            auth_env: "XAI_API_KEY",
72            base_url_env: "XAI_BASE_URL",
73            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
74        },
75    ),
76    (
77        "grok-3-mini",
78        ProviderMetadata {
79            provider: ProviderKind::Xai,
80            auth_env: "XAI_API_KEY",
81            base_url_env: "XAI_BASE_URL",
82            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
83        },
84    ),
85    (
86        "grok-2",
87        ProviderMetadata {
88            provider: ProviderKind::Xai,
89            auth_env: "XAI_API_KEY",
90            base_url_env: "XAI_BASE_URL",
91            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
92        },
93    ),
94    (
95        "gpt-4o",
96        ProviderMetadata {
97            provider: ProviderKind::OpenAi,
98            auth_env: "OPENAI_API_KEY",
99            base_url_env: "OPENAI_BASE_URL",
100            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
101        },
102    ),
103    (
104        "gpt-4o-mini",
105        ProviderMetadata {
106            provider: ProviderKind::OpenAi,
107            auth_env: "OPENAI_API_KEY",
108            base_url_env: "OPENAI_BASE_URL",
109            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
110        },
111    ),
112    (
113        "o3",
114        ProviderMetadata {
115            provider: ProviderKind::OpenAi,
116            auth_env: "OPENAI_API_KEY",
117            base_url_env: "OPENAI_BASE_URL",
118            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
119        },
120    ),
121    (
122        "o3-mini",
123        ProviderMetadata {
124            provider: ProviderKind::OpenAi,
125            auth_env: "OPENAI_API_KEY",
126            base_url_env: "OPENAI_BASE_URL",
127            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
128        },
129    ),
130];
131
132/// Built-in provider presets for OpenAI-compatible services.
133/// Each entry: (name, base_url, api_key_env or empty for local providers).
134pub const BUILTIN_PROVIDER_PRESETS: &[BuiltinProviderPreset] = &[
135    BuiltinProviderPreset {
136        name: "ollama",
137        base_url: "http://localhost:11434/v1",
138        api_key_env: "",
139        description: "Local Ollama instance (no API key needed)",
140    },
141    BuiltinProviderPreset {
142        name: "lmstudio",
143        base_url: "http://localhost:1234/v1",
144        api_key_env: "",
145        description: "Local LM Studio instance (no API key needed)",
146    },
147    BuiltinProviderPreset {
148        name: "openrouter",
149        base_url: "https://openrouter.ai/api/v1",
150        api_key_env: "OPENROUTER_API_KEY",
151        description: "OpenRouter (free models available)",
152    },
153    BuiltinProviderPreset {
154        name: "groq",
155        base_url: "https://api.groq.com/openai/v1",
156        api_key_env: "GROQ_API_KEY",
157        description: "Groq Cloud (generous free tier)",
158    },
159];
160
161#[derive(Debug, Clone, Copy, PartialEq, Eq)]
162pub struct BuiltinProviderPreset {
163    pub name: &'static str,
164    pub base_url: &'static str,
165    pub api_key_env: &'static str,
166    pub description: &'static str,
167}
168
169/// If model starts with `provider/`, return `(provider_name, model_name)`.
170/// Otherwise return `None`.
171#[must_use]
172pub fn parse_custom_provider_prefix(model: &str) -> Option<(&str, &str)> {
173    let trimmed = model.trim();
174    let slash_pos = trimmed.find('/')?;
175    let provider = &trimmed[..slash_pos];
176    let model_part = &trimmed[slash_pos + 1..];
177    if provider.is_empty() || model_part.is_empty() {
178        return None;
179    }
180    Some((provider, model_part))
181}
182
183/// Look up a built-in provider preset by name (case-insensitive).
184#[must_use]
185pub fn builtin_preset(name: &str) -> Option<&'static BuiltinProviderPreset> {
186    let lower = name.to_ascii_lowercase();
187    BUILTIN_PROVIDER_PRESETS
188        .iter()
189        .find(|preset| preset.name == lower)
190}
191
192/// Normalize a model name: trim whitespace, apply user-defined aliases.
193/// Pass an empty map if no user aliases are available.
194#[must_use]
195pub fn resolve_model_alias(
196    model: &str,
197    user_aliases: &std::collections::BTreeMap<String, String>,
198) -> String {
199    let trimmed = model.trim();
200    if parse_custom_provider_prefix(trimmed).is_some() {
201        return trimmed.to_string();
202    }
203    let lower = trimmed.to_ascii_lowercase();
204    user_aliases
205        .get(&lower)
206        .cloned()
207        .unwrap_or_else(|| trimmed.to_string())
208}
209
210#[must_use]
211pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
212    let lower = model.trim().to_ascii_lowercase();
213    if let Some((_, metadata)) = MODEL_REGISTRY.iter().find(|(alias, _)| *alias == lower) {
214        return Some(*metadata);
215    }
216    if lower.starts_with("grok") {
217        return Some(ProviderMetadata {
218            provider: ProviderKind::Xai,
219            auth_env: "XAI_API_KEY",
220            base_url_env: "XAI_BASE_URL",
221            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
222        });
223    }
224    if lower.starts_with("claude-") || lower == "claude" {
225        return Some(ProviderMetadata {
226            provider: ProviderKind::CodineerApi,
227            auth_env: "ANTHROPIC_API_KEY",
228            base_url_env: "ANTHROPIC_BASE_URL",
229            default_base_url: codineer_provider::DEFAULT_BASE_URL,
230        });
231    }
232    if lower.starts_with("gpt")
233        || lower.starts_with("o1")
234        || lower.starts_with("o3")
235        || lower.starts_with("o4")
236        || lower.starts_with("chatgpt-")
237    {
238        return Some(ProviderMetadata {
239            provider: ProviderKind::OpenAi,
240            auth_env: "OPENAI_API_KEY",
241            base_url_env: "OPENAI_BASE_URL",
242            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
243        });
244    }
245    None
246}
247
248#[must_use]
249pub fn detect_provider_kind(model: &str) -> ProviderKind {
250    if parse_custom_provider_prefix(model).is_some() {
251        return ProviderKind::Custom;
252    }
253    if let Some(metadata) = metadata_for_model(model) {
254        return metadata.provider;
255    }
256    let fallback = detect_available_provider().unwrap_or(ProviderKind::CodineerApi);
257    eprintln!("[warn] unknown model \"{model}\", falling back to {fallback:?} provider");
258    fallback
259}
260
261fn detect_available_provider() -> Option<ProviderKind> {
262    if codineer_provider::has_auth_from_env_or_saved().unwrap_or(false) {
263        return Some(ProviderKind::CodineerApi);
264    }
265    if openai_compat::has_api_key("OPENAI_API_KEY") {
266        return Some(ProviderKind::OpenAi);
267    }
268    if openai_compat::has_api_key("XAI_API_KEY") {
269        return Some(ProviderKind::Xai);
270    }
271    None
272}
273
274/// Detect which provider has available credentials and return its default model.
275/// Returns `None` if no credentials are found for any provider.
276#[must_use]
277pub fn auto_detect_default_model() -> Option<&'static str> {
278    match detect_available_provider()? {
279        ProviderKind::CodineerApi => Some("claude-sonnet-4-6"),
280        ProviderKind::Xai => Some("grok-3"),
281        ProviderKind::OpenAi => Some("gpt-4o"),
282        ProviderKind::Custom => None,
283    }
284}
285
286#[must_use]
287pub fn max_tokens_for_model(model: &str) -> u32 {
288    let canonical = model.trim();
289    if canonical.starts_with("claude-opus") {
290        32_000
291    } else if parse_custom_provider_prefix(canonical).is_some() {
292        // Local / custom models often have smaller context windows;
293        // 16k is a safe default that avoids hitting limits on 8B–32B models.
294        16_000
295    } else {
296        64_000
297    }
298}
299
300/// Return all known model names from the registry.
301#[must_use]
302pub fn list_known_models(
303    filter_provider: Option<ProviderKind>,
304) -> Vec<(&'static str, ProviderKind)> {
305    MODEL_REGISTRY
306        .iter()
307        .filter(|(_, meta)| filter_provider.is_none_or(|p| meta.provider == p))
308        .map(|(name, meta)| (*name, meta.provider))
309        .collect()
310}
311
312/// Resolve a provider name to `ProviderKind` from known aliases.
313#[must_use]
314pub fn provider_kind_by_name(name: &str) -> Option<ProviderKind> {
315    let lower = name.to_ascii_lowercase();
316    match lower.as_str() {
317        "anthropic" | "claude" => Some(ProviderKind::CodineerApi),
318        "xai" | "grok" => Some(ProviderKind::Xai),
319        "openai" | "gpt" => Some(ProviderKind::OpenAi),
320        _ => None,
321    }
322}
323
324impl ProviderKind {
325    pub const fn display_name(self) -> &'static str {
326        match self {
327            Self::CodineerApi => "Anthropic",
328            Self::Xai => "xAI",
329            Self::OpenAi => "OpenAI",
330            Self::Custom => "Custom",
331        }
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::{
338        builtin_preset, detect_provider_kind, list_known_models, max_tokens_for_model,
339        parse_custom_provider_prefix, provider_kind_by_name, resolve_model_alias, ProviderKind,
340    };
341    use std::collections::BTreeMap;
342
343    fn empty_aliases() -> BTreeMap<String, String> {
344        BTreeMap::new()
345    }
346
347    fn sample_aliases() -> BTreeMap<String, String> {
348        let mut m = BTreeMap::new();
349        m.insert("sonnet".into(), "claude-sonnet-4-6".into());
350        m.insert("grok".into(), "grok-3".into());
351        m
352    }
353
354    #[test]
355    fn resolves_user_aliases() {
356        let aliases = sample_aliases();
357        assert_eq!(resolve_model_alias("sonnet", &aliases), "claude-sonnet-4-6");
358        assert_eq!(resolve_model_alias("grok", &aliases), "grok-3");
359    }
360
361    #[test]
362    fn passthrough_when_no_alias() {
363        let aliases = empty_aliases();
364        assert_eq!(resolve_model_alias("grok-2", &aliases), "grok-2");
365        assert_eq!(
366            resolve_model_alias("custom-model", &aliases),
367            "custom-model"
368        );
369    }
370
371    #[test]
372    fn detects_provider_from_model_name_first() {
373        assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai);
374        assert_eq!(
375            detect_provider_kind("claude-sonnet-4-6"),
376            ProviderKind::CodineerApi
377        );
378    }
379
380    #[test]
381    fn detects_provider_by_unlisted_model_id_prefix() {
382        assert_eq!(
383            detect_provider_kind("claude-3-5-sonnet-20241022"),
384            ProviderKind::CodineerApi
385        );
386        assert_eq!(detect_provider_kind("gpt-4-turbo"), ProviderKind::OpenAi);
387        assert_eq!(detect_provider_kind("o1-preview"), ProviderKind::OpenAi);
388        assert_eq!(detect_provider_kind("o3-pro"), ProviderKind::OpenAi);
389    }
390
391    #[test]
392    fn keeps_existing_max_token_heuristic() {
393        assert_eq!(max_tokens_for_model("claude-opus-4-6"), 32_000);
394        assert_eq!(max_tokens_for_model("grok-3"), 64_000);
395    }
396
397    #[test]
398    fn parses_custom_provider_prefix() {
399        assert_eq!(
400            parse_custom_provider_prefix("ollama/qwen2.5-coder"),
401            Some(("ollama", "qwen2.5-coder"))
402        );
403        assert_eq!(
404            parse_custom_provider_prefix("groq/llama-3.3-70b"),
405            Some(("groq", "llama-3.3-70b"))
406        );
407        assert_eq!(
408            parse_custom_provider_prefix("openrouter/meta-llama/llama-3.1-8b:free"),
409            Some(("openrouter", "meta-llama/llama-3.1-8b:free"))
410        );
411        assert_eq!(parse_custom_provider_prefix("grok-3"), None);
412        assert_eq!(parse_custom_provider_prefix("sonnet"), None);
413        assert_eq!(parse_custom_provider_prefix("/model"), None);
414        assert_eq!(parse_custom_provider_prefix("provider/"), None);
415    }
416
417    #[test]
418    fn detects_custom_provider_kind() {
419        assert_eq!(
420            detect_provider_kind("ollama/qwen2.5-coder"),
421            ProviderKind::Custom
422        );
423        assert_eq!(
424            detect_provider_kind("lmstudio/my-model"),
425            ProviderKind::Custom
426        );
427    }
428
429    #[test]
430    fn resolves_custom_model_passthrough() {
431        assert_eq!(
432            resolve_model_alias("ollama/qwen2.5-coder", &empty_aliases()),
433            "ollama/qwen2.5-coder"
434        );
435    }
436
437    #[test]
438    fn custom_model_tokens_smaller_default() {
439        assert_eq!(max_tokens_for_model("ollama/qwen2.5-coder"), 16_000);
440    }
441
442    #[test]
443    fn builtin_presets_lookup() {
444        let ollama = builtin_preset("ollama").expect("ollama preset should exist");
445        assert_eq!(ollama.base_url, "http://localhost:11434/v1");
446        assert!(ollama.api_key_env.is_empty());
447
448        let groq = builtin_preset("groq").expect("groq preset should exist");
449        assert_eq!(groq.api_key_env, "GROQ_API_KEY");
450
451        assert!(builtin_preset("nonexistent").is_none());
452    }
453
454    #[test]
455    fn list_known_models_returns_all_when_unfiltered() {
456        let all = list_known_models(None);
457        assert!(!all.is_empty());
458        assert!(all.iter().any(|(_, k)| *k == ProviderKind::CodineerApi));
459        assert!(all.iter().any(|(_, k)| *k == ProviderKind::Xai));
460    }
461
462    #[test]
463    fn list_known_models_filters_by_provider() {
464        let xai = list_known_models(Some(ProviderKind::Xai));
465        assert!(!xai.is_empty());
466        assert!(xai.iter().all(|(_, k)| *k == ProviderKind::Xai));
467
468        let anthropic = list_known_models(Some(ProviderKind::CodineerApi));
469        assert!(!anthropic.is_empty());
470        assert!(anthropic
471            .iter()
472            .all(|(_, k)| *k == ProviderKind::CodineerApi));
473    }
474
475    #[test]
476    fn list_known_models_custom_filter_returns_empty() {
477        let custom = list_known_models(Some(ProviderKind::Custom));
478        assert!(custom.is_empty());
479    }
480
481    #[test]
482    fn provider_kind_by_name_resolves_known() {
483        assert_eq!(
484            provider_kind_by_name("anthropic"),
485            Some(ProviderKind::CodineerApi)
486        );
487        assert_eq!(
488            provider_kind_by_name("claude"),
489            Some(ProviderKind::CodineerApi)
490        );
491        assert_eq!(provider_kind_by_name("xai"), Some(ProviderKind::Xai));
492        assert_eq!(provider_kind_by_name("grok"), Some(ProviderKind::Xai));
493        assert_eq!(provider_kind_by_name("openai"), Some(ProviderKind::OpenAi));
494        assert_eq!(provider_kind_by_name("gpt"), Some(ProviderKind::OpenAi));
495    }
496
497    #[test]
498    fn provider_kind_by_name_case_insensitive() {
499        assert_eq!(
500            provider_kind_by_name("Anthropic"),
501            Some(ProviderKind::CodineerApi)
502        );
503        assert_eq!(provider_kind_by_name("XAI"), Some(ProviderKind::Xai));
504    }
505
506    #[test]
507    fn provider_kind_by_name_returns_none_for_unknown() {
508        assert_eq!(provider_kind_by_name("ollama"), None);
509        assert_eq!(provider_kind_by_name("unknown"), None);
510        assert_eq!(provider_kind_by_name(""), None);
511    }
512
513    #[test]
514    fn provider_kind_display_name_covers_all_variants() {
515        assert_eq!(ProviderKind::CodineerApi.display_name(), "Anthropic");
516        assert_eq!(ProviderKind::Xai.display_name(), "xAI");
517        assert_eq!(ProviderKind::OpenAi.display_name(), "OpenAI");
518        assert_eq!(ProviderKind::Custom.display_name(), "Custom");
519    }
520}