1use crate::serve::backends::ServingBackend;
11use serde::{Deserialize, Serialize};
12use std::collections::VecDeque;
13use std::time::{Duration, Instant};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum StreamState {
22 Pending,
24 Streaming,
26 Completed,
28 Failed(String),
30 Recovered,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct StreamingContext {
37 pub prompt: String,
39 pub generated_prefix: String,
41 pub token_count: usize,
43 pub primary_backend: String,
45 pub request_id: String,
47}
48
49impl StreamingContext {
50 #[must_use]
52 pub fn new(prompt: impl Into<String>, request_id: impl Into<String>) -> Self {
53 Self {
54 prompt: prompt.into(),
55 generated_prefix: String::new(),
56 token_count: 0,
57 primary_backend: String::new(),
58 request_id: request_id.into(),
59 }
60 }
61
62 pub fn append(&mut self, tokens: &str) {
64 self.generated_prefix.push_str(tokens);
65 self.token_count += tokens.split_whitespace().count().max(1);
67 }
68
69 #[must_use]
71 pub fn continuation_prompt(&self) -> String {
72 if self.generated_prefix.is_empty() {
73 self.prompt.clone()
74 } else {
75 format!("{}{}", self.prompt, self.generated_prefix)
76 }
77 }
78
79 #[must_use]
81 pub fn worth_recovering(&self) -> bool {
82 self.token_count >= 5 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct FailoverConfig {
93 pub max_retries: u32,
95 pub failover_timeout: Duration,
97 pub min_tokens_for_recovery: usize,
99 pub include_prefix: bool,
101 pub fallback_order: Vec<ServingBackend>,
103}
104
105impl Default for FailoverConfig {
106 fn default() -> Self {
107 Self {
108 max_retries: 2,
109 failover_timeout: Duration::from_secs(30),
110 min_tokens_for_recovery: 5,
111 include_prefix: true,
112 fallback_order: vec![
113 ServingBackend::Realizar,
114 ServingBackend::Ollama,
115 ServingBackend::Together,
116 ServingBackend::Groq,
117 ],
118 }
119 }
120}
121
122#[derive(Debug, Clone)]
128pub struct FailoverAttempt {
129 pub backend: ServingBackend,
130 pub started_at: Instant,
131 pub result: Option<FailoverResult>,
132}
133
134#[derive(Debug, Clone, PartialEq, Eq)]
136pub enum FailoverResult {
137 Success,
138 Timeout,
139 BackendError(String),
140 NoBackendsAvailable,
141}
142
143pub struct FailoverManager {
145 config: FailoverConfig,
146 contexts: std::collections::HashMap<String, StreamingContext>,
148 history: VecDeque<FailoverAttempt>,
150 max_history: usize,
152}
153
154impl FailoverManager {
155 #[must_use]
157 pub fn new(config: FailoverConfig) -> Self {
158 Self {
159 config,
160 contexts: std::collections::HashMap::new(),
161 history: VecDeque::new(),
162 max_history: 100,
163 }
164 }
165
166 #[must_use]
168 pub fn with_defaults() -> Self {
169 Self::new(FailoverConfig::default())
170 }
171
172 pub fn start_tracking(&mut self, request_id: &str, prompt: &str) {
174 let context = StreamingContext::new(prompt, request_id);
175 self.contexts.insert(request_id.to_string(), context);
176 }
177
178 pub fn append_tokens(&mut self, request_id: &str, tokens: &str) {
180 if let Some(ctx) = self.contexts.get_mut(request_id) {
181 ctx.append(tokens);
182 }
183 }
184
185 pub fn complete(&mut self, request_id: &str) {
187 self.contexts.remove(request_id);
188 }
189
190 #[must_use]
192 pub fn get_context(&self, request_id: &str) -> Option<&StreamingContext> {
193 self.contexts.get(request_id)
194 }
195
196 #[must_use]
198 pub fn should_failover(&self, request_id: &str) -> bool {
199 self.contexts.get(request_id).map(|ctx| ctx.worth_recovering()).unwrap_or(false)
200 }
201
202 #[must_use]
204 pub fn next_backend(&self, failed_backend: ServingBackend) -> Option<ServingBackend> {
205 self.config.fallback_order.iter().find(|&&b| b != failed_backend).copied()
206 }
207
208 #[must_use]
210 pub fn prepare_failover(&self, request_id: &str) -> Option<FailoverRequest> {
211 let ctx = self.contexts.get(request_id)?;
212
213 let prompt =
214 if self.config.include_prefix { ctx.continuation_prompt() } else { ctx.prompt.clone() };
215
216 Some(FailoverRequest {
217 request_id: request_id.to_string(),
218 prompt,
219 generated_prefix: ctx.generated_prefix.clone(),
220 token_count: ctx.token_count,
221 })
222 }
223
224 pub fn record_attempt(&mut self, attempt: FailoverAttempt) {
226 self.history.push_back(attempt);
227 while self.history.len() > self.max_history {
228 self.history.pop_front();
229 }
230 }
231
232 #[must_use]
234 pub fn stats(&self) -> FailoverStats {
235 let total = self.history.len();
236 let successes =
237 self.history.iter().filter(|a| a.result == Some(FailoverResult::Success)).count();
238 let timeouts =
239 self.history.iter().filter(|a| a.result == Some(FailoverResult::Timeout)).count();
240
241 FailoverStats {
242 total_attempts: total,
243 successful: successes,
244 timeouts,
245 active_contexts: self.contexts.len(),
246 }
247 }
248
249 #[must_use]
251 pub fn config(&self) -> &FailoverConfig {
252 &self.config
253 }
254}
255
256impl Default for FailoverManager {
257 fn default() -> Self {
258 Self::with_defaults()
259 }
260}
261
262#[derive(Debug, Clone)]
264pub struct FailoverRequest {
265 pub request_id: String,
266 pub prompt: String,
267 pub generated_prefix: String,
268 pub token_count: usize,
269}
270
271#[derive(Debug, Clone, Default)]
273pub struct FailoverStats {
274 pub total_attempts: usize,
275 pub successful: usize,
276 pub timeouts: usize,
277 pub active_contexts: usize,
278}
279
280impl FailoverStats {
281 #[must_use]
283 pub fn success_rate(&self) -> f64 {
284 if self.total_attempts == 0 {
285 0.0
286 } else {
287 (self.successful as f64 / self.total_attempts as f64) * 100.0
288 }
289 }
290}
291
292#[cfg(test)]
297#[allow(non_snake_case)]
298mod tests {
299 use super::*;
300
301 #[test]
306 fn test_SERVE_FLO_001_context_new() {
307 let ctx = StreamingContext::new("Hello, how are you?", "req-123");
308 assert_eq!(ctx.prompt, "Hello, how are you?");
309 assert_eq!(ctx.request_id, "req-123");
310 assert!(ctx.generated_prefix.is_empty());
311 assert_eq!(ctx.token_count, 0);
312 }
313
314 #[test]
315 fn test_SERVE_FLO_001_context_append() {
316 let mut ctx = StreamingContext::new("Test", "req-1");
317 ctx.append("Hello ");
318 ctx.append("world!");
319 assert_eq!(ctx.generated_prefix, "Hello world!");
320 assert!(ctx.token_count > 0);
321 }
322
323 #[test]
324 fn test_SERVE_FLO_001_continuation_prompt() {
325 let mut ctx = StreamingContext::new("Prompt: ", "req-1");
326 ctx.append("Response so far");
327 assert_eq!(ctx.continuation_prompt(), "Prompt: Response so far");
328 }
329
330 #[test]
331 fn test_SERVE_FLO_001_continuation_prompt_empty() {
332 let ctx = StreamingContext::new("Just prompt", "req-1");
333 assert_eq!(ctx.continuation_prompt(), "Just prompt");
334 }
335
336 #[test]
337 fn test_SERVE_FLO_001_worth_recovering() {
338 let mut ctx = StreamingContext::new("Test", "req-1");
339 assert!(!ctx.worth_recovering());
340
341 ctx.append("one two three four five six");
342 assert!(ctx.worth_recovering());
343 }
344
345 #[test]
350 fn test_SERVE_FLO_002_default_config() {
351 let config = FailoverConfig::default();
352 assert_eq!(config.max_retries, 2);
353 assert!(config.include_prefix);
354 assert!(!config.fallback_order.is_empty());
355 }
356
357 #[test]
358 fn test_SERVE_FLO_002_fallback_order() {
359 let config = FailoverConfig::default();
360 assert!(config.fallback_order.contains(&ServingBackend::Realizar));
361 assert!(config.fallback_order.contains(&ServingBackend::Together));
362 }
363
364 #[test]
369 fn test_SERVE_FLO_003_start_tracking() {
370 let mut manager = FailoverManager::with_defaults();
371 manager.start_tracking("req-1", "Test prompt");
372 assert!(manager.get_context("req-1").is_some());
373 }
374
375 #[test]
376 fn test_SERVE_FLO_003_append_tokens() {
377 let mut manager = FailoverManager::with_defaults();
378 manager.start_tracking("req-1", "Prompt");
379 manager.append_tokens("req-1", "Generated");
380
381 let ctx = manager.get_context("req-1").expect("unexpected failure");
382 assert_eq!(ctx.generated_prefix, "Generated");
383 }
384
385 #[test]
386 fn test_SERVE_FLO_003_complete_removes() {
387 let mut manager = FailoverManager::with_defaults();
388 manager.start_tracking("req-1", "Prompt");
389 manager.complete("req-1");
390 assert!(manager.get_context("req-1").is_none());
391 }
392
393 #[test]
394 fn test_SERVE_FLO_003_should_failover() {
395 let mut manager = FailoverManager::with_defaults();
396 manager.start_tracking("req-1", "Prompt");
397
398 assert!(!manager.should_failover("req-1"));
400
401 manager.append_tokens("req-1", "one two three four five six");
403 assert!(manager.should_failover("req-1"));
404 }
405
406 #[test]
411 fn test_SERVE_FLO_004_next_backend_skips_failed() {
412 let manager = FailoverManager::with_defaults();
413 let next = manager.next_backend(ServingBackend::Realizar);
414 assert!(next.is_some());
415 assert_ne!(next.expect("unexpected failure"), ServingBackend::Realizar);
416 }
417
418 #[test]
419 fn test_SERVE_FLO_004_next_backend_order() {
420 let config = FailoverConfig {
421 fallback_order: vec![ServingBackend::Ollama, ServingBackend::Together],
422 ..Default::default()
423 };
424 let manager = FailoverManager::new(config);
425
426 let next = manager.next_backend(ServingBackend::Realizar);
427 assert_eq!(next, Some(ServingBackend::Ollama));
428 }
429
430 #[test]
435 fn test_SERVE_FLO_005_prepare_failover() {
436 let mut manager = FailoverManager::with_defaults();
437 manager.start_tracking("req-1", "Original prompt");
438 manager.append_tokens("req-1", " partial response");
439
440 let request = manager.prepare_failover("req-1").expect("unexpected failure");
441 assert_eq!(request.request_id, "req-1");
442 assert!(request.prompt.contains("Original prompt"));
443 assert!(request.prompt.contains("partial response"));
444 }
445
446 #[test]
447 fn test_SERVE_FLO_005_prepare_failover_not_found() {
448 let manager = FailoverManager::with_defaults();
449 assert!(manager.prepare_failover("nonexistent").is_none());
450 }
451
452 #[test]
457 fn test_SERVE_FLO_006_empty_stats() {
458 let manager = FailoverManager::with_defaults();
459 let stats = manager.stats();
460 assert_eq!(stats.total_attempts, 0);
461 assert_eq!(stats.success_rate(), 0.0);
462 }
463
464 #[test]
465 fn test_SERVE_FLO_006_record_attempt() {
466 let mut manager = FailoverManager::with_defaults();
467 manager.record_attempt(FailoverAttempt {
468 backend: ServingBackend::Together,
469 started_at: Instant::now(),
470 result: Some(FailoverResult::Success),
471 });
472
473 let stats = manager.stats();
474 assert_eq!(stats.total_attempts, 1);
475 assert_eq!(stats.successful, 1);
476 assert_eq!(stats.success_rate(), 100.0);
477 }
478
479 #[test]
480 fn test_SERVE_FLO_006_mixed_results() {
481 let mut manager = FailoverManager::with_defaults();
482
483 manager.record_attempt(FailoverAttempt {
484 backend: ServingBackend::Together,
485 started_at: Instant::now(),
486 result: Some(FailoverResult::Success),
487 });
488 manager.record_attempt(FailoverAttempt {
489 backend: ServingBackend::Groq,
490 started_at: Instant::now(),
491 result: Some(FailoverResult::Timeout),
492 });
493
494 let stats = manager.stats();
495 assert_eq!(stats.total_attempts, 2);
496 assert_eq!(stats.successful, 1);
497 assert_eq!(stats.timeouts, 1);
498 assert_eq!(stats.success_rate(), 50.0);
499 }
500
501 #[test]
506 fn test_SERVE_FLO_007_stream_states() {
507 assert_eq!(StreamState::Pending, StreamState::Pending);
508 assert_ne!(StreamState::Streaming, StreamState::Completed);
509
510 let failed = StreamState::Failed("Connection reset".to_string());
511 if let StreamState::Failed(msg) = failed {
512 assert!(msg.contains("reset"));
513 }
514 }
515
516 #[test]
521 fn test_stream_state_clone() {
522 let state = StreamState::Failed("error".to_string());
523 let cloned = state.clone();
524 assert_eq!(state, cloned);
525 }
526
527 #[test]
528 fn test_stream_state_debug() {
529 let state = StreamState::Recovered;
530 let debug_str = format!("{:?}", state);
531 assert!(debug_str.contains("Recovered"));
532 }
533
534 #[test]
535 fn test_stream_state_all_variants() {
536 let states = [
537 StreamState::Pending,
538 StreamState::Streaming,
539 StreamState::Completed,
540 StreamState::Failed("err".to_string()),
541 StreamState::Recovered,
542 ];
543 assert_eq!(states.len(), 5);
544 }
545
546 #[test]
547 fn test_streaming_context_clone() {
548 let mut ctx = StreamingContext::new("prompt", "req-1");
549 ctx.append("tokens");
550 let cloned = ctx.clone();
551 assert_eq!(ctx.prompt, cloned.prompt);
552 assert_eq!(ctx.generated_prefix, cloned.generated_prefix);
553 }
554
555 #[test]
556 fn test_streaming_context_debug() {
557 let ctx = StreamingContext::new("test prompt", "req-debug");
558 let debug_str = format!("{:?}", ctx);
559 assert!(debug_str.contains("test prompt"));
560 assert!(debug_str.contains("req-debug"));
561 }
562
563 #[test]
564 fn test_streaming_context_serialize() {
565 let ctx = StreamingContext::new("serializable", "req-ser");
566 let json = serde_json::to_string(&ctx).expect("json serialize failed");
567 assert!(json.contains("serializable"));
568 assert!(json.contains("req-ser"));
569 }
570
571 #[test]
572 fn test_streaming_context_deserialize() {
573 let json = r#"{"prompt":"deserialized","generated_prefix":"prefix","token_count":5,"primary_backend":"test","request_id":"req-de"}"#;
574 let ctx: StreamingContext = serde_json::from_str(json).expect("json deserialize failed");
575 assert_eq!(ctx.prompt, "deserialized");
576 assert_eq!(ctx.generated_prefix, "prefix");
577 assert_eq!(ctx.token_count, 5);
578 }
579
580 #[test]
581 fn test_failover_config_clone() {
582 let config = FailoverConfig::default();
583 let cloned = config.clone();
584 assert_eq!(config.max_retries, cloned.max_retries);
585 }
586
587 #[test]
588 fn test_failover_config_debug() {
589 let config = FailoverConfig::default();
590 let debug_str = format!("{:?}", config);
591 assert!(debug_str.contains("max_retries"));
592 }
593
594 #[test]
595 fn test_failover_config_serialize() {
596 let config = FailoverConfig::default();
597 let json = serde_json::to_string(&config).expect("json serialize failed");
598 assert!(json.contains("max_retries"));
599 assert!(json.contains("include_prefix"));
600 }
601
602 #[test]
603 fn test_failover_attempt_clone() {
604 let attempt = FailoverAttempt {
605 backend: ServingBackend::Ollama,
606 started_at: Instant::now(),
607 result: Some(FailoverResult::Success),
608 };
609 let cloned = attempt.clone();
610 assert_eq!(attempt.backend, cloned.backend);
611 }
612
613 #[test]
614 fn test_failover_attempt_debug() {
615 let attempt = FailoverAttempt {
616 backend: ServingBackend::Together,
617 started_at: Instant::now(),
618 result: None,
619 };
620 let debug_str = format!("{:?}", attempt);
621 assert!(debug_str.contains("Together"));
622 }
623
624 #[test]
625 fn test_failover_result_clone() {
626 let result = FailoverResult::BackendError("timeout".to_string());
627 let cloned = result.clone();
628 assert_eq!(result, cloned);
629 }
630
631 #[test]
632 fn test_failover_result_debug() {
633 let result = FailoverResult::NoBackendsAvailable;
634 let debug_str = format!("{:?}", result);
635 assert!(debug_str.contains("NoBackendsAvailable"));
636 }
637
638 #[test]
639 fn test_failover_result_all_variants() {
640 let results = [
641 FailoverResult::Success,
642 FailoverResult::Timeout,
643 FailoverResult::BackendError("err".to_string()),
644 FailoverResult::NoBackendsAvailable,
645 ];
646 assert_eq!(results.len(), 4);
647 assert_eq!(results[0], FailoverResult::Success);
648 }
649
650 #[test]
651 fn test_failover_request_clone() {
652 let request = FailoverRequest {
653 request_id: "req-1".to_string(),
654 prompt: "prompt".to_string(),
655 generated_prefix: "prefix".to_string(),
656 token_count: 10,
657 };
658 let cloned = request.clone();
659 assert_eq!(request.request_id, cloned.request_id);
660 }
661
662 #[test]
663 fn test_failover_request_debug() {
664 let request = FailoverRequest {
665 request_id: "debug-req".to_string(),
666 prompt: "test".to_string(),
667 generated_prefix: "".to_string(),
668 token_count: 0,
669 };
670 let debug_str = format!("{:?}", request);
671 assert!(debug_str.contains("debug-req"));
672 }
673
674 #[test]
675 fn test_failover_stats_clone() {
676 let stats =
677 FailoverStats { total_attempts: 10, successful: 8, timeouts: 1, active_contexts: 2 };
678 let cloned = stats.clone();
679 assert_eq!(stats.total_attempts, cloned.total_attempts);
680 }
681
682 #[test]
683 fn test_failover_stats_debug() {
684 let stats = FailoverStats::default();
685 let debug_str = format!("{:?}", stats);
686 assert!(debug_str.contains("FailoverStats"));
687 }
688
689 #[test]
690 fn test_failover_stats_default() {
691 let stats = FailoverStats::default();
692 assert_eq!(stats.total_attempts, 0);
693 assert_eq!(stats.successful, 0);
694 assert_eq!(stats.timeouts, 0);
695 assert_eq!(stats.active_contexts, 0);
696 }
697
698 #[test]
699 fn test_failover_manager_config() {
700 let config = FailoverConfig { max_retries: 5, ..Default::default() };
701 let manager = FailoverManager::new(config);
702 assert_eq!(manager.config().max_retries, 5);
703 }
704
705 #[test]
706 fn test_failover_manager_default() {
707 let manager = FailoverManager::default();
708 let stats = manager.stats();
709 assert_eq!(stats.total_attempts, 0);
710 }
711
712 #[test]
713 fn test_failover_manager_history_trimming() {
714 let mut manager = FailoverManager::with_defaults();
715 for _ in 0..110 {
717 manager.record_attempt(FailoverAttempt {
718 backend: ServingBackend::Together,
719 started_at: Instant::now(),
720 result: Some(FailoverResult::Success),
721 });
722 }
723 let stats = manager.stats();
724 assert_eq!(stats.total_attempts, 100);
726 }
727
728 #[test]
729 fn test_append_tokens_nonexistent() {
730 let mut manager = FailoverManager::with_defaults();
731 manager.append_tokens("nonexistent", "tokens");
733 assert!(manager.get_context("nonexistent").is_none());
734 }
735
736 #[test]
737 fn test_should_failover_nonexistent() {
738 let manager = FailoverManager::with_defaults();
739 assert!(!manager.should_failover("nonexistent"));
740 }
741
742 #[test]
743 fn test_prepare_failover_without_prefix() {
744 let config = FailoverConfig { include_prefix: false, ..Default::default() };
745 let mut manager = FailoverManager::new(config);
746 manager.start_tracking("req-1", "Original prompt");
747 manager.append_tokens("req-1", " generated");
748
749 let request = manager.prepare_failover("req-1").expect("unexpected failure");
750 assert_eq!(request.prompt, "Original prompt");
752 }
753
754 #[test]
755 fn test_stats_active_contexts() {
756 let mut manager = FailoverManager::with_defaults();
757 manager.start_tracking("req-1", "p1");
758 manager.start_tracking("req-2", "p2");
759 manager.start_tracking("req-3", "p3");
760
761 let stats = manager.stats();
762 assert_eq!(stats.active_contexts, 3);
763
764 manager.complete("req-1");
765 let stats = manager.stats();
766 assert_eq!(stats.active_contexts, 2);
767 }
768
769 #[test]
770 fn test_streaming_context_primary_backend() {
771 let mut ctx = StreamingContext::new("prompt", "req-1");
772 ctx.primary_backend = "realizar".to_string();
773 assert_eq!(ctx.primary_backend, "realizar");
774 }
775}