1pub mod asi;
13pub mod bandit;
14pub mod cascade;
15pub mod reputation;
16pub mod thompson;
17pub mod triage;
18
19use std::collections::HashMap;
20use std::path::Path;
21use std::sync::Arc;
22use std::sync::atomic::{AtomicU64, Ordering};
23
24use parking_lot::Mutex;
25
26use crate::any::AnyProvider;
27use crate::ema::EmaTracker;
28use crate::embed::owned_strs;
29use crate::error::LlmError;
30use crate::provider::{ChatResponse, ChatStream, LlmProvider, Message, StatusTx, ToolDefinition};
31
32use asi::AsiState;
33use bandit::{BanditState, embedding_to_features};
34use cascade::{CascadeState, ClassifierMode, heuristic_score};
35use reputation::ReputationTracker;
36use thompson::ThompsonState;
37use zeph_common::math::cosine_similarity;
38
39#[derive(Debug)]
45struct BanditEmbedCache {
46 map: HashMap<u64, Vec<f32>>,
47 order: std::collections::VecDeque<u64>,
48 capacity: usize,
49}
50
51impl BanditEmbedCache {
52 fn new(capacity: usize) -> Self {
53 Self {
54 map: HashMap::with_capacity(capacity),
55 order: std::collections::VecDeque::with_capacity(capacity),
56 capacity,
57 }
58 }
59
60 fn get(&self, key: u64) -> Option<&Vec<f32>> {
61 self.map.get(&key)
62 }
63
64 fn insert(&mut self, key: u64, value: Vec<f32>) {
65 if self.map.contains_key(&key) {
66 return;
67 }
68 if self.map.len() >= self.capacity
69 && let Some(evict) = self.order.pop_front()
70 {
71 self.map.remove(&evict);
72 }
73 self.map.insert(key, value);
74 self.order.push_back(key);
75 }
76}
77
78impl Default for BanditEmbedCache {
79 fn default() -> Self {
80 Self::new(512)
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
86pub enum RouterStrategy {
87 #[default]
89 Ema,
90 Thompson,
92 Cascade,
94 Bandit,
96}
97
98#[derive(Debug, Clone)]
102#[allow(clippy::doc_markdown)] pub struct BanditRouterConfig {
104 pub alpha: f32,
106 pub dim: usize,
108 pub cost_weight: f32,
111 pub decay_factor: f32,
113 pub warmup_queries: u64,
116 pub embedding_timeout_ms: u64,
119 pub cache_size: usize,
121 pub memory_confidence_threshold: f32,
124}
125
126impl Default for BanditRouterConfig {
127 fn default() -> Self {
128 Self {
129 alpha: 1.0,
130 dim: 32,
131 cost_weight: 0.1,
132 decay_factor: 1.0,
133 warmup_queries: 0, embedding_timeout_ms: 50,
135 cache_size: 512,
136 memory_confidence_threshold: 0.9,
137 }
138 }
139}
140
141#[derive(Debug, Clone)]
146pub struct AsiRouterConfig {
147 pub window: usize,
149 pub coherence_threshold: f32,
151 pub penalty_weight: f32,
153}
154
155impl Default for AsiRouterConfig {
156 fn default() -> Self {
157 Self {
158 window: 5,
159 coherence_threshold: 0.7,
160 penalty_weight: 0.3,
161 }
162 }
163}
164
165#[derive(Debug, Clone)]
167pub struct CascadeRouterConfig {
168 pub quality_threshold: f64,
169 pub max_escalations: u8,
170 pub classifier_mode: ClassifierMode,
171 pub window_size: usize,
172 pub max_cascade_tokens: Option<u32>,
173 pub summary_provider: Option<AnyProvider>,
176 pub cost_tiers: Option<Vec<String>>,
180}
181
182impl Default for CascadeRouterConfig {
183 fn default() -> Self {
184 Self {
185 quality_threshold: 0.5,
186 max_escalations: 2,
187 classifier_mode: ClassifierMode::Heuristic,
188 window_size: 50,
189 max_cascade_tokens: None,
190 summary_provider: None,
191 cost_tiers: None,
192 }
193 }
194}
195
196#[derive(Debug, Clone)]
197pub struct RouterProvider {
198 providers: Arc<[AnyProvider]>,
202 status_tx: Option<StatusTx>,
203 ema: Option<EmaTracker>,
204 provider_order: Arc<Mutex<Vec<usize>>>,
205 strategy: RouterStrategy,
206 thompson: Option<Arc<Mutex<ThompsonState>>>,
207 thompson_state_path: Option<std::path::PathBuf>,
209 cascade_state: Option<Arc<Mutex<CascadeState>>>,
211 cascade_config: Option<CascadeRouterConfig>,
213 reputation: Option<Arc<Mutex<ReputationTracker>>>,
215 reputation_state_path: Option<std::path::PathBuf>,
217 reputation_weight: f64,
219 last_active_provider: Arc<Mutex<Option<String>>>,
222 bandit: Option<Arc<Mutex<BanditState>>>,
224 bandit_state_path: Option<std::path::PathBuf>,
226 bandit_config: Option<BanditRouterConfig>,
228 bandit_embedding_provider: Option<AnyProvider>,
231 bandit_embed_cache: Arc<Mutex<BanditEmbedCache>>,
234 last_memory_confidence: Arc<Mutex<Option<f32>>>,
237 provider_models: Arc<std::collections::HashMap<String, String>>,
240 asi: Option<Arc<Mutex<AsiState>>>,
242 asi_config: Option<AsiRouterConfig>,
244 quality_gate: Option<f32>,
248 turn_counter: Arc<AtomicU64>,
251 asi_last_turn: Arc<AtomicU64>,
254 embed_semaphore: Option<Arc<tokio::sync::Semaphore>>,
256}
257
258impl RouterProvider {
259 #[must_use]
260 pub fn new(providers: Vec<AnyProvider>) -> Self {
261 let n = providers.len();
262 let provider_models: std::collections::HashMap<String, String> = providers
263 .iter()
264 .map(|p| (p.name().to_owned(), p.model_identifier().to_owned()))
265 .collect();
266 Self {
267 providers: Arc::from(providers),
268 status_tx: None,
269 ema: None,
270 provider_order: Arc::new(Mutex::new((0..n).collect())),
271 strategy: RouterStrategy::Ema,
272 thompson: None,
273 thompson_state_path: None,
274 cascade_state: None,
275 cascade_config: None,
276 reputation: None,
277 reputation_state_path: None,
278 reputation_weight: 0.3,
279 last_active_provider: Arc::new(Mutex::new(None)),
280 bandit: None,
281 bandit_state_path: None,
282 bandit_config: None,
283 bandit_embedding_provider: None,
284 bandit_embed_cache: Arc::new(Mutex::new(BanditEmbedCache::default())),
285 last_memory_confidence: Arc::new(Mutex::new(None)),
286 provider_models: Arc::new(provider_models),
287 asi: None,
288 asi_config: None,
289 quality_gate: None,
290 turn_counter: Arc::new(AtomicU64::new(0)),
291 asi_last_turn: Arc::new(AtomicU64::new(u64::MAX)),
292 embed_semaphore: None,
293 }
294 }
295
296 #[must_use]
300 pub fn with_embed_concurrency(mut self, limit: usize) -> Self {
301 self.embed_semaphore = if limit > 0 {
302 Some(Arc::new(tokio::sync::Semaphore::new(limit)))
303 } else {
304 None
305 };
306 self
307 }
308
309 pub fn set_memory_confidence(&self, confidence: Option<f32>) {
314 *self.last_memory_confidence.lock() = confidence;
315 }
316
317 #[must_use]
319 pub fn with_ema(mut self, alpha: f64, reorder_interval: u64) -> Self {
320 self.ema = Some(EmaTracker::new(alpha, reorder_interval));
321 self
322 }
323
324 #[must_use]
331 pub fn with_asi(mut self, config: AsiRouterConfig) -> Self {
332 self.asi = Some(Arc::new(Mutex::new(AsiState::default())));
333 self.asi_config = Some(config);
334 self
335 }
336
337 #[must_use]
344 pub fn with_quality_gate(mut self, threshold: f32) -> Self {
345 self.quality_gate = Some(threshold);
346 self
347 }
348
349 #[must_use]
354 pub fn with_thompson(mut self, state_path: Option<&Path>) -> Self {
355 self.strategy = RouterStrategy::Thompson;
356 let path = state_path.map_or_else(ThompsonState::default_path, Path::to_path_buf);
357 let mut state = ThompsonState::load(&path);
358 let known: std::collections::HashSet<String> =
360 self.providers.iter().map(|p| p.name().to_owned()).collect();
361 state.prune(&known);
362 self.thompson = Some(Arc::new(Mutex::new(state)));
363 self.thompson_state_path = Some(path);
364 self
365 }
366
367 #[must_use]
379 pub fn with_bandit(
380 mut self,
381 mut config: BanditRouterConfig,
382 state_path: Option<&Path>,
383 embedding_provider: Option<AnyProvider>,
384 ) -> Self {
385 self.strategy = RouterStrategy::Bandit;
386 let n = self.providers.len();
387 if config.warmup_queries == 0 {
388 config.warmup_queries = u64::try_from(10 * n.max(1)).unwrap_or(100);
389 }
390 let cache_size = config.cache_size;
391 let path = state_path.map_or_else(BanditState::default_path, Path::to_path_buf);
392 let mut state = BanditState::load(&path);
393 if state.dim == 0 {
394 state = BanditState::new(config.dim);
395 } else if state.dim != config.dim {
396 tracing::warn!(
398 old_dim = state.dim,
399 new_dim = config.dim,
400 "bandit: dim changed, resetting state"
401 );
402 state = BanditState::new(config.dim);
403 }
404 if config.alpha <= 0.0 {
406 tracing::warn!(alpha = config.alpha, "bandit: alpha <= 0, clamping to 0.01");
407 config.alpha = 0.01;
408 }
409 if config.dim == 0 || config.dim > 256 {
410 tracing::warn!(
411 dim = config.dim,
412 "bandit: dim out of range [1, 256], clamping to 32"
413 );
414 config.dim = 32;
415 }
416 if config.decay_factor <= 0.0 || config.decay_factor > 1.0 {
417 tracing::warn!(
418 decay_factor = config.decay_factor,
419 "bandit: decay_factor out of (0.0, 1.0], clamping to 1.0"
420 );
421 config.decay_factor = 1.0;
422 }
423 if config.decay_factor < 1.0 {
424 state.apply_decay(config.decay_factor);
425 }
426 let known: std::collections::HashSet<String> =
427 self.providers.iter().map(|p| p.name().to_owned()).collect();
428 state.prune(&known);
429 self.bandit = Some(Arc::new(Mutex::new(state)));
430 self.bandit_state_path = Some(path);
431 self.bandit_embed_cache = Arc::new(Mutex::new(BanditEmbedCache::new(cache_size)));
432 self.bandit_embedding_provider = embedding_provider;
433 self.thompson = Some(Arc::new(Mutex::new(ThompsonState::default())));
436 self.bandit_config = Some(config);
437 self
438 }
439
440 pub fn save_bandit_state(&self) {
442 let (Some(bandit), Some(path)) = (&self.bandit, &self.bandit_state_path) else {
443 return;
444 };
445 let state = bandit.lock();
446 if let Err(e) = state.save(path) {
447 tracing::warn!(error = %e, "failed to save bandit state");
448 }
449 }
450
451 #[must_use]
455 pub fn bandit_stats(&self) -> Vec<(String, u64, f32)> {
456 let Some(ref bandit) = self.bandit else {
457 return vec![];
458 };
459 let state = bandit.lock();
460 state.stats()
461 }
462
463 #[must_use]
470 pub fn with_reputation(
471 mut self,
472 decay_factor: f64,
473 weight: f64,
474 min_observations: u64,
475 state_path: Option<&Path>,
476 ) -> Self {
477 let path = state_path.map_or_else(ReputationTracker::default_path, Path::to_path_buf);
478 let mut tracker = ReputationTracker::load(&path);
480 let known: std::collections::HashSet<String> =
481 self.providers.iter().map(|p| p.name().to_owned()).collect();
482 tracker.apply_decay();
483 tracker.prune(&known);
484 let tracker = {
486 let stats = tracker.stats();
487 let mut t = ReputationTracker::new(decay_factor, min_observations);
488 for (name, alpha, beta, _, obs) in stats {
489 t.models.insert(
490 name,
491 reputation::ReputationEntry {
492 dist: thompson::BetaDist { alpha, beta },
493 observations: obs,
494 },
495 );
496 }
497 t
498 };
499 self.reputation = Some(Arc::new(Mutex::new(tracker)));
500 self.reputation_state_path = Some(path);
501 self.reputation_weight = weight.clamp(0.0, 1.0);
502 self
503 }
504
505 pub fn record_quality_outcome(&self, _provider_name: &str, success: bool) {
515 if matches!(
516 self.strategy,
517 RouterStrategy::Cascade | RouterStrategy::Bandit
518 ) {
519 return;
522 }
523 let Some(ref reputation) = self.reputation else {
524 return;
525 };
526 let active = self.last_active_provider.lock().clone();
527 let Some(provider_name) = active else {
528 return;
529 };
530 let mut tracker = reputation.lock();
531 tracker.record_quality(&provider_name, success);
532 }
533
534 pub fn save_reputation_state(&self) {
536 let (Some(reputation), Some(path)) = (&self.reputation, &self.reputation_state_path) else {
537 return;
538 };
539 let state = reputation.lock();
540 if let Err(e) = state.save(path) {
541 tracing::warn!(error = %e, "failed to save reputation state");
542 }
543 }
544
545 #[must_use]
547 pub fn reputation_stats(&self) -> Vec<(String, f64, f64, f64, u64)> {
548 let Some(ref reputation) = self.reputation else {
549 return vec![];
550 };
551 let tracker = reputation.lock();
552 tracker.stats()
553 }
554
555 #[must_use]
569 pub fn with_cascade(mut self, config: CascadeRouterConfig) -> Self {
570 self.strategy = RouterStrategy::Cascade;
571
572 if let Some(ref tiers) = config.cost_tiers
573 && !tiers.is_empty()
574 {
575 let tier_pos: std::collections::HashMap<&str, usize> = tiers
576 .iter()
577 .enumerate()
578 .map(|(i, n)| (n.as_str(), i))
579 .collect();
580
581 let before: Vec<_> = self.providers.iter().map(|p| p.name().to_owned()).collect();
582 let mut indexed: Vec<(usize, AnyProvider)> =
583 self.providers.iter().cloned().enumerate().collect();
584 indexed.sort_by_key(|(orig_idx, p)| {
585 tier_pos
586 .get(p.name())
587 .copied()
588 .map_or((1usize, *orig_idx), |t| (0, t))
589 });
590 let after: Vec<_> = indexed.iter().map(|(_, p)| p.name().to_owned()).collect();
591 if before != after {
592 tracing::debug!(
593 before = ?before,
594 after = ?after,
595 "cascade: providers reordered by cost_tiers"
596 );
597 }
598 self.providers = Arc::from(indexed.into_iter().map(|(_, p)| p).collect::<Vec<_>>());
599 }
600
601 let window = config.window_size;
602 self.cascade_state = Some(Arc::new(Mutex::new(CascadeState::new(window))));
603 self.cascade_config = Some(config);
604 self
605 }
606
607 pub fn save_thompson_state(&self) {
617 let (Some(thompson), Some(path)) = (&self.thompson, &self.thompson_state_path) else {
618 return;
619 };
620 let state = thompson.lock();
621 if let Err(e) = state.save(path) {
622 tracing::warn!(error = %e, "failed to save Thompson router state");
623 }
624 }
625
626 fn query_hash(query: &str) -> u64 {
628 use std::hash::{Hash as _, Hasher as _};
629 let mut h = std::collections::hash_map::DefaultHasher::new();
630 query.hash(&mut h);
631 h.finish()
632 }
633
634 async fn bandit_features(&self, query: &str) -> Option<Vec<f32>> {
641 let cfg = self.bandit_config.as_ref()?;
642 let key = Self::query_hash(query);
643
644 {
646 let cache = self.bandit_embed_cache.lock();
647 if let Some(cached) = cache.get(key) {
648 return Some(cached.clone());
649 }
650 }
651
652 let provider = self.bandit_embedding_provider.as_ref()?;
653 let timeout = std::time::Duration::from_millis(cfg.embedding_timeout_ms);
654 let embed_future = provider.embed(query);
655 let embedding = match tokio::time::timeout(timeout, embed_future).await {
656 Ok(Ok(emb)) => emb,
657 Ok(Err(e)) => {
658 tracing::debug!(error = %e, "bandit: embedding failed, falling back");
659 return None;
660 }
661 Err(_) => {
662 tracing::debug!(
663 timeout_ms = cfg.embedding_timeout_ms,
664 "bandit: embedding timed out, falling back"
665 );
666 return None;
667 }
668 };
669
670 let features = embedding_to_features(&embedding, cfg.dim)?;
671
672 {
674 let mut cache = self.bandit_embed_cache.lock();
675 cache.insert(key, features.clone());
676 }
677 Some(features)
678 }
679
680 async fn bandit_select_provider(&self, query: &str) -> Option<AnyProvider> {
686 let Some(ref bandit_arc) = self.bandit else {
687 return self.providers.first().cloned();
688 };
689 let cfg = self.bandit_config.as_ref()?;
690
691 let names: Vec<String> = self.providers.iter().map(|p| p.name().to_owned()).collect();
692
693 if let Some(features) = self.bandit_features(query).await {
695 let memory_confidence = self.last_memory_confidence.lock().as_ref().copied();
696 let selected = {
697 let state = bandit_arc.lock();
698 state.select(
699 &names,
700 &features,
701 cfg.alpha,
702 cfg.warmup_queries,
703 &|_| true,
704 cfg.cost_weight,
705 &self.provider_models,
706 memory_confidence,
707 cfg.memory_confidence_threshold,
708 )
709 };
710 if let Some(name) = selected {
711 tracing::debug!(
712 provider = %name,
713 strategy = "bandit",
714 memory_confidence = ?memory_confidence,
715 "selected provider"
716 );
717 return self.providers.iter().find(|p| p.name() == name).cloned();
718 }
719 }
720
721 if let Some(ref thompson) = self.thompson {
723 let mut state = thompson.lock();
724 if let Some(sel) = state.select(&names) {
725 tracing::debug!(
726 provider = %sel.provider,
727 strategy = "bandit-fallback-thompson",
728 "selected provider"
729 );
730 return self
731 .providers
732 .iter()
733 .find(|p| p.name() == sel.provider)
734 .cloned();
735 }
736 }
737
738 self.providers.first().cloned()
740 }
741
742 fn bandit_record_reward(
747 &self,
748 provider_name: &str,
749 features: &[f32],
750 quality_score: f64,
751 cost_fraction: f64,
752 ) {
753 let Some(ref bandit_arc) = self.bandit else {
754 return;
755 };
756 let Some(cfg) = &self.bandit_config else {
757 return;
758 };
759 #[allow(clippy::cast_possible_truncation)]
760 let reward = (quality_score as f32) - cfg.cost_weight * (cost_fraction as f32);
761 let reward = reward.clamp(-1.0, 1.0);
762 let mut state = bandit_arc.lock();
763 state.update(provider_name, features, reward);
764 tracing::debug!(
765 provider = provider_name,
766 reward,
767 quality = quality_score,
768 "bandit: recorded reward"
769 );
770 }
771
772 fn ordered_providers(&self) -> Vec<AnyProvider> {
773 match self.strategy {
774 RouterStrategy::Thompson => self.thompson_ordered_providers(),
775 RouterStrategy::Ema => self.ema_ordered_providers(),
776 RouterStrategy::Cascade | RouterStrategy::Bandit => self.providers.to_vec(),
780 }
781 }
782
783 fn ema_ordered_providers(&self) -> Vec<AnyProvider> {
784 let order = self.provider_order.lock();
785 let mut ordered: Vec<AnyProvider> = order
786 .iter()
787 .filter_map(|&i| self.providers.get(i).cloned())
788 .collect();
789
790 if let Some(ref reputation) = self.reputation
797 && let Some(ref ema) = self.ema
798 {
799 let rep = reputation.lock();
800 let w = self.reputation_weight;
801 let snap = ema.snapshot();
802 let mut scored: Vec<(usize, f64)> = ordered
803 .iter()
804 .enumerate()
805 .map(|(idx, p)| {
806 let ema_score = snap
807 .get(p.name())
808 .map_or(0.0, |s| s.success_ema - s.latency_ema_ms / 10_000.0);
809 let score = if let Some(rep_factor) = rep.ema_reputation_factor(p.name()) {
810 let adjustment = 1.0 + w * (rep_factor - 0.5) * 2.0;
812 ema_score * adjustment
813 } else {
814 ema_score
815 };
816 (idx, score)
817 })
818 .collect();
819 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
820 let reordered: Vec<AnyProvider> = scored
821 .into_iter()
822 .filter_map(|(idx, _)| ordered.get(idx).cloned())
823 .collect();
824 ordered = reordered;
825 }
826
827 if let (Some(asi_arc), Some(asi_cfg)) = (&self.asi, &self.asi_config) {
829 let asi: parking_lot::MutexGuard<'_, AsiState> = asi_arc.lock();
830 let snap = self.ema.as_ref().map(EmaTracker::snapshot);
831 let mut scored: Vec<(usize, f64)> = ordered
832 .iter()
833 .enumerate()
834 .map(|(idx, p)| {
835 let coherence = asi.coherence(p.name());
836 if coherence < asi_cfg.coherence_threshold {
837 tracing::warn!(
838 provider = p.name(),
839 coherence,
840 threshold = asi_cfg.coherence_threshold,
841 "asi: coherence below threshold"
842 );
843 }
844 let base_score = snap
845 .as_ref()
846 .and_then(|s| s.get(p.name()))
847 .map_or(0.0, |s| s.success_ema - s.latency_ema_ms / 10_000.0);
848 let multiplier = (coherence / asi_cfg.coherence_threshold).clamp(0.5, 1.0);
850 #[allow(clippy::cast_possible_truncation)]
851 let adjusted = base_score * f64::from(multiplier);
852 (idx, adjusted)
853 })
854 .collect();
855 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
856 let reordered: Vec<AnyProvider> = scored
857 .into_iter()
858 .filter_map(|(idx, _)| ordered.get(idx).cloned())
859 .collect();
860 ordered = reordered;
861 }
862
863 if let Some(first) = ordered.first() {
864 tracing::debug!(
865 provider = %first.name(),
866 strategy = "ema",
867 "selected provider"
868 );
869 }
870 ordered
871 }
872
873 fn thompson_ordered_providers(&self) -> Vec<AnyProvider> {
874 let Some(ref thompson) = self.thompson else {
875 return self.providers.to_vec();
876 };
877 let mut state = thompson.lock();
878 let names: Vec<String> = self.providers.iter().map(|p| p.name().to_owned()).collect();
879
880 let has_reputation = self.reputation.is_some();
883 let has_asi = self.asi.is_some() && self.asi_config.is_some();
884
885 let selected = if has_reputation || has_asi {
886 let rep_guard = self.reputation.as_ref().map(|r| r.lock());
888 let asi_guard: Option<parking_lot::MutexGuard<'_, AsiState>> =
889 self.asi.as_ref().map(|a| a.lock());
890 let w = self.reputation_weight;
891
892 let overrides: std::collections::HashMap<String, (f64, f64)> = names
893 .iter()
894 .map(|name| {
895 let base = state.get_distribution(name);
896 let (alpha, mut beta) = if let Some(ref rep) = rep_guard {
898 rep.shift_thompson_priors(name, base.alpha, base.beta, w)
899 } else {
900 (base.alpha, base.beta)
901 };
902 if let (Some(asi), Some(asi_cfg)) = (&asi_guard, &self.asi_config) {
904 let coherence = asi.coherence(name);
905 if coherence < asi_cfg.coherence_threshold {
906 tracing::warn!(
907 provider = name.as_str(),
908 coherence,
909 threshold = asi_cfg.coherence_threshold,
910 "asi: coherence below threshold"
911 );
912 let deficit = asi_cfg.coherence_threshold - coherence;
913 let penalty = f64::from(asi_cfg.penalty_weight * deficit);
914 beta += penalty;
915 }
916 }
917 (name.clone(), (alpha, beta))
918 })
919 .collect();
920
921 drop(rep_guard);
922 drop(asi_guard);
923 state.select_with_priors(&names, &overrides)
924 } else {
925 state.select(&names)
926 };
927
928 if let Some(ref sel) = selected {
929 tracing::debug!(
930 provider = %sel.provider,
931 strategy = "thompson",
932 mode = if sel.exploit { "exploit" } else { "explore" },
933 alpha = sel.alpha,
934 beta = sel.beta,
935 "selected provider"
936 );
937 }
938 let mut ordered = self.providers.to_vec();
940 if let Some(ref sel) = selected
941 && let Some(pos) = ordered.iter().position(|p| p.name() == sel.provider)
942 {
943 ordered.swap(0, pos);
944 }
945 ordered
946 }
947
948 fn record_availability(&self, provider_name: &str, success: bool, latency_ms: u64) {
954 match self.strategy {
955 RouterStrategy::Thompson => {
956 if let Some(ref thompson) = self.thompson {
957 let mut state = thompson.lock();
958 state.update(provider_name, success);
959 }
960 }
961 RouterStrategy::Ema => {
962 self.ema_record(provider_name, success, latency_ms);
963 }
964 RouterStrategy::Cascade | RouterStrategy::Bandit => {
965 }
968 }
969 }
970
971 fn ema_record(&self, provider_name: &str, success: bool, latency_ms: u64) {
972 let Some(ref ema) = self.ema else {
973 return;
974 };
975 ema.record(provider_name, success, latency_ms);
976 let current_names: Vec<String> =
977 self.providers.iter().map(|p| p.name().to_owned()).collect();
978 if let Some(new_order_names) = ema.maybe_reorder(¤t_names) {
979 let name_to_idx: std::collections::HashMap<&str, usize> = self
980 .providers
981 .iter()
982 .enumerate()
983 .map(|(i, p)| (p.name(), i))
984 .collect();
985 let new_order: Vec<usize> = new_order_names
986 .iter()
987 .filter_map(|n| name_to_idx.get(n.as_str()).copied())
988 .collect();
989 let mut order = self.provider_order.lock();
990 *order = new_order;
991 }
992 }
993
994 #[must_use]
998 pub fn thompson_stats(&self) -> Vec<(String, f64, f64)> {
999 let Some(ref thompson) = self.thompson else {
1000 return vec![];
1001 };
1002 let state = thompson.lock();
1003 state.provider_stats()
1004 }
1005
1006 pub fn set_status_tx(&mut self, tx: StatusTx) {
1007 if let Some(providers) = Arc::get_mut(&mut self.providers) {
1008 for p in providers {
1009 p.set_status_tx(tx.clone());
1010 }
1011 } else {
1012 let mut v: Vec<_> = self.providers.iter().cloned().collect();
1014 for p in &mut v {
1015 p.set_status_tx(tx.clone());
1016 }
1017 self.providers = Arc::from(v);
1018 }
1019 self.status_tx = Some(tx);
1020 }
1021
1022 pub async fn list_models_remote(
1030 &self,
1031 ) -> Result<Vec<crate::model_cache::RemoteModelInfo>, LlmError> {
1032 let mut seen = std::collections::HashSet::new();
1033 let mut all = Vec::new();
1034 for p in self.providers.iter() {
1035 match p.list_models_remote().await {
1036 Ok(models) => {
1037 for m in models {
1038 if seen.insert(m.id.clone()) {
1039 all.push(m);
1040 }
1041 }
1042 }
1043 Err(e) => {
1044 tracing::warn!(error = %e, "router: list_models_remote sub-provider failed");
1045 }
1046 }
1047 }
1048 Ok(all)
1049 }
1050
1051 fn evaluate_heuristic(response: &str, threshold: f64) -> cascade::QualityVerdict {
1053 let mut verdict = heuristic_score(response);
1054 verdict.should_escalate = verdict.score < threshold;
1055 verdict
1056 }
1057
1058 async fn evaluate_quality(
1063 response: &str,
1064 threshold: f64,
1065 mode: ClassifierMode,
1066 summary_provider: Option<&AnyProvider>,
1067 ) -> cascade::QualityVerdict {
1068 if mode == ClassifierMode::Judge {
1069 if let Some(judge) = summary_provider {
1070 match cascade::judge_score(judge, response).await {
1071 Some(score) => {
1072 let should_escalate = score < threshold;
1073 tracing::debug!(
1074 score,
1075 threshold,
1076 should_escalate,
1077 "cascade: judge scored response"
1078 );
1079 return cascade::QualityVerdict {
1080 score,
1081 should_escalate,
1082 reason: format!("judge score: {score:.2}"),
1083 };
1084 }
1085 None => {
1086 tracing::warn!("cascade: judge call failed, falling back to heuristic");
1087 }
1088 }
1089 } else {
1090 tracing::warn!(
1091 "cascade: classifier_mode=judge but no summary_provider configured, \
1092 using heuristic"
1093 );
1094 }
1095 }
1096 Self::evaluate_heuristic(response, threshold)
1097 }
1098}
1099
1100const EMBED_MAX_RETRIES: u32 = 3;
1101const EMBED_BASE_DELAY_MS: u64 = 500;
1102
1103impl RouterProvider {
1104 fn spawn_asi_update(&self, provider: &str, response: String, turn_id: u64) {
1113 let prev = self.asi_last_turn.swap(turn_id, Ordering::AcqRel);
1117 if prev == turn_id {
1118 return;
1119 }
1120
1121 let Some(ref asi_arc) = self.asi else { return };
1122 let Some(ref asi_cfg) = self.asi_config else {
1123 return;
1124 };
1125 let asi = Arc::clone(asi_arc);
1126 let router = self.clone();
1127 let window_size = asi_cfg.window;
1128 let provider_name = provider.to_owned();
1129 tokio::spawn(async move {
1130 match router.embed(&response).await {
1131 Ok(emb) => {
1132 let mut state = asi.lock();
1133 state.push_embedding(&provider_name, emb, window_size);
1134 }
1135 Err(e) => {
1136 tracing::debug!(
1137 provider = provider_name,
1138 error = %e,
1139 "asi: embed failed, skipping coherence update"
1140 );
1141 }
1142 }
1143 });
1144 }
1145}
1146
1147impl LlmProvider for RouterProvider {
1148 fn context_window(&self) -> Option<usize> {
1149 self.providers.first().and_then(LlmProvider::context_window)
1150 }
1151
1152 fn chat(
1153 &self,
1154 messages: &[Message],
1155 ) -> impl std::future::Future<Output = Result<String, LlmError>> + Send {
1156 let status_tx = self.status_tx.clone();
1157 let messages = messages.to_vec();
1158 let router = self.clone();
1159 Box::pin(async move {
1162 let turn_id = router.turn_counter.fetch_add(1, Ordering::Relaxed);
1166
1167 if router.strategy == RouterStrategy::Cascade {
1168 return router
1171 .cascade_chat(&router.providers, &messages, status_tx)
1172 .await;
1173 }
1174 if router.strategy == RouterStrategy::Bandit {
1175 return router.bandit_chat(&messages, status_tx).await;
1176 }
1177 let providers = router.ordered_providers();
1178
1179 let query_text = messages
1181 .last()
1182 .map(Message::to_llm_content)
1183 .unwrap_or_default();
1184 let query_embedding = if router.quality_gate.is_some() && !query_text.is_empty() {
1185 router.embed(query_text).await.ok()
1186 } else {
1187 None
1188 };
1189
1190 let mut best_response: Option<(f32, String)> = None;
1192
1193 for p in &providers {
1194 let start = std::time::Instant::now();
1195 match p.chat(&messages).await {
1196 Ok(r) => {
1197 router.record_availability(
1198 p.name(),
1199 true,
1200 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1201 );
1202
1203 if let (Some(threshold), Some(qemb)) =
1205 (router.quality_gate, &query_embedding)
1206 {
1207 let resp_emb = router.embed(&r).await.ok();
1208 let similarity = resp_emb
1209 .as_ref()
1210 .map_or(threshold, |e| cosine_similarity(qemb, e)); if similarity < threshold {
1212 tracing::info!(
1213 provider = p.name(),
1214 score = similarity,
1215 threshold,
1216 "thompson_quality_fallback"
1217 );
1218 let is_better = best_response
1220 .as_ref()
1221 .is_none_or(|(best, _)| similarity > *best);
1222 if is_better {
1223 best_response = Some((similarity, r.clone()));
1224 }
1225 router.spawn_asi_update(p.name(), r, turn_id);
1227 continue;
1228 }
1229 }
1230
1231 router.spawn_asi_update(p.name(), r.clone(), turn_id);
1233
1234 return Ok(r);
1235 }
1236 Err(e) => {
1237 router.record_availability(
1238 p.name(),
1239 false,
1240 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1241 );
1242 if let Some(ref tx) = status_tx {
1243 let _ = tx.send(format!("router: {} failed, falling back", p.name()));
1244 }
1245 tracing::warn!(provider = p.name(), error = %e, "router fallback");
1246 }
1247 }
1248 }
1249
1250 if let Some((_, response)) = best_response {
1252 return Ok(response);
1253 }
1254
1255 Err(LlmError::NoProviders)
1256 })
1257 }
1258
1259 fn chat_stream(
1260 &self,
1261 messages: &[Message],
1262 ) -> impl std::future::Future<Output = Result<ChatStream, LlmError>> + Send {
1263 let status_tx = self.status_tx.clone();
1264 let messages = messages.to_vec();
1265 let router = self.clone();
1266 Box::pin(async move {
1267 if router.strategy == RouterStrategy::Cascade {
1268 return router
1270 .cascade_chat_stream(&router.providers, &messages, status_tx)
1271 .await;
1272 }
1273 if router.strategy == RouterStrategy::Bandit {
1274 let query = messages
1278 .last()
1279 .map(super::provider::Message::to_llm_content)
1280 .unwrap_or_default();
1281 let p = router
1282 .bandit_select_provider(query)
1283 .await
1284 .ok_or(LlmError::NoProviders)?;
1285 return p.chat_stream(&messages).await;
1286 }
1287 let providers = router.ordered_providers();
1288 for p in &providers {
1289 let start = std::time::Instant::now();
1290 match p.chat_stream(&messages).await {
1291 Ok(r) => {
1292 router.record_availability(
1299 p.name(),
1300 true,
1301 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1302 );
1303 return Ok(r);
1304 }
1305 Err(e) => {
1306 router.record_availability(
1307 p.name(),
1308 false,
1309 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1310 );
1311 if let Some(ref tx) = status_tx {
1312 let _ = tx.send(format!("router: {} failed, falling back", p.name()));
1313 }
1314 tracing::warn!(provider = p.name(), error = %e, "router stream fallback");
1315 }
1316 }
1317 }
1318 Err(LlmError::NoProviders)
1319 })
1320 }
1321
1322 fn supports_streaming(&self) -> bool {
1323 self.providers.iter().any(LlmProvider::supports_streaming)
1324 }
1325
1326 fn embed(
1327 &self,
1328 text: &str,
1329 ) -> impl std::future::Future<Output = Result<Vec<f32>, LlmError>> + Send {
1330 let providers = self.ordered_providers();
1331 let status_tx = self.status_tx.clone();
1332 let text = text.to_owned();
1333 let router = self.clone();
1334 Box::pin(async move {
1335 for p in &providers {
1336 if !p.supports_embeddings() {
1337 continue;
1338 }
1339 let mut last_err: Option<LlmError> = None;
1340 for attempt in 0..=EMBED_MAX_RETRIES {
1341 if attempt > 0 {
1342 let delay = EMBED_BASE_DELAY_MS * (1u64 << (attempt - 1));
1343 tracing::warn!(
1344 provider = p.name(),
1345 attempt,
1346 delay_ms = delay,
1347 "embed: rate limited, retrying after backoff"
1348 );
1349 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
1350 }
1351 let start = std::time::Instant::now();
1352 match p.embed(&text).await {
1353 Ok(r) => {
1354 router.record_availability(
1355 p.name(),
1356 true,
1357 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1358 );
1359 return Ok(r);
1360 }
1361 Err(e) if e.is_invalid_input() => {
1362 tracing::warn!(
1365 provider = p.name(),
1366 error = %e,
1367 "embed: invalid input, not retrying on other providers"
1368 );
1369 return Err(e);
1370 }
1371 Err(e) if e.is_rate_limited() && attempt < EMBED_MAX_RETRIES => {
1372 last_err = Some(e);
1373 }
1374 Err(e) => {
1375 router.record_availability(
1376 p.name(),
1377 false,
1378 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1379 );
1380 if let Some(ref tx) = status_tx {
1381 let _ = tx.send(format!(
1382 "router: {} embed failed, falling back",
1383 p.name()
1384 ));
1385 }
1386 tracing::warn!(provider = p.name(), error = %e, "router embed fallback");
1387 last_err = Some(e);
1388 break;
1389 }
1390 }
1391 }
1392 if matches!(last_err, Some(ref e) if e.is_rate_limited()) {
1394 router.record_availability(p.name(), false, 0);
1395 if let Some(ref tx) = status_tx {
1396 let _ = tx.send(format!(
1397 "router: {} embed rate limited, falling back",
1398 p.name()
1399 ));
1400 }
1401 tracing::warn!(
1402 provider = p.name(),
1403 "embed: rate limit retries exhausted, falling back"
1404 );
1405 }
1406 }
1407 Err(LlmError::NoProviders)
1408 })
1409 }
1410
1411 fn embed_batch(
1412 &self,
1413 texts: &[&str],
1414 ) -> impl std::future::Future<Output = Result<Vec<Vec<f32>>, LlmError>> + Send {
1415 let providers = self.ordered_providers();
1416 let status_tx = self.status_tx.clone();
1417 let owned = owned_strs(texts);
1418 let router = self.clone();
1419 let semaphore = self.embed_semaphore.clone();
1420 Box::pin(async move {
1421 let _permit = if let Some(ref sem) = semaphore {
1423 Some(sem.acquire().await.map_err(|_| LlmError::NoProviders)?)
1424 } else {
1425 None
1426 };
1427 let refs: Vec<&str> = owned.iter().map(String::as_str).collect();
1428 for p in &providers {
1429 if !p.supports_embeddings() {
1430 continue;
1431 }
1432 let mut last_err: Option<LlmError> = None;
1433 for attempt in 0..=EMBED_MAX_RETRIES {
1434 if attempt > 0 {
1435 let delay = EMBED_BASE_DELAY_MS * (1u64 << (attempt - 1));
1436 tracing::warn!(
1437 provider = p.name(),
1438 attempt,
1439 delay_ms = delay,
1440 "embed_batch: rate limited, retrying after backoff"
1441 );
1442 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
1443 }
1444 let start = std::time::Instant::now();
1445 match p.embed_batch(&refs).await {
1446 Ok(r) => {
1447 router.record_availability(
1448 p.name(),
1449 true,
1450 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1451 );
1452 return Ok(r);
1453 }
1454 Err(e) if e.is_invalid_input() => {
1455 tracing::warn!(
1456 provider = p.name(),
1457 error = %e,
1458 "embed_batch: invalid input, not retrying on other providers"
1459 );
1460 return Err(e);
1461 }
1462 Err(e) if e.is_rate_limited() && attempt < EMBED_MAX_RETRIES => {
1463 last_err = Some(e);
1464 }
1465 Err(e) => {
1466 router.record_availability(
1467 p.name(),
1468 false,
1469 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1470 );
1471 if let Some(ref tx) = status_tx {
1472 let _ = tx.send(format!(
1473 "router: {} embed_batch failed, falling back",
1474 p.name()
1475 ));
1476 }
1477 tracing::warn!(
1478 provider = p.name(),
1479 error = %e,
1480 "router embed_batch fallback"
1481 );
1482 last_err = Some(e);
1483 break;
1484 }
1485 }
1486 }
1487 if matches!(last_err, Some(ref e) if e.is_rate_limited()) {
1489 router.record_availability(p.name(), false, 0);
1490 if let Some(ref tx) = status_tx {
1491 let _ = tx.send(format!(
1492 "router: {} embed_batch rate limited, falling back",
1493 p.name()
1494 ));
1495 }
1496 tracing::warn!(
1497 provider = p.name(),
1498 "embed_batch: rate limit retries exhausted, falling back"
1499 );
1500 }
1501 }
1502 Err(LlmError::NoProviders)
1503 })
1504 }
1505
1506 fn supports_embeddings(&self) -> bool {
1507 self.providers.iter().any(LlmProvider::supports_embeddings)
1508 }
1509
1510 #[allow(clippy::unnecessary_literal_bound)]
1511 fn name(&self) -> &str {
1512 "router"
1513 }
1514
1515 fn supports_tool_use(&self) -> bool {
1516 self.providers.iter().any(LlmProvider::supports_tool_use)
1517 }
1518
1519 fn list_models(&self) -> Vec<String> {
1520 self.providers
1521 .iter()
1522 .flat_map(super::provider::LlmProvider::list_models)
1523 .collect()
1524 }
1525
1526 #[allow(refining_impl_trait_reachable)]
1527 fn chat_with_tools(
1528 &self,
1529 messages: &[Message],
1530 tools: &[ToolDefinition],
1531 ) -> impl std::future::Future<Output = Result<ChatResponse, LlmError>> + Send {
1532 let messages = messages.to_vec();
1533 let tools = tools.to_vec();
1534 let status_tx = self.status_tx.clone();
1535 let router = self.clone();
1536 Box::pin(async move {
1537 if router.strategy == RouterStrategy::Bandit {
1539 let query = messages
1540 .last()
1541 .map(super::provider::Message::to_llm_content)
1542 .unwrap_or_default();
1543 let p = router
1544 .bandit_select_provider(query)
1545 .await
1546 .ok_or(LlmError::NoProviders)?;
1547 if !p.supports_tool_use() {
1548 return Err(LlmError::NoProviders);
1549 }
1550 let result = p.chat_with_tools(&messages, &tools).await;
1551 if result.is_ok() {
1552 *router.last_active_provider.lock() = Some(p.name().to_owned());
1553 }
1554 return result;
1555 }
1556
1557 let providers = router.ordered_providers();
1562 for p in &providers {
1563 if !p.supports_tool_use() {
1564 continue;
1565 }
1566 let start = std::time::Instant::now();
1567 match p.chat_with_tools(&messages, &tools).await {
1568 Ok(r) => {
1569 router.record_availability(
1570 p.name(),
1571 true,
1572 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1573 );
1574 *router.last_active_provider.lock() = Some(p.name().to_owned());
1576 return Ok(r);
1577 }
1578 Err(e) => {
1579 router.record_availability(
1580 p.name(),
1581 false,
1582 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1583 );
1584 if let Some(ref tx) = status_tx {
1585 let _ = tx.send(format!(
1586 "router: {} tool call failed, falling back",
1587 p.name()
1588 ));
1589 }
1590 tracing::warn!(provider = p.name(), error = %e, "router tool fallback");
1591 }
1592 }
1593 }
1594 Err(LlmError::NoProviders)
1595 })
1596 }
1597
1598 fn debug_request_json(
1599 &self,
1600 messages: &[Message],
1601 tools: &[ToolDefinition],
1602 stream: bool,
1603 ) -> serde_json::Value {
1604 let candidate = if tools.is_empty() {
1605 self.ordered_providers().into_iter().next()
1606 } else {
1607 self.ordered_providers()
1608 .into_iter()
1609 .find(super::provider::LlmProvider::supports_tool_use)
1610 };
1611 candidate.map_or_else(
1612 || crate::provider::default_debug_request_json(messages, tools),
1613 |provider| provider.debug_request_json(messages, tools, stream),
1614 )
1615 }
1616
1617 fn last_cache_usage(&self) -> Option<(u64, u64)> {
1618 None
1619 }
1620}
1621
1622impl RouterProvider {
1625 async fn bandit_chat(
1627 &self,
1628 messages: &[Message],
1629 status_tx: Option<StatusTx>,
1630 ) -> Result<String, LlmError> {
1631 let query = messages
1632 .last()
1633 .map(super::provider::Message::to_llm_content)
1634 .unwrap_or_default();
1635 let features = self.bandit_features(query.as_ref()).await;
1636
1637 let p = self
1638 .bandit_select_provider(query.as_ref())
1639 .await
1640 .ok_or(LlmError::NoProviders)?;
1641
1642 if let Some(ref tx) = status_tx {
1643 let _ = tx.send(format!("bandit: routing to {}", p.name()));
1644 }
1645
1646 let result = p.chat(messages).await;
1647 match &result {
1648 Ok(response) => {
1649 let verdict = heuristic_score(response);
1650 let feat_ref: &[f32];
1653 let zero_vec: Vec<f32>;
1654 let dim = self.bandit_config.as_ref().map_or(32, |c| c.dim);
1655 if let Some(ref feat) = features {
1656 feat_ref = feat;
1657 } else {
1658 zero_vec = vec![0.0; dim];
1659 feat_ref = &zero_vec;
1660 tracing::debug!(
1661 provider = p.name(),
1662 "bandit: recording reward with zero features (embed unavailable)"
1663 );
1664 }
1665 self.bandit_record_reward(p.name(), feat_ref, verdict.score, 0.0);
1666 }
1667 Err(e) => {
1668 tracing::warn!(provider = p.name(), error = %e, "bandit: provider failed");
1669 }
1670 }
1671 result
1672 }
1673}
1674
1675struct CascadeEvalResult {
1679 verdict: cascade::QualityVerdict,
1680 tokens_used: u32,
1682 budget_exhausted: bool,
1684}
1685
1686async fn cascade_evaluate_response(
1689 provider_name: &str,
1690 response: &str,
1691 cfg: &CascadeRouterConfig,
1692 cascade_state: &Mutex<CascadeState>,
1693 tokens_used_before: u32,
1694 log_prefix: &str,
1695) -> CascadeEvalResult {
1696 let estimated_tokens =
1697 u32::try_from(zeph_common::text::estimate_tokens(response).max(1)).unwrap_or(u32::MAX);
1698 let tokens_used = tokens_used_before.saturating_add(estimated_tokens);
1699
1700 let verdict = RouterProvider::evaluate_quality(
1701 response,
1702 cfg.quality_threshold,
1703 cfg.classifier_mode,
1704 cfg.summary_provider.as_ref(),
1705 )
1706 .await;
1707
1708 {
1709 let mut state = cascade_state.lock();
1710 state.record(provider_name, verdict.score);
1711 }
1712
1713 tracing::debug!(
1714 provider = %provider_name,
1715 score = verdict.score,
1716 threshold = cfg.quality_threshold,
1717 should_escalate = verdict.should_escalate,
1718 reason = %verdict.reason,
1719 "{log_prefix}: quality verdict"
1720 );
1721
1722 let budget_exhausted = cfg
1723 .max_cascade_tokens
1724 .is_some_and(|budget| tokens_used >= budget);
1725
1726 CascadeEvalResult {
1727 verdict,
1728 tokens_used,
1729 budget_exhausted,
1730 }
1731}
1732
1733impl RouterProvider {
1734 #[allow(clippy::too_many_lines)] async fn cascade_chat(
1739 &self,
1740 providers: &[AnyProvider],
1741 messages: &[Message],
1742 status_tx: Option<StatusTx>,
1743 ) -> Result<String, LlmError> {
1744 let cfg = self
1745 .cascade_config
1746 .as_ref()
1747 .expect("cascade_config must be set");
1748 let cascade_state = self
1749 .cascade_state
1750 .as_ref()
1751 .expect("cascade_state must be set");
1752
1753 let mut escalations_remaining = cfg.max_escalations;
1754 let mut best: Option<(String, f64)> = None; let mut tokens_used: u32 = 0;
1756
1757 for (idx, p) in providers.iter().enumerate() {
1758 tracing::debug!(
1759 provider = %p.name(),
1760 attempt = idx + 1,
1761 total = providers.len(),
1762 classifier_mode = ?cfg.classifier_mode,
1763 quality_threshold = cfg.quality_threshold,
1764 "cascade: trying provider"
1765 );
1766 let start = std::time::Instant::now();
1767 match p.chat(messages).await {
1768 Err(e) => {
1769 let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
1771 self.record_availability(p.name(), false, latency);
1772 if let Some(tx) = &status_tx {
1773 let _ = tx.send(format!("cascade: {} unavailable, trying next", p.name()));
1774 }
1775 tracing::warn!(provider = p.name(), error = %e, "cascade: provider error");
1776 }
1777 Ok(response) => {
1778 let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
1779
1780 let eval = cascade_evaluate_response(
1781 p.name(),
1782 &response,
1783 cfg,
1784 cascade_state,
1785 tokens_used,
1786 "cascade",
1787 )
1788 .await;
1789 tokens_used = eval.tokens_used;
1790 let verdict = eval.verdict;
1791 let budget_exhausted = eval.budget_exhausted;
1792
1793 let is_better = !response.is_empty()
1795 && best
1796 .as_ref()
1797 .is_none_or(|(_, best_score)| verdict.score > *best_score);
1798 if is_better {
1799 tracing::debug!(
1800 provider = %p.name(),
1801 score = verdict.score,
1802 "cascade: best_seen updated"
1803 );
1804 best = Some((response.clone(), verdict.score));
1805 }
1806
1807 let is_last = idx == providers.len() - 1;
1808
1809 if !verdict.should_escalate
1810 || is_last
1811 || escalations_remaining == 0
1812 || budget_exhausted
1813 {
1814 self.record_availability(p.name(), true, latency);
1815 if verdict.should_escalate
1820 && (budget_exhausted || escalations_remaining == 0)
1821 {
1822 let best_response = best.take().map_or(response, |(r, _)| r);
1823 tracing::info!(
1824 tokens_used,
1825 budget = cfg.max_cascade_tokens,
1826 escalations_remaining,
1827 "cascade: escalation blocked, returning best response"
1828 );
1829 return Ok(best_response);
1830 }
1831 return Ok(response);
1832 }
1833
1834 self.record_availability(p.name(), true, latency);
1836 escalations_remaining -= 1;
1837
1838 if let Some(tx) = &status_tx {
1839 let _ = tx.send(format!(
1840 "cascade: {} quality {:.2} < {:.2}, escalating ({} left)",
1841 p.name(),
1842 verdict.score,
1843 cfg.quality_threshold,
1844 escalations_remaining
1845 ));
1846 }
1847 tracing::info!(
1848 provider = %p.name(),
1849 score = verdict.score,
1850 threshold = cfg.quality_threshold,
1851 escalations_remaining,
1852 "cascade: escalating to next provider"
1853 );
1854 }
1855 }
1856 }
1857
1858 if let Some((_, score)) = &best {
1860 tracing::info!(
1861 score,
1862 "cascade: all providers exhausted, returning best-seen response"
1863 );
1864 } else {
1865 tracing::warn!("cascade: all providers failed, no response available");
1866 }
1867 best.map(|(r, _)| r).ok_or(LlmError::NoProviders)
1868 }
1869
1870 #[allow(clippy::too_many_lines)] async fn cascade_chat_stream(
1880 &self,
1881 providers: &[AnyProvider],
1882 messages: &[Message],
1883 status_tx: Option<StatusTx>,
1884 ) -> Result<ChatStream, LlmError> {
1885 let cfg = self
1886 .cascade_config
1887 .as_ref()
1888 .expect("cascade_config must be set");
1889 let cascade_state = self
1890 .cascade_state
1891 .as_ref()
1892 .expect("cascade_state must be set");
1893
1894 let mut escalations_remaining = cfg.max_escalations;
1895 let mut tokens_used: u32 = 0;
1896 let mut best_seen: Option<(String, f64)> = None;
1900
1901 let (last, early) = providers.split_last().ok_or(LlmError::NoProviders)?;
1904
1905 for (idx, p) in early.iter().enumerate() {
1906 tracing::debug!(
1907 provider = %p.name(),
1908 attempt = idx + 1,
1909 total = providers.len(),
1910 classifier_mode = ?cfg.classifier_mode,
1911 quality_threshold = cfg.quality_threshold,
1912 "cascade stream: trying provider (buffered)"
1913 );
1914 let start = std::time::Instant::now();
1916 let stream = match p.chat_stream(messages).await {
1917 Err(e) => {
1918 let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
1919 self.record_availability(p.name(), false, latency);
1920 tracing::warn!(provider = p.name(), error = %e, "cascade stream: provider error");
1921 if let Some(tx) = &status_tx {
1922 let _ = tx.send(format!("cascade: {} unavailable, trying next", p.name()));
1923 }
1924 continue;
1925 }
1926 Ok(s) => s,
1927 };
1928
1929 let buffered = collect_stream(stream).await;
1931 let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
1932
1933 match buffered {
1934 Err(e) => {
1935 self.record_availability(p.name(), false, latency);
1937 tracing::warn!(provider = p.name(), error = %e, "cascade stream: stream error");
1938 }
1939 Ok(text) => {
1940 let eval = cascade_evaluate_response(
1941 p.name(),
1942 &text,
1943 cfg,
1944 cascade_state,
1945 tokens_used,
1946 "cascade stream",
1947 )
1948 .await;
1949 tokens_used = eval.tokens_used;
1950 let verdict = eval.verdict;
1951 let budget_exhausted = eval.budget_exhausted;
1952
1953 let is_better = !text.is_empty()
1956 && best_seen
1957 .as_ref()
1958 .is_none_or(|(_, best_score)| verdict.score > *best_score);
1959 if is_better {
1960 tracing::debug!(
1961 provider = %p.name(),
1962 score = verdict.score,
1963 "cascade stream: best_seen updated"
1964 );
1965 best_seen = Some((text.clone(), verdict.score));
1966 }
1967
1968 if !verdict.should_escalate || escalations_remaining == 0 || budget_exhausted {
1969 self.record_availability(p.name(), true, latency);
1970
1971 let response_text = if verdict.should_escalate
1976 && (budget_exhausted || escalations_remaining == 0)
1977 {
1978 tracing::info!(
1979 tokens_used,
1980 budget = cfg.max_cascade_tokens,
1981 escalations_remaining,
1982 "cascade stream: escalation blocked, returning best response"
1983 );
1984 best_seen.take().map_or(text, |(r, _)| r)
1985 } else {
1986 text
1987 };
1988
1989 let stream: ChatStream = Box::pin(tokio_stream::once(Ok(
1990 crate::provider::StreamChunk::Content(response_text),
1991 )));
1992 return Ok(stream);
1993 }
1994
1995 self.record_availability(p.name(), true, latency);
1997 escalations_remaining -= 1;
1998
1999 if let Some(tx) = &status_tx {
2000 let _ = tx.send(format!(
2001 "cascade: {} quality {:.2} < {:.2}, escalating",
2002 p.name(),
2003 verdict.score,
2004 cfg.quality_threshold,
2005 ));
2006 }
2007 tracing::info!(
2008 provider = %p.name(),
2009 score = verdict.score,
2010 threshold = cfg.quality_threshold,
2011 escalations_remaining,
2012 "cascade stream: escalating to next provider"
2013 );
2014 }
2015 }
2016 }
2017
2018 tracing::debug!(
2023 provider = %last.name(),
2024 attempt = providers.len(),
2025 total = providers.len(),
2026 "cascade stream: trying last provider (streaming, no classification)"
2027 );
2028 let start = std::time::Instant::now();
2029 match last.chat_stream(messages).await {
2030 Ok(stream) => {
2031 self.record_availability(
2032 last.name(),
2033 true,
2034 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
2035 );
2036 Ok(stream)
2037 }
2038 Err(e) => {
2039 self.record_availability(
2040 last.name(),
2041 false,
2042 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
2043 );
2044 if let Some((best_text, _)) = best_seen {
2047 tracing::info!(
2048 "cascade stream: last provider failed, returning best-seen response"
2049 );
2050 let stream: ChatStream = Box::pin(tokio_stream::once(Ok(
2051 crate::provider::StreamChunk::Content(best_text),
2052 )));
2053 return Ok(stream);
2054 }
2055 Err(e)
2056 }
2057 }
2058 }
2059}
2060
2061const CASCADE_STREAM_MAX_BYTES: usize = 1024 * 1024; async fn collect_stream(stream: ChatStream) -> Result<String, LlmError> {
2068 use tokio_stream::StreamExt as _;
2069
2070 let mut stream = stream;
2071 let mut buf = String::new();
2072 while let Some(chunk) = stream.next().await {
2073 match chunk? {
2074 crate::provider::StreamChunk::Content(c) => {
2075 if buf.len() + c.len() > CASCADE_STREAM_MAX_BYTES {
2076 return Err(LlmError::Other(
2077 "cascade: stream response exceeds 1 MiB buffer limit".into(),
2078 ));
2079 }
2080 buf.push_str(&c);
2081 }
2082 crate::provider::StreamChunk::Thinking(_)
2083 | crate::provider::StreamChunk::Compaction(_)
2084 | crate::provider::StreamChunk::ToolUse(_) => {}
2085 }
2086 }
2087 Ok(buf)
2088}
2089
2090#[cfg(test)]
2091mod tests {
2092 use super::*;
2093 use crate::provider::Role;
2094
2095 #[test]
2096 fn empty_router_name() {
2097 let r = RouterProvider::new(vec![]);
2098 assert_eq!(r.name(), "router");
2099 }
2100
2101 #[test]
2102 fn empty_router_supports_nothing() {
2103 let r = RouterProvider::new(vec![]);
2104 assert!(!r.supports_streaming());
2105 assert!(!r.supports_embeddings());
2106 assert!(!r.supports_tool_use());
2107 }
2108
2109 #[test]
2110 fn empty_router_context_window_none() {
2111 let r = RouterProvider::new(vec![]);
2112 assert!(r.context_window().is_none());
2113 }
2114
2115 #[tokio::test]
2116 async fn empty_router_chat_returns_no_providers() {
2117 let r = RouterProvider::new(vec![]);
2118 let msgs = vec![Message::from_legacy(Role::User, "hello")];
2119 let err = r.chat(&msgs).await.unwrap_err();
2120 assert!(matches!(err, LlmError::NoProviders));
2121 }
2122
2123 #[tokio::test]
2124 async fn empty_router_chat_stream_returns_no_providers() {
2125 let r = RouterProvider::new(vec![]);
2126 let msgs = vec![Message::from_legacy(Role::User, "hello")];
2127 let result = r.chat_stream(&msgs).await;
2128 assert!(matches!(result, Err(LlmError::NoProviders)));
2129 }
2130
2131 #[tokio::test]
2132 async fn empty_router_embed_returns_no_providers() {
2133 let r = RouterProvider::new(vec![]);
2134 let err = r.embed("test").await.unwrap_err();
2135 assert!(matches!(err, LlmError::NoProviders));
2136 }
2137
2138 #[tokio::test]
2139 async fn empty_router_chat_with_tools_returns_no_providers() {
2140 let r = RouterProvider::new(vec![]);
2141 let msgs = vec![Message::from_legacy(Role::User, "hello")];
2142 let err = r.chat_with_tools(&msgs, &[]).await.unwrap_err();
2143 assert!(matches!(err, LlmError::NoProviders));
2144 }
2145
2146 #[tokio::test]
2147 async fn router_falls_back_on_unreachable() {
2148 use crate::ollama::OllamaProvider;
2149
2150 let p1 = AnyProvider::Ollama(OllamaProvider::new(
2151 "http://127.0.0.1:1",
2152 "m".into(),
2153 "e".into(),
2154 ));
2155 let p2 = AnyProvider::Ollama(OllamaProvider::new(
2156 "http://127.0.0.1:2",
2157 "m".into(),
2158 "e".into(),
2159 ));
2160 let r = RouterProvider::new(vec![p1, p2]);
2161 let msgs = vec![Message::from_legacy(Role::User, "hello")];
2162 let err = r.chat(&msgs).await.unwrap_err();
2163 assert!(matches!(err, LlmError::NoProviders));
2164 }
2165
2166 #[test]
2167 fn router_with_streaming_provider() {
2168 use crate::ollama::OllamaProvider;
2169
2170 let p = AnyProvider::Ollama(OllamaProvider::new(
2171 "http://127.0.0.1:1",
2172 "m".into(),
2173 "e".into(),
2174 ));
2175 let r = RouterProvider::new(vec![p]);
2176 assert!(r.supports_streaming());
2177 assert!(r.supports_embeddings());
2178 }
2179
2180 #[test]
2181 fn clone_preserves_providers() {
2182 use crate::ollama::OllamaProvider;
2183
2184 let p = AnyProvider::Ollama(OllamaProvider::new(
2185 "http://127.0.0.1:1",
2186 "m".into(),
2187 "e".into(),
2188 ));
2189 let r = RouterProvider::new(vec![p]);
2190 let c = r.clone();
2191 assert_eq!(c.providers.len(), 1);
2192 assert_eq!(c.name(), "router");
2193 }
2194
2195 #[test]
2196 fn last_cache_usage_returns_none() {
2197 let r = RouterProvider::new(vec![]);
2198 assert!(r.last_cache_usage().is_none());
2199 }
2200
2201 #[test]
2202 fn thompson_strategy_is_set() {
2203 let r = RouterProvider::new(vec![]).with_thompson(None);
2204 assert_eq!(r.strategy, RouterStrategy::Thompson);
2205 assert!(r.thompson.is_some());
2206 }
2207
2208 #[test]
2209 fn save_thompson_state_noop_without_thompson() {
2210 let r = RouterProvider::new(vec![]);
2211 r.save_thompson_state(); }
2213
2214 #[test]
2215 fn thompson_ordered_providers_empty() {
2216 let r = RouterProvider::new(vec![]).with_thompson(None);
2217 let ordered = r.ordered_providers();
2218 assert!(ordered.is_empty());
2219 }
2220
2221 #[test]
2222 fn concurrent_record_outcome_does_not_deadlock() {
2223 use std::sync::Arc;
2224 let r = Arc::new(RouterProvider::new(vec![]).with_thompson(None));
2225 let handles: Vec<_> = (0..8)
2226 .map(|i| {
2227 let router = Arc::clone(&r);
2228 std::thread::spawn(move || {
2229 router.record_availability(&format!("p{i}"), i % 2 == 0, 10);
2230 })
2231 })
2232 .collect();
2233 for h in handles {
2234 h.join().expect("thread panicked");
2235 }
2236 let stats = r.thompson_stats();
2238 assert_eq!(stats.len(), 8);
2239 }
2240
2241 #[test]
2244 fn cascade_strategy_is_set() {
2245 let r = RouterProvider::new(vec![]).with_cascade(CascadeRouterConfig::default());
2246 assert_eq!(r.strategy, RouterStrategy::Cascade);
2247 assert!(r.cascade_state.is_some());
2248 assert!(r.cascade_config.is_some());
2249 }
2250
2251 #[test]
2252 fn cascade_ordered_providers_preserves_chain_order() {
2253 use crate::ollama::OllamaProvider;
2254 let p1 = AnyProvider::Ollama(OllamaProvider::new(
2255 "http://127.0.0.1:1",
2256 "a".into(),
2257 String::new(),
2258 ));
2259 let p2 = AnyProvider::Ollama(OllamaProvider::new(
2260 "http://127.0.0.1:2",
2261 "b".into(),
2262 String::new(),
2263 ));
2264 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
2265 let ordered = r.ordered_providers();
2266 assert_eq!(ordered.len(), 2);
2267 }
2268
2269 #[tokio::test]
2270 async fn cascade_empty_router_returns_no_providers() {
2271 let r = RouterProvider::new(vec![]).with_cascade(CascadeRouterConfig::default());
2272 let msgs = vec![Message::from_legacy(Role::User, "hello")];
2273 let err = r.chat(&msgs).await.unwrap_err();
2274 assert!(matches!(err, LlmError::NoProviders));
2275 }
2276
2277 #[tokio::test]
2278 async fn cascade_returns_best_seen_when_all_fail_after_good_response() {
2279 use crate::mock::MockProvider;
2280
2281 let cheap =
2283 AnyProvider::Mock(MockProvider::with_responses(vec!["ok".to_owned()]).with_delay(0));
2284 let expensive = AnyProvider::Mock(MockProvider::failing());
2286
2287 let r = RouterProvider::new(vec![cheap, expensive]).with_cascade(CascadeRouterConfig {
2288 quality_threshold: 0.9, max_escalations: 2,
2290 ..CascadeRouterConfig::default()
2291 });
2292 let msgs = vec![Message::from_legacy(Role::User, "hello")];
2293 let result = r.chat(&msgs).await.unwrap();
2295 assert_eq!(result, "ok");
2296 }
2297
2298 #[tokio::test]
2299 async fn cascade_accepts_good_quality_response() {
2300 use crate::mock::MockProvider;
2301
2302 let good_response = "This is a comprehensive, well-structured response that provides \
2303 detailed information about the topic. It covers multiple aspects and explains \
2304 the reasoning clearly with proper sentence structure.";
2305
2306 let cheap = AnyProvider::Mock(
2307 MockProvider::with_responses(vec![good_response.to_owned()]).with_delay(0),
2308 );
2309 let expensive = AnyProvider::Mock(MockProvider::failing());
2311
2312 let r = RouterProvider::new(vec![cheap, expensive]).with_cascade(CascadeRouterConfig {
2313 quality_threshold: 0.5,
2314 max_escalations: 1,
2315 ..CascadeRouterConfig::default()
2316 });
2317 let msgs = vec![Message::from_legacy(Role::User, "explain something")];
2318 let result = r.chat(&msgs).await.unwrap();
2319 assert_eq!(result, good_response);
2320 }
2321
2322 #[tokio::test]
2323 async fn cascade_max_escalations_budget_exhausted_returns_last_attempted() {
2324 use crate::mock::MockProvider;
2325
2326 let p1 =
2329 AnyProvider::Mock(MockProvider::with_responses(vec!["x".to_owned()]).with_delay(0));
2330 let p2 =
2331 AnyProvider::Mock(MockProvider::with_responses(vec!["x".to_owned()]).with_delay(0));
2332 let p3 = AnyProvider::Mock(MockProvider::failing()); let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
2335 quality_threshold: 0.9,
2336 max_escalations: 1, ..CascadeRouterConfig::default()
2338 });
2339 let msgs = vec![Message::from_legacy(Role::User, "test")];
2340 let result = r.chat(&msgs).await.unwrap();
2341 assert_eq!(result, "x");
2342 }
2343
2344 #[tokio::test]
2345 async fn cascade_token_budget_stops_escalation() {
2346 use crate::mock::MockProvider;
2347
2348 let p1 =
2349 AnyProvider::Mock(MockProvider::with_responses(vec!["x".to_owned()]).with_delay(0));
2350 let p2 = AnyProvider::Mock(MockProvider::failing()); let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2353 quality_threshold: 0.9, max_escalations: 5,
2355 max_cascade_tokens: Some(1), ..CascadeRouterConfig::default()
2357 });
2358 let msgs = vec![Message::from_legacy(Role::User, "test")];
2359 let result = r.chat(&msgs).await.unwrap();
2360 assert_eq!(result, "x"); }
2362
2363 #[tokio::test]
2364 async fn cascade_budget_returns_best_seen_not_current() {
2365 use crate::mock::MockProvider;
2366
2367 let good_response = "This is a reasonable response with enough content to score well.";
2370 let bad_response = "x"; let p1 = AnyProvider::Mock(
2373 MockProvider::with_responses(vec![good_response.to_owned()]).with_delay(0),
2374 );
2375 let p2 = AnyProvider::Mock(
2376 MockProvider::with_responses(vec![bad_response.to_owned()]).with_delay(0),
2377 );
2378
2379 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2380 quality_threshold: 0.95, max_escalations: 5,
2382 max_cascade_tokens: Some(1), ..CascadeRouterConfig::default()
2384 });
2385 let msgs = vec![Message::from_legacy(Role::User, "test")];
2386 let result = r.chat(&msgs).await.unwrap();
2390 assert_ne!(result, bad_response, "should return best-seen, not current");
2392 }
2393
2394 #[tokio::test]
2395 async fn cascade_escalations_exhausted_returns_best_seen_not_current() {
2396 use crate::mock::MockProvider;
2397
2398 let good_response = "This is a reasonable response with enough content to score well.";
2401 let bad_response = "x";
2402
2403 let p1 = AnyProvider::Mock(
2404 MockProvider::with_responses(vec![good_response.to_owned()]).with_delay(0),
2405 );
2406 let p2 = AnyProvider::Mock(
2407 MockProvider::with_responses(vec![bad_response.to_owned()]).with_delay(0),
2408 );
2409 let p3 = AnyProvider::Mock(MockProvider::failing()); let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
2412 quality_threshold: 0.95, max_escalations: 1, ..CascadeRouterConfig::default()
2415 });
2416 let msgs = vec![Message::from_legacy(Role::User, "test")];
2417 let result = r.chat(&msgs).await.unwrap();
2418 assert_eq!(
2419 result, good_response,
2420 "should return best-seen (p1), not the degenerate current response (p2)"
2421 );
2422 assert_ne!(
2423 result, bad_response,
2424 "must not return degenerate p2 response"
2425 );
2426 }
2427
2428 #[tokio::test]
2429 async fn cascade_stream_escalations_exhausted_returns_best_seen_not_current() {
2430 use crate::mock::MockProvider;
2431
2432 let good_response = "This is a reasonable response with enough content to score well.";
2436 let bad_response = "x";
2437
2438 let p1 = AnyProvider::Mock(
2439 MockProvider::with_responses(vec![good_response.to_owned()])
2440 .with_delay(0)
2441 .with_streaming(),
2442 );
2443 let p2 = AnyProvider::Mock(
2444 MockProvider::with_responses(vec![bad_response.to_owned()])
2445 .with_delay(0)
2446 .with_streaming(),
2447 );
2448 let p3 = AnyProvider::Mock(MockProvider::failing()); let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
2451 quality_threshold: 0.95, max_escalations: 1, ..CascadeRouterConfig::default()
2454 });
2455 let msgs = vec![Message::from_legacy(Role::User, "test")];
2456 let stream = r.chat_stream(&msgs).await.unwrap();
2457 let collected = collect_stream(stream).await.unwrap();
2458 assert_eq!(
2459 collected, good_response,
2460 "should return best-seen (p1), not the degenerate current response (p2)"
2461 );
2462 assert_ne!(
2463 collected, bad_response,
2464 "must not return degenerate p2 response"
2465 );
2466 }
2467
2468 #[tokio::test]
2469 async fn cascade_all_providers_fail_returns_no_providers() {
2470 use crate::mock::MockProvider;
2471
2472 let p1 = AnyProvider::Mock(MockProvider::failing());
2473 let p2 = AnyProvider::Mock(MockProvider::failing());
2474
2475 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
2476 let msgs = vec![Message::from_legacy(Role::User, "test")];
2477 let err = r.chat(&msgs).await.unwrap_err();
2478 assert!(matches!(err, LlmError::NoProviders));
2479 }
2480
2481 #[tokio::test]
2482 async fn cascade_stream_good_quality_no_escalation() {
2483 use crate::mock::MockProvider;
2484
2485 let good = "This is a well-formed response with sufficient length and coherent structure.";
2486 let p1 = AnyProvider::Mock(
2487 MockProvider::with_responses(vec![good.to_owned()])
2488 .with_delay(0)
2489 .with_streaming(),
2490 );
2491 let p2 = AnyProvider::Mock(MockProvider::failing());
2492
2493 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2494 quality_threshold: 0.5,
2495 max_escalations: 1,
2496 ..CascadeRouterConfig::default()
2497 });
2498 let msgs = vec![Message::from_legacy(Role::User, "q")];
2499 let stream = r.chat_stream(&msgs).await.unwrap();
2500 let collected = collect_stream(stream).await.unwrap();
2501 assert_eq!(collected, good);
2502 }
2503
2504 #[tokio::test]
2505 async fn cascade_stream_escalates_to_last_provider() {
2506 use crate::mock::MockProvider;
2507
2508 let bad = "x"; let good = "This is the expensive model's comprehensive response.";
2510 let p1 = AnyProvider::Mock(
2511 MockProvider::with_responses(vec![bad.to_owned()])
2512 .with_delay(0)
2513 .with_streaming(),
2514 );
2515 let p2 = AnyProvider::Mock(
2516 MockProvider::with_responses(vec![good.to_owned()])
2517 .with_delay(0)
2518 .with_streaming(),
2519 );
2520
2521 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2522 quality_threshold: 0.9, max_escalations: 1,
2524 ..CascadeRouterConfig::default()
2525 });
2526 let msgs = vec![Message::from_legacy(Role::User, "q")];
2527 let stream = r.chat_stream(&msgs).await.unwrap();
2528 let collected = collect_stream(stream).await.unwrap();
2529 assert_eq!(collected, good);
2530 }
2531
2532 #[tokio::test]
2533 async fn cascade_stream_budget_returns_best_seen() {
2534 use crate::mock::MockProvider;
2535
2536 let good_response = "This is a reasonable response with enough content to score well.";
2541 let bad_response = "x"; let p1 = AnyProvider::Mock(
2544 MockProvider::with_responses(vec![good_response.to_owned()])
2545 .with_delay(0)
2546 .with_streaming(),
2547 );
2548 let p2 = AnyProvider::Mock(
2549 MockProvider::with_responses(vec![bad_response.to_owned()])
2550 .with_delay(0)
2551 .with_streaming(),
2552 );
2553 let p3 = AnyProvider::Mock(MockProvider::failing()); let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
2556 quality_threshold: 0.95, max_escalations: 5,
2558 max_cascade_tokens: Some(1), ..CascadeRouterConfig::default()
2560 });
2561 let msgs = vec![Message::from_legacy(Role::User, "test")];
2562 let stream = r.chat_stream(&msgs).await.unwrap();
2563 let collected = collect_stream(stream).await.unwrap();
2564 assert_eq!(
2566 collected, good_response,
2567 "should return best-seen p1 response when budget exhausted"
2568 );
2569 }
2570
2571 #[tokio::test]
2572 async fn cascade_stream_budget_returns_best_seen_not_current() {
2573 use crate::mock::MockProvider;
2574
2575 let good_response = "This is a reasonable response with enough content to score well.";
2581 let bad_response = "x"; let p1 = AnyProvider::Mock(
2584 MockProvider::with_responses(vec![good_response.to_owned()])
2585 .with_delay(0)
2586 .with_streaming(),
2587 );
2588 let p2 = AnyProvider::Mock(
2589 MockProvider::with_responses(vec![bad_response.to_owned()])
2590 .with_delay(0)
2591 .with_streaming(),
2592 );
2593 let p3 = AnyProvider::Mock(MockProvider::failing()); let p4 = AnyProvider::Mock(MockProvider::failing()); let r = RouterProvider::new(vec![p1, p2, p3, p4]).with_cascade(CascadeRouterConfig {
2599 quality_threshold: 0.95, max_escalations: 5,
2601 max_cascade_tokens: Some(17), ..CascadeRouterConfig::default()
2603 });
2604 let msgs = vec![Message::from_legacy(Role::User, "test")];
2605 let stream = r.chat_stream(&msgs).await.unwrap();
2606 let collected = collect_stream(stream).await.unwrap();
2607 assert_eq!(
2609 collected, good_response,
2610 "should return best-seen (p1), not current degenerate (p2)"
2611 );
2612 assert_ne!(
2613 collected, bad_response,
2614 "must not return the degenerate p2 response"
2615 );
2616 }
2617
2618 #[tokio::test]
2619 async fn cascade_stream_last_fails_returns_best_seen() {
2620 use crate::mock::MockProvider;
2621
2622 let low_quality = "ok"; let p1 = AnyProvider::Mock(
2628 MockProvider::with_responses(vec![low_quality.to_owned()])
2629 .with_delay(0)
2630 .with_streaming(),
2631 );
2632 let p2 = AnyProvider::Mock(MockProvider::failing()); let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2635 quality_threshold: 0.9, max_escalations: 2,
2637 ..CascadeRouterConfig::default()
2638 });
2639 let msgs = vec![Message::from_legacy(Role::User, "hello")];
2640 let stream = r.chat_stream(&msgs).await.unwrap();
2641 let collected = collect_stream(stream).await.unwrap();
2642 assert_eq!(collected, low_quality);
2643 }
2644
2645 #[tokio::test]
2646 async fn cascade_stream_all_fail_returns_error() {
2647 use crate::mock::MockProvider;
2648
2649 let p1 = AnyProvider::Mock(MockProvider::failing());
2653 let p2 = AnyProvider::Mock(MockProvider::failing());
2654
2655 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
2656 let msgs = vec![Message::from_legacy(Role::User, "test")];
2657 let result = r.chat_stream(&msgs).await;
2658 assert!(
2659 result.is_err(),
2660 "expected error when all providers fail with no best_seen"
2661 );
2662 }
2663
2664 #[test]
2665 fn cascade_config_default_values() {
2666 let cfg = CascadeRouterConfig::default();
2667 assert!((cfg.quality_threshold - 0.5).abs() < f64::EPSILON);
2668 assert_eq!(cfg.max_escalations, 2);
2669 assert_eq!(cfg.window_size, 50);
2670 assert!(cfg.max_cascade_tokens.is_none());
2671 assert_eq!(cfg.classifier_mode, cascade::ClassifierMode::Heuristic);
2672 }
2673
2674 #[test]
2675 fn evaluate_heuristic_empty_should_escalate_above_threshold() {
2676 let verdict = RouterProvider::evaluate_heuristic("", 0.05);
2677 assert!(verdict.should_escalate);
2679 }
2680
2681 #[test]
2682 fn evaluate_heuristic_good_response_does_not_escalate() {
2683 let text = "The answer to your question is straightforward. Consider the options and pick the best one.";
2684 let verdict = RouterProvider::evaluate_heuristic(text, 0.5);
2685 assert!(!verdict.should_escalate, "score={}", verdict.score);
2686 }
2687
2688 #[tokio::test]
2692 async fn cascade_empty_response_not_stored_as_best_seen() {
2693 use crate::mock::MockProvider;
2694
2695 let p = AnyProvider::Mock(MockProvider::with_responses(vec![String::new()]));
2698 let cfg = CascadeRouterConfig {
2699 quality_threshold: 0.0,
2700 ..Default::default()
2701 };
2702 let r = RouterProvider::new(vec![p]).with_cascade(cfg);
2703 let msgs = vec![Message::from_legacy(Role::User, "hi")];
2704 let result = r.chat(&msgs).await;
2707 assert!(result.is_ok());
2708 assert_eq!(result.unwrap(), "");
2709 }
2710
2711 #[tokio::test]
2714 async fn cascade_empty_best_seen_not_returned_on_all_fail() {
2715 use crate::mock::MockProvider;
2716
2717 let p1 = AnyProvider::Mock(MockProvider::with_responses(vec![String::new()]));
2720 let p2 = AnyProvider::Mock(MockProvider::failing());
2721
2722 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
2723 let msgs = vec![Message::from_legacy(Role::User, "hi")];
2724 let result = r.chat(&msgs).await;
2725 assert!(
2727 result.is_err(),
2728 "expected error, not silent empty string; got: {result:?}"
2729 );
2730 }
2731
2732 #[tokio::test]
2734 async fn cascade_stream_empty_response_not_stored_as_best_seen() {
2735 use crate::mock::MockProvider;
2736
2737 let p1 = AnyProvider::Mock(MockProvider::with_responses(vec![String::new()]));
2740 let p2 = AnyProvider::Mock(
2741 MockProvider::with_responses(vec!["real answer".to_owned()]).with_streaming(),
2742 );
2743
2744 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
2745 let msgs = vec![Message::from_legacy(Role::User, "hi")];
2746 let stream = r.chat_stream(&msgs).await.expect("should not error");
2747 let text = collect_stream(stream).await.expect("stream should succeed");
2748 assert_eq!(text, "real answer");
2749 }
2750
2751 #[test]
2754 fn arc_providers_clone_shares_allocation() {
2755 use crate::mock::MockProvider;
2756 let p = AnyProvider::Mock(MockProvider::default());
2757 let r = RouterProvider::new(vec![p]);
2758 let c = r.clone();
2759 assert!(Arc::ptr_eq(&r.providers, &c.providers));
2761 }
2762
2763 #[test]
2764 fn cost_tiers_reorders_providers_at_construction() {
2765 use crate::mock::MockProvider;
2766 let p1 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
2767 let p2 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
2768 let p3 = AnyProvider::Mock(MockProvider::default().with_name("openai"));
2769 let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
2770 cost_tiers: Some(vec!["ollama".into(), "claude".into()]),
2771 ..CascadeRouterConfig::default()
2772 });
2773 let names: Vec<&str> = r.providers.iter().map(LlmProvider::name).collect();
2774 assert_eq!(names, vec!["ollama", "claude", "openai"]);
2776 }
2777
2778 #[test]
2779 fn cost_tiers_none_preserves_chain_order() {
2780 use crate::mock::MockProvider;
2781 let p1 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
2782 let p2 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
2783 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2784 cost_tiers: None,
2785 ..CascadeRouterConfig::default()
2786 });
2787 let names: Vec<&str> = r.providers.iter().map(LlmProvider::name).collect();
2788 assert_eq!(names, vec!["claude", "ollama"]);
2789 }
2790
2791 #[test]
2792 fn cost_tiers_empty_vec_preserves_chain_order() {
2793 use crate::mock::MockProvider;
2794 let p1 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
2795 let p2 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
2796 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2797 cost_tiers: Some(vec![]),
2798 ..CascadeRouterConfig::default()
2799 });
2800 let names: Vec<&str> = r.providers.iter().map(LlmProvider::name).collect();
2801 assert_eq!(names, vec!["claude", "ollama"]);
2802 }
2803
2804 #[test]
2805 fn cost_tiers_unknown_name_ignored() {
2806 use crate::mock::MockProvider;
2807 let p1 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
2808 let p2 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
2809 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2810 cost_tiers: Some(vec!["nonexistent".into(), "ollama".into()]),
2811 ..CascadeRouterConfig::default()
2812 });
2813 let names: Vec<&str> = r.providers.iter().map(LlmProvider::name).collect();
2814 assert_eq!(names, vec!["ollama", "claude"]);
2816 }
2817
2818 #[test]
2819 fn cost_tiers_all_providers_listed() {
2820 use crate::mock::MockProvider;
2821 let p1 = AnyProvider::Mock(MockProvider::default().with_name("c"));
2822 let p2 = AnyProvider::Mock(MockProvider::default().with_name("b"));
2823 let p3 = AnyProvider::Mock(MockProvider::default().with_name("a"));
2824 let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
2825 cost_tiers: Some(vec!["a".into(), "b".into(), "c".into()]),
2826 ..CascadeRouterConfig::default()
2827 });
2828 let names: Vec<&str> = r.providers.iter().map(LlmProvider::name).collect();
2829 assert_eq!(names, vec!["a", "b", "c"]);
2830 }
2831
2832 #[test]
2833 fn cost_tiers_duplicate_name_uses_last_position() {
2834 use crate::mock::MockProvider;
2835 let p1 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
2836 let p2 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
2837 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2840 cost_tiers: Some(vec!["claude".into(), "ollama".into(), "ollama".into()]),
2841 ..CascadeRouterConfig::default()
2842 });
2843 let names: Vec<&str> = r.providers.iter().map(LlmProvider::name).collect();
2844 assert_eq!(names, vec!["claude", "ollama"]);
2845 }
2846
2847 #[test]
2848 fn cost_tiers_empty_router_does_not_panic() {
2849 let r = RouterProvider::new(vec![]).with_cascade(CascadeRouterConfig {
2850 cost_tiers: Some(vec!["foo".into()]),
2851 ..CascadeRouterConfig::default()
2852 });
2853 assert_eq!(r.providers.len(), 0);
2854 }
2855
2856 #[test]
2857 fn set_status_tx_works_with_arc() {
2858 use crate::mock::MockProvider;
2859 let p = AnyProvider::Mock(MockProvider::default());
2860 let mut r = RouterProvider::new(vec![p]);
2861 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
2862 r.set_status_tx(tx); }
2864
2865 #[tokio::test]
2866 async fn cascade_chat_with_tools_unaffected_by_cost_tiers() {
2867 use crate::mock::MockProvider;
2868 let p1 = AnyProvider::Mock(MockProvider::failing().with_name("cheap"));
2871 let p2 = AnyProvider::Mock(MockProvider::failing().with_name("expensive"));
2872 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2873 cost_tiers: Some(vec!["cheap".into()]),
2874 ..CascadeRouterConfig::default()
2875 });
2876 let msgs = vec![Message::from_legacy(Role::User, "hi")];
2877 let err = r.chat_with_tools(&msgs, &[]).await.unwrap_err();
2879 assert!(matches!(err, LlmError::NoProviders));
2880 }
2881
2882 #[tokio::test]
2887 async fn embed_retries_on_rate_limited_then_succeeds() {
2888 use crate::mock::MockProvider;
2889
2890 let p = AnyProvider::Mock({
2891 let mut m = MockProvider::default()
2892 .with_errors(vec![LlmError::RateLimited, LlmError::RateLimited])
2893 .with_name("p1");
2894 m.supports_embeddings = true;
2895 m.embedding = vec![0.1, 0.2];
2896 m
2897 });
2898 let r = RouterProvider::new(vec![p]);
2899 let result = r.embed("text").await.unwrap();
2900 assert_eq!(result, vec![0.1, 0.2]);
2901 }
2902
2903 #[tokio::test]
2906 async fn embed_falls_back_after_all_retries_exhausted() {
2907 use crate::mock::MockProvider;
2908
2909 let p1 = AnyProvider::Mock({
2911 let mut m = MockProvider::default()
2912 .with_errors(vec![
2913 LlmError::RateLimited,
2914 LlmError::RateLimited,
2915 LlmError::RateLimited,
2916 LlmError::RateLimited,
2917 ])
2918 .with_name("p1");
2919 m.supports_embeddings = true;
2920 m
2921 });
2922 let p2 = AnyProvider::Mock({
2923 let mut m = MockProvider::default().with_name("p2");
2924 m.supports_embeddings = true;
2925 m.embedding = vec![9.0, 8.0];
2926 m
2927 });
2928 let r = RouterProvider::new(vec![p1, p2]);
2929 let result = r.embed("text").await.unwrap();
2930 assert_eq!(result, vec![9.0, 8.0]);
2931 }
2932
2933 #[tokio::test]
2935 async fn embed_batch_retries_on_rate_limited_then_succeeds() {
2936 use crate::mock::MockProvider;
2937
2938 let p = AnyProvider::Mock({
2939 let mut m = MockProvider::default()
2940 .with_errors(vec![LlmError::RateLimited, LlmError::RateLimited])
2941 .with_name("p1");
2942 m.supports_embeddings = true;
2943 m.embedding = vec![0.5, 0.6];
2944 m
2945 });
2946 let r = RouterProvider::new(vec![p]);
2947 let result = r.embed_batch(&["a", "b"]).await.unwrap();
2948 assert_eq!(result, vec![vec![0.5, 0.6], vec![0.5, 0.6]]);
2949 }
2950
2951 #[tokio::test]
2954 async fn embed_batch_falls_back_after_all_retries_exhausted() {
2955 use crate::mock::MockProvider;
2956
2957 let p1 = AnyProvider::Mock({
2959 let mut m = MockProvider::default()
2960 .with_errors(vec![
2961 LlmError::RateLimited,
2962 LlmError::RateLimited,
2963 LlmError::RateLimited,
2964 LlmError::RateLimited,
2965 ])
2966 .with_name("p1");
2967 m.supports_embeddings = true;
2968 m
2969 });
2970 let p2 = AnyProvider::Mock({
2971 let mut m = MockProvider::default().with_name("p2");
2972 m.supports_embeddings = true;
2973 m.embedding = vec![7.0, 8.0];
2974 m
2975 });
2976 let r = RouterProvider::new(vec![p1, p2]);
2977 let result = r.embed_batch(&["x"]).await.unwrap();
2978 assert_eq!(result, vec![vec![7.0, 8.0]]);
2979 }
2980
2981 #[tokio::test]
2986 async fn embed_invalid_input_breaks_loop_and_returns_invalid_input() {
2987 use crate::mock::MockProvider;
2988
2989 let p = AnyProvider::Mock(MockProvider::default().with_embed_invalid_input());
2990 let r = RouterProvider::new(vec![p]).with_thompson(None);
2991 let err = r.embed("some text").await.unwrap_err();
2992 assert!(
2993 matches!(err, LlmError::InvalidInput { .. }),
2994 "expected InvalidInput, got {err:?}"
2995 );
2996 }
2997
2998 #[tokio::test]
3001 async fn embed_invalid_input_does_not_fall_through_to_second_provider() {
3002 use crate::mock::MockProvider;
3003
3004 let p1 = AnyProvider::Mock(
3008 MockProvider::default()
3009 .with_embed_invalid_input()
3010 .with_name("p1"),
3011 );
3012 let p2 = AnyProvider::Mock({
3013 let mut m = MockProvider::default();
3014 m.supports_embeddings = true;
3015 m.name_override = Some("p2".into());
3016 m
3017 });
3018
3019 let r = RouterProvider::new(vec![p1, p2]);
3020 let err = r.embed("test").await.unwrap_err();
3021
3022 assert!(
3024 matches!(&err, LlmError::InvalidInput { provider, .. } if provider == "p1"),
3025 "expected InvalidInput from p1, got {err:?}"
3026 );
3027 }
3028
3029 #[tokio::test]
3032 async fn embed_skips_non_embedding_providers_and_falls_through() {
3033 use crate::mock::MockProvider;
3034
3035 let p1 = AnyProvider::Mock({
3038 let mut m = MockProvider::default().with_name("p1");
3039 m.supports_embeddings = false;
3040 m
3041 });
3042 let p2 = AnyProvider::Mock({
3043 let mut m = MockProvider::default().with_name("p2");
3044 m.supports_embeddings = true;
3045 m.embedding = vec![1.0, 2.0, 3.0];
3046 m
3047 });
3048
3049 let r = RouterProvider::new(vec![p1, p2]);
3050 let result = r.embed("hello").await.unwrap();
3051 assert_eq!(result, vec![1.0, 2.0, 3.0]);
3052 }
3053
3054 #[tokio::test]
3058 async fn embed_invalid_input_does_not_record_availability() {
3059 use crate::mock::MockProvider;
3060
3061 let p = AnyProvider::Mock(
3062 MockProvider::default()
3063 .with_embed_invalid_input()
3064 .with_name("test-provider"),
3065 );
3066 let r = RouterProvider::new(vec![p]).with_thompson(None);
3067 let _ = r.embed("text").await;
3068
3069 let stats = r.thompson_stats();
3072 let provider_in_stats = stats.iter().any(|(name, ..)| name == "test-provider");
3073 assert!(
3074 !provider_in_stats,
3075 "InvalidInput must not update provider reputation; stats: {stats:?}"
3076 );
3077 }
3078
3079 #[tokio::test]
3084 async fn quality_gate_passes_when_similarity_above_threshold() {
3085 use crate::mock::MockProvider;
3086
3087 let p1 = AnyProvider::Mock({
3090 let mut m = MockProvider::with_responses(vec!["answer".to_owned()]).with_name("p1");
3091 m.supports_embeddings = true;
3092 m.embedding = vec![1.0, 0.0];
3093 m
3094 });
3095 let r = RouterProvider::new(vec![p1])
3096 .with_thompson(None)
3097 .with_quality_gate(0.5);
3098 let msgs = vec![Message::from_legacy(Role::User, "question")];
3099 let result = r.chat(&msgs).await.unwrap();
3100 assert_eq!(result, "answer");
3101 }
3102
3103 #[tokio::test]
3106 async fn quality_gate_exhaustion_returns_best_seen() {
3107 use crate::mock::MockProvider;
3108
3109 let p1 = AnyProvider::Mock({
3113 let mut m =
3114 MockProvider::with_responses(vec!["best_so_far".to_owned()]).with_name("p1");
3115 m.supports_embeddings = true;
3116 m.embedding = vec![0.0, 1.0];
3118 m
3119 });
3120 let p2 = AnyProvider::Mock(MockProvider::failing().with_name("p2"));
3121 let r = RouterProvider::new(vec![p1, p2])
3122 .with_thompson(None)
3123 .with_quality_gate(0.9);
3124 let msgs = vec![Message::from_legacy(Role::User, "question")];
3125 let result = r.chat(&msgs).await.unwrap();
3126 assert_eq!(result, "best_so_far");
3127 }
3128
3129 #[test]
3134 fn routing_signals_quality_gate_above_one_is_ignored() {
3135 let threshold: f32 = 5.0;
3138 let mut router = RouterProvider::new(vec![]);
3139 if threshold.is_finite() && threshold > 0.0 && threshold <= 1.0 {
3140 router = router.with_quality_gate(threshold);
3141 }
3142 assert!(
3143 router.quality_gate.is_none(),
3144 "out-of-range quality_gate must not be wired; got {:?}",
3145 router.quality_gate
3146 );
3147 }
3148
3149 #[test]
3151 fn routing_signals_quality_gate_valid_is_wired() {
3152 let threshold: f32 = 0.8;
3153 let mut router = RouterProvider::new(vec![]);
3154 if threshold.is_finite() && threshold > 0.0 && threshold <= 1.0 {
3155 router = router.with_quality_gate(threshold);
3156 }
3157 assert_eq!(
3158 router.quality_gate,
3159 Some(0.8),
3160 "valid quality_gate must be wired"
3161 );
3162 }
3163
3164 #[test]
3167 fn asi_debounce_same_turn_fires_once() {
3168 let router = RouterProvider::new(vec![]);
3169 let turn_id = 42u64;
3170
3171 let prev1 = router.asi_last_turn.swap(turn_id, Ordering::AcqRel);
3173 let first_dropped = prev1 == turn_id;
3174
3175 let prev2 = router.asi_last_turn.swap(turn_id, Ordering::AcqRel);
3177 let second_dropped = prev2 == turn_id;
3178
3179 assert!(!first_dropped, "first call in turn must not be dropped");
3180 assert!(second_dropped, "second call in same turn must be dropped");
3181 }
3182
3183 #[test]
3184 fn asi_debounce_next_turn_fires_again() {
3185 let router = RouterProvider::new(vec![]);
3186
3187 let prev1 = router.asi_last_turn.swap(1u64, Ordering::AcqRel);
3189 assert_ne!(prev1, 1u64, "turn 1: initial value != 1, should proceed");
3190
3191 let prev2 = router.asi_last_turn.swap(2u64, Ordering::AcqRel);
3193 let dropped = prev2 == 2u64;
3194 assert!(!dropped, "turn 2 must not be dropped (different turn_id)");
3195 }
3196
3197 #[test]
3198 fn turn_counter_increments_across_clones() {
3199 let router = RouterProvider::new(vec![]);
3200 let clone = router.clone();
3201
3202 let t0 = router.turn_counter.fetch_add(1, Ordering::Relaxed);
3203 let t1 = clone.turn_counter.fetch_add(1, Ordering::Relaxed);
3204
3205 assert_eq!(t1, t0 + 1, "cloned router shares turn_counter");
3207 }
3208
3209 #[test]
3210 fn with_embed_concurrency_zero_means_no_semaphore() {
3211 let r = RouterProvider::new(vec![]).with_embed_concurrency(0);
3212 assert!(r.embed_semaphore.is_none(), "0 should disable semaphore");
3213 }
3214
3215 #[test]
3216 fn with_embed_concurrency_positive_creates_semaphore() {
3217 let r = RouterProvider::new(vec![]).with_embed_concurrency(4);
3218 let sem = r.embed_semaphore.as_ref().expect("semaphore should exist");
3219 assert_eq!(sem.available_permits(), 4);
3220 }
3221
3222 #[tokio::test]
3223 async fn embed_semaphore_limits_concurrency() {
3224 use std::sync::Arc as StdArc;
3225 use std::sync::atomic::{AtomicUsize, Ordering as AO};
3226
3227 let sem = Arc::new(tokio::sync::Semaphore::new(2));
3230 let concurrent_peak = StdArc::new(AtomicUsize::new(0));
3231 let active = StdArc::new(AtomicUsize::new(0));
3232
3233 let mut handles = vec![];
3234 for _ in 0..6 {
3235 let sem_clone = sem.clone();
3236 let peak = concurrent_peak.clone();
3237 let active = active.clone();
3238 handles.push(tokio::spawn(async move {
3239 let _permit = sem_clone.acquire().await.unwrap();
3240 let cur = active.fetch_add(1, AO::SeqCst) + 1;
3241 let mut p = peak.load(AO::SeqCst);
3243 while p < cur {
3244 match peak.compare_exchange(p, cur, AO::SeqCst, AO::SeqCst) {
3245 Ok(_) => break,
3246 Err(new) => p = new,
3247 }
3248 }
3249 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
3250 active.fetch_sub(1, AO::SeqCst);
3251 }));
3252 }
3253 for h in handles {
3254 h.await.unwrap();
3255 }
3256 assert!(
3257 concurrent_peak.load(AO::SeqCst) <= 2,
3258 "peak concurrency should not exceed semaphore limit"
3259 );
3260 }
3261}