1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
6pub enum ErrorKind {
7 RateLimit,
8 ServerError,
9 Timeout,
10 ContentPolicy,
11 ContextWindow,
12 AuthError,
13 Other,
14}
15
16impl ErrorKind {
17 pub fn from_status(status: u16, body: &str) -> Self {
18 match status {
19 429 => Self::RateLimit,
20 401 | 403 => Self::AuthError,
21 400 => {
22 let lower = body.to_lowercase();
23 if lower.contains("content_policy") || lower.contains("content_filter") {
24 Self::ContentPolicy
25 } else if lower.contains("context_length")
26 || lower.contains("context_window")
27 || lower.contains("maximum context")
28 {
29 Self::ContextWindow
30 } else {
31 Self::Other
32 }
33 }
34 500 | 502 | 503 | 504 => Self::ServerError,
35 _ => Self::Other,
36 }
37 }
38
39 pub fn is_timeout(err: &hyperinfer_core::HyperInferError) -> bool {
40 match err {
41 hyperinfer_core::HyperInferError::Http(e) => e.is_timeout(),
42 _ => false,
43 }
44 }
45
46 pub fn classify(err: &hyperinfer_core::HyperInferError) -> Self {
47 match err {
48 hyperinfer_core::HyperInferError::ApiError { status, message } => {
49 Self::from_status(*status, message)
50 }
51 hyperinfer_core::HyperInferError::Http(e) if e.is_timeout() => Self::Timeout,
52 _ => Self::Other,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct FallbackConfig {
59 pub fallbacks: HashMap<String, Vec<String>>,
60 pub default_fallbacks: Vec<String>,
61 pub content_policy_fallbacks: HashMap<String, Vec<String>>,
62 pub context_window_fallbacks: HashMap<String, Vec<String>>,
63 pub max_fallbacks: usize,
64 pub num_retries: u32,
65}
66
67impl Default for FallbackConfig {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl FallbackConfig {
74 pub fn new() -> Self {
75 Self {
76 fallbacks: HashMap::new(),
77 default_fallbacks: Vec::new(),
78 content_policy_fallbacks: HashMap::new(),
79 context_window_fallbacks: HashMap::new(),
80 max_fallbacks: 5,
81 num_retries: 3,
82 }
83 }
84
85 pub fn with_fallback(mut self, model: impl Into<String>, targets: Vec<String>) -> Self {
86 self.fallbacks.insert(model.into(), targets);
87 self
88 }
89
90 pub fn with_default_fallbacks(mut self, targets: Vec<String>) -> Self {
91 self.default_fallbacks = targets;
92 self
93 }
94
95 pub fn with_content_policy_fallback(
96 mut self,
97 model: impl Into<String>,
98 targets: Vec<String>,
99 ) -> Self {
100 self.content_policy_fallbacks.insert(model.into(), targets);
101 self
102 }
103
104 pub fn with_context_window_fallback(
105 mut self,
106 model: impl Into<String>,
107 targets: Vec<String>,
108 ) -> Self {
109 self.context_window_fallbacks.insert(model.into(), targets);
110 self
111 }
112
113 pub fn get_fallbacks(&self, model: &str, error_kind: &ErrorKind) -> Vec<String> {
114 let map = match error_kind {
115 ErrorKind::ContentPolicy => Some(&self.content_policy_fallbacks),
116 ErrorKind::ContextWindow => Some(&self.context_window_fallbacks),
117 _ => Some(&self.fallbacks),
118 };
119
120 if let Some(map) = map {
121 if let Some(targets) = map.get(model) {
122 return targets.clone();
123 }
124 }
125
126 self.default_fallbacks.clone()
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn test_classify_429() {
136 let err = hyperinfer_core::HyperInferError::ApiError {
137 status: 429,
138 message: "rate limited".into(),
139 };
140 assert_eq!(ErrorKind::classify(&err), ErrorKind::RateLimit);
141 }
142
143 #[test]
144 fn test_classify_500() {
145 let err = hyperinfer_core::HyperInferError::ApiError {
146 status: 500,
147 message: "internal error".into(),
148 };
149 assert_eq!(ErrorKind::classify(&err), ErrorKind::ServerError);
150 }
151
152 #[test]
153 fn test_classify_502() {
154 let err = hyperinfer_core::HyperInferError::ApiError {
155 status: 502,
156 message: "bad gateway".into(),
157 };
158 assert_eq!(ErrorKind::classify(&err), ErrorKind::ServerError);
159 }
160
161 #[test]
162 fn test_classify_401() {
163 let err = hyperinfer_core::HyperInferError::ApiError {
164 status: 401,
165 message: "unauthorized".into(),
166 };
167 assert_eq!(ErrorKind::classify(&err), ErrorKind::AuthError);
168 }
169
170 #[test]
171 fn test_classify_content_policy() {
172 let err = hyperinfer_core::HyperInferError::ApiError {
173 status: 400,
174 message: "violated content_policy rules".into(),
175 };
176 assert_eq!(ErrorKind::classify(&err), ErrorKind::ContentPolicy);
177 }
178
179 #[test]
180 fn test_classify_context_window() {
181 let err = hyperinfer_core::HyperInferError::ApiError {
182 status: 400,
183 message: "exceeds context_length limit".into(),
184 };
185 assert_eq!(ErrorKind::classify(&err), ErrorKind::ContextWindow);
186 }
187
188 #[test]
189 fn test_classify_unknown_400() {
190 let err = hyperinfer_core::HyperInferError::ApiError {
191 status: 400,
192 message: "bad request".into(),
193 };
194 assert_eq!(ErrorKind::classify(&err), ErrorKind::Other);
195 }
196
197 #[test]
198 fn test_fallback_lookup_specific() {
199 let config = FallbackConfig::new()
200 .with_fallback("gpt-4", vec!["claude-3".into(), "gemini-pro".into()]);
201 let result = config.get_fallbacks("gpt-4", &ErrorKind::ServerError);
202 assert_eq!(result, vec!["claude-3", "gemini-pro"]);
203 }
204
205 #[test]
206 fn test_fallback_lookup_default() {
207 let config = FallbackConfig::new().with_default_fallbacks(vec!["default-model".into()]);
208 let result = config.get_fallbacks("unknown-model", &ErrorKind::ServerError);
209 assert_eq!(result, vec!["default-model"]);
210 }
211
212 #[test]
213 fn test_fallback_content_policy_specific() {
214 let config = FallbackConfig::new()
215 .with_content_policy_fallback("gpt-4", vec!["claude-3-opus".into()]);
216 let result = config.get_fallbacks("gpt-4", &ErrorKind::ContentPolicy);
217 assert_eq!(result, vec!["claude-3-opus"]);
218 }
219
220 #[test]
221 fn test_fallback_context_window_specific() {
222 let config = FallbackConfig::new()
223 .with_context_window_fallback("gpt-4", vec!["gemini-pro-1m".into()]);
224 let result = config.get_fallbacks("gpt-4", &ErrorKind::ContextWindow);
225 assert_eq!(result, vec!["gemini-pro-1m"]);
226 }
227
228 #[test]
229 fn test_fallback_no_match_returns_empty() {
230 let config = FallbackConfig::new();
231 let result = config.get_fallbacks("gpt-4", &ErrorKind::ServerError);
232 assert!(result.is_empty());
233 }
234}