Skip to main content

brainwires_provider/
model_listing.rs

1//! Model listing and validation for AI providers.
2//!
3//! Each provider implements [`ModelLister`] to query available models from its API.
4
5use anyhow::Result;
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9use super::ProviderType;
10
11/// Capabilities a model may support.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum ModelCapability {
15    /// Text chat / completions.
16    Chat,
17    /// Tool / function calling.
18    ToolUse,
19    /// Image / vision understanding.
20    Vision,
21    /// Text embedding generation.
22    Embedding,
23    /// Audio processing.
24    Audio,
25    /// Image generation.
26    ImageGeneration,
27}
28
29impl std::fmt::Display for ModelCapability {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            Self::Chat => write!(f, "chat"),
33            Self::ToolUse => write!(f, "tool_use"),
34            Self::Vision => write!(f, "vision"),
35            Self::Embedding => write!(f, "embedding"),
36            Self::Audio => write!(f, "audio"),
37            Self::ImageGeneration => write!(f, "image_generation"),
38        }
39    }
40}
41
42/// A model available from a provider.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct AvailableModel {
45    /// Model identifier (e.g. "claude-sonnet-4-20250514", "gpt-4o").
46    pub id: String,
47    /// Human-readable name, if provided by the API.
48    pub display_name: Option<String>,
49    /// Which provider owns this model.
50    pub provider: ProviderType,
51    /// What the model can do.
52    pub capabilities: Vec<ModelCapability>,
53    /// Organization/owner string from the API.
54    pub owned_by: Option<String>,
55    /// Maximum input context window (tokens).
56    pub context_window: Option<u32>,
57    /// Maximum output tokens the model can produce.
58    pub max_output_tokens: Option<u32>,
59    /// Unix timestamp (seconds) when the model was created.
60    pub created_at: Option<i64>,
61}
62
63impl AvailableModel {
64    /// Whether this model supports chat completions.
65    pub fn is_chat_capable(&self) -> bool {
66        self.capabilities.contains(&ModelCapability::Chat)
67    }
68}
69
70/// Trait for querying a provider's model catalogue.
71#[async_trait]
72pub trait ModelLister: Send + Sync {
73    /// Fetch all models available for this provider.
74    async fn list_models(&self) -> Result<Vec<AvailableModel>>;
75}
76
77// ---------------------------------------------------------------------------
78// Shared helpers
79// ---------------------------------------------------------------------------
80
81/// Infer capabilities for an OpenAI-format model ID.
82///
83/// Shared by the OpenAI and Groq listers.
84pub fn infer_openai_capabilities(model_id: &str) -> Vec<ModelCapability> {
85    let id = model_id.to_lowercase();
86
87    // Embedding models
88    if id.contains("embedding") || id.starts_with("text-embedding") {
89        return vec![ModelCapability::Embedding];
90    }
91
92    // Audio models
93    if id.starts_with("whisper") || id.starts_with("tts") {
94        return vec![ModelCapability::Audio];
95    }
96
97    // Image generation
98    if id.starts_with("dall-e") {
99        return vec![ModelCapability::ImageGeneration];
100    }
101
102    // Chat-capable models get Chat + ToolUse by default
103    let mut caps = vec![ModelCapability::Chat, ModelCapability::ToolUse];
104
105    // Vision-capable models
106    if id.contains("vision")
107        || id.contains("gpt-4o")
108        || id.contains("gpt-4-turbo")
109        || id.contains("gpt-5")
110        || (id.starts_with("o") && !id.starts_with("omni"))
111    {
112        caps.push(ModelCapability::Vision);
113    }
114
115    caps
116}
117
118// ---------------------------------------------------------------------------
119// Factory
120// ---------------------------------------------------------------------------
121
122/// Create a [`ModelLister`] for the given provider.
123///
124/// * `api_key` — required for cloud providers, ignored for Ollama.
125/// * `base_url` — optional override (used for Ollama or custom endpoints).
126pub fn create_model_lister(
127    provider_type: ProviderType,
128    api_key: Option<&str>,
129    base_url: Option<&str>,
130) -> Result<Box<dyn ModelLister>> {
131    match provider_type {
132        ProviderType::Anthropic => {
133            let key = api_key
134                .ok_or_else(|| anyhow::anyhow!("Anthropic requires an API key"))?
135                .to_string();
136            Ok(Box::new(super::anthropic::AnthropicModelLister::new(key)))
137        }
138        ProviderType::OpenAI => {
139            let key = api_key
140                .ok_or_else(|| anyhow::anyhow!("OpenAI requires an API key"))?
141                .to_string();
142            Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
143                key,
144                base_url.map(|s| s.to_string()),
145            )))
146        }
147        ProviderType::Google => {
148            let key = api_key
149                .ok_or_else(|| anyhow::anyhow!("Google requires an API key"))?
150                .to_string();
151            Ok(Box::new(super::gemini::GoogleModelLister::new(key)))
152        }
153        ProviderType::Groq
154        | ProviderType::Together
155        | ProviderType::Fireworks
156        | ProviderType::Anyscale => {
157            // All OpenAI-compatible: reuse OpenAI model lister with the registry's models URL
158            let key = api_key
159                .ok_or_else(|| anyhow::anyhow!("{} requires an API key", provider_type))?
160                .to_string();
161            let registry_url = super::registry::lookup(provider_type).and_then(|e| e.models_url);
162            let url = base_url
163                .or(registry_url)
164                .unwrap_or("https://api.openai.com/v1/models");
165            Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
166                key,
167                Some(url.to_string()),
168            )))
169        }
170        ProviderType::Ollama => Ok(Box::new(super::ollama::OllamaModelLister::new(
171            base_url.map(|s| s.to_string()),
172        ))),
173        ProviderType::OpenAiResponses => {
174            // Shares the same models endpoint as OpenAI Chat Completions
175            let key = api_key
176                .ok_or_else(|| anyhow::anyhow!("OpenAI Responses requires an API key"))?
177                .to_string();
178            Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
179                key,
180                base_url.map(|s| s.to_string()),
181            )))
182        }
183        ProviderType::Brainwires
184        | ProviderType::Custom
185        | ProviderType::MiniMax
186        | ProviderType::Bedrock
187        | ProviderType::VertexAI
188        | ProviderType::ElevenLabs
189        | ProviderType::Deepgram
190        | ProviderType::Azure
191        | ProviderType::Fish
192        | ProviderType::Cartesia
193        | ProviderType::Murf => Err(anyhow::anyhow!(
194            "Model listing is not supported for {} provider via this interface",
195            provider_type
196        )),
197    }
198}
199
200// ---------------------------------------------------------------------------
201// Response types shared across listers
202// ---------------------------------------------------------------------------
203
204/// Anthropic `/v1/models` list response.
205#[derive(Debug, Deserialize)]
206pub(crate) struct AnthropicListResponse {
207    pub data: Vec<AnthropicModelEntry>,
208    pub has_more: bool,
209    #[serde(default)]
210    pub last_id: Option<String>,
211}
212
213/// A model entry from the Anthropic API.
214#[derive(Debug, Deserialize)]
215pub struct AnthropicModelEntry {
216    /// Model identifier (e.g. `"claude-sonnet-4-20250514"`).
217    pub id: String,
218    /// Human-readable name.
219    pub display_name: String,
220    /// Resource type (always `"model"`).
221    #[serde(rename = "type")]
222    pub _type: Option<String>,
223    /// ISO 8601 creation timestamp.
224    pub created_at: Option<String>,
225}
226
227/// OpenAI `/v1/models` response.
228#[derive(Debug, Deserialize)]
229pub(crate) struct OpenAIListResponse {
230    pub data: Vec<OpenAIModelEntry>,
231}
232
233#[derive(Debug, Deserialize)]
234pub(crate) struct OpenAIModelEntry {
235    pub id: String,
236    pub owned_by: Option<String>,
237    pub created: Option<i64>,
238}
239
240/// Google `models` list response.
241#[derive(Debug, Deserialize)]
242#[allow(dead_code)]
243pub(crate) struct GoogleListResponse {
244    #[serde(default)]
245    pub models: Vec<GoogleModelEntry>,
246    #[serde(rename = "nextPageToken")]
247    pub next_page_token: Option<String>,
248}
249
250#[derive(Debug, Deserialize)]
251#[allow(dead_code)]
252pub(crate) struct GoogleModelEntry {
253    /// e.g. "models/gemini-2.0-flash"
254    pub name: String,
255    #[serde(rename = "displayName")]
256    pub display_name: Option<String>,
257    #[serde(rename = "inputTokenLimit")]
258    pub input_token_limit: Option<u32>,
259    #[serde(rename = "outputTokenLimit")]
260    pub output_token_limit: Option<u32>,
261    #[serde(rename = "supportedGenerationMethods", default)]
262    pub supported_generation_methods: Vec<String>,
263}
264
265/// Ollama `/api/tags` response.
266#[derive(Debug, Deserialize)]
267pub(crate) struct OllamaTagsResponse {
268    pub models: Vec<OllamaModelEntry>,
269}
270
271#[derive(Debug, Deserialize)]
272#[allow(dead_code)]
273pub(crate) struct OllamaModelEntry {
274    pub name: String,
275    pub modified_at: Option<String>,
276    pub size: Option<u64>,
277}
278
279// ---------------------------------------------------------------------------
280// Tests
281// ---------------------------------------------------------------------------
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    #[test]
288    fn test_infer_openai_capabilities_chat() {
289        let caps = infer_openai_capabilities("gpt-4o");
290        assert!(caps.contains(&ModelCapability::Chat));
291        assert!(caps.contains(&ModelCapability::ToolUse));
292        assert!(caps.contains(&ModelCapability::Vision));
293    }
294
295    #[test]
296    fn test_infer_openai_capabilities_embedding() {
297        let caps = infer_openai_capabilities("text-embedding-3-small");
298        assert!(caps.contains(&ModelCapability::Embedding));
299        assert!(!caps.contains(&ModelCapability::Chat));
300    }
301
302    #[test]
303    fn test_infer_openai_capabilities_audio() {
304        let caps = infer_openai_capabilities("whisper-1");
305        assert!(caps.contains(&ModelCapability::Audio));
306        assert!(!caps.contains(&ModelCapability::Chat));
307    }
308
309    #[test]
310    fn test_infer_openai_capabilities_image_gen() {
311        let caps = infer_openai_capabilities("dall-e-3");
312        assert!(caps.contains(&ModelCapability::ImageGeneration));
313        assert!(!caps.contains(&ModelCapability::Chat));
314    }
315
316    #[test]
317    fn test_infer_openai_capabilities_basic_chat() {
318        let caps = infer_openai_capabilities("gpt-3.5-turbo");
319        assert!(caps.contains(&ModelCapability::Chat));
320        assert!(caps.contains(&ModelCapability::ToolUse));
321        assert!(!caps.contains(&ModelCapability::Vision));
322    }
323
324    #[test]
325    fn test_available_model_is_chat_capable() {
326        let model = AvailableModel {
327            id: "test".to_string(),
328            display_name: None,
329            provider: ProviderType::OpenAI,
330            capabilities: vec![ModelCapability::Chat],
331            owned_by: None,
332            context_window: None,
333            max_output_tokens: None,
334            created_at: None,
335        };
336        assert!(model.is_chat_capable());
337
338        let embedding_model = AvailableModel {
339            id: "embed".to_string(),
340            display_name: None,
341            provider: ProviderType::OpenAI,
342            capabilities: vec![ModelCapability::Embedding],
343            owned_by: None,
344            context_window: None,
345            max_output_tokens: None,
346            created_at: None,
347        };
348        assert!(!embedding_model.is_chat_capable());
349    }
350
351    #[test]
352    fn test_parse_anthropic_response() {
353        let json = r#"{
354            "data": [
355                {"id": "claude-sonnet-4-20250514", "display_name": "Claude Sonnet 4", "type": "model", "created_at": "2025-05-14T00:00:00Z"},
356                {"id": "claude-3-5-haiku-20241022", "display_name": "Claude 3.5 Haiku", "type": "model"}
357            ],
358            "has_more": false
359        }"#;
360        let resp: AnthropicListResponse = serde_json::from_str(json).unwrap();
361        assert_eq!(resp.data.len(), 2);
362        assert_eq!(resp.data[0].id, "claude-sonnet-4-20250514");
363        assert!(!resp.has_more);
364    }
365
366    #[test]
367    fn test_parse_openai_response() {
368        let json = r#"{
369            "data": [
370                {"id": "gpt-4o", "owned_by": "openai", "created": 1715367049},
371                {"id": "text-embedding-3-small", "owned_by": "openai", "created": 1705948997}
372            ]
373        }"#;
374        let resp: OpenAIListResponse = serde_json::from_str(json).unwrap();
375        assert_eq!(resp.data.len(), 2);
376        assert_eq!(resp.data[0].id, "gpt-4o");
377    }
378
379    #[test]
380    fn test_parse_google_response() {
381        let json = r#"{
382            "models": [
383                {
384                    "name": "models/gemini-2.0-flash",
385                    "displayName": "Gemini 2.0 Flash",
386                    "inputTokenLimit": 1048576,
387                    "outputTokenLimit": 8192,
388                    "supportedGenerationMethods": ["generateContent", "countTokens"]
389                }
390            ]
391        }"#;
392        let resp: GoogleListResponse = serde_json::from_str(json).unwrap();
393        assert_eq!(resp.models.len(), 1);
394        assert_eq!(resp.models[0].input_token_limit, Some(1048576));
395    }
396
397    #[test]
398    fn test_parse_ollama_response() {
399        let json = r#"{
400            "models": [
401                {"name": "llama3.1:latest", "modified_at": "2024-08-01T00:00:00Z", "size": 4000000000},
402                {"name": "codellama:7b", "modified_at": "2024-07-15T00:00:00Z", "size": 3800000000}
403            ]
404        }"#;
405        let resp: OllamaTagsResponse = serde_json::from_str(json).unwrap();
406        assert_eq!(resp.models.len(), 2);
407        assert_eq!(resp.models[0].name, "llama3.1:latest");
408    }
409
410    #[test]
411    fn test_model_capability_display() {
412        assert_eq!(ModelCapability::Chat.to_string(), "chat");
413        assert_eq!(ModelCapability::ToolUse.to_string(), "tool_use");
414        assert_eq!(ModelCapability::Vision.to_string(), "vision");
415    }
416
417    #[test]
418    fn test_create_model_lister_no_key() {
419        let result = create_model_lister(ProviderType::Anthropic, None, None);
420        assert!(result.is_err());
421        let err = result.map(|_| ()).unwrap_err();
422        assert!(err.to_string().contains("API key"));
423    }
424
425    #[test]
426    fn test_create_model_lister_ollama_no_key() {
427        let result = create_model_lister(ProviderType::Ollama, None, None);
428        assert!(result.is_ok());
429    }
430
431    #[test]
432    fn test_create_model_lister_brainwires_unsupported() {
433        let result = create_model_lister(ProviderType::Brainwires, Some("key"), None);
434        assert!(result.is_err());
435    }
436}