routa_server/api/
provider_models.rs1use axum::{extract::Query, routing::get, Json, Router};
9use serde::Deserialize;
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex, OnceLock};
12use std::time::{Duration, SystemTime};
13
14use crate::state::AppState;
15
16#[derive(Debug, Deserialize)]
17struct ModelsQuery {
18 provider: String,
19}
20
21struct ProviderModelConfig {
23 command: &'static str,
25 args: &'static [&'static str],
27 filter_fn: fn(&str) -> bool,
30}
31
32fn default_filter(line: &str) -> bool {
33 !line.is_empty() && line.contains('/')
34}
35
36fn provider_model_configs() -> HashMap<&'static str, ProviderModelConfig> {
38 let mut map = HashMap::new();
39 map.insert(
40 "opencode",
41 ProviderModelConfig {
42 command: "opencode",
43 args: &["models"],
44 filter_fn: default_filter,
45 },
46 );
47 map
50}
51
52struct ModelsCache {
55 by_provider: HashMap<String, (Vec<String>, SystemTime)>,
56}
57
58static MODELS_CACHE: OnceLock<Arc<Mutex<ModelsCache>>> = OnceLock::new();
59
60fn get_models_cache() -> &'static Arc<Mutex<ModelsCache>> {
61 MODELS_CACHE.get_or_init(|| {
62 Arc::new(Mutex::new(ModelsCache {
63 by_provider: HashMap::new(),
64 }))
65 })
66}
67
68const MODELS_CACHE_TTL: Duration = Duration::from_secs(300); pub fn router() -> Router<AppState> {
73 Router::new().route("/models", get(list_models))
74}
75
76async fn list_models(Query(query): Query<ModelsQuery>) -> Json<serde_json::Value> {
77 let provider = query.provider.as_str();
78
79 {
81 let cache = get_models_cache().lock().unwrap();
82 if let Some((models, ts)) = cache.by_provider.get(provider) {
83 if ts.elapsed().unwrap_or(MODELS_CACHE_TTL) < MODELS_CACHE_TTL {
84 return Json(serde_json::json!({ "models": models, "cached": true }));
85 }
86 }
87 }
88
89 let configs = provider_model_configs();
90 let Some(config) = configs.get(provider) else {
91 return Json(
92 serde_json::json!({ "models": [], "error": "Provider does not support model listing" }),
93 );
94 };
95
96 let resolved = match crate::shell_env::which(config.command) {
97 Some(p) => p,
98 None => {
99 return Json(serde_json::json!({
100 "models": [],
101 "error": format!("'{}' not found in PATH", config.command)
102 }));
103 }
104 };
105
106 let result = tokio::time::timeout(
107 Duration::from_secs(15),
108 tokio::process::Command::new(&resolved)
109 .args(config.args)
110 .env("PATH", crate::shell_env::full_path())
111 .output(),
112 )
113 .await;
114
115 let models: Vec<String> = match result {
116 Ok(Ok(output)) => {
117 let stdout = String::from_utf8_lossy(&output.stdout);
118 stdout
119 .lines()
120 .map(|l| l.trim().to_string())
121 .filter(|l| (config.filter_fn)(l))
122 .collect()
123 }
124 Ok(Err(e)) => {
125 tracing::warn!(
126 "[provider_models] Failed to run '{}': {}",
127 config.command,
128 e
129 );
130 return Json(serde_json::json!({ "models": [], "error": e.to_string() }));
131 }
132 Err(_) => {
133 tracing::warn!(
134 "[provider_models] Timeout listing models for '{}'",
135 provider
136 );
137 return Json(serde_json::json!({ "models": [], "error": "Timeout" }));
138 }
139 };
140
141 {
143 let mut cache = get_models_cache().lock().unwrap();
144 cache
145 .by_provider
146 .insert(provider.to_string(), (models.clone(), SystemTime::now()));
147 }
148
149 Json(serde_json::json!({ "models": models }))
150}