1pub mod asi;
43pub mod aware;
44pub mod bandit;
45pub mod cascade;
46pub mod coe;
47pub mod reputation;
48pub mod state;
49pub mod thompson;
50pub mod triage;
51
52pub use aware::RouterAware;
53pub use state::RouterState;
54
55use std::collections::HashMap;
56use std::path::Path;
57use std::sync::Arc;
58use std::sync::atomic::{AtomicU64, Ordering};
59
60use parking_lot::Mutex;
61
62use crate::any::AnyProvider;
63use crate::ema::EmaTracker;
64use crate::embed::owned_strs;
65use crate::error::LlmError;
66use crate::provider::{ChatResponse, ChatStream, LlmProvider, Message, StatusTx, ToolDefinition};
67use coe::{CoeDecision, CoeRouter, run_coe};
68
69use asi::AsiState;
70use bandit::{BanditState, embedding_to_features};
71use cascade::{CascadeState, ClassifierMode, heuristic_score};
72use reputation::ReputationTracker;
73use thompson::ThompsonState;
74
75static ASI_WARN_LAST_SECS: AtomicU64 = AtomicU64::new(0);
77
78const MAX_ASI_TASKS: usize = 8;
84use zeph_common::math::cosine_similarity;
85
86fn blocking_load<T>(f: impl FnOnce() -> T) -> T {
92 if tokio::runtime::Handle::try_current()
93 .is_ok_and(|h| h.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread)
94 {
95 tokio::task::block_in_place(f)
96 } else {
97 f()
98 }
99}
100
101#[derive(Debug)]
107struct BanditEmbedCache {
108 map: HashMap<u64, Vec<f32>>,
109 order: std::collections::VecDeque<u64>,
110 capacity: usize,
111}
112
113impl BanditEmbedCache {
114 fn new(capacity: usize) -> Self {
115 Self {
116 map: HashMap::with_capacity(capacity),
117 order: std::collections::VecDeque::with_capacity(capacity),
118 capacity,
119 }
120 }
121
122 fn get(&self, key: u64) -> Option<&Vec<f32>> {
123 self.map.get(&key)
124 }
125
126 fn insert(&mut self, key: u64, value: Vec<f32>) {
127 if self.map.contains_key(&key) {
128 return;
129 }
130 if self.map.len() >= self.capacity
131 && let Some(evict) = self.order.pop_front()
132 {
133 self.map.remove(&evict);
134 }
135 self.map.insert(key, value);
136 self.order.push_back(key);
137 }
138}
139
140impl Default for BanditEmbedCache {
141 fn default() -> Self {
142 Self::new(512)
143 }
144}
145
146#[derive(Debug, Default)]
152struct TurnEmbedCache {
153 entries: HashMap<String, Vec<f32>>,
154}
155
156impl TurnEmbedCache {
157 fn get(&self, text: &str) -> Option<&Vec<f32>> {
158 self.entries.get(text)
159 }
160
161 fn insert(&mut self, text: impl Into<String>, embedding: Vec<f32>) {
162 self.entries.insert(text.into(), embedding);
163 }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
168#[non_exhaustive]
169pub enum RouterStrategy {
170 #[default]
172 Ema,
173 Thompson,
175 Cascade,
177 Bandit,
179}
180
181#[derive(Debug, Clone)]
185#[allow(clippy::doc_markdown)] pub struct BanditRouterConfig {
187 pub alpha: f32,
189 pub dim: usize,
191 pub cost_weight: f32,
194 pub decay_factor: f32,
196 pub warmup_queries: u64,
199 pub embedding_timeout_ms: u64,
202 pub cache_size: usize,
204 pub memory_confidence_threshold: f32,
207}
208
209impl Default for BanditRouterConfig {
210 fn default() -> Self {
211 Self {
212 alpha: 1.0,
213 dim: 32,
214 cost_weight: 0.1,
215 decay_factor: 1.0,
216 warmup_queries: 0, embedding_timeout_ms: 50,
218 cache_size: 512,
219 memory_confidence_threshold: 0.9,
220 }
221 }
222}
223
224#[derive(Debug, Clone)]
229pub struct AsiRouterConfig {
230 pub window: usize,
232 pub coherence_threshold: f32,
234 pub penalty_weight: f32,
236}
237
238impl Default for AsiRouterConfig {
239 fn default() -> Self {
240 Self {
241 window: 5,
242 coherence_threshold: 0.7,
243 penalty_weight: 0.3,
244 }
245 }
246}
247
248#[derive(Debug, Clone)]
250pub struct CascadeRouterConfig {
251 pub quality_threshold: f64,
252 pub max_escalations: u8,
253 pub classifier_mode: ClassifierMode,
254 pub window_size: usize,
255 pub max_cascade_tokens: Option<u32>,
256 pub summary_provider: Option<Arc<dyn crate::provider_dyn::LlmProviderDyn>>,
259 pub cost_tiers: Option<Vec<String>>,
263 pub judge_timeout_ms: u64,
265}
266
267impl Default for CascadeRouterConfig {
268 fn default() -> Self {
269 Self {
270 quality_threshold: 0.5,
271 max_escalations: 2,
272 classifier_mode: ClassifierMode::Heuristic,
273 window_size: 50,
274 max_cascade_tokens: None,
275 summary_provider: None,
276 cost_tiers: None,
277 judge_timeout_ms: 5_000,
278 }
279 }
280}
281
282#[derive(Debug, Clone)]
293pub struct RouterProvider {
294 pub(crate) state: RouterState,
298 status_tx: Option<StatusTx>,
299 ema: Option<EmaTracker>,
300 strategy: RouterStrategy,
301 thompson: Option<Arc<Mutex<ThompsonState>>>,
302 thompson_state_path: Option<std::path::PathBuf>,
304 cascade_state: Option<Arc<Mutex<CascadeState>>>,
306 cascade_config: Option<CascadeRouterConfig>,
308 reputation: Option<Arc<Mutex<ReputationTracker>>>,
310 reputation_state_path: Option<std::path::PathBuf>,
312 reputation_weight: f64,
314 bandit: Option<Arc<Mutex<BanditState>>>,
316 bandit_state_path: Option<std::path::PathBuf>,
318 bandit_config: Option<BanditRouterConfig>,
320 bandit_embedding_provider: Option<Arc<dyn crate::provider_dyn::LlmProviderDyn>>,
323 bandit_embed_cache: Arc<Mutex<BanditEmbedCache>>,
326 asi: Option<Arc<Mutex<AsiState>>>,
328 asi_config: Option<AsiRouterConfig>,
330 quality_gate: Option<f32>,
334 coe: Option<Arc<CoeRouter>>,
336 embed_timeout_ms: u64,
339 asi_tasks: Arc<Mutex<tokio::task::JoinSet<()>>>,
344}
345
346impl RouterProvider {
347 #[must_use]
353 pub fn new(providers: Vec<AnyProvider>) -> Self {
354 let state = RouterState::new(Arc::from(providers));
355 Self {
356 state,
357 status_tx: None,
358 ema: None,
359 strategy: RouterStrategy::Ema,
360 thompson: None,
361 thompson_state_path: None,
362 cascade_state: None,
363 cascade_config: None,
364 reputation: None,
365 reputation_state_path: None,
366 reputation_weight: 0.3,
367 bandit: None,
368 bandit_state_path: None,
369 bandit_config: None,
370 bandit_embedding_provider: None,
371 bandit_embed_cache: Arc::new(Mutex::new(BanditEmbedCache::default())),
372 asi: None,
373 asi_config: None,
374 quality_gate: None,
375 coe: None,
376 embed_timeout_ms: 5000,
377 asi_tasks: Arc::new(Mutex::new(tokio::task::JoinSet::new())),
378 }
379 }
380
381 #[must_use]
393 pub fn with_embed_timeout(mut self, timeout_ms: u64) -> Self {
394 self.embed_timeout_ms = timeout_ms;
395 self
396 }
397
398 #[must_use]
402 pub fn with_embed_concurrency(mut self, limit: usize) -> Self {
403 self.state.embed_semaphore = if limit > 0 {
404 Some(Arc::new(tokio::sync::Semaphore::new(limit)))
405 } else {
406 None
407 };
408 self
409 }
410
411 pub fn set_memory_confidence(&self, confidence: Option<f32>) {
416 let raw = confidence.map_or(u32::MAX, f32::to_bits);
417 self.state
418 .last_memory_confidence
419 .store(raw, std::sync::atomic::Ordering::Relaxed);
420 }
421
422 #[must_use]
424 pub fn with_ema(mut self, alpha: f64, reorder_interval: u64) -> Self {
425 self.ema = Some(EmaTracker::new(alpha, reorder_interval));
426 self
427 }
428
429 #[must_use]
436 pub fn with_coe(
437 mut self,
438 config: coe::CoeConfig,
439 secondary: AnyProvider,
440 embed: AnyProvider,
441 ) -> Self {
442 if matches!(
443 self.strategy,
444 RouterStrategy::Cascade | RouterStrategy::Bandit
445 ) {
446 tracing::warn!(
447 strategy = ?self.strategy,
448 "coe disabled for strategy; supported: ema, thompson"
449 );
450 return self;
451 }
452 self.coe = Some(Arc::new(CoeRouter {
453 config,
454 secondary: Arc::new(secondary) as Arc<dyn crate::provider_dyn::LlmProviderDyn>,
455 embed: Arc::new(embed) as Arc<dyn crate::provider_dyn::LlmProviderDyn>,
456 metrics: Arc::new(coe::CoeMetrics::default()),
457 }));
458 self
459 }
460
461 #[must_use]
463 pub fn coe_metrics(&self) -> Option<(u64, u64, u64, u64)> {
464 self.coe.as_ref().map(|c| {
465 (
466 c.metrics.kept_primary.load(Ordering::Relaxed),
467 c.metrics.intra_escalations.load(Ordering::Relaxed),
468 c.metrics.inter_escalations.load(Ordering::Relaxed),
469 c.metrics.embed_failures.load(Ordering::Relaxed),
470 )
471 })
472 }
473
474 #[must_use]
481 pub fn with_asi(mut self, config: AsiRouterConfig) -> Self {
482 self.asi = Some(Arc::new(Mutex::new(AsiState::default())));
483 self.asi_config = Some(config);
484 self
485 }
486
487 #[must_use]
494 pub fn with_quality_gate(mut self, threshold: f32) -> Self {
495 self.quality_gate = Some(threshold);
496 self
497 }
498
499 #[must_use]
504 pub fn with_thompson(mut self, state_path: Option<&Path>) -> Self {
505 self.strategy = RouterStrategy::Thompson;
506 let path = state_path.map_or_else(ThompsonState::default_path, Path::to_path_buf);
507 let mut state = blocking_load(|| ThompsonState::load(&path));
508 let known: std::collections::HashSet<String> = self
510 .state
511 .providers
512 .iter()
513 .map(|p| p.name().to_owned())
514 .collect();
515 state.prune(&known);
516 self.thompson = Some(Arc::new(Mutex::new(state)));
517 self.thompson_state_path = Some(path);
518 self
519 }
520
521 #[must_use]
535 pub fn with_bandit(
536 mut self,
537 mut config: BanditRouterConfig,
538 state_path: Option<&Path>,
539 embedding_provider: Option<AnyProvider>,
540 ) -> Self {
541 self.strategy = RouterStrategy::Bandit;
542 let n = self.state.providers.len();
543 if config.warmup_queries == 0 {
544 config.warmup_queries = u64::try_from(10 * n.max(1)).unwrap_or(100);
545 }
546 let cache_size = config.cache_size;
547 let path = state_path.map_or_else(BanditState::default_path, Path::to_path_buf);
548 let mut state = blocking_load(|| BanditState::load(&path));
549 if state.dim == 0 {
550 state = BanditState::new(config.dim);
551 } else if state.dim != config.dim {
552 tracing::warn!(
554 old_dim = state.dim,
555 new_dim = config.dim,
556 "bandit: dim changed, resetting state"
557 );
558 state = BanditState::new(config.dim);
559 }
560 if config.alpha <= 0.0 {
562 tracing::warn!(alpha = config.alpha, "bandit: alpha <= 0, clamping to 0.01");
563 config.alpha = 0.01;
564 }
565 if config.dim == 0 || config.dim > 256 {
566 tracing::warn!(
567 dim = config.dim,
568 "bandit: dim out of range [1, 256], clamping to 32"
569 );
570 config.dim = 32;
571 }
572 if config.decay_factor <= 0.0 || config.decay_factor > 1.0 {
573 tracing::warn!(
574 decay_factor = config.decay_factor,
575 "bandit: decay_factor out of (0.0, 1.0], clamping to 1.0"
576 );
577 config.decay_factor = 1.0;
578 }
579 if config.decay_factor < 1.0 {
580 state.apply_decay(config.decay_factor);
581 }
582 let known: std::collections::HashSet<String> = self
583 .state
584 .providers
585 .iter()
586 .map(|p| p.name().to_owned())
587 .collect();
588 state.prune(&known);
589 self.bandit = Some(Arc::new(Mutex::new(state)));
590 self.bandit_state_path = Some(path);
591 self.bandit_embed_cache = Arc::new(Mutex::new(BanditEmbedCache::new(cache_size)));
592 self.bandit_embedding_provider =
593 embedding_provider.map(|p| Arc::new(p) as Arc<dyn crate::provider_dyn::LlmProviderDyn>);
594 self.thompson = Some(Arc::new(Mutex::new(ThompsonState::default())));
597 self.bandit_config = Some(config);
598 self
599 }
600
601 pub async fn save_bandit_state(&self) {
605 let (Some(bandit), Some(path)) = (&self.bandit, &self.bandit_state_path) else {
606 return;
607 };
608 let bandit = Arc::clone(bandit);
609 let path = path.clone();
610 tokio::task::spawn_blocking(move || {
611 let state = bandit.lock();
612 if let Err(e) = state.save(&path) {
613 tracing::warn!(error = %e, "failed to save bandit state");
614 }
615 })
616 .await
617 .unwrap_or_else(|e| tracing::warn!(error = %e, "bandit state save task panicked"));
618 }
619
620 #[must_use]
624 pub fn bandit_stats(&self) -> Vec<(String, u64, f32)> {
625 let Some(ref bandit) = self.bandit else {
626 return vec![];
627 };
628 let state = bandit.lock();
629 state.stats()
630 }
631
632 #[must_use]
640 pub fn with_reputation(
641 mut self,
642 decay_factor: f64,
643 weight: f64,
644 min_observations: u64,
645 state_path: Option<&Path>,
646 ) -> Self {
647 let path = state_path.map_or_else(ReputationTracker::default_path, Path::to_path_buf);
648 let mut tracker = blocking_load(|| ReputationTracker::load(&path));
650 let known: std::collections::HashSet<String> = self
651 .state
652 .providers
653 .iter()
654 .map(|p| p.name().to_owned())
655 .collect();
656 tracker.apply_decay();
657 tracker.prune(&known);
658 let tracker = {
660 let stats = tracker.stats();
661 let mut t = ReputationTracker::new(decay_factor, min_observations);
662 for (name, alpha, beta, _, obs) in stats {
663 t.models.insert(
664 name,
665 reputation::ReputationEntry {
666 dist: thompson::BetaDist { alpha, beta },
667 observations: obs,
668 },
669 );
670 }
671 t
672 };
673 self.reputation = Some(Arc::new(Mutex::new(tracker)));
674 self.reputation_state_path = Some(path);
675 self.reputation_weight = weight.clamp(0.0, 1.0);
676 self
677 }
678
679 pub fn record_quality_outcome(&self, _provider_name: &str, success: bool) {
689 if matches!(
690 self.strategy,
691 RouterStrategy::Cascade | RouterStrategy::Bandit
692 ) {
693 return;
696 }
697 let Some(ref reputation) = self.reputation else {
698 return;
699 };
700 let active = self.state.last_active_provider.lock().clone();
701 let Some(provider_name) = active else {
702 return;
703 };
704 let mut tracker = reputation.lock();
705 tracker.record_quality(&provider_name, success);
706 }
707
708 #[must_use]
714 pub fn last_selected_provider_kind(&self) -> &'static str {
715 let name = self.state.last_active_provider.lock().clone();
716 let Some(name) = name else {
717 return "local";
718 };
719 self.state
720 .providers
721 .iter()
722 .find(|p| p.name() == name)
723 .map_or("local", |p| p.provider_kind_str())
724 }
725
726 pub async fn save_reputation_state(&self) {
729 let (Some(reputation), Some(path)) = (&self.reputation, &self.reputation_state_path) else {
730 return;
731 };
732 let reputation = Arc::clone(reputation);
733 let path = path.clone();
734 tokio::task::spawn_blocking(move || {
735 let state = reputation.lock();
736 if let Err(e) = state.save(&path) {
737 tracing::warn!(error = %e, "failed to save reputation state");
738 }
739 })
740 .await
741 .unwrap_or_else(|e| tracing::warn!(error = %e, "reputation state save task panicked"));
742 }
743
744 #[must_use]
746 pub fn reputation_stats(&self) -> Vec<(String, f64, f64, f64, u64)> {
747 let Some(ref reputation) = self.reputation else {
748 return vec![];
749 };
750 let tracker = reputation.lock();
751 tracker.stats()
752 }
753
754 #[must_use]
768 pub fn with_cascade(mut self, config: CascadeRouterConfig) -> Self {
769 self.strategy = RouterStrategy::Cascade;
770
771 if let Some(ref tiers) = config.cost_tiers
772 && !tiers.is_empty()
773 {
774 let provider_names: std::collections::HashSet<&str> =
775 self.state.providers.iter().map(AnyProvider::name).collect();
776 for name in tiers {
777 if !provider_names.contains(name.as_str()) {
778 tracing::warn!(
779 name = %name,
780 "cascade: cost_tiers entry does not match any provider name"
781 );
782 }
783 }
784
785 let tier_pos: std::collections::HashMap<&str, usize> = tiers
786 .iter()
787 .enumerate()
788 .map(|(i, n)| (n.as_str(), i))
789 .collect();
790
791 let before: Vec<_> = self
792 .state
793 .providers
794 .iter()
795 .map(|p| p.name().to_owned())
796 .collect();
797 let mut indexed: Vec<(usize, AnyProvider)> =
798 self.state.providers.iter().cloned().enumerate().collect();
799 indexed.sort_by_key(|(orig_idx, p)| {
800 tier_pos
801 .get(p.name())
802 .copied()
803 .map_or((1usize, *orig_idx), |t| (0, t))
804 });
805 let after: Vec<_> = indexed.iter().map(|(_, p)| p.name().to_owned()).collect();
806 if before != after {
807 tracing::debug!(
808 before = ?before,
809 after = ?after,
810 "cascade: providers reordered by cost_tiers"
811 );
812 }
813 self.state.providers =
814 Arc::from(indexed.into_iter().map(|(_, p)| p).collect::<Vec<_>>());
815 }
816
817 let window = config.window_size;
818 self.cascade_state = Some(Arc::new(Mutex::new(CascadeState::new(window))));
819 self.cascade_config = Some(config);
820 self
821 }
822
823 pub async fn save_thompson_state(&self) {
830 let (Some(thompson), Some(path)) = (&self.thompson, &self.thompson_state_path) else {
831 return;
832 };
833 let thompson = Arc::clone(thompson);
834 let path = path.clone();
835 tokio::task::spawn_blocking(move || {
836 let state = thompson.lock();
837 if let Err(e) = state.save(&path) {
838 tracing::warn!(error = %e, "failed to save Thompson router state");
839 }
840 })
841 .await
842 .unwrap_or_else(|e| tracing::warn!(error = %e, "Thompson state save task panicked"));
843 }
844
845 fn query_hash(query: &str) -> u64 {
847 use std::hash::{Hash as _, Hasher as _};
848 let mut h = std::collections::hash_map::DefaultHasher::new();
849 query.hash(&mut h);
850 h.finish()
851 }
852
853 async fn bandit_features(&self, query: &str) -> Option<Vec<f32>> {
860 let cfg = self.bandit_config.as_ref()?;
861 let key = Self::query_hash(query);
862
863 {
865 let cache = self.bandit_embed_cache.lock();
866 if let Some(cached) = cache.get(key) {
867 return Some(cached.clone());
868 }
869 }
870
871 let provider = self.bandit_embedding_provider.as_ref()?;
872 let timeout = std::time::Duration::from_millis(cfg.embedding_timeout_ms);
873 let embed_future = provider.embed(query);
874 let embedding = match tokio::time::timeout(timeout, embed_future).await {
875 Ok(Ok(emb)) => emb,
876 Ok(Err(e)) => {
877 tracing::debug!(error = %e, "bandit: embedding failed, falling back");
878 return None;
879 }
880 Err(_) => {
881 tracing::debug!(
882 timeout_ms = cfg.embedding_timeout_ms,
883 "bandit: embedding timed out, falling back"
884 );
885 return None;
886 }
887 };
888
889 let features = embedding_to_features(&embedding, cfg.dim)?;
890
891 {
893 let mut cache = self.bandit_embed_cache.lock();
894 cache.insert(key, features.clone());
895 }
896 Some(features)
897 }
898
899 async fn bandit_select_provider(&self, query: &str) -> Option<AnyProvider> {
905 let Some(ref bandit_arc) = self.bandit else {
906 return self.state.providers.first().cloned();
907 };
908 let cfg = self.bandit_config.as_ref()?;
909
910 let names: Vec<String> = self
911 .state
912 .providers
913 .iter()
914 .map(|p| p.name().to_owned())
915 .collect();
916
917 if let Some(features) = self.bandit_features(query).await {
919 let raw = self
920 .state
921 .last_memory_confidence
922 .load(std::sync::atomic::Ordering::Relaxed);
923 let memory_confidence = if raw == u32::MAX {
924 None
925 } else {
926 Some(f32::from_bits(raw))
927 };
928 let selected = {
929 let state = bandit_arc.lock();
930 state.select(
931 &names,
932 &features,
933 cfg.alpha,
934 cfg.warmup_queries,
935 &|_| true,
936 cfg.cost_weight,
937 &self.state.provider_models,
938 memory_confidence,
939 cfg.memory_confidence_threshold,
940 )
941 };
942 if let Some(name) = selected {
943 tracing::debug!(
944 provider = %name,
945 strategy = "bandit",
946 memory_confidence = ?memory_confidence,
947 "selected provider"
948 );
949 return self
950 .state
951 .providers
952 .iter()
953 .find(|p| p.name() == name)
954 .cloned();
955 }
956 }
957
958 if let Some(ref thompson) = self.thompson {
960 let mut state = thompson.lock();
961 if let Some(sel) = state.select(&names) {
962 tracing::debug!(
963 provider = %sel.provider,
964 strategy = "bandit-fallback-thompson",
965 "selected provider"
966 );
967 return self
968 .state
969 .providers
970 .iter()
971 .find(|p| p.name() == sel.provider)
972 .cloned();
973 }
974 }
975
976 self.state.providers.first().cloned()
978 }
979
980 fn bandit_record_reward(
985 &self,
986 provider_name: &str,
987 features: &[f32],
988 quality_score: f64,
989 cost_fraction: f64,
990 ) {
991 let Some(ref bandit_arc) = self.bandit else {
992 return;
993 };
994 let Some(cfg) = &self.bandit_config else {
995 return;
996 };
997 #[allow(clippy::cast_possible_truncation)]
998 let reward = (quality_score as f32) - cfg.cost_weight * (cost_fraction as f32);
999 let reward = reward.clamp(-1.0, 1.0);
1000 let mut state = bandit_arc.lock();
1001 state.update(provider_name, features, reward);
1002 tracing::debug!(
1003 provider = provider_name,
1004 reward,
1005 quality = quality_score,
1006 "bandit: recorded reward"
1007 );
1008 }
1009
1010 fn ordered_providers(&self) -> Vec<AnyProvider> {
1011 match self.strategy {
1012 RouterStrategy::Thompson => self.thompson_ordered_providers(),
1013 RouterStrategy::Ema => self.ema_ordered_providers(),
1014 RouterStrategy::Cascade | RouterStrategy::Bandit => self.state.providers.to_vec(),
1018 }
1019 }
1020
1021 fn ema_ordered_providers(&self) -> Vec<AnyProvider> {
1022 let order = self.state.provider_order.lock();
1023 let mut ordered: Vec<AnyProvider> = order
1024 .iter()
1025 .filter_map(|&i| self.state.providers.get(i).cloned())
1026 .collect();
1027
1028 if let Some(ref reputation) = self.reputation
1035 && let Some(ref ema) = self.ema
1036 {
1037 let rep = reputation.lock();
1038 let w = self.reputation_weight;
1039 let snap = ema.snapshot();
1040 let mut scored: Vec<(usize, f64)> = ordered
1041 .iter()
1042 .enumerate()
1043 .map(|(idx, p)| {
1044 let ema_score = snap
1045 .get(p.name())
1046 .map_or(0.0, |s| s.success_ema - s.latency_ema_ms / 10_000.0);
1047 let score = if let Some(rep_factor) = rep.ema_reputation_factor(p.name()) {
1048 let adjustment = 1.0 + w * (rep_factor - 0.5) * 2.0;
1050 ema_score * adjustment
1051 } else {
1052 ema_score
1053 };
1054 (idx, score)
1055 })
1056 .collect();
1057 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1058 let reordered: Vec<AnyProvider> = scored
1059 .into_iter()
1060 .filter_map(|(idx, _)| ordered.get(idx).cloned())
1061 .collect();
1062 ordered = reordered;
1063 }
1064
1065 if let (Some(asi_arc), Some(asi_cfg)) = (&self.asi, &self.asi_config) {
1067 let asi: parking_lot::MutexGuard<'_, AsiState> = asi_arc.lock();
1068 let snap = self.ema.as_ref().map(EmaTracker::snapshot);
1069 let mut scored: Vec<(usize, f64)> = ordered
1070 .iter()
1071 .enumerate()
1072 .map(|(idx, p)| {
1073 let coherence = asi.coherence(p.name());
1074 if coherence < asi_cfg.coherence_threshold {
1075 let now = std::time::SystemTime::now()
1076 .duration_since(std::time::UNIX_EPOCH)
1077 .unwrap_or(std::time::Duration::MAX)
1078 .as_secs();
1079 let last = ASI_WARN_LAST_SECS.load(Ordering::Relaxed);
1080 if now.saturating_sub(last) >= 60
1081 && ASI_WARN_LAST_SECS
1082 .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
1083 .is_ok()
1084 {
1085 tracing::warn!(
1086 provider = p.name(),
1087 coherence,
1088 threshold = asi_cfg.coherence_threshold,
1089 "asi: coherence below threshold"
1090 );
1091 } else {
1092 tracing::trace!(
1093 provider = p.name(),
1094 coherence,
1095 threshold = asi_cfg.coherence_threshold,
1096 "asi: coherence below threshold (warn rate-limited)"
1097 );
1098 }
1099 }
1100 let base_score = snap
1101 .as_ref()
1102 .and_then(|s| s.get(p.name()))
1103 .map_or(0.0, |s| s.success_ema - s.latency_ema_ms / 10_000.0);
1104 let multiplier = (coherence / asi_cfg.coherence_threshold).clamp(0.5, 1.0);
1106 #[allow(clippy::cast_possible_truncation)]
1107 let adjusted = base_score * f64::from(multiplier);
1108 (idx, adjusted)
1109 })
1110 .collect();
1111 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1112 let reordered: Vec<AnyProvider> = scored
1113 .into_iter()
1114 .filter_map(|(idx, _)| ordered.get(idx).cloned())
1115 .collect();
1116 ordered = reordered;
1117 }
1118
1119 if let Some(first) = ordered.first() {
1120 tracing::debug!(
1121 provider = %first.name(),
1122 strategy = "ema",
1123 "selected provider"
1124 );
1125 }
1126 ordered
1127 }
1128
1129 fn thompson_ordered_providers(&self) -> Vec<AnyProvider> {
1130 let Some(ref thompson) = self.thompson else {
1131 return self.state.providers.to_vec();
1132 };
1133 let mut state = thompson.lock();
1134 let names: Vec<String> = self
1135 .state
1136 .providers
1137 .iter()
1138 .map(|p| p.name().to_owned())
1139 .collect();
1140
1141 let has_reputation = self.reputation.is_some();
1144 let has_asi = self.asi.is_some() && self.asi_config.is_some();
1145
1146 let selected = if has_reputation || has_asi {
1147 let rep_guard = self.reputation.as_ref().map(|r| r.lock());
1149 let asi_guard: Option<parking_lot::MutexGuard<'_, AsiState>> =
1150 self.asi.as_ref().map(|a| a.lock());
1151 let w = self.reputation_weight;
1152
1153 let overrides: std::collections::HashMap<String, (f64, f64)> = names
1154 .iter()
1155 .map(|name| {
1156 let base = state.get_distribution(name);
1157 let (alpha, mut beta) = if let Some(ref rep) = rep_guard {
1159 rep.shift_thompson_priors(name, base.alpha, base.beta, w)
1160 } else {
1161 (base.alpha, base.beta)
1162 };
1163 if let (Some(asi), Some(asi_cfg)) = (&asi_guard, &self.asi_config) {
1165 let coherence = asi.coherence(name);
1166 if coherence < asi_cfg.coherence_threshold {
1167 let now = std::time::SystemTime::now()
1168 .duration_since(std::time::UNIX_EPOCH)
1169 .unwrap_or(std::time::Duration::MAX)
1170 .as_secs();
1171 let last = ASI_WARN_LAST_SECS.load(Ordering::Relaxed);
1172 if now.saturating_sub(last) >= 60
1173 && ASI_WARN_LAST_SECS
1174 .compare_exchange(
1175 last,
1176 now,
1177 Ordering::Relaxed,
1178 Ordering::Relaxed,
1179 )
1180 .is_ok()
1181 {
1182 tracing::warn!(
1183 provider = name.as_str(),
1184 coherence,
1185 threshold = asi_cfg.coherence_threshold,
1186 "asi: coherence below threshold"
1187 );
1188 } else {
1189 tracing::trace!(
1190 provider = name.as_str(),
1191 coherence,
1192 threshold = asi_cfg.coherence_threshold,
1193 "asi: coherence below threshold (warn rate-limited)"
1194 );
1195 }
1196 let deficit = asi_cfg.coherence_threshold - coherence;
1197 let penalty = f64::from(asi_cfg.penalty_weight * deficit);
1198 beta += penalty;
1199 }
1200 }
1201 (name.clone(), (alpha, beta))
1202 })
1203 .collect();
1204
1205 drop(rep_guard);
1206 drop(asi_guard);
1207 state.select_with_priors(&names, &overrides)
1208 } else {
1209 state.select(&names)
1210 };
1211
1212 if let Some(ref sel) = selected {
1213 tracing::debug!(
1214 provider = %sel.provider,
1215 strategy = "thompson",
1216 mode = if sel.exploit { "exploit" } else { "explore" },
1217 alpha = sel.alpha,
1218 beta = sel.beta,
1219 "selected provider"
1220 );
1221 }
1222 let mut ordered = self.state.providers.to_vec();
1224 if let Some(ref sel) = selected
1225 && let Some(pos) = ordered.iter().position(|p| p.name() == sel.provider)
1226 {
1227 ordered.swap(0, pos);
1228 }
1229 ordered
1230 }
1231
1232 fn record_availability(&self, provider_name: &str, success: bool, latency_ms: u64) {
1238 match self.strategy {
1239 RouterStrategy::Thompson => {
1240 if let Some(ref thompson) = self.thompson {
1241 let mut state = thompson.lock();
1242 state.update(provider_name, success);
1243 }
1244 }
1245 RouterStrategy::Ema => {
1246 self.ema_record(provider_name, success, latency_ms);
1247 }
1248 RouterStrategy::Cascade | RouterStrategy::Bandit => {
1249 }
1252 }
1253 }
1254
1255 fn ema_record(&self, provider_name: &str, success: bool, latency_ms: u64) {
1256 let Some(ref ema) = self.ema else {
1257 return;
1258 };
1259 ema.record(provider_name, success, latency_ms);
1260 let current_names: Vec<String> = self
1261 .state
1262 .providers
1263 .iter()
1264 .map(|p| p.name().to_owned())
1265 .collect();
1266 if let Some(new_order_names) = ema.maybe_reorder(¤t_names) {
1267 let name_to_idx: std::collections::HashMap<&str, usize> = self
1268 .state
1269 .providers
1270 .iter()
1271 .enumerate()
1272 .map(|(i, p)| (p.name(), i))
1273 .collect();
1274 let new_order: Vec<usize> = new_order_names
1275 .iter()
1276 .filter_map(|n| name_to_idx.get(n.as_str()).copied())
1277 .collect();
1278 let mut order = self.state.provider_order.lock();
1279 *order = new_order;
1280 }
1281 }
1282
1283 #[must_use]
1287 pub fn thompson_stats(&self) -> Vec<(String, f64, f64)> {
1288 let Some(ref thompson) = self.thompson else {
1289 return vec![];
1290 };
1291 let state = thompson.lock();
1292 state.provider_stats()
1293 }
1294
1295 pub fn set_status_tx(&mut self, tx: StatusTx) {
1296 if let Some(providers) = Arc::get_mut(&mut self.state.providers) {
1297 for p in providers {
1298 p.set_status_tx(tx.clone());
1299 }
1300 } else {
1301 let mut v: Vec<_> = self.state.providers.iter().cloned().collect();
1303 for p in &mut v {
1304 p.set_status_tx(tx.clone());
1305 }
1306 self.state.providers = Arc::from(v);
1307 }
1308 self.status_tx = Some(tx);
1309 }
1310
1311 pub async fn list_models_remote(
1319 &self,
1320 ) -> Result<Vec<crate::model_cache::RemoteModelInfo>, LlmError> {
1321 let mut seen = std::collections::HashSet::new();
1322 let mut all = Vec::new();
1323 for p in self.state.providers.iter() {
1324 match p.list_models_remote().await {
1325 Ok(models) => {
1326 for m in models {
1327 if seen.insert(m.id.clone()) {
1328 all.push(m);
1329 }
1330 }
1331 }
1332 Err(e) => {
1333 tracing::warn!(error = %e, "router: list_models_remote sub-provider failed");
1334 }
1335 }
1336 }
1337 Ok(all)
1338 }
1339
1340 fn evaluate_heuristic(response: &str, threshold: f64) -> cascade::QualityVerdict {
1342 let mut verdict = heuristic_score(response);
1343 verdict.should_escalate = verdict.score < threshold;
1344 verdict
1345 }
1346
1347 async fn evaluate_quality(
1352 response: &str,
1353 threshold: f64,
1354 mode: ClassifierMode,
1355 summary_provider: Option<&dyn crate::provider_dyn::LlmProviderDyn>,
1356 judge_timeout_ms: u64,
1357 ) -> cascade::QualityVerdict {
1358 if mode == ClassifierMode::Judge {
1359 if let Some(judge) = summary_provider {
1360 match cascade::judge_score(
1361 judge,
1362 response,
1363 std::time::Duration::from_millis(judge_timeout_ms),
1364 )
1365 .await
1366 {
1367 Some(score) => {
1368 let should_escalate = score < threshold;
1369 tracing::debug!(
1370 score,
1371 threshold,
1372 should_escalate,
1373 "cascade: judge scored response"
1374 );
1375 return cascade::QualityVerdict {
1376 score,
1377 should_escalate,
1378 reason: format!("judge score: {score:.2}"),
1379 };
1380 }
1381 None => {
1382 tracing::warn!("cascade: judge call failed, falling back to heuristic");
1383 }
1384 }
1385 } else {
1386 tracing::warn!(
1387 "cascade: classifier_mode=judge but no summary_provider configured, \
1388 using heuristic"
1389 );
1390 }
1391 }
1392 Self::evaluate_heuristic(response, threshold)
1393 }
1394}
1395
1396const EMBED_MAX_RETRIES: u32 = 3;
1397const EMBED_BASE_DELAY_MS: u64 = 500;
1398
1399impl RouterProvider {
1400 async fn embed_cached(
1406 &self,
1407 text: &str,
1408 cache: &Mutex<TurnEmbedCache>,
1409 ) -> Result<Vec<f32>, crate::error::LlmError> {
1410 self.state.embed_call_count.fetch_add(1, Ordering::Relaxed);
1411 if let Some(emb) = cache.lock().get(text) {
1412 self.state.embed_cache_hits.fetch_add(1, Ordering::Relaxed);
1413 return Ok(emb.clone());
1414 }
1415 let emb = self.embed(text).await?;
1416 cache.lock().insert(text, emb.clone());
1417 Ok(emb)
1418 }
1419
1420 #[must_use]
1422 pub fn embed_cache_metrics(&self) -> (u64, u64) {
1423 (
1424 self.state.embed_call_count.load(Ordering::Relaxed),
1425 self.state.embed_cache_hits.load(Ordering::Relaxed),
1426 )
1427 }
1428
1429 fn spawn_asi_update(
1441 &self,
1442 provider: &str,
1443 response: String,
1444 turn_id: u64,
1445 precomputed_embedding: Option<Vec<f32>>,
1446 ) {
1447 let prev = self.state.asi_last_turn.swap(turn_id, Ordering::AcqRel);
1451 if prev == turn_id {
1452 return;
1453 }
1454
1455 let Some(ref asi_arc) = self.asi else { return };
1456 let Some(ref asi_cfg) = self.asi_config else {
1457 return;
1458 };
1459
1460 let mut tasks = self.asi_tasks.lock();
1461 while tasks.try_join_next().is_some() {}
1463 if tasks.len() >= MAX_ASI_TASKS {
1464 tracing::debug!("asi: task limit reached, skipping coherence update");
1465 return;
1466 }
1467
1468 let asi = Arc::clone(asi_arc);
1469 let router = self.clone();
1470 let window_size = asi_cfg.window;
1471 let provider_name = provider.to_owned();
1472 let embed_timeout_ms = self.embed_timeout_ms;
1473 tasks.spawn(async move {
1474 let emb = if let Some(e) = precomputed_embedding {
1475 e
1476 } else {
1477 let embed_fut = router.embed(&response);
1478 let embed_result = if embed_timeout_ms > 0 {
1479 let timeout = std::time::Duration::from_millis(embed_timeout_ms);
1480 if let Ok(r) = tokio::time::timeout(timeout, embed_fut).await {
1481 r
1482 } else {
1483 tracing::debug!(
1484 provider = provider_name,
1485 timeout_ms = embed_timeout_ms,
1486 "asi: embed timed out, skipping coherence update"
1487 );
1488 return;
1489 }
1490 } else {
1491 embed_fut.await
1492 };
1493 match embed_result {
1494 Ok(e) => e,
1495 Err(err) => {
1496 tracing::debug!(
1497 provider = provider_name,
1498 error = %err,
1499 "asi: embed failed, skipping coherence update"
1500 );
1501 return;
1502 }
1503 }
1504 };
1505 let mut state = asi.lock();
1506 state.push_embedding(&provider_name, emb, window_size);
1507 });
1508 }
1509}
1510
1511fn record_fallback_error(
1516 router: &RouterProvider,
1517 provider_name: &str,
1518 error: &LlmError,
1519 elapsed_ms: u64,
1520 status_tx: Option<&StatusTx>,
1521 log_msg: &'static str,
1522) {
1523 router.record_availability(provider_name, false, elapsed_ms);
1524 if error.is_rate_limited() {
1525 router.record_availability(provider_name, false, 0);
1526 }
1527 if let Some(tx) = status_tx {
1528 let _ = tx.send(format!("router: {provider_name} failed, falling back"));
1529 }
1530 tracing::warn!(provider = provider_name, error = %error, "{}", log_msg);
1531}
1532
1533impl LlmProvider for RouterProvider {
1534 fn context_window(&self) -> Option<usize> {
1535 self.state
1536 .providers
1537 .first()
1538 .and_then(LlmProvider::context_window)
1539 }
1540
1541 #[allow(clippy::too_many_lines)] fn chat(
1543 &self,
1544 messages: &[Message],
1545 ) -> impl std::future::Future<Output = Result<String, LlmError>> + Send {
1546 let status_tx = self.status_tx.clone();
1547 let messages = messages.to_vec();
1548 let router = self.clone();
1549 #[cfg(feature = "profiling")]
1550 let model = self.model_identifier().to_owned();
1551 let fut = Box::pin(async move {
1555 let turn_id = router.state.turn_counter.fetch_add(1, Ordering::Relaxed);
1559
1560 tracing::info!(
1561 strategy = ?router.strategy,
1562 turn_id,
1563 provider_count = router.state.providers.len(),
1564 "llm.router.select"
1565 );
1566
1567 if router.strategy == RouterStrategy::Cascade {
1568 return router
1571 .cascade_chat(&router.state.providers, &messages, status_tx)
1572 .await;
1573 }
1574 if router.strategy == RouterStrategy::Bandit {
1575 return router.bandit_chat(&messages, status_tx).await;
1576 }
1577 let providers = router.ordered_providers();
1578
1579 let turn_cache = Mutex::new(TurnEmbedCache::default());
1582
1583 let query_text = messages
1585 .last()
1586 .map(Message::to_llm_content)
1587 .unwrap_or_default();
1588 let query_embedding = if router.quality_gate.is_some() && !query_text.is_empty() {
1589 router.embed_cached(query_text, &turn_cache).await.ok()
1590 } else {
1591 None
1592 };
1593
1594 let mut best_response: Option<(f32, String)> = None;
1596
1597 for p in &providers {
1598 let start = std::time::Instant::now();
1599 match p.chat_with_extras(&messages).await {
1600 Ok((r, extras)) => {
1601 router.record_availability(
1602 p.name(),
1603 true,
1604 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1605 );
1606
1607 if let (Some(threshold), Some(qemb)) =
1609 (router.quality_gate, &query_embedding)
1610 {
1611 let resp_emb = router.embed_cached(&r, &turn_cache).await.ok();
1612 let similarity = resp_emb
1613 .as_ref()
1614 .map_or(threshold, |e| cosine_similarity(qemb, e)); if similarity < threshold {
1616 tracing::info!(
1617 provider = p.name(),
1618 score = similarity,
1619 threshold,
1620 "thompson_quality_fallback"
1621 );
1622 let is_better = best_response
1624 .as_ref()
1625 .is_none_or(|(best, _)| similarity > *best);
1626 if is_better {
1627 best_response = Some((similarity, r.clone()));
1628 }
1629 router.spawn_asi_update(p.name(), r, turn_id, resp_emb);
1631 continue;
1632 }
1633 router.spawn_asi_update(p.name(), r.clone(), turn_id, resp_emb);
1635
1636 if let Some(ref coe_router) = router.coe
1638 && let Ok((final_r, pname, decision)) = run_coe(
1639 coe_router,
1640 p.name().to_owned(),
1641 r.clone(),
1642 extras,
1643 &messages,
1644 )
1645 .await
1646 {
1647 if matches!(
1648 decision,
1649 CoeDecision::EscalateIntra | CoeDecision::EscalateInter
1650 ) {
1651 router.record_quality_outcome(&pname, false);
1652 router
1653 .record_quality_outcome(coe_router.secondary.name(), true);
1654 }
1655 return Ok(final_r);
1656 }
1657
1658 return Ok(r);
1659 }
1660
1661 router.spawn_asi_update(p.name(), r.clone(), turn_id, None);
1663
1664 if let Some(ref coe_router) = router.coe
1666 && let Ok((final_r, pname, decision)) = run_coe(
1667 coe_router,
1668 p.name().to_owned(),
1669 r.clone(),
1670 extras,
1671 &messages,
1672 )
1673 .await
1674 {
1675 if matches!(
1676 decision,
1677 CoeDecision::EscalateIntra | CoeDecision::EscalateInter
1678 ) {
1679 router.record_quality_outcome(&pname, false);
1680 router.record_quality_outcome(coe_router.secondary.name(), true);
1681 }
1682 return Ok(final_r);
1683 }
1684
1685 return Ok(r);
1686 }
1687 Err(e) => {
1688 record_fallback_error(
1689 &router,
1690 p.name(),
1691 &e,
1692 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1693 status_tx.as_ref(),
1694 "router fallback",
1695 );
1696 }
1697 }
1698 }
1699
1700 if let Some((_, response)) = best_response {
1702 return Ok(response);
1703 }
1704
1705 Err(LlmError::NoProviders)
1706 });
1707 #[cfg(feature = "profiling")]
1708 let fut = {
1709 use tracing::Instrument as _;
1710 fut.instrument(tracing::info_span!("llm.router.chat", model = model))
1711 };
1712 fut
1713 }
1714
1715 fn chat_stream(
1716 &self,
1717 messages: &[Message],
1718 ) -> impl std::future::Future<Output = Result<ChatStream, LlmError>> + Send {
1719 let status_tx = self.status_tx.clone();
1720 let messages = messages.to_vec();
1721 let router = self.clone();
1722 #[cfg(feature = "profiling")]
1723 let model = self.model_identifier().to_owned();
1724 let fut = Box::pin(async move {
1725 if router.strategy == RouterStrategy::Cascade {
1728 return router
1730 .cascade_chat_stream(&router.state.providers, &messages, status_tx)
1731 .await;
1732 }
1733 if router.strategy == RouterStrategy::Bandit {
1734 let query = messages
1738 .last()
1739 .map(super::provider::Message::to_llm_content)
1740 .unwrap_or_default();
1741 let p = router
1742 .bandit_select_provider(query)
1743 .await
1744 .ok_or(LlmError::NoProviders)?;
1745 return p.chat_stream(&messages).await;
1746 }
1747 let providers = router.ordered_providers();
1748 for p in &providers {
1749 let start = std::time::Instant::now();
1750 match p.chat_stream(&messages).await {
1751 Ok(r) => {
1752 router.record_availability(
1759 p.name(),
1760 true,
1761 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1762 );
1763 return Ok(r);
1764 }
1765 Err(e) => {
1766 record_fallback_error(
1767 &router,
1768 p.name(),
1769 &e,
1770 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1771 status_tx.as_ref(),
1772 "router stream fallback",
1773 );
1774 }
1775 }
1776 }
1777 Err(LlmError::NoProviders)
1778 });
1779 #[cfg(feature = "profiling")]
1780 let fut = {
1781 use tracing::Instrument as _;
1782 fut.instrument(tracing::info_span!("llm.router.chat_stream", model = model))
1783 };
1784 fut
1785 }
1786
1787 fn supports_streaming(&self) -> bool {
1788 self.state
1789 .providers
1790 .iter()
1791 .any(LlmProvider::supports_streaming)
1792 }
1793
1794 #[allow(clippy::too_many_lines)] fn embed(
1796 &self,
1797 text: &str,
1798 ) -> impl std::future::Future<Output = Result<Vec<f32>, LlmError>> + Send {
1799 let providers = self.ordered_providers();
1800 let status_tx = self.status_tx.clone();
1801 let text = text.to_owned();
1802 let router = self.clone();
1803 let embed_timeout_ms = self.embed_timeout_ms;
1804 #[cfg(feature = "profiling")]
1805 let model = self.model_identifier().to_owned();
1806 let fut = Box::pin(async move {
1807 for p in &providers {
1808 if !p.supports_embeddings() {
1809 continue;
1810 }
1811 let mut last_err: Option<LlmError> = None;
1812 for attempt in 0..=EMBED_MAX_RETRIES {
1813 if attempt > 0 {
1814 let delay = EMBED_BASE_DELAY_MS * (1u64 << (attempt - 1));
1815 tracing::warn!(
1816 provider = p.name(),
1817 attempt,
1818 delay_ms = delay,
1819 "embed: rate limited, retrying after backoff"
1820 );
1821 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
1822 }
1823 let start = std::time::Instant::now();
1824 let embed_result: Result<Vec<f32>, LlmError> = if embed_timeout_ms > 0 {
1826 let timeout = std::time::Duration::from_millis(embed_timeout_ms);
1827 match tokio::time::timeout(timeout, p.embed(&text)).await {
1828 Ok(inner) => inner,
1829 Err(_elapsed) => {
1830 tracing::warn!(
1831 provider = p.name(),
1832 timeout_ms = embed_timeout_ms,
1833 "embed: provider timed out, falling back"
1834 );
1835 last_err = Some(LlmError::Timeout);
1836 break;
1837 }
1838 }
1839 } else {
1840 p.embed(&text).await
1841 };
1842 match embed_result {
1843 Ok(r) => {
1844 router.record_availability(
1845 p.name(),
1846 true,
1847 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1848 );
1849 return Ok(r);
1850 }
1851 Err(e) if e.is_invalid_input() => {
1852 tracing::warn!(
1855 provider = p.name(),
1856 error = %e,
1857 "embed: invalid input, not retrying on other providers"
1858 );
1859 return Err(e);
1860 }
1861 Err(e) if e.is_rate_limited() && attempt < EMBED_MAX_RETRIES => {
1862 last_err = Some(e);
1863 }
1864 Err(e) => {
1865 router.record_availability(
1866 p.name(),
1867 false,
1868 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1869 );
1870 if let Some(ref tx) = status_tx {
1871 let _ = tx.send(format!(
1872 "router: {} embed failed, falling back",
1873 p.name()
1874 ));
1875 }
1876 tracing::warn!(provider = p.name(), error = %e, "router embed fallback");
1877 last_err = Some(e);
1878 break;
1879 }
1880 }
1881 }
1882 if matches!(last_err, Some(ref e) if e.is_rate_limited()) {
1884 router.record_availability(p.name(), false, 0);
1885 if let Some(ref tx) = status_tx {
1886 let _ = tx.send(format!(
1887 "router: {} embed rate limited, falling back",
1888 p.name()
1889 ));
1890 }
1891 tracing::warn!(
1892 provider = p.name(),
1893 "embed: rate limit retries exhausted, falling back"
1894 );
1895 }
1896 }
1897 Err(LlmError::NoProviders)
1898 });
1899 #[cfg(feature = "profiling")]
1900 let fut = {
1901 use tracing::Instrument as _;
1902 fut.instrument(tracing::info_span!("llm.router.embed", model = model))
1903 };
1904 fut
1905 }
1906
1907 fn embed_batch(
1908 &self,
1909 texts: &[&str],
1910 ) -> impl std::future::Future<Output = Result<Vec<Vec<f32>>, LlmError>> + Send {
1911 let providers = self.ordered_providers();
1912 let status_tx = self.status_tx.clone();
1913 let owned = owned_strs(texts);
1914 let router = self.clone();
1915 let semaphore = self.state.embed_semaphore.clone();
1916 #[cfg(feature = "profiling")]
1917 let model = self.model_identifier().to_owned();
1918 let fut = Box::pin(async move {
1919 let _permit = if let Some(ref sem) = semaphore {
1921 Some(sem.acquire().await.map_err(|_| LlmError::NoProviders)?)
1922 } else {
1923 None
1924 };
1925 let refs: Vec<&str> = owned.iter().map(String::as_str).collect();
1926 for p in &providers {
1927 if !p.supports_embeddings() {
1928 continue;
1929 }
1930 let mut last_err: Option<LlmError> = None;
1931 for attempt in 0..=EMBED_MAX_RETRIES {
1932 if attempt > 0 {
1933 let delay = EMBED_BASE_DELAY_MS * (1u64 << (attempt - 1));
1934 tracing::warn!(
1935 provider = p.name(),
1936 attempt,
1937 delay_ms = delay,
1938 "embed_batch: rate limited, retrying after backoff"
1939 );
1940 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
1941 }
1942 let start = std::time::Instant::now();
1943 match p.embed_batch(&refs).await {
1944 Ok(r) => {
1945 router.record_availability(
1946 p.name(),
1947 true,
1948 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1949 );
1950 return Ok(r);
1951 }
1952 Err(e) if e.is_invalid_input() => {
1953 tracing::warn!(
1954 provider = p.name(),
1955 error = %e,
1956 "embed_batch: invalid input, not retrying on other providers"
1957 );
1958 return Err(e);
1959 }
1960 Err(e) if e.is_rate_limited() && attempt < EMBED_MAX_RETRIES => {
1961 last_err = Some(e);
1962 }
1963 Err(e) => {
1964 router.record_availability(
1965 p.name(),
1966 false,
1967 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1968 );
1969 if let Some(ref tx) = status_tx {
1970 let _ = tx.send(format!(
1971 "router: {} embed_batch failed, falling back",
1972 p.name()
1973 ));
1974 }
1975 tracing::warn!(
1976 provider = p.name(),
1977 error = %e,
1978 "router embed_batch fallback"
1979 );
1980 last_err = Some(e);
1981 break;
1982 }
1983 }
1984 }
1985 if matches!(last_err, Some(ref e) if e.is_rate_limited()) {
1987 router.record_availability(p.name(), false, 0);
1988 if let Some(ref tx) = status_tx {
1989 let _ = tx.send(format!(
1990 "router: {} embed_batch rate limited, falling back",
1991 p.name()
1992 ));
1993 }
1994 tracing::warn!(
1995 provider = p.name(),
1996 "embed_batch: rate limit retries exhausted, falling back"
1997 );
1998 }
1999 }
2000 Err(LlmError::NoProviders)
2001 });
2002 #[cfg(feature = "profiling")]
2003 let fut = {
2004 use tracing::Instrument as _;
2005 fut.instrument(tracing::info_span!("llm.router.embed_batch", model = model))
2006 };
2007 fut
2008 }
2009
2010 fn supports_embeddings(&self) -> bool {
2011 self.state
2012 .providers
2013 .iter()
2014 .any(LlmProvider::supports_embeddings)
2015 }
2016
2017 #[allow(clippy::unnecessary_literal_bound)]
2018 fn name(&self) -> &str {
2019 "router"
2020 }
2021
2022 #[allow(clippy::unnecessary_literal_bound)]
2023 fn model_identifier(&self) -> &str {
2024 "router"
2025 }
2026
2027 fn supports_tool_use(&self) -> bool {
2028 self.state
2029 .providers
2030 .iter()
2031 .any(LlmProvider::supports_tool_use)
2032 }
2033
2034 fn list_models(&self) -> Vec<String> {
2035 self.state
2036 .providers
2037 .iter()
2038 .flat_map(super::provider::LlmProvider::list_models)
2039 .collect()
2040 }
2041
2042 #[allow(refining_impl_trait_reachable)]
2043 fn chat_with_tools(
2044 &self,
2045 messages: &[Message],
2046 tools: &[ToolDefinition],
2047 ) -> impl std::future::Future<Output = Result<ChatResponse, LlmError>> + Send {
2048 let messages = messages.to_vec();
2049 #[cfg(feature = "profiling")]
2050 let tool_count = tools.len();
2051 let tools = tools.to_vec();
2052 let status_tx = self.status_tx.clone();
2053 let router = self.clone();
2054 #[cfg(feature = "profiling")]
2055 let model = self.model_identifier().to_owned();
2056 let fut = Box::pin(async move {
2057 if router.strategy == RouterStrategy::Bandit {
2059 let query = messages
2060 .last()
2061 .map(super::provider::Message::to_llm_content)
2062 .unwrap_or_default();
2063 let p = router
2064 .bandit_select_provider(query)
2065 .await
2066 .ok_or(LlmError::NoProviders)?;
2067 if !p.supports_tool_use() {
2068 return Err(LlmError::NoProviders);
2069 }
2070 let result = p.chat_with_tools(&messages, &tools).await;
2071 if result.is_ok() {
2072 *router.state.last_active_provider.lock() = Some(p.name().to_owned());
2073 }
2074 return result;
2075 }
2076
2077 let providers = router.ordered_providers();
2082 for p in &providers {
2083 if !p.supports_tool_use() {
2084 continue;
2085 }
2086 let start = std::time::Instant::now();
2087 match p.chat_with_tools(&messages, &tools).await {
2088 Ok(r) => {
2089 router.record_availability(
2090 p.name(),
2091 true,
2092 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
2093 );
2094 *router.state.last_active_provider.lock() = Some(p.name().to_owned());
2096 return Ok(r);
2097 }
2098 Err(e) => {
2099 router.record_availability(
2100 p.name(),
2101 false,
2102 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
2103 );
2104 if e.is_invalid_input() {
2105 tracing::warn!(
2106 provider = p.name(),
2107 error = %e,
2108 "chat_with_tools: invalid input, not retrying on other providers"
2109 );
2110 return Err(e);
2111 }
2112 if e.is_rate_limited() {
2113 router.record_availability(p.name(), false, 0);
2114 }
2115 if let Some(ref tx) = status_tx {
2116 let _ = tx.send(format!(
2117 "router: {} tool call failed, falling back",
2118 p.name()
2119 ));
2120 }
2121 tracing::warn!(provider = p.name(), error = %e, "router tool fallback");
2122 }
2123 }
2124 }
2125 Err(LlmError::NoProviders)
2126 });
2127 #[cfg(feature = "profiling")]
2128 let fut = {
2129 use tracing::Instrument as _;
2130 fut.instrument(tracing::info_span!(
2131 "llm.router.chat_with_tools",
2132 model = model,
2133 tool_count = tool_count
2134 ))
2135 };
2136 fut
2137 }
2138
2139 fn debug_request_json(
2140 &self,
2141 messages: &[Message],
2142 tools: &[ToolDefinition],
2143 stream: bool,
2144 ) -> serde_json::Value {
2145 let candidate = if tools.is_empty() {
2146 self.ordered_providers().into_iter().next()
2147 } else {
2148 self.ordered_providers()
2149 .into_iter()
2150 .find(super::provider::LlmProvider::supports_tool_use)
2151 };
2152 candidate.map_or_else(
2153 || crate::provider::default_debug_request_json(messages, tools),
2154 |provider| provider.debug_request_json(messages, tools, stream),
2155 )
2156 }
2157
2158 fn last_cache_usage(&self) -> Option<(u64, u64)> {
2159 None
2160 }
2161}
2162
2163impl RouterProvider {
2166 #[cfg_attr(
2168 feature = "profiling",
2169 tracing::instrument(name = "llm.router.bandit_chat", skip_all)
2170 )]
2171 async fn bandit_chat(
2172 &self,
2173 messages: &[Message],
2174 status_tx: Option<StatusTx>,
2175 ) -> Result<String, LlmError> {
2176 let query = messages
2177 .last()
2178 .map(super::provider::Message::to_llm_content)
2179 .unwrap_or_default();
2180 let features = self.bandit_features(query.as_ref()).await;
2181
2182 let p = self
2183 .bandit_select_provider(query.as_ref())
2184 .await
2185 .ok_or(LlmError::NoProviders)?;
2186
2187 if let Some(ref tx) = status_tx {
2188 let _ = tx.send(format!("bandit: routing to {}", p.name()));
2189 }
2190
2191 let result = p.chat(messages).await;
2192 match &result {
2193 Ok(response) => {
2194 let verdict = heuristic_score(response);
2195 let feat_ref: &[f32];
2198 let zero_vec: Vec<f32>;
2199 let dim = self.bandit_config.as_ref().map_or(32, |c| c.dim);
2200 if let Some(ref feat) = features {
2201 feat_ref = feat;
2202 } else {
2203 zero_vec = vec![0.0; dim];
2204 feat_ref = &zero_vec;
2205 tracing::debug!(
2206 provider = p.name(),
2207 "bandit: recording reward with zero features (embed unavailable)"
2208 );
2209 }
2210 self.bandit_record_reward(p.name(), feat_ref, verdict.score, 0.0);
2211 }
2212 Err(e) => {
2213 tracing::warn!(provider = p.name(), error = %e, "bandit: provider failed");
2214 }
2215 }
2216 result
2217 }
2218}
2219
2220struct CascadeEvalResult {
2224 verdict: cascade::QualityVerdict,
2225 tokens_used: u32,
2227 budget_exhausted: bool,
2229}
2230
2231async fn cascade_evaluate_response(
2234 provider_name: &str,
2235 response: &str,
2236 cfg: &CascadeRouterConfig,
2237 cascade_state: &Mutex<CascadeState>,
2238 tokens_used_before: u32,
2239 log_prefix: &str,
2240) -> CascadeEvalResult {
2241 let estimated_tokens =
2242 u32::try_from(zeph_common::text::estimate_tokens(response).max(1)).unwrap_or(u32::MAX);
2243 let tokens_used = tokens_used_before.saturating_add(estimated_tokens);
2244
2245 let verdict = RouterProvider::evaluate_quality(
2246 response,
2247 cfg.quality_threshold,
2248 cfg.classifier_mode,
2249 cfg.summary_provider.as_deref(),
2250 cfg.judge_timeout_ms,
2251 )
2252 .await;
2253
2254 {
2255 let mut state = cascade_state.lock();
2256 state.record(provider_name, verdict.score);
2257 }
2258
2259 tracing::debug!(
2260 provider = %provider_name,
2261 score = verdict.score,
2262 threshold = cfg.quality_threshold,
2263 should_escalate = verdict.should_escalate,
2264 reason = %verdict.reason,
2265 "{log_prefix}: quality verdict"
2266 );
2267
2268 let budget_exhausted = cfg
2269 .max_cascade_tokens
2270 .is_some_and(|budget| tokens_used >= budget);
2271
2272 CascadeEvalResult {
2273 verdict,
2274 tokens_used,
2275 budget_exhausted,
2276 }
2277}
2278
2279impl RouterProvider {
2280 #[cfg_attr(
2284 feature = "profiling",
2285 tracing::instrument(name = "llm.router.cascade_chat", skip_all)
2286 )]
2287 #[allow(clippy::too_many_lines)] async fn cascade_chat(
2289 &self,
2290 providers: &[AnyProvider],
2291 messages: &[Message],
2292 status_tx: Option<StatusTx>,
2293 ) -> Result<String, LlmError> {
2294 let cfg = self
2295 .cascade_config
2296 .as_ref()
2297 .expect("cascade_config must be set");
2298 let cascade_state = self
2299 .cascade_state
2300 .as_ref()
2301 .expect("cascade_state must be set");
2302
2303 let mut escalations_remaining = cfg.max_escalations;
2304 let mut best: Option<(String, f64)> = None; let mut tokens_used: u32 = 0;
2306
2307 for (idx, p) in providers.iter().enumerate() {
2308 tracing::debug!(
2309 provider = %p.name(),
2310 attempt = idx + 1,
2311 total = providers.len(),
2312 classifier_mode = ?cfg.classifier_mode,
2313 quality_threshold = cfg.quality_threshold,
2314 "cascade: trying provider"
2315 );
2316 let start = std::time::Instant::now();
2317 match p.chat(messages).await {
2318 Err(e) => {
2319 let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
2321 self.record_availability(p.name(), false, latency);
2322 if let Some(tx) = &status_tx {
2323 let _ = tx.send(format!("cascade: {} unavailable, trying next", p.name()));
2324 }
2325 tracing::warn!(provider = p.name(), error = %e, "cascade: provider error");
2326 }
2327 Ok(response) => {
2328 let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
2329
2330 let eval = cascade_evaluate_response(
2331 p.name(),
2332 &response,
2333 cfg,
2334 cascade_state,
2335 tokens_used,
2336 "cascade",
2337 )
2338 .await;
2339 tokens_used = eval.tokens_used;
2340 let verdict = eval.verdict;
2341 let budget_exhausted = eval.budget_exhausted;
2342
2343 let is_better = !response.is_empty()
2345 && best
2346 .as_ref()
2347 .is_none_or(|(_, best_score)| verdict.score > *best_score);
2348 if is_better {
2349 tracing::debug!(
2350 provider = %p.name(),
2351 score = verdict.score,
2352 "cascade: best_seen updated"
2353 );
2354 best = Some((response.clone(), verdict.score));
2355 }
2356
2357 let is_last = idx == providers.len() - 1;
2358
2359 if !verdict.should_escalate
2360 || is_last
2361 || escalations_remaining == 0
2362 || budget_exhausted
2363 {
2364 self.record_availability(p.name(), true, latency);
2365 if verdict.should_escalate
2370 && (budget_exhausted || escalations_remaining == 0)
2371 {
2372 let best_response = best.take().map_or(response, |(r, _)| r);
2373 tracing::info!(
2374 tokens_used,
2375 budget = cfg.max_cascade_tokens,
2376 escalations_remaining,
2377 "cascade: escalation blocked, returning best response"
2378 );
2379 return Ok(best_response);
2380 }
2381 return Ok(response);
2382 }
2383
2384 self.record_availability(p.name(), true, latency);
2386 escalations_remaining -= 1;
2387
2388 if let Some(tx) = &status_tx {
2389 let _ = tx.send(format!(
2390 "cascade: {} quality {:.2} < {:.2}, escalating ({} left)",
2391 p.name(),
2392 verdict.score,
2393 cfg.quality_threshold,
2394 escalations_remaining
2395 ));
2396 }
2397 tracing::info!(
2398 provider = %p.name(),
2399 score = verdict.score,
2400 threshold = cfg.quality_threshold,
2401 escalations_remaining,
2402 "cascade: escalating to next provider"
2403 );
2404 }
2405 }
2406 }
2407
2408 if let Some((_, score)) = &best {
2410 tracing::info!(
2411 score,
2412 "cascade: all providers exhausted, returning best-seen response"
2413 );
2414 } else {
2415 tracing::warn!("cascade: all providers failed, no response available");
2416 }
2417 best.map(|(r, _)| r).ok_or(LlmError::NoProviders)
2418 }
2419
2420 #[allow(clippy::too_many_lines)] async fn cascade_chat_stream(
2430 &self,
2431 providers: &[AnyProvider],
2432 messages: &[Message],
2433 status_tx: Option<StatusTx>,
2434 ) -> Result<ChatStream, LlmError> {
2435 let cfg = self
2436 .cascade_config
2437 .as_ref()
2438 .expect("cascade_config must be set");
2439 let cascade_state = self
2440 .cascade_state
2441 .as_ref()
2442 .expect("cascade_state must be set");
2443
2444 let mut escalations_remaining = cfg.max_escalations;
2445 let mut tokens_used: u32 = 0;
2446 let mut best_seen: Option<(CollectedStream, f64)> = None;
2450
2451 let (last, early) = providers.split_last().ok_or(LlmError::NoProviders)?;
2454
2455 for (idx, p) in early.iter().enumerate() {
2456 tracing::debug!(
2457 provider = %p.name(),
2458 attempt = idx + 1,
2459 total = providers.len(),
2460 classifier_mode = ?cfg.classifier_mode,
2461 quality_threshold = cfg.quality_threshold,
2462 "cascade stream: trying provider (buffered)"
2463 );
2464 let start = std::time::Instant::now();
2466 let stream = match p.chat_stream(messages).await {
2467 Err(e) => {
2468 let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
2469 self.record_availability(p.name(), false, latency);
2470 tracing::warn!(provider = p.name(), error = %e, "cascade stream: provider error");
2471 if let Some(tx) = &status_tx {
2472 let _ = tx.send(format!("cascade: {} unavailable, trying next", p.name()));
2473 }
2474 continue;
2475 }
2476 Ok(s) => s,
2477 };
2478
2479 let buffered = collect_stream(stream).await;
2481 let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
2482
2483 match buffered {
2484 Err(e) => {
2485 self.record_availability(p.name(), false, latency);
2487 tracing::warn!(provider = p.name(), error = %e, "cascade stream: stream error");
2488 }
2489 Ok(collected) => {
2490 let eval = cascade_evaluate_response(
2491 p.name(),
2492 &collected.content,
2493 cfg,
2494 cascade_state,
2495 tokens_used,
2496 "cascade stream",
2497 )
2498 .await;
2499 tokens_used = eval.tokens_used;
2500 let verdict = eval.verdict;
2501 let budget_exhausted = eval.budget_exhausted;
2502
2503 let is_better = !collected.is_empty()
2507 && best_seen
2508 .as_ref()
2509 .is_none_or(|(_, best_score)| verdict.score > *best_score);
2510 if is_better {
2511 tracing::debug!(
2512 provider = %p.name(),
2513 score = verdict.score,
2514 "cascade stream: best_seen updated"
2515 );
2516 best_seen = Some((collected.clone(), verdict.score));
2517 }
2518
2519 if !verdict.should_escalate || escalations_remaining == 0 || budget_exhausted {
2520 self.record_availability(p.name(), true, latency);
2521
2522 let response = if verdict.should_escalate
2527 && (budget_exhausted || escalations_remaining == 0)
2528 {
2529 tracing::info!(
2530 tokens_used,
2531 budget = cfg.max_cascade_tokens,
2532 escalations_remaining,
2533 "cascade stream: escalation blocked, returning best response"
2534 );
2535 best_seen.take().map_or(collected, |(r, _)| r)
2536 } else {
2537 collected
2538 };
2539
2540 return Ok(response.into_stream());
2541 }
2542
2543 self.record_availability(p.name(), true, latency);
2545 escalations_remaining -= 1;
2546
2547 if let Some(tx) = &status_tx {
2548 let _ = tx.send(format!(
2549 "cascade: {} quality {:.2} < {:.2}, escalating",
2550 p.name(),
2551 verdict.score,
2552 cfg.quality_threshold,
2553 ));
2554 }
2555 tracing::info!(
2556 provider = %p.name(),
2557 score = verdict.score,
2558 threshold = cfg.quality_threshold,
2559 escalations_remaining,
2560 "cascade stream: escalating to next provider"
2561 );
2562 }
2563 }
2564 }
2565
2566 tracing::debug!(
2571 provider = %last.name(),
2572 attempt = providers.len(),
2573 total = providers.len(),
2574 "cascade stream: trying last provider (streaming, no classification)"
2575 );
2576 let start = std::time::Instant::now();
2577 match last.chat_stream(messages).await {
2578 Ok(stream) => {
2579 self.record_availability(
2580 last.name(),
2581 true,
2582 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
2583 );
2584 Ok(stream)
2585 }
2586 Err(e) => {
2587 self.record_availability(
2588 last.name(),
2589 false,
2590 u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
2591 );
2592 if let Some((best_collected, _)) = best_seen {
2595 tracing::info!(
2596 "cascade stream: last provider failed, returning best-seen response"
2597 );
2598 return Ok(best_collected.into_stream());
2599 }
2600 Err(e)
2601 }
2602 }
2603 }
2604}
2605
2606const CASCADE_STREAM_MAX_BYTES: usize = 1024 * 1024; #[derive(Clone, Default, Debug)]
2615struct CollectedStream {
2616 content: String,
2617 thinking: Vec<String>,
2618 tool_calls: Vec<crate::provider::ToolUseRequest>,
2619 compaction: Option<String>,
2620}
2621
2622impl CollectedStream {
2623 fn into_stream(self) -> ChatStream {
2625 use crate::provider::StreamChunk;
2626 let mut chunks: Vec<Result<StreamChunk, LlmError>> = Vec::new();
2627 for t in self.thinking {
2628 chunks.push(Ok(StreamChunk::Thinking(t)));
2629 }
2630 if !self.tool_calls.is_empty() {
2631 chunks.push(Ok(StreamChunk::ToolUse(self.tool_calls)));
2632 }
2633 if let Some(c) = self.compaction {
2634 chunks.push(Ok(StreamChunk::Compaction(c)));
2635 }
2636 if !self.content.is_empty() {
2637 chunks.push(Ok(StreamChunk::Content(self.content)));
2638 }
2639 Box::pin(tokio_stream::iter(chunks))
2640 }
2641
2642 fn is_empty(&self) -> bool {
2643 self.content.is_empty() && self.tool_calls.is_empty()
2644 }
2645}
2646
2647async fn collect_stream(stream: ChatStream) -> Result<CollectedStream, LlmError> {
2651 use tokio_stream::StreamExt as _;
2652
2653 let mut stream = stream;
2654 let mut collected = CollectedStream::default();
2655 while let Some(chunk) = stream.next().await {
2656 match chunk? {
2657 crate::provider::StreamChunk::Content(c) => {
2658 if collected.content.len() + c.len() > CASCADE_STREAM_MAX_BYTES {
2659 return Err(LlmError::Other(
2660 "cascade: stream response exceeds 1 MiB buffer limit".into(),
2661 ));
2662 }
2663 collected.content.push_str(&c);
2664 }
2665 crate::provider::StreamChunk::Thinking(t) => {
2666 collected.thinking.push(t);
2667 }
2668 crate::provider::StreamChunk::ToolUse(tools) => {
2669 collected.tool_calls.extend(tools);
2670 }
2671 crate::provider::StreamChunk::Compaction(c) => {
2672 collected.compaction = Some(c);
2673 }
2674 }
2675 }
2676 Ok(collected)
2677}
2678
2679#[cfg(test)]
2680mod tests {
2681 use super::*;
2682 use crate::provider::Role;
2683
2684 #[test]
2685 fn empty_router_name() {
2686 let r = RouterProvider::new(vec![]);
2687 assert_eq!(r.name(), "router");
2688 }
2689
2690 #[test]
2691 fn empty_router_supports_nothing() {
2692 let r = RouterProvider::new(vec![]);
2693 assert!(!r.supports_streaming());
2694 assert!(!r.supports_embeddings());
2695 assert!(!r.supports_tool_use());
2696 }
2697
2698 #[test]
2699 fn empty_router_context_window_none() {
2700 let r = RouterProvider::new(vec![]);
2701 assert!(r.context_window().is_none());
2702 }
2703
2704 #[tokio::test]
2705 async fn empty_router_chat_returns_no_providers() {
2706 let r = RouterProvider::new(vec![]);
2707 let msgs = vec![Message::from_legacy(Role::User, "hello")];
2708 let err = r.chat(&msgs).await.unwrap_err();
2709 assert!(matches!(err, LlmError::NoProviders));
2710 }
2711
2712 #[tokio::test]
2713 async fn empty_router_chat_stream_returns_no_providers() {
2714 let r = RouterProvider::new(vec![]);
2715 let msgs = vec![Message::from_legacy(Role::User, "hello")];
2716 let result = r.chat_stream(&msgs).await;
2717 assert!(matches!(result, Err(LlmError::NoProviders)));
2718 }
2719
2720 #[tokio::test]
2721 async fn empty_router_embed_returns_no_providers() {
2722 let r = RouterProvider::new(vec![]);
2723 let err = r.embed("test").await.unwrap_err();
2724 assert!(matches!(err, LlmError::NoProviders));
2725 }
2726
2727 #[tokio::test]
2728 async fn empty_router_chat_with_tools_returns_no_providers() {
2729 let r = RouterProvider::new(vec![]);
2730 let msgs = vec![Message::from_legacy(Role::User, "hello")];
2731 let err = r.chat_with_tools(&msgs, &[]).await.unwrap_err();
2732 assert!(matches!(err, LlmError::NoProviders));
2733 }
2734
2735 #[tokio::test]
2736 async fn router_falls_back_on_unreachable() {
2737 use crate::ollama::OllamaProvider;
2738
2739 let p1 = AnyProvider::Ollama(OllamaProvider::new(
2740 "http://127.0.0.1:1",
2741 "m".into(),
2742 "e".into(),
2743 ));
2744 let p2 = AnyProvider::Ollama(OllamaProvider::new(
2745 "http://127.0.0.1:2",
2746 "m".into(),
2747 "e".into(),
2748 ));
2749 let r = RouterProvider::new(vec![p1, p2]);
2750 let msgs = vec![Message::from_legacy(Role::User, "hello")];
2751 let err = r.chat(&msgs).await.unwrap_err();
2752 assert!(matches!(err, LlmError::NoProviders));
2753 }
2754
2755 #[test]
2756 fn router_with_streaming_provider() {
2757 use crate::ollama::OllamaProvider;
2758
2759 let p = AnyProvider::Ollama(OllamaProvider::new(
2760 "http://127.0.0.1:1",
2761 "m".into(),
2762 "e".into(),
2763 ));
2764 let r = RouterProvider::new(vec![p]);
2765 assert!(r.supports_streaming());
2766 assert!(r.supports_embeddings());
2767 }
2768
2769 #[test]
2770 fn clone_preserves_providers() {
2771 use crate::ollama::OllamaProvider;
2772
2773 let p = AnyProvider::Ollama(OllamaProvider::new(
2774 "http://127.0.0.1:1",
2775 "m".into(),
2776 "e".into(),
2777 ));
2778 let r = RouterProvider::new(vec![p]);
2779 let c = r.clone();
2780 assert_eq!(c.state.providers.len(), 1);
2781 assert_eq!(c.name(), "router");
2782 }
2783
2784 #[test]
2785 fn last_cache_usage_returns_none() {
2786 let r = RouterProvider::new(vec![]);
2787 assert!(r.last_cache_usage().is_none());
2788 }
2789
2790 #[test]
2791 fn thompson_strategy_is_set() {
2792 let r = RouterProvider::new(vec![]).with_thompson(None);
2793 assert_eq!(r.strategy, RouterStrategy::Thompson);
2794 assert!(r.thompson.is_some());
2795 }
2796
2797 #[tokio::test]
2798 async fn save_thompson_state_noop_without_thompson() {
2799 let r = RouterProvider::new(vec![]);
2800 r.save_thompson_state().await; }
2802
2803 #[test]
2804 fn thompson_ordered_providers_empty() {
2805 let r = RouterProvider::new(vec![]).with_thompson(None);
2806 let ordered = r.ordered_providers();
2807 assert!(ordered.is_empty());
2808 }
2809
2810 #[test]
2811 fn concurrent_record_outcome_does_not_deadlock() {
2812 use std::sync::Arc;
2813 let r = Arc::new(RouterProvider::new(vec![]).with_thompson(None));
2814 let handles: Vec<_> = (0..8)
2815 .map(|i| {
2816 let router = Arc::clone(&r);
2817 std::thread::spawn(move || {
2818 router.record_availability(&format!("p{i}"), i % 2 == 0, 10);
2819 })
2820 })
2821 .collect();
2822 for h in handles {
2823 h.join().expect("thread panicked");
2824 }
2825 let stats = r.thompson_stats();
2827 assert_eq!(stats.len(), 8);
2828 }
2829
2830 #[test]
2833 fn cascade_strategy_is_set() {
2834 let r = RouterProvider::new(vec![]).with_cascade(CascadeRouterConfig::default());
2835 assert_eq!(r.strategy, RouterStrategy::Cascade);
2836 assert!(r.cascade_state.is_some());
2837 assert!(r.cascade_config.is_some());
2838 }
2839
2840 #[test]
2841 fn cascade_ordered_providers_preserves_chain_order() {
2842 use crate::ollama::OllamaProvider;
2843 let p1 = AnyProvider::Ollama(OllamaProvider::new(
2844 "http://127.0.0.1:1",
2845 "a".into(),
2846 String::new(),
2847 ));
2848 let p2 = AnyProvider::Ollama(OllamaProvider::new(
2849 "http://127.0.0.1:2",
2850 "b".into(),
2851 String::new(),
2852 ));
2853 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
2854 let ordered = r.ordered_providers();
2855 assert_eq!(ordered.len(), 2);
2856 }
2857
2858 #[tokio::test]
2859 async fn cascade_empty_router_returns_no_providers() {
2860 let r = RouterProvider::new(vec![]).with_cascade(CascadeRouterConfig::default());
2861 let msgs = vec![Message::from_legacy(Role::User, "hello")];
2862 let err = r.chat(&msgs).await.unwrap_err();
2863 assert!(matches!(err, LlmError::NoProviders));
2864 }
2865
2866 #[tokio::test]
2867 async fn cascade_returns_best_seen_when_all_fail_after_good_response() {
2868 use crate::mock::MockProvider;
2869
2870 let cheap =
2872 AnyProvider::Mock(MockProvider::with_responses(vec!["ok".to_owned()]).with_delay(0));
2873 let expensive = AnyProvider::Mock(MockProvider::failing());
2875
2876 let r = RouterProvider::new(vec![cheap, expensive]).with_cascade(CascadeRouterConfig {
2877 quality_threshold: 0.9, max_escalations: 2,
2879 ..CascadeRouterConfig::default()
2880 });
2881 let msgs = vec![Message::from_legacy(Role::User, "hello")];
2882 let result = r.chat(&msgs).await.unwrap();
2884 assert_eq!(result, "ok");
2885 }
2886
2887 #[tokio::test]
2888 async fn cascade_accepts_good_quality_response() {
2889 use crate::mock::MockProvider;
2890
2891 let good_response = "This is a comprehensive, well-structured response that provides \
2892 detailed information about the topic. It covers multiple aspects and explains \
2893 the reasoning clearly with proper sentence structure.";
2894
2895 let cheap = AnyProvider::Mock(
2896 MockProvider::with_responses(vec![good_response.to_owned()]).with_delay(0),
2897 );
2898 let expensive = AnyProvider::Mock(MockProvider::failing());
2900
2901 let r = RouterProvider::new(vec![cheap, expensive]).with_cascade(CascadeRouterConfig {
2902 quality_threshold: 0.5,
2903 max_escalations: 1,
2904 ..CascadeRouterConfig::default()
2905 });
2906 let msgs = vec![Message::from_legacy(Role::User, "explain something")];
2907 let result = r.chat(&msgs).await.unwrap();
2908 assert_eq!(result, good_response);
2909 }
2910
2911 #[tokio::test]
2912 async fn cascade_max_escalations_budget_exhausted_returns_last_attempted() {
2913 use crate::mock::MockProvider;
2914
2915 let p1 =
2918 AnyProvider::Mock(MockProvider::with_responses(vec!["x".to_owned()]).with_delay(0));
2919 let p2 =
2920 AnyProvider::Mock(MockProvider::with_responses(vec!["x".to_owned()]).with_delay(0));
2921 let p3 = AnyProvider::Mock(MockProvider::failing()); let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
2924 quality_threshold: 0.9,
2925 max_escalations: 1, ..CascadeRouterConfig::default()
2927 });
2928 let msgs = vec![Message::from_legacy(Role::User, "test")];
2929 let result = r.chat(&msgs).await.unwrap();
2930 assert_eq!(result, "x");
2931 }
2932
2933 #[tokio::test]
2934 async fn cascade_token_budget_stops_escalation() {
2935 use crate::mock::MockProvider;
2936
2937 let p1 =
2938 AnyProvider::Mock(MockProvider::with_responses(vec!["x".to_owned()]).with_delay(0));
2939 let p2 = AnyProvider::Mock(MockProvider::failing()); let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2942 quality_threshold: 0.9, max_escalations: 5,
2944 max_cascade_tokens: Some(1), ..CascadeRouterConfig::default()
2946 });
2947 let msgs = vec![Message::from_legacy(Role::User, "test")];
2948 let result = r.chat(&msgs).await.unwrap();
2949 assert_eq!(result, "x"); }
2951
2952 #[tokio::test]
2953 async fn cascade_budget_returns_best_seen_not_current() {
2954 use crate::mock::MockProvider;
2955
2956 let good_response = "This is a reasonable response with enough content to score well.";
2959 let bad_response = "x"; let p1 = AnyProvider::Mock(
2962 MockProvider::with_responses(vec![good_response.to_owned()]).with_delay(0),
2963 );
2964 let p2 = AnyProvider::Mock(
2965 MockProvider::with_responses(vec![bad_response.to_owned()]).with_delay(0),
2966 );
2967
2968 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2969 quality_threshold: 0.95, max_escalations: 5,
2971 max_cascade_tokens: Some(1), ..CascadeRouterConfig::default()
2973 });
2974 let msgs = vec![Message::from_legacy(Role::User, "test")];
2975 let result = r.chat(&msgs).await.unwrap();
2979 assert_ne!(result, bad_response, "should return best-seen, not current");
2981 }
2982
2983 #[tokio::test]
2984 async fn cascade_escalations_exhausted_returns_best_seen_not_current() {
2985 use crate::mock::MockProvider;
2986
2987 let good_response = "This is a reasonable response with enough content to score well.";
2990 let bad_response = "x";
2991
2992 let p1 = AnyProvider::Mock(
2993 MockProvider::with_responses(vec![good_response.to_owned()]).with_delay(0),
2994 );
2995 let p2 = AnyProvider::Mock(
2996 MockProvider::with_responses(vec![bad_response.to_owned()]).with_delay(0),
2997 );
2998 let p3 = AnyProvider::Mock(MockProvider::failing()); let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
3001 quality_threshold: 0.95, max_escalations: 1, ..CascadeRouterConfig::default()
3004 });
3005 let msgs = vec![Message::from_legacy(Role::User, "test")];
3006 let result = r.chat(&msgs).await.unwrap();
3007 assert_eq!(
3008 result, good_response,
3009 "should return best-seen (p1), not the degenerate current response (p2)"
3010 );
3011 assert_ne!(
3012 result, bad_response,
3013 "must not return degenerate p2 response"
3014 );
3015 }
3016
3017 #[tokio::test]
3018 async fn cascade_stream_escalations_exhausted_returns_best_seen_not_current() {
3019 use crate::mock::MockProvider;
3020
3021 let good_response = "This is a reasonable response with enough content to score well.";
3025 let bad_response = "x";
3026
3027 let p1 = AnyProvider::Mock(
3028 MockProvider::with_responses(vec![good_response.to_owned()])
3029 .with_delay(0)
3030 .with_streaming(),
3031 );
3032 let p2 = AnyProvider::Mock(
3033 MockProvider::with_responses(vec![bad_response.to_owned()])
3034 .with_delay(0)
3035 .with_streaming(),
3036 );
3037 let p3 = AnyProvider::Mock(MockProvider::failing()); let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
3040 quality_threshold: 0.95, max_escalations: 1, ..CascadeRouterConfig::default()
3043 });
3044 let msgs = vec![Message::from_legacy(Role::User, "test")];
3045 let stream = r.chat_stream(&msgs).await.unwrap();
3046 let collected = collect_stream(stream).await.unwrap();
3047 assert_eq!(
3048 collected.content, good_response,
3049 "should return best-seen (p1), not the degenerate current response (p2)"
3050 );
3051 assert_ne!(
3052 collected.content, bad_response,
3053 "must not return degenerate p2 response"
3054 );
3055 }
3056
3057 #[tokio::test]
3058 async fn cascade_all_providers_fail_returns_no_providers() {
3059 use crate::mock::MockProvider;
3060
3061 let p1 = AnyProvider::Mock(MockProvider::failing());
3062 let p2 = AnyProvider::Mock(MockProvider::failing());
3063
3064 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
3065 let msgs = vec![Message::from_legacy(Role::User, "test")];
3066 let err = r.chat(&msgs).await.unwrap_err();
3067 assert!(matches!(err, LlmError::NoProviders));
3068 }
3069
3070 #[tokio::test]
3071 async fn cascade_stream_good_quality_no_escalation() {
3072 use crate::mock::MockProvider;
3073
3074 let good = "This is a well-formed response with sufficient length and coherent structure.";
3075 let p1 = AnyProvider::Mock(
3076 MockProvider::with_responses(vec![good.to_owned()])
3077 .with_delay(0)
3078 .with_streaming(),
3079 );
3080 let p2 = AnyProvider::Mock(MockProvider::failing());
3081
3082 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3083 quality_threshold: 0.5,
3084 max_escalations: 1,
3085 ..CascadeRouterConfig::default()
3086 });
3087 let msgs = vec![Message::from_legacy(Role::User, "q")];
3088 let stream = r.chat_stream(&msgs).await.unwrap();
3089 let collected = collect_stream(stream).await.unwrap();
3090 assert_eq!(collected.content, good);
3091 }
3092
3093 #[tokio::test]
3094 async fn cascade_stream_escalates_to_last_provider() {
3095 use crate::mock::MockProvider;
3096
3097 let bad = "x"; let good = "This is the expensive model's comprehensive response.";
3099 let p1 = AnyProvider::Mock(
3100 MockProvider::with_responses(vec![bad.to_owned()])
3101 .with_delay(0)
3102 .with_streaming(),
3103 );
3104 let p2 = AnyProvider::Mock(
3105 MockProvider::with_responses(vec![good.to_owned()])
3106 .with_delay(0)
3107 .with_streaming(),
3108 );
3109
3110 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3111 quality_threshold: 0.9, max_escalations: 1,
3113 ..CascadeRouterConfig::default()
3114 });
3115 let msgs = vec![Message::from_legacy(Role::User, "q")];
3116 let stream = r.chat_stream(&msgs).await.unwrap();
3117 let collected = collect_stream(stream).await.unwrap();
3118 assert_eq!(collected.content, good);
3119 }
3120
3121 #[tokio::test]
3122 async fn cascade_stream_budget_returns_best_seen() {
3123 use crate::mock::MockProvider;
3124
3125 let good_response = "This is a reasonable response with enough content to score well.";
3130 let bad_response = "x"; let p1 = AnyProvider::Mock(
3133 MockProvider::with_responses(vec![good_response.to_owned()])
3134 .with_delay(0)
3135 .with_streaming(),
3136 );
3137 let p2 = AnyProvider::Mock(
3138 MockProvider::with_responses(vec![bad_response.to_owned()])
3139 .with_delay(0)
3140 .with_streaming(),
3141 );
3142 let p3 = AnyProvider::Mock(MockProvider::failing()); let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
3145 quality_threshold: 0.95, max_escalations: 5,
3147 max_cascade_tokens: Some(1), ..CascadeRouterConfig::default()
3149 });
3150 let msgs = vec![Message::from_legacy(Role::User, "test")];
3151 let stream = r.chat_stream(&msgs).await.unwrap();
3152 let collected = collect_stream(stream).await.unwrap();
3153 assert_eq!(
3155 collected.content, good_response,
3156 "should return best-seen p1 response when budget exhausted"
3157 );
3158 }
3159
3160 #[tokio::test]
3161 async fn cascade_stream_budget_returns_best_seen_not_current() {
3162 use crate::mock::MockProvider;
3163
3164 let good_response = "This is a reasonable response with enough content to score well.";
3170 let bad_response = "x"; let p1 = AnyProvider::Mock(
3173 MockProvider::with_responses(vec![good_response.to_owned()])
3174 .with_delay(0)
3175 .with_streaming(),
3176 );
3177 let p2 = AnyProvider::Mock(
3178 MockProvider::with_responses(vec![bad_response.to_owned()])
3179 .with_delay(0)
3180 .with_streaming(),
3181 );
3182 let p3 = AnyProvider::Mock(MockProvider::failing()); let p4 = AnyProvider::Mock(MockProvider::failing()); let r = RouterProvider::new(vec![p1, p2, p3, p4]).with_cascade(CascadeRouterConfig {
3188 quality_threshold: 0.95, max_escalations: 5,
3190 max_cascade_tokens: Some(17), ..CascadeRouterConfig::default()
3192 });
3193 let msgs = vec![Message::from_legacy(Role::User, "test")];
3194 let stream = r.chat_stream(&msgs).await.unwrap();
3195 let collected = collect_stream(stream).await.unwrap();
3196 assert_eq!(
3198 collected.content, good_response,
3199 "should return best-seen (p1), not current degenerate (p2)"
3200 );
3201 assert_ne!(
3202 collected.content, bad_response,
3203 "must not return the degenerate p2 response"
3204 );
3205 }
3206
3207 #[tokio::test]
3208 async fn cascade_stream_last_fails_returns_best_seen() {
3209 use crate::mock::MockProvider;
3210
3211 let low_quality = "ok"; let p1 = AnyProvider::Mock(
3217 MockProvider::with_responses(vec![low_quality.to_owned()])
3218 .with_delay(0)
3219 .with_streaming(),
3220 );
3221 let p2 = AnyProvider::Mock(MockProvider::failing()); let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3224 quality_threshold: 0.9, max_escalations: 2,
3226 ..CascadeRouterConfig::default()
3227 });
3228 let msgs = vec![Message::from_legacy(Role::User, "hello")];
3229 let stream = r.chat_stream(&msgs).await.unwrap();
3230 let collected = collect_stream(stream).await.unwrap();
3231 assert_eq!(collected.content, low_quality);
3232 }
3233
3234 #[tokio::test]
3235 async fn cascade_stream_all_fail_returns_error() {
3236 use crate::mock::MockProvider;
3237
3238 let p1 = AnyProvider::Mock(MockProvider::failing());
3242 let p2 = AnyProvider::Mock(MockProvider::failing());
3243
3244 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
3245 let msgs = vec![Message::from_legacy(Role::User, "test")];
3246 let result = r.chat_stream(&msgs).await;
3247 assert!(
3248 result.is_err(),
3249 "expected error when all providers fail with no best_seen"
3250 );
3251 }
3252
3253 #[test]
3254 fn cascade_config_default_values() {
3255 let cfg = CascadeRouterConfig::default();
3256 assert!((cfg.quality_threshold - 0.5).abs() < f64::EPSILON);
3257 assert_eq!(cfg.max_escalations, 2);
3258 assert_eq!(cfg.window_size, 50);
3259 assert!(cfg.max_cascade_tokens.is_none());
3260 assert_eq!(cfg.classifier_mode, cascade::ClassifierMode::Heuristic);
3261 }
3262
3263 #[test]
3264 fn evaluate_heuristic_empty_should_escalate_above_threshold() {
3265 let verdict = RouterProvider::evaluate_heuristic("", 0.05);
3266 assert!(verdict.should_escalate);
3268 }
3269
3270 #[test]
3271 fn evaluate_heuristic_good_response_does_not_escalate() {
3272 let text = "The answer to your question is straightforward. Consider the options and pick the best one.";
3273 let verdict = RouterProvider::evaluate_heuristic(text, 0.5);
3274 assert!(!verdict.should_escalate, "score={}", verdict.score);
3275 }
3276
3277 #[tokio::test]
3281 async fn cascade_empty_response_not_stored_as_best_seen() {
3282 use crate::mock::MockProvider;
3283
3284 let p = AnyProvider::Mock(MockProvider::with_responses(vec![String::new()]));
3287 let cfg = CascadeRouterConfig {
3288 quality_threshold: 0.0,
3289 ..Default::default()
3290 };
3291 let r = RouterProvider::new(vec![p]).with_cascade(cfg);
3292 let msgs = vec![Message::from_legacy(Role::User, "hi")];
3293 let result = r.chat(&msgs).await;
3296 assert!(result.is_ok());
3297 assert_eq!(result.unwrap(), "");
3298 }
3299
3300 #[tokio::test]
3303 async fn cascade_empty_best_seen_not_returned_on_all_fail() {
3304 use crate::mock::MockProvider;
3305
3306 let p1 = AnyProvider::Mock(MockProvider::with_responses(vec![String::new()]));
3309 let p2 = AnyProvider::Mock(MockProvider::failing());
3310
3311 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
3312 let msgs = vec![Message::from_legacy(Role::User, "hi")];
3313 let result = r.chat(&msgs).await;
3314 assert!(
3316 result.is_err(),
3317 "expected error, not silent empty string; got: {result:?}"
3318 );
3319 }
3320
3321 #[tokio::test]
3323 async fn cascade_stream_empty_response_not_stored_as_best_seen() {
3324 use crate::mock::MockProvider;
3325
3326 let p1 = AnyProvider::Mock(MockProvider::with_responses(vec![String::new()]));
3329 let p2 = AnyProvider::Mock(
3330 MockProvider::with_responses(vec!["real answer".to_owned()]).with_streaming(),
3331 );
3332
3333 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
3334 let msgs = vec![Message::from_legacy(Role::User, "hi")];
3335 let stream = r.chat_stream(&msgs).await.expect("should not error");
3336 let collected = collect_stream(stream).await.expect("stream should succeed");
3337 assert_eq!(collected.content, "real answer");
3338 }
3339
3340 #[test]
3343 fn arc_providers_clone_shares_allocation() {
3344 use crate::mock::MockProvider;
3345 let p = AnyProvider::Mock(MockProvider::default());
3346 let r = RouterProvider::new(vec![p]);
3347 let c = r.clone();
3348 assert!(Arc::ptr_eq(&r.state.providers, &c.state.providers));
3350 }
3351
3352 #[test]
3353 fn cost_tiers_reorders_providers_at_construction() {
3354 use crate::mock::MockProvider;
3355 let p1 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
3356 let p2 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
3357 let p3 = AnyProvider::Mock(MockProvider::default().with_name("openai"));
3358 let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
3359 cost_tiers: Some(vec!["ollama".into(), "claude".into()]),
3360 ..CascadeRouterConfig::default()
3361 });
3362 let names: Vec<&str> = r.state.providers.iter().map(LlmProvider::name).collect();
3363 assert_eq!(names, vec!["ollama", "claude", "openai"]);
3365 }
3366
3367 #[test]
3368 fn cost_tiers_none_preserves_chain_order() {
3369 use crate::mock::MockProvider;
3370 let p1 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
3371 let p2 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
3372 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3373 cost_tiers: None,
3374 ..CascadeRouterConfig::default()
3375 });
3376 let names: Vec<&str> = r.state.providers.iter().map(LlmProvider::name).collect();
3377 assert_eq!(names, vec!["claude", "ollama"]);
3378 }
3379
3380 #[test]
3381 fn cost_tiers_empty_vec_preserves_chain_order() {
3382 use crate::mock::MockProvider;
3383 let p1 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
3384 let p2 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
3385 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3386 cost_tiers: Some(vec![]),
3387 ..CascadeRouterConfig::default()
3388 });
3389 let names: Vec<&str> = r.state.providers.iter().map(LlmProvider::name).collect();
3390 assert_eq!(names, vec!["claude", "ollama"]);
3391 }
3392
3393 #[test]
3394 fn cost_tiers_unknown_name_ignored() {
3395 use crate::mock::MockProvider;
3396 let p1 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
3397 let p2 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
3398 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3399 cost_tiers: Some(vec!["nonexistent".into(), "ollama".into()]),
3400 ..CascadeRouterConfig::default()
3401 });
3402 let names: Vec<&str> = r.state.providers.iter().map(LlmProvider::name).collect();
3403 assert_eq!(names, vec!["ollama", "claude"]);
3405 }
3406
3407 #[test]
3408 fn cost_tiers_all_providers_listed() {
3409 use crate::mock::MockProvider;
3410 let p1 = AnyProvider::Mock(MockProvider::default().with_name("c"));
3411 let p2 = AnyProvider::Mock(MockProvider::default().with_name("b"));
3412 let p3 = AnyProvider::Mock(MockProvider::default().with_name("a"));
3413 let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
3414 cost_tiers: Some(vec!["a".into(), "b".into(), "c".into()]),
3415 ..CascadeRouterConfig::default()
3416 });
3417 let names: Vec<&str> = r.state.providers.iter().map(LlmProvider::name).collect();
3418 assert_eq!(names, vec!["a", "b", "c"]);
3419 }
3420
3421 #[test]
3422 fn cost_tiers_duplicate_name_uses_last_position() {
3423 use crate::mock::MockProvider;
3424 let p1 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
3425 let p2 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
3426 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3429 cost_tiers: Some(vec!["claude".into(), "ollama".into(), "ollama".into()]),
3430 ..CascadeRouterConfig::default()
3431 });
3432 let names: Vec<&str> = r.state.providers.iter().map(LlmProvider::name).collect();
3433 assert_eq!(names, vec!["claude", "ollama"]);
3434 }
3435
3436 #[test]
3437 fn cost_tiers_empty_router_does_not_panic() {
3438 let r = RouterProvider::new(vec![]).with_cascade(CascadeRouterConfig {
3439 cost_tiers: Some(vec!["foo".into()]),
3440 ..CascadeRouterConfig::default()
3441 });
3442 assert_eq!(r.state.providers.len(), 0);
3443 }
3444
3445 #[test]
3446 fn set_status_tx_works_with_arc() {
3447 use crate::mock::MockProvider;
3448 let p = AnyProvider::Mock(MockProvider::default());
3449 let mut r = RouterProvider::new(vec![p]);
3450 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
3451 r.set_status_tx(tx); }
3453
3454 #[tokio::test]
3455 async fn cascade_chat_with_tools_unaffected_by_cost_tiers() {
3456 use crate::mock::MockProvider;
3457 let p1 = AnyProvider::Mock(MockProvider::failing().with_name("cheap"));
3460 let p2 = AnyProvider::Mock(MockProvider::failing().with_name("expensive"));
3461 let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3462 cost_tiers: Some(vec!["cheap".into()]),
3463 ..CascadeRouterConfig::default()
3464 });
3465 let msgs = vec![Message::from_legacy(Role::User, "hi")];
3466 let err = r.chat_with_tools(&msgs, &[]).await.unwrap_err();
3468 assert!(matches!(err, LlmError::NoProviders));
3469 }
3470
3471 #[tokio::test]
3476 async fn embed_retries_on_rate_limited_then_succeeds() {
3477 use crate::mock::MockProvider;
3478
3479 let p = AnyProvider::Mock({
3480 let mut m = MockProvider::default()
3481 .with_errors(vec![LlmError::RateLimited, LlmError::RateLimited])
3482 .with_name("p1");
3483 m.supports_embeddings = true;
3484 m.embedding = vec![0.1, 0.2];
3485 m
3486 });
3487 let r = RouterProvider::new(vec![p]);
3488 let result = r.embed("text").await.unwrap();
3489 assert_eq!(result, vec![0.1, 0.2]);
3490 }
3491
3492 #[tokio::test]
3495 async fn embed_falls_back_after_all_retries_exhausted() {
3496 use crate::mock::MockProvider;
3497
3498 let p1 = AnyProvider::Mock({
3500 let mut m = MockProvider::default()
3501 .with_errors(vec![
3502 LlmError::RateLimited,
3503 LlmError::RateLimited,
3504 LlmError::RateLimited,
3505 LlmError::RateLimited,
3506 ])
3507 .with_name("p1");
3508 m.supports_embeddings = true;
3509 m
3510 });
3511 let p2 = AnyProvider::Mock({
3512 let mut m = MockProvider::default().with_name("p2");
3513 m.supports_embeddings = true;
3514 m.embedding = vec![9.0, 8.0];
3515 m
3516 });
3517 let r = RouterProvider::new(vec![p1, p2]);
3518 let result = r.embed("text").await.unwrap();
3519 assert_eq!(result, vec![9.0, 8.0]);
3520 }
3521
3522 #[tokio::test]
3524 async fn embed_batch_retries_on_rate_limited_then_succeeds() {
3525 use crate::mock::MockProvider;
3526
3527 let p = AnyProvider::Mock({
3528 let mut m = MockProvider::default()
3529 .with_errors(vec![LlmError::RateLimited, LlmError::RateLimited])
3530 .with_name("p1");
3531 m.supports_embeddings = true;
3532 m.embedding = vec![0.5, 0.6];
3533 m
3534 });
3535 let r = RouterProvider::new(vec![p]);
3536 let result = r.embed_batch(&["a", "b"]).await.unwrap();
3537 assert_eq!(result, vec![vec![0.5, 0.6], vec![0.5, 0.6]]);
3538 }
3539
3540 #[tokio::test]
3543 async fn embed_batch_falls_back_after_all_retries_exhausted() {
3544 use crate::mock::MockProvider;
3545
3546 let p1 = AnyProvider::Mock({
3548 let mut m = MockProvider::default()
3549 .with_errors(vec![
3550 LlmError::RateLimited,
3551 LlmError::RateLimited,
3552 LlmError::RateLimited,
3553 LlmError::RateLimited,
3554 ])
3555 .with_name("p1");
3556 m.supports_embeddings = true;
3557 m
3558 });
3559 let p2 = AnyProvider::Mock({
3560 let mut m = MockProvider::default().with_name("p2");
3561 m.supports_embeddings = true;
3562 m.embedding = vec![7.0, 8.0];
3563 m
3564 });
3565 let r = RouterProvider::new(vec![p1, p2]);
3566 let result = r.embed_batch(&["x"]).await.unwrap();
3567 assert_eq!(result, vec![vec![7.0, 8.0]]);
3568 }
3569
3570 #[tokio::test]
3575 async fn embed_invalid_input_breaks_loop_and_returns_invalid_input() {
3576 use crate::mock::MockProvider;
3577
3578 let p = AnyProvider::Mock(MockProvider::default().with_embed_invalid_input());
3579 let r = RouterProvider::new(vec![p]).with_thompson(None);
3580 let err = r.embed("some text").await.unwrap_err();
3581 assert!(
3582 matches!(err, LlmError::InvalidInput { .. }),
3583 "expected InvalidInput, got {err:?}"
3584 );
3585 }
3586
3587 #[tokio::test]
3590 async fn embed_invalid_input_does_not_fall_through_to_second_provider() {
3591 use crate::mock::MockProvider;
3592
3593 let p1 = AnyProvider::Mock(
3597 MockProvider::default()
3598 .with_embed_invalid_input()
3599 .with_name("p1"),
3600 );
3601 let p2 = AnyProvider::Mock({
3602 let mut m = MockProvider::default();
3603 m.supports_embeddings = true;
3604 m.name_override = Some("p2".into());
3605 m
3606 });
3607
3608 let r = RouterProvider::new(vec![p1, p2]);
3609 let err = r.embed("test").await.unwrap_err();
3610
3611 assert!(
3613 matches!(&err, LlmError::InvalidInput { provider, .. } if provider == "p1"),
3614 "expected InvalidInput from p1, got {err:?}"
3615 );
3616 }
3617
3618 #[tokio::test]
3623 async fn chat_with_tools_invalid_input_breaks_loop_and_returns_invalid_input() {
3624 use crate::mock::MockProvider;
3625 use crate::provider::ToolDefinition;
3626
3627 let p = AnyProvider::Mock(MockProvider::default().with_tool_chat_invalid_input());
3628 let r = RouterProvider::new(vec![p]).with_thompson(None);
3629 let err = r
3630 .chat_with_tools(&[], &[] as &[ToolDefinition])
3631 .await
3632 .unwrap_err();
3633 assert!(
3634 matches!(err, LlmError::InvalidInput { .. }),
3635 "expected InvalidInput, got {err:?}"
3636 );
3637 }
3638
3639 #[tokio::test]
3642 async fn chat_with_tools_invalid_input_does_not_fall_through_to_second_provider() {
3643 use crate::mock::MockProvider;
3644 use crate::provider::ToolDefinition;
3645
3646 let p1 = AnyProvider::Mock(
3647 MockProvider::default()
3648 .with_tool_chat_invalid_input()
3649 .with_name("p1"),
3650 );
3651 let p2 = AnyProvider::Mock(MockProvider::default().with_name("p2"));
3652
3653 let r = RouterProvider::new(vec![p1, p2]);
3654 let err = r
3655 .chat_with_tools(&[], &[] as &[ToolDefinition])
3656 .await
3657 .unwrap_err();
3658
3659 assert!(
3660 matches!(&err, LlmError::InvalidInput { provider, .. } if provider == "p1"),
3661 "expected InvalidInput from p1, got {err:?}"
3662 );
3663 }
3664
3665 #[tokio::test]
3668 async fn embed_skips_non_embedding_providers_and_falls_through() {
3669 use crate::mock::MockProvider;
3670
3671 let p1 = AnyProvider::Mock({
3674 let mut m = MockProvider::default().with_name("p1");
3675 m.supports_embeddings = false;
3676 m
3677 });
3678 let p2 = AnyProvider::Mock({
3679 let mut m = MockProvider::default().with_name("p2");
3680 m.supports_embeddings = true;
3681 m.embedding = vec![1.0, 2.0, 3.0];
3682 m
3683 });
3684
3685 let r = RouterProvider::new(vec![p1, p2]);
3686 let result = r.embed("hello").await.unwrap();
3687 assert_eq!(result, vec![1.0, 2.0, 3.0]);
3688 }
3689
3690 #[tokio::test]
3694 async fn embed_invalid_input_does_not_record_availability() {
3695 use crate::mock::MockProvider;
3696
3697 let p = AnyProvider::Mock(
3698 MockProvider::default()
3699 .with_embed_invalid_input()
3700 .with_name("test-provider"),
3701 );
3702 let r = RouterProvider::new(vec![p]).with_thompson(None);
3703 let _ = r.embed("text").await;
3704
3705 let stats = r.thompson_stats();
3708 let provider_in_stats = stats.iter().any(|(name, ..)| name == "test-provider");
3709 assert!(
3710 !provider_in_stats,
3711 "InvalidInput must not update provider reputation; stats: {stats:?}"
3712 );
3713 }
3714
3715 #[tokio::test]
3720 async fn embed_timeout_single_provider_returns_no_providers() {
3721 use crate::mock::MockProvider;
3722
3723 let p = AnyProvider::Mock(
3724 MockProvider::default()
3725 .with_embed_delay(200)
3726 .with_name("slow"),
3727 );
3728 let r = RouterProvider::new(vec![p]).with_embed_timeout(10);
3729 let err = r.embed("hello").await.unwrap_err();
3730 assert!(
3731 matches!(err, LlmError::NoProviders),
3732 "expected NoProviders after timeout, got {err:?}"
3733 );
3734 }
3735
3736 #[tokio::test]
3739 async fn embed_timeout_falls_back_to_next_provider() {
3740 use crate::mock::MockProvider;
3741
3742 let p1 = AnyProvider::Mock(
3743 MockProvider::default()
3744 .with_embed_delay(200)
3745 .with_name("slow"),
3746 );
3747 let p2 = AnyProvider::Mock({
3748 let mut m = MockProvider::default().with_name("fast");
3749 m.supports_embeddings = true;
3750 m.embedding = vec![1.0, 2.0, 3.0];
3751 m
3752 });
3753 let r = RouterProvider::new(vec![p1, p2]).with_embed_timeout(10);
3754 let result = r.embed("hello").await.unwrap();
3755 assert_eq!(result, vec![1.0, 2.0, 3.0]);
3756 }
3757
3758 #[tokio::test]
3763 async fn quality_gate_passes_when_similarity_above_threshold() {
3764 use crate::mock::MockProvider;
3765
3766 let p1 = AnyProvider::Mock({
3769 let mut m = MockProvider::with_responses(vec!["answer".to_owned()]).with_name("p1");
3770 m.supports_embeddings = true;
3771 m.embedding = vec![1.0, 0.0];
3772 m
3773 });
3774 let r = RouterProvider::new(vec![p1])
3775 .with_thompson(None)
3776 .with_quality_gate(0.5);
3777 let msgs = vec![Message::from_legacy(Role::User, "question")];
3778 let result = r.chat(&msgs).await.unwrap();
3779 assert_eq!(result, "answer");
3780 }
3781
3782 #[tokio::test]
3785 async fn quality_gate_exhaustion_returns_best_seen() {
3786 use crate::mock::MockProvider;
3787
3788 let p1 = AnyProvider::Mock({
3792 let mut m =
3793 MockProvider::with_responses(vec!["best_so_far".to_owned()]).with_name("p1");
3794 m.supports_embeddings = true;
3795 m.embedding = vec![0.0, 1.0];
3797 m
3798 });
3799 let p2 = AnyProvider::Mock(MockProvider::failing().with_name("p2"));
3800 let r = RouterProvider::new(vec![p1, p2])
3801 .with_thompson(None)
3802 .with_quality_gate(0.9);
3803 let msgs = vec![Message::from_legacy(Role::User, "question")];
3804 let result = r.chat(&msgs).await.unwrap();
3805 assert_eq!(result, "best_so_far");
3806 }
3807
3808 #[test]
3813 fn routing_signals_quality_gate_above_one_is_ignored() {
3814 let threshold: f32 = 5.0;
3817 let mut router = RouterProvider::new(vec![]);
3818 if threshold.is_finite() && threshold > 0.0 && threshold <= 1.0 {
3819 router = router.with_quality_gate(threshold);
3820 }
3821 assert!(
3822 router.quality_gate.is_none(),
3823 "out-of-range quality_gate must not be wired; got {:?}",
3824 router.quality_gate
3825 );
3826 }
3827
3828 #[test]
3830 fn routing_signals_quality_gate_valid_is_wired() {
3831 let threshold: f32 = 0.8;
3832 let mut router = RouterProvider::new(vec![]);
3833 if threshold.is_finite() && threshold > 0.0 && threshold <= 1.0 {
3834 router = router.with_quality_gate(threshold);
3835 }
3836 assert_eq!(
3837 router.quality_gate,
3838 Some(0.8),
3839 "valid quality_gate must be wired"
3840 );
3841 }
3842
3843 #[test]
3846 fn asi_debounce_same_turn_fires_once() {
3847 let router = RouterProvider::new(vec![]);
3848 let turn_id = 42u64;
3849
3850 let prev1 = router.state.asi_last_turn.swap(turn_id, Ordering::AcqRel);
3852 let first_dropped = prev1 == turn_id;
3853
3854 let prev2 = router.state.asi_last_turn.swap(turn_id, Ordering::AcqRel);
3856 let second_dropped = prev2 == turn_id;
3857
3858 assert!(!first_dropped, "first call in turn must not be dropped");
3859 assert!(second_dropped, "second call in same turn must be dropped");
3860 }
3861
3862 #[test]
3863 fn asi_debounce_next_turn_fires_again() {
3864 let router = RouterProvider::new(vec![]);
3865
3866 let prev1 = router.state.asi_last_turn.swap(1u64, Ordering::AcqRel);
3868 assert_ne!(prev1, 1u64, "turn 1: initial value != 1, should proceed");
3869
3870 let prev2 = router.state.asi_last_turn.swap(2u64, Ordering::AcqRel);
3872 let dropped = prev2 == 2u64;
3873 assert!(!dropped, "turn 2 must not be dropped (different turn_id)");
3874 }
3875
3876 #[test]
3877 fn turn_counter_increments_across_clones() {
3878 let router = RouterProvider::new(vec![]);
3879 let clone = router.clone();
3880
3881 let t0 = router.state.turn_counter.fetch_add(1, Ordering::Relaxed);
3882 let t1 = clone.state.turn_counter.fetch_add(1, Ordering::Relaxed);
3883
3884 assert_eq!(t1, t0 + 1, "cloned router shares turn_counter");
3886 }
3887
3888 #[test]
3889 fn with_embed_concurrency_zero_means_no_semaphore() {
3890 let r = RouterProvider::new(vec![]).with_embed_concurrency(0);
3891 assert!(
3892 r.state.embed_semaphore.is_none(),
3893 "0 should disable semaphore"
3894 );
3895 }
3896
3897 #[test]
3898 fn with_embed_concurrency_positive_creates_semaphore() {
3899 let r = RouterProvider::new(vec![]).with_embed_concurrency(4);
3900 let sem = r
3901 .state
3902 .embed_semaphore
3903 .as_ref()
3904 .expect("semaphore should exist");
3905 assert_eq!(sem.available_permits(), 4);
3906 }
3907
3908 #[tokio::test]
3909 async fn embed_semaphore_limits_concurrency() {
3910 use std::sync::Arc as StdArc;
3911 use std::sync::atomic::{AtomicUsize, Ordering as AO};
3912
3913 let sem = Arc::new(tokio::sync::Semaphore::new(2));
3916 let concurrent_peak = StdArc::new(AtomicUsize::new(0));
3917 let active = StdArc::new(AtomicUsize::new(0));
3918
3919 let mut handles = vec![];
3920 for _ in 0..6 {
3921 let sem_clone = sem.clone();
3922 let peak = concurrent_peak.clone();
3923 let active = active.clone();
3924 handles.push(tokio::spawn(async move {
3925 let _permit = sem_clone.acquire().await.unwrap();
3926 let cur = active.fetch_add(1, AO::SeqCst) + 1;
3927 let mut p = peak.load(AO::SeqCst);
3929 while p < cur {
3930 match peak.compare_exchange(p, cur, AO::SeqCst, AO::SeqCst) {
3931 Ok(_) => break,
3932 Err(new) => p = new,
3933 }
3934 }
3935 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
3936 active.fetch_sub(1, AO::SeqCst);
3937 }));
3938 }
3939 for h in handles {
3940 h.await.unwrap();
3941 }
3942 assert!(
3943 concurrent_peak.load(AO::SeqCst) <= 2,
3944 "peak concurrency should not exceed semaphore limit"
3945 );
3946 }
3947
3948 #[tokio::test]
3953 async fn turn_embed_cache_hit_increments_counter() {
3954 use crate::mock::MockProvider;
3955
3956 let mut m = MockProvider::default();
3957 m.supports_embeddings = true;
3958 m.embedding = vec![0.5, 0.5];
3959 let provider_embed_calls = Arc::clone(&m.embed_call_count);
3960
3961 let r = RouterProvider::new(vec![AnyProvider::Mock(m)]);
3962 let cache = Mutex::new(TurnEmbedCache::default());
3963
3964 let emb1 = r.embed_cached("hello", &cache).await.unwrap();
3966 let emb2 = r.embed_cached("hello", &cache).await.unwrap();
3968
3969 assert_eq!(emb1, emb2, "cached embedding must match original");
3970 assert_eq!(
3971 provider_embed_calls.load(Ordering::Relaxed),
3972 1,
3973 "provider embed() must be called exactly once (second call hits cache)"
3974 );
3975 let (total, hits) = r.embed_cache_metrics();
3976 assert_eq!(
3977 total, 2,
3978 "embed_call_count must be 2 (two embed_cached calls)"
3979 );
3980 assert_eq!(hits, 1, "embed_cache_hits must be 1 (one cache hit)");
3981 }
3982
3983 #[tokio::test]
3986 async fn spawn_asi_update_with_precomputed_skips_embed() {
3987 use crate::mock::MockProvider;
3988
3989 let mut m = MockProvider::with_responses(vec!["ok".to_owned()]);
3990 m.supports_embeddings = true;
3991 m.embedding = vec![1.0, 0.0];
3992 let provider_embed_calls = Arc::clone(&m.embed_call_count);
3993
3994 let r =
3995 RouterProvider::new(vec![AnyProvider::Mock(m)]).with_asi(AsiRouterConfig::default());
3996
3997 let precomputed = vec![0.9_f32, 0.1];
3998 let turn_id = 42u64;
3999
4000 r.state.asi_last_turn.store(u64::MAX, Ordering::SeqCst);
4002
4003 r.spawn_asi_update(
4004 "p1",
4005 "response".to_owned(),
4006 turn_id,
4007 Some(precomputed.clone()),
4008 );
4009
4010 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
4012
4013 assert_eq!(
4015 provider_embed_calls.load(Ordering::Relaxed),
4016 0,
4017 "embed() must not be called when precomputed_embedding is Some"
4018 );
4019
4020 let asi = r.asi.as_ref().unwrap().lock();
4022 let coherence = asi.coherence("p1");
4023 let _ = coherence; }
4027
4028 #[tokio::test]
4031 async fn blocking_load_runs_closure_on_current_thread_runtime() {
4032 let result = super::blocking_load(|| 42_u32);
4033 assert_eq!(result, 42, "blocking_load must return the closure result");
4034 }
4035
4036 #[tokio::test]
4041 async fn spawn_asi_update_reaped_after_cap_full() {
4042 use crate::mock::MockProvider;
4043 use std::sync::atomic::Ordering;
4044
4045 let mut m = MockProvider::with_responses(vec!["ok".to_owned()]);
4046 m.supports_embeddings = true;
4047 m.embedding = vec![1.0, 0.0];
4048 let embed_calls = Arc::clone(&m.embed_call_count);
4049
4050 let r =
4051 RouterProvider::new(vec![AnyProvider::Mock(m)]).with_asi(AsiRouterConfig::default());
4052 r.state.asi_last_turn.store(u64::MAX, Ordering::SeqCst);
4053
4054 for i in 0..super::MAX_ASI_TASKS {
4056 r.spawn_asi_update("p1", format!("resp{i}"), i as u64, Some(vec![0.5, 0.5]));
4057 }
4058 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
4059
4060 r.spawn_asi_update(
4063 "p1",
4064 "extra".to_owned(),
4065 super::MAX_ASI_TASKS as u64,
4066 Some(vec![0.9, 0.1]),
4067 );
4068
4069 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
4071
4072 assert_eq!(
4074 embed_calls.load(Ordering::Relaxed),
4075 0,
4076 "embed() must not be called when precomputed_embedding is Some"
4077 );
4078
4079 r.spawn_asi_update(
4082 "p1",
4083 "probe".to_owned(),
4084 (super::MAX_ASI_TASKS + 1) as u64,
4085 Some(vec![0.1, 0.9]),
4086 );
4087
4088 let remaining = r.asi_tasks.lock().len();
4090 assert!(
4091 remaining <= 1,
4092 "completed tasks must be reaped; at most 1 in-flight task expected, got {remaining}"
4093 );
4094 }
4095
4096 #[tokio::test]
4101 async fn spawn_asi_update_embed_timeout_does_not_update_asi() {
4102 use crate::mock::MockProvider;
4103 use std::sync::atomic::Ordering;
4104
4105 let mut m = MockProvider::with_responses(vec!["ok".to_owned()]);
4107 m.supports_embeddings = true;
4108 m.embedding = vec![1.0, 0.0];
4109 m.embed_delay_ms = 200;
4110 let provider_embed_calls = Arc::clone(&m.embed_call_count);
4111
4112 let r = RouterProvider::new(vec![AnyProvider::Mock(m)])
4113 .with_asi(AsiRouterConfig::default())
4114 .with_embed_timeout(10);
4115
4116 r.state.asi_last_turn.store(u64::MAX, Ordering::SeqCst);
4118
4119 r.spawn_asi_update("p1", "response".to_owned(), 1u64, None);
4121
4122 tokio::time::sleep(std::time::Duration::from_millis(150)).await;
4124
4125 assert!(
4127 provider_embed_calls.load(Ordering::Relaxed) >= 1,
4128 "embed() must have been attempted"
4129 );
4130
4131 let asi = r.asi.as_ref().unwrap().lock();
4133 let coherence = asi.coherence("p1");
4134 assert!(
4136 (coherence - 1.0).abs() < f32::EPSILON,
4137 "ASI window must be empty after embed timeout; coherence={coherence}"
4138 );
4139 }
4140}