1use anyhow::Result;
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9use super::ProviderType;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum ModelCapability {
15 Chat,
17 ToolUse,
19 Vision,
21 Embedding,
23 Audio,
25 ImageGeneration,
27}
28
29impl std::fmt::Display for ModelCapability {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 Self::Chat => write!(f, "chat"),
33 Self::ToolUse => write!(f, "tool_use"),
34 Self::Vision => write!(f, "vision"),
35 Self::Embedding => write!(f, "embedding"),
36 Self::Audio => write!(f, "audio"),
37 Self::ImageGeneration => write!(f, "image_generation"),
38 }
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct AvailableModel {
45 pub id: String,
47 pub display_name: Option<String>,
49 pub provider: ProviderType,
51 pub capabilities: Vec<ModelCapability>,
53 pub owned_by: Option<String>,
55 pub context_window: Option<u32>,
57 pub max_output_tokens: Option<u32>,
59 pub created_at: Option<i64>,
61}
62
63impl AvailableModel {
64 pub fn is_chat_capable(&self) -> bool {
66 self.capabilities.contains(&ModelCapability::Chat)
67 }
68}
69
70#[async_trait]
72pub trait ModelLister: Send + Sync {
73 async fn list_models(&self) -> Result<Vec<AvailableModel>>;
75}
76
77pub fn infer_openai_capabilities(model_id: &str) -> Vec<ModelCapability> {
85 let id = model_id.to_lowercase();
86
87 if id.contains("embedding") || id.starts_with("text-embedding") {
89 return vec![ModelCapability::Embedding];
90 }
91
92 if id.starts_with("whisper") || id.starts_with("tts") {
94 return vec![ModelCapability::Audio];
95 }
96
97 if id.starts_with("dall-e") {
99 return vec![ModelCapability::ImageGeneration];
100 }
101
102 let mut caps = vec![ModelCapability::Chat, ModelCapability::ToolUse];
104
105 if id.contains("vision")
107 || id.contains("gpt-4o")
108 || id.contains("gpt-4-turbo")
109 || id.contains("gpt-5")
110 || (id.starts_with("o") && !id.starts_with("omni"))
111 {
112 caps.push(ModelCapability::Vision);
113 }
114
115 caps
116}
117
118pub fn create_model_lister(
127 provider_type: ProviderType,
128 api_key: Option<&str>,
129 base_url: Option<&str>,
130) -> Result<Box<dyn ModelLister>> {
131 match provider_type {
132 ProviderType::Anthropic => {
133 let key = api_key
134 .ok_or_else(|| anyhow::anyhow!("Anthropic requires an API key"))?
135 .to_string();
136 Ok(Box::new(super::anthropic::AnthropicModelLister::new(key)))
137 }
138 ProviderType::OpenAI => {
139 let key = api_key
140 .ok_or_else(|| anyhow::anyhow!("OpenAI requires an API key"))?
141 .to_string();
142 Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
143 key,
144 base_url.map(|s| s.to_string()),
145 )))
146 }
147 ProviderType::Google => {
148 let key = api_key
149 .ok_or_else(|| anyhow::anyhow!("Google requires an API key"))?
150 .to_string();
151 Ok(Box::new(super::gemini::GoogleModelLister::new(key)))
152 }
153 ProviderType::Groq
154 | ProviderType::Together
155 | ProviderType::Fireworks
156 | ProviderType::Anyscale => {
157 let key = api_key
159 .ok_or_else(|| anyhow::anyhow!("{} requires an API key", provider_type))?
160 .to_string();
161 let registry_url = super::registry::lookup(provider_type).and_then(|e| e.models_url);
162 let url = base_url
163 .or(registry_url)
164 .unwrap_or("https://api.openai.com/v1/models");
165 Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
166 key,
167 Some(url.to_string()),
168 )))
169 }
170 ProviderType::Ollama => Ok(Box::new(super::ollama::OllamaModelLister::new(
171 base_url.map(|s| s.to_string()),
172 ))),
173 ProviderType::OpenAiResponses => {
174 let key = api_key
176 .ok_or_else(|| anyhow::anyhow!("OpenAI Responses requires an API key"))?
177 .to_string();
178 Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
179 key,
180 base_url.map(|s| s.to_string()),
181 )))
182 }
183 ProviderType::Brainwires
184 | ProviderType::Custom
185 | ProviderType::Bedrock
186 | ProviderType::VertexAI
187 | ProviderType::ElevenLabs
188 | ProviderType::Deepgram
189 | ProviderType::Azure
190 | ProviderType::Fish
191 | ProviderType::Cartesia
192 | ProviderType::Murf => Err(anyhow::anyhow!(
193 "Model listing is not supported for {} provider via this interface",
194 provider_type
195 )),
196 }
197}
198
199#[derive(Debug, Deserialize)]
205pub(crate) struct AnthropicListResponse {
206 pub data: Vec<AnthropicModelEntry>,
207 pub has_more: bool,
208 #[serde(default)]
209 pub last_id: Option<String>,
210}
211
212#[derive(Debug, Deserialize)]
214pub struct AnthropicModelEntry {
215 pub id: String,
217 pub display_name: String,
219 #[serde(rename = "type")]
221 pub _type: Option<String>,
222 pub created_at: Option<String>,
224}
225
226#[derive(Debug, Deserialize)]
228pub(crate) struct OpenAIListResponse {
229 pub data: Vec<OpenAIModelEntry>,
230}
231
232#[derive(Debug, Deserialize)]
233pub(crate) struct OpenAIModelEntry {
234 pub id: String,
235 pub owned_by: Option<String>,
236 pub created: Option<i64>,
237}
238
239#[derive(Debug, Deserialize)]
241#[allow(dead_code)]
242pub(crate) struct GoogleListResponse {
243 #[serde(default)]
244 pub models: Vec<GoogleModelEntry>,
245 #[serde(rename = "nextPageToken")]
246 pub next_page_token: Option<String>,
247}
248
249#[derive(Debug, Deserialize)]
250#[allow(dead_code)]
251pub(crate) struct GoogleModelEntry {
252 pub name: String,
254 #[serde(rename = "displayName")]
255 pub display_name: Option<String>,
256 #[serde(rename = "inputTokenLimit")]
257 pub input_token_limit: Option<u32>,
258 #[serde(rename = "outputTokenLimit")]
259 pub output_token_limit: Option<u32>,
260 #[serde(rename = "supportedGenerationMethods", default)]
261 pub supported_generation_methods: Vec<String>,
262}
263
264#[derive(Debug, Deserialize)]
266pub(crate) struct OllamaTagsResponse {
267 pub models: Vec<OllamaModelEntry>,
268}
269
270#[derive(Debug, Deserialize)]
271#[allow(dead_code)]
272pub(crate) struct OllamaModelEntry {
273 pub name: String,
274 pub modified_at: Option<String>,
275 pub size: Option<u64>,
276}
277
278#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn test_infer_openai_capabilities_chat() {
288 let caps = infer_openai_capabilities("gpt-4o");
289 assert!(caps.contains(&ModelCapability::Chat));
290 assert!(caps.contains(&ModelCapability::ToolUse));
291 assert!(caps.contains(&ModelCapability::Vision));
292 }
293
294 #[test]
295 fn test_infer_openai_capabilities_embedding() {
296 let caps = infer_openai_capabilities("text-embedding-3-small");
297 assert!(caps.contains(&ModelCapability::Embedding));
298 assert!(!caps.contains(&ModelCapability::Chat));
299 }
300
301 #[test]
302 fn test_infer_openai_capabilities_audio() {
303 let caps = infer_openai_capabilities("whisper-1");
304 assert!(caps.contains(&ModelCapability::Audio));
305 assert!(!caps.contains(&ModelCapability::Chat));
306 }
307
308 #[test]
309 fn test_infer_openai_capabilities_image_gen() {
310 let caps = infer_openai_capabilities("dall-e-3");
311 assert!(caps.contains(&ModelCapability::ImageGeneration));
312 assert!(!caps.contains(&ModelCapability::Chat));
313 }
314
315 #[test]
316 fn test_infer_openai_capabilities_basic_chat() {
317 let caps = infer_openai_capabilities("gpt-3.5-turbo");
318 assert!(caps.contains(&ModelCapability::Chat));
319 assert!(caps.contains(&ModelCapability::ToolUse));
320 assert!(!caps.contains(&ModelCapability::Vision));
321 }
322
323 #[test]
324 fn test_available_model_is_chat_capable() {
325 let model = AvailableModel {
326 id: "test".to_string(),
327 display_name: None,
328 provider: ProviderType::OpenAI,
329 capabilities: vec![ModelCapability::Chat],
330 owned_by: None,
331 context_window: None,
332 max_output_tokens: None,
333 created_at: None,
334 };
335 assert!(model.is_chat_capable());
336
337 let embedding_model = AvailableModel {
338 id: "embed".to_string(),
339 display_name: None,
340 provider: ProviderType::OpenAI,
341 capabilities: vec![ModelCapability::Embedding],
342 owned_by: None,
343 context_window: None,
344 max_output_tokens: None,
345 created_at: None,
346 };
347 assert!(!embedding_model.is_chat_capable());
348 }
349
350 #[test]
351 fn test_parse_anthropic_response() {
352 let json = r#"{
353 "data": [
354 {"id": "claude-sonnet-4-20250514", "display_name": "Claude Sonnet 4", "type": "model", "created_at": "2025-05-14T00:00:00Z"},
355 {"id": "claude-3-5-haiku-20241022", "display_name": "Claude 3.5 Haiku", "type": "model"}
356 ],
357 "has_more": false
358 }"#;
359 let resp: AnthropicListResponse = serde_json::from_str(json).unwrap();
360 assert_eq!(resp.data.len(), 2);
361 assert_eq!(resp.data[0].id, "claude-sonnet-4-20250514");
362 assert!(!resp.has_more);
363 }
364
365 #[test]
366 fn test_parse_openai_response() {
367 let json = r#"{
368 "data": [
369 {"id": "gpt-4o", "owned_by": "openai", "created": 1715367049},
370 {"id": "text-embedding-3-small", "owned_by": "openai", "created": 1705948997}
371 ]
372 }"#;
373 let resp: OpenAIListResponse = serde_json::from_str(json).unwrap();
374 assert_eq!(resp.data.len(), 2);
375 assert_eq!(resp.data[0].id, "gpt-4o");
376 }
377
378 #[test]
379 fn test_parse_google_response() {
380 let json = r#"{
381 "models": [
382 {
383 "name": "models/gemini-2.0-flash",
384 "displayName": "Gemini 2.0 Flash",
385 "inputTokenLimit": 1048576,
386 "outputTokenLimit": 8192,
387 "supportedGenerationMethods": ["generateContent", "countTokens"]
388 }
389 ]
390 }"#;
391 let resp: GoogleListResponse = serde_json::from_str(json).unwrap();
392 assert_eq!(resp.models.len(), 1);
393 assert_eq!(resp.models[0].input_token_limit, Some(1048576));
394 }
395
396 #[test]
397 fn test_parse_ollama_response() {
398 let json = r#"{
399 "models": [
400 {"name": "llama3.1:latest", "modified_at": "2024-08-01T00:00:00Z", "size": 4000000000},
401 {"name": "codellama:7b", "modified_at": "2024-07-15T00:00:00Z", "size": 3800000000}
402 ]
403 }"#;
404 let resp: OllamaTagsResponse = serde_json::from_str(json).unwrap();
405 assert_eq!(resp.models.len(), 2);
406 assert_eq!(resp.models[0].name, "llama3.1:latest");
407 }
408
409 #[test]
410 fn test_model_capability_display() {
411 assert_eq!(ModelCapability::Chat.to_string(), "chat");
412 assert_eq!(ModelCapability::ToolUse.to_string(), "tool_use");
413 assert_eq!(ModelCapability::Vision.to_string(), "vision");
414 }
415
416 #[test]
417 fn test_create_model_lister_no_key() {
418 let result = create_model_lister(ProviderType::Anthropic, None, None);
419 assert!(result.is_err());
420 let err = result.map(|_| ()).unwrap_err();
421 assert!(err.to_string().contains("API key"));
422 }
423
424 #[test]
425 fn test_create_model_lister_ollama_no_key() {
426 let result = create_model_lister(ProviderType::Ollama, None, None);
427 assert!(result.is_ok());
428 }
429
430 #[test]
431 fn test_create_model_lister_brainwires_unsupported() {
432 let result = create_model_lister(ProviderType::Brainwires, Some("key"), None);
433 assert!(result.is_err());
434 }
435}