Skip to main content

opencode_provider_manager/discovery/
models_dev.rs

1//! models.dev API client for fetching provider and model catalogs.
2//!
3//! Endpoint: https://models.dev/api.json
4
5use super::cache::CacheManager;
6use super::error::Result;
7use super::{DiscoveredModel, DiscoveredProvider};
8use reqwest::Client;
9use serde::Deserialize;
10
11const MODELS_DEV_API_URL: &str = "https://models.dev/api.json";
12
13/// Client for the models.dev API.
14pub struct ModelsDevClient {
15    client: Client,
16    api_url: String,
17}
18
19impl Default for ModelsDevClient {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl ModelsDevClient {
26    /// Create a new client with the default API URL.
27    pub fn new() -> Self {
28        Self {
29            client: Client::new(),
30            api_url: MODELS_DEV_API_URL.to_string(),
31        }
32    }
33
34    /// Create a new client with a custom API URL (for testing).
35    pub fn with_url(api_url: String) -> Self {
36        Self {
37            client: Client::new(),
38            api_url,
39        }
40    }
41
42    /// Fetch all providers and their models from models.dev.
43    pub async fn fetch_providers(&self) -> Result<Vec<DiscoveredProvider>> {
44        let response = self
45            .client
46            .get(&self.api_url)
47            .send()
48            .await
49            .map_err(|e| super::error::DiscoveryError::Network(e.to_string()))?;
50
51        let providers: ModelsDevResponse = response
52            .json()
53            .await
54            .map_err(|e| super::error::DiscoveryError::Parse(e.to_string()))?;
55
56        Ok(providers.into_providers())
57    }
58
59    /// Fetch all providers from the local cache when available, otherwise from models.dev.
60    pub async fn fetch_providers_cached(
61        &self,
62        force_refresh: bool,
63    ) -> Result<Vec<DiscoveredProvider>> {
64        let cache = CacheManager::new()?;
65        let cache_key = "models_dev_providers";
66
67        if !force_refresh {
68            if let Some(providers) = cache.get::<Vec<DiscoveredProvider>>(cache_key)? {
69                return Ok(providers);
70            }
71        }
72
73        let providers = self.fetch_providers().await?;
74        cache.set(cache_key, &providers)?;
75        Ok(providers)
76    }
77
78    /// Fetch a specific provider's models.
79    pub async fn fetch_provider_models(&self, provider_id: &str) -> Result<Vec<DiscoveredModel>> {
80        let providers = self.fetch_providers().await?;
81        Ok(providers
82            .into_iter()
83            .find(|p| p.id == provider_id)
84            .map(|p| p.models)
85            .unwrap_or_default())
86    }
87
88    /// Fetch a provider's models through the models.dev cache.
89    pub async fn fetch_provider_models_cached(
90        &self,
91        provider_id: &str,
92        force_refresh: bool,
93    ) -> Result<Vec<DiscoveredModel>> {
94        let providers = self.fetch_providers_cached(force_refresh).await?;
95        Ok(providers
96            .into_iter()
97            .find(|p| p.id == provider_id)
98            .map(|p| p.models)
99            .unwrap_or_default())
100    }
101}
102
103/// Internal representation of the models.dev API response.
104#[derive(Debug, Deserialize)]
105struct ModelsDevResponse {
106    #[serde(flatten)]
107    providers: HashMap<String, ModelsDevProvider>,
108}
109
110#[derive(Debug, Deserialize)]
111struct ModelsDevProvider {
112    name: String,
113    #[serde(default)]
114    models: HashMap<String, ModelsDevModel>,
115}
116
117#[derive(Debug, Deserialize)]
118struct ModelsDevModel {
119    name: Option<String>,
120    context_length: Option<u64>,
121    max_output_tokens: Option<u64>,
122    pricing: Option<ModelsDevPricing>,
123}
124
125#[derive(Debug, Deserialize)]
126struct ModelsDevPricing {
127    prompt: Option<String>,
128    completion: Option<String>,
129}
130
131use std::collections::HashMap;
132
133impl ModelsDevResponse {
134    fn into_providers(self) -> Vec<DiscoveredProvider> {
135        self.providers
136            .into_iter()
137            .map(|(id, provider)| DiscoveredProvider {
138                id: id.clone(),
139                name: provider.name.clone(),
140                models: provider
141                    .models
142                    .into_iter()
143                    .map(|(model_id, model)| DiscoveredModel {
144                        id: model_id,
145                        name: model.name.unwrap_or_default(),
146                        provider_id: id.clone(),
147                        context_length: model.context_length,
148                        max_output_tokens: model.max_output_tokens,
149                        input_cost_per_million: model
150                            .pricing
151                            .as_ref()
152                            .and_then(|p| p.prompt.as_ref()?.parse::<f64>().ok()),
153                        output_cost_per_million: model
154                            .pricing
155                            .as_ref()
156                            .and_then(|p| p.completion.as_ref()?.parse::<f64>().ok()),
157                    })
158                    .collect(),
159            })
160            .collect()
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn test_client_creation() {
170        let client = ModelsDevClient::new();
171        assert_eq!(client.api_url, MODELS_DEV_API_URL);
172    }
173
174    #[test]
175    fn test_client_custom_url() {
176        let client = ModelsDevClient::with_url("http://localhost:8080/api.json".to_string());
177        assert_eq!(client.api_url, "http://localhost:8080/api.json");
178    }
179}