Skip to main content

litellm_rs/core/router/
fallback.rs

1//! Fallback configuration and execution result types
2//!
3//! This module defines fallback configuration for error handling
4//! and execution result metadata.
5
6use super::deployment::DeploymentId;
7use std::collections::HashMap;
8use std::sync::RwLock;
9
10/// Fallback type enumeration
11///
12/// Defines different types of fallback scenarios that can trigger alternative model selection.
13/// Each type corresponds to a specific error condition and has its own fallback mapping.
14///
15/// ## Fallback Priority
16///
17/// When determining fallback models, the router checks in this order:
18/// 1. Specific fallback type (ContextWindow, ContentPolicy, RateLimit)
19/// 2. General fallback (if no specific type matches)
20/// 3. Empty list (no fallback available)
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum FallbackType {
23    /// General fallback for any error
24    General,
25    /// Context window exceeded - model cannot handle the input size
26    ContextWindow,
27    /// Content policy violation - content was filtered/rejected
28    ContentPolicy,
29    /// Rate limit exceeded - too many requests
30    RateLimit,
31}
32
33/// Execution result with metadata
34///
35/// Contains the result of a successful execution along with metadata about
36/// the execution such as which deployment was used, how many attempts were made,
37/// and whether fallback was used.
38///
39/// # Type Parameters
40///
41/// * `T` - The type of the result value
42#[derive(Debug, Clone)]
43pub struct ExecutionResult<T> {
44    /// The successful result value
45    pub result: T,
46    /// The deployment ID that successfully handled the request
47    pub deployment_id: DeploymentId,
48    /// Total number of attempts across all retries and fallbacks
49    pub attempts: u32,
50    /// The actual model that was used (may differ from requested if fallback occurred)
51    pub model_used: String,
52    /// Whether a fallback model was used (true if not the original model)
53    pub used_fallback: bool,
54    /// Total execution latency in microseconds (including retries)
55    pub latency_us: u64,
56}
57
58/// Fallback configuration
59///
60/// Manages fallback mappings for different error types. Each model can have different
61/// fallback models configured for different scenarios.
62///
63/// ## Thread Safety
64///
65/// Uses `RwLock` to allow concurrent reads and exclusive writes.
66#[derive(Debug, Default)]
67pub struct FallbackConfig {
68    /// General fallback: model_name -> fallback model_names
69    general: RwLock<HashMap<String, Vec<String>>>,
70
71    /// Context window exceeded fallback
72    context_window: RwLock<HashMap<String, Vec<String>>>,
73
74    /// Content policy violation fallback
75    content_policy: RwLock<HashMap<String, Vec<String>>>,
76
77    /// Rate limit exceeded fallback
78    rate_limit: RwLock<HashMap<String, Vec<String>>>,
79}
80
81impl FallbackConfig {
82    /// Create a new empty fallback configuration
83    pub fn new() -> Self {
84        Self {
85            general: RwLock::new(HashMap::new()),
86            context_window: RwLock::new(HashMap::new()),
87            content_policy: RwLock::new(HashMap::new()),
88            rate_limit: RwLock::new(HashMap::new()),
89        }
90    }
91
92    /// Add general fallback models for a model (builder pattern)
93    ///
94    /// General fallbacks are used when no specific fallback type matches the error.
95    ///
96    /// # Panics
97    ///
98    /// Panics if the internal lock is poisoned (indicates a bug in the calling code).
99    pub fn add_general(self, model: &str, fallbacks: Vec<String>) -> Self {
100        self.general
101            .write()
102            .expect("FallbackConfig general lock poisoned")
103            .insert(model.to_string(), fallbacks);
104        self
105    }
106
107    /// Add context window fallback models for a model (builder pattern)
108    ///
109    /// Context window fallbacks are used when the input exceeds the model's maximum context length.
110    ///
111    /// # Panics
112    ///
113    /// Panics if the internal lock is poisoned (indicates a bug in the calling code).
114    pub fn add_context_window(self, model: &str, fallbacks: Vec<String>) -> Self {
115        self.context_window
116            .write()
117            .expect("FallbackConfig context_window lock poisoned")
118            .insert(model.to_string(), fallbacks);
119        self
120    }
121
122    /// Add content policy fallback models for a model (builder pattern)
123    ///
124    /// Content policy fallbacks are used when content is filtered or rejected by safety systems.
125    ///
126    /// # Panics
127    ///
128    /// Panics if the internal lock is poisoned (indicates a bug in the calling code).
129    pub fn add_content_policy(self, model: &str, fallbacks: Vec<String>) -> Self {
130        self.content_policy
131            .write()
132            .expect("FallbackConfig content_policy lock poisoned")
133            .insert(model.to_string(), fallbacks);
134        self
135    }
136
137    /// Add rate limit fallback models for a model (builder pattern)
138    ///
139    /// Rate limit fallbacks are used when the model's rate limit is exceeded.
140    ///
141    /// # Panics
142    ///
143    /// Panics if the internal lock is poisoned (indicates a bug in the calling code).
144    pub fn add_rate_limit(self, model: &str, fallbacks: Vec<String>) -> Self {
145        self.rate_limit
146            .write()
147            .expect("FallbackConfig rate_limit lock poisoned")
148            .insert(model.to_string(), fallbacks);
149        self
150    }
151
152    /// Get fallback models for a specific type
153    ///
154    /// Returns a cloned vector of fallback model names. Returns empty vector if no fallbacks
155    /// are configured for the given model and type.
156    ///
157    /// # Panics
158    ///
159    /// Panics if the internal lock is poisoned (indicates a bug in the calling code).
160    pub fn get_fallbacks_for_type(
161        &self,
162        model_name: &str,
163        fallback_type: FallbackType,
164    ) -> Vec<String> {
165        let lock = match fallback_type {
166            FallbackType::General => &self.general,
167            FallbackType::ContextWindow => &self.context_window,
168            FallbackType::ContentPolicy => &self.content_policy,
169            FallbackType::RateLimit => &self.rate_limit,
170        };
171
172        lock.read()
173            .expect("FallbackConfig lock poisoned")
174            .get(model_name)
175            .cloned()
176            .unwrap_or_default()
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    // ==================== FallbackType Tests ====================
185
186    #[test]
187    fn test_fallback_type_debug() {
188        assert!(format!("{:?}", FallbackType::General).contains("General"));
189        assert!(format!("{:?}", FallbackType::ContextWindow).contains("ContextWindow"));
190        assert!(format!("{:?}", FallbackType::ContentPolicy).contains("ContentPolicy"));
191        assert!(format!("{:?}", FallbackType::RateLimit).contains("RateLimit"));
192    }
193
194    #[test]
195    fn test_fallback_type_clone() {
196        let t = FallbackType::ContextWindow;
197        let cloned = t;
198        assert_eq!(cloned, FallbackType::ContextWindow);
199    }
200
201    #[test]
202    fn test_fallback_type_copy() {
203        let t = FallbackType::RateLimit;
204        let copied = t;
205        assert_eq!(t, copied);
206    }
207
208    #[test]
209    fn test_fallback_type_eq() {
210        assert_eq!(FallbackType::General, FallbackType::General);
211        assert_ne!(FallbackType::General, FallbackType::ContextWindow);
212        assert_ne!(FallbackType::ContextWindow, FallbackType::ContentPolicy);
213        assert_ne!(FallbackType::ContentPolicy, FallbackType::RateLimit);
214    }
215
216    // ==================== ExecutionResult Tests ====================
217
218    #[test]
219    fn test_execution_result_creation() {
220        let result = ExecutionResult {
221            result: "success".to_string(),
222            deployment_id: "openai-gpt-4".to_string(),
223            attempts: 1,
224            model_used: "gpt-4".to_string(),
225            used_fallback: false,
226            latency_us: 1500,
227        };
228
229        assert_eq!(result.result, "success");
230        assert_eq!(result.attempts, 1);
231        assert_eq!(result.model_used, "gpt-4");
232        assert!(!result.used_fallback);
233        assert_eq!(result.latency_us, 1500);
234    }
235
236    #[test]
237    fn test_execution_result_with_fallback() {
238        let result = ExecutionResult {
239            result: 42,
240            deployment_id: "anthropic-claude-3".to_string(),
241            attempts: 3,
242            model_used: "claude-3".to_string(),
243            used_fallback: true,
244            latency_us: 5000,
245        };
246
247        assert_eq!(result.result, 42);
248        assert_eq!(result.attempts, 3);
249        assert!(result.used_fallback);
250    }
251
252    #[test]
253    fn test_execution_result_debug() {
254        let result = ExecutionResult {
255            result: "test",
256            deployment_id: "openai-gpt-4".to_string(),
257            attempts: 1,
258            model_used: "gpt-4".to_string(),
259            used_fallback: false,
260            latency_us: 100,
261        };
262        let debug = format!("{:?}", result);
263        assert!(debug.contains("ExecutionResult"));
264        assert!(debug.contains("attempts"));
265    }
266
267    #[test]
268    fn test_execution_result_clone() {
269        let result = ExecutionResult {
270            result: vec![1, 2, 3],
271            deployment_id: "openai-gpt-4".to_string(),
272            attempts: 2,
273            model_used: "gpt-4".to_string(),
274            used_fallback: true,
275            latency_us: 2000,
276        };
277        let cloned = result.clone();
278        assert_eq!(cloned.result, vec![1, 2, 3]);
279        assert_eq!(cloned.attempts, 2);
280        assert!(cloned.used_fallback);
281    }
282
283    // ==================== FallbackConfig Tests ====================
284
285    #[test]
286    fn test_fallback_config_new() {
287        let config = FallbackConfig::new();
288        assert!(
289            config
290                .get_fallbacks_for_type("gpt-4", FallbackType::General)
291                .is_empty()
292        );
293    }
294
295    #[test]
296    fn test_fallback_config_default() {
297        let config = FallbackConfig::default();
298        assert!(
299            config
300                .get_fallbacks_for_type("gpt-4", FallbackType::General)
301                .is_empty()
302        );
303    }
304
305    #[test]
306    fn test_fallback_config_debug() {
307        let config = FallbackConfig::new();
308        let debug = format!("{:?}", config);
309        assert!(debug.contains("FallbackConfig"));
310    }
311
312    #[test]
313    fn test_add_general_fallback() {
314        let config = FallbackConfig::new().add_general(
315            "gpt-4",
316            vec!["gpt-3.5-turbo".to_string(), "claude-3".to_string()],
317        );
318
319        let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::General);
320        assert_eq!(fallbacks.len(), 2);
321        assert!(fallbacks.contains(&"gpt-3.5-turbo".to_string()));
322        assert!(fallbacks.contains(&"claude-3".to_string()));
323    }
324
325    #[test]
326    fn test_add_context_window_fallback() {
327        let config =
328            FallbackConfig::new().add_context_window("gpt-4", vec!["gpt-4-32k".to_string()]);
329
330        let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::ContextWindow);
331        assert_eq!(fallbacks.len(), 1);
332        assert_eq!(fallbacks[0], "gpt-4-32k");
333    }
334
335    #[test]
336    fn test_add_content_policy_fallback() {
337        let config =
338            FallbackConfig::new().add_content_policy("gpt-4", vec!["claude-3".to_string()]);
339
340        let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::ContentPolicy);
341        assert_eq!(fallbacks.len(), 1);
342        assert_eq!(fallbacks[0], "claude-3");
343    }
344
345    #[test]
346    fn test_add_rate_limit_fallback() {
347        let config = FallbackConfig::new().add_rate_limit(
348            "gpt-4",
349            vec!["gpt-3.5-turbo".to_string(), "gpt-4-turbo".to_string()],
350        );
351
352        let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::RateLimit);
353        assert_eq!(fallbacks.len(), 2);
354    }
355
356    #[test]
357    fn test_builder_pattern_chaining() {
358        let config = FallbackConfig::new()
359            .add_general("gpt-4", vec!["gpt-3.5-turbo".to_string()])
360            .add_context_window("gpt-4", vec!["gpt-4-32k".to_string()])
361            .add_content_policy("gpt-4", vec!["claude-3".to_string()])
362            .add_rate_limit("gpt-4", vec!["gemini".to_string()]);
363
364        assert_eq!(
365            config
366                .get_fallbacks_for_type("gpt-4", FallbackType::General)
367                .len(),
368            1
369        );
370        assert_eq!(
371            config
372                .get_fallbacks_for_type("gpt-4", FallbackType::ContextWindow)
373                .len(),
374            1
375        );
376        assert_eq!(
377            config
378                .get_fallbacks_for_type("gpt-4", FallbackType::ContentPolicy)
379                .len(),
380            1
381        );
382        assert_eq!(
383            config
384                .get_fallbacks_for_type("gpt-4", FallbackType::RateLimit)
385                .len(),
386            1
387        );
388    }
389
390    #[test]
391    fn test_multiple_models() {
392        let config = FallbackConfig::new()
393            .add_general("gpt-4", vec!["gpt-3.5-turbo".to_string()])
394            .add_general("claude-3", vec!["gemini".to_string()]);
395
396        let gpt4_fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::General);
397        let claude_fallbacks = config.get_fallbacks_for_type("claude-3", FallbackType::General);
398
399        assert_eq!(gpt4_fallbacks, vec!["gpt-3.5-turbo".to_string()]);
400        assert_eq!(claude_fallbacks, vec!["gemini".to_string()]);
401    }
402
403    #[test]
404    fn test_get_fallbacks_nonexistent_model() {
405        let config = FallbackConfig::new().add_general("gpt-4", vec!["gpt-3.5-turbo".to_string()]);
406
407        let fallbacks = config.get_fallbacks_for_type("nonexistent", FallbackType::General);
408        assert!(fallbacks.is_empty());
409    }
410
411    #[test]
412    fn test_get_fallbacks_wrong_type() {
413        let config = FallbackConfig::new().add_general("gpt-4", vec!["gpt-3.5-turbo".to_string()]);
414
415        // No context window fallback configured
416        let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::ContextWindow);
417        assert!(fallbacks.is_empty());
418    }
419
420    #[test]
421    fn test_empty_fallback_list() {
422        let config = FallbackConfig::new().add_general("gpt-4", vec![]);
423
424        let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::General);
425        assert!(fallbacks.is_empty());
426    }
427
428    #[test]
429    fn test_override_fallback() {
430        let config = FallbackConfig::new()
431            .add_general("gpt-4", vec!["first".to_string()])
432            .add_general("gpt-4", vec!["second".to_string()]);
433
434        let fallbacks = config.get_fallbacks_for_type("gpt-4", FallbackType::General);
435        assert_eq!(fallbacks, vec!["second".to_string()]);
436    }
437
438    // ==================== Thread Safety Tests ====================
439
440    #[test]
441    fn test_concurrent_reads() {
442        use std::sync::Arc;
443        use std::thread;
444
445        let config =
446            Arc::new(FallbackConfig::new().add_general("gpt-4", vec!["gpt-3.5".to_string()]));
447
448        let mut handles = vec![];
449
450        for _ in 0..10 {
451            let c = config.clone();
452            let handle = thread::spawn(move || {
453                for _ in 0..100 {
454                    let _ = c.get_fallbacks_for_type("gpt-4", FallbackType::General);
455                }
456            });
457            handles.push(handle);
458        }
459
460        for handle in handles {
461            handle.join().unwrap();
462        }
463    }
464
465    // ==================== Edge Cases ====================
466
467    #[test]
468    fn test_special_characters_in_model_name() {
469        let config =
470            FallbackConfig::new().add_general("model/v2.0:latest", vec!["backup".to_string()]);
471
472        let fallbacks = config.get_fallbacks_for_type("model/v2.0:latest", FallbackType::General);
473        assert_eq!(fallbacks.len(), 1);
474    }
475
476    #[test]
477    fn test_unicode_in_model_name() {
478        let config = FallbackConfig::new().add_general("模型", vec!["备份".to_string()]);
479
480        let fallbacks = config.get_fallbacks_for_type("模型", FallbackType::General);
481        assert_eq!(fallbacks, vec!["备份".to_string()]);
482    }
483
484    #[test]
485    fn test_empty_model_name() {
486        let config = FallbackConfig::new().add_general("", vec!["fallback".to_string()]);
487
488        let fallbacks = config.get_fallbacks_for_type("", FallbackType::General);
489        assert_eq!(fallbacks, vec!["fallback".to_string()]);
490    }
491
492    #[test]
493    fn test_many_fallbacks() {
494        let fallbacks: Vec<String> = (0..100).map(|i| format!("model_{}", i)).collect();
495        let config = FallbackConfig::new().add_general("primary", fallbacks.clone());
496
497        let result = config.get_fallbacks_for_type("primary", FallbackType::General);
498        assert_eq!(result.len(), 100);
499        assert_eq!(result[0], "model_0");
500        assert_eq!(result[99], "model_99");
501    }
502
503    #[test]
504    fn test_fallback_type_all_variants() {
505        let config = FallbackConfig::new()
506            .add_general("model", vec!["g".to_string()])
507            .add_context_window("model", vec!["cw".to_string()])
508            .add_content_policy("model", vec!["cp".to_string()])
509            .add_rate_limit("model", vec!["rl".to_string()]);
510
511        assert_eq!(
512            config.get_fallbacks_for_type("model", FallbackType::General),
513            vec!["g"]
514        );
515        assert_eq!(
516            config.get_fallbacks_for_type("model", FallbackType::ContextWindow),
517            vec!["cw"]
518        );
519        assert_eq!(
520            config.get_fallbacks_for_type("model", FallbackType::ContentPolicy),
521            vec!["cp"]
522        );
523        assert_eq!(
524            config.get_fallbacks_for_type("model", FallbackType::RateLimit),
525            vec!["rl"]
526        );
527    }
528}