Skip to main content

sparrow/provider/
detect.rs

1//! Provider auto-detection: scan environment for API keys, test connectivity
2//! with lightweight API calls, rank providers by cost tier (free > paid),
3//! and return a list of ready-to-use providers.
4//!
5//! Integrates with the first-run wizard in [`crate::onboarding::wizard`].
6
7use crate::config::providers::{ProviderDef, find_provider, provider_registry};
8use serde::{Deserialize, Serialize};
9
10// ─── Detection result types ──────────────────────────────────────────────────
11
12/// The result of scanning for one provider.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct DetectedProvider {
15    /// Registry id (e.g. "anthropic", "nvidia")
16    pub id: String,
17    /// Human label (e.g. "Anthropic", "NVIDIA NIM")
18    pub label: String,
19    /// Whether we found an API key in the environment
20    pub key_found: bool,
21    /// The env var name if applicable
22    pub env_var: Option<String>,
23    /// Cost tier
24    pub tier: ProviderTier,
25    /// Whether we successfully validated the key with a lightweight API call
26    pub validated: Option<bool>,
27    /// Error message if validation failed
28    pub validation_error: Option<String>,
29    /// Signup URL for getting a key
30    pub signup_url: Option<String>,
31    /// Whether this provider is recommended for the user
32    pub recommended: bool,
33    /// Short description for the wizard UI
34    pub description: String,
35}
36
37/// Cost tier for ranking.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
39pub enum ProviderTier {
40    /// Completely free (NVIDIA NIM, Groq free tier, Gemini free tier)
41    Free = 0,
42    /// Has a generous free tier (some paid models but free tier exists)
43    FreeTier = 1,
44    /// Paid but cheap (DeepSeek, etc.)
45    Cheap = 2,
46    /// Paid, standard pricing
47    Paid = 3,
48    /// Requires signup / no key found
49    RequiresSignup = 4,
50}
51
52// ─── Known API key environment variables ─────────────────────────────────────
53
54/// List of all known API key env vars and their associated provider ids.
55const KNOWN_API_KEY_ENVS: &[(&str, &str)] = &[
56    ("OPENAI_API_KEY", "openai-codex"),
57    ("ANTHROPIC_API_KEY", "anthropic"),
58    ("GEMINI_API_KEY", "gemini"),
59    ("GROQ_API_KEY", "groq"),
60    ("NVIDIA_API_KEY", "nvidia"),
61    ("DEEPSEEK_API_KEY", "deepseek"),
62    ("OPENROUTER_API_KEY", "openrouter"),
63    ("XAI_API_KEY", "xai"),
64    ("HF_TOKEN", "huggingface"),
65    ("NOUS_API_KEY", "nous"),
66    ("NOVITA_API_KEY", "novita"),
67    ("DASHSCOPE_API_KEY", "alibaba"),
68    ("MOONSHOT_API_KEY", "kimi-coding"),
69    ("MISTRAL_API_KEY", "mistral"),
70    ("TOGETHER_API_KEY", "together"),
71    ("CEREBRAS_API_KEY", "cerebras"),
72    ("FIREWORKS_API_KEY", "fireworks"),
73    ("PERPLEXITY_API_KEY", "perplexity"),
74    ("COHERE_API_KEY", "cohere"),
75    ("AWS_ACCESS_KEY_ID", "bedrock"),
76    ("COPILOT_TOKEN", "copilot"),
77];
78
79// ─── Environment scanning ────────────────────────────────────────────────────
80
81/// Scan the environment for all known API keys.
82///
83/// Returns a map of provider_id → (env_var_name, key_value).
84pub fn scan_environment() -> Vec<(&'static str, &'static str, String)> {
85    KNOWN_API_KEY_ENVS
86        .iter()
87        .filter_map(|&(env_var, provider_id)| {
88            std::env::var(env_var)
89                .ok()
90                .filter(|v| !v.trim().is_empty())
91                .map(|key| (provider_id, env_var, key))
92        })
93        .collect()
94}
95
96/// Detect all providers with their status.
97///
98/// Returns a vector of [`DetectedProvider`] sorted by tier (free first), then
99/// by whether a key was found.
100pub fn detect_all_providers() -> Vec<DetectedProvider> {
101    let env_keys = scan_environment();
102    let env_keys_map: std::collections::HashMap<&str, (&str, String)> = env_keys
103        .iter()
104        .map(|(pid, env, key)| (*pid, (*env, key.clone())))
105        .collect();
106
107    let mut providers: Vec<DetectedProvider> = Vec::new();
108
109    for def in provider_registry() {
110        let key_info = env_keys_map.get(def.id.as_str());
111
112        let tier = classify_tier(&def);
113        let signup_url = signup_url_for(&def);
114
115        let description = match tier {
116            ProviderTier::Free => format!(
117                "Gratuit — {}. Modèle recommandé : {}",
118                def.notes.trim_end_matches('.'),
119                def.models
120                    .iter()
121                    .find(|m| m.recommended)
122                    .map(|m| m.name.as_str())
123                    .unwrap_or(def.models.first().map(|m| m.name.as_str()).unwrap_or("N/A")),
124            ),
125            _ => def.notes.clone(),
126        };
127
128        providers.push(DetectedProvider {
129            id: def.id.clone(),
130            label: def.label.clone(),
131            key_found: key_info.is_some(),
132            env_var: key_info.map(|(env, _)| env.to_string()),
133            tier,
134            validated: None,
135            validation_error: None,
136            signup_url,
137            recommended: def.models.iter().any(|m| m.recommended),
138            description,
139        });
140    }
141
142    // Sort: free first, then by key found
143    providers.sort_by(|a, b| {
144        a.tier
145            .cmp(&b.tier)
146            .then_with(|| b.key_found.cmp(&a.key_found)) // key found first within same tier
147            .then_with(|| a.label.cmp(&b.label))
148    });
149
150    providers
151}
152
153/// Classify a provider definition into a cost tier.
154fn classify_tier(def: &ProviderDef) -> ProviderTier {
155    // Check tags first
156    if def.tags.iter().any(|t| t == "free") {
157        return ProviderTier::Free;
158    }
159
160    // Check if any model has zero cost (free tier)
161    let all_free = !def.models.is_empty()
162        && def
163            .models
164            .iter()
165            .all(|m| m.cost_input_per_mtok == 0.0 && m.cost_output_per_mtok == 0.0);
166
167    if all_free {
168        return ProviderTier::Free;
169    }
170
171    let has_free_models = def
172        .models
173        .iter()
174        .any(|m| m.cost_input_per_mtok == 0.0 && m.cost_output_per_mtok == 0.0);
175
176    if has_free_models {
177        return ProviderTier::FreeTier;
178    }
179
180    // Cheap if cheapest model < $1/M input tokens
181    let cheapest_input = def
182        .models
183        .iter()
184        .map(|m| m.cost_input_per_mtok)
185        .fold(f64::MAX, f64::min);
186
187    if cheapest_input < 1.0 {
188        return ProviderTier::Cheap;
189    }
190
191    ProviderTier::Paid
192}
193
194/// Get the signup URL for a provider.
195fn signup_url_for(def: &ProviderDef) -> Option<String> {
196    match def.id.as_str() {
197        "anthropic" => Some("https://console.anthropic.com/settings/keys".into()),
198        "openai-codex" => Some("https://platform.openai.com/api-keys".into()),
199        "gemini" => Some("https://aistudio.google.com/app/apikey".into()),
200        "groq" => Some("https://console.groq.com/keys".into()),
201        "nvidia" => Some("https://build.nvidia.com/explore/discover".into()),
202        "deepseek" => Some("https://platform.deepseek.com/api_keys".into()),
203        "openrouter" => Some("https://openrouter.ai/keys".into()),
204        "xai" => Some("https://console.x.ai/".into()),
205        "huggingface" => Some("https://huggingface.co/settings/tokens".into()),
206        "nous" => Some("https://portal.nousresearch.com".into()),
207        "novita" => Some("https://novita.ai/dashboard/key".into()),
208        "alibaba" => Some("https://bailian.console.aliyun.com".into()),
209        "kimi-coding" => Some("https://platform.moonshot.cn/console".into()),
210        "mistral" => Some("https://console.mistral.ai/api-keys/".into()),
211        "together" => Some("https://api.together.xyz/settings/api-keys".into()),
212        "cerebras" => Some("https://cloud.cerebras.ai/".into()),
213        "fireworks" => Some("https://fireworks.ai/api-keys".into()),
214        "perplexity" => Some("https://www.perplexity.ai/settings/api".into()),
215        "cohere" => Some("https://dashboard.cohere.com/api-keys".into()),
216        _ => None,
217    }
218}
219
220/// Check if the `gh` CLI is installed (for GitHub Copilot integration).
221pub fn gh_cli_installed() -> bool {
222    std::process::Command::new("gh")
223        .arg("--version")
224        .stdout(std::process::Stdio::null())
225        .stderr(std::process::Stdio::null())
226        .status()
227        .map(|s| s.success())
228        .unwrap_or(false)
229}
230
231/// Test a provider's API key with a minimal request.
232///
233/// Returns `Ok(())` if the key is valid, `Err(msg)` with a French error
234/// message otherwise.
235pub async fn validate_api_key(provider_id: &str, api_key: &str) -> Result<(), String> {
236    let def = match find_provider(provider_id) {
237        Some(d) => d,
238        None => {
239            return Err(format!(
240                "Provider \"{provider_id}\" inconnu dans le registre Sparrow."
241            ));
242        }
243    };
244
245    // Build a minimal request based on the adapter type
246    match def.adapter.as_str() {
247        "anthropic-messages" => validate_anthropic_key(api_key).await,
248        "openai-compatible" => validate_openai_compatible_key(&def.base_url, api_key).await,
249        "ollama" => {
250            // Ollama doesn't need an API key — just check if it's reachable
251            validate_ollama_connection(&def.base_url).await
252        }
253        _ => validate_openai_compatible_key(&def.base_url, api_key).await,
254    }
255}
256
257/// Validate an Anthropic API key with a GET to /v1/models.
258async fn validate_anthropic_key(api_key: &str) -> Result<(), String> {
259    let client = reqwest::Client::builder()
260        .timeout(std::time::Duration::from_secs(10))
261        .build()
262        .map_err(|e| format!("Erreur client HTTP: {e}"))?;
263
264    let resp = client
265        .get("https://api.anthropic.com/v1/models?limit=1")
266        .header("x-api-key", api_key)
267        .header("anthropic-version", "2023-06-01")
268        .send()
269        .await
270        .map_err(|e| {
271            if e.is_timeout() {
272                "Timeout — le serveur Anthropic ne répond pas. Check ta connexion.".into()
273            } else if e.is_connect() {
274                "Impossible de contacter api.anthropic.com. Vérifie ta connexion ou VPN.".into()
275            } else {
276                format!("Erreur réseau : {e}")
277            }
278        })?;
279
280    match resp.status().as_u16() {
281        200 => Ok(()),
282        401 | 403 => Err("Clé API Anthropic invalide. Vérifie ta clé sur https://console.anthropic.com/settings/keys".into()),
283        429 => Err("Rate limit Anthropic — trop de requêtes. Réessaie dans quelques secondes.".into()),
284        s => Err(format!("Erreur HTTP {s} du serveur Anthropic.")),
285    }
286}
287
288/// Validate an OpenAI-compatible API key with a GET to /v1/models?limit=1.
289async fn validate_openai_compatible_key(base_url: &str, api_key: &str) -> Result<(), String> {
290    let client = reqwest::Client::builder()
291        .timeout(std::time::Duration::from_secs(10))
292        .build()
293        .map_err(|e| format!("Erreur client HTTP: {e}"))?;
294
295    let url = format!("{}/models?limit=1", base_url.trim_end_matches('/'));
296
297    let resp = client
298        .get(&url)
299        .bearer_auth(api_key)
300        .send()
301        .await
302        .map_err(|e| {
303            if e.is_timeout() {
304                format!("Timeout — le serveur à {url} ne répond pas. Check ta connexion.")
305            } else if e.is_connect() {
306                format!("Impossible de contacter {url}. Vérifie ta connexion ou VPN.")
307            } else {
308                format!("Erreur réseau : {e}")
309            }
310        })?;
311
312    match resp.status().as_u16() {
313        200 => Ok(()),
314        401 | 403 => Err("Clé API invalide. Vérifie ta clé.".into()),
315        404 => {
316            // Some providers don't have /v1/models — try a chat completions endpoint instead
317            validate_with_chat_request(base_url, api_key).await
318        }
319        429 => Err("Rate limit — trop de requêtes. Réessaie dans quelques secondes.".into()),
320        s => Err(format!("Erreur HTTP {s}.")),
321    }
322}
323
324/// Fallback validation: send a minimal chat completion request (1 token max).
325async fn validate_with_chat_request(base_url: &str, api_key: &str) -> Result<(), String> {
326    let client = reqwest::Client::builder()
327        .timeout(std::time::Duration::from_secs(10))
328        .build()
329        .map_err(|e| format!("Erreur client HTTP: {e}"))?;
330
331    let url = format!("{}/chat/completions", base_url.trim_end_matches('/'));
332
333    let body = serde_json::json!({
334        "model": "gpt-3.5-turbo",  // widely supported model name for testing
335        "messages": [{"role": "user", "content": "hi"}],
336        "max_tokens": 1,
337        "temperature": 0.0,
338    });
339
340    let resp = client
341        .post(&url)
342        .bearer_auth(api_key)
343        .json(&body)
344        .send()
345        .await
346        .map_err(|e| {
347            if e.is_timeout() {
348                "Timeout — le serveur ne répond pas.".into()
349            } else if e.is_connect() {
350                format!("Impossible de contacter {url}.")
351            } else {
352                format!("Erreur réseau : {e}")
353            }
354        })?;
355
356    match resp.status().as_u16() {
357        200 => Ok(()),
358        401 | 403 => Err("Clé API invalide.".into()),
359        404 => Err(
360            "Endpoint chat/completions introuvable. L'URL de base est peut-être incorrecte.".into(),
361        ),
362        429 => Err("Rate limit — trop de requêtes.".into()),
363        s => {
364            // Even 400/422 is "good" — it means the key was accepted, just the model
365            // name was wrong.
366            if s == 400 || s == 422 {
367                Ok(())
368            } else {
369                Err(format!("Erreur HTTP {s}."))
370            }
371        }
372    }
373}
374
375/// Validate an Ollama connection (no API key needed).
376async fn validate_ollama_connection(base_url: &str) -> Result<(), String> {
377    let client = reqwest::Client::builder()
378        .timeout(std::time::Duration::from_secs(5))
379        .build()
380        .map_err(|e| format!("Erreur client HTTP: {e}"))?;
381
382    let root = base_url.trim_end_matches('/').trim_end_matches("/v1");
383    let url = format!("{root}/api/tags");
384
385    let resp = client.get(&url).send().await.map_err(|e| {
386        if e.is_connect() {
387            format!(
388                "Ollama ne tourne pas sur {root}.\n\
389                 → Lance `ollama serve` dans un autre terminal.\n\
390                 → Ou installe Ollama : https://ollama.com"
391            )
392        } else {
393            format!("Erreur réseau : {e}")
394        }
395    })?;
396
397    match resp.status().as_u16() {
398        200 => Ok(()),
399        s => Err(format!(
400            "Ollama a répondu HTTP {s}. Vérifie que le serveur tourne."
401        )),
402    }
403}
404
405/// Run validation for all detected providers with keys.
406///
407/// Returns the list of providers with `validated` fields populated.
408pub async fn validate_detected_providers(providers: &mut [DetectedProvider]) {
409    for p in providers.iter_mut() {
410        if !p.key_found {
411            p.validated = Some(false);
412            p.validation_error = Some("Aucune clé API trouvée dans l'environnement.".into());
413            continue;
414        }
415
416        let env_var = match &p.env_var {
417            Some(env) => env.clone(),
418            None => {
419                p.validated = Some(false);
420                p.validation_error = Some("Variable d'environnement inconnue.".into());
421                continue;
422            }
423        };
424
425        let api_key = match std::env::var(&env_var) {
426            Ok(k) if !k.trim().is_empty() => k,
427            _ => {
428                p.validated = Some(false);
429                p.validation_error = Some(format!("Variable {env_var} vide."));
430                continue;
431            }
432        };
433
434        match validate_api_key(&p.id, &api_key).await {
435            Ok(()) => {
436                p.validated = Some(true);
437                p.validation_error = None;
438            }
439            Err(e) => {
440                p.validated = Some(false);
441                p.validation_error = Some(e);
442            }
443        }
444    }
445}
446
447/// Get a summary list of ready-to-use providers (key found + validated).
448pub fn ready_providers(providers: &[DetectedProvider]) -> Vec<&DetectedProvider> {
449    providers
450        .iter()
451        .filter(|p| p.key_found && p.validated == Some(true))
452        .collect()
453}
454
455/// Get a list of free providers (regardless of key status), for suggestions.
456pub fn free_providers(providers: &[DetectedProvider]) -> Vec<&DetectedProvider> {
457    providers
458        .iter()
459        .filter(|p| matches!(p.tier, ProviderTier::Free))
460        .collect()
461}