use serde::{Deserialize, Serialize};
use crate::BoxFuture;
use crate::agents::error::AgentError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum EffortLevel {
Low,
Medium,
High,
Max,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ThinkingConfig {
Adaptive,
Enabled {
budget_tokens: u32,
},
Disabled,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[allow(clippy::struct_excessive_bools)]
pub struct ModelCapabilities {
pub tool_calling: bool,
pub vision: bool,
pub streaming: bool,
pub structured_output: bool,
pub effort_levels: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub display_name: String,
pub description: String,
pub capabilities: ModelCapabilities,
pub context_window: u32,
pub max_output_tokens: u32,
pub supported_effort_levels: Vec<EffortLevel>,
}
pub trait ModelProvider: Send + Sync {
fn list_models(&self) -> BoxFuture<'_, Result<Vec<ModelInfo>, AgentError>>;
}
pub struct ModelSelector<'a> {
models: &'a [ModelInfo],
}
impl<'a> ModelSelector<'a> {
#[must_use]
pub const fn new(models: &'a [ModelInfo]) -> Self {
Self { models }
}
#[must_use]
pub fn by_name(&self, id: &str) -> Option<&ModelInfo> {
self.models.iter().find(|m| m.id == id)
}
#[must_use]
pub fn by_provider(&self, prefix: &str) -> Vec<&ModelInfo> {
self.models
.iter()
.filter(|m| m.id.starts_with(prefix))
.collect()
}
#[must_use]
pub fn with_tool_calling(&self) -> Vec<&ModelInfo> {
self.models
.iter()
.filter(|m| m.capabilities.tool_calling)
.collect()
}
#[must_use]
pub fn with_vision(&self) -> Vec<&ModelInfo> {
self.models
.iter()
.filter(|m| m.capabilities.vision)
.collect()
}
#[must_use]
pub fn with_streaming(&self) -> Vec<&ModelInfo> {
self.models
.iter()
.filter(|m| m.capabilities.streaming)
.collect()
}
#[must_use]
pub fn with_structured_output(&self) -> Vec<&ModelInfo> {
self.models
.iter()
.filter(|m| m.capabilities.structured_output)
.collect()
}
#[must_use]
pub fn with_effort_levels(&self) -> Vec<&ModelInfo> {
self.models
.iter()
.filter(|m| m.capabilities.effort_levels)
.collect()
}
#[must_use]
pub fn by_min_context(&self, min_tokens: u32) -> Vec<&ModelInfo> {
self.models
.iter()
.filter(|m| m.context_window >= min_tokens)
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_model(id: &str, tool_calling: bool, vision: bool) -> ModelInfo {
ModelInfo {
id: id.to_string(),
display_name: id.to_string(),
description: String::new(),
capabilities: ModelCapabilities {
tool_calling,
vision,
streaming: true,
structured_output: false,
effort_levels: false,
},
context_window: 100_000,
max_output_tokens: 4096,
supported_effort_levels: Vec::new(),
}
}
#[test]
fn test_selector_by_name() {
let models = vec![
make_model("anthropic/claude-3-5-sonnet", true, true),
make_model("openai/gpt-4o", true, false),
];
let sel = ModelSelector::new(&models);
assert!(sel.by_name("openai/gpt-4o").is_some());
assert!(sel.by_name("nonexistent").is_none());
}
#[test]
fn test_selector_by_provider() {
let models = vec![
make_model("anthropic/claude-3-5-sonnet", true, true),
make_model("anthropic/claude-3-haiku", true, false),
make_model("openai/gpt-4o", true, false),
];
let sel = ModelSelector::new(&models);
assert_eq!(sel.by_provider("anthropic/").len(), 2);
assert_eq!(sel.by_provider("openai/").len(), 1);
}
#[test]
fn test_selector_with_vision() {
let models = vec![
make_model("vision-model", true, true),
make_model("text-only", true, false),
];
let sel = ModelSelector::new(&models);
let vision = sel.with_vision();
assert_eq!(vision.len(), 1);
assert_eq!(vision[0].id, "vision-model");
}
}