ai_agent/utils/model/
model_validate.rs1use crate::constants::env::ai;
6use std::collections::HashMap;
7use std::sync::{Mutex, OnceLock};
8
9#[derive(Debug, Clone)]
15pub struct ModelValidationResult {
16 pub valid: bool,
18 pub error: Option<String>,
20}
21
22impl ModelValidationResult {
23 pub fn valid() -> Self {
24 Self {
25 valid: true,
26 error: None,
27 }
28 }
29
30 pub fn invalid(error: impl Into<String>) -> Self {
31 Self {
32 valid: false,
33 error: Some(error.into()),
34 }
35 }
36}
37
38static VALID_MODEL_CACHE: OnceLock<Mutex<HashMap<String, bool>>> = OnceLock::new();
44
45fn get_valid_model_cache() -> &'static Mutex<HashMap<String, bool>> {
46 VALID_MODEL_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
47}
48
49fn cache_valid_model(model: &str) {
50 let mut cache = get_valid_model_cache().lock().unwrap();
51 cache.insert(model.to_string(), true);
52}
53
54fn is_cached_as_valid(model: &str) -> bool {
55 let cache = get_valid_model_cache().lock().unwrap();
56 cache.get(model).copied().unwrap_or(false)
57}
58
59const MODEL_ALIASES: &[&str] = &["opus", "sonnet", "haiku", "opusplan", "haikuplan", "best"];
65
66fn is_model_alias(model: &str) -> bool {
67 MODEL_ALIASES.contains(&model.to_lowercase().as_str())
68}
69
70pub async fn validate_model(model: &str) -> ModelValidationResult {
78 let normalized_model = model.trim().to_string();
79
80 if normalized_model.is_empty() {
82 return ModelValidationResult::invalid("Model name cannot be empty");
83 }
84
85 if !is_model_allowed(&normalized_model) {
87 return ModelValidationResult::invalid(format!(
88 "Model '{}' is not in the list of available models",
89 normalized_model
90 ));
91 }
92
93 let lower_model = normalized_model.to_lowercase();
95 if MODEL_ALIASES.contains(&lower_model.as_str()) {
96 return ModelValidationResult::valid();
97 }
98
99 if let Ok(custom_model) = std::env::var(ai::ANTHROPIC_CUSTOM_MODEL_OPTION) {
101 if normalized_model == custom_model {
102 return ModelValidationResult::valid();
103 }
104 }
105
106 if is_cached_as_valid(&normalized_model) {
108 return ModelValidationResult::valid();
109 }
110
111 match do_validate_api_call(&normalized_model).await {
114 Ok(_) => {
115 cache_valid_model(&normalized_model);
117 ModelValidationResult::valid()
118 }
119 Err(e) => handle_validation_error(e, &normalized_model),
120 }
121}
122
123async fn do_validate_api_call(_model: &str) -> Result<(), ValidationError> {
125 Ok(())
129}
130
131#[derive(Debug)]
133pub enum ValidationError {
134 NotFound(String),
135 Authentication(String),
136 Connection(String),
137 Api(String),
138 Unknown(String),
139}
140
141impl std::fmt::Display for ValidationError {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 match self {
144 ValidationError::NotFound(msg) => write!(f, "NotFound: {}", msg),
145 ValidationError::Authentication(msg) => write!(f, "Authentication: {}", msg),
146 ValidationError::Connection(msg) => write!(f, "Connection: {}", msg),
147 ValidationError::Api(msg) => write!(f, "Api: {}", msg),
148 ValidationError::Unknown(msg) => write!(f, "Unknown: {}", msg),
149 }
150 }
151}
152
153fn handle_validation_error(error: ValidationError, model_name: &str) -> ModelValidationResult {
155 match error {
156 ValidationError::NotFound(_) => {
158 let fallback = get_3p_fallback_suggestion(model_name);
159 let suggestion = fallback
160 .map(|f| format!(". Try '{}' instead", f))
161 .unwrap_or_default();
162 ModelValidationResult::invalid(format!(
163 "Model '{}' not found{}",
164 model_name, suggestion
165 ))
166 }
167
168 ValidationError::Authentication(_) => ModelValidationResult::invalid(
170 "Authentication failed. Please check your API credentials.",
171 ),
172
173 ValidationError::Connection(_) => {
174 ModelValidationResult::invalid("Network error. Please check your internet connection.")
175 }
176
177 ValidationError::Api(msg) => {
179 if msg.contains("model:") && msg.contains("not_found_error") {
180 return ModelValidationResult::invalid(format!("Model '{}' not found", model_name));
181 }
182 ModelValidationResult::invalid(format!("API error: {}", msg))
184 }
185
186 ValidationError::Unknown(msg) => {
188 ModelValidationResult::invalid(format!("Unable to validate model: {}", msg))
189 }
190 }
191}
192
193fn get_3p_fallback_suggestion(model: &str) -> Option<String> {
199 if get_api_provider() == "firstParty" {
200 return None;
201 }
202
203 let lower_model = model.to_lowercase();
204
205 if lower_model.contains("opus-4-6") || lower_model.contains("opus_4_6") {
206 return Some(get_model_strings().opus_41.clone());
207 }
208 if lower_model.contains("sonnet-4-6") || lower_model.contains("sonnet_4_6") {
209 return Some(get_model_strings().sonnet_45.clone());
210 }
211 if lower_model.contains("sonnet-4-5") || lower_model.contains("sonnet_4_5") {
212 return Some(get_model_strings().sonnet_40.clone());
213 }
214
215 None
216}
217
218fn get_api_provider() -> String {
224 std::env::var(ai::API_PROVIDER)
225 .ok()
226 .unwrap_or_else(|| "firstParty".to_string())
227}
228
229fn is_model_allowed(_model: &str) -> bool {
231 true
234}
235
236fn get_model_strings() -> ModelStrings {
238 ModelStrings {
239 opus_41: "claude-opus-4-1-20250805".to_string(),
240 opus_45: "claude-opus-4-5-20250514".to_string(),
241 opus_46: "claude-opus-4-6-20251106".to_string(),
242 sonnet_40: "claude-sonnet-4-0-20250514".to_string(),
243 sonnet_45: "claude-sonnet-4-5-20241022".to_string(),
244 sonnet_46: "claude-sonnet-4-6-20251106".to_string(),
245 }
246}
247
248#[derive(Debug, Clone)]
249struct ModelStrings {
250 opus_41: String,
251 opus_45: String,
252 opus_46: String,
253 sonnet_40: String,
254 sonnet_45: String,
255 sonnet_46: String,
256}
257
258impl ModelStrings {
259 fn opus_41(&self) -> String {
260 self.opus_41.clone()
261 }
262 fn opus_45(&self) -> String {
263 self.opus_45.clone()
264 }
265 fn opus_46(&self) -> String {
266 self.opus_46.clone()
267 }
268 fn sonnet_40(&self) -> String {
269 self.sonnet_40.clone()
270 }
271 fn sonnet_45(&self) -> String {
272 self.sonnet_45.clone()
273 }
274 fn sonnet_46(&self) -> String {
275 self.sonnet_46.clone()
276 }
277}