1use sentinel_config::{FallbackConfig, FallbackUpstream};
11
12use super::context::FallbackReason;
13
14#[derive(Debug, Clone)]
16pub struct FallbackDecision {
17 pub next_upstream: String,
19 pub reason: FallbackReason,
21 pub model_mapping: Option<(String, String)>,
23}
24
25pub struct FallbackEvaluator<'a> {
31 config: &'a FallbackConfig,
32 tried_upstreams: &'a [String],
33 current_attempt: u32,
34}
35
36impl<'a> FallbackEvaluator<'a> {
37 pub fn new(
44 config: &'a FallbackConfig,
45 tried_upstreams: &'a [String],
46 current_attempt: u32,
47 ) -> Self {
48 Self {
49 config,
50 tried_upstreams,
51 current_attempt,
52 }
53 }
54
55 pub fn can_attempt_fallback(&self) -> bool {
57 self.current_attempt < self.config.max_attempts
58 }
59
60 pub fn should_fallback_before_request(
75 &self,
76 upstream_id: &str,
77 is_healthy: bool,
78 is_budget_exhausted: bool,
79 current_model: Option<&str>,
80 ) -> Option<FallbackDecision> {
81 if !self.can_attempt_fallback() {
82 return None;
83 }
84
85 if self.config.triggers.on_health_failure && !is_healthy {
87 return self.create_fallback_decision(
88 FallbackReason::HealthCheckFailed,
89 upstream_id,
90 current_model,
91 );
92 }
93
94 if self.config.triggers.on_budget_exhausted && is_budget_exhausted {
96 return self.create_fallback_decision(
97 FallbackReason::BudgetExhausted,
98 upstream_id,
99 current_model,
100 );
101 }
102
103 None
104 }
105
106 pub fn should_fallback_after_response(
121 &self,
122 upstream_id: &str,
123 status_code: u16,
124 latency_ms: u64,
125 current_model: Option<&str>,
126 ) -> Option<FallbackDecision> {
127 if !self.can_attempt_fallback() {
128 return None;
129 }
130
131 if !self.config.triggers.on_error_codes.is_empty()
133 && self.config.triggers.on_error_codes.contains(&status_code)
134 {
135 return self.create_fallback_decision(
136 FallbackReason::ErrorCode(status_code),
137 upstream_id,
138 current_model,
139 );
140 }
141
142 if let Some(threshold_ms) = self.config.triggers.on_latency_threshold_ms {
144 if latency_ms > threshold_ms {
145 return self.create_fallback_decision(
146 FallbackReason::LatencyThreshold {
147 observed_ms: latency_ms,
148 threshold_ms,
149 },
150 upstream_id,
151 current_model,
152 );
153 }
154 }
155
156 None
157 }
158
159 pub fn should_fallback_on_connection_error(
169 &self,
170 upstream_id: &str,
171 error_message: &str,
172 current_model: Option<&str>,
173 ) -> Option<FallbackDecision> {
174 if !self.can_attempt_fallback() {
175 return None;
176 }
177
178 if self.config.triggers.on_connection_error {
179 return self.create_fallback_decision(
180 FallbackReason::ConnectionError(error_message.to_string()),
181 upstream_id,
182 current_model,
183 );
184 }
185
186 None
187 }
188
189 pub fn next_fallback(&self) -> Option<&FallbackUpstream> {
198 self.config.upstreams.iter().find(|fb| {
199 !self.tried_upstreams.contains(&fb.upstream)
201 })
202 }
203
204 pub fn map_model(&self, upstream: &FallbackUpstream, model: &str) -> String {
215 if let Some(mapped) = upstream.model_mapping.get(model) {
217 return mapped.to_string();
218 }
219
220 for (pattern, mapped) in &upstream.model_mapping {
222 if glob_match(pattern, model) {
223 return mapped.to_string();
224 }
225 }
226
227 model.to_string()
229 }
230
231 fn create_fallback_decision(
233 &self,
234 reason: FallbackReason,
235 _current_upstream: &str,
236 current_model: Option<&str>,
237 ) -> Option<FallbackDecision> {
238 let next = self.next_fallback()?;
239
240 let model_mapping = current_model.map(|model| {
241 let mapped = self.map_model(next, model);
242 if mapped != model {
243 Some((model.to_string(), mapped))
244 } else {
245 None
246 }
247 }).flatten();
248
249 Some(FallbackDecision {
250 next_upstream: next.upstream.clone(),
251 reason,
252 model_mapping,
253 })
254 }
255}
256
257fn glob_match(pattern: &str, text: &str) -> bool {
267 let pattern_chars: Vec<char> = pattern.chars().collect();
268 let text_chars: Vec<char> = text.chars().collect();
269
270 glob_match_recursive(&pattern_chars, &text_chars, 0, 0)
271}
272
273fn glob_match_recursive(pattern: &[char], text: &[char], p_idx: usize, t_idx: usize) -> bool {
274 if p_idx >= pattern.len() {
276 return t_idx >= text.len();
277 }
278
279 if pattern[p_idx] == '*' {
281 for i in t_idx..=text.len() {
283 if glob_match_recursive(pattern, text, p_idx + 1, i) {
284 return true;
285 }
286 }
287 return false;
288 }
289
290 if t_idx < text.len() && pattern[p_idx] == text[t_idx] {
292 return glob_match_recursive(pattern, text, p_idx + 1, t_idx + 1);
293 }
294
295 false
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use sentinel_config::{FallbackTriggers, InferenceProvider};
302 use std::collections::HashMap;
303
304 fn create_test_config() -> FallbackConfig {
305 FallbackConfig {
306 upstreams: vec![
307 FallbackUpstream {
308 upstream: "anthropic-fallback".to_string(),
309 provider: InferenceProvider::Anthropic,
310 model_mapping: {
311 let mut map = HashMap::new();
312 map.insert("gpt-4".to_string(), "claude-3-opus".to_string());
313 map.insert("gpt-4o".to_string(), "claude-3-5-sonnet".to_string());
314 map.insert("gpt-3.5-turbo".to_string(), "claude-3-haiku".to_string());
315 map
316 },
317 skip_if_unhealthy: true,
318 },
319 FallbackUpstream {
320 upstream: "local-gpu".to_string(),
321 provider: InferenceProvider::Generic,
322 model_mapping: {
323 let mut map = HashMap::new();
324 map.insert("gpt-4*".to_string(), "llama-3-70b".to_string());
325 map.insert("gpt-3.5*".to_string(), "llama-3-8b".to_string());
326 map
327 },
328 skip_if_unhealthy: true,
329 },
330 ],
331 triggers: FallbackTriggers {
332 on_health_failure: true,
333 on_budget_exhausted: true,
334 on_latency_threshold_ms: Some(5000),
335 on_error_codes: vec![429, 500, 502, 503, 504],
336 on_connection_error: true,
337 },
338 max_attempts: 2,
339 }
340 }
341
342 #[test]
343 fn test_glob_match_exact() {
344 assert!(glob_match("gpt-4", "gpt-4"));
345 assert!(!glob_match("gpt-4", "gpt-4-turbo"));
346 }
347
348 #[test]
349 fn test_glob_match_suffix_wildcard() {
350 assert!(glob_match("gpt-4*", "gpt-4"));
351 assert!(glob_match("gpt-4*", "gpt-4-turbo"));
352 assert!(glob_match("gpt-4*", "gpt-4o"));
353 assert!(!glob_match("gpt-4*", "gpt-3.5-turbo"));
354 }
355
356 #[test]
357 fn test_glob_match_middle_wildcard() {
358 assert!(glob_match("claude-*-sonnet", "claude-3-sonnet"));
359 assert!(glob_match("claude-*-sonnet", "claude-3.5-sonnet"));
360 assert!(!glob_match("claude-*-sonnet", "claude-3-opus"));
361 }
362
363 #[test]
364 fn test_glob_match_prefix_wildcard() {
365 assert!(glob_match("*-turbo", "gpt-4-turbo"));
366 assert!(glob_match("*-turbo", "gpt-3.5-turbo"));
367 assert!(!glob_match("*-turbo", "gpt-4"));
368 }
369
370 #[test]
371 fn test_can_attempt_fallback() {
372 let config = create_test_config();
373
374 let evaluator = FallbackEvaluator::new(&config, &[], 0);
375 assert!(evaluator.can_attempt_fallback());
376
377 let evaluator = FallbackEvaluator::new(&config, &[], 1);
378 assert!(evaluator.can_attempt_fallback());
379
380 let evaluator = FallbackEvaluator::new(&config, &[], 2);
381 assert!(!evaluator.can_attempt_fallback());
382 }
383
384 #[test]
385 fn test_fallback_on_health_failure() {
386 let config = create_test_config();
387 let evaluator = FallbackEvaluator::new(&config, &[], 0);
388
389 let decision = evaluator.should_fallback_before_request(
391 "openai-primary",
392 false, false,
394 Some("gpt-4"),
395 );
396
397 assert!(decision.is_some());
398 let decision = decision.unwrap();
399 assert_eq!(decision.next_upstream, "anthropic-fallback");
400 assert!(matches!(decision.reason, FallbackReason::HealthCheckFailed));
401 assert_eq!(
402 decision.model_mapping,
403 Some(("gpt-4".to_string(), "claude-3-opus".to_string()))
404 );
405 }
406
407 #[test]
408 fn test_fallback_on_budget_exhausted() {
409 let config = create_test_config();
410 let evaluator = FallbackEvaluator::new(&config, &[], 0);
411
412 let decision = evaluator.should_fallback_before_request(
413 "openai-primary",
414 true, true, Some("gpt-4o"),
417 );
418
419 assert!(decision.is_some());
420 let decision = decision.unwrap();
421 assert_eq!(decision.next_upstream, "anthropic-fallback");
422 assert!(matches!(decision.reason, FallbackReason::BudgetExhausted));
423 assert_eq!(
424 decision.model_mapping,
425 Some(("gpt-4o".to_string(), "claude-3-5-sonnet".to_string()))
426 );
427 }
428
429 #[test]
430 fn test_fallback_on_error_code() {
431 let config = create_test_config();
432 let evaluator = FallbackEvaluator::new(&config, &[], 0);
433
434 let decision = evaluator.should_fallback_after_response(
436 "openai-primary",
437 503,
438 1000,
439 Some("gpt-4"),
440 );
441 assert!(decision.is_some());
442
443 let decision = evaluator.should_fallback_after_response(
445 "openai-primary",
446 200,
447 1000,
448 Some("gpt-4"),
449 );
450 assert!(decision.is_none());
451
452 let decision = evaluator.should_fallback_after_response(
454 "openai-primary",
455 404,
456 1000,
457 Some("gpt-4"),
458 );
459 assert!(decision.is_none());
460 }
461
462 #[test]
463 fn test_fallback_on_latency_threshold() {
464 let config = create_test_config();
465 let evaluator = FallbackEvaluator::new(&config, &[], 0);
466
467 let decision = evaluator.should_fallback_after_response(
469 "openai-primary",
470 200,
471 6000, Some("gpt-4"),
473 );
474 assert!(decision.is_some());
475 let decision = decision.unwrap();
476 assert!(matches!(
477 decision.reason,
478 FallbackReason::LatencyThreshold {
479 observed_ms: 6000,
480 threshold_ms: 5000
481 }
482 ));
483
484 let decision = evaluator.should_fallback_after_response(
486 "openai-primary",
487 200,
488 4000, Some("gpt-4"),
490 );
491 assert!(decision.is_none());
492 }
493
494 #[test]
495 fn test_fallback_on_connection_error() {
496 let config = create_test_config();
497 let evaluator = FallbackEvaluator::new(&config, &[], 0);
498
499 let decision = evaluator.should_fallback_on_connection_error(
500 "openai-primary",
501 "connection refused",
502 Some("gpt-4"),
503 );
504
505 assert!(decision.is_some());
506 let decision = decision.unwrap();
507 assert!(matches!(
508 decision.reason,
509 FallbackReason::ConnectionError(_)
510 ));
511 }
512
513 #[test]
514 fn test_next_fallback_skips_tried() {
515 let config = create_test_config();
516
517 let evaluator = FallbackEvaluator::new(&config, &[], 0);
519 let next = evaluator.next_fallback().unwrap();
520 assert_eq!(next.upstream, "anthropic-fallback");
521
522 let tried = vec!["anthropic-fallback".to_string()];
524 let evaluator = FallbackEvaluator::new(&config, &tried, 1);
525 let next = evaluator.next_fallback().unwrap();
526 assert_eq!(next.upstream, "local-gpu");
527
528 let tried = vec!["anthropic-fallback".to_string(), "local-gpu".to_string()];
530 let evaluator = FallbackEvaluator::new(&config, &tried, 2);
531 assert!(evaluator.next_fallback().is_none());
532 }
533
534 #[test]
535 fn test_model_mapping_exact() {
536 let config = create_test_config();
537 let evaluator = FallbackEvaluator::new(&config, &[], 0);
538 let upstream = &config.upstreams[0]; assert_eq!(evaluator.map_model(upstream, "gpt-4"), "claude-3-opus");
541 assert_eq!(evaluator.map_model(upstream, "gpt-4o"), "claude-3-5-sonnet");
542 assert_eq!(evaluator.map_model(upstream, "gpt-3.5-turbo"), "claude-3-haiku");
543 assert_eq!(evaluator.map_model(upstream, "unknown-model"), "unknown-model");
545 }
546
547 #[test]
548 fn test_model_mapping_glob() {
549 let config = create_test_config();
550 let evaluator = FallbackEvaluator::new(&config, &[], 0);
551 let upstream = &config.upstreams[1]; assert_eq!(evaluator.map_model(upstream, "gpt-4"), "llama-3-70b");
554 assert_eq!(evaluator.map_model(upstream, "gpt-4-turbo"), "llama-3-70b");
555 assert_eq!(evaluator.map_model(upstream, "gpt-4o"), "llama-3-70b");
556 assert_eq!(evaluator.map_model(upstream, "gpt-3.5-turbo"), "llama-3-8b");
557 }
558
559 #[test]
560 fn test_no_fallback_when_healthy_and_budget_ok() {
561 let config = create_test_config();
562 let evaluator = FallbackEvaluator::new(&config, &[], 0);
563
564 let decision = evaluator.should_fallback_before_request(
565 "openai-primary",
566 true, false, Some("gpt-4"),
569 );
570
571 assert!(decision.is_none());
572 }
573
574 #[test]
575 fn test_no_fallback_when_max_attempts_reached() {
576 let config = create_test_config();
577 let evaluator = FallbackEvaluator::new(&config, &[], 2); let decision = evaluator.should_fallback_before_request(
581 "openai-primary",
582 false, false,
584 Some("gpt-4"),
585 );
586
587 assert!(decision.is_none());
588 }
589}