Skip to main content

ai_agent/utils/model/
model_validate.rs

1//! Model validation functions.
2//!
3//! Translated from openclaudecode/src/utils/model/validateModel.ts
4
5use crate::constants::env::ai;
6use std::collections::HashMap;
7use std::sync::{Mutex, OnceLock};
8
9// =============================================================================
10// TYPES
11// =============================================================================
12
13/// Model validation result
14#[derive(Debug, Clone)]
15pub struct ModelValidationResult {
16    /// Whether the model is valid
17    pub valid: bool,
18    /// Error message if invalid
19    pub error: Option<String>,
20}
21
22impl ModelValidationResult {
23    pub fn valid() -> Self {
24        Self {
25            valid: true,
26            error: None,
27        }
28    }
29
30    pub fn invalid(error: impl Into<String>) -> Self {
31        Self {
32            valid: false,
33            error: Some(error.into()),
34        }
35    }
36}
37
38// =============================================================================
39// CACHE
40// =============================================================================
41
42/// Cache valid models to avoid repeated API calls
43static VALID_MODEL_CACHE: OnceLock<Mutex<HashMap<String, bool>>> = OnceLock::new();
44
45fn get_valid_model_cache() -> &'static Mutex<HashMap<String, bool>> {
46    VALID_MODEL_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
47}
48
49fn cache_valid_model(model: &str) {
50    let mut cache = get_valid_model_cache().lock().unwrap();
51    cache.insert(model.to_string(), true);
52}
53
54fn is_cached_as_valid(model: &str) -> bool {
55    let cache = get_valid_model_cache().lock().unwrap();
56    cache.get(model).copied().unwrap_or(false)
57}
58
59// =============================================================================
60// MODEL ALIASES
61// =============================================================================
62
63/// Model aliases
64const MODEL_ALIASES: &[&str] = &["opus", "sonnet", "haiku", "opusplan", "haikuplan", "best"];
65
66fn is_model_alias(model: &str) -> bool {
67    MODEL_ALIASES.contains(&model.to_lowercase().as_str())
68}
69
70// =============================================================================
71// VALIDATION
72// =============================================================================
73
74/// Validates a model by attempting an actual API call.
75/// This is an async function that would make an actual API call.
76/// In this stub, we provide a simplified synchronous version.
77pub async fn validate_model(model: &str) -> ModelValidationResult {
78    let normalized_model = model.trim().to_string();
79
80    // Empty model is invalid
81    if normalized_model.is_empty() {
82        return ModelValidationResult::invalid("Model name cannot be empty");
83    }
84
85    // Check against availableModels allowlist before any API call
86    if !is_model_allowed(&normalized_model) {
87        return ModelValidationResult::invalid(format!(
88            "Model '{}' is not in the list of available models",
89            normalized_model
90        ));
91    }
92
93    // Check if it's a known alias (these are always valid)
94    let lower_model = normalized_model.to_lowercase();
95    if MODEL_ALIASES.contains(&lower_model.as_str()) {
96        return ModelValidationResult::valid();
97    }
98
99    // Check if it matches ANTHROPIC_CUSTOM_MODEL_OPTION (pre-validated by the user)
100    if let Ok(custom_model) = std::env::var(ai::ANTHROPIC_CUSTOM_MODEL_OPTION) {
101        if normalized_model == custom_model {
102            return ModelValidationResult::valid();
103        }
104    }
105
106    // Check cache first
107    if is_cached_as_valid(&normalized_model) {
108        return ModelValidationResult::valid();
109    }
110
111    // Try to make an actual API call with minimal parameters
112    // In a real implementation, this would call sideQuery or similar
113    match do_validate_api_call(&normalized_model).await {
114        Ok(_) => {
115            // If we got here, the model is valid
116            cache_valid_model(&normalized_model);
117            ModelValidationResult::valid()
118        }
119        Err(e) => handle_validation_error(e, &normalized_model),
120    }
121}
122
123/// Do actual API call to validate model
124async fn do_validate_api_call(_model: &str) -> Result<(), ValidationError> {
125    // Stub - would need to make actual API call via sideQuery or similar
126    // For now, we'll just return Ok as a placeholder
127    // In the real implementation, this would call the API and return an error if it fails
128    Ok(())
129}
130
131/// Validation error types
132#[derive(Debug)]
133pub enum ValidationError {
134    NotFound(String),
135    Authentication(String),
136    Connection(String),
137    Api(String),
138    Unknown(String),
139}
140
141impl std::fmt::Display for ValidationError {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        match self {
144            ValidationError::NotFound(msg) => write!(f, "NotFound: {}", msg),
145            ValidationError::Authentication(msg) => write!(f, "Authentication: {}", msg),
146            ValidationError::Connection(msg) => write!(f, "Connection: {}", msg),
147            ValidationError::Api(msg) => write!(f, "Api: {}", msg),
148            ValidationError::Unknown(msg) => write!(f, "Unknown: {}", msg),
149        }
150    }
151}
152
153/// Handle validation error and return appropriate result
154fn handle_validation_error(error: ValidationError, model_name: &str) -> ModelValidationResult {
155    match error {
156        // NotFoundError (404) means the model doesn't exist
157        ValidationError::NotFound(_) => {
158            let fallback = get_3p_fallback_suggestion(model_name);
159            let suggestion = fallback
160                .map(|f| format!(". Try '{}' instead", f))
161                .unwrap_or_default();
162            ModelValidationResult::invalid(format!(
163                "Model '{}' not found{}",
164                model_name, suggestion
165            ))
166        }
167
168        // For other API errors, provide context-specific messages
169        ValidationError::Authentication(_) => ModelValidationResult::invalid(
170            "Authentication failed. Please check your API credentials.",
171        ),
172
173        ValidationError::Connection(_) => {
174            ModelValidationResult::invalid("Network error. Please check your internet connection.")
175        }
176
177        // Check error body for model-specific errors
178        ValidationError::Api(msg) => {
179            if msg.contains("model:") && msg.contains("not_found_error") {
180                return ModelValidationResult::invalid(format!("Model '{}' not found", model_name));
181            }
182            // Generic API error
183            ModelValidationResult::invalid(format!("API error: {}", msg))
184        }
185
186        // For unknown errors, be safe and reject
187        ValidationError::Unknown(msg) => {
188            ModelValidationResult::invalid(format!("Unable to validate model: {}", msg))
189        }
190    }
191}
192
193// =============================================================================
194// FALLBACK SUGGESTIONS
195// =============================================================================
196
197/// Suggest a fallback model for 3P users when the selected model is unavailable.
198fn get_3p_fallback_suggestion(model: &str) -> Option<String> {
199    if get_api_provider() == "firstParty" {
200        return None;
201    }
202
203    let lower_model = model.to_lowercase();
204
205    if lower_model.contains("opus-4-6") || lower_model.contains("opus_4_6") {
206        return Some(get_model_strings().opus_41.clone());
207    }
208    if lower_model.contains("sonnet-4-6") || lower_model.contains("sonnet_4_6") {
209        return Some(get_model_strings().sonnet_45.clone());
210    }
211    if lower_model.contains("sonnet-4-5") || lower_model.contains("sonnet_4_5") {
212        return Some(get_model_strings().sonnet_40.clone());
213    }
214
215    None
216}
217
218// =============================================================================
219// STUB HELPERS
220// =============================================================================
221
222/// Get API provider
223fn get_api_provider() -> String {
224    std::env::var(ai::API_PROVIDER)
225        .ok()
226        .unwrap_or_else(|| "firstParty".to_string())
227}
228
229/// Check if model is allowed (from modelAllowlist)
230fn is_model_allowed(_model: &str) -> bool {
231    // Stub - would need modelAllowlist module
232    // For now, allow all models
233    true
234}
235
236/// Get model strings
237fn get_model_strings() -> ModelStrings {
238    ModelStrings {
239        opus_41: "claude-opus-4-1-20250805".to_string(),
240        opus_45: "claude-opus-4-5-20250514".to_string(),
241        opus_46: "claude-opus-4-6-20251106".to_string(),
242        sonnet_40: "claude-sonnet-4-0-20250514".to_string(),
243        sonnet_45: "claude-sonnet-4-5-20241022".to_string(),
244        sonnet_46: "claude-sonnet-4-6-20251106".to_string(),
245    }
246}
247
248#[derive(Debug, Clone)]
249struct ModelStrings {
250    opus_41: String,
251    opus_45: String,
252    opus_46: String,
253    sonnet_40: String,
254    sonnet_45: String,
255    sonnet_46: String,
256}
257
258impl ModelStrings {
259    fn opus_41(&self) -> String {
260        self.opus_41.clone()
261    }
262    fn opus_45(&self) -> String {
263        self.opus_45.clone()
264    }
265    fn opus_46(&self) -> String {
266        self.opus_46.clone()
267    }
268    fn sonnet_40(&self) -> String {
269        self.sonnet_40.clone()
270    }
271    fn sonnet_45(&self) -> String {
272        self.sonnet_45.clone()
273    }
274    fn sonnet_46(&self) -> String {
275        self.sonnet_46.clone()
276    }
277}