use super::cache::CacheManager;
use super::error::Result;
use super::{DiscoveredModel, DiscoveredProvider};
use reqwest::Client;
use serde::Deserialize;
const MODELS_DEV_API_URL: &str = "https://models.dev/api.json";
pub struct ModelsDevClient {
client: Client,
api_url: String,
}
impl Default for ModelsDevClient {
fn default() -> Self {
Self::new()
}
}
impl ModelsDevClient {
pub fn new() -> Self {
Self {
client: Client::new(),
api_url: MODELS_DEV_API_URL.to_string(),
}
}
pub fn with_url(api_url: String) -> Self {
Self {
client: Client::new(),
api_url,
}
}
pub async fn fetch_providers(&self) -> Result<Vec<DiscoveredProvider>> {
let response = self
.client
.get(&self.api_url)
.send()
.await
.map_err(|e| super::error::DiscoveryError::Network(e.to_string()))?;
let providers: ModelsDevResponse = response
.json()
.await
.map_err(|e| super::error::DiscoveryError::Parse(e.to_string()))?;
Ok(providers.into_providers())
}
pub async fn fetch_providers_cached(
&self,
force_refresh: bool,
) -> Result<Vec<DiscoveredProvider>> {
let cache = CacheManager::new()?;
let cache_key = "models_dev_providers";
if !force_refresh {
if let Some(providers) = cache.get::<Vec<DiscoveredProvider>>(cache_key)? {
return Ok(providers);
}
}
let providers = self.fetch_providers().await?;
cache.set(cache_key, &providers)?;
Ok(providers)
}
pub async fn fetch_provider_models(&self, provider_id: &str) -> Result<Vec<DiscoveredModel>> {
let providers = self.fetch_providers().await?;
Ok(providers
.into_iter()
.find(|p| p.id == provider_id)
.map(|p| p.models)
.unwrap_or_default())
}
pub async fn fetch_provider_models_cached(
&self,
provider_id: &str,
force_refresh: bool,
) -> Result<Vec<DiscoveredModel>> {
let providers = self.fetch_providers_cached(force_refresh).await?;
Ok(providers
.into_iter()
.find(|p| p.id == provider_id)
.map(|p| p.models)
.unwrap_or_default())
}
}
#[derive(Debug, Deserialize)]
struct ModelsDevResponse {
#[serde(flatten)]
providers: HashMap<String, ModelsDevProvider>,
}
#[derive(Debug, Deserialize)]
struct ModelsDevProvider {
name: String,
#[serde(default)]
models: HashMap<String, ModelsDevModel>,
}
#[derive(Debug, Deserialize)]
struct ModelsDevModel {
name: Option<String>,
context_length: Option<u64>,
max_output_tokens: Option<u64>,
pricing: Option<ModelsDevPricing>,
}
#[derive(Debug, Deserialize)]
struct ModelsDevPricing {
prompt: Option<String>,
completion: Option<String>,
}
use std::collections::HashMap;
impl ModelsDevResponse {
fn into_providers(self) -> Vec<DiscoveredProvider> {
self.providers
.into_iter()
.map(|(id, provider)| DiscoveredProvider {
id: id.clone(),
name: provider.name.clone(),
models: provider
.models
.into_iter()
.map(|(model_id, model)| DiscoveredModel {
id: model_id,
name: model.name.unwrap_or_default(),
provider_id: id.clone(),
context_length: model.context_length,
max_output_tokens: model.max_output_tokens,
input_cost_per_million: model
.pricing
.as_ref()
.and_then(|p| p.prompt.as_ref()?.parse::<f64>().ok()),
output_cost_per_million: model
.pricing
.as_ref()
.and_then(|p| p.completion.as_ref()?.parse::<f64>().ok()),
})
.collect(),
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = ModelsDevClient::new();
assert_eq!(client.api_url, MODELS_DEV_API_URL);
}
#[test]
fn test_client_custom_url() {
let client = ModelsDevClient::with_url("http://localhost:8080/api.json".to_string());
assert_eq!(client.api_url, "http://localhost:8080/api.json");
}
}