use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct ModelCapabilities {
pub supports_thinking: bool,
pub supports_vision: bool,
pub supports_tool_use: bool,
pub supports_streaming: bool,
pub supports_structured_output: bool,
pub max_context_window: Option<u64>,
pub max_output_tokens: Option<u64>,
}
impl ModelCapabilities {
#[must_use]
pub fn none() -> Self {
Self::default()
}
#[must_use]
pub const fn with_thinking(mut self, val: bool) -> Self {
self.supports_thinking = val;
self
}
#[must_use]
pub const fn with_vision(mut self, val: bool) -> Self {
self.supports_vision = val;
self
}
#[must_use]
pub const fn with_tool_use(mut self, val: bool) -> Self {
self.supports_tool_use = val;
self
}
#[must_use]
pub const fn with_streaming(mut self, val: bool) -> Self {
self.supports_streaming = val;
self
}
#[must_use]
pub const fn with_structured_output(mut self, val: bool) -> Self {
self.supports_structured_output = val;
self
}
#[must_use]
pub const fn with_max_context_window(mut self, tokens: u64) -> Self {
self.max_context_window = Some(tokens);
self
}
#[must_use]
pub const fn with_max_output_tokens(mut self, tokens: u64) -> Self {
self.max_output_tokens = Some(tokens);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ThinkingLevel {
#[default]
Off,
Minimal,
Low,
Medium,
High,
ExtraHigh,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ThinkingBudgets {
pub budgets: HashMap<ThinkingLevel, u64>,
}
impl ThinkingBudgets {
#[must_use]
pub const fn new(budgets: HashMap<ThinkingLevel, u64>) -> Self {
Self { budgets }
}
#[must_use]
pub fn get(&self, level: &ThinkingLevel) -> Option<u64> {
self.budgets.get(level).copied()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[allow(clippy::derive_partial_eq_without_eq)]
pub struct ModelSpec {
pub provider: String,
pub model_id: String,
pub thinking_level: ThinkingLevel,
pub thinking_budgets: Option<ThinkingBudgets>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub provider_config: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub capabilities: Option<ModelCapabilities>,
}
impl ModelSpec {
#[must_use]
pub fn new(provider: impl Into<String>, model_id: impl Into<String>) -> Self {
Self {
provider: provider.into(),
model_id: model_id.into(),
thinking_level: ThinkingLevel::Off,
thinking_budgets: None,
provider_config: None,
capabilities: None,
}
}
#[must_use]
pub const fn with_thinking_level(mut self, level: ThinkingLevel) -> Self {
self.thinking_level = level;
self
}
#[must_use]
pub fn with_thinking_budgets(mut self, budgets: ThinkingBudgets) -> Self {
self.thinking_budgets = Some(budgets);
self
}
#[must_use]
pub fn with_provider_config(mut self, config: serde_json::Value) -> Self {
self.provider_config = Some(config);
self
}
#[must_use]
pub const fn with_capabilities(mut self, capabilities: ModelCapabilities) -> Self {
self.capabilities = Some(capabilities);
self
}
#[must_use]
pub fn capabilities(&self) -> ModelCapabilities {
self.capabilities.clone().unwrap_or_default()
}
#[must_use]
pub fn provider_config_as<T: serde::de::DeserializeOwned>(&self) -> Option<T> {
self.provider_config
.as_ref()
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
}