use anyhow::{anyhow, bail, Context, Result};
use std::collections::HashMap;
use crate::config::TomlConfig;
use crate::skill::manifest::{ModelConfig, Provider};
fn role_default_temperature(role: &str) -> Option<f64> {
match role {
"thinker" => Some(1.0),
"worker" => Some(0.0),
"default" => Some(0.8),
_ => None,
}
}
pub fn pattern_default_role(pattern: &str) -> &'static str {
match pattern {
"react" => "thinker",
_ => "worker", }
}
#[derive(Clone)]
pub struct ResolvedModel {
pub role: String,
pub provider: Provider,
pub model_name: String,
pub temperature: Option<f64>,
pub max_tokens: Option<u32>,
pub api_key: String,
pub base_url: Option<String>,
}
pub struct ModelRegistry {
models: HashMap<String, ResolvedModel>,
}
impl ModelRegistry {
pub fn build(config: &TomlConfig) -> Result<Self> {
let mut models = HashMap::new();
for (role, entry) in &config.models {
let resolved = resolve_role(role, entry.name(), entry.temperature(), entry.max_tokens(), entry.base_url(), &config.providers)
.with_context(|| format!("failed to configure model role '{role}'"))?;
models.insert(role.clone(), resolved);
}
Ok(ModelRegistry { models })
}
pub fn get(&self, role: &str) -> Result<&ResolvedModel> {
self.models
.get(role)
.ok_or_else(|| anyhow!("model role '{}' is not defined in tama.toml [models]", role))
}
pub fn resolve(&self, model_config: Option<&ModelConfig>, pattern: &str) -> Result<ResolvedModel> {
let (mut resolved, local_temp, local_max_tokens) = match model_config {
Some(mc) if mc.name.is_some() => {
let name = mc.name.as_ref().unwrap();
let r = resolve_direct_spec(name)
.with_context(|| format!("in model: name: '{name}'"))?;
(r, mc.temperature.map(|t| t as f64), mc.max_tokens)
}
Some(mc) if mc.role.is_some() => {
let role = mc.role.as_ref().unwrap();
let r = self.get_or_default(role)?;
(r.clone(), mc.temperature.map(|t| t as f64), mc.max_tokens)
}
_ => {
let default_role = pattern_default_role(pattern);
let r = self.get_or_default(default_role)?;
(r.clone(), None, None)
}
};
if let Some(t) = local_temp {
resolved.temperature = Some(t);
}
if let Some(mt) = local_max_tokens {
resolved.max_tokens = Some(mt);
}
Ok(resolved)
}
fn get_or_default(&self, role: &str) -> Result<&ResolvedModel> {
if let Some(m) = self.models.get(role) {
return Ok(m);
}
if let Some(m) = self.models.get("default") {
return Ok(m);
}
bail!(
"model role '{}' is not defined in tama.toml [models], and no 'default' role is defined.\n \
Add a [models] section to tama.toml or set TAMA_MODEL_{}_NAME.",
role,
role.to_uppercase().replace('-', "_")
)
}
}
fn resolve_role(
role: &str,
entry_name: &str,
entry_temperature: Option<f64>,
entry_max_tokens: Option<u32>,
entry_base_url: Option<&str>,
providers: &HashMap<String, crate::config::ProviderEntry>,
) -> Result<ResolvedModel> {
let role_upper = role.to_uppercase().replace('-', "_");
let name_str = std::env::var(format!("TAMA_MODEL_{role_upper}_NAME"))
.unwrap_or_else(|_| entry_name.to_string());
let (provider, model_name) = parse_provider_model(&name_str)
.with_context(|| format!("invalid model spec '{name_str}'"))?;
let env_temp: Option<f64> = std::env::var(format!("TAMA_MODEL_{role_upper}_TEMPERATURE"))
.ok()
.and_then(|s| s.parse().ok());
let temperature = env_temp.or(entry_temperature).or_else(|| role_default_temperature(role));
let env_max_tokens: Option<u32> = std::env::var(format!("TAMA_MODEL_{role_upper}_MAX_TOKENS"))
.ok()
.and_then(|s| s.parse().ok());
let max_tokens = env_max_tokens.or(entry_max_tokens);
let api_key = if provider == Provider::Ollama {
"ollama".to_string()
} else {
let provider_env = provider_api_key_env(&provider);
std::env::var(format!("TAMA_MODEL_{role_upper}_API_KEY"))
.or_else(|_| std::env::var(provider_env))
.with_context(|| {
format!(
"model role '{}' (provider: {}) is missing an API key.\n \
Set TAMA_MODEL_{role_upper}_API_KEY or {}.",
role, provider, provider_env
)
})?
};
let provider_str = provider.to_string().to_uppercase();
let base_url = std::env::var(format!("TAMA_MODEL_{role_upper}_BASE_URL"))
.ok()
.or_else(|| std::env::var(format!("TAMA_PROVIDER_{provider_str}_BASE_URL")).ok())
.or_else(|| entry_base_url.map(str::to_string))
.or_else(|| {
providers
.get(&provider.to_string())
.and_then(|p| p.base_url.clone())
});
Ok(ResolvedModel {
role: role.to_string(),
provider,
model_name,
temperature,
max_tokens,
api_key,
base_url,
})
}
fn resolve_direct_spec(name: &str) -> Result<ResolvedModel> {
let (provider, model_name) = parse_provider_model(name)?;
let api_key = if provider == Provider::Ollama {
"ollama".to_string()
} else {
let provider_env = provider_api_key_env(&provider);
let provider_str = provider.to_string().to_uppercase();
std::env::var(format!("TAMA_PROVIDER_{provider_str}_API_KEY"))
.or_else(|_| std::env::var(provider_env))
.with_context(|| {
format!(
"direct model spec '{}' is missing an API key.\n Set {}.",
name, provider_env
)
})?
};
let provider_str = provider.to_string().to_uppercase();
let base_url = std::env::var(format!("TAMA_PROVIDER_{provider_str}_BASE_URL")).ok();
Ok(ResolvedModel {
role: format!("direct:{name}"),
provider,
model_name,
temperature: None,
max_tokens: None,
api_key,
base_url,
})
}
fn parse_provider_model(s: &str) -> Result<(Provider, String)> {
let (provider_str, model) = s
.split_once(':')
.with_context(|| format!("invalid model spec '{}': expected 'provider:model'", s))?;
if model.is_empty() {
bail!("model name cannot be empty in '{}'", s);
}
use crate::skill::manifest::Provider::*;
let provider = match provider_str {
"anthropic" => Anthropic,
"openai" => OpenAi,
"google" => Google,
"ollama" => Ollama,
other => bail!(
"unknown provider '{}': supported providers: anthropic, openai, google, ollama",
other
),
};
Ok((provider, model.to_string()))
}
fn provider_api_key_env(provider: &Provider) -> &'static str {
match provider {
Provider::Anthropic => "ANTHROPIC_API_KEY",
Provider::OpenAi => "OPENAI_API_KEY",
Provider::Google => "GEMINI_API_KEY",
Provider::Ollama => unreachable!("Ollama doesn't use an API key"),
}
}