Skip to main content

codineer_api/providers/
mod.rs

1use std::time::Duration;
2
3pub mod codineer_provider;
4pub mod openai_compat;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct RetryPolicy {
8    pub max_retries: u32,
9    pub initial_backoff: Duration,
10    pub max_backoff: Duration,
11}
12
13impl Default for RetryPolicy {
14    fn default() -> Self {
15        Self {
16            max_retries: 2,
17            initial_backoff: Duration::from_millis(200),
18            max_backoff: Duration::from_secs(2),
19        }
20    }
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum ProviderKind {
25    CodineerApi,
26    Xai,
27    OpenAi,
28    Custom,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub struct ProviderMetadata {
33    pub provider: ProviderKind,
34    pub auth_env: &'static str,
35    pub base_url_env: &'static str,
36    pub default_base_url: &'static str,
37}
38
39const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
40    (
41        "opus",
42        ProviderMetadata {
43            provider: ProviderKind::CodineerApi,
44            auth_env: "ANTHROPIC_API_KEY",
45            base_url_env: "ANTHROPIC_BASE_URL",
46            default_base_url: codineer_provider::DEFAULT_BASE_URL,
47        },
48    ),
49    (
50        "sonnet",
51        ProviderMetadata {
52            provider: ProviderKind::CodineerApi,
53            auth_env: "ANTHROPIC_API_KEY",
54            base_url_env: "ANTHROPIC_BASE_URL",
55            default_base_url: codineer_provider::DEFAULT_BASE_URL,
56        },
57    ),
58    (
59        "haiku",
60        ProviderMetadata {
61            provider: ProviderKind::CodineerApi,
62            auth_env: "ANTHROPIC_API_KEY",
63            base_url_env: "ANTHROPIC_BASE_URL",
64            default_base_url: codineer_provider::DEFAULT_BASE_URL,
65        },
66    ),
67    (
68        "claude-opus-4-6",
69        ProviderMetadata {
70            provider: ProviderKind::CodineerApi,
71            auth_env: "ANTHROPIC_API_KEY",
72            base_url_env: "ANTHROPIC_BASE_URL",
73            default_base_url: codineer_provider::DEFAULT_BASE_URL,
74        },
75    ),
76    (
77        "claude-sonnet-4-6",
78        ProviderMetadata {
79            provider: ProviderKind::CodineerApi,
80            auth_env: "ANTHROPIC_API_KEY",
81            base_url_env: "ANTHROPIC_BASE_URL",
82            default_base_url: codineer_provider::DEFAULT_BASE_URL,
83        },
84    ),
85    (
86        "claude-haiku-4-5-20251213",
87        ProviderMetadata {
88            provider: ProviderKind::CodineerApi,
89            auth_env: "ANTHROPIC_API_KEY",
90            base_url_env: "ANTHROPIC_BASE_URL",
91            default_base_url: codineer_provider::DEFAULT_BASE_URL,
92        },
93    ),
94    (
95        "grok",
96        ProviderMetadata {
97            provider: ProviderKind::Xai,
98            auth_env: "XAI_API_KEY",
99            base_url_env: "XAI_BASE_URL",
100            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
101        },
102    ),
103    (
104        "grok-3",
105        ProviderMetadata {
106            provider: ProviderKind::Xai,
107            auth_env: "XAI_API_KEY",
108            base_url_env: "XAI_BASE_URL",
109            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
110        },
111    ),
112    (
113        "grok-mini",
114        ProviderMetadata {
115            provider: ProviderKind::Xai,
116            auth_env: "XAI_API_KEY",
117            base_url_env: "XAI_BASE_URL",
118            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
119        },
120    ),
121    (
122        "grok-3-mini",
123        ProviderMetadata {
124            provider: ProviderKind::Xai,
125            auth_env: "XAI_API_KEY",
126            base_url_env: "XAI_BASE_URL",
127            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
128        },
129    ),
130    (
131        "grok-2",
132        ProviderMetadata {
133            provider: ProviderKind::Xai,
134            auth_env: "XAI_API_KEY",
135            base_url_env: "XAI_BASE_URL",
136            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
137        },
138    ),
139    (
140        "gpt",
141        ProviderMetadata {
142            provider: ProviderKind::OpenAi,
143            auth_env: "OPENAI_API_KEY",
144            base_url_env: "OPENAI_BASE_URL",
145            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
146        },
147    ),
148    (
149        "gpt-4o",
150        ProviderMetadata {
151            provider: ProviderKind::OpenAi,
152            auth_env: "OPENAI_API_KEY",
153            base_url_env: "OPENAI_BASE_URL",
154            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
155        },
156    ),
157    (
158        "mini",
159        ProviderMetadata {
160            provider: ProviderKind::OpenAi,
161            auth_env: "OPENAI_API_KEY",
162            base_url_env: "OPENAI_BASE_URL",
163            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
164        },
165    ),
166    (
167        "gpt-4o-mini",
168        ProviderMetadata {
169            provider: ProviderKind::OpenAi,
170            auth_env: "OPENAI_API_KEY",
171            base_url_env: "OPENAI_BASE_URL",
172            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
173        },
174    ),
175    (
176        "o3",
177        ProviderMetadata {
178            provider: ProviderKind::OpenAi,
179            auth_env: "OPENAI_API_KEY",
180            base_url_env: "OPENAI_BASE_URL",
181            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
182        },
183    ),
184    (
185        "o3-mini",
186        ProviderMetadata {
187            provider: ProviderKind::OpenAi,
188            auth_env: "OPENAI_API_KEY",
189            base_url_env: "OPENAI_BASE_URL",
190            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
191        },
192    ),
193];
194
195/// Built-in provider presets for OpenAI-compatible services.
196/// Each entry: (name, base_url, api_key_env or empty for local providers).
197pub const BUILTIN_PROVIDER_PRESETS: &[BuiltinProviderPreset] = &[
198    BuiltinProviderPreset {
199        name: "ollama",
200        base_url: "http://localhost:11434/v1",
201        api_key_env: "",
202        description: "Local Ollama instance (no API key needed)",
203    },
204    BuiltinProviderPreset {
205        name: "lmstudio",
206        base_url: "http://localhost:1234/v1",
207        api_key_env: "",
208        description: "Local LM Studio instance (no API key needed)",
209    },
210    BuiltinProviderPreset {
211        name: "openrouter",
212        base_url: "https://openrouter.ai/api/v1",
213        api_key_env: "OPENROUTER_API_KEY",
214        description: "OpenRouter (free models available)",
215    },
216    BuiltinProviderPreset {
217        name: "groq",
218        base_url: "https://api.groq.com/openai/v1",
219        api_key_env: "GROQ_API_KEY",
220        description: "Groq Cloud (generous free tier)",
221    },
222];
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq)]
225pub struct BuiltinProviderPreset {
226    pub name: &'static str,
227    pub base_url: &'static str,
228    pub api_key_env: &'static str,
229    pub description: &'static str,
230}
231
232/// If model starts with `provider/`, return `(provider_name, model_name)`.
233/// Otherwise return `None`.
234#[must_use]
235pub fn parse_custom_provider_prefix(model: &str) -> Option<(&str, &str)> {
236    let trimmed = model.trim();
237    let slash_pos = trimmed.find('/')?;
238    let provider = &trimmed[..slash_pos];
239    let model_part = &trimmed[slash_pos + 1..];
240    if provider.is_empty() || model_part.is_empty() {
241        return None;
242    }
243    if MODEL_REGISTRY
244        .iter()
245        .any(|(alias, _)| *alias == provider.to_ascii_lowercase())
246    {
247        return None;
248    }
249    Some((provider, model_part))
250}
251
252/// Look up a built-in provider preset by name (case-insensitive).
253#[must_use]
254pub fn builtin_preset(name: &str) -> Option<&'static BuiltinProviderPreset> {
255    let lower = name.to_ascii_lowercase();
256    BUILTIN_PROVIDER_PRESETS
257        .iter()
258        .find(|preset| preset.name == lower)
259}
260
261#[must_use]
262pub fn resolve_model_alias(model: &str) -> String {
263    let trimmed = model.trim();
264    if parse_custom_provider_prefix(trimmed).is_some() {
265        return trimmed.to_string();
266    }
267    let lower = trimmed.to_ascii_lowercase();
268    MODEL_REGISTRY
269        .iter()
270        .find_map(|(alias, metadata)| {
271            (*alias == lower).then_some(match metadata.provider {
272                ProviderKind::CodineerApi => match *alias {
273                    "opus" => "claude-opus-4-6",
274                    "sonnet" => "claude-sonnet-4-6",
275                    "haiku" => "claude-haiku-4-5-20251213",
276                    _ => trimmed,
277                },
278                ProviderKind::Xai => match *alias {
279                    "grok" | "grok-3" => "grok-3",
280                    "grok-mini" | "grok-3-mini" => "grok-3-mini",
281                    "grok-2" => "grok-2",
282                    _ => trimmed,
283                },
284                ProviderKind::OpenAi => match *alias {
285                    "gpt" | "gpt-4o" => "gpt-4o",
286                    "mini" | "gpt-4o-mini" => "gpt-4o-mini",
287                    "o3" => "o3",
288                    "o3-mini" => "o3-mini",
289                    _ => trimmed,
290                },
291                ProviderKind::Custom => trimmed,
292            })
293        })
294        .map_or_else(|| trimmed.to_string(), ToOwned::to_owned)
295}
296
297#[must_use]
298pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
299    let canonical = resolve_model_alias(model);
300    let lower = canonical.to_ascii_lowercase();
301    if let Some((_, metadata)) = MODEL_REGISTRY.iter().find(|(alias, _)| *alias == lower) {
302        return Some(*metadata);
303    }
304    if lower.starts_with("grok") {
305        return Some(ProviderMetadata {
306            provider: ProviderKind::Xai,
307            auth_env: "XAI_API_KEY",
308            base_url_env: "XAI_BASE_URL",
309            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
310        });
311    }
312    None
313}
314
315#[must_use]
316pub fn detect_provider_kind(model: &str) -> ProviderKind {
317    if parse_custom_provider_prefix(model).is_some() {
318        return ProviderKind::Custom;
319    }
320    if let Some(metadata) = metadata_for_model(model) {
321        return metadata.provider;
322    }
323    let fallback = detect_available_provider().unwrap_or(ProviderKind::CodineerApi);
324    eprintln!("[warn] unknown model \"{model}\", falling back to {fallback:?} provider");
325    fallback
326}
327
328fn detect_available_provider() -> Option<ProviderKind> {
329    if codineer_provider::has_auth_from_env_or_saved().unwrap_or(false) {
330        return Some(ProviderKind::CodineerApi);
331    }
332    if openai_compat::has_api_key("OPENAI_API_KEY") {
333        return Some(ProviderKind::OpenAi);
334    }
335    if openai_compat::has_api_key("XAI_API_KEY") {
336        return Some(ProviderKind::Xai);
337    }
338    None
339}
340
341/// Detect which provider has available credentials and return its default model.
342/// Returns `None` if no credentials are found for any provider.
343#[must_use]
344pub fn auto_detect_default_model() -> Option<&'static str> {
345    match detect_available_provider()? {
346        ProviderKind::CodineerApi => Some("claude-sonnet-4-6"),
347        ProviderKind::Xai => Some("grok-3"),
348        ProviderKind::OpenAi => Some("gpt-4o"),
349        ProviderKind::Custom => None,
350    }
351}
352
353#[must_use]
354pub fn max_tokens_for_model(model: &str) -> u32 {
355    let canonical = resolve_model_alias(model);
356    if canonical.starts_with("claude-opus") || canonical == "opus" {
357        32_000
358    } else if parse_custom_provider_prefix(&canonical).is_some() {
359        // Local / custom models often have smaller context windows;
360        // 16k is a safe default that avoids hitting limits on 8B–32B models.
361        16_000
362    } else {
363        64_000
364    }
365}
366
367/// A model alias entry for listing.
368#[derive(Debug, Clone)]
369pub struct ModelAliasEntry {
370    pub alias: &'static str,
371    pub canonical: String,
372    pub provider: ProviderKind,
373}
374
375/// Return all built-in model aliases, optionally filtered by provider kind.
376#[must_use]
377pub fn list_builtin_models(filter_provider: Option<ProviderKind>) -> Vec<ModelAliasEntry> {
378    MODEL_REGISTRY
379        .iter()
380        .filter(|(_, meta)| filter_provider.is_none_or(|p| meta.provider == p))
381        .map(|(alias, meta)| ModelAliasEntry {
382            alias,
383            canonical: resolve_model_alias(alias),
384            provider: meta.provider,
385        })
386        .collect()
387}
388
389/// Resolve a provider name to `ProviderKind` from known aliases.
390#[must_use]
391pub fn provider_kind_by_name(name: &str) -> Option<ProviderKind> {
392    let lower = name.to_ascii_lowercase();
393    match lower.as_str() {
394        "anthropic" | "claude" => Some(ProviderKind::CodineerApi),
395        "xai" | "grok" => Some(ProviderKind::Xai),
396        "openai" | "gpt" => Some(ProviderKind::OpenAi),
397        _ => None,
398    }
399}
400
401impl ProviderKind {
402    pub const fn display_name(self) -> &'static str {
403        match self {
404            Self::CodineerApi => "Anthropic",
405            Self::Xai => "xAI",
406            Self::OpenAi => "OpenAI",
407            Self::Custom => "Custom",
408        }
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::{
415        builtin_preset, detect_provider_kind, list_builtin_models, max_tokens_for_model,
416        parse_custom_provider_prefix, provider_kind_by_name, resolve_model_alias, ProviderKind,
417    };
418
419    #[test]
420    fn resolves_grok_aliases() {
421        assert_eq!(resolve_model_alias("grok"), "grok-3");
422        assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini");
423        assert_eq!(resolve_model_alias("grok-2"), "grok-2");
424    }
425
426    #[test]
427    fn detects_provider_from_model_name_first() {
428        assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai);
429        assert_eq!(
430            detect_provider_kind("claude-sonnet-4-6"),
431            ProviderKind::CodineerApi
432        );
433    }
434
435    #[test]
436    fn keeps_existing_max_token_heuristic() {
437        assert_eq!(max_tokens_for_model("opus"), 32_000);
438        assert_eq!(max_tokens_for_model("grok-3"), 64_000);
439    }
440
441    #[test]
442    fn parses_custom_provider_prefix() {
443        assert_eq!(
444            parse_custom_provider_prefix("ollama/qwen2.5-coder"),
445            Some(("ollama", "qwen2.5-coder"))
446        );
447        assert_eq!(
448            parse_custom_provider_prefix("groq/llama-3.3-70b"),
449            Some(("groq", "llama-3.3-70b"))
450        );
451        assert_eq!(
452            parse_custom_provider_prefix("openrouter/meta-llama/llama-3.1-8b:free"),
453            Some(("openrouter", "meta-llama/llama-3.1-8b:free"))
454        );
455        assert_eq!(parse_custom_provider_prefix("grok-3"), None);
456        assert_eq!(parse_custom_provider_prefix("sonnet"), None);
457        assert_eq!(parse_custom_provider_prefix("/model"), None);
458        assert_eq!(parse_custom_provider_prefix("provider/"), None);
459    }
460
461    #[test]
462    fn detects_custom_provider_kind() {
463        assert_eq!(
464            detect_provider_kind("ollama/qwen2.5-coder"),
465            ProviderKind::Custom
466        );
467        assert_eq!(
468            detect_provider_kind("lmstudio/my-model"),
469            ProviderKind::Custom
470        );
471    }
472
473    #[test]
474    fn resolves_custom_model_passthrough() {
475        assert_eq!(
476            resolve_model_alias("ollama/qwen2.5-coder"),
477            "ollama/qwen2.5-coder"
478        );
479    }
480
481    #[test]
482    fn custom_model_tokens_smaller_default() {
483        assert_eq!(max_tokens_for_model("ollama/qwen2.5-coder"), 16_000);
484    }
485
486    #[test]
487    fn builtin_presets_lookup() {
488        let ollama = builtin_preset("ollama").expect("ollama preset should exist");
489        assert_eq!(ollama.base_url, "http://localhost:11434/v1");
490        assert!(ollama.api_key_env.is_empty());
491
492        let groq = builtin_preset("groq").expect("groq preset should exist");
493        assert_eq!(groq.api_key_env, "GROQ_API_KEY");
494
495        assert!(builtin_preset("nonexistent").is_none());
496    }
497
498    #[test]
499    fn list_builtin_models_returns_all_when_unfiltered() {
500        let all = list_builtin_models(None);
501        assert!(!all.is_empty());
502        assert!(all.iter().any(|e| e.provider == ProviderKind::CodineerApi));
503        assert!(all.iter().any(|e| e.provider == ProviderKind::Xai));
504    }
505
506    #[test]
507    fn list_builtin_models_filters_by_provider() {
508        let xai = list_builtin_models(Some(ProviderKind::Xai));
509        assert!(!xai.is_empty());
510        assert!(xai.iter().all(|e| e.provider == ProviderKind::Xai));
511
512        let anthropic = list_builtin_models(Some(ProviderKind::CodineerApi));
513        assert!(!anthropic.is_empty());
514        assert!(anthropic
515            .iter()
516            .all(|e| e.provider == ProviderKind::CodineerApi));
517    }
518
519    #[test]
520    fn list_builtin_models_custom_filter_returns_empty() {
521        let custom = list_builtin_models(Some(ProviderKind::Custom));
522        assert!(custom.is_empty());
523    }
524
525    #[test]
526    fn provider_kind_by_name_resolves_known() {
527        assert_eq!(
528            provider_kind_by_name("anthropic"),
529            Some(ProviderKind::CodineerApi)
530        );
531        assert_eq!(
532            provider_kind_by_name("claude"),
533            Some(ProviderKind::CodineerApi)
534        );
535        assert_eq!(provider_kind_by_name("xai"), Some(ProviderKind::Xai));
536        assert_eq!(provider_kind_by_name("grok"), Some(ProviderKind::Xai));
537        assert_eq!(provider_kind_by_name("openai"), Some(ProviderKind::OpenAi));
538        assert_eq!(provider_kind_by_name("gpt"), Some(ProviderKind::OpenAi));
539    }
540
541    #[test]
542    fn provider_kind_by_name_case_insensitive() {
543        assert_eq!(
544            provider_kind_by_name("Anthropic"),
545            Some(ProviderKind::CodineerApi)
546        );
547        assert_eq!(provider_kind_by_name("XAI"), Some(ProviderKind::Xai));
548    }
549
550    #[test]
551    fn provider_kind_by_name_returns_none_for_unknown() {
552        assert_eq!(provider_kind_by_name("ollama"), None);
553        assert_eq!(provider_kind_by_name("unknown"), None);
554        assert_eq!(provider_kind_by_name(""), None);
555    }
556
557    #[test]
558    fn provider_kind_display_name_covers_all_variants() {
559        assert_eq!(ProviderKind::CodineerApi.display_name(), "Anthropic");
560        assert_eq!(ProviderKind::Xai.display_name(), "xAI");
561        assert_eq!(ProviderKind::OpenAi.display_name(), "OpenAI");
562        assert_eq!(ProviderKind::Custom.display_name(), "Custom");
563    }
564}