Skip to main content

codineer_api/providers/
mod.rs

1use std::time::Duration;
2
3pub mod codineer_provider;
4pub mod openai_compat;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct RetryPolicy {
8    pub max_retries: u32,
9    pub initial_backoff: Duration,
10    pub max_backoff: Duration,
11}
12
13impl Default for RetryPolicy {
14    fn default() -> Self {
15        Self {
16            max_retries: 2,
17            initial_backoff: Duration::from_millis(200),
18            max_backoff: Duration::from_secs(2),
19        }
20    }
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum ProviderKind {
25    CodineerApi,
26    Xai,
27    OpenAi,
28    Custom,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub struct ProviderMetadata {
33    pub provider: ProviderKind,
34    pub auth_env: &'static str,
35    pub base_url_env: &'static str,
36    pub default_base_url: &'static str,
37}
38
39const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
40    (
41        "opus",
42        ProviderMetadata {
43            provider: ProviderKind::CodineerApi,
44            auth_env: "ANTHROPIC_API_KEY",
45            base_url_env: "ANTHROPIC_BASE_URL",
46            default_base_url: codineer_provider::DEFAULT_BASE_URL,
47        },
48    ),
49    (
50        "sonnet",
51        ProviderMetadata {
52            provider: ProviderKind::CodineerApi,
53            auth_env: "ANTHROPIC_API_KEY",
54            base_url_env: "ANTHROPIC_BASE_URL",
55            default_base_url: codineer_provider::DEFAULT_BASE_URL,
56        },
57    ),
58    (
59        "haiku",
60        ProviderMetadata {
61            provider: ProviderKind::CodineerApi,
62            auth_env: "ANTHROPIC_API_KEY",
63            base_url_env: "ANTHROPIC_BASE_URL",
64            default_base_url: codineer_provider::DEFAULT_BASE_URL,
65        },
66    ),
67    (
68        "claude-opus-4-6",
69        ProviderMetadata {
70            provider: ProviderKind::CodineerApi,
71            auth_env: "ANTHROPIC_API_KEY",
72            base_url_env: "ANTHROPIC_BASE_URL",
73            default_base_url: codineer_provider::DEFAULT_BASE_URL,
74        },
75    ),
76    (
77        "claude-sonnet-4-6",
78        ProviderMetadata {
79            provider: ProviderKind::CodineerApi,
80            auth_env: "ANTHROPIC_API_KEY",
81            base_url_env: "ANTHROPIC_BASE_URL",
82            default_base_url: codineer_provider::DEFAULT_BASE_URL,
83        },
84    ),
85    (
86        "claude-haiku-4-5-20251213",
87        ProviderMetadata {
88            provider: ProviderKind::CodineerApi,
89            auth_env: "ANTHROPIC_API_KEY",
90            base_url_env: "ANTHROPIC_BASE_URL",
91            default_base_url: codineer_provider::DEFAULT_BASE_URL,
92        },
93    ),
94    (
95        "grok",
96        ProviderMetadata {
97            provider: ProviderKind::Xai,
98            auth_env: "XAI_API_KEY",
99            base_url_env: "XAI_BASE_URL",
100            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
101        },
102    ),
103    (
104        "grok-3",
105        ProviderMetadata {
106            provider: ProviderKind::Xai,
107            auth_env: "XAI_API_KEY",
108            base_url_env: "XAI_BASE_URL",
109            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
110        },
111    ),
112    (
113        "grok-mini",
114        ProviderMetadata {
115            provider: ProviderKind::Xai,
116            auth_env: "XAI_API_KEY",
117            base_url_env: "XAI_BASE_URL",
118            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
119        },
120    ),
121    (
122        "grok-3-mini",
123        ProviderMetadata {
124            provider: ProviderKind::Xai,
125            auth_env: "XAI_API_KEY",
126            base_url_env: "XAI_BASE_URL",
127            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
128        },
129    ),
130    (
131        "grok-2",
132        ProviderMetadata {
133            provider: ProviderKind::Xai,
134            auth_env: "XAI_API_KEY",
135            base_url_env: "XAI_BASE_URL",
136            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
137        },
138    ),
139    (
140        "gpt",
141        ProviderMetadata {
142            provider: ProviderKind::OpenAi,
143            auth_env: "OPENAI_API_KEY",
144            base_url_env: "OPENAI_BASE_URL",
145            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
146        },
147    ),
148    (
149        "gpt-4o",
150        ProviderMetadata {
151            provider: ProviderKind::OpenAi,
152            auth_env: "OPENAI_API_KEY",
153            base_url_env: "OPENAI_BASE_URL",
154            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
155        },
156    ),
157    (
158        "mini",
159        ProviderMetadata {
160            provider: ProviderKind::OpenAi,
161            auth_env: "OPENAI_API_KEY",
162            base_url_env: "OPENAI_BASE_URL",
163            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
164        },
165    ),
166    (
167        "gpt-4o-mini",
168        ProviderMetadata {
169            provider: ProviderKind::OpenAi,
170            auth_env: "OPENAI_API_KEY",
171            base_url_env: "OPENAI_BASE_URL",
172            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
173        },
174    ),
175    (
176        "o3",
177        ProviderMetadata {
178            provider: ProviderKind::OpenAi,
179            auth_env: "OPENAI_API_KEY",
180            base_url_env: "OPENAI_BASE_URL",
181            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
182        },
183    ),
184    (
185        "o3-mini",
186        ProviderMetadata {
187            provider: ProviderKind::OpenAi,
188            auth_env: "OPENAI_API_KEY",
189            base_url_env: "OPENAI_BASE_URL",
190            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
191        },
192    ),
193];
194
195/// Built-in provider presets for OpenAI-compatible services.
196/// Each entry: (name, base_url, api_key_env or empty for local providers).
197pub const BUILTIN_PROVIDER_PRESETS: &[BuiltinProviderPreset] = &[
198    BuiltinProviderPreset {
199        name: "ollama",
200        base_url: "http://localhost:11434/v1",
201        api_key_env: "",
202        description: "Local Ollama instance (no API key needed)",
203    },
204    BuiltinProviderPreset {
205        name: "lmstudio",
206        base_url: "http://localhost:1234/v1",
207        api_key_env: "",
208        description: "Local LM Studio instance (no API key needed)",
209    },
210    BuiltinProviderPreset {
211        name: "openrouter",
212        base_url: "https://openrouter.ai/api/v1",
213        api_key_env: "OPENROUTER_API_KEY",
214        description: "OpenRouter (free models available)",
215    },
216    BuiltinProviderPreset {
217        name: "groq",
218        base_url: "https://api.groq.com/openai/v1",
219        api_key_env: "GROQ_API_KEY",
220        description: "Groq Cloud (generous free tier)",
221    },
222];
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq)]
225pub struct BuiltinProviderPreset {
226    pub name: &'static str,
227    pub base_url: &'static str,
228    pub api_key_env: &'static str,
229    pub description: &'static str,
230}
231
232/// If model starts with `provider/`, return `(provider_name, model_name)`.
233/// Otherwise return `None`.
234#[must_use]
235pub fn parse_custom_provider_prefix(model: &str) -> Option<(&str, &str)> {
236    let trimmed = model.trim();
237    let slash_pos = trimmed.find('/')?;
238    let provider = &trimmed[..slash_pos];
239    let model_part = &trimmed[slash_pos + 1..];
240    if provider.is_empty() || model_part.is_empty() {
241        return None;
242    }
243    if MODEL_REGISTRY
244        .iter()
245        .any(|(alias, _)| *alias == provider.to_ascii_lowercase())
246    {
247        return None;
248    }
249    Some((provider, model_part))
250}
251
252/// Look up a built-in provider preset by name (case-insensitive).
253#[must_use]
254pub fn builtin_preset(name: &str) -> Option<&'static BuiltinProviderPreset> {
255    let lower = name.to_ascii_lowercase();
256    BUILTIN_PROVIDER_PRESETS
257        .iter()
258        .find(|preset| preset.name == lower)
259}
260
261#[must_use]
262pub fn resolve_model_alias(model: &str) -> String {
263    let trimmed = model.trim();
264    if parse_custom_provider_prefix(trimmed).is_some() {
265        return trimmed.to_string();
266    }
267    let lower = trimmed.to_ascii_lowercase();
268    MODEL_REGISTRY
269        .iter()
270        .find_map(|(alias, metadata)| {
271            (*alias == lower).then_some(match metadata.provider {
272                ProviderKind::CodineerApi => match *alias {
273                    "opus" => "claude-opus-4-6",
274                    "sonnet" => "claude-sonnet-4-6",
275                    "haiku" => "claude-haiku-4-5-20251213",
276                    _ => trimmed,
277                },
278                ProviderKind::Xai => match *alias {
279                    "grok" | "grok-3" => "grok-3",
280                    "grok-mini" | "grok-3-mini" => "grok-3-mini",
281                    "grok-2" => "grok-2",
282                    _ => trimmed,
283                },
284                ProviderKind::OpenAi => match *alias {
285                    "gpt" | "gpt-4o" => "gpt-4o",
286                    "mini" | "gpt-4o-mini" => "gpt-4o-mini",
287                    "o3" => "o3",
288                    "o3-mini" => "o3-mini",
289                    _ => trimmed,
290                },
291                ProviderKind::Custom => trimmed,
292            })
293        })
294        .map_or_else(|| trimmed.to_string(), ToOwned::to_owned)
295}
296
297#[must_use]
298pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
299    let canonical = resolve_model_alias(model);
300    let lower = canonical.to_ascii_lowercase();
301    if let Some((_, metadata)) = MODEL_REGISTRY.iter().find(|(alias, _)| *alias == lower) {
302        return Some(*metadata);
303    }
304    if lower.starts_with("grok") {
305        return Some(ProviderMetadata {
306            provider: ProviderKind::Xai,
307            auth_env: "XAI_API_KEY",
308            base_url_env: "XAI_BASE_URL",
309            default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
310        });
311    }
312    if lower.starts_with("claude-") || lower == "claude" {
313        return Some(ProviderMetadata {
314            provider: ProviderKind::CodineerApi,
315            auth_env: "ANTHROPIC_API_KEY",
316            base_url_env: "ANTHROPIC_BASE_URL",
317            default_base_url: codineer_provider::DEFAULT_BASE_URL,
318        });
319    }
320    if lower.starts_with("gpt")
321        || lower.starts_with("o1")
322        || lower.starts_with("o3")
323        || lower.starts_with("o4")
324        || lower.starts_with("chatgpt-")
325    {
326        return Some(ProviderMetadata {
327            provider: ProviderKind::OpenAi,
328            auth_env: "OPENAI_API_KEY",
329            base_url_env: "OPENAI_BASE_URL",
330            default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
331        });
332    }
333    None
334}
335
336#[must_use]
337pub fn detect_provider_kind(model: &str) -> ProviderKind {
338    if parse_custom_provider_prefix(model).is_some() {
339        return ProviderKind::Custom;
340    }
341    if let Some(metadata) = metadata_for_model(model) {
342        return metadata.provider;
343    }
344    let fallback = detect_available_provider().unwrap_or(ProviderKind::CodineerApi);
345    eprintln!("[warn] unknown model \"{model}\", falling back to {fallback:?} provider");
346    fallback
347}
348
349fn detect_available_provider() -> Option<ProviderKind> {
350    if codineer_provider::has_auth_from_env_or_saved().unwrap_or(false) {
351        return Some(ProviderKind::CodineerApi);
352    }
353    if openai_compat::has_api_key("OPENAI_API_KEY") {
354        return Some(ProviderKind::OpenAi);
355    }
356    if openai_compat::has_api_key("XAI_API_KEY") {
357        return Some(ProviderKind::Xai);
358    }
359    None
360}
361
362/// Detect which provider has available credentials and return its default model.
363/// Returns `None` if no credentials are found for any provider.
364#[must_use]
365pub fn auto_detect_default_model() -> Option<&'static str> {
366    match detect_available_provider()? {
367        ProviderKind::CodineerApi => Some("claude-sonnet-4-6"),
368        ProviderKind::Xai => Some("grok-3"),
369        ProviderKind::OpenAi => Some("gpt-4o"),
370        ProviderKind::Custom => None,
371    }
372}
373
374#[must_use]
375pub fn max_tokens_for_model(model: &str) -> u32 {
376    let canonical = resolve_model_alias(model);
377    if canonical.starts_with("claude-opus") || canonical == "opus" {
378        32_000
379    } else if parse_custom_provider_prefix(&canonical).is_some() {
380        // Local / custom models often have smaller context windows;
381        // 16k is a safe default that avoids hitting limits on 8B–32B models.
382        16_000
383    } else {
384        64_000
385    }
386}
387
388/// A model alias entry for listing.
389#[derive(Debug, Clone)]
390pub struct ModelAliasEntry {
391    pub alias: &'static str,
392    pub canonical: String,
393    pub provider: ProviderKind,
394}
395
396/// Return all built-in model aliases, optionally filtered by provider kind.
397#[must_use]
398pub fn list_builtin_models(filter_provider: Option<ProviderKind>) -> Vec<ModelAliasEntry> {
399    MODEL_REGISTRY
400        .iter()
401        .filter(|(_, meta)| filter_provider.is_none_or(|p| meta.provider == p))
402        .map(|(alias, meta)| ModelAliasEntry {
403            alias,
404            canonical: resolve_model_alias(alias),
405            provider: meta.provider,
406        })
407        .collect()
408}
409
410/// Resolve a provider name to `ProviderKind` from known aliases.
411#[must_use]
412pub fn provider_kind_by_name(name: &str) -> Option<ProviderKind> {
413    let lower = name.to_ascii_lowercase();
414    match lower.as_str() {
415        "anthropic" | "claude" => Some(ProviderKind::CodineerApi),
416        "xai" | "grok" => Some(ProviderKind::Xai),
417        "openai" | "gpt" => Some(ProviderKind::OpenAi),
418        _ => None,
419    }
420}
421
422impl ProviderKind {
423    pub const fn display_name(self) -> &'static str {
424        match self {
425            Self::CodineerApi => "Anthropic",
426            Self::Xai => "xAI",
427            Self::OpenAi => "OpenAI",
428            Self::Custom => "Custom",
429        }
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::{
436        builtin_preset, detect_provider_kind, list_builtin_models, max_tokens_for_model,
437        parse_custom_provider_prefix, provider_kind_by_name, resolve_model_alias, ProviderKind,
438    };
439
440    #[test]
441    fn resolves_grok_aliases() {
442        assert_eq!(resolve_model_alias("grok"), "grok-3");
443        assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini");
444        assert_eq!(resolve_model_alias("grok-2"), "grok-2");
445    }
446
447    #[test]
448    fn detects_provider_from_model_name_first() {
449        assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai);
450        assert_eq!(
451            detect_provider_kind("claude-sonnet-4-6"),
452            ProviderKind::CodineerApi
453        );
454    }
455
456    #[test]
457    fn detects_provider_by_unlisted_model_id_prefix() {
458        assert_eq!(
459            detect_provider_kind("claude-3-5-sonnet-20241022"),
460            ProviderKind::CodineerApi
461        );
462        assert_eq!(detect_provider_kind("gpt-4-turbo"), ProviderKind::OpenAi);
463        assert_eq!(detect_provider_kind("o1-preview"), ProviderKind::OpenAi);
464        assert_eq!(detect_provider_kind("o3-pro"), ProviderKind::OpenAi);
465    }
466
467    #[test]
468    fn keeps_existing_max_token_heuristic() {
469        assert_eq!(max_tokens_for_model("opus"), 32_000);
470        assert_eq!(max_tokens_for_model("grok-3"), 64_000);
471    }
472
473    #[test]
474    fn parses_custom_provider_prefix() {
475        assert_eq!(
476            parse_custom_provider_prefix("ollama/qwen2.5-coder"),
477            Some(("ollama", "qwen2.5-coder"))
478        );
479        assert_eq!(
480            parse_custom_provider_prefix("groq/llama-3.3-70b"),
481            Some(("groq", "llama-3.3-70b"))
482        );
483        assert_eq!(
484            parse_custom_provider_prefix("openrouter/meta-llama/llama-3.1-8b:free"),
485            Some(("openrouter", "meta-llama/llama-3.1-8b:free"))
486        );
487        assert_eq!(parse_custom_provider_prefix("grok-3"), None);
488        assert_eq!(parse_custom_provider_prefix("sonnet"), None);
489        assert_eq!(parse_custom_provider_prefix("/model"), None);
490        assert_eq!(parse_custom_provider_prefix("provider/"), None);
491    }
492
493    #[test]
494    fn detects_custom_provider_kind() {
495        assert_eq!(
496            detect_provider_kind("ollama/qwen2.5-coder"),
497            ProviderKind::Custom
498        );
499        assert_eq!(
500            detect_provider_kind("lmstudio/my-model"),
501            ProviderKind::Custom
502        );
503    }
504
505    #[test]
506    fn resolves_custom_model_passthrough() {
507        assert_eq!(
508            resolve_model_alias("ollama/qwen2.5-coder"),
509            "ollama/qwen2.5-coder"
510        );
511    }
512
513    #[test]
514    fn custom_model_tokens_smaller_default() {
515        assert_eq!(max_tokens_for_model("ollama/qwen2.5-coder"), 16_000);
516    }
517
518    #[test]
519    fn builtin_presets_lookup() {
520        let ollama = builtin_preset("ollama").expect("ollama preset should exist");
521        assert_eq!(ollama.base_url, "http://localhost:11434/v1");
522        assert!(ollama.api_key_env.is_empty());
523
524        let groq = builtin_preset("groq").expect("groq preset should exist");
525        assert_eq!(groq.api_key_env, "GROQ_API_KEY");
526
527        assert!(builtin_preset("nonexistent").is_none());
528    }
529
530    #[test]
531    fn list_builtin_models_returns_all_when_unfiltered() {
532        let all = list_builtin_models(None);
533        assert!(!all.is_empty());
534        assert!(all.iter().any(|e| e.provider == ProviderKind::CodineerApi));
535        assert!(all.iter().any(|e| e.provider == ProviderKind::Xai));
536    }
537
538    #[test]
539    fn list_builtin_models_filters_by_provider() {
540        let xai = list_builtin_models(Some(ProviderKind::Xai));
541        assert!(!xai.is_empty());
542        assert!(xai.iter().all(|e| e.provider == ProviderKind::Xai));
543
544        let anthropic = list_builtin_models(Some(ProviderKind::CodineerApi));
545        assert!(!anthropic.is_empty());
546        assert!(anthropic
547            .iter()
548            .all(|e| e.provider == ProviderKind::CodineerApi));
549    }
550
551    #[test]
552    fn list_builtin_models_custom_filter_returns_empty() {
553        let custom = list_builtin_models(Some(ProviderKind::Custom));
554        assert!(custom.is_empty());
555    }
556
557    #[test]
558    fn provider_kind_by_name_resolves_known() {
559        assert_eq!(
560            provider_kind_by_name("anthropic"),
561            Some(ProviderKind::CodineerApi)
562        );
563        assert_eq!(
564            provider_kind_by_name("claude"),
565            Some(ProviderKind::CodineerApi)
566        );
567        assert_eq!(provider_kind_by_name("xai"), Some(ProviderKind::Xai));
568        assert_eq!(provider_kind_by_name("grok"), Some(ProviderKind::Xai));
569        assert_eq!(provider_kind_by_name("openai"), Some(ProviderKind::OpenAi));
570        assert_eq!(provider_kind_by_name("gpt"), Some(ProviderKind::OpenAi));
571    }
572
573    #[test]
574    fn provider_kind_by_name_case_insensitive() {
575        assert_eq!(
576            provider_kind_by_name("Anthropic"),
577            Some(ProviderKind::CodineerApi)
578        );
579        assert_eq!(provider_kind_by_name("XAI"), Some(ProviderKind::Xai));
580    }
581
582    #[test]
583    fn provider_kind_by_name_returns_none_for_unknown() {
584        assert_eq!(provider_kind_by_name("ollama"), None);
585        assert_eq!(provider_kind_by_name("unknown"), None);
586        assert_eq!(provider_kind_by_name(""), None);
587    }
588
589    #[test]
590    fn provider_kind_display_name_covers_all_variants() {
591        assert_eq!(ProviderKind::CodineerApi.display_name(), "Anthropic");
592        assert_eq!(ProviderKind::Xai.display_name(), "xAI");
593        assert_eq!(ProviderKind::OpenAi.display_name(), "OpenAI");
594        assert_eq!(ProviderKind::Custom.display_name(), "Custom");
595    }
596}