use serde::{Deserialize, Serialize};
pub const DEFAULT_OPENROUTER_MODEL: &str = "mistralai/mistral-small-2603";
pub const DEFAULT_GEMINI_MODEL: &str = "gemini-3.1-flash-lite-preview";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaskType {
Triage,
Review,
Create,
}
#[derive(Debug, Deserialize, Serialize, Default, Clone)]
#[serde(default)]
pub struct TaskOverride {
pub provider: Option<String>,
pub model: Option<String>,
}
#[derive(Debug, Deserialize, Serialize, Default, Clone)]
#[serde(default)]
pub struct TasksConfig {
pub triage: Option<TaskOverride>,
pub review: Option<TaskOverride>,
pub create: Option<TaskOverride>,
}
#[derive(Debug, Clone, Serialize)]
pub struct FallbackEntry {
pub provider: String,
pub model: Option<String>,
}
impl<'de> Deserialize<'de> for FallbackEntry {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum EntryVariant {
String(String),
Struct {
provider: String,
model: Option<String>,
},
}
match EntryVariant::deserialize(deserializer)? {
EntryVariant::String(provider) => Ok(FallbackEntry {
provider,
model: None,
}),
EntryVariant::Struct { provider, model } => Ok(FallbackEntry { provider, model }),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[serde(default)]
pub struct FallbackConfig {
pub chain: Vec<FallbackEntry>,
}
fn default_retry_max_attempts() -> u32 {
3
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(default)]
pub struct AiConfig {
pub provider: String,
pub model: String,
pub timeout_seconds: u64,
pub allow_paid_models: bool,
pub max_tokens: u32,
pub temperature: f32,
pub circuit_breaker_threshold: u32,
pub circuit_breaker_reset_seconds: u64,
#[serde(default = "default_retry_max_attempts")]
pub retry_max_attempts: u32,
pub tasks: Option<TasksConfig>,
pub fallback: Option<FallbackConfig>,
pub custom_guidance: Option<String>,
pub validation_enabled: bool,
}
impl Default for AiConfig {
fn default() -> Self {
Self {
provider: "openrouter".to_string(),
model: DEFAULT_OPENROUTER_MODEL.to_string(),
timeout_seconds: 30,
allow_paid_models: true,
max_tokens: 4096,
temperature: 0.3,
circuit_breaker_threshold: 3,
circuit_breaker_reset_seconds: 60,
retry_max_attempts: default_retry_max_attempts(),
tasks: None,
fallback: None,
custom_guidance: None,
validation_enabled: true,
}
}
}
impl AiConfig {
#[must_use]
pub fn resolve_for_task(&self, task: TaskType) -> (String, String) {
let task_override = match task {
TaskType::Triage => self.tasks.as_ref().and_then(|t| t.triage.as_ref()),
TaskType::Review => self.tasks.as_ref().and_then(|t| t.review.as_ref()),
TaskType::Create => self.tasks.as_ref().and_then(|t| t.create.as_ref()),
};
let provider = task_override
.and_then(|o| o.provider.clone())
.unwrap_or_else(|| self.provider.clone());
let model = task_override
.and_then(|o| o.model.clone())
.unwrap_or_else(|| self.model.clone());
(provider, model)
}
}