1use anyhow::Result;
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9use super::ProviderType;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum ModelCapability {
15 Chat,
17 ToolUse,
19 Vision,
21 Embedding,
23 Audio,
25 ImageGeneration,
27}
28
29impl std::fmt::Display for ModelCapability {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 Self::Chat => write!(f, "chat"),
33 Self::ToolUse => write!(f, "tool_use"),
34 Self::Vision => write!(f, "vision"),
35 Self::Embedding => write!(f, "embedding"),
36 Self::Audio => write!(f, "audio"),
37 Self::ImageGeneration => write!(f, "image_generation"),
38 }
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct AvailableModel {
45 pub id: String,
47 pub display_name: Option<String>,
49 pub provider: ProviderType,
51 pub capabilities: Vec<ModelCapability>,
53 pub owned_by: Option<String>,
55 pub context_window: Option<u32>,
57 pub max_output_tokens: Option<u32>,
59 pub created_at: Option<i64>,
61}
62
63impl AvailableModel {
64 pub fn is_chat_capable(&self) -> bool {
66 self.capabilities.contains(&ModelCapability::Chat)
67 }
68}
69
70#[async_trait]
72pub trait ModelLister: Send + Sync {
73 async fn list_models(&self) -> Result<Vec<AvailableModel>>;
75}
76
77pub fn infer_openai_capabilities(model_id: &str) -> Vec<ModelCapability> {
85 let id = model_id.to_lowercase();
86
87 if id.contains("embedding") || id.starts_with("text-embedding") {
89 return vec![ModelCapability::Embedding];
90 }
91
92 if id.starts_with("whisper") || id.starts_with("tts") {
94 return vec![ModelCapability::Audio];
95 }
96
97 if id.starts_with("dall-e") {
99 return vec![ModelCapability::ImageGeneration];
100 }
101
102 let mut caps = vec![ModelCapability::Chat, ModelCapability::ToolUse];
104
105 if id.contains("vision")
107 || id.contains("gpt-4o")
108 || id.contains("gpt-4-turbo")
109 || id.contains("gpt-5")
110 || (id.starts_with("o") && !id.starts_with("omni"))
111 {
112 caps.push(ModelCapability::Vision);
113 }
114
115 caps
116}
117
118pub fn create_model_lister(
127 provider_type: ProviderType,
128 api_key: Option<&str>,
129 base_url: Option<&str>,
130) -> Result<Box<dyn ModelLister>> {
131 match provider_type {
132 ProviderType::Anthropic => {
133 let key = api_key
134 .ok_or_else(|| anyhow::anyhow!("Anthropic requires an API key"))?
135 .to_string();
136 Ok(Box::new(super::anthropic::AnthropicModelLister::new(key)))
137 }
138 ProviderType::OpenAI => {
139 let key = api_key
140 .ok_or_else(|| anyhow::anyhow!("OpenAI requires an API key"))?
141 .to_string();
142 Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
143 key,
144 base_url.map(|s| s.to_string()),
145 )))
146 }
147 ProviderType::Google => {
148 let key = api_key
149 .ok_or_else(|| anyhow::anyhow!("Google requires an API key"))?
150 .to_string();
151 Ok(Box::new(super::gemini::GoogleModelLister::new(key)))
152 }
153 ProviderType::Groq
154 | ProviderType::Together
155 | ProviderType::Fireworks
156 | ProviderType::Anyscale => {
157 let key = api_key
159 .ok_or_else(|| anyhow::anyhow!("{} requires an API key", provider_type))?
160 .to_string();
161 let registry_url = super::registry::lookup(provider_type).and_then(|e| e.models_url);
162 let url = base_url
163 .or(registry_url)
164 .unwrap_or("https://api.openai.com/v1/models");
165 Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
166 key,
167 Some(url.to_string()),
168 )))
169 }
170 ProviderType::Ollama => Ok(Box::new(super::ollama::OllamaModelLister::new(
171 base_url.map(|s| s.to_string()),
172 ))),
173 ProviderType::OpenAiResponses => {
174 let key = api_key
176 .ok_or_else(|| anyhow::anyhow!("OpenAI Responses requires an API key"))?
177 .to_string();
178 Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
179 key,
180 base_url.map(|s| s.to_string()),
181 )))
182 }
183 ProviderType::Brainwires
184 | ProviderType::Custom
185 | ProviderType::MiniMax
186 | ProviderType::Bedrock
187 | ProviderType::VertexAI
188 | ProviderType::ElevenLabs
189 | ProviderType::Deepgram
190 | ProviderType::Azure
191 | ProviderType::Fish
192 | ProviderType::Cartesia
193 | ProviderType::Murf => Err(anyhow::anyhow!(
194 "Model listing is not supported for {} provider via this interface",
195 provider_type
196 )),
197 }
198}
199
200#[derive(Debug, Deserialize)]
206pub(crate) struct AnthropicListResponse {
207 pub data: Vec<AnthropicModelEntry>,
208 pub has_more: bool,
209 #[serde(default)]
210 pub last_id: Option<String>,
211}
212
213#[derive(Debug, Deserialize)]
215pub struct AnthropicModelEntry {
216 pub id: String,
218 pub display_name: String,
220 #[serde(rename = "type")]
222 pub _type: Option<String>,
223 pub created_at: Option<String>,
225}
226
227#[derive(Debug, Deserialize)]
229pub(crate) struct OpenAIListResponse {
230 pub data: Vec<OpenAIModelEntry>,
231}
232
233#[derive(Debug, Deserialize)]
234pub(crate) struct OpenAIModelEntry {
235 pub id: String,
236 pub owned_by: Option<String>,
237 pub created: Option<i64>,
238}
239
240#[derive(Debug, Deserialize)]
242#[allow(dead_code)]
243pub(crate) struct GoogleListResponse {
244 #[serde(default)]
245 pub models: Vec<GoogleModelEntry>,
246 #[serde(rename = "nextPageToken")]
247 pub next_page_token: Option<String>,
248}
249
250#[derive(Debug, Deserialize)]
251#[allow(dead_code)]
252pub(crate) struct GoogleModelEntry {
253 pub name: String,
255 #[serde(rename = "displayName")]
256 pub display_name: Option<String>,
257 #[serde(rename = "inputTokenLimit")]
258 pub input_token_limit: Option<u32>,
259 #[serde(rename = "outputTokenLimit")]
260 pub output_token_limit: Option<u32>,
261 #[serde(rename = "supportedGenerationMethods", default)]
262 pub supported_generation_methods: Vec<String>,
263}
264
265#[derive(Debug, Deserialize)]
267pub(crate) struct OllamaTagsResponse {
268 pub models: Vec<OllamaModelEntry>,
269}
270
271#[derive(Debug, Deserialize)]
272#[allow(dead_code)]
273pub(crate) struct OllamaModelEntry {
274 pub name: String,
275 pub modified_at: Option<String>,
276 pub size: Option<u64>,
277}
278
279#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
288 fn test_infer_openai_capabilities_chat() {
289 let caps = infer_openai_capabilities("gpt-4o");
290 assert!(caps.contains(&ModelCapability::Chat));
291 assert!(caps.contains(&ModelCapability::ToolUse));
292 assert!(caps.contains(&ModelCapability::Vision));
293 }
294
295 #[test]
296 fn test_infer_openai_capabilities_embedding() {
297 let caps = infer_openai_capabilities("text-embedding-3-small");
298 assert!(caps.contains(&ModelCapability::Embedding));
299 assert!(!caps.contains(&ModelCapability::Chat));
300 }
301
302 #[test]
303 fn test_infer_openai_capabilities_audio() {
304 let caps = infer_openai_capabilities("whisper-1");
305 assert!(caps.contains(&ModelCapability::Audio));
306 assert!(!caps.contains(&ModelCapability::Chat));
307 }
308
309 #[test]
310 fn test_infer_openai_capabilities_image_gen() {
311 let caps = infer_openai_capabilities("dall-e-3");
312 assert!(caps.contains(&ModelCapability::ImageGeneration));
313 assert!(!caps.contains(&ModelCapability::Chat));
314 }
315
316 #[test]
317 fn test_infer_openai_capabilities_basic_chat() {
318 let caps = infer_openai_capabilities("gpt-3.5-turbo");
319 assert!(caps.contains(&ModelCapability::Chat));
320 assert!(caps.contains(&ModelCapability::ToolUse));
321 assert!(!caps.contains(&ModelCapability::Vision));
322 }
323
324 #[test]
325 fn test_available_model_is_chat_capable() {
326 let model = AvailableModel {
327 id: "test".to_string(),
328 display_name: None,
329 provider: ProviderType::OpenAI,
330 capabilities: vec![ModelCapability::Chat],
331 owned_by: None,
332 context_window: None,
333 max_output_tokens: None,
334 created_at: None,
335 };
336 assert!(model.is_chat_capable());
337
338 let embedding_model = AvailableModel {
339 id: "embed".to_string(),
340 display_name: None,
341 provider: ProviderType::OpenAI,
342 capabilities: vec![ModelCapability::Embedding],
343 owned_by: None,
344 context_window: None,
345 max_output_tokens: None,
346 created_at: None,
347 };
348 assert!(!embedding_model.is_chat_capable());
349 }
350
351 #[test]
352 fn test_parse_anthropic_response() {
353 let json = r#"{
354 "data": [
355 {"id": "claude-sonnet-4-20250514", "display_name": "Claude Sonnet 4", "type": "model", "created_at": "2025-05-14T00:00:00Z"},
356 {"id": "claude-3-5-haiku-20241022", "display_name": "Claude 3.5 Haiku", "type": "model"}
357 ],
358 "has_more": false
359 }"#;
360 let resp: AnthropicListResponse = serde_json::from_str(json).unwrap();
361 assert_eq!(resp.data.len(), 2);
362 assert_eq!(resp.data[0].id, "claude-sonnet-4-20250514");
363 assert!(!resp.has_more);
364 }
365
366 #[test]
367 fn test_parse_openai_response() {
368 let json = r#"{
369 "data": [
370 {"id": "gpt-4o", "owned_by": "openai", "created": 1715367049},
371 {"id": "text-embedding-3-small", "owned_by": "openai", "created": 1705948997}
372 ]
373 }"#;
374 let resp: OpenAIListResponse = serde_json::from_str(json).unwrap();
375 assert_eq!(resp.data.len(), 2);
376 assert_eq!(resp.data[0].id, "gpt-4o");
377 }
378
379 #[test]
380 fn test_parse_google_response() {
381 let json = r#"{
382 "models": [
383 {
384 "name": "models/gemini-2.0-flash",
385 "displayName": "Gemini 2.0 Flash",
386 "inputTokenLimit": 1048576,
387 "outputTokenLimit": 8192,
388 "supportedGenerationMethods": ["generateContent", "countTokens"]
389 }
390 ]
391 }"#;
392 let resp: GoogleListResponse = serde_json::from_str(json).unwrap();
393 assert_eq!(resp.models.len(), 1);
394 assert_eq!(resp.models[0].input_token_limit, Some(1048576));
395 }
396
397 #[test]
398 fn test_parse_ollama_response() {
399 let json = r#"{
400 "models": [
401 {"name": "llama3.1:latest", "modified_at": "2024-08-01T00:00:00Z", "size": 4000000000},
402 {"name": "codellama:7b", "modified_at": "2024-07-15T00:00:00Z", "size": 3800000000}
403 ]
404 }"#;
405 let resp: OllamaTagsResponse = serde_json::from_str(json).unwrap();
406 assert_eq!(resp.models.len(), 2);
407 assert_eq!(resp.models[0].name, "llama3.1:latest");
408 }
409
410 #[test]
411 fn test_model_capability_display() {
412 assert_eq!(ModelCapability::Chat.to_string(), "chat");
413 assert_eq!(ModelCapability::ToolUse.to_string(), "tool_use");
414 assert_eq!(ModelCapability::Vision.to_string(), "vision");
415 }
416
417 #[test]
418 fn test_create_model_lister_no_key() {
419 let result = create_model_lister(ProviderType::Anthropic, None, None);
420 assert!(result.is_err());
421 let err = result.map(|_| ()).unwrap_err();
422 assert!(err.to_string().contains("API key"));
423 }
424
425 #[test]
426 fn test_create_model_lister_ollama_no_key() {
427 let result = create_model_lister(ProviderType::Ollama, None, None);
428 assert!(result.is_ok());
429 }
430
431 #[test]
432 fn test_create_model_lister_brainwires_unsupported() {
433 let result = create_model_lister(ProviderType::Brainwires, Some("key"), None);
434 assert!(result.is_err());
435 }
436}