Skip to main content

routa_server/api/
provider_models.rs

1//! Provider Models API
2//!
3//! GET /api/providers/models?provider=<id>
4//!
5//! Runs the provider's model listing command and returns available models.
6//! Designed to be extensible: each provider can define its own model listing command.
7
8use axum::{extract::Query, routing::get, Json, Router};
9use serde::Deserialize;
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex, OnceLock};
12use std::time::{Duration, SystemTime};
13
14use crate::state::AppState;
15
16#[derive(Debug, Deserialize)]
17struct ModelsQuery {
18    provider: String,
19}
20
21/// Describes how to list models for a provider.
22struct ProviderModelConfig {
23    /// The CLI command to run (e.g., "opencode")
24    command: &'static str,
25    /// Arguments to pass (e.g., ["models"])
26    args: &'static [&'static str],
27    /// How to parse a line of output into a model ID (None = use line as-is)
28    /// Lines that don't contain '/' are filtered out (not valid model IDs)
29    filter_fn: fn(&str) -> bool,
30}
31
32fn default_filter(line: &str) -> bool {
33    !line.is_empty() && line.contains('/')
34}
35
36/// Registry of providers that support model listing.
37fn provider_model_configs() -> HashMap<&'static str, ProviderModelConfig> {
38    let mut map = HashMap::new();
39    map.insert(
40        "opencode",
41        ProviderModelConfig {
42            command: "opencode",
43            args: &["models"],
44            filter_fn: default_filter,
45        },
46    );
47    // Future providers can be added here, e.g.:
48    // map.insert("gemini", ProviderModelConfig { command: "gemini", args: &["models", "--list"], filter_fn: ... });
49    map
50}
51
52// ─── Cache ───────────────────────────────────────────────────────────────────
53
54struct ModelsCache {
55    by_provider: HashMap<String, (Vec<String>, SystemTime)>,
56}
57
58static MODELS_CACHE: OnceLock<Arc<Mutex<ModelsCache>>> = OnceLock::new();
59
60fn get_models_cache() -> &'static Arc<Mutex<ModelsCache>> {
61    MODELS_CACHE.get_or_init(|| {
62        Arc::new(Mutex::new(ModelsCache {
63            by_provider: HashMap::new(),
64        }))
65    })
66}
67
68const MODELS_CACHE_TTL: Duration = Duration::from_secs(300); // 5 minutes
69
70// ─── Router ──────────────────────────────────────────────────────────────────
71
72pub fn router() -> Router<AppState> {
73    Router::new().route("/models", get(list_models))
74}
75
76async fn list_models(Query(query): Query<ModelsQuery>) -> Json<serde_json::Value> {
77    let provider = query.provider.as_str();
78
79    // Check cache
80    {
81        let cache = get_models_cache().lock().unwrap();
82        if let Some((models, ts)) = cache.by_provider.get(provider) {
83            if ts.elapsed().unwrap_or(MODELS_CACHE_TTL) < MODELS_CACHE_TTL {
84                return Json(serde_json::json!({ "models": models, "cached": true }));
85            }
86        }
87    }
88
89    let configs = provider_model_configs();
90    let Some(config) = configs.get(provider) else {
91        return Json(
92            serde_json::json!({ "models": [], "error": "Provider does not support model listing" }),
93        );
94    };
95
96    let resolved = match crate::shell_env::which(config.command) {
97        Some(p) => p,
98        None => {
99            return Json(serde_json::json!({
100                "models": [],
101                "error": format!("'{}' not found in PATH", config.command)
102            }));
103        }
104    };
105
106    let result = tokio::time::timeout(
107        Duration::from_secs(15),
108        tokio::process::Command::new(&resolved)
109            .args(config.args)
110            .env("PATH", crate::shell_env::full_path())
111            .output(),
112    )
113    .await;
114
115    let models: Vec<String> = match result {
116        Ok(Ok(output)) => {
117            let stdout = String::from_utf8_lossy(&output.stdout);
118            stdout
119                .lines()
120                .map(|l| l.trim().to_string())
121                .filter(|l| (config.filter_fn)(l))
122                .collect()
123        }
124        Ok(Err(e)) => {
125            tracing::warn!(
126                "[provider_models] Failed to run '{}': {}",
127                config.command,
128                e
129            );
130            return Json(serde_json::json!({ "models": [], "error": e.to_string() }));
131        }
132        Err(_) => {
133            tracing::warn!(
134                "[provider_models] Timeout listing models for '{}'",
135                provider
136            );
137            return Json(serde_json::json!({ "models": [], "error": "Timeout" }));
138        }
139    };
140
141    // Update cache
142    {
143        let mut cache = get_models_cache().lock().unwrap();
144        cache
145            .by_provider
146            .insert(provider.to_string(), (models.clone(), SystemTime::now()));
147    }
148
149    Json(serde_json::json!({ "models": models }))
150}