Skip to main content

aptu_core/config/
ai.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! AI provider configuration.
4
5use serde::{Deserialize, Serialize};
6
7/// Default `OpenRouter` model identifier.
8pub const DEFAULT_OPENROUTER_MODEL: &str = "mistralai/mistral-small-2603";
9/// Default `Gemini` model identifier.
10pub const DEFAULT_GEMINI_MODEL: &str = "gemini-3.1-flash-lite-preview";
11
12/// Task type for model selection.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum TaskType {
15    /// Issue triage task.
16    Triage,
17    /// Pull request review task.
18    Review,
19    /// Label creation task.
20    Create,
21}
22
23/// Task-specific AI model override.
24#[derive(Debug, Deserialize, Serialize, Default, Clone)]
25#[serde(default)]
26pub struct TaskOverride {
27    /// Optional provider override for this task.
28    pub provider: Option<String>,
29    /// Optional model override for this task.
30    pub model: Option<String>,
31}
32
33/// Task-specific AI configuration.
34#[derive(Debug, Deserialize, Serialize, Default, Clone)]
35#[serde(default)]
36pub struct TasksConfig {
37    /// Triage task configuration.
38    pub triage: Option<TaskOverride>,
39    /// Review task configuration.
40    pub review: Option<TaskOverride>,
41    /// Create task configuration.
42    pub create: Option<TaskOverride>,
43}
44
45/// Single entry in the fallback provider chain.
46#[derive(Debug, Clone, Serialize)]
47pub struct FallbackEntry {
48    /// Provider name (e.g., "openrouter", "anthropic", "gemini").
49    pub provider: String,
50    /// Optional model override for this specific provider.
51    pub model: Option<String>,
52}
53
54impl<'de> Deserialize<'de> for FallbackEntry {
55    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
56    where
57        D: serde::Deserializer<'de>,
58    {
59        #[derive(Deserialize)]
60        #[serde(untagged)]
61        enum EntryVariant {
62            String(String),
63            Struct {
64                provider: String,
65                model: Option<String>,
66            },
67        }
68
69        match EntryVariant::deserialize(deserializer)? {
70            EntryVariant::String(provider) => Ok(FallbackEntry {
71                provider,
72                model: None,
73            }),
74            EntryVariant::Struct { provider, model } => Ok(FallbackEntry { provider, model }),
75        }
76    }
77}
78
79/// Fallback provider chain configuration.
80#[derive(Debug, Deserialize, Serialize, Clone, Default)]
81#[serde(default)]
82pub struct FallbackConfig {
83    /// Chain of fallback entries to try in order when primary fails.
84    pub chain: Vec<FallbackEntry>,
85}
86
87/// Default value for `retry_max_attempts`.
88fn default_retry_max_attempts() -> u32 {
89    3
90}
91
92/// AI provider settings.
93#[derive(Debug, Deserialize, Serialize, Clone)]
94#[serde(default)]
95pub struct AiConfig {
96    /// AI provider: one of `"gemini"`, `"openrouter"`, `"groq"`, `"cerebras"`, `"zenmux"`, or `"zai"`.
97    pub provider: String,
98    /// Model identifier.
99    pub model: String,
100    /// Request timeout in seconds.
101    pub timeout_seconds: u64,
102    /// Allow paid models (default: true).
103    pub allow_paid_models: bool,
104    /// Maximum tokens for API responses.
105    pub max_tokens: u32,
106    /// Temperature for API requests (0.0-1.0).
107    pub temperature: f32,
108    /// Circuit breaker failure threshold before opening (default: 3).
109    pub circuit_breaker_threshold: u32,
110    /// Circuit breaker reset timeout in seconds (default: 60).
111    pub circuit_breaker_reset_seconds: u64,
112    /// Maximum retry attempts for rate-limited requests (default: 3).
113    #[serde(default = "default_retry_max_attempts")]
114    pub retry_max_attempts: u32,
115    /// Task-specific model overrides.
116    pub tasks: Option<TasksConfig>,
117    /// Fallback provider chain for resilience.
118    pub fallback: Option<FallbackConfig>,
119    /// Custom guidance to override or extend default best practices.
120    ///
121    /// Allows users to provide project-specific tooling recommendations
122    /// that will be appended to the default best practices context.
123    /// Useful for enforcing project-specific choices (e.g., poetry instead of uv).
124    pub custom_guidance: Option<String>,
125    /// Enable pre-flight model validation with fuzzy matching (default: true).
126    ///
127    /// When enabled, validates that the configured model ID exists in the
128    /// cached model registry before creating an AI client. Provides helpful
129    /// suggestions if an invalid model ID is detected.
130    pub validation_enabled: bool,
131}
132
133impl Default for AiConfig {
134    fn default() -> Self {
135        Self {
136            provider: "openrouter".to_string(),
137            model: DEFAULT_OPENROUTER_MODEL.to_string(),
138            timeout_seconds: 30,
139            allow_paid_models: true,
140            max_tokens: 4096,
141            temperature: 0.3,
142            circuit_breaker_threshold: 3,
143            circuit_breaker_reset_seconds: 60,
144            retry_max_attempts: default_retry_max_attempts(),
145            tasks: None,
146            fallback: None,
147            custom_guidance: None,
148            validation_enabled: true,
149        }
150    }
151}
152
153impl AiConfig {
154    /// Resolve provider and model for a specific task type.
155    ///
156    /// Returns a tuple of (provider, model) by checking task-specific overrides first,
157    /// then falling back to the default provider and model.
158    ///
159    /// # Arguments
160    ///
161    /// * `task` - The task type to resolve configuration for
162    ///
163    /// # Returns
164    ///
165    /// A tuple of (`provider_name`, `model_name`) strings
166    #[must_use]
167    pub fn resolve_for_task(&self, task: TaskType) -> (String, String) {
168        let task_override = match task {
169            TaskType::Triage => self.tasks.as_ref().and_then(|t| t.triage.as_ref()),
170            TaskType::Review => self.tasks.as_ref().and_then(|t| t.review.as_ref()),
171            TaskType::Create => self.tasks.as_ref().and_then(|t| t.create.as_ref()),
172        };
173
174        let provider = task_override
175            .and_then(|o| o.provider.clone())
176            .unwrap_or_else(|| self.provider.clone());
177
178        let model = task_override
179            .and_then(|o| o.model.clone())
180            .unwrap_or_else(|| self.model.clone());
181
182        (provider, model)
183    }
184}