use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use super::ProviderType;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ModelCapability {
Chat,
ToolUse,
Vision,
Embedding,
Audio,
ImageGeneration,
}
impl std::fmt::Display for ModelCapability {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Chat => write!(f, "chat"),
Self::ToolUse => write!(f, "tool_use"),
Self::Vision => write!(f, "vision"),
Self::Embedding => write!(f, "embedding"),
Self::Audio => write!(f, "audio"),
Self::ImageGeneration => write!(f, "image_generation"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AvailableModel {
pub id: String,
pub display_name: Option<String>,
pub provider: ProviderType,
pub capabilities: Vec<ModelCapability>,
pub owned_by: Option<String>,
pub context_window: Option<u32>,
pub max_output_tokens: Option<u32>,
pub created_at: Option<i64>,
}
impl AvailableModel {
pub fn is_chat_capable(&self) -> bool {
self.capabilities.contains(&ModelCapability::Chat)
}
}
#[async_trait]
pub trait ModelLister: Send + Sync {
async fn list_models(&self) -> Result<Vec<AvailableModel>>;
}
pub fn infer_openai_capabilities(model_id: &str) -> Vec<ModelCapability> {
let id = model_id.to_lowercase();
if id.contains("embedding") || id.starts_with("text-embedding") {
return vec![ModelCapability::Embedding];
}
if id.starts_with("whisper") || id.starts_with("tts") {
return vec![ModelCapability::Audio];
}
if id.starts_with("dall-e") {
return vec![ModelCapability::ImageGeneration];
}
let mut caps = vec![ModelCapability::Chat, ModelCapability::ToolUse];
if id.contains("vision")
|| id.contains("gpt-4o")
|| id.contains("gpt-4-turbo")
|| id.contains("gpt-5")
|| (id.starts_with("o") && !id.starts_with("omni"))
{
caps.push(ModelCapability::Vision);
}
caps
}
pub fn create_model_lister(
provider_type: ProviderType,
api_key: Option<&str>,
base_url: Option<&str>,
) -> Result<Box<dyn ModelLister>> {
match provider_type {
ProviderType::Anthropic => {
let key = api_key
.ok_or_else(|| anyhow::anyhow!("Anthropic requires an API key"))?
.to_string();
Ok(Box::new(super::anthropic::AnthropicModelLister::new(key)))
}
ProviderType::OpenAI => {
let key = api_key
.ok_or_else(|| anyhow::anyhow!("OpenAI requires an API key"))?
.to_string();
Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
key,
base_url.map(|s| s.to_string()),
)))
}
ProviderType::Google => {
let key = api_key
.ok_or_else(|| anyhow::anyhow!("Google requires an API key"))?
.to_string();
Ok(Box::new(super::gemini::GoogleModelLister::new(key)))
}
ProviderType::Groq
| ProviderType::Together
| ProviderType::Fireworks
| ProviderType::Anyscale => {
let key = api_key
.ok_or_else(|| anyhow::anyhow!("{} requires an API key", provider_type))?
.to_string();
let registry_url = super::registry::lookup(provider_type).and_then(|e| e.models_url);
let url = base_url
.or(registry_url)
.unwrap_or("https://api.openai.com/v1/models");
Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
key,
Some(url.to_string()),
)))
}
ProviderType::Ollama => Ok(Box::new(super::ollama::OllamaModelLister::new(
base_url.map(|s| s.to_string()),
))),
ProviderType::OpenAiResponses => {
let key = api_key
.ok_or_else(|| anyhow::anyhow!("OpenAI Responses requires an API key"))?
.to_string();
Ok(Box::new(super::openai_chat::OpenAIModelLister::new(
key,
base_url.map(|s| s.to_string()),
)))
}
ProviderType::Brainwires
| ProviderType::Custom
| ProviderType::MiniMax
| ProviderType::Bedrock
| ProviderType::VertexAI
| ProviderType::ElevenLabs
| ProviderType::Deepgram
| ProviderType::Azure
| ProviderType::Fish
| ProviderType::Cartesia
| ProviderType::Murf => Err(anyhow::anyhow!(
"Model listing is not supported for {} provider via this interface",
provider_type
)),
}
}
#[derive(Debug, Deserialize)]
pub(crate) struct AnthropicListResponse {
pub data: Vec<AnthropicModelEntry>,
pub has_more: bool,
#[serde(default)]
pub last_id: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct AnthropicModelEntry {
pub id: String,
pub display_name: String,
#[serde(rename = "type")]
pub _type: Option<String>,
pub created_at: Option<String>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct OpenAIListResponse {
pub data: Vec<OpenAIModelEntry>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct OpenAIModelEntry {
pub id: String,
pub owned_by: Option<String>,
pub created: Option<i64>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub(crate) struct GoogleListResponse {
#[serde(default)]
pub models: Vec<GoogleModelEntry>,
#[serde(rename = "nextPageToken")]
pub next_page_token: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub(crate) struct GoogleModelEntry {
pub name: String,
#[serde(rename = "displayName")]
pub display_name: Option<String>,
#[serde(rename = "inputTokenLimit")]
pub input_token_limit: Option<u32>,
#[serde(rename = "outputTokenLimit")]
pub output_token_limit: Option<u32>,
#[serde(rename = "supportedGenerationMethods", default)]
pub supported_generation_methods: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct OllamaTagsResponse {
pub models: Vec<OllamaModelEntry>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub(crate) struct OllamaModelEntry {
pub name: String,
pub modified_at: Option<String>,
pub size: Option<u64>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_infer_openai_capabilities_chat() {
let caps = infer_openai_capabilities("gpt-4o");
assert!(caps.contains(&ModelCapability::Chat));
assert!(caps.contains(&ModelCapability::ToolUse));
assert!(caps.contains(&ModelCapability::Vision));
}
#[test]
fn test_infer_openai_capabilities_embedding() {
let caps = infer_openai_capabilities("text-embedding-3-small");
assert!(caps.contains(&ModelCapability::Embedding));
assert!(!caps.contains(&ModelCapability::Chat));
}
#[test]
fn test_infer_openai_capabilities_audio() {
let caps = infer_openai_capabilities("whisper-1");
assert!(caps.contains(&ModelCapability::Audio));
assert!(!caps.contains(&ModelCapability::Chat));
}
#[test]
fn test_infer_openai_capabilities_image_gen() {
let caps = infer_openai_capabilities("dall-e-3");
assert!(caps.contains(&ModelCapability::ImageGeneration));
assert!(!caps.contains(&ModelCapability::Chat));
}
#[test]
fn test_infer_openai_capabilities_basic_chat() {
let caps = infer_openai_capabilities("gpt-3.5-turbo");
assert!(caps.contains(&ModelCapability::Chat));
assert!(caps.contains(&ModelCapability::ToolUse));
assert!(!caps.contains(&ModelCapability::Vision));
}
#[test]
fn test_available_model_is_chat_capable() {
let model = AvailableModel {
id: "test".to_string(),
display_name: None,
provider: ProviderType::OpenAI,
capabilities: vec![ModelCapability::Chat],
owned_by: None,
context_window: None,
max_output_tokens: None,
created_at: None,
};
assert!(model.is_chat_capable());
let embedding_model = AvailableModel {
id: "embed".to_string(),
display_name: None,
provider: ProviderType::OpenAI,
capabilities: vec![ModelCapability::Embedding],
owned_by: None,
context_window: None,
max_output_tokens: None,
created_at: None,
};
assert!(!embedding_model.is_chat_capable());
}
#[test]
fn test_parse_anthropic_response() {
let json = r#"{
"data": [
{"id": "claude-sonnet-4-20250514", "display_name": "Claude Sonnet 4", "type": "model", "created_at": "2025-05-14T00:00:00Z"},
{"id": "claude-3-5-haiku-20241022", "display_name": "Claude 3.5 Haiku", "type": "model"}
],
"has_more": false
}"#;
let resp: AnthropicListResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.data.len(), 2);
assert_eq!(resp.data[0].id, "claude-sonnet-4-20250514");
assert!(!resp.has_more);
}
#[test]
fn test_parse_openai_response() {
let json = r#"{
"data": [
{"id": "gpt-4o", "owned_by": "openai", "created": 1715367049},
{"id": "text-embedding-3-small", "owned_by": "openai", "created": 1705948997}
]
}"#;
let resp: OpenAIListResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.data.len(), 2);
assert_eq!(resp.data[0].id, "gpt-4o");
}
#[test]
fn test_parse_google_response() {
let json = r#"{
"models": [
{
"name": "models/gemini-2.0-flash",
"displayName": "Gemini 2.0 Flash",
"inputTokenLimit": 1048576,
"outputTokenLimit": 8192,
"supportedGenerationMethods": ["generateContent", "countTokens"]
}
]
}"#;
let resp: GoogleListResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.models.len(), 1);
assert_eq!(resp.models[0].input_token_limit, Some(1048576));
}
#[test]
fn test_parse_ollama_response() {
let json = r#"{
"models": [
{"name": "llama3.1:latest", "modified_at": "2024-08-01T00:00:00Z", "size": 4000000000},
{"name": "codellama:7b", "modified_at": "2024-07-15T00:00:00Z", "size": 3800000000}
]
}"#;
let resp: OllamaTagsResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.models.len(), 2);
assert_eq!(resp.models[0].name, "llama3.1:latest");
}
#[test]
fn test_model_capability_display() {
assert_eq!(ModelCapability::Chat.to_string(), "chat");
assert_eq!(ModelCapability::ToolUse.to_string(), "tool_use");
assert_eq!(ModelCapability::Vision.to_string(), "vision");
}
#[test]
fn test_create_model_lister_no_key() {
let result = create_model_lister(ProviderType::Anthropic, None, None);
assert!(result.is_err());
let err = result.map(|_| ()).unwrap_err();
assert!(err.to_string().contains("API key"));
}
#[test]
fn test_create_model_lister_ollama_no_key() {
let result = create_model_lister(ProviderType::Ollama, None, None);
assert!(result.is_ok());
}
#[test]
fn test_create_model_lister_brainwires_unsupported() {
let result = create_model_lister(ProviderType::Brainwires, Some("key"), None);
assert!(result.is_err());
}
}