1use super::deployment::DeploymentId;
7use std::collections::HashMap;
8use std::sync::RwLock;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum FallbackType {
23 General,
25 ContextWindow,
27 ContentPolicy,
29 RateLimit,
31}
32
33#[derive(Debug, Clone)]
43pub struct ExecutionResult<T> {
44 pub result: T,
46 pub deployment_id: DeploymentId,
48 pub attempts: u32,
50 pub model_used: String,
52 pub used_fallback: bool,
54 pub latency_us: u64,
56}
57
58#[derive(Debug, Default)]
67pub struct FallbackConfig {
68 general: RwLock<HashMap<String, Vec<String>>>,
70
71 context_window: RwLock<HashMap<String, Vec<String>>>,
73
74 content_policy: RwLock<HashMap<String, Vec<String>>>,
76
77 rate_limit: RwLock<HashMap<String, Vec<String>>>,
79}
80
81impl FallbackConfig {
82 pub fn new() -> Self {
84 Self {
85 general: RwLock::new(HashMap::new()),
86 context_window: RwLock::new(HashMap::new()),
87 content_policy: RwLock::new(HashMap::new()),
88 rate_limit: RwLock::new(HashMap::new()),
89 }
90 }
91
92 pub fn add_general(self, model: &str, fallbacks: Vec<String>) -> Self {
100 self.general
101 .write()
102 .expect("FallbackConfig general lock poisoned")
103 .insert(model.to_string(), fallbacks);
104 self
105 }
106
107 pub fn add_context_window(self, model: &str, fallbacks: Vec<String>) -> Self {
115 self.context_window
116 .write()
117 .expect("FallbackConfig context_window lock poisoned")
118 .insert(model.to_string(), fallbacks);
119 self
120 }
121
122 pub fn add_content_policy(self, model: &str, fallbacks: Vec<String>) -> Self {
130 self.content_policy
131 .write()
132 .expect("FallbackConfig content_policy lock poisoned")
133 .insert(model.to_string(), fallbacks);
134 self
135 }
136
137 pub fn add_rate_limit(self, model: &str, fallbacks: Vec<String>) -> Self {
145 self.rate_limit
146 .write()
147 .expect("FallbackConfig rate_limit lock poisoned")
148 .insert(model.to_string(), fallbacks);
149 self
150 }
151
152 pub fn get_fallbacks_for_type(
161 &self,
162 model_name: &str,
163 fallback_type: FallbackType,
164 ) -> Vec<String> {
165 let lock = match fallback_type {
166 FallbackType::General => &self.general,
167 FallbackType::ContextWindow => &self.context_window,
168 FallbackType::ContentPolicy => &self.content_policy,
169 FallbackType::RateLimit => &self.rate_limit,
170 };
171
172 lock.read()
173 .expect("FallbackConfig lock poisoned")
174 .get(model_name)
175 .cloned()
176 .unwrap_or_default()
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
187 fn test_fallback_type_debug() {
188 assert!(format!("{:?}", FallbackType::General).contains("General"));
189 assert!(format!("{:?}", FallbackType::ContextWindow).contains("ContextWindow"));
190 assert!(format!("{:?}", FallbackType::ContentPolicy).contains("ContentPolicy"));
191 assert!(format!("{:?}", FallbackType::RateLimit).contains("RateLimit"));
192 }
193
194 #[test]
195 fn test_fallback_type_clone() {
196 let t = FallbackType::ContextWindow;
197 let cloned = t;
198 assert_eq!(cloned, FallbackType::ContextWindow);
199 }
200
201 #[test]
202 fn test_fallback_type_copy() {
203 let t = FallbackType::RateLimit;
204 let copied = t;
205 assert_eq!(t, copied);
206 }
207
208 #[test]
209 fn test_fallback_type_eq() {
210 assert_eq!(FallbackType::General, FallbackType::General);
211 assert_ne!(FallbackType::General, FallbackType::ContextWindow);
212 assert_ne!(FallbackType::ContextWindow, FallbackType::ContentPolicy);
213 assert_ne!(FallbackType::ContentPolicy, FallbackType::RateLimit);
214 }
215
216 #[test]
219 fn test_execution_result_creation() {
220 let result = ExecutionResult {
221 result: "success".to_string(),
222 deployment_id: "openai-gpt-4".to_string(),
223 attempts: 1,
224 model_used: "gpt-4".to_string(),
225 used_fallback: false,
226 latency_us: 1500,
227 };
228
229 assert_eq!(result.result, "success");
230 assert_eq!(result.attempts, 1);
231 assert_eq!(result.model_used, "gpt-4");
232 assert!(!result.used_fallback);
233 assert_eq!(result.latency_us, 1500);
234 }
235
236 #[test]
237 fn test_execution_result_with_fallback() {
238 let result = ExecutionResult {
239 result: 42,
240 deployment_id: "anthropic-claude-3".to_string(),
241 attempts: 3,
242 model_used: "claude-3".to_string(),
243 used_fallback: true,
244 latency_us: 5000,
245 };
246
247 assert_eq!(result.result, 42);
248 assert_eq!(result.attempts, 3);
249 assert!(result.used_fallback);
250 }
251
252 #[test]
253 fn test_execution_result_debug() {
254 let result = ExecutionResult {
255 result: "test",
256 deployment_id: "openai-gpt-4".to_string(),
257 attempts: 1,
258 model_used: "gpt-4".to_string(),
259 used_fallback: false,
260 latency_us: 100,
261 };
262 let debug = format!("{:?}", result);
263 assert!(debug.contains("ExecutionResult"));
264 assert!(debug.contains("attempts"));
265 }
266
267 #[test]
268 fn test_execution_result_clone() {
269 let result = ExecutionResult {
270 result: vec![1, 2, 3],
271 deployment_id: "openai-gpt-4".to_string(),
272 attempts: 2,
273 model_used: "gpt-4".to_string(),
274 used_fallback: true,
275 latency_us: 2000,
276 };
277 let cloned = result.clone();
278 assert_eq!(cloned.result, vec![1, 2, 3]);
279 assert_eq!(cloned.attempts, 2);
280 assert!(cloned.used_fallback);
281 }
282
283 #[test]
286 fn test_fallback_config_new() {
287 let config = FallbackConfig::new();
288 assert!(
289 config
290 .get_fallbacks_for_type("gpt-4", FallbackType::General)
291 .is_empty()
292 );
293 }
294
295 #[test]
296 fn test_fallback_config_default() {
297 let config = FallbackConfig::default();
298 assert!(
299 config
300 .get_fallbacks_for_type("gpt-4", FallbackType::General)
301 .is_empty()
302 );
303 }
304
305 #[test]
306 fn test_fallback_config_debug() {
307 let config = FallbackConfig::new();
308 let debug = format!("{:?}", config);
309 assert!(debug.contains("FallbackConfig"));
310 }
311
312 #[test]
313 fn test_add_general_fallback() {
314 let config = FallbackConfig::new().add_general(
315 "gpt-4",
316 vec!["gpt-3.5-turbo".to_string(), "claude-3".to_string()],
317 );
318
319 let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::General);
320 assert_eq!(fallbacks.len(), 2);
321 assert!(fallbacks.contains(&"gpt-3.5-turbo".to_string()));
322 assert!(fallbacks.contains(&"claude-3".to_string()));
323 }
324
325 #[test]
326 fn test_add_context_window_fallback() {
327 let config =
328 FallbackConfig::new().add_context_window("gpt-4", vec!["gpt-4-32k".to_string()]);
329
330 let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::ContextWindow);
331 assert_eq!(fallbacks.len(), 1);
332 assert_eq!(fallbacks[0], "gpt-4-32k");
333 }
334
335 #[test]
336 fn test_add_content_policy_fallback() {
337 let config =
338 FallbackConfig::new().add_content_policy("gpt-4", vec!["claude-3".to_string()]);
339
340 let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::ContentPolicy);
341 assert_eq!(fallbacks.len(), 1);
342 assert_eq!(fallbacks[0], "claude-3");
343 }
344
345 #[test]
346 fn test_add_rate_limit_fallback() {
347 let config = FallbackConfig::new().add_rate_limit(
348 "gpt-4",
349 vec!["gpt-3.5-turbo".to_string(), "gpt-4-turbo".to_string()],
350 );
351
352 let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::RateLimit);
353 assert_eq!(fallbacks.len(), 2);
354 }
355
356 #[test]
357 fn test_builder_pattern_chaining() {
358 let config = FallbackConfig::new()
359 .add_general("gpt-4", vec!["gpt-3.5-turbo".to_string()])
360 .add_context_window("gpt-4", vec!["gpt-4-32k".to_string()])
361 .add_content_policy("gpt-4", vec!["claude-3".to_string()])
362 .add_rate_limit("gpt-4", vec!["gemini".to_string()]);
363
364 assert_eq!(
365 config
366 .get_fallbacks_for_type("gpt-4", FallbackType::General)
367 .len(),
368 1
369 );
370 assert_eq!(
371 config
372 .get_fallbacks_for_type("gpt-4", FallbackType::ContextWindow)
373 .len(),
374 1
375 );
376 assert_eq!(
377 config
378 .get_fallbacks_for_type("gpt-4", FallbackType::ContentPolicy)
379 .len(),
380 1
381 );
382 assert_eq!(
383 config
384 .get_fallbacks_for_type("gpt-4", FallbackType::RateLimit)
385 .len(),
386 1
387 );
388 }
389
390 #[test]
391 fn test_multiple_models() {
392 let config = FallbackConfig::new()
393 .add_general("gpt-4", vec!["gpt-3.5-turbo".to_string()])
394 .add_general("claude-3", vec!["gemini".to_string()]);
395
396 let gpt4_fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::General);
397 let claude_fallbacks = config.get_fallbacks_for_type("claude-3", FallbackType::General);
398
399 assert_eq!(gpt4_fallbacks, vec!["gpt-3.5-turbo".to_string()]);
400 assert_eq!(claude_fallbacks, vec!["gemini".to_string()]);
401 }
402
403 #[test]
404 fn test_get_fallbacks_nonexistent_model() {
405 let config = FallbackConfig::new().add_general("gpt-4", vec!["gpt-3.5-turbo".to_string()]);
406
407 let fallbacks = config.get_fallbacks_for_type("nonexistent", FallbackType::General);
408 assert!(fallbacks.is_empty());
409 }
410
411 #[test]
412 fn test_get_fallbacks_wrong_type() {
413 let config = FallbackConfig::new().add_general("gpt-4", vec!["gpt-3.5-turbo".to_string()]);
414
415 let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::ContextWindow);
417 assert!(fallbacks.is_empty());
418 }
419
420 #[test]
421 fn test_empty_fallback_list() {
422 let config = FallbackConfig::new().add_general("gpt-4", vec![]);
423
424 let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::General);
425 assert!(fallbacks.is_empty());
426 }
427
428 #[test]
429 fn test_override_fallback() {
430 let config = FallbackConfig::new()
431 .add_general("gpt-4", vec!["first".to_string()])
432 .add_general("gpt-4", vec!["second".to_string()]);
433
434 let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::General);
435 assert_eq!(fallbacks, vec!["second".to_string()]);
436 }
437
438 #[test]
441 fn test_concurrent_reads() {
442 use std::sync::Arc;
443 use std::thread;
444
445 let config =
446 Arc::new(FallbackConfig::new().add_general("gpt-4", vec!["gpt-3.5".to_string()]));
447
448 let mut handles = vec![];
449
450 for _ in 0..10 {
451 let c = config.clone();
452 let handle = thread::spawn(move || {
453 for _ in 0..100 {
454 let _ = c.get_fallbacks_for_type("gpt-4", FallbackType::General);
455 }
456 });
457 handles.push(handle);
458 }
459
460 for handle in handles {
461 handle.join().unwrap();
462 }
463 }
464
465 #[test]
468 fn test_special_characters_in_model_name() {
469 let config =
470 FallbackConfig::new().add_general("model/v2.0:latest", vec!["backup".to_string()]);
471
472 let fallbacks = config.get_fallbacks_for_type("model/v2.0:latest", FallbackType::General);
473 assert_eq!(fallbacks.len(), 1);
474 }
475
476 #[test]
477 fn test_unicode_in_model_name() {
478 let config = FallbackConfig::new().add_general("模型", vec!["备份".to_string()]);
479
480 let fallbacks = config.get_fallbacks_for_type("模型", FallbackType::General);
481 assert_eq!(fallbacks, vec!["备份".to_string()]);
482 }
483
484 #[test]
485 fn test_empty_model_name() {
486 let config = FallbackConfig::new().add_general("", vec!["fallback".to_string()]);
487
488 let fallbacks = config.get_fallbacks_for_type("", FallbackType::General);
489 assert_eq!(fallbacks, vec!["fallback".to_string()]);
490 }
491
492 #[test]
493 fn test_many_fallbacks() {
494 let fallbacks: Vec<String> = (0..100).map(|i| format!("model_{}", i)).collect();
495 let config = FallbackConfig::new().add_general("primary", fallbacks.clone());
496
497 let result = config.get_fallbacks_for_type("primary", FallbackType::General);
498 assert_eq!(result.len(), 100);
499 assert_eq!(result[0], "model_0");
500 assert_eq!(result[99], "model_99");
501 }
502
503 #[test]
504 fn test_fallback_type_all_variants() {
505 let config = FallbackConfig::new()
506 .add_general("model", vec!["g".to_string()])
507 .add_context_window("model", vec!["cw".to_string()])
508 .add_content_policy("model", vec!["cp".to_string()])
509 .add_rate_limit("model", vec!["rl".to_string()]);
510
511 assert_eq!(
512 config.get_fallbacks_for_type("model", FallbackType::General),
513 vec!["g"]
514 );
515 assert_eq!(
516 config.get_fallbacks_for_type("model", FallbackType::ContextWindow),
517 vec!["cw"]
518 );
519 assert_eq!(
520 config.get_fallbacks_for_type("model", FallbackType::ContentPolicy),
521 vec!["cp"]
522 );
523 assert_eq!(
524 config.get_fallbacks_for_type("model", FallbackType::RateLimit),
525 vec!["rl"]
526 );
527 }
528}