opencode_provider_manager/discovery/
models_dev.rs1use super::cache::CacheManager;
6use super::error::Result;
7use super::{DiscoveredModel, DiscoveredProvider};
8use reqwest::Client;
9use serde::Deserialize;
10
11const MODELS_DEV_API_URL: &str = "https://models.dev/api.json";
12
13pub struct ModelsDevClient {
15 client: Client,
16 api_url: String,
17}
18
19impl Default for ModelsDevClient {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl ModelsDevClient {
26 pub fn new() -> Self {
28 Self {
29 client: Client::new(),
30 api_url: MODELS_DEV_API_URL.to_string(),
31 }
32 }
33
34 pub fn with_url(api_url: String) -> Self {
36 Self {
37 client: Client::new(),
38 api_url,
39 }
40 }
41
42 pub async fn fetch_providers(&self) -> Result<Vec<DiscoveredProvider>> {
44 let response = self
45 .client
46 .get(&self.api_url)
47 .send()
48 .await
49 .map_err(|e| super::error::DiscoveryError::Network(e.to_string()))?;
50
51 let providers: ModelsDevResponse = response
52 .json()
53 .await
54 .map_err(|e| super::error::DiscoveryError::Parse(e.to_string()))?;
55
56 Ok(providers.into_providers())
57 }
58
59 pub async fn fetch_providers_cached(
61 &self,
62 force_refresh: bool,
63 ) -> Result<Vec<DiscoveredProvider>> {
64 let cache = CacheManager::new()?;
65 let cache_key = "models_dev_providers";
66
67 if !force_refresh {
68 if let Some(providers) = cache.get::<Vec<DiscoveredProvider>>(cache_key)? {
69 return Ok(providers);
70 }
71 }
72
73 let providers = self.fetch_providers().await?;
74 cache.set(cache_key, &providers)?;
75 Ok(providers)
76 }
77
78 pub async fn fetch_provider_models(&self, provider_id: &str) -> Result<Vec<DiscoveredModel>> {
80 let providers = self.fetch_providers().await?;
81 Ok(providers
82 .into_iter()
83 .find(|p| p.id == provider_id)
84 .map(|p| p.models)
85 .unwrap_or_default())
86 }
87
88 pub async fn fetch_provider_models_cached(
90 &self,
91 provider_id: &str,
92 force_refresh: bool,
93 ) -> Result<Vec<DiscoveredModel>> {
94 let providers = self.fetch_providers_cached(force_refresh).await?;
95 Ok(providers
96 .into_iter()
97 .find(|p| p.id == provider_id)
98 .map(|p| p.models)
99 .unwrap_or_default())
100 }
101}
102
103#[derive(Debug, Deserialize)]
105struct ModelsDevResponse {
106 #[serde(flatten)]
107 providers: HashMap<String, ModelsDevProvider>,
108}
109
110#[derive(Debug, Deserialize)]
111struct ModelsDevProvider {
112 name: String,
113 #[serde(default)]
114 models: HashMap<String, ModelsDevModel>,
115}
116
117#[derive(Debug, Deserialize)]
118struct ModelsDevModel {
119 name: Option<String>,
120 context_length: Option<u64>,
121 max_output_tokens: Option<u64>,
122 pricing: Option<ModelsDevPricing>,
123}
124
125#[derive(Debug, Deserialize)]
126struct ModelsDevPricing {
127 prompt: Option<String>,
128 completion: Option<String>,
129}
130
131use std::collections::HashMap;
132
133impl ModelsDevResponse {
134 fn into_providers(self) -> Vec<DiscoveredProvider> {
135 self.providers
136 .into_iter()
137 .map(|(id, provider)| DiscoveredProvider {
138 id: id.clone(),
139 name: provider.name.clone(),
140 models: provider
141 .models
142 .into_iter()
143 .map(|(model_id, model)| DiscoveredModel {
144 id: model_id,
145 name: model.name.unwrap_or_default(),
146 provider_id: id.clone(),
147 context_length: model.context_length,
148 max_output_tokens: model.max_output_tokens,
149 input_cost_per_million: model
150 .pricing
151 .as_ref()
152 .and_then(|p| p.prompt.as_ref()?.parse::<f64>().ok()),
153 output_cost_per_million: model
154 .pricing
155 .as_ref()
156 .and_then(|p| p.completion.as_ref()?.parse::<f64>().ok()),
157 })
158 .collect(),
159 })
160 .collect()
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn test_client_creation() {
170 let client = ModelsDevClient::new();
171 assert_eq!(client.api_url, MODELS_DEV_API_URL);
172 }
173
174 #[test]
175 fn test_client_custom_url() {
176 let client = ModelsDevClient::with_url("http://localhost:8080/api.json".to_string());
177 assert_eq!(client.api_url, "http://localhost:8080/api.json");
178 }
179}