use crate::llm::LlmConfig;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct ModelCost {
#[serde(default)]
pub input: f64,
#[serde(default)]
pub output: f64,
#[serde(default)]
pub cache_read: f64,
#[serde(default)]
pub cache_write: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ModelLimit {
#[serde(default)]
pub context: u32,
#[serde(default)]
pub output: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ModelModalities {
#[serde(default)]
pub input: Vec<String>,
#[serde(default)]
pub output: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ModelConfig {
pub id: String,
#[serde(default)]
pub name: String,
#[serde(default)]
pub family: String,
#[serde(default)]
pub api_key: Option<String>,
#[serde(default)]
pub base_url: Option<String>,
#[serde(default)]
pub headers: HashMap<String, String>,
#[serde(default)]
pub session_id_header: Option<String>,
#[serde(default)]
pub attachment: bool,
#[serde(default)]
pub reasoning: bool,
#[serde(default = "default_true")]
pub tool_call: bool,
#[serde(default = "default_true")]
pub temperature: bool,
#[serde(default)]
pub release_date: Option<String>,
#[serde(default)]
pub modalities: ModelModalities,
#[serde(default)]
pub cost: ModelCost,
#[serde(default)]
pub limit: ModelLimit,
}
pub(crate) fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ProviderConfig {
pub name: String,
#[serde(default)]
pub api_key: Option<String>,
#[serde(default)]
pub base_url: Option<String>,
#[serde(default)]
pub headers: HashMap<String, String>,
#[serde(default)]
pub session_id_header: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
}
pub(crate) fn apply_model_caps(
mut config: LlmConfig,
model: &ModelConfig,
thinking_budget: Option<usize>,
) -> LlmConfig {
if model.reasoning {
if let Some(budget) = thinking_budget {
config = config.with_thinking_budget(budget);
}
}
if model.limit.output > 0 {
config = config.with_max_tokens(model.limit.output as usize);
}
if !model.temperature {
config.disable_temperature = true;
}
config
}
impl ProviderConfig {
pub fn find_model(&self, model_id: &str) -> Option<&ModelConfig> {
self.models.iter().find(|m| m.id == model_id)
}
pub fn get_api_key<'a>(&'a self, model: &'a ModelConfig) -> Option<&'a str> {
model.api_key.as_deref().or(self.api_key.as_deref())
}
pub fn get_base_url<'a>(&'a self, model: &'a ModelConfig) -> Option<&'a str> {
model.base_url.as_deref().or(self.base_url.as_deref())
}
pub fn get_headers(&self, model: &ModelConfig) -> HashMap<String, String> {
let mut headers = self.headers.clone();
headers.extend(model.headers.clone());
headers
}
pub fn get_session_id_header<'a>(&'a self, model: &'a ModelConfig) -> Option<&'a str> {
model
.session_id_header
.as_deref()
.or(self.session_id_header.as_deref())
}
}