use axum::{extract::Query, routing::get, Json, Router};
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, SystemTime};
use crate::state::AppState;
#[derive(Debug, Deserialize)]
struct ModelsQuery {
provider: String,
}
struct ProviderModelConfig {
command: &'static str,
args: &'static [&'static str],
filter_fn: fn(&str) -> bool,
}
fn default_filter(line: &str) -> bool {
!line.is_empty() && line.contains('/')
}
fn provider_model_configs() -> HashMap<&'static str, ProviderModelConfig> {
let mut map = HashMap::new();
map.insert(
"opencode",
ProviderModelConfig {
command: "opencode",
args: &["models"],
filter_fn: default_filter,
},
);
map
}
struct ModelsCache {
by_provider: HashMap<String, (Vec<String>, SystemTime)>,
}
static MODELS_CACHE: OnceLock<Arc<Mutex<ModelsCache>>> = OnceLock::new();
fn get_models_cache() -> &'static Arc<Mutex<ModelsCache>> {
MODELS_CACHE.get_or_init(|| {
Arc::new(Mutex::new(ModelsCache {
by_provider: HashMap::new(),
}))
})
}
const MODELS_CACHE_TTL: Duration = Duration::from_secs(300);
pub fn router() -> Router<AppState> {
Router::new().route("/models", get(list_models))
}
async fn list_models(Query(query): Query<ModelsQuery>) -> Json<serde_json::Value> {
let provider = query.provider.as_str();
{
let cache = get_models_cache().lock().unwrap();
if let Some((models, ts)) = cache.by_provider.get(provider) {
if ts.elapsed().unwrap_or(MODELS_CACHE_TTL) < MODELS_CACHE_TTL {
return Json(serde_json::json!({ "models": models, "cached": true }));
}
}
}
let configs = provider_model_configs();
let Some(config) = configs.get(provider) else {
return Json(
serde_json::json!({ "models": [], "error": "Provider does not support model listing" }),
);
};
let resolved = match crate::shell_env::which(config.command) {
Some(p) => p,
None => {
return Json(serde_json::json!({
"models": [],
"error": format!("'{}' not found in PATH", config.command)
}));
}
};
let result = tokio::time::timeout(
Duration::from_secs(15),
tokio::process::Command::new(&resolved)
.args(config.args)
.env("PATH", crate::shell_env::full_path())
.output(),
)
.await;
let models: Vec<String> = match result {
Ok(Ok(output)) => {
let stdout = String::from_utf8_lossy(&output.stdout);
stdout
.lines()
.map(|l| l.trim().to_string())
.filter(|l| (config.filter_fn)(l))
.collect()
}
Ok(Err(e)) => {
tracing::warn!(
"[provider_models] Failed to run '{}': {}",
config.command,
e
);
return Json(serde_json::json!({ "models": [], "error": e.to_string() }));
}
Err(_) => {
tracing::warn!(
"[provider_models] Timeout listing models for '{}'",
provider
);
return Json(serde_json::json!({ "models": [], "error": "Timeout" }));
}
};
{
let mut cache = get_models_cache().lock().unwrap();
cache
.by_provider
.insert(provider.to_string(), (models.clone(), SystemTime::now()));
}
Json(serde_json::json!({ "models": models }))
}