1use grapsus_config::{FallbackConfig, FallbackUpstream};
11
12use super::context::FallbackReason;
13
14#[derive(Debug, Clone)]
16pub struct FallbackDecision {
17 pub next_upstream: String,
19 pub reason: FallbackReason,
21 pub model_mapping: Option<(String, String)>,
23}
24
25pub struct FallbackEvaluator<'a> {
31 config: &'a FallbackConfig,
32 tried_upstreams: &'a [String],
33 current_attempt: u32,
34}
35
36impl<'a> FallbackEvaluator<'a> {
37 pub fn new(
44 config: &'a FallbackConfig,
45 tried_upstreams: &'a [String],
46 current_attempt: u32,
47 ) -> Self {
48 Self {
49 config,
50 tried_upstreams,
51 current_attempt,
52 }
53 }
54
55 pub fn can_attempt_fallback(&self) -> bool {
57 self.current_attempt < self.config.max_attempts
58 }
59
60 pub fn should_fallback_before_request(
75 &self,
76 upstream_id: &str,
77 is_healthy: bool,
78 is_budget_exhausted: bool,
79 current_model: Option<&str>,
80 ) -> Option<FallbackDecision> {
81 if !self.can_attempt_fallback() {
82 return None;
83 }
84
85 if self.config.triggers.on_health_failure && !is_healthy {
87 return self.create_fallback_decision(
88 FallbackReason::HealthCheckFailed,
89 upstream_id,
90 current_model,
91 );
92 }
93
94 if self.config.triggers.on_budget_exhausted && is_budget_exhausted {
96 return self.create_fallback_decision(
97 FallbackReason::BudgetExhausted,
98 upstream_id,
99 current_model,
100 );
101 }
102
103 None
104 }
105
106 pub fn should_fallback_after_response(
121 &self,
122 upstream_id: &str,
123 status_code: u16,
124 latency_ms: u64,
125 current_model: Option<&str>,
126 ) -> Option<FallbackDecision> {
127 if !self.can_attempt_fallback() {
128 return None;
129 }
130
131 if !self.config.triggers.on_error_codes.is_empty()
133 && self.config.triggers.on_error_codes.contains(&status_code)
134 {
135 return self.create_fallback_decision(
136 FallbackReason::ErrorCode(status_code),
137 upstream_id,
138 current_model,
139 );
140 }
141
142 if let Some(threshold_ms) = self.config.triggers.on_latency_threshold_ms {
144 if latency_ms > threshold_ms {
145 return self.create_fallback_decision(
146 FallbackReason::LatencyThreshold {
147 observed_ms: latency_ms,
148 threshold_ms,
149 },
150 upstream_id,
151 current_model,
152 );
153 }
154 }
155
156 None
157 }
158
159 pub fn should_fallback_on_connection_error(
169 &self,
170 upstream_id: &str,
171 error_message: &str,
172 current_model: Option<&str>,
173 ) -> Option<FallbackDecision> {
174 if !self.can_attempt_fallback() {
175 return None;
176 }
177
178 if self.config.triggers.on_connection_error {
179 return self.create_fallback_decision(
180 FallbackReason::ConnectionError(error_message.to_string()),
181 upstream_id,
182 current_model,
183 );
184 }
185
186 None
187 }
188
189 pub fn next_fallback(&self) -> Option<&FallbackUpstream> {
198 self.config.upstreams.iter().find(|fb| {
199 !self.tried_upstreams.contains(&fb.upstream)
201 })
202 }
203
204 pub fn map_model(&self, upstream: &FallbackUpstream, model: &str) -> String {
215 if let Some(mapped) = upstream.model_mapping.get(model) {
217 return mapped.to_string();
218 }
219
220 for (pattern, mapped) in &upstream.model_mapping {
222 if glob_match(pattern, model) {
223 return mapped.to_string();
224 }
225 }
226
227 model.to_string()
229 }
230
231 fn create_fallback_decision(
233 &self,
234 reason: FallbackReason,
235 _current_upstream: &str,
236 current_model: Option<&str>,
237 ) -> Option<FallbackDecision> {
238 let next = self.next_fallback()?;
239
240 let model_mapping = current_model.and_then(|model| {
241 let mapped = self.map_model(next, model);
242 if mapped != model {
243 Some((model.to_string(), mapped))
244 } else {
245 None
246 }
247 });
248
249 Some(FallbackDecision {
250 next_upstream: next.upstream.clone(),
251 reason,
252 model_mapping,
253 })
254 }
255}
256
257fn glob_match(pattern: &str, text: &str) -> bool {
267 let pattern_chars: Vec<char> = pattern.chars().collect();
268 let text_chars: Vec<char> = text.chars().collect();
269
270 glob_match_recursive(&pattern_chars, &text_chars, 0, 0)
271}
272
273fn glob_match_recursive(pattern: &[char], text: &[char], p_idx: usize, t_idx: usize) -> bool {
274 if p_idx >= pattern.len() {
276 return t_idx >= text.len();
277 }
278
279 if pattern[p_idx] == '*' {
281 for i in t_idx..=text.len() {
283 if glob_match_recursive(pattern, text, p_idx + 1, i) {
284 return true;
285 }
286 }
287 return false;
288 }
289
290 if t_idx < text.len() && pattern[p_idx] == text[t_idx] {
292 return glob_match_recursive(pattern, text, p_idx + 1, t_idx + 1);
293 }
294
295 false
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use std::collections::HashMap;
302 use grapsus_config::{FallbackTriggers, InferenceProvider};
303
304 fn create_test_config() -> FallbackConfig {
305 FallbackConfig {
306 upstreams: vec![
307 FallbackUpstream {
308 upstream: "anthropic-fallback".to_string(),
309 provider: InferenceProvider::Anthropic,
310 model_mapping: {
311 let mut map = HashMap::new();
312 map.insert("gpt-4".to_string(), "claude-3-opus".to_string());
313 map.insert("gpt-4o".to_string(), "claude-3-5-sonnet".to_string());
314 map.insert("gpt-3.5-turbo".to_string(), "claude-3-haiku".to_string());
315 map
316 },
317 skip_if_unhealthy: true,
318 },
319 FallbackUpstream {
320 upstream: "local-gpu".to_string(),
321 provider: InferenceProvider::Generic,
322 model_mapping: {
323 let mut map = HashMap::new();
324 map.insert("gpt-4*".to_string(), "llama-3-70b".to_string());
325 map.insert("gpt-3.5*".to_string(), "llama-3-8b".to_string());
326 map
327 },
328 skip_if_unhealthy: true,
329 },
330 ],
331 triggers: FallbackTriggers {
332 on_health_failure: true,
333 on_budget_exhausted: true,
334 on_latency_threshold_ms: Some(5000),
335 on_error_codes: vec![429, 500, 502, 503, 504],
336 on_connection_error: true,
337 },
338 max_attempts: 2,
339 }
340 }
341
342 #[test]
343 fn test_glob_match_exact() {
344 assert!(glob_match("gpt-4", "gpt-4"));
345 assert!(!glob_match("gpt-4", "gpt-4-turbo"));
346 }
347
348 #[test]
349 fn test_glob_match_suffix_wildcard() {
350 assert!(glob_match("gpt-4*", "gpt-4"));
351 assert!(glob_match("gpt-4*", "gpt-4-turbo"));
352 assert!(glob_match("gpt-4*", "gpt-4o"));
353 assert!(!glob_match("gpt-4*", "gpt-3.5-turbo"));
354 }
355
356 #[test]
357 fn test_glob_match_middle_wildcard() {
358 assert!(glob_match("claude-*-sonnet", "claude-3-sonnet"));
359 assert!(glob_match("claude-*-sonnet", "claude-3.5-sonnet"));
360 assert!(!glob_match("claude-*-sonnet", "claude-3-opus"));
361 }
362
363 #[test]
364 fn test_glob_match_prefix_wildcard() {
365 assert!(glob_match("*-turbo", "gpt-4-turbo"));
366 assert!(glob_match("*-turbo", "gpt-3.5-turbo"));
367 assert!(!glob_match("*-turbo", "gpt-4"));
368 }
369
370 #[test]
371 fn test_can_attempt_fallback() {
372 let config = create_test_config();
373
374 let evaluator = FallbackEvaluator::new(&config, &[], 0);
375 assert!(evaluator.can_attempt_fallback());
376
377 let evaluator = FallbackEvaluator::new(&config, &[], 1);
378 assert!(evaluator.can_attempt_fallback());
379
380 let evaluator = FallbackEvaluator::new(&config, &[], 2);
381 assert!(!evaluator.can_attempt_fallback());
382 }
383
384 #[test]
385 fn test_fallback_on_health_failure() {
386 let config = create_test_config();
387 let evaluator = FallbackEvaluator::new(&config, &[], 0);
388
389 let decision = evaluator.should_fallback_before_request(
391 "openai-primary",
392 false, false,
394 Some("gpt-4"),
395 );
396
397 assert!(decision.is_some());
398 let decision = decision.unwrap();
399 assert_eq!(decision.next_upstream, "anthropic-fallback");
400 assert!(matches!(decision.reason, FallbackReason::HealthCheckFailed));
401 assert_eq!(
402 decision.model_mapping,
403 Some(("gpt-4".to_string(), "claude-3-opus".to_string()))
404 );
405 }
406
407 #[test]
408 fn test_fallback_on_budget_exhausted() {
409 let config = create_test_config();
410 let evaluator = FallbackEvaluator::new(&config, &[], 0);
411
412 let decision = evaluator.should_fallback_before_request(
413 "openai-primary",
414 true, true, Some("gpt-4o"),
417 );
418
419 assert!(decision.is_some());
420 let decision = decision.unwrap();
421 assert_eq!(decision.next_upstream, "anthropic-fallback");
422 assert!(matches!(decision.reason, FallbackReason::BudgetExhausted));
423 assert_eq!(
424 decision.model_mapping,
425 Some(("gpt-4o".to_string(), "claude-3-5-sonnet".to_string()))
426 );
427 }
428
429 #[test]
430 fn test_fallback_on_error_code() {
431 let config = create_test_config();
432 let evaluator = FallbackEvaluator::new(&config, &[], 0);
433
434 let decision =
436 evaluator.should_fallback_after_response("openai-primary", 503, 1000, Some("gpt-4"));
437 assert!(decision.is_some());
438
439 let decision =
441 evaluator.should_fallback_after_response("openai-primary", 200, 1000, Some("gpt-4"));
442 assert!(decision.is_none());
443
444 let decision =
446 evaluator.should_fallback_after_response("openai-primary", 404, 1000, Some("gpt-4"));
447 assert!(decision.is_none());
448 }
449
450 #[test]
451 fn test_fallback_on_latency_threshold() {
452 let config = create_test_config();
453 let evaluator = FallbackEvaluator::new(&config, &[], 0);
454
455 let decision = evaluator.should_fallback_after_response(
457 "openai-primary",
458 200,
459 6000, Some("gpt-4"),
461 );
462 assert!(decision.is_some());
463 let decision = decision.unwrap();
464 assert!(matches!(
465 decision.reason,
466 FallbackReason::LatencyThreshold {
467 observed_ms: 6000,
468 threshold_ms: 5000
469 }
470 ));
471
472 let decision = evaluator.should_fallback_after_response(
474 "openai-primary",
475 200,
476 4000, Some("gpt-4"),
478 );
479 assert!(decision.is_none());
480 }
481
482 #[test]
483 fn test_fallback_on_connection_error() {
484 let config = create_test_config();
485 let evaluator = FallbackEvaluator::new(&config, &[], 0);
486
487 let decision = evaluator.should_fallback_on_connection_error(
488 "openai-primary",
489 "connection refused",
490 Some("gpt-4"),
491 );
492
493 assert!(decision.is_some());
494 let decision = decision.unwrap();
495 assert!(matches!(
496 decision.reason,
497 FallbackReason::ConnectionError(_)
498 ));
499 }
500
501 #[test]
502 fn test_next_fallback_skips_tried() {
503 let config = create_test_config();
504
505 let evaluator = FallbackEvaluator::new(&config, &[], 0);
507 let next = evaluator.next_fallback().unwrap();
508 assert_eq!(next.upstream, "anthropic-fallback");
509
510 let tried = vec!["anthropic-fallback".to_string()];
512 let evaluator = FallbackEvaluator::new(&config, &tried, 1);
513 let next = evaluator.next_fallback().unwrap();
514 assert_eq!(next.upstream, "local-gpu");
515
516 let tried = vec!["anthropic-fallback".to_string(), "local-gpu".to_string()];
518 let evaluator = FallbackEvaluator::new(&config, &tried, 2);
519 assert!(evaluator.next_fallback().is_none());
520 }
521
522 #[test]
523 fn test_model_mapping_exact() {
524 let config = create_test_config();
525 let evaluator = FallbackEvaluator::new(&config, &[], 0);
526 let upstream = &config.upstreams[0]; assert_eq!(evaluator.map_model(upstream, "gpt-4"), "claude-3-opus");
529 assert_eq!(evaluator.map_model(upstream, "gpt-4o"), "claude-3-5-sonnet");
530 assert_eq!(
531 evaluator.map_model(upstream, "gpt-3.5-turbo"),
532 "claude-3-haiku"
533 );
534 assert_eq!(
536 evaluator.map_model(upstream, "unknown-model"),
537 "unknown-model"
538 );
539 }
540
541 #[test]
542 fn test_model_mapping_glob() {
543 let config = create_test_config();
544 let evaluator = FallbackEvaluator::new(&config, &[], 0);
545 let upstream = &config.upstreams[1]; assert_eq!(evaluator.map_model(upstream, "gpt-4"), "llama-3-70b");
548 assert_eq!(evaluator.map_model(upstream, "gpt-4-turbo"), "llama-3-70b");
549 assert_eq!(evaluator.map_model(upstream, "gpt-4o"), "llama-3-70b");
550 assert_eq!(evaluator.map_model(upstream, "gpt-3.5-turbo"), "llama-3-8b");
551 }
552
553 #[test]
554 fn test_no_fallback_when_healthy_and_budget_ok() {
555 let config = create_test_config();
556 let evaluator = FallbackEvaluator::new(&config, &[], 0);
557
558 let decision = evaluator.should_fallback_before_request(
559 "openai-primary",
560 true, false, Some("gpt-4"),
563 );
564
565 assert!(decision.is_none());
566 }
567
568 #[test]
569 fn test_no_fallback_when_max_attempts_reached() {
570 let config = create_test_config();
571 let evaluator = FallbackEvaluator::new(&config, &[], 2); let decision = evaluator.should_fallback_before_request(
575 "openai-primary",
576 false, false,
578 Some("gpt-4"),
579 );
580
581 assert!(decision.is_none());
582 }
583}