Skip to main content

brainwires_providers/
model_listing.rs

1//! Model listing and validation for AI providers.
2//!
3//! Each provider implements [`ModelLister`] to query available models from its API.
4
5use anyhow::Result;
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9use super::ProviderType;
10
11/// Capabilities a model may support.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum ModelCapability {
15    /// Text chat / completions.
16    Chat,
17    /// Tool / function calling.
18    ToolUse,
19    /// Image / vision understanding.
20    Vision,
21    /// Text embedding generation.
22    Embedding,
23    /// Audio processing.
24    Audio,
25    /// Image generation.
26    ImageGeneration,
27}
28
29impl std::fmt::Display for ModelCapability {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            Self::Chat => write!(f, "chat"),
33            Self::ToolUse => write!(f, "tool_use"),
34            Self::Vision => write!(f, "vision"),
35            Self::Embedding => write!(f, "embedding"),
36            Self::Audio => write!(f, "audio"),
37            Self::ImageGeneration => write!(f, "image_generation"),
38        }
39    }
40}
41
42/// A model available from a provider.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct AvailableModel {
45    /// Model identifier (e.g. "claude-sonnet-4-20250514", "gpt-4o").
46    pub id: String,
47    /// Human-readable name, if provided by the API.
48    pub display_name: Option<String>,
49    /// Which provider owns this model.
50    pub provider: ProviderType,
51    /// What the model can do.
52    pub capabilities: Vec<ModelCapability>,
53    /// Organization/owner string from the API.
54    pub owned_by: Option<String>,
55    /// Maximum input context window (tokens).
56    pub context_window: Option<u32>,
57    /// Maximum output tokens the model can produce.
58    pub max_output_tokens: Option<u32>,
59    /// Unix timestamp (seconds) when the model was created.
60    pub created_at: Option<i64>,
61}
62
63impl AvailableModel {
64    /// Whether this model supports chat completions.
65    pub fn is_chat_capable(&self) -> bool {
66        self.capabilities.contains(&ModelCapability::Chat)
67    }
68}
69
70/// Trait for querying a provider's model catalogue.
71#[async_trait]
72pub trait ModelLister: Send + Sync {
73    /// Fetch all models available for this provider.
74    async fn list_models(&self) -> Result<Vec<AvailableModel>>;
75}
76
77// ---------------------------------------------------------------------------
78// Shared helpers
79// ---------------------------------------------------------------------------
80
81/// Infer capabilities for an OpenAI-format model ID.
82///
83/// Shared by the OpenAI and Groq listers.
84pub fn infer_openai_capabilities(model_id: &str) -> Vec<ModelCapability> {
85    let id = model_id.to_lowercase();
86
87    // Embedding models
88    if id.contains("embedding") || id.starts_with("text-embedding") {
89        return vec![ModelCapability::Embedding];
90    }
91
92    // Audio models
93    if id.starts_with("whisper") || id.starts_with("tts") {
94        return vec![ModelCapability::Audio];
95    }
96
97    // Image generation
98    if id.starts_with("dall-e") {
99        return vec![ModelCapability::ImageGeneration];
100    }
101
102    // Chat-capable models get Chat + ToolUse by default
103    let mut caps = vec![ModelCapability::Chat, ModelCapability::ToolUse];
104
105    // Vision-capable models
106    if id.contains("vision")
107        || id.contains("gpt-4o")
108        || id.contains("gpt-4-turbo")
109        || id.contains("gpt-5")
110        || (id.starts_with("o") && !id.starts_with("omni"))
111    {
112        caps.push(ModelCapability::Vision);
113    }
114
115    caps
116}
117
118// ---------------------------------------------------------------------------
119// Factory
120// ---------------------------------------------------------------------------
121
122/// Create a [`ModelLister`] for the given provider.
123///
124/// * `api_key` — required for cloud providers, ignored for Ollama.
125/// * `base_url` — optional override (used for Ollama or custom endpoints).
126pub fn create_model_lister(
127    provider_type: ProviderType,
128    api_key: Option<&str>,
129    base_url: Option<&str>,
130) -> Result<Box<dyn ModelLister>> {
131    match provider_type {
132        ProviderType::Anthropic => {
133            let key = api_key
134                .ok_or_else(|| anyhow::anyhow!("Anthropic requires an API key"))?
135                .to_string();
136            Ok(Box::new(super::anthropic::AnthropicModelLister::new(key)))
137        }
138        ProviderType::OpenAI => {
139            let key = api_key
140                .ok_or_else(|| anyhow::anyhow!("OpenAI requires an API key"))?
141                .to_string();
142            Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
143                key,
144                base_url.map(|s| s.to_string()),
145            )))
146        }
147        ProviderType::Google => {
148            let key = api_key
149                .ok_or_else(|| anyhow::anyhow!("Google requires an API key"))?
150                .to_string();
151            Ok(Box::new(super::gemini::GoogleModelLister::new(key)))
152        }
153        ProviderType::Groq
154        | ProviderType::Together
155        | ProviderType::Fireworks
156        | ProviderType::Anyscale => {
157            // All OpenAI-compatible: reuse OpenAI model lister with the registry's models URL
158            let key = api_key
159                .ok_or_else(|| anyhow::anyhow!("{} requires an API key", provider_type))?
160                .to_string();
161            let registry_url = super::registry::lookup(provider_type).and_then(|e| e.models_url);
162            let url = base_url
163                .or(registry_url)
164                .unwrap_or("https://api.openai.com/v1/models");
165            Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
166                key,
167                Some(url.to_string()),
168            )))
169        }
170        ProviderType::Ollama => Ok(Box::new(super::ollama::OllamaModelLister::new(
171            base_url.map(|s| s.to_string()),
172        ))),
173        ProviderType::OpenAiResponses => {
174            // Shares the same models endpoint as OpenAI Chat Completions
175            let key = api_key
176                .ok_or_else(|| anyhow::anyhow!("OpenAI Responses requires an API key"))?
177                .to_string();
178            Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
179                key,
180                base_url.map(|s| s.to_string()),
181            )))
182        }
183        ProviderType::Brainwires
184        | ProviderType::Custom
185        | ProviderType::Bedrock
186        | ProviderType::VertexAI
187        | ProviderType::ElevenLabs
188        | ProviderType::Deepgram
189        | ProviderType::Azure
190        | ProviderType::Fish
191        | ProviderType::Cartesia
192        | ProviderType::Murf => Err(anyhow::anyhow!(
193            "Model listing is not supported for {} provider via this interface",
194            provider_type
195        )),
196    }
197}
198
199// ---------------------------------------------------------------------------
200// Response types shared across listers
201// ---------------------------------------------------------------------------
202
203/// Anthropic `/v1/models` list response.
204#[derive(Debug, Deserialize)]
205pub(crate) struct AnthropicListResponse {
206    pub data: Vec<AnthropicModelEntry>,
207    pub has_more: bool,
208    #[serde(default)]
209    pub last_id: Option<String>,
210}
211
212/// A model entry from the Anthropic API.
213#[derive(Debug, Deserialize)]
214pub struct AnthropicModelEntry {
215    /// Model identifier (e.g. `"claude-sonnet-4-20250514"`).
216    pub id: String,
217    /// Human-readable name.
218    pub display_name: String,
219    /// Resource type (always `"model"`).
220    #[serde(rename = "type")]
221    pub _type: Option<String>,
222    /// ISO 8601 creation timestamp.
223    pub created_at: Option<String>,
224}
225
226/// OpenAI `/v1/models` response.
227#[derive(Debug, Deserialize)]
228pub(crate) struct OpenAIListResponse {
229    pub data: Vec<OpenAIModelEntry>,
230}
231
232#[derive(Debug, Deserialize)]
233pub(crate) struct OpenAIModelEntry {
234    pub id: String,
235    pub owned_by: Option<String>,
236    pub created: Option<i64>,
237}
238
239/// Google `models` list response.
240#[derive(Debug, Deserialize)]
241#[allow(dead_code)]
242pub(crate) struct GoogleListResponse {
243    #[serde(default)]
244    pub models: Vec<GoogleModelEntry>,
245    #[serde(rename = "nextPageToken")]
246    pub next_page_token: Option<String>,
247}
248
249#[derive(Debug, Deserialize)]
250#[allow(dead_code)]
251pub(crate) struct GoogleModelEntry {
252    /// e.g. "models/gemini-2.0-flash"
253    pub name: String,
254    #[serde(rename = "displayName")]
255    pub display_name: Option<String>,
256    #[serde(rename = "inputTokenLimit")]
257    pub input_token_limit: Option<u32>,
258    #[serde(rename = "outputTokenLimit")]
259    pub output_token_limit: Option<u32>,
260    #[serde(rename = "supportedGenerationMethods", default)]
261    pub supported_generation_methods: Vec<String>,
262}
263
264/// Ollama `/api/tags` response.
265#[derive(Debug, Deserialize)]
266pub(crate) struct OllamaTagsResponse {
267    pub models: Vec<OllamaModelEntry>,
268}
269
270#[derive(Debug, Deserialize)]
271#[allow(dead_code)]
272pub(crate) struct OllamaModelEntry {
273    pub name: String,
274    pub modified_at: Option<String>,
275    pub size: Option<u64>,
276}
277
278// ---------------------------------------------------------------------------
279// Tests
280// ---------------------------------------------------------------------------
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    #[test]
287    fn test_infer_openai_capabilities_chat() {
288        let caps = infer_openai_capabilities("gpt-4o");
289        assert!(caps.contains(&ModelCapability::Chat));
290        assert!(caps.contains(&ModelCapability::ToolUse));
291        assert!(caps.contains(&ModelCapability::Vision));
292    }
293
294    #[test]
295    fn test_infer_openai_capabilities_embedding() {
296        let caps = infer_openai_capabilities("text-embedding-3-small");
297        assert!(caps.contains(&ModelCapability::Embedding));
298        assert!(!caps.contains(&ModelCapability::Chat));
299    }
300
301    #[test]
302    fn test_infer_openai_capabilities_audio() {
303        let caps = infer_openai_capabilities("whisper-1");
304        assert!(caps.contains(&ModelCapability::Audio));
305        assert!(!caps.contains(&ModelCapability::Chat));
306    }
307
308    #[test]
309    fn test_infer_openai_capabilities_image_gen() {
310        let caps = infer_openai_capabilities("dall-e-3");
311        assert!(caps.contains(&ModelCapability::ImageGeneration));
312        assert!(!caps.contains(&ModelCapability::Chat));
313    }
314
315    #[test]
316    fn test_infer_openai_capabilities_basic_chat() {
317        let caps = infer_openai_capabilities("gpt-3.5-turbo");
318        assert!(caps.contains(&ModelCapability::Chat));
319        assert!(caps.contains(&ModelCapability::ToolUse));
320        assert!(!caps.contains(&ModelCapability::Vision));
321    }
322
323    #[test]
324    fn test_available_model_is_chat_capable() {
325        let model = AvailableModel {
326            id: "test".to_string(),
327            display_name: None,
328            provider: ProviderType::OpenAI,
329            capabilities: vec![ModelCapability::Chat],
330            owned_by: None,
331            context_window: None,
332            max_output_tokens: None,
333            created_at: None,
334        };
335        assert!(model.is_chat_capable());
336
337        let embedding_model = AvailableModel {
338            id: "embed".to_string(),
339            display_name: None,
340            provider: ProviderType::OpenAI,
341            capabilities: vec![ModelCapability::Embedding],
342            owned_by: None,
343            context_window: None,
344            max_output_tokens: None,
345            created_at: None,
346        };
347        assert!(!embedding_model.is_chat_capable());
348    }
349
350    #[test]
351    fn test_parse_anthropic_response() {
352        let json = r#"{
353            "data": [
354                {"id": "claude-sonnet-4-20250514", "display_name": "Claude Sonnet 4", "type": "model", "created_at": "2025-05-14T00:00:00Z"},
355                {"id": "claude-3-5-haiku-20241022", "display_name": "Claude 3.5 Haiku", "type": "model"}
356            ],
357            "has_more": false
358        }"#;
359        let resp: AnthropicListResponse = serde_json::from_str(json).unwrap();
360        assert_eq!(resp.data.len(), 2);
361        assert_eq!(resp.data[0].id, "claude-sonnet-4-20250514");
362        assert!(!resp.has_more);
363    }
364
365    #[test]
366    fn test_parse_openai_response() {
367        let json = r#"{
368            "data": [
369                {"id": "gpt-4o", "owned_by": "openai", "created": 1715367049},
370                {"id": "text-embedding-3-small", "owned_by": "openai", "created": 1705948997}
371            ]
372        }"#;
373        let resp: OpenAIListResponse = serde_json::from_str(json).unwrap();
374        assert_eq!(resp.data.len(), 2);
375        assert_eq!(resp.data[0].id, "gpt-4o");
376    }
377
378    #[test]
379    fn test_parse_google_response() {
380        let json = r#"{
381            "models": [
382                {
383                    "name": "models/gemini-2.0-flash",
384                    "displayName": "Gemini 2.0 Flash",
385                    "inputTokenLimit": 1048576,
386                    "outputTokenLimit": 8192,
387                    "supportedGenerationMethods": ["generateContent", "countTokens"]
388                }
389            ]
390        }"#;
391        let resp: GoogleListResponse = serde_json::from_str(json).unwrap();
392        assert_eq!(resp.models.len(), 1);
393        assert_eq!(resp.models[0].input_token_limit, Some(1048576));
394    }
395
396    #[test]
397    fn test_parse_ollama_response() {
398        let json = r#"{
399            "models": [
400                {"name": "llama3.1:latest", "modified_at": "2024-08-01T00:00:00Z", "size": 4000000000},
401                {"name": "codellama:7b", "modified_at": "2024-07-15T00:00:00Z", "size": 3800000000}
402            ]
403        }"#;
404        let resp: OllamaTagsResponse = serde_json::from_str(json).unwrap();
405        assert_eq!(resp.models.len(), 2);
406        assert_eq!(resp.models[0].name, "llama3.1:latest");
407    }
408
409    #[test]
410    fn test_model_capability_display() {
411        assert_eq!(ModelCapability::Chat.to_string(), "chat");
412        assert_eq!(ModelCapability::ToolUse.to_string(), "tool_use");
413        assert_eq!(ModelCapability::Vision.to_string(), "vision");
414    }
415
416    #[test]
417    fn test_create_model_lister_no_key() {
418        let result = create_model_lister(ProviderType::Anthropic, None, None);
419        assert!(result.is_err());
420        let err = result.map(|_| ()).unwrap_err();
421        assert!(err.to_string().contains("API key"));
422    }
423
424    #[test]
425    fn test_create_model_lister_ollama_no_key() {
426        let result = create_model_lister(ProviderType::Ollama, None, None);
427        assert!(result.is_ok());
428    }
429
430    #[test]
431    fn test_create_model_lister_brainwires_unsupported() {
432        let result = create_model_lister(ProviderType::Brainwires, Some("key"), None);
433        assert!(result.is_err());
434    }
435}