Skip to main content

hyperinfer_router/
fallback.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
6pub enum ErrorKind {
7    RateLimit,
8    ServerError,
9    Timeout,
10    ContentPolicy,
11    ContextWindow,
12    AuthError,
13    Other,
14}
15
16impl ErrorKind {
17    pub fn from_status(status: u16, body: &str) -> Self {
18        match status {
19            429 => Self::RateLimit,
20            401 | 403 => Self::AuthError,
21            400 => {
22                let lower = body.to_lowercase();
23                if lower.contains("content_policy") || lower.contains("content_filter") {
24                    Self::ContentPolicy
25                } else if lower.contains("context_length")
26                    || lower.contains("context_window")
27                    || lower.contains("maximum context")
28                {
29                    Self::ContextWindow
30                } else {
31                    Self::Other
32                }
33            }
34            500 | 502 | 503 | 504 => Self::ServerError,
35            _ => Self::Other,
36        }
37    }
38
39    pub fn is_timeout(err: &hyperinfer_core::HyperInferError) -> bool {
40        match err {
41            hyperinfer_core::HyperInferError::Http(e) => e.is_timeout(),
42            _ => false,
43        }
44    }
45
46    pub fn classify(err: &hyperinfer_core::HyperInferError) -> Self {
47        match err {
48            hyperinfer_core::HyperInferError::ApiError { status, message } => {
49                Self::from_status(*status, message)
50            }
51            hyperinfer_core::HyperInferError::Http(e) if e.is_timeout() => Self::Timeout,
52            _ => Self::Other,
53        }
54    }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct FallbackConfig {
59    pub fallbacks: HashMap<String, Vec<String>>,
60    pub default_fallbacks: Vec<String>,
61    pub content_policy_fallbacks: HashMap<String, Vec<String>>,
62    pub context_window_fallbacks: HashMap<String, Vec<String>>,
63    pub max_fallbacks: usize,
64    pub num_retries: u32,
65}
66
67impl Default for FallbackConfig {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl FallbackConfig {
74    pub fn new() -> Self {
75        Self {
76            fallbacks: HashMap::new(),
77            default_fallbacks: Vec::new(),
78            content_policy_fallbacks: HashMap::new(),
79            context_window_fallbacks: HashMap::new(),
80            max_fallbacks: 5,
81            num_retries: 3,
82        }
83    }
84
85    pub fn with_fallback(mut self, model: impl Into<String>, targets: Vec<String>) -> Self {
86        self.fallbacks.insert(model.into(), targets);
87        self
88    }
89
90    pub fn with_default_fallbacks(mut self, targets: Vec<String>) -> Self {
91        self.default_fallbacks = targets;
92        self
93    }
94
95    pub fn with_content_policy_fallback(
96        mut self,
97        model: impl Into<String>,
98        targets: Vec<String>,
99    ) -> Self {
100        self.content_policy_fallbacks.insert(model.into(), targets);
101        self
102    }
103
104    pub fn with_context_window_fallback(
105        mut self,
106        model: impl Into<String>,
107        targets: Vec<String>,
108    ) -> Self {
109        self.context_window_fallbacks.insert(model.into(), targets);
110        self
111    }
112
113    pub fn get_fallbacks(&self, model: &str, error_kind: &ErrorKind) -> Vec<String> {
114        let map = match error_kind {
115            ErrorKind::ContentPolicy => Some(&self.content_policy_fallbacks),
116            ErrorKind::ContextWindow => Some(&self.context_window_fallbacks),
117            _ => Some(&self.fallbacks),
118        };
119
120        if let Some(map) = map {
121            if let Some(targets) = map.get(model) {
122                return targets.clone();
123            }
124        }
125
126        self.default_fallbacks.clone()
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn test_classify_429() {
136        let err = hyperinfer_core::HyperInferError::ApiError {
137            status: 429,
138            message: "rate limited".into(),
139        };
140        assert_eq!(ErrorKind::classify(&err), ErrorKind::RateLimit);
141    }
142
143    #[test]
144    fn test_classify_500() {
145        let err = hyperinfer_core::HyperInferError::ApiError {
146            status: 500,
147            message: "internal error".into(),
148        };
149        assert_eq!(ErrorKind::classify(&err), ErrorKind::ServerError);
150    }
151
152    #[test]
153    fn test_classify_502() {
154        let err = hyperinfer_core::HyperInferError::ApiError {
155            status: 502,
156            message: "bad gateway".into(),
157        };
158        assert_eq!(ErrorKind::classify(&err), ErrorKind::ServerError);
159    }
160
161    #[test]
162    fn test_classify_401() {
163        let err = hyperinfer_core::HyperInferError::ApiError {
164            status: 401,
165            message: "unauthorized".into(),
166        };
167        assert_eq!(ErrorKind::classify(&err), ErrorKind::AuthError);
168    }
169
170    #[test]
171    fn test_classify_content_policy() {
172        let err = hyperinfer_core::HyperInferError::ApiError {
173            status: 400,
174            message: "violated content_policy rules".into(),
175        };
176        assert_eq!(ErrorKind::classify(&err), ErrorKind::ContentPolicy);
177    }
178
179    #[test]
180    fn test_classify_context_window() {
181        let err = hyperinfer_core::HyperInferError::ApiError {
182            status: 400,
183            message: "exceeds context_length limit".into(),
184        };
185        assert_eq!(ErrorKind::classify(&err), ErrorKind::ContextWindow);
186    }
187
188    #[test]
189    fn test_classify_unknown_400() {
190        let err = hyperinfer_core::HyperInferError::ApiError {
191            status: 400,
192            message: "bad request".into(),
193        };
194        assert_eq!(ErrorKind::classify(&err), ErrorKind::Other);
195    }
196
197    #[test]
198    fn test_fallback_lookup_specific() {
199        let config = FallbackConfig::new()
200            .with_fallback("gpt-4", vec!["claude-3".into(), "gemini-pro".into()]);
201        let result = config.get_fallbacks("gpt-4", &ErrorKind::ServerError);
202        assert_eq!(result, vec!["claude-3", "gemini-pro"]);
203    }
204
205    #[test]
206    fn test_fallback_lookup_default() {
207        let config = FallbackConfig::new().with_default_fallbacks(vec!["default-model".into()]);
208        let result = config.get_fallbacks("unknown-model", &ErrorKind::ServerError);
209        assert_eq!(result, vec!["default-model"]);
210    }
211
212    #[test]
213    fn test_fallback_content_policy_specific() {
214        let config = FallbackConfig::new()
215            .with_content_policy_fallback("gpt-4", vec!["claude-3-opus".into()]);
216        let result = config.get_fallbacks("gpt-4", &ErrorKind::ContentPolicy);
217        assert_eq!(result, vec!["claude-3-opus"]);
218    }
219
220    #[test]
221    fn test_fallback_context_window_specific() {
222        let config = FallbackConfig::new()
223            .with_context_window_fallback("gpt-4", vec!["gemini-pro-1m".into()]);
224        let result = config.get_fallbacks("gpt-4", &ErrorKind::ContextWindow);
225        assert_eq!(result, vec!["gemini-pro-1m"]);
226    }
227
228    #[test]
229    fn test_fallback_no_match_returns_empty() {
230        let config = FallbackConfig::new();
231        let result = config.get_fallbacks("gpt-4", &ErrorKind::ServerError);
232        assert!(result.is_empty());
233    }
234}