use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use katu_core::{ModelId, ProviderId, RouteId};
use crate::cache::CachePolicy;
use katu_core::GenerationOptions;
use crate::http::HttpOptions;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum InputModality {
Text,
Image,
Audio,
Video,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ReasoningEffort {
None,
Low,
Medium,
High,
XHigh,
Max
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ThinkingMode {
Adaptive,
Budget,
Effort,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ThinkingConfig {
pub mode: ThinkingMode,
pub default_budget: Option<u32>,
pub min_effort: Option<ReasoningEffort>,
pub max_effort: Option<ReasoningEffort>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ModelCapabilities {
pub input_modalities: Vec<InputModality>,
pub tool_calls: bool,
pub streaming_tool_input: bool,
pub structured_output: bool,
pub prompt_caching: bool,
pub thinking: Option<ThinkingConfig>,
}
impl ModelCapabilities {
pub fn supports_modality(&self, modality: InputModality) -> bool {
self.input_modalities.contains(&modality)
}
pub fn supports_thinking(&self) -> bool {
self.thinking.is_some()
}
}
impl Default for ModelCapabilities {
fn default() -> Self {
Self {
input_modalities: vec![InputModality::Text],
tool_calls: true,
streaming_tool_input: false,
structured_output: false,
prompt_caching: false,
thinking: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ModelLimits {
pub context_window: u32,
pub max_output_tokens: u32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ModelPricing {
pub input: f64,
pub output: f64,
pub cache_read: f64,
pub cache_write: f64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ModelRef {
pub id: ModelId,
pub provider: ProviderId,
pub route: RouteId,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
pub base_url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub query_params: Option<HashMap<String, String>>,
pub limits: ModelLimits,
pub capabilities: ModelCapabilities,
#[serde(skip_serializing_if = "Option::is_none")]
pub generation: Option<GenerationOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_policy: Option<CachePolicy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pricing: Option<ModelPricing>,
#[serde(skip_serializing_if = "Option::is_none")]
pub provider_options: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub http: Option<HttpOptions>,
}
impl ModelRef {
pub fn new(
id: ModelId,
provider: ProviderId,
route: RouteId,
base_url: impl Into<String>,
limits: ModelLimits,
) -> Self {
Self {
id,
provider,
route,
display_name: None,
base_url: base_url.into(),
api_key: None,
headers: None,
query_params: None,
limits,
capabilities: ModelCapabilities::default(),
generation: None,
cache_policy: None,
pricing: None,
provider_options: None,
http: None,
}
}
pub fn with_display_name(mut self, name: impl Into<String>) -> Self {
self.display_name = Some(name.into());
self
}
pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers
.get_or_insert_with(HashMap::new)
.insert(key.into(), value.into());
self
}
pub fn with_query_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.query_params
.get_or_insert_with(HashMap::new)
.insert(key.into(), value.into());
self
}
pub fn with_capabilities(mut self, capabilities: ModelCapabilities) -> Self {
self.capabilities = capabilities;
self
}
pub fn with_generation(mut self, generation: GenerationOptions) -> Self {
self.generation = Some(generation);
self
}
pub fn with_cache_policy(mut self, policy: CachePolicy) -> Self {
self.cache_policy = Some(policy);
self
}
pub fn with_pricing(mut self, pricing: ModelPricing) -> Self {
self.pricing = Some(pricing);
self
}
pub fn with_provider_options(mut self, options: serde_json::Value) -> Self {
self.provider_options = Some(options);
self
}
pub fn with_http(mut self, http: HttpOptions) -> Self {
self.http = Some(http);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_model() -> ModelRef {
ModelRef::new(
ModelId::new("claude-sonnet-4-20250514"),
ProviderId::new("anthropic"),
RouteId::new("anthropic-messages"),
"https://api.anthropic.com/v1",
ModelLimits {
context_window: 200_000,
max_output_tokens: 8192,
},
)
}
#[test]
fn test_new_has_required_fields() {
let m = sample_model();
assert_eq!(m.id.as_str(), "claude-sonnet-4-20250514");
assert_eq!(m.provider.as_str(), "anthropic");
assert_eq!(m.route.as_str(), "anthropic-messages");
assert_eq!(m.base_url, "https://api.anthropic.com/v1");
assert_eq!(m.limits.context_window, 200_000);
assert_eq!(m.limits.max_output_tokens, 8192);
}
#[test]
fn test_new_optional_fields_are_none() {
let m = sample_model();
assert_eq!(m.display_name, None);
assert_eq!(m.api_key, None);
assert_eq!(m.headers, None);
assert_eq!(m.generation, None);
assert_eq!(m.pricing, None);
assert_eq!(m.provider_options, None);
assert_eq!(m.http, None);
}
#[test]
fn test_builder_chain() {
let m = sample_model()
.with_display_name("Claude Sonnet 4")
.with_api_key("sk-ant-xxx")
.with_header("x-custom", "value")
.with_query_param("version", "1")
.with_generation(GenerationOptions::new().with_max_tokens(4096))
.with_cache_policy(CachePolicy::Auto)
.with_pricing(ModelPricing {
input: 3.0,
output: 15.0,
cache_read: 0.30,
cache_write: 3.75,
});
assert_eq!(m.display_name.as_deref(), Some("Claude Sonnet 4"));
assert_eq!(m.api_key.as_deref(), Some("sk-ant-xxx"));
assert_eq!(
m.headers.as_ref().unwrap().get("x-custom").unwrap(),
"value"
);
assert_eq!(
m.generation.as_ref().unwrap().max_tokens,
Some(4096)
);
assert_eq!(m.pricing.as_ref().unwrap().input, 3.0);
}
#[test]
fn test_capabilities_default() {
let m = sample_model();
assert!(m.capabilities.supports_modality(InputModality::Text));
assert!(!m.capabilities.supports_modality(InputModality::Image));
assert!(m.capabilities.tool_calls);
assert!(!m.capabilities.supports_thinking());
}
#[test]
fn test_capabilities_with_thinking() {
let m = sample_model().with_capabilities(ModelCapabilities {
input_modalities: vec![InputModality::Text, InputModality::Image],
tool_calls: true,
streaming_tool_input: true,
structured_output: false,
prompt_caching: true,
thinking: Some(ThinkingConfig {
mode: ThinkingMode::Adaptive,
default_budget: None,
min_effort: None,
max_effort: None,
}),
});
assert!(m.capabilities.supports_thinking());
assert!(m.capabilities.supports_modality(InputModality::Image));
assert!(m.capabilities.streaming_tool_input);
}
#[test]
fn test_serde_roundtrip_minimal() {
let m = sample_model();
let json = serde_json::to_string(&m).unwrap();
let restored: ModelRef = serde_json::from_str(&json).unwrap();
assert_eq!(m.id, restored.id);
assert_eq!(m.provider, restored.provider);
assert_eq!(m.limits, restored.limits);
}
#[test]
fn test_serde_roundtrip_full() {
let m = sample_model()
.with_display_name("Claude Sonnet 4")
.with_api_key("sk-test")
.with_capabilities(ModelCapabilities {
input_modalities: vec![InputModality::Text, InputModality::Image],
tool_calls: true,
streaming_tool_input: true,
structured_output: true,
prompt_caching: true,
thinking: Some(ThinkingConfig {
mode: ThinkingMode::Budget,
default_budget: Some(10000),
min_effort: Some(ReasoningEffort::Low),
max_effort: Some(ReasoningEffort::High),
}),
})
.with_generation(GenerationOptions::new().with_max_tokens(4096).with_temperature(0.7))
.with_cache_policy(CachePolicy::Auto)
.with_pricing(ModelPricing {
input: 3.0,
output: 15.0,
cache_read: 0.30,
cache_write: 3.75,
})
.with_provider_options(serde_json::json!({"region": "us-east-1"}))
.with_http(HttpOptions::new().with_header("x-extra", "val"));
let json = serde_json::to_string_pretty(&m).unwrap();
let restored: ModelRef = serde_json::from_str(&json).unwrap();
assert_eq!(m, restored);
}
#[test]
fn test_serde_skips_none_fields() {
let m = sample_model();
let json = serde_json::to_string(&m).unwrap();
assert!(!json.contains("display_name"));
assert!(!json.contains("api_key"));
assert!(!json.contains("pricing"));
assert!(!json.contains("provider_options"));
}
}