use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use crate::schema::{CostModel, ModelSchema, TrustTier};
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct DiscoveryDoc {
#[serde(default)]
pub models: Vec<ModelSchema>,
}
pub fn cache_path(models_dir: &Path) -> PathBuf {
models_dir.join("discovered_models.json")
}
pub fn load_cache(path: &Path) -> Vec<ModelSchema> {
std::fs::read_to_string(path)
.ok()
.and_then(|s| serde_json::from_str::<DiscoveryDoc>(&s).ok())
.map(|d| d.models)
.unwrap_or_default()
}
pub fn save_cache(path: &Path, models: &[ModelSchema]) -> Result<(), String> {
let doc = DiscoveryDoc { models: models.to_vec() };
let json = serde_json::to_string_pretty(&doc).map_err(|e| e.to_string())?;
std::fs::write(path, json).map_err(|e| e.to_string())
}
pub fn is_chat_model(id: &str) -> bool {
let id = id.to_ascii_lowercase();
let is_gen_family = id.starts_with("gpt-")
|| id.starts_with("chatgpt-")
|| (id.starts_with('o') && id[1..].chars().next().is_some_and(|c| c.is_ascii_digit()));
if !is_gen_family {
return false;
}
const DENY: &[&str] = &[
"embedding", "audio", "realtime", "transcribe", "tts", "image",
"moderation", "search", "dall", "whisper", "instruct",
];
!DENY.iter().any(|d| id.contains(d))
}
pub fn models_url_from_endpoint(endpoint: &str) -> Option<String> {
let marker = "/v1";
let idx = endpoint.find(marker)?;
let base = &endpoint[..idx + marker.len()];
Some(format!("{base}/models"))
}
pub fn discovered_schema(provider: &str, bare_id: &str, template: &ModelSchema) -> ModelSchema {
let mut s = template.clone();
s.id = format!("{provider}/{bare_id}:latest");
s.name = bare_id.to_string();
s.version = bare_id.to_string();
s.trust_tier = TrustTier::Community;
s.cost = CostModel::default();
s.public_benchmarks = vec![];
s.deprecated = false;
s.available = false; s
}
pub async fn fetch_model_ids(
http: &reqwest::Client,
models_url: &str,
api_key: &str,
) -> Result<Vec<String>, String> {
let resp = http
.get(models_url)
.bearer_auth(api_key)
.send()
.await
.map_err(|e| format!("discovery fetch: {e}"))?;
if !resp.status().is_success() {
return Err(format!("discovery fetch: HTTP {}", resp.status()));
}
let body: serde_json::Value =
resp.json().await.map_err(|e| format!("discovery parse: {e}"))?;
Ok(body
.get("data")
.and_then(|d| d.as_array())
.map(|arr| {
arr.iter()
.filter_map(|m| m.get("id").and_then(|i| i.as_str()).map(String::from))
.collect()
})
.unwrap_or_default())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chat_models_pass_non_generation_rejected() {
for ok in ["gpt-5.5", "gpt-5.5-mini", "gpt-4.1", "o3", "o4-mini", "chatgpt-4o-latest"] {
assert!(is_chat_model(ok), "{ok} should be a chat model");
}
for no in [
"text-embedding-3-large", "gpt-4o-audio-preview", "gpt-realtime",
"dall-e-3", "whisper-1", "omni-moderation-latest", "gpt-4o-transcribe",
"tts-1", "gpt-image-1", "babbage-002",
] {
assert!(!is_chat_model(no), "{no} should be rejected");
}
}
#[test]
fn models_url_derives_from_chat_endpoint() {
assert_eq!(
models_url_from_endpoint("https://api.openai.com/v1/chat/completions").as_deref(),
Some("https://api.openai.com/v1/models"),
);
assert_eq!(
models_url_from_endpoint("https://api.openai.com/v1/responses").as_deref(),
Some("https://api.openai.com/v1/models"),
);
assert_eq!(models_url_from_endpoint("https://example.com/openai"), None);
}
#[test]
fn discovered_schema_inherits_template_and_marks_community() {
let catalog: Vec<ModelSchema> =
serde_json::from_str(include_str!("builtin_catalog.json")).unwrap();
let template = catalog
.iter()
.find(|m| m.id == "openai/gpt-5.4:latest")
.expect("gpt-5.4 present in builtin catalog");
let s = discovered_schema("openai", "gpt-5.5", template);
assert_eq!(s.id, "openai/gpt-5.5:latest");
assert_eq!(s.name, "gpt-5.5");
assert_eq!(s.provider, "openai"); assert_eq!(s.trust_tier, TrustTier::Community);
assert!(s.cost.input_per_mtok.is_none()); assert!(!s.available);
assert!(!s.capabilities.is_empty()); }
}