1use serde::{Deserialize, Serialize};
6
7pub const DEFAULT_OPENROUTER_MODEL: &str = "mistralai/mistral-small-2603";
9pub const DEFAULT_GEMINI_MODEL: &str = "gemini-3.1-flash-lite-preview";
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum TaskType {
15 Triage,
17 Review,
19 Create,
21}
22
23#[derive(Debug, Deserialize, Serialize, Default, Clone)]
25#[serde(default)]
26pub struct TaskOverride {
27 pub provider: Option<String>,
29 pub model: Option<String>,
31}
32
33#[derive(Debug, Deserialize, Serialize, Default, Clone)]
35#[serde(default)]
36pub struct TasksConfig {
37 pub triage: Option<TaskOverride>,
39 pub review: Option<TaskOverride>,
41 pub create: Option<TaskOverride>,
43}
44
45#[derive(Debug, Clone, Serialize)]
47pub struct FallbackEntry {
48 pub provider: String,
50 pub model: Option<String>,
52}
53
54impl<'de> Deserialize<'de> for FallbackEntry {
55 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
56 where
57 D: serde::Deserializer<'de>,
58 {
59 #[derive(Deserialize)]
60 #[serde(untagged)]
61 enum EntryVariant {
62 String(String),
63 Struct {
64 provider: String,
65 model: Option<String>,
66 },
67 }
68
69 match EntryVariant::deserialize(deserializer)? {
70 EntryVariant::String(provider) => Ok(FallbackEntry {
71 provider,
72 model: None,
73 }),
74 EntryVariant::Struct { provider, model } => Ok(FallbackEntry { provider, model }),
75 }
76 }
77}
78
79#[derive(Debug, Deserialize, Serialize, Clone, Default)]
81#[serde(default)]
82pub struct FallbackConfig {
83 pub chain: Vec<FallbackEntry>,
85}
86
87fn default_retry_max_attempts() -> u32 {
89 3
90}
91
92#[derive(Debug, Deserialize, Serialize, Clone)]
94#[serde(default)]
95pub struct AiConfig {
96 pub provider: String,
98 pub model: String,
100 pub timeout_seconds: u64,
102 pub allow_paid_models: bool,
104 pub max_tokens: u32,
106 pub temperature: f32,
108 pub circuit_breaker_threshold: u32,
110 pub circuit_breaker_reset_seconds: u64,
112 #[serde(default = "default_retry_max_attempts")]
114 pub retry_max_attempts: u32,
115 pub tasks: Option<TasksConfig>,
117 pub fallback: Option<FallbackConfig>,
119 pub custom_guidance: Option<String>,
125 pub validation_enabled: bool,
131}
132
133impl Default for AiConfig {
134 fn default() -> Self {
135 Self {
136 provider: "openrouter".to_string(),
137 model: DEFAULT_OPENROUTER_MODEL.to_string(),
138 timeout_seconds: 30,
139 allow_paid_models: true,
140 max_tokens: 4096,
141 temperature: 0.3,
142 circuit_breaker_threshold: 3,
143 circuit_breaker_reset_seconds: 60,
144 retry_max_attempts: default_retry_max_attempts(),
145 tasks: None,
146 fallback: None,
147 custom_guidance: None,
148 validation_enabled: true,
149 }
150 }
151}
152
153impl AiConfig {
154 #[must_use]
167 pub fn resolve_for_task(&self, task: TaskType) -> (String, String) {
168 let task_override = match task {
169 TaskType::Triage => self.tasks.as_ref().and_then(|t| t.triage.as_ref()),
170 TaskType::Review => self.tasks.as_ref().and_then(|t| t.review.as_ref()),
171 TaskType::Create => self.tasks.as_ref().and_then(|t| t.create.as_ref()),
172 };
173
174 let provider = task_override
175 .and_then(|o| o.provider.clone())
176 .unwrap_or_else(|| self.provider.clone());
177
178 let model = task_override
179 .and_then(|o| o.model.clone())
180 .unwrap_or_else(|| self.model.clone());
181
182 (provider, model)
183 }
184}