Skip to main content

zeph_llm/router/
mod.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Multi-provider router with pluggable routing strategies.
5//!
6//! [`RouterProvider`] implements [`LlmProvider`] and forwards every call to one of
7//! its configured backends, chosen according to the active [`RouterStrategy`].
8//!
9//! # Routing strategies
10//!
11//! | Strategy | Module | Description |
12//! |---|---|---|
13//! | [`RouterStrategy::Ema`] | `crate::ema` | EMA-weighted latency-aware ordering |
14//! | [`RouterStrategy::Thompson`] | [`thompson`] | Bayesian Beta-distribution sampling |
15//! | [`RouterStrategy::Cascade`] | [`cascade`] | Cheapest-first with quality escalation |
16//! | [`RouterStrategy::Bandit`] | [`bandit`] | Contextual `LinUCB` (PILOT algorithm) |
17//!
18//! Strategies are selected via builder methods on [`RouterProvider`]:
19//! - [`RouterProvider::with_ema`]
20//! - [`RouterProvider::with_thompson`]
21//! - [`RouterProvider::with_cascade`]
22//! - [`RouterProvider::with_bandit`]
23//!
24//! # Reputation-Aware Provider Selection (RAPS)
25//!
26//! All strategies support an optional Bayesian reputation layer ([`reputation`]) that
27//! penalizes providers which produce semantically invalid tool arguments. Enable with
28//! [`RouterProvider::with_reputation`].
29//!
30//! # Agent Stability Index (ASI)
31//!
32//! An optional session-level coherence tracker ([`asi`]) measures embedding-based
33//! response quality and feeds back into Thompson selection. Enable with
34//! [`RouterProvider::with_asi`].
35//!
36//! # Security
37//!
38//! Thompson and Bandit state files are loaded from user-controlled paths at startup.
39//! Files are validated (finite floats, clamped range) and written with `0o600` permissions
40//! on Unix. Do not store state files in world-writable directories.
41
42pub mod asi;
43pub mod aware;
44pub mod bandit;
45pub mod cascade;
46pub mod coe;
47pub mod reputation;
48pub mod state;
49pub mod thompson;
50pub mod triage;
51
52pub use aware::RouterAware;
53pub use state::RouterState;
54
55use std::collections::HashMap;
56use std::path::Path;
57use std::sync::Arc;
58use std::sync::atomic::{AtomicU64, Ordering};
59
60use parking_lot::Mutex;
61
62use crate::any::AnyProvider;
63use crate::ema::EmaTracker;
64use crate::embed::owned_strs;
65use crate::error::LlmError;
66use crate::provider::{ChatResponse, ChatStream, LlmProvider, Message, StatusTx, ToolDefinition};
67use coe::{CoeDecision, CoeRouter, run_coe};
68
69use asi::AsiState;
70use bandit::{BanditState, embedding_to_features};
71use cascade::{CascadeState, ClassifierMode, heuristic_score};
72use reputation::ReputationTracker;
73use thompson::ThompsonState;
74
75/// Rate-limits the ASI coherence WARN to at most once per 60 seconds process-wide.
76static ASI_WARN_LAST_SECS: AtomicU64 = AtomicU64::new(0);
77
78/// Maximum number of concurrent fire-and-forget ASI coherence update tasks.
79///
80/// When the `JoinSet` reaches this limit, new spawns are skipped (not aborted) to
81/// preserve in-flight work. ASI tasks are analytics-only and do not affect
82/// memory persistence.
83const MAX_ASI_TASKS: usize = 8;
84use zeph_common::math::cosine_similarity;
85
86/// Runs `f` without blocking the Tokio executor.
87///
88/// On a multi-thread runtime uses `block_in_place`; on a `current_thread` runtime (unit
89/// tests, single-threaded entry points) falls back to a direct call since there is no
90/// executor thread pool to offload to.
91fn blocking_load<T>(f: impl FnOnce() -> T) -> T {
92    if tokio::runtime::Handle::try_current()
93        .is_ok_and(|h| h.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread)
94    {
95        tokio::task::block_in_place(f)
96    } else {
97        f()
98    }
99}
100
101/// Simple bounded embedding cache for bandit feature vectors.
102///
103/// Keyed by `u64` hash of query text (using `std::hash`). Eviction is FIFO on insertion
104/// order (not LRU) — acceptable for a routing cache where hot queries repeat often.
105/// The `lru` crate is not in the workspace; a `HashMap` + insertion-order Vec avoids a new dep.
106#[derive(Debug)]
107struct BanditEmbedCache {
108    map: HashMap<u64, Vec<f32>>,
109    order: std::collections::VecDeque<u64>,
110    capacity: usize,
111}
112
113impl BanditEmbedCache {
114    fn new(capacity: usize) -> Self {
115        Self {
116            map: HashMap::with_capacity(capacity),
117            order: std::collections::VecDeque::with_capacity(capacity),
118            capacity,
119        }
120    }
121
122    fn get(&self, key: u64) -> Option<&Vec<f32>> {
123        self.map.get(&key)
124    }
125
126    fn insert(&mut self, key: u64, value: Vec<f32>) {
127        if self.map.contains_key(&key) {
128            return;
129        }
130        if self.map.len() >= self.capacity
131            && let Some(evict) = self.order.pop_front()
132        {
133            self.map.remove(&evict);
134        }
135        self.map.insert(key, value);
136        self.order.push_back(key);
137    }
138}
139
140impl Default for BanditEmbedCache {
141    fn default() -> Self {
142        Self::new(512)
143    }
144}
145
146/// Per-turn embedding cache keyed by the exact input string.
147///
148/// Created at the start of each `chat()` call and dropped at the end. With 2-4 entries
149/// per turn, `String` keys have negligible overhead and eliminate the hash-collision risk
150/// of `u64`-keyed caches.
151#[derive(Debug, Default)]
152struct TurnEmbedCache {
153    entries: HashMap<String, Vec<f32>>,
154}
155
156impl TurnEmbedCache {
157    fn get(&self, text: &str) -> Option<&Vec<f32>> {
158        self.entries.get(text)
159    }
160
161    fn insert(&mut self, text: impl Into<String>, embedding: Vec<f32>) {
162        self.entries.insert(text.into(), embedding);
163    }
164}
165
166/// Routing strategy used by [`RouterProvider`].
167#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
168#[non_exhaustive]
169pub enum RouterStrategy {
170    /// Exponential moving average-based latency-aware ordering.
171    #[default]
172    Ema,
173    /// Thompson Sampling with Beta distributions.
174    Thompson,
175    /// Cascade: try cheapest provider first, escalate on degenerate output.
176    Cascade,
177    /// PILOT: `LinUCB` contextual bandit with online learning and budget-aware selection.
178    Bandit,
179}
180
181/// Configuration for PILOT bandit routing in `RouterProvider`.
182///
183/// See [`bandit`] module for the algorithm details and trade-offs.
184#[derive(Debug, Clone)]
185#[allow(clippy::doc_markdown)] // PILOT, LinUCB, Thompson are proper nouns/acronyms
186pub struct BanditRouterConfig {
187    /// `LinUCB` exploration parameter. Higher = more exploration. Default: 1.0.
188    pub alpha: f32,
189    /// Feature vector dimension (first `dim` components of embedding). Default: 32.
190    pub dim: usize,
191    /// Cost penalty weight in the reward signal: `reward = quality - cost_weight * cost_fraction`.
192    /// Default: 0.1. Increase to penalise expensive providers more aggressively.
193    pub cost_weight: f32,
194    /// Session-level decay factor: values < 1.0 cause re-exploration over time. Default: 1.0.
195    pub decay_factor: f32,
196    /// Minimum total updates before `LinUCB` takes over from Thompson fallback.
197    /// Default: `10 * num_providers` (computed at construction time from provider count).
198    pub warmup_queries: u64,
199    /// Hard timeout for the embedding call (milliseconds). If exceeded, falls back
200    /// to Thompson/uniform selection. Default: 50.
201    pub embedding_timeout_ms: u64,
202    /// Maximum number of cached embeddings (keyed by query string hash). Default: 512.
203    pub cache_size: usize,
204    /// MAR threshold: when `memory_hit_confidence >= this`, bias toward cheap providers.
205    /// Default: 0.9. Set to 1.0 to disable MAR.
206    pub memory_confidence_threshold: f32,
207}
208
209impl Default for BanditRouterConfig {
210    fn default() -> Self {
211        Self {
212            alpha: 1.0,
213            dim: 32,
214            cost_weight: 0.1,
215            decay_factor: 1.0,
216            warmup_queries: 0, // overridden by with_bandit() based on provider count
217            embedding_timeout_ms: 50,
218            cache_size: 512,
219            memory_confidence_threshold: 0.9,
220        }
221    }
222}
223
224/// Runtime ASI configuration passed to [`RouterProvider::with_asi`].
225///
226/// Mirrors `AsiRouterConfig` but lives in `zeph-llm` to avoid
227/// a dependency on `zeph-config`. The bootstrap layer maps config → this struct.
228#[derive(Debug, Clone)]
229pub struct AsiRouterConfig {
230    /// Sliding window size. Default: 5.
231    pub window: usize,
232    /// Coherence score threshold below which the provider is penalized. Default: 0.7.
233    pub coherence_threshold: f32,
234    /// Penalty weight added to Thompson beta on low coherence. Default: 0.3.
235    pub penalty_weight: f32,
236}
237
238impl Default for AsiRouterConfig {
239    fn default() -> Self {
240        Self {
241            window: 5,
242            coherence_threshold: 0.7,
243            penalty_weight: 0.3,
244        }
245    }
246}
247
248/// Configuration for cascade routing in `RouterProvider`.
249#[derive(Debug, Clone)]
250pub struct CascadeRouterConfig {
251    pub quality_threshold: f64,
252    pub max_escalations: u8,
253    pub classifier_mode: ClassifierMode,
254    pub window_size: usize,
255    pub max_cascade_tokens: Option<u32>,
256    /// LLM provider used for judge-mode quality scoring.
257    /// Required when `classifier_mode = Judge`; falls back to heuristic if `None`.
258    pub summary_provider: Option<Arc<dyn crate::provider_dyn::LlmProviderDyn>>,
259    /// Explicit cost ordering of provider names (cheapest first).
260    /// When set, providers are sorted by their position in this list at construction time.
261    /// Providers not listed are appended after listed ones in original chain order.
262    pub cost_tiers: Option<Vec<String>>,
263    /// Hard timeout for the judge LLM call (milliseconds). Default: 5000.
264    pub judge_timeout_ms: u64,
265}
266
267impl Default for CascadeRouterConfig {
268    fn default() -> Self {
269        Self {
270            quality_threshold: 0.5,
271            max_escalations: 2,
272            classifier_mode: ClassifierMode::Heuristic,
273            window_size: 50,
274            max_cascade_tokens: None,
275            summary_provider: None,
276            cost_tiers: None,
277            judge_timeout_ms: 5_000,
278        }
279    }
280}
281
282/// Multi-provider LLM router implementing [`LlmProvider`].
283///
284/// Construct with [`RouterProvider::new`] and configure a routing strategy via the
285/// builder methods. All configuration is immutable after construction except for
286/// runtime state (EMA statistics, Thompson distribution, bandit weights) which is
287/// stored behind `Arc<Mutex<_>>` and updated on every successful call.
288///
289/// Cloning is cheap: [`RouterState`] and all per-strategy state are `Arc`-wrapped
290/// and shared between the original and all clones — clone cost is proportional to
291/// the number of `Arc` fields, not to provider count or strategy complexity.
292#[derive(Debug, Clone)]
293pub struct RouterProvider {
294    /// Shared cross-strategy runtime signals (providers, turn counter, MAR, etc.).
295    ///
296    /// All fields inside are `Arc`-wrapped; clone is O(1).
297    pub(crate) state: RouterState,
298    status_tx: Option<StatusTx>,
299    ema: Option<EmaTracker>,
300    strategy: RouterStrategy,
301    thompson: Option<Arc<Mutex<ThompsonState>>>,
302    /// Path for persisting Thompson state. `None` disables persistence.
303    thompson_state_path: Option<std::path::PathBuf>,
304    /// Cascade routing state (quality history per provider).
305    cascade_state: Option<Arc<Mutex<CascadeState>>>,
306    /// Cascade routing configuration.
307    cascade_config: Option<CascadeRouterConfig>,
308    /// Bayesian reputation tracker (RAPS). None when disabled.
309    reputation: Option<Arc<Mutex<ReputationTracker>>>,
310    /// Path for persisting reputation state.
311    reputation_state_path: Option<std::path::PathBuf>,
312    /// Reputation weight in [0.0, 1.0] for routing score blend.
313    reputation_weight: f64,
314    /// PILOT bandit state.
315    bandit: Option<Arc<Mutex<BanditState>>>,
316    /// Path for persisting bandit state. `None` disables persistence.
317    bandit_state_path: Option<std::path::PathBuf>,
318    /// Bandit routing configuration.
319    bandit_config: Option<BanditRouterConfig>,
320    /// Dedicated embedding provider for bandit feature vectors.
321    /// When `None`, bandit falls back to Thompson/uniform on embed failure.
322    bandit_embedding_provider: Option<Arc<dyn crate::provider_dyn::LlmProviderDyn>>,
323    /// LRU embedding cache: maps query-string hash to feature vector.
324    /// Shared across requests; keyed by `u64` hash of query text.
325    bandit_embed_cache: Arc<Mutex<BanditEmbedCache>>,
326    /// Agent Stability Index state (session-only coherence tracking).
327    asi: Option<Arc<Mutex<AsiState>>>,
328    /// ASI configuration. `None` when ASI is disabled.
329    asi_config: Option<AsiRouterConfig>,
330    /// Embedding-based quality gate threshold. `None` = disabled.
331    /// After provider selection, `cosine_similarity(query_emb, response_emb)` must be >= this
332    /// value; otherwise the next provider in the ordered list is tried.
333    quality_gate: Option<f32>,
334    /// `CoE` (Collaborative Entropy) router. `None` when `CoE` is disabled.
335    coe: Option<Arc<CoeRouter>>,
336    /// Per-call timeout for `embed()` across all non-bandit providers (milliseconds).
337    /// Defaults to 5000 ms. A stalled provider is skipped and the next one is tried.
338    embed_timeout_ms: u64,
339    /// Bounded set of fire-and-forget ASI coherence update tasks.
340    ///
341    /// Shared across all clones via `Arc`; capped at [`MAX_ASI_TASKS`]. New spawns are
342    /// skipped (not aborted) when the cap is reached to preserve in-flight work.
343    asi_tasks: Arc<Mutex<tokio::task::JoinSet<()>>>,
344}
345
346impl RouterProvider {
347    /// Create a new router over `providers`.
348    ///
349    /// Use the builder methods (e.g., [`with_thompson`][Self::with_thompson],
350    /// [`with_cascade`][Self::with_cascade]) to configure a routing strategy.
351    /// The default strategy is [`RouterStrategy::Ema`].
352    #[must_use]
353    pub fn new(providers: Vec<AnyProvider>) -> Self {
354        let state = RouterState::new(Arc::from(providers));
355        Self {
356            state,
357            status_tx: None,
358            ema: None,
359            strategy: RouterStrategy::Ema,
360            thompson: None,
361            thompson_state_path: None,
362            cascade_state: None,
363            cascade_config: None,
364            reputation: None,
365            reputation_state_path: None,
366            reputation_weight: 0.3,
367            bandit: None,
368            bandit_state_path: None,
369            bandit_config: None,
370            bandit_embedding_provider: None,
371            bandit_embed_cache: Arc::new(Mutex::new(BanditEmbedCache::default())),
372            asi: None,
373            asi_config: None,
374            quality_gate: None,
375            coe: None,
376            embed_timeout_ms: 5000,
377            asi_tasks: Arc::new(Mutex::new(tokio::task::JoinSet::new())),
378        }
379    }
380
381    /// Set the per-call timeout for [`embed`][Self::embed] across all non-bandit providers.
382    ///
383    /// A stalled provider is skipped and the next candidate is tried. Default is `5000` ms.
384    /// Pass `0` to disable the timeout (not recommended for production).
385    ///
386    /// # Examples
387    ///
388    /// ```no_run
389    /// # use zeph_llm::router::RouterProvider;
390    /// let router = RouterProvider::new(vec![]).with_embed_timeout(3000);
391    /// ```
392    #[must_use]
393    pub fn with_embed_timeout(mut self, timeout_ms: u64) -> Self {
394        self.embed_timeout_ms = timeout_ms;
395        self
396    }
397
398    /// Set the maximum number of concurrent `embed_batch` calls.
399    ///
400    /// A value of 0 disables the semaphore (unlimited). Default is no semaphore.
401    #[must_use]
402    pub fn with_embed_concurrency(mut self, limit: usize) -> Self {
403        self.state.embed_semaphore = if limit > 0 {
404            Some(Arc::new(tokio::sync::Semaphore::new(limit)))
405        } else {
406            None
407        };
408        self
409    }
410
411    /// Set the MAR (Memory-Augmented Routing) signal for the current turn.
412    ///
413    /// Must be called before `chat` / `chat_stream` to influence bandit provider selection.
414    /// Pass `None` to disable MAR for this turn.
415    pub fn set_memory_confidence(&self, confidence: Option<f32>) {
416        let raw = confidence.map_or(u32::MAX, f32::to_bits);
417        self.state
418            .last_memory_confidence
419            .store(raw, std::sync::atomic::Ordering::Relaxed);
420    }
421
422    /// Enable EMA-based adaptive provider ordering.
423    #[must_use]
424    pub fn with_ema(mut self, alpha: f64, reorder_interval: u64) -> Self {
425        self.ema = Some(EmaTracker::new(alpha, reorder_interval));
426        self
427    }
428
429    /// Enable Collaborative Entropy (`CoE`) for Ema/Thompson strategies.
430    ///
431    /// `CoE` detects uncertain responses via intra-entropy and inter-divergence signals,
432    /// escalating to `secondary` when either threshold is exceeded.
433    ///
434    /// No-op (with a `warn!`) when the active strategy is `Cascade` or `Bandit`.
435    #[must_use]
436    pub fn with_coe(
437        mut self,
438        config: coe::CoeConfig,
439        secondary: AnyProvider,
440        embed: AnyProvider,
441    ) -> Self {
442        if matches!(
443            self.strategy,
444            RouterStrategy::Cascade | RouterStrategy::Bandit
445        ) {
446            tracing::warn!(
447                strategy = ?self.strategy,
448                "coe disabled for strategy; supported: ema, thompson"
449            );
450            return self;
451        }
452        self.coe = Some(Arc::new(CoeRouter {
453            config,
454            secondary: Arc::new(secondary) as Arc<dyn crate::provider_dyn::LlmProviderDyn>,
455            embed: Arc::new(embed) as Arc<dyn crate::provider_dyn::LlmProviderDyn>,
456            metrics: Arc::new(coe::CoeMetrics::default()),
457        }));
458        self
459    }
460
461    /// Return session-level `CoE` metrics snapshot, or `None` if `CoE` is disabled.
462    #[must_use]
463    pub fn coe_metrics(&self) -> Option<(u64, u64, u64, u64)> {
464        self.coe.as_ref().map(|c| {
465            (
466                c.metrics.kept_primary.load(Ordering::Relaxed),
467                c.metrics.intra_escalations.load(Ordering::Relaxed),
468                c.metrics.inter_escalations.load(Ordering::Relaxed),
469                c.metrics.embed_failures.load(Ordering::Relaxed),
470            )
471        })
472    }
473
474    /// Enable Agent Stability Index (ASI) coherence tracking.
475    ///
476    /// When enabled, each successful response is embedded in a background task and added
477    /// to a per-provider sliding window. The coherence score (cosine similarity of the
478    /// latest embedding vs. window mean) penalizes Thompson/EMA routing priors for
479    /// providers whose responses drift.
480    #[must_use]
481    pub fn with_asi(mut self, config: AsiRouterConfig) -> Self {
482        self.asi = Some(Arc::new(Mutex::new(AsiState::default())));
483        self.asi_config = Some(config);
484        self
485    }
486
487    /// Enable embedding-based quality gate for Thompson/EMA routing.
488    ///
489    /// After provider selection, computes cosine similarity between the query embedding
490    /// and the response embedding. If below `threshold`, tries the next provider in the
491    /// ordered list. On full exhaustion, returns the best response seen (highest similarity).
492    /// Fail-open: embedding errors disable the gate for that request.
493    #[must_use]
494    pub fn with_quality_gate(mut self, threshold: f32) -> Self {
495        self.quality_gate = Some(threshold);
496        self
497    }
498
499    /// Enable Thompson Sampling strategy.
500    ///
501    /// Loads existing state from `state_path` if present; falls back to uniform prior.
502    /// Prunes stale entries for providers not in the current chain.
503    #[must_use]
504    pub fn with_thompson(mut self, state_path: Option<&Path>) -> Self {
505        self.strategy = RouterStrategy::Thompson;
506        let path = state_path.map_or_else(ThompsonState::default_path, Path::to_path_buf);
507        let mut state = blocking_load(|| ThompsonState::load(&path));
508        // CRIT-3: prune orphan entries from previous configs.
509        let known: std::collections::HashSet<String> = self
510            .state
511            .providers
512            .iter()
513            .map(|p| p.name().to_owned())
514            .collect();
515        state.prune(&known);
516        self.thompson = Some(Arc::new(Mutex::new(state)));
517        self.thompson_state_path = Some(path);
518        self
519    }
520
521    /// Enable PILOT bandit routing strategy (`LinUCB` contextual bandit).
522    ///
523    /// Loads existing state from `state_path` (or the default path) using
524    /// [`tokio::task::block_in_place`] to avoid blocking the async executor.
525    /// Applies session-level decay if `config.decay_factor < 1.0`, and prunes arms for
526    /// removed providers.
527    ///
528    /// `embedding_provider` is used to obtain feature vectors for each query.
529    /// When `None`, the bandit falls back to Thompson/uniform selection whenever an
530    /// embedding cannot be obtained within `config.embedding_timeout_ms`.
531    ///
532    /// The `warmup_queries` default of `0` in `BanditRouterConfig` is overridden here to
533    /// `10 * num_providers` to ensure sufficient initial exploration.
534    #[must_use]
535    pub fn with_bandit(
536        mut self,
537        mut config: BanditRouterConfig,
538        state_path: Option<&Path>,
539        embedding_provider: Option<AnyProvider>,
540    ) -> Self {
541        self.strategy = RouterStrategy::Bandit;
542        let n = self.state.providers.len();
543        if config.warmup_queries == 0 {
544            config.warmup_queries = u64::try_from(10 * n.max(1)).unwrap_or(100);
545        }
546        let cache_size = config.cache_size;
547        let path = state_path.map_or_else(BanditState::default_path, Path::to_path_buf);
548        let mut state = blocking_load(|| BanditState::load(&path));
549        if state.dim == 0 {
550            state = BanditState::new(config.dim);
551        } else if state.dim != config.dim {
552            // Config changed dim — reset state rather than use mismatched arms.
553            tracing::warn!(
554                old_dim = state.dim,
555                new_dim = config.dim,
556                "bandit: dim changed, resetting state"
557            );
558            state = BanditState::new(config.dim);
559        }
560        // Validate config bounds before applying. Clamp to safe ranges with a warning.
561        if config.alpha <= 0.0 {
562            tracing::warn!(alpha = config.alpha, "bandit: alpha <= 0, clamping to 0.01");
563            config.alpha = 0.01;
564        }
565        if config.dim == 0 || config.dim > 256 {
566            tracing::warn!(
567                dim = config.dim,
568                "bandit: dim out of range [1, 256], clamping to 32"
569            );
570            config.dim = 32;
571        }
572        if config.decay_factor <= 0.0 || config.decay_factor > 1.0 {
573            tracing::warn!(
574                decay_factor = config.decay_factor,
575                "bandit: decay_factor out of (0.0, 1.0], clamping to 1.0"
576            );
577            config.decay_factor = 1.0;
578        }
579        if config.decay_factor < 1.0 {
580            state.apply_decay(config.decay_factor);
581        }
582        let known: std::collections::HashSet<String> = self
583            .state
584            .providers
585            .iter()
586            .map(|p| p.name().to_owned())
587            .collect();
588        state.prune(&known);
589        self.bandit = Some(Arc::new(Mutex::new(state)));
590        self.bandit_state_path = Some(path);
591        self.bandit_embed_cache = Arc::new(Mutex::new(BanditEmbedCache::new(cache_size)));
592        self.bandit_embedding_provider =
593            embedding_provider.map(|p| Arc::new(p) as Arc<dyn crate::provider_dyn::LlmProviderDyn>);
594        // Initialize Thompson state for cold-start fallback (total_updates < warmup_queries).
595        // Uses default uniform priors; no persistence path needed since it's a fallback only.
596        self.thompson = Some(Arc::new(Mutex::new(ThompsonState::default())));
597        self.bandit_config = Some(config);
598        self
599    }
600
601    /// Persist current bandit state to disk. No-op if bandit strategy is not active.
602    ///
603    /// Uses [`tokio::task::spawn_blocking`] so it is safe to call from any async context.
604    pub async fn save_bandit_state(&self) {
605        let (Some(bandit), Some(path)) = (&self.bandit, &self.bandit_state_path) else {
606            return;
607        };
608        let bandit = Arc::clone(bandit);
609        let path = path.clone();
610        tokio::task::spawn_blocking(move || {
611            let state = bandit.lock();
612            if let Err(e) = state.save(&path) {
613                tracing::warn!(error = %e, "failed to save bandit state");
614            }
615        })
616        .await
617        .unwrap_or_else(|e| tracing::warn!(error = %e, "bandit state save task panicked"));
618    }
619
620    /// Return bandit diagnostic stats: `(provider_name, pulls, mean_reward)`.
621    ///
622    /// Returns an empty vec if bandit strategy is not active.
623    #[must_use]
624    pub fn bandit_stats(&self) -> Vec<(String, u64, f32)> {
625        let Some(ref bandit) = self.bandit else {
626            return vec![];
627        };
628        let state = bandit.lock();
629        state.stats()
630    }
631
632    /// Enable Bayesian reputation scoring (RAPS).
633    ///
634    /// Loads existing state from `state_path` (or the default path) using
635    /// [`tokio::task::block_in_place`] to avoid blocking the async executor.
636    /// Applies session-level decay and prunes stale provider entries.
637    ///
638    /// No-op for Cascade routing (reputation is not used for cost-tier ordering).
639    #[must_use]
640    pub fn with_reputation(
641        mut self,
642        decay_factor: f64,
643        weight: f64,
644        min_observations: u64,
645        state_path: Option<&Path>,
646    ) -> Self {
647        let path = state_path.map_or_else(ReputationTracker::default_path, Path::to_path_buf);
648        // Load persisted state, apply decay, and prune orphaned providers.
649        let mut tracker = blocking_load(|| ReputationTracker::load(&path));
650        let known: std::collections::HashSet<String> = self
651            .state
652            .providers
653            .iter()
654            .map(|p| p.name().to_owned())
655            .collect();
656        tracker.apply_decay();
657        tracker.prune(&known);
658        // Overwrite config params (decay/min_obs may differ from the persisted defaults).
659        let tracker = {
660            let stats = tracker.stats();
661            let mut t = ReputationTracker::new(decay_factor, min_observations);
662            for (name, alpha, beta, _, obs) in stats {
663                t.models.insert(
664                    name,
665                    reputation::ReputationEntry {
666                        dist: thompson::BetaDist { alpha, beta },
667                        observations: obs,
668                    },
669                );
670            }
671            t
672        };
673        self.reputation = Some(Arc::new(Mutex::new(tracker)));
674        self.reputation_state_path = Some(path);
675        self.reputation_weight = weight.clamp(0.0, 1.0);
676        self
677    }
678
679    /// Record a quality outcome for the last active sub-provider (tool execution result).
680    ///
681    /// Call only for semantic failures (invalid tool args, parse errors).
682    /// Do NOT call for network errors, rate limits, or transient I/O failures.
683    /// No-op when reputation scoring is disabled, strategy is Cascade, or no tool call
684    /// has been made yet in this session.
685    ///
686    /// The `_provider_name` parameter is ignored — quality is attributed to the sub-provider
687    /// that served the most recent `chat_with_tools` call, tracked via `last_active_provider`.
688    pub fn record_quality_outcome(&self, _provider_name: &str, success: bool) {
689        if matches!(
690            self.strategy,
691            RouterStrategy::Cascade | RouterStrategy::Bandit
692        ) {
693            // Cascade: quality tracked via CascadeState.
694            // Bandit: quality fed via bandit_record_reward() after each response.
695            return;
696        }
697        let Some(ref reputation) = self.reputation else {
698            return;
699        };
700        let active = self.state.last_active_provider.lock().clone();
701        let Some(provider_name) = active else {
702            return;
703        };
704        let mut tracker = reputation.lock();
705        tracker.record_quality(&provider_name, success);
706    }
707
708    /// Returns the `provider_kind_str` of the last provider selected by the router.
709    ///
710    /// Used by [`crate::any::AnyProvider::provider_kind_str`] to attribute cost to the
711    /// actual child provider rather than returning the generic `"local"` sentinel for all
712    /// router-dispatched calls. Falls back to `"local"` when no call has been made yet.
713    #[must_use]
714    pub fn last_selected_provider_kind(&self) -> &'static str {
715        let name = self.state.last_active_provider.lock().clone();
716        let Some(name) = name else {
717            return "local";
718        };
719        self.state
720            .providers
721            .iter()
722            .find(|p| p.name() == name)
723            .map_or("local", |p| p.provider_kind_str())
724    }
725
726    /// Persist current reputation state to disk. No-op if reputation is disabled.
727    /// Uses [`tokio::task::spawn_blocking`] so it is safe to call from any async context.
728    pub async fn save_reputation_state(&self) {
729        let (Some(reputation), Some(path)) = (&self.reputation, &self.reputation_state_path) else {
730            return;
731        };
732        let reputation = Arc::clone(reputation);
733        let path = path.clone();
734        tokio::task::spawn_blocking(move || {
735            let state = reputation.lock();
736            if let Err(e) = state.save(&path) {
737                tracing::warn!(error = %e, "failed to save reputation state");
738            }
739        })
740        .await
741        .unwrap_or_else(|e| tracing::warn!(error = %e, "reputation state save task panicked"));
742    }
743
744    /// Return reputation stats for all tracked providers: (name, alpha, beta, mean, observations).
745    #[must_use]
746    pub fn reputation_stats(&self) -> Vec<(String, f64, f64, f64, u64)> {
747        let Some(ref reputation) = self.reputation else {
748            return vec![];
749        };
750        let tracker = reputation.lock();
751        tracker.stats()
752    }
753
754    /// Enable Cascade routing strategy.
755    ///
756    /// Providers are tried in chain order (cheapest first). Each response is evaluated
757    /// by the quality classifier; if it falls below `quality_threshold`, the next
758    /// provider is tried. At most `max_escalations` quality-based escalations occur.
759    ///
760    /// Network/API errors do not count against the escalation budget.
761    /// The best response seen so far is returned if all escalations are exhausted.
762    ///
763    /// When `config.cost_tiers` is set, providers are reordered once at construction
764    /// time (no per-request cost). Providers absent from `cost_tiers` are appended
765    /// after listed ones in original chain order. Unknown names in `cost_tiers` emit
766    /// a warning and are otherwise ignored.
767    #[must_use]
768    pub fn with_cascade(mut self, config: CascadeRouterConfig) -> Self {
769        self.strategy = RouterStrategy::Cascade;
770
771        if let Some(ref tiers) = config.cost_tiers
772            && !tiers.is_empty()
773        {
774            let provider_names: std::collections::HashSet<&str> =
775                self.state.providers.iter().map(AnyProvider::name).collect();
776            for name in tiers {
777                if !provider_names.contains(name.as_str()) {
778                    tracing::warn!(
779                        name = %name,
780                        "cascade: cost_tiers entry does not match any provider name"
781                    );
782                }
783            }
784
785            let tier_pos: std::collections::HashMap<&str, usize> = tiers
786                .iter()
787                .enumerate()
788                .map(|(i, n)| (n.as_str(), i))
789                .collect();
790
791            let before: Vec<_> = self
792                .state
793                .providers
794                .iter()
795                .map(|p| p.name().to_owned())
796                .collect();
797            let mut indexed: Vec<(usize, AnyProvider)> =
798                self.state.providers.iter().cloned().enumerate().collect();
799            indexed.sort_by_key(|(orig_idx, p)| {
800                tier_pos
801                    .get(p.name())
802                    .copied()
803                    .map_or((1usize, *orig_idx), |t| (0, t))
804            });
805            let after: Vec<_> = indexed.iter().map(|(_, p)| p.name().to_owned()).collect();
806            if before != after {
807                tracing::debug!(
808                    before = ?before,
809                    after = ?after,
810                    "cascade: providers reordered by cost_tiers"
811                );
812            }
813            self.state.providers =
814                Arc::from(indexed.into_iter().map(|(_, p)| p).collect::<Vec<_>>());
815        }
816
817        let window = config.window_size;
818        self.cascade_state = Some(Arc::new(Mutex::new(CascadeState::new(window))));
819        self.cascade_config = Some(config);
820        self
821    }
822
823    /// Persist current Thompson state to disk.
824    ///
825    /// No-op if Thompson strategy is not active.
826    ///
827    /// Uses [`tokio::task::spawn_blocking`] so it is safe to call from any async context,
828    /// including mid-request paths.
829    pub async fn save_thompson_state(&self) {
830        let (Some(thompson), Some(path)) = (&self.thompson, &self.thompson_state_path) else {
831            return;
832        };
833        let thompson = Arc::clone(thompson);
834        let path = path.clone();
835        tokio::task::spawn_blocking(move || {
836            let state = thompson.lock();
837            if let Err(e) = state.save(&path) {
838                tracing::warn!(error = %e, "failed to save Thompson router state");
839            }
840        })
841        .await
842        .unwrap_or_else(|e| tracing::warn!(error = %e, "Thompson state save task panicked"));
843    }
844
845    /// Hash a query string to a `u64` cache key.
846    fn query_hash(query: &str) -> u64 {
847        use std::hash::{Hash as _, Hasher as _};
848        let mut h = std::collections::hash_map::DefaultHasher::new();
849        query.hash(&mut h);
850        h.finish()
851    }
852
853    /// Fetch or compute the feature vector for `query` using the bandit embedding provider.
854    ///
855    /// Returns `None` if:
856    /// - No embedding provider is configured.
857    /// - The embedding call exceeds `embedding_timeout_ms`.
858    /// - The embedding is shorter than `dim` or is all-zero.
859    async fn bandit_features(&self, query: &str) -> Option<Vec<f32>> {
860        let cfg = self.bandit_config.as_ref()?;
861        let key = Self::query_hash(query);
862
863        // Check cache first (no async needed).
864        {
865            let cache = self.bandit_embed_cache.lock();
866            if let Some(cached) = cache.get(key) {
867                return Some(cached.clone());
868            }
869        }
870
871        let provider = self.bandit_embedding_provider.as_ref()?;
872        let timeout = std::time::Duration::from_millis(cfg.embedding_timeout_ms);
873        let embed_future = provider.embed(query);
874        let embedding = match tokio::time::timeout(timeout, embed_future).await {
875            Ok(Ok(emb)) => emb,
876            Ok(Err(e)) => {
877                tracing::debug!(error = %e, "bandit: embedding failed, falling back");
878                return None;
879            }
880            Err(_) => {
881                tracing::debug!(
882                    timeout_ms = cfg.embedding_timeout_ms,
883                    "bandit: embedding timed out, falling back"
884                );
885                return None;
886            }
887        };
888
889        let features = embedding_to_features(&embedding, cfg.dim)?;
890
891        // Insert into cache.
892        {
893            let mut cache = self.bandit_embed_cache.lock();
894            cache.insert(key, features.clone());
895        }
896        Some(features)
897    }
898
899    /// Select a provider using `LinUCB` bandit, with Thompson fallback on cold start / missing features.
900    ///
901    /// Falls through to Thompson or first available provider when bandit cannot decide.
902    /// Budget enforcement via global `CostTracker` is handled at the caller level.
903    /// Per-provider budget fractions are intentionally NOT implemented (scope creep, see #2230).
904    async fn bandit_select_provider(&self, query: &str) -> Option<AnyProvider> {
905        let Some(ref bandit_arc) = self.bandit else {
906            return self.state.providers.first().cloned();
907        };
908        let cfg = self.bandit_config.as_ref()?;
909
910        let names: Vec<String> = self
911            .state
912            .providers
913            .iter()
914            .map(|p| p.name().to_owned())
915            .collect();
916
917        // Try LinUCB selection with feature vector.
918        if let Some(features) = self.bandit_features(query).await {
919            let raw = self
920                .state
921                .last_memory_confidence
922                .load(std::sync::atomic::Ordering::Relaxed);
923            let memory_confidence = if raw == u32::MAX {
924                None
925            } else {
926                Some(f32::from_bits(raw))
927            };
928            let selected = {
929                let state = bandit_arc.lock();
930                state.select(
931                    &names,
932                    &features,
933                    cfg.alpha,
934                    cfg.warmup_queries,
935                    &|_| true,
936                    cfg.cost_weight,
937                    &self.state.provider_models,
938                    memory_confidence,
939                    cfg.memory_confidence_threshold,
940                )
941            };
942            if let Some(name) = selected {
943                tracing::debug!(
944                    provider = %name,
945                    strategy = "bandit",
946                    memory_confidence = ?memory_confidence,
947                    "selected provider"
948                );
949                return self
950                    .state
951                    .providers
952                    .iter()
953                    .find(|p| p.name() == name)
954                    .cloned();
955            }
956        }
957
958        // Fallback: Thompson sampling.
959        if let Some(ref thompson) = self.thompson {
960            let mut state = thompson.lock();
961            if let Some(sel) = state.select(&names) {
962                tracing::debug!(
963                    provider = %sel.provider,
964                    strategy = "bandit-fallback-thompson",
965                    "selected provider"
966                );
967                return self
968                    .state
969                    .providers
970                    .iter()
971                    .find(|p| p.name() == sel.provider)
972                    .cloned();
973            }
974        }
975
976        // Last resort: first provider.
977        self.state.providers.first().cloned()
978    }
979
980    /// Record the bandit reward for a completed request.
981    ///
982    /// `quality_score`: heuristic quality in [0, 1] from `heuristic_score()`.
983    /// `cost_fraction`: `request_cost_cents / max_daily_cents` (0 when budget is unlimited).
984    fn bandit_record_reward(
985        &self,
986        provider_name: &str,
987        features: &[f32],
988        quality_score: f64,
989        cost_fraction: f64,
990    ) {
991        let Some(ref bandit_arc) = self.bandit else {
992            return;
993        };
994        let Some(cfg) = &self.bandit_config else {
995            return;
996        };
997        #[allow(clippy::cast_possible_truncation)]
998        let reward = (quality_score as f32) - cfg.cost_weight * (cost_fraction as f32);
999        let reward = reward.clamp(-1.0, 1.0);
1000        let mut state = bandit_arc.lock();
1001        state.update(provider_name, features, reward);
1002        tracing::debug!(
1003            provider = provider_name,
1004            reward,
1005            quality = quality_score,
1006            "bandit: recorded reward"
1007        );
1008    }
1009
1010    fn ordered_providers(&self) -> Vec<AnyProvider> {
1011        match self.strategy {
1012            RouterStrategy::Thompson => self.thompson_ordered_providers(),
1013            RouterStrategy::Ema => self.ema_ordered_providers(),
1014            // Cascade/Bandit: sync path used only for debug_request_json(); hot paths use
1015            // dedicated async selection methods. For Cascade, providers are sorted at
1016            // construction time.
1017            RouterStrategy::Cascade | RouterStrategy::Bandit => self.state.providers.to_vec(),
1018        }
1019    }
1020
1021    fn ema_ordered_providers(&self) -> Vec<AnyProvider> {
1022        let order = self.state.provider_order.lock();
1023        let mut ordered: Vec<AnyProvider> = order
1024            .iter()
1025            .filter_map(|&i| self.state.providers.get(i).cloned())
1026            .collect();
1027
1028        // CRIT-2 fix: apply reputation as a multiplicative adjustment to the EMA score,
1029        // not an additive term. This avoids unbounded score inflation.
1030        //
1031        // Adjustment formula: ema_score * (1 + weight * (rep_factor - 0.5) * 2)
1032        // where rep_factor in [0,1]: 0.5 = neutral, >0.5 = positive, <0.5 = negative.
1033        // CRIT-1 fix: reputation factor is sampled per-provider (each has its own Beta mean).
1034        if let Some(ref reputation) = self.reputation
1035            && let Some(ref ema) = self.ema
1036        {
1037            let rep = reputation.lock();
1038            let w = self.reputation_weight;
1039            let snap = ema.snapshot();
1040            let mut scored: Vec<(usize, f64)> = ordered
1041                .iter()
1042                .enumerate()
1043                .map(|(idx, p)| {
1044                    let ema_score = snap
1045                        .get(p.name())
1046                        .map_or(0.0, |s| s.success_ema - s.latency_ema_ms / 10_000.0);
1047                    let score = if let Some(rep_factor) = rep.ema_reputation_factor(p.name()) {
1048                        // Multiplicative blend: neutral at rep_factor=0.5, range ±weight.
1049                        let adjustment = 1.0 + w * (rep_factor - 0.5) * 2.0;
1050                        ema_score * adjustment
1051                    } else {
1052                        ema_score
1053                    };
1054                    (idx, score)
1055                })
1056                .collect();
1057            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1058            let reordered: Vec<AnyProvider> = scored
1059                .into_iter()
1060                .filter_map(|(idx, _)| ordered.get(idx).cloned())
1061                .collect();
1062            ordered = reordered;
1063        }
1064
1065        // ASI: re-score by down-weighting providers with low coherence.
1066        if let (Some(asi_arc), Some(asi_cfg)) = (&self.asi, &self.asi_config) {
1067            let asi: parking_lot::MutexGuard<'_, AsiState> = asi_arc.lock();
1068            let snap = self.ema.as_ref().map(EmaTracker::snapshot);
1069            let mut scored: Vec<(usize, f64)> = ordered
1070                .iter()
1071                .enumerate()
1072                .map(|(idx, p)| {
1073                    let coherence = asi.coherence(p.name());
1074                    if coherence < asi_cfg.coherence_threshold {
1075                        let now = std::time::SystemTime::now()
1076                            .duration_since(std::time::UNIX_EPOCH)
1077                            .unwrap_or(std::time::Duration::MAX)
1078                            .as_secs();
1079                        let last = ASI_WARN_LAST_SECS.load(Ordering::Relaxed);
1080                        if now.saturating_sub(last) >= 60
1081                            && ASI_WARN_LAST_SECS
1082                                .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
1083                                .is_ok()
1084                        {
1085                            tracing::warn!(
1086                                provider = p.name(),
1087                                coherence,
1088                                threshold = asi_cfg.coherence_threshold,
1089                                "asi: coherence below threshold"
1090                            );
1091                        } else {
1092                            tracing::trace!(
1093                                provider = p.name(),
1094                                coherence,
1095                                threshold = asi_cfg.coherence_threshold,
1096                                "asi: coherence below threshold (warn rate-limited)"
1097                            );
1098                        }
1099                    }
1100                    let base_score = snap
1101                        .as_ref()
1102                        .and_then(|s| s.get(p.name()))
1103                        .map_or(0.0, |s| s.success_ema - s.latency_ema_ms / 10_000.0);
1104                    // Multiply EMA score by coherence multiplier clamped to [0.5, 1.0].
1105                    let multiplier = (coherence / asi_cfg.coherence_threshold).clamp(0.5, 1.0);
1106                    #[allow(clippy::cast_possible_truncation)]
1107                    let adjusted = base_score * f64::from(multiplier);
1108                    (idx, adjusted)
1109                })
1110                .collect();
1111            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1112            let reordered: Vec<AnyProvider> = scored
1113                .into_iter()
1114                .filter_map(|(idx, _)| ordered.get(idx).cloned())
1115                .collect();
1116            ordered = reordered;
1117        }
1118
1119        if let Some(first) = ordered.first() {
1120            tracing::debug!(
1121                provider = %first.name(),
1122                strategy = "ema",
1123                "selected provider"
1124            );
1125        }
1126        ordered
1127    }
1128
1129    fn thompson_ordered_providers(&self) -> Vec<AnyProvider> {
1130        let Some(ref thompson) = self.thompson else {
1131            return self.state.providers.to_vec();
1132        };
1133        let mut state = thompson.lock();
1134        let names: Vec<String> = self
1135            .state
1136            .providers
1137            .iter()
1138            .map(|p| p.name().to_owned())
1139            .collect();
1140
1141        // Compute per-provider prior overrides: start from base Beta distribution, apply
1142        // reputation shift (CRIT-3), then apply ASI coherence penalty.
1143        let has_reputation = self.reputation.is_some();
1144        let has_asi = self.asi.is_some() && self.asi_config.is_some();
1145
1146        let selected = if has_reputation || has_asi {
1147            // Build overrides by composing reputation and ASI adjustments.
1148            let rep_guard = self.reputation.as_ref().map(|r| r.lock());
1149            let asi_guard: Option<parking_lot::MutexGuard<'_, AsiState>> =
1150                self.asi.as_ref().map(|a| a.lock());
1151            let w = self.reputation_weight;
1152
1153            let overrides: std::collections::HashMap<String, (f64, f64)> = names
1154                .iter()
1155                .map(|name| {
1156                    let base = state.get_distribution(name);
1157                    // Apply reputation prior shift.
1158                    let (alpha, mut beta) = if let Some(ref rep) = rep_guard {
1159                        rep.shift_thompson_priors(name, base.alpha, base.beta, w)
1160                    } else {
1161                        (base.alpha, base.beta)
1162                    };
1163                    // Apply ASI coherence penalty: shift beta by penalty_weight * deficit.
1164                    if let (Some(asi), Some(asi_cfg)) = (&asi_guard, &self.asi_config) {
1165                        let coherence = asi.coherence(name);
1166                        if coherence < asi_cfg.coherence_threshold {
1167                            let now = std::time::SystemTime::now()
1168                                .duration_since(std::time::UNIX_EPOCH)
1169                                .unwrap_or(std::time::Duration::MAX)
1170                                .as_secs();
1171                            let last = ASI_WARN_LAST_SECS.load(Ordering::Relaxed);
1172                            if now.saturating_sub(last) >= 60
1173                                && ASI_WARN_LAST_SECS
1174                                    .compare_exchange(
1175                                        last,
1176                                        now,
1177                                        Ordering::Relaxed,
1178                                        Ordering::Relaxed,
1179                                    )
1180                                    .is_ok()
1181                            {
1182                                tracing::warn!(
1183                                    provider = name.as_str(),
1184                                    coherence,
1185                                    threshold = asi_cfg.coherence_threshold,
1186                                    "asi: coherence below threshold"
1187                                );
1188                            } else {
1189                                tracing::trace!(
1190                                    provider = name.as_str(),
1191                                    coherence,
1192                                    threshold = asi_cfg.coherence_threshold,
1193                                    "asi: coherence below threshold (warn rate-limited)"
1194                                );
1195                            }
1196                            let deficit = asi_cfg.coherence_threshold - coherence;
1197                            let penalty = f64::from(asi_cfg.penalty_weight * deficit);
1198                            beta += penalty;
1199                        }
1200                    }
1201                    (name.clone(), (alpha, beta))
1202                })
1203                .collect();
1204
1205            drop(rep_guard);
1206            drop(asi_guard);
1207            state.select_with_priors(&names, &overrides)
1208        } else {
1209            state.select(&names)
1210        };
1211
1212        if let Some(ref sel) = selected {
1213            tracing::debug!(
1214                provider = %sel.provider,
1215                strategy = "thompson",
1216                mode = if sel.exploit { "exploit" } else { "explore" },
1217                alpha = sel.alpha,
1218                beta = sel.beta,
1219                "selected provider"
1220            );
1221        }
1222        // Put selected provider first, keep rest in original order.
1223        let mut ordered = self.state.providers.to_vec();
1224        if let Some(ref sel) = selected
1225            && let Some(pos) = ordered.iter().position(|p| p.name() == sel.provider)
1226        {
1227            ordered.swap(0, pos);
1228        }
1229        ordered
1230    }
1231
1232    /// Record availability outcome (network success/failure) for EMA or Thompson.
1233    ///
1234    /// For cascade routing, quality outcomes are tracked separately in `CascadeState`.
1235    /// Only availability outcomes (API up/down) are recorded here to avoid corrupting
1236    /// Thompson/EMA distributions with quality-based failures (HIGH-01).
1237    fn record_availability(&self, provider_name: &str, success: bool, latency_ms: u64) {
1238        match self.strategy {
1239            RouterStrategy::Thompson => {
1240                if let Some(ref thompson) = self.thompson {
1241                    let mut state = thompson.lock();
1242                    state.update(provider_name, success);
1243                }
1244            }
1245            RouterStrategy::Ema => {
1246                self.ema_record(provider_name, success, latency_ms);
1247            }
1248            RouterStrategy::Cascade | RouterStrategy::Bandit => {
1249                // Cascade does not use Thompson/EMA for ordering; no-op.
1250                // Bandit tracks rewards separately via bandit_record_reward().
1251            }
1252        }
1253    }
1254
1255    fn ema_record(&self, provider_name: &str, success: bool, latency_ms: u64) {
1256        let Some(ref ema) = self.ema else {
1257            return;
1258        };
1259        ema.record(provider_name, success, latency_ms);
1260        let current_names: Vec<String> = self
1261            .state
1262            .providers
1263            .iter()
1264            .map(|p| p.name().to_owned())
1265            .collect();
1266        if let Some(new_order_names) = ema.maybe_reorder(&current_names) {
1267            let name_to_idx: std::collections::HashMap<&str, usize> = self
1268                .state
1269                .providers
1270                .iter()
1271                .enumerate()
1272                .map(|(i, p)| (p.name(), i))
1273                .collect();
1274            let new_order: Vec<usize> = new_order_names
1275                .iter()
1276                .filter_map(|n| name_to_idx.get(n.as_str()).copied())
1277                .collect();
1278            let mut order = self.state.provider_order.lock();
1279            *order = new_order;
1280        }
1281    }
1282
1283    /// Return a snapshot of Thompson distribution parameters for all tracked providers.
1284    ///
1285    /// Returns an empty vec if Thompson strategy is not active.
1286    #[must_use]
1287    pub fn thompson_stats(&self) -> Vec<(String, f64, f64)> {
1288        let Some(ref thompson) = self.thompson else {
1289            return vec![];
1290        };
1291        let state = thompson.lock();
1292        state.provider_stats()
1293    }
1294
1295    pub fn set_status_tx(&mut self, tx: StatusTx) {
1296        if let Some(providers) = Arc::get_mut(&mut self.state.providers) {
1297            for p in providers {
1298                p.set_status_tx(tx.clone());
1299            }
1300        } else {
1301            // Defensive path: should never happen at bootstrap (refcount == 1).
1302            let mut v: Vec<_> = self.state.providers.iter().cloned().collect();
1303            for p in &mut v {
1304                p.set_status_tx(tx.clone());
1305            }
1306            self.state.providers = Arc::from(v);
1307        }
1308        self.status_tx = Some(tx);
1309    }
1310
1311    /// Aggregate model lists from all sub-providers, deduplicating by id.
1312    ///
1313    /// Individual sub-provider errors are logged as warnings and skipped.
1314    ///
1315    /// # Errors
1316    ///
1317    /// Always succeeds (errors per-provider are swallowed).
1318    pub async fn list_models_remote(
1319        &self,
1320    ) -> Result<Vec<crate::model_cache::RemoteModelInfo>, LlmError> {
1321        let mut seen = std::collections::HashSet::new();
1322        let mut all = Vec::new();
1323        for p in self.state.providers.iter() {
1324            match p.list_models_remote().await {
1325                Ok(models) => {
1326                    for m in models {
1327                        if seen.insert(m.id.clone()) {
1328                            all.push(m);
1329                        }
1330                    }
1331                }
1332                Err(e) => {
1333                    tracing::warn!(error = %e, "router: list_models_remote sub-provider failed");
1334                }
1335            }
1336        }
1337        Ok(all)
1338    }
1339
1340    /// Evaluate quality with heuristics only.
1341    fn evaluate_heuristic(response: &str, threshold: f64) -> cascade::QualityVerdict {
1342        let mut verdict = heuristic_score(response);
1343        verdict.should_escalate = verdict.score < threshold;
1344        verdict
1345    }
1346
1347    /// Evaluate quality using the configured classifier mode.
1348    ///
1349    /// For `ClassifierMode::Judge`, calls the summary provider and falls back to heuristic
1350    /// on any error or timeout. For `ClassifierMode::Heuristic`, evaluates synchronously.
1351    async fn evaluate_quality(
1352        response: &str,
1353        threshold: f64,
1354        mode: ClassifierMode,
1355        summary_provider: Option<&dyn crate::provider_dyn::LlmProviderDyn>,
1356        judge_timeout_ms: u64,
1357    ) -> cascade::QualityVerdict {
1358        if mode == ClassifierMode::Judge {
1359            if let Some(judge) = summary_provider {
1360                match cascade::judge_score(
1361                    judge,
1362                    response,
1363                    std::time::Duration::from_millis(judge_timeout_ms),
1364                )
1365                .await
1366                {
1367                    Some(score) => {
1368                        let should_escalate = score < threshold;
1369                        tracing::debug!(
1370                            score,
1371                            threshold,
1372                            should_escalate,
1373                            "cascade: judge scored response"
1374                        );
1375                        return cascade::QualityVerdict {
1376                            score,
1377                            should_escalate,
1378                            reason: format!("judge score: {score:.2}"),
1379                        };
1380                    }
1381                    None => {
1382                        tracing::warn!("cascade: judge call failed, falling back to heuristic");
1383                    }
1384                }
1385            } else {
1386                tracing::warn!(
1387                    "cascade: classifier_mode=judge but no summary_provider configured, \
1388                     using heuristic"
1389                );
1390            }
1391        }
1392        Self::evaluate_heuristic(response, threshold)
1393    }
1394}
1395
1396const EMBED_MAX_RETRIES: u32 = 3;
1397const EMBED_BASE_DELAY_MS: u64 = 500;
1398
1399impl RouterProvider {
1400    /// Embed `text` with per-turn caching.
1401    ///
1402    /// Checks `cache` before calling the underlying provider. On a cache hit, increments
1403    /// `embed_cache_hits`; on a miss, embeds via `self.embed()` and populates the cache.
1404    /// Either way, `embed_call_count` is incremented for observability.
1405    async fn embed_cached(
1406        &self,
1407        text: &str,
1408        cache: &Mutex<TurnEmbedCache>,
1409    ) -> Result<Vec<f32>, crate::error::LlmError> {
1410        self.state.embed_call_count.fetch_add(1, Ordering::Relaxed);
1411        if let Some(emb) = cache.lock().get(text) {
1412            self.state.embed_cache_hits.fetch_add(1, Ordering::Relaxed);
1413            return Ok(emb.clone());
1414        }
1415        let emb = self.embed(text).await?;
1416        cache.lock().insert(text, emb.clone());
1417        Ok(emb)
1418    }
1419
1420    /// Return session-level embedding cache metrics: `(total_calls, cache_hits)`.
1421    #[must_use]
1422    pub fn embed_cache_metrics(&self) -> (u64, u64) {
1423        (
1424            self.state.embed_call_count.load(Ordering::Relaxed),
1425            self.state.embed_cache_hits.load(Ordering::Relaxed),
1426        )
1427    }
1428
1429    /// Spawn a background task to update the ASI window for `provider`.
1430    ///
1431    /// Fire-and-forget: routing is not blocked on the embed call. If the embed fails,
1432    /// the ASI window is not updated (no penalty for embed failure).
1433    ///
1434    /// `turn_id` is used to debounce: at most one ASI update fires per turn even when
1435    /// `chat()` is called N times concurrently (e.g., tool schema fetches). Subsequent
1436    /// calls within the same turn are silently dropped.
1437    ///
1438    /// `precomputed_embedding` — when `Some`, skips the embed call entirely (reuse from
1439    /// quality gate). When `None`, embeds `response` inline in the spawned task.
1440    fn spawn_asi_update(
1441        &self,
1442        provider: &str,
1443        response: String,
1444        turn_id: u64,
1445        precomputed_embedding: Option<Vec<f32>>,
1446    ) {
1447        // Debounce: swap in turn_id; if the previous value equals turn_id, another call
1448        // already claimed this turn → drop silently. `swap` is atomic so exactly one
1449        // concurrent caller wins the "first for this turn" race.
1450        let prev = self.state.asi_last_turn.swap(turn_id, Ordering::AcqRel);
1451        if prev == turn_id {
1452            return;
1453        }
1454
1455        let Some(ref asi_arc) = self.asi else { return };
1456        let Some(ref asi_cfg) = self.asi_config else {
1457            return;
1458        };
1459
1460        let mut tasks = self.asi_tasks.lock();
1461        // Drain finished tasks so completed handles don't count toward the cap.
1462        while tasks.try_join_next().is_some() {}
1463        if tasks.len() >= MAX_ASI_TASKS {
1464            tracing::debug!("asi: task limit reached, skipping coherence update");
1465            return;
1466        }
1467
1468        let asi = Arc::clone(asi_arc);
1469        let router = self.clone();
1470        let window_size = asi_cfg.window;
1471        let provider_name = provider.to_owned();
1472        let embed_timeout_ms = self.embed_timeout_ms;
1473        tasks.spawn(async move {
1474            let emb = if let Some(e) = precomputed_embedding {
1475                e
1476            } else {
1477                let embed_fut = router.embed(&response);
1478                let embed_result = if embed_timeout_ms > 0 {
1479                    let timeout = std::time::Duration::from_millis(embed_timeout_ms);
1480                    if let Ok(r) = tokio::time::timeout(timeout, embed_fut).await {
1481                        r
1482                    } else {
1483                        tracing::debug!(
1484                            provider = provider_name,
1485                            timeout_ms = embed_timeout_ms,
1486                            "asi: embed timed out, skipping coherence update"
1487                        );
1488                        return;
1489                    }
1490                } else {
1491                    embed_fut.await
1492                };
1493                match embed_result {
1494                    Ok(e) => e,
1495                    Err(err) => {
1496                        tracing::debug!(
1497                            provider = provider_name,
1498                            error = %err,
1499                            "asi: embed failed, skipping coherence update"
1500                        );
1501                        return;
1502                    }
1503                }
1504            };
1505            let mut state = asi.lock();
1506            state.push_embedding(&provider_name, emb, window_size);
1507        });
1508    }
1509}
1510
1511/// Record a provider error during the fallback loop and emit a warning log.
1512///
1513/// Shared by [`RouterProvider::chat`] and [`RouterProvider::chat_stream`] to avoid
1514/// duplicating error-path bookkeeping. Not part of the public API.
1515fn record_fallback_error(
1516    router: &RouterProvider,
1517    provider_name: &str,
1518    error: &LlmError,
1519    elapsed_ms: u64,
1520    status_tx: Option<&StatusTx>,
1521    log_msg: &'static str,
1522) {
1523    router.record_availability(provider_name, false, elapsed_ms);
1524    if error.is_rate_limited() {
1525        router.record_availability(provider_name, false, 0);
1526    }
1527    if let Some(tx) = status_tx {
1528        let _ = tx.send(format!("router: {provider_name} failed, falling back"));
1529    }
1530    tracing::warn!(provider = provider_name, error = %error, "{}", log_msg);
1531}
1532
1533impl LlmProvider for RouterProvider {
1534    fn context_window(&self) -> Option<usize> {
1535        self.state
1536            .providers
1537            .first()
1538            .and_then(LlmProvider::context_window)
1539    }
1540
1541    #[allow(clippy::too_many_lines)] // CoE + quality-gate inline logic; extracting would obscure the control flow
1542    fn chat(
1543        &self,
1544        messages: &[Message],
1545    ) -> impl std::future::Future<Output = Result<String, LlmError>> + Send {
1546        let status_tx = self.status_tx.clone();
1547        let messages = messages.to_vec();
1548        let router = self.clone();
1549        #[cfg(feature = "profiling")]
1550        let model = self.model_identifier().to_owned();
1551        // NOTE: `chat` and `chat_stream` share error-path logic via `record_fallback_error`.
1552        // Their success paths diverge (quality gate + CoE vs. plain stream-open), so a
1553        // shared loop helper would reduce clarity without removing significant duplication.
1554        let fut = Box::pin(async move {
1555            // Increment turn counter once per top-level chat() call. All concurrent sub-calls
1556            // (tool schema fetches, embed probes) that re-enter chat() will see the same
1557            // turn_id via the shared Arc<AtomicU64>, enabling ASI debounce.
1558            let turn_id = router.state.turn_counter.fetch_add(1, Ordering::Relaxed);
1559
1560            tracing::info!(
1561                strategy = ?router.strategy,
1562                turn_id,
1563                provider_count = router.state.providers.len(),
1564                "llm.router.select"
1565            );
1566
1567            if router.strategy == RouterStrategy::Cascade {
1568                // Cascade: pass Arc slice directly — providers are sorted at construction,
1569                // so no Vec allocation needed on the hot path.
1570                return router
1571                    .cascade_chat(&router.state.providers, &messages, status_tx)
1572                    .await;
1573            }
1574            if router.strategy == RouterStrategy::Bandit {
1575                return router.bandit_chat(&messages, status_tx).await;
1576            }
1577            let providers = router.ordered_providers();
1578
1579            // Per-turn embedding cache: avoids re-embedding the same text across quality
1580            // gate and ASI update within a single chat() call.
1581            let turn_cache = Mutex::new(TurnEmbedCache::default());
1582
1583            // Pre-compute query embedding once for quality gate (fail-open on error).
1584            let query_text = messages
1585                .last()
1586                .map(Message::to_llm_content)
1587                .unwrap_or_default();
1588            let query_embedding = if router.quality_gate.is_some() && !query_text.is_empty() {
1589                router.embed_cached(query_text, &turn_cache).await.ok()
1590            } else {
1591                None
1592            };
1593
1594            // Best response seen so far (for quality gate exhaustion fallback, M2).
1595            let mut best_response: Option<(f32, String)> = None;
1596
1597            for p in &providers {
1598                let start = std::time::Instant::now();
1599                match p.chat_with_extras(&messages).await {
1600                    Ok((r, extras)) => {
1601                        router.record_availability(
1602                            p.name(),
1603                            true,
1604                            u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1605                        );
1606
1607                        // Quality gate: check response-query embedding similarity.
1608                        if let (Some(threshold), Some(qemb)) =
1609                            (router.quality_gate, &query_embedding)
1610                        {
1611                            let resp_emb = router.embed_cached(&r, &turn_cache).await.ok();
1612                            let similarity = resp_emb
1613                                .as_ref()
1614                                .map_or(threshold, |e| cosine_similarity(qemb, e)); // fail-open: None → treat as passing
1615                            if similarity < threshold {
1616                                tracing::info!(
1617                                    provider = p.name(),
1618                                    score = similarity,
1619                                    threshold,
1620                                    "thompson_quality_fallback"
1621                                );
1622                                // Track best response seen so far.
1623                                let is_better = best_response
1624                                    .as_ref()
1625                                    .is_none_or(|(best, _)| similarity > *best);
1626                                if is_better {
1627                                    best_response = Some((similarity, r.clone()));
1628                                }
1629                                // Spawn ASI update even on quality failure, reusing resp_emb.
1630                                router.spawn_asi_update(p.name(), r, turn_id, resp_emb);
1631                                continue;
1632                            }
1633                            // Pass resp_emb to ASI to avoid a redundant embed call.
1634                            router.spawn_asi_update(p.name(), r.clone(), turn_id, resp_emb);
1635
1636                            // CoE: pass already-obtained primary result to avoid double call.
1637                            if let Some(ref coe_router) = router.coe
1638                                && let Ok((final_r, pname, decision)) = run_coe(
1639                                    coe_router,
1640                                    p.name().to_owned(),
1641                                    r.clone(),
1642                                    extras,
1643                                    &messages,
1644                                )
1645                                .await
1646                            {
1647                                if matches!(
1648                                    decision,
1649                                    CoeDecision::EscalateIntra | CoeDecision::EscalateInter
1650                                ) {
1651                                    router.record_quality_outcome(&pname, false);
1652                                    router
1653                                        .record_quality_outcome(coe_router.secondary.name(), true);
1654                                }
1655                                return Ok(final_r);
1656                            }
1657
1658                            return Ok(r);
1659                        }
1660
1661                        // Spawn ASI embedding update (fire-and-forget, no precomputed embedding).
1662                        router.spawn_asi_update(p.name(), r.clone(), turn_id, None);
1663
1664                        // CoE: pass already-obtained primary result to avoid double call.
1665                        if let Some(ref coe_router) = router.coe
1666                            && let Ok((final_r, pname, decision)) = run_coe(
1667                                coe_router,
1668                                p.name().to_owned(),
1669                                r.clone(),
1670                                extras,
1671                                &messages,
1672                            )
1673                            .await
1674                        {
1675                            if matches!(
1676                                decision,
1677                                CoeDecision::EscalateIntra | CoeDecision::EscalateInter
1678                            ) {
1679                                router.record_quality_outcome(&pname, false);
1680                                router.record_quality_outcome(coe_router.secondary.name(), true);
1681                            }
1682                            return Ok(final_r);
1683                        }
1684
1685                        return Ok(r);
1686                    }
1687                    Err(e) => {
1688                        record_fallback_error(
1689                            &router,
1690                            p.name(),
1691                            &e,
1692                            u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1693                            status_tx.as_ref(),
1694                            "router fallback",
1695                        );
1696                    }
1697                }
1698            }
1699
1700            // All providers exhausted by quality gate: return best response seen (M2).
1701            if let Some((_, response)) = best_response {
1702                return Ok(response);
1703            }
1704
1705            Err(LlmError::NoProviders)
1706        });
1707        #[cfg(feature = "profiling")]
1708        let fut = {
1709            use tracing::Instrument as _;
1710            fut.instrument(tracing::info_span!("llm.router.chat", model = model))
1711        };
1712        fut
1713    }
1714
1715    fn chat_stream(
1716        &self,
1717        messages: &[Message],
1718    ) -> impl std::future::Future<Output = Result<ChatStream, LlmError>> + Send {
1719        let status_tx = self.status_tx.clone();
1720        let messages = messages.to_vec();
1721        let router = self.clone();
1722        #[cfg(feature = "profiling")]
1723        let model = self.model_identifier().to_owned();
1724        let fut = Box::pin(async move {
1725            // NOTE: see DRY design decision above `chat()` — error path shared via
1726            // `record_fallback_error`; success paths diverge intentionally.
1727            if router.strategy == RouterStrategy::Cascade {
1728                // Cascade: pass Arc slice directly — no Vec allocation on the hot path.
1729                return router
1730                    .cascade_chat_stream(&router.state.providers, &messages, status_tx)
1731                    .await;
1732            }
1733            if router.strategy == RouterStrategy::Bandit {
1734                // Bandit stream: select provider then stream from it.
1735                // Reward is not recorded for streams (stream completion is async);
1736                // this is a known pre-1.0 limitation — same as Thompson stream mode.
1737                let query = messages
1738                    .last()
1739                    .map(super::provider::Message::to_llm_content)
1740                    .unwrap_or_default();
1741                let p = router
1742                    .bandit_select_provider(query)
1743                    .await
1744                    .ok_or(LlmError::NoProviders)?;
1745                return p.chat_stream(&messages).await;
1746            }
1747            let providers = router.ordered_providers();
1748            for p in &providers {
1749                let start = std::time::Instant::now();
1750                match p.chat_stream(&messages).await {
1751                    Ok(r) => {
1752                        // NOTE: success is recorded at stream-open time, not on stream
1753                        // completion. A provider that opens the stream but then fails
1754                        // mid-delivery still gets alpha += 1. This is a known pre-1.0
1755                        // limitation: fixing it requires wrapping ChatStream to intercept
1756                        // the completion/error signal, which adds latency on the hot path.
1757                        // Tracked in the adaptive-inference epic (CRIT-2).
1758                        router.record_availability(
1759                            p.name(),
1760                            true,
1761                            u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1762                        );
1763                        return Ok(r);
1764                    }
1765                    Err(e) => {
1766                        record_fallback_error(
1767                            &router,
1768                            p.name(),
1769                            &e,
1770                            u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1771                            status_tx.as_ref(),
1772                            "router stream fallback",
1773                        );
1774                    }
1775                }
1776            }
1777            Err(LlmError::NoProviders)
1778        });
1779        #[cfg(feature = "profiling")]
1780        let fut = {
1781            use tracing::Instrument as _;
1782            fut.instrument(tracing::info_span!("llm.router.chat_stream", model = model))
1783        };
1784        fut
1785    }
1786
1787    fn supports_streaming(&self) -> bool {
1788        self.state
1789            .providers
1790            .iter()
1791            .any(LlmProvider::supports_streaming)
1792    }
1793
1794    #[allow(clippy::too_many_lines)] // retry + timeout + fallback + availability tracking: splitting would break the shared `last_err` accumulator
1795    fn embed(
1796        &self,
1797        text: &str,
1798    ) -> impl std::future::Future<Output = Result<Vec<f32>, LlmError>> + Send {
1799        let providers = self.ordered_providers();
1800        let status_tx = self.status_tx.clone();
1801        let text = text.to_owned();
1802        let router = self.clone();
1803        let embed_timeout_ms = self.embed_timeout_ms;
1804        #[cfg(feature = "profiling")]
1805        let model = self.model_identifier().to_owned();
1806        let fut = Box::pin(async move {
1807            for p in &providers {
1808                if !p.supports_embeddings() {
1809                    continue;
1810                }
1811                let mut last_err: Option<LlmError> = None;
1812                for attempt in 0..=EMBED_MAX_RETRIES {
1813                    if attempt > 0 {
1814                        let delay = EMBED_BASE_DELAY_MS * (1u64 << (attempt - 1));
1815                        tracing::warn!(
1816                            provider = p.name(),
1817                            attempt,
1818                            delay_ms = delay,
1819                            "embed: rate limited, retrying after backoff"
1820                        );
1821                        tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
1822                    }
1823                    let start = std::time::Instant::now();
1824                    // Apply per-call timeout when configured (embed_timeout_ms > 0).
1825                    let embed_result: Result<Vec<f32>, LlmError> = if embed_timeout_ms > 0 {
1826                        let timeout = std::time::Duration::from_millis(embed_timeout_ms);
1827                        match tokio::time::timeout(timeout, p.embed(&text)).await {
1828                            Ok(inner) => inner,
1829                            Err(_elapsed) => {
1830                                tracing::warn!(
1831                                    provider = p.name(),
1832                                    timeout_ms = embed_timeout_ms,
1833                                    "embed: provider timed out, falling back"
1834                                );
1835                                last_err = Some(LlmError::Timeout);
1836                                break;
1837                            }
1838                        }
1839                    } else {
1840                        p.embed(&text).await
1841                    };
1842                    match embed_result {
1843                        Ok(r) => {
1844                            router.record_availability(
1845                                p.name(),
1846                                true,
1847                                u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1848                            );
1849                            return Ok(r);
1850                        }
1851                        Err(e) if e.is_invalid_input() => {
1852                            // The input itself is invalid — retrying on another provider
1853                            // would fail identically. Do not penalize provider reputation.
1854                            tracing::warn!(
1855                                provider = p.name(),
1856                                error = %e,
1857                                "embed: invalid input, not retrying on other providers"
1858                            );
1859                            return Err(e);
1860                        }
1861                        Err(e) if e.is_rate_limited() && attempt < EMBED_MAX_RETRIES => {
1862                            last_err = Some(e);
1863                        }
1864                        Err(e) => {
1865                            router.record_availability(
1866                                p.name(),
1867                                false,
1868                                u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1869                            );
1870                            if let Some(ref tx) = status_tx {
1871                                let _ = tx.send(format!(
1872                                    "router: {} embed failed, falling back",
1873                                    p.name()
1874                                ));
1875                            }
1876                            tracing::warn!(provider = p.name(), error = %e, "router embed fallback");
1877                            last_err = Some(e);
1878                            break;
1879                        }
1880                    }
1881                }
1882                // All retries exhausted for this provider (rate-limited every time).
1883                if matches!(last_err, Some(ref e) if e.is_rate_limited()) {
1884                    router.record_availability(p.name(), false, 0);
1885                    if let Some(ref tx) = status_tx {
1886                        let _ = tx.send(format!(
1887                            "router: {} embed rate limited, falling back",
1888                            p.name()
1889                        ));
1890                    }
1891                    tracing::warn!(
1892                        provider = p.name(),
1893                        "embed: rate limit retries exhausted, falling back"
1894                    );
1895                }
1896            }
1897            Err(LlmError::NoProviders)
1898        });
1899        #[cfg(feature = "profiling")]
1900        let fut = {
1901            use tracing::Instrument as _;
1902            fut.instrument(tracing::info_span!("llm.router.embed", model = model))
1903        };
1904        fut
1905    }
1906
1907    fn embed_batch(
1908        &self,
1909        texts: &[&str],
1910    ) -> impl std::future::Future<Output = Result<Vec<Vec<f32>>, LlmError>> + Send {
1911        let providers = self.ordered_providers();
1912        let status_tx = self.status_tx.clone();
1913        let owned = owned_strs(texts);
1914        let router = self.clone();
1915        let semaphore = self.state.embed_semaphore.clone();
1916        #[cfg(feature = "profiling")]
1917        let model = self.model_identifier().to_owned();
1918        let fut = Box::pin(async move {
1919            // Acquire embed semaphore permit before any HTTP work to cap concurrency.
1920            let _permit = if let Some(ref sem) = semaphore {
1921                Some(sem.acquire().await.map_err(|_| LlmError::NoProviders)?)
1922            } else {
1923                None
1924            };
1925            let refs: Vec<&str> = owned.iter().map(String::as_str).collect();
1926            for p in &providers {
1927                if !p.supports_embeddings() {
1928                    continue;
1929                }
1930                let mut last_err: Option<LlmError> = None;
1931                for attempt in 0..=EMBED_MAX_RETRIES {
1932                    if attempt > 0 {
1933                        let delay = EMBED_BASE_DELAY_MS * (1u64 << (attempt - 1));
1934                        tracing::warn!(
1935                            provider = p.name(),
1936                            attempt,
1937                            delay_ms = delay,
1938                            "embed_batch: rate limited, retrying after backoff"
1939                        );
1940                        tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
1941                    }
1942                    let start = std::time::Instant::now();
1943                    match p.embed_batch(&refs).await {
1944                        Ok(r) => {
1945                            router.record_availability(
1946                                p.name(),
1947                                true,
1948                                u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1949                            );
1950                            return Ok(r);
1951                        }
1952                        Err(e) if e.is_invalid_input() => {
1953                            tracing::warn!(
1954                                provider = p.name(),
1955                                error = %e,
1956                                "embed_batch: invalid input, not retrying on other providers"
1957                            );
1958                            return Err(e);
1959                        }
1960                        Err(e) if e.is_rate_limited() && attempt < EMBED_MAX_RETRIES => {
1961                            last_err = Some(e);
1962                        }
1963                        Err(e) => {
1964                            router.record_availability(
1965                                p.name(),
1966                                false,
1967                                u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1968                            );
1969                            if let Some(ref tx) = status_tx {
1970                                let _ = tx.send(format!(
1971                                    "router: {} embed_batch failed, falling back",
1972                                    p.name()
1973                                ));
1974                            }
1975                            tracing::warn!(
1976                                provider = p.name(),
1977                                error = %e,
1978                                "router embed_batch fallback"
1979                            );
1980                            last_err = Some(e);
1981                            break;
1982                        }
1983                    }
1984                }
1985                // All retries exhausted for this provider (rate-limited every time).
1986                if matches!(last_err, Some(ref e) if e.is_rate_limited()) {
1987                    router.record_availability(p.name(), false, 0);
1988                    if let Some(ref tx) = status_tx {
1989                        let _ = tx.send(format!(
1990                            "router: {} embed_batch rate limited, falling back",
1991                            p.name()
1992                        ));
1993                    }
1994                    tracing::warn!(
1995                        provider = p.name(),
1996                        "embed_batch: rate limit retries exhausted, falling back"
1997                    );
1998                }
1999            }
2000            Err(LlmError::NoProviders)
2001        });
2002        #[cfg(feature = "profiling")]
2003        let fut = {
2004            use tracing::Instrument as _;
2005            fut.instrument(tracing::info_span!("llm.router.embed_batch", model = model))
2006        };
2007        fut
2008    }
2009
2010    fn supports_embeddings(&self) -> bool {
2011        self.state
2012            .providers
2013            .iter()
2014            .any(LlmProvider::supports_embeddings)
2015    }
2016
2017    #[allow(clippy::unnecessary_literal_bound)]
2018    fn name(&self) -> &str {
2019        "router"
2020    }
2021
2022    #[allow(clippy::unnecessary_literal_bound)]
2023    fn model_identifier(&self) -> &str {
2024        "router"
2025    }
2026
2027    fn supports_tool_use(&self) -> bool {
2028        self.state
2029            .providers
2030            .iter()
2031            .any(LlmProvider::supports_tool_use)
2032    }
2033
2034    fn list_models(&self) -> Vec<String> {
2035        self.state
2036            .providers
2037            .iter()
2038            .flat_map(super::provider::LlmProvider::list_models)
2039            .collect()
2040    }
2041
2042    #[allow(refining_impl_trait_reachable)]
2043    fn chat_with_tools(
2044        &self,
2045        messages: &[Message],
2046        tools: &[ToolDefinition],
2047    ) -> impl std::future::Future<Output = Result<ChatResponse, LlmError>> + Send {
2048        let messages = messages.to_vec();
2049        #[cfg(feature = "profiling")]
2050        let tool_count = tools.len();
2051        let tools = tools.to_vec();
2052        let status_tx = self.status_tx.clone();
2053        let router = self.clone();
2054        #[cfg(feature = "profiling")]
2055        let model = self.model_identifier().to_owned();
2056        let fut = Box::pin(async move {
2057            // Bandit routing for tool calls: select a single provider, no quality escalation.
2058            if router.strategy == RouterStrategy::Bandit {
2059                let query = messages
2060                    .last()
2061                    .map(super::provider::Message::to_llm_content)
2062                    .unwrap_or_default();
2063                let p = router
2064                    .bandit_select_provider(query)
2065                    .await
2066                    .ok_or(LlmError::NoProviders)?;
2067                if !p.supports_tool_use() {
2068                    return Err(LlmError::NoProviders);
2069                }
2070                let result = p.chat_with_tools(&messages, &tools).await;
2071                if result.is_ok() {
2072                    *router.state.last_active_provider.lock() = Some(p.name().to_owned());
2073                }
2074                return result;
2075            }
2076
2077            // Cascade is intentionally skipped for tool calls: evaluating quality of
2078            // a tool-call response (structured JSON with tool name + args) requires
2079            // different heuristics than text quality. Skipping cascade for tool calls
2080            // avoids inappropriate escalation based on text signals (HIGH-04).
2081            let providers = router.ordered_providers();
2082            for p in &providers {
2083                if !p.supports_tool_use() {
2084                    continue;
2085                }
2086                let start = std::time::Instant::now();
2087                match p.chat_with_tools(&messages, &tools).await {
2088                    Ok(r) => {
2089                        router.record_availability(
2090                            p.name(),
2091                            true,
2092                            u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
2093                        );
2094                        // Track which sub-provider served this tool call for reputation attribution.
2095                        *router.state.last_active_provider.lock() = Some(p.name().to_owned());
2096                        return Ok(r);
2097                    }
2098                    Err(e) => {
2099                        router.record_availability(
2100                            p.name(),
2101                            false,
2102                            u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
2103                        );
2104                        if e.is_invalid_input() {
2105                            tracing::warn!(
2106                                provider = p.name(),
2107                                error = %e,
2108                                "chat_with_tools: invalid input, not retrying on other providers"
2109                            );
2110                            return Err(e);
2111                        }
2112                        if e.is_rate_limited() {
2113                            router.record_availability(p.name(), false, 0);
2114                        }
2115                        if let Some(ref tx) = status_tx {
2116                            let _ = tx.send(format!(
2117                                "router: {} tool call failed, falling back",
2118                                p.name()
2119                            ));
2120                        }
2121                        tracing::warn!(provider = p.name(), error = %e, "router tool fallback");
2122                    }
2123                }
2124            }
2125            Err(LlmError::NoProviders)
2126        });
2127        #[cfg(feature = "profiling")]
2128        let fut = {
2129            use tracing::Instrument as _;
2130            fut.instrument(tracing::info_span!(
2131                "llm.router.chat_with_tools",
2132                model = model,
2133                tool_count = tool_count
2134            ))
2135        };
2136        fut
2137    }
2138
2139    fn debug_request_json(
2140        &self,
2141        messages: &[Message],
2142        tools: &[ToolDefinition],
2143        stream: bool,
2144    ) -> serde_json::Value {
2145        let candidate = if tools.is_empty() {
2146            self.ordered_providers().into_iter().next()
2147        } else {
2148            self.ordered_providers()
2149                .into_iter()
2150                .find(super::provider::LlmProvider::supports_tool_use)
2151        };
2152        candidate.map_or_else(
2153            || crate::provider::default_debug_request_json(messages, tools),
2154            |provider| provider.debug_request_json(messages, tools, stream),
2155        )
2156    }
2157
2158    fn last_cache_usage(&self) -> Option<(u64, u64)> {
2159        None
2160    }
2161}
2162
2163// ── Bandit routing helpers ────────────────────────────────────────────────────
2164
2165impl RouterProvider {
2166    /// Bandit `chat()` implementation: select provider, call, record reward.
2167    #[cfg_attr(
2168        feature = "profiling",
2169        tracing::instrument(name = "llm.router.bandit_chat", skip_all)
2170    )]
2171    async fn bandit_chat(
2172        &self,
2173        messages: &[Message],
2174        status_tx: Option<StatusTx>,
2175    ) -> Result<String, LlmError> {
2176        let query = messages
2177            .last()
2178            .map(super::provider::Message::to_llm_content)
2179            .unwrap_or_default();
2180        let features = self.bandit_features(query.as_ref()).await;
2181
2182        let p = self
2183            .bandit_select_provider(query.as_ref())
2184            .await
2185            .ok_or(LlmError::NoProviders)?;
2186
2187        if let Some(ref tx) = status_tx {
2188            let _ = tx.send(format!("bandit: routing to {}", p.name()));
2189        }
2190
2191        let result = p.chat(messages).await;
2192        match &result {
2193            Ok(response) => {
2194                let verdict = heuristic_score(response);
2195                // Record reward even when embedding failed (use zero vector so the arm's
2196                // update count increments — prevents permanent cold-start on flaky embedders).
2197                let feat_ref: &[f32];
2198                let zero_vec: Vec<f32>;
2199                let dim = self.bandit_config.as_ref().map_or(32, |c| c.dim);
2200                if let Some(ref feat) = features {
2201                    feat_ref = feat;
2202                } else {
2203                    zero_vec = vec![0.0; dim];
2204                    feat_ref = &zero_vec;
2205                    tracing::debug!(
2206                        provider = p.name(),
2207                        "bandit: recording reward with zero features (embed unavailable)"
2208                    );
2209                }
2210                self.bandit_record_reward(p.name(), feat_ref, verdict.score, 0.0);
2211            }
2212            Err(e) => {
2213                tracing::warn!(provider = p.name(), error = %e, "bandit: provider failed");
2214            }
2215        }
2216        result
2217    }
2218}
2219
2220// ── Cascade routing helpers ───────────────────────────────────────────────────
2221
2222/// Outcome of evaluating one provider's response during cascade routing.
2223struct CascadeEvalResult {
2224    verdict: cascade::QualityVerdict,
2225    /// Updated token counter after adding this response's estimated cost.
2226    tokens_used: u32,
2227    /// Whether the token budget is now exhausted.
2228    budget_exhausted: bool,
2229}
2230
2231/// Evaluate a cascade response: score it, record the verdict in shared state, and
2232/// compute whether the token budget is exhausted.
2233async fn cascade_evaluate_response(
2234    provider_name: &str,
2235    response: &str,
2236    cfg: &CascadeRouterConfig,
2237    cascade_state: &Mutex<CascadeState>,
2238    tokens_used_before: u32,
2239    log_prefix: &str,
2240) -> CascadeEvalResult {
2241    let estimated_tokens =
2242        u32::try_from(zeph_common::text::estimate_tokens(response).max(1)).unwrap_or(u32::MAX);
2243    let tokens_used = tokens_used_before.saturating_add(estimated_tokens);
2244
2245    let verdict = RouterProvider::evaluate_quality(
2246        response,
2247        cfg.quality_threshold,
2248        cfg.classifier_mode,
2249        cfg.summary_provider.as_deref(),
2250        cfg.judge_timeout_ms,
2251    )
2252    .await;
2253
2254    {
2255        let mut state = cascade_state.lock();
2256        state.record(provider_name, verdict.score);
2257    }
2258
2259    tracing::debug!(
2260        provider = %provider_name,
2261        score = verdict.score,
2262        threshold = cfg.quality_threshold,
2263        should_escalate = verdict.should_escalate,
2264        reason = %verdict.reason,
2265        "{log_prefix}: quality verdict"
2266    );
2267
2268    let budget_exhausted = cfg
2269        .max_cascade_tokens
2270        .is_some_and(|budget| tokens_used >= budget);
2271
2272    CascadeEvalResult {
2273        verdict,
2274        tokens_used,
2275        budget_exhausted,
2276    }
2277}
2278
2279impl RouterProvider {
2280    /// Cascade chat: try providers in order, escalate on degenerate output.
2281    ///
2282    /// Returns the best-seen response if all providers fail or budget is exhausted.
2283    #[cfg_attr(
2284        feature = "profiling",
2285        tracing::instrument(name = "llm.router.cascade_chat", skip_all)
2286    )]
2287    #[allow(clippy::too_many_lines)] // cascade loop: per-provider error/ok/budget/escalation branches are tightly coupled — extracting would obscure the control flow
2288    async fn cascade_chat(
2289        &self,
2290        providers: &[AnyProvider],
2291        messages: &[Message],
2292        status_tx: Option<StatusTx>,
2293    ) -> Result<String, LlmError> {
2294        let cfg = self
2295            .cascade_config
2296            .as_ref()
2297            .expect("cascade_config must be set");
2298        let cascade_state = self
2299            .cascade_state
2300            .as_ref()
2301            .expect("cascade_state must be set");
2302
2303        let mut escalations_remaining = cfg.max_escalations;
2304        let mut best: Option<(String, f64)> = None; // (response, score)
2305        let mut tokens_used: u32 = 0;
2306
2307        for (idx, p) in providers.iter().enumerate() {
2308            tracing::debug!(
2309                provider = %p.name(),
2310                attempt = idx + 1,
2311                total = providers.len(),
2312                classifier_mode = ?cfg.classifier_mode,
2313                quality_threshold = cfg.quality_threshold,
2314                "cascade: trying provider"
2315            );
2316            let start = std::time::Instant::now();
2317            match p.chat(messages).await {
2318                Err(e) => {
2319                    // Network/API error: record availability failure but don't consume escalation budget.
2320                    let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
2321                    self.record_availability(p.name(), false, latency);
2322                    if let Some(tx) = &status_tx {
2323                        let _ = tx.send(format!("cascade: {} unavailable, trying next", p.name()));
2324                    }
2325                    tracing::warn!(provider = p.name(), error = %e, "cascade: provider error");
2326                }
2327                Ok(response) => {
2328                    let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
2329
2330                    let eval = cascade_evaluate_response(
2331                        p.name(),
2332                        &response,
2333                        cfg,
2334                        cascade_state,
2335                        tokens_used,
2336                        "cascade",
2337                    )
2338                    .await;
2339                    tokens_used = eval.tokens_used;
2340                    let verdict = eval.verdict;
2341                    let budget_exhausted = eval.budget_exhausted;
2342
2343                    // Update best-seen response; skip empty strings to avoid silent failures.
2344                    let is_better = !response.is_empty()
2345                        && best
2346                            .as_ref()
2347                            .is_none_or(|(_, best_score)| verdict.score > *best_score);
2348                    if is_better {
2349                        tracing::debug!(
2350                            provider = %p.name(),
2351                            score = verdict.score,
2352                            "cascade: best_seen updated"
2353                        );
2354                        best = Some((response.clone(), verdict.score));
2355                    }
2356
2357                    let is_last = idx == providers.len() - 1;
2358
2359                    if !verdict.should_escalate
2360                        || is_last
2361                        || escalations_remaining == 0
2362                        || budget_exhausted
2363                    {
2364                        self.record_availability(p.name(), true, latency);
2365                        // When escalation is blocked (budget exhausted or escalation count
2366                        // at zero) and the current response would have triggered escalation,
2367                        // return the best-seen response instead of the current (possibly
2368                        // lower-quality) one.
2369                        if verdict.should_escalate
2370                            && (budget_exhausted || escalations_remaining == 0)
2371                        {
2372                            let best_response = best.take().map_or(response, |(r, _)| r);
2373                            tracing::info!(
2374                                tokens_used,
2375                                budget = cfg.max_cascade_tokens,
2376                                escalations_remaining,
2377                                "cascade: escalation blocked, returning best response"
2378                            );
2379                            return Ok(best_response);
2380                        }
2381                        return Ok(response);
2382                    }
2383
2384                    // Escalate: record availability success (provider worked, just low quality).
2385                    self.record_availability(p.name(), true, latency);
2386                    escalations_remaining -= 1;
2387
2388                    if let Some(tx) = &status_tx {
2389                        let _ = tx.send(format!(
2390                            "cascade: {} quality {:.2} < {:.2}, escalating ({} left)",
2391                            p.name(),
2392                            verdict.score,
2393                            cfg.quality_threshold,
2394                            escalations_remaining
2395                        ));
2396                    }
2397                    tracing::info!(
2398                        provider = %p.name(),
2399                        score = verdict.score,
2400                        threshold = cfg.quality_threshold,
2401                        escalations_remaining,
2402                        "cascade: escalating to next provider"
2403                    );
2404                }
2405            }
2406        }
2407
2408        // All providers tried — return best-seen response, or NoProviders if none worked.
2409        if let Some((_, score)) = &best {
2410            tracing::info!(
2411                score,
2412                "cascade: all providers exhausted, returning best-seen response"
2413            );
2414        } else {
2415            tracing::warn!("cascade: all providers failed, no response available");
2416        }
2417        best.map(|(r, _)| r).ok_or(LlmError::NoProviders)
2418    }
2419
2420    /// Cascade `chat_stream`: buffer cheap response, classify, escalate or replay.
2421    ///
2422    /// # Streaming latency tradeoff
2423    ///
2424    /// The first N-1 providers are fully buffered before classification. If escalation
2425    /// occurs, the user experiences: cheap model's full response time + expensive model's
2426    /// TTFT. This is strictly worse than direct routing to the expensive model for
2427    /// hard queries. Acceptable for v1; see CRIT-01 in critic handoff for alternatives.
2428    #[allow(clippy::too_many_lines)] // sequential cascade semantics: buffer→classify→escalate
2429    async fn cascade_chat_stream(
2430        &self,
2431        providers: &[AnyProvider],
2432        messages: &[Message],
2433        status_tx: Option<StatusTx>,
2434    ) -> Result<ChatStream, LlmError> {
2435        let cfg = self
2436            .cascade_config
2437            .as_ref()
2438            .expect("cascade_config must be set");
2439        let cascade_state = self
2440            .cascade_state
2441            .as_ref()
2442            .expect("cascade_state must be set");
2443
2444        let mut escalations_remaining = cfg.max_escalations;
2445        let mut tokens_used: u32 = 0;
2446        // Tracks the highest-scoring fully-buffered response seen so far.
2447        // Only populated from the early provider loop; the last provider streams
2448        // directly without buffering or scoring, so it never updates best_seen.
2449        let mut best_seen: Option<(CollectedStream, f64)> = None;
2450
2451        // Try all providers except the last without consuming the escalation budget
2452        // for errors (only quality failures consume it).
2453        let (last, early) = providers.split_last().ok_or(LlmError::NoProviders)?;
2454
2455        for (idx, p) in early.iter().enumerate() {
2456            tracing::debug!(
2457                provider = %p.name(),
2458                attempt = idx + 1,
2459                total = providers.len(),
2460                classifier_mode = ?cfg.classifier_mode,
2461                quality_threshold = cfg.quality_threshold,
2462                "cascade stream: trying provider (buffered)"
2463            );
2464            // Buffer response to classify quality.
2465            let start = std::time::Instant::now();
2466            let stream = match p.chat_stream(messages).await {
2467                Err(e) => {
2468                    let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
2469                    self.record_availability(p.name(), false, latency);
2470                    tracing::warn!(provider = p.name(), error = %e, "cascade stream: provider error");
2471                    if let Some(tx) = &status_tx {
2472                        let _ = tx.send(format!("cascade: {} unavailable, trying next", p.name()));
2473                    }
2474                    continue;
2475                }
2476                Ok(s) => s,
2477            };
2478
2479            // Collect the full stream.
2480            let buffered = collect_stream(stream).await;
2481            let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
2482
2483            match buffered {
2484                Err(e) => {
2485                    // Stream failed mid-delivery; treat as availability failure.
2486                    self.record_availability(p.name(), false, latency);
2487                    tracing::warn!(provider = p.name(), error = %e, "cascade stream: stream error");
2488                }
2489                Ok(collected) => {
2490                    let eval = cascade_evaluate_response(
2491                        p.name(),
2492                        &collected.content,
2493                        cfg,
2494                        cascade_state,
2495                        tokens_used,
2496                        "cascade stream",
2497                    )
2498                    .await;
2499                    tokens_used = eval.tokens_used;
2500                    let verdict = eval.verdict;
2501                    let budget_exhausted = eval.budget_exhausted;
2502
2503                    // Track the best response seen so far across early providers.
2504                    // Skip empty responses (no content and no tool calls) to avoid
2505                    // returning silent failures on all-fail fallback.
2506                    let is_better = !collected.is_empty()
2507                        && best_seen
2508                            .as_ref()
2509                            .is_none_or(|(_, best_score)| verdict.score > *best_score);
2510                    if is_better {
2511                        tracing::debug!(
2512                            provider = %p.name(),
2513                            score = verdict.score,
2514                            "cascade stream: best_seen updated"
2515                        );
2516                        best_seen = Some((collected.clone(), verdict.score));
2517                    }
2518
2519                    if !verdict.should_escalate || escalations_remaining == 0 || budget_exhausted {
2520                        self.record_availability(p.name(), true, latency);
2521
2522                        // When escalation is blocked (budget exhausted or escalation count
2523                        // at zero) and the current response would have triggered escalation,
2524                        // return the best-seen response instead of the current (possibly
2525                        // lower-quality) one.
2526                        let response = if verdict.should_escalate
2527                            && (budget_exhausted || escalations_remaining == 0)
2528                        {
2529                            tracing::info!(
2530                                tokens_used,
2531                                budget = cfg.max_cascade_tokens,
2532                                escalations_remaining,
2533                                "cascade stream: escalation blocked, returning best response"
2534                            );
2535                            best_seen.take().map_or(collected, |(r, _)| r)
2536                        } else {
2537                            collected
2538                        };
2539
2540                        return Ok(response.into_stream());
2541                    }
2542
2543                    // Escalate.
2544                    self.record_availability(p.name(), true, latency);
2545                    escalations_remaining -= 1;
2546
2547                    if let Some(tx) = &status_tx {
2548                        let _ = tx.send(format!(
2549                            "cascade: {} quality {:.2} < {:.2}, escalating",
2550                            p.name(),
2551                            verdict.score,
2552                            cfg.quality_threshold,
2553                        ));
2554                    }
2555                    tracing::info!(
2556                        provider = %p.name(),
2557                        score = verdict.score,
2558                        threshold = cfg.quality_threshold,
2559                        escalations_remaining,
2560                        "cascade stream: escalating to next provider"
2561                    );
2562                }
2563            }
2564        }
2565
2566        // Last provider: stream directly without buffering.
2567        // Note: if the stream itself fails mid-delivery (after Ok(stream) is returned),
2568        // there is no fallback to best_seen — the caller receives a partial response.
2569        // This is a pre-existing limitation; fixing it would require wrapping the stream.
2570        tracing::debug!(
2571            provider = %last.name(),
2572            attempt = providers.len(),
2573            total = providers.len(),
2574            "cascade stream: trying last provider (streaming, no classification)"
2575        );
2576        let start = std::time::Instant::now();
2577        match last.chat_stream(messages).await {
2578            Ok(stream) => {
2579                self.record_availability(
2580                    last.name(),
2581                    true,
2582                    u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
2583                );
2584                Ok(stream)
2585            }
2586            Err(e) => {
2587                self.record_availability(
2588                    last.name(),
2589                    false,
2590                    u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
2591                );
2592                // If we have a best-seen response from an early provider, return it
2593                // instead of propagating the last provider's error.
2594                if let Some((best_collected, _)) = best_seen {
2595                    tracing::info!(
2596                        "cascade stream: last provider failed, returning best-seen response"
2597                    );
2598                    return Ok(best_collected.into_stream());
2599                }
2600                Err(e)
2601            }
2602        }
2603    }
2604}
2605
2606/// Maximum bytes buffered per stream in cascade routing (SEC-CASCADE-03).
2607const CASCADE_STREAM_MAX_BYTES: usize = 1024 * 1024; // 1 MiB
2608
2609/// All chunks accumulated from a single provider stream, preserving non-Content chunks.
2610///
2611/// Keeping all chunk types allows the router to re-emit a buffered response faithfully
2612/// (including `Thinking`, `ToolUse`, and `Compaction` chunks) instead of silently
2613/// dropping them when a best-seen response is replayed.
2614#[derive(Clone, Default, Debug)]
2615struct CollectedStream {
2616    content: String,
2617    thinking: Vec<String>,
2618    tool_calls: Vec<crate::provider::ToolUseRequest>,
2619    compaction: Option<String>,
2620}
2621
2622impl CollectedStream {
2623    /// Reconstructs a `ChatStream` that re-emits all accumulated chunks in order.
2624    fn into_stream(self) -> ChatStream {
2625        use crate::provider::StreamChunk;
2626        let mut chunks: Vec<Result<StreamChunk, LlmError>> = Vec::new();
2627        for t in self.thinking {
2628            chunks.push(Ok(StreamChunk::Thinking(t)));
2629        }
2630        if !self.tool_calls.is_empty() {
2631            chunks.push(Ok(StreamChunk::ToolUse(self.tool_calls)));
2632        }
2633        if let Some(c) = self.compaction {
2634            chunks.push(Ok(StreamChunk::Compaction(c)));
2635        }
2636        if !self.content.is_empty() {
2637            chunks.push(Ok(StreamChunk::Content(self.content)));
2638        }
2639        Box::pin(tokio_stream::iter(chunks))
2640    }
2641
2642    fn is_empty(&self) -> bool {
2643        self.content.is_empty() && self.tool_calls.is_empty()
2644    }
2645}
2646
2647/// Collect a `ChatStream` into a [`CollectedStream`], preserving all chunk types.
2648///
2649/// Returns `Err` if the accumulated `Content` buffer exceeds [`CASCADE_STREAM_MAX_BYTES`].
2650async fn collect_stream(stream: ChatStream) -> Result<CollectedStream, LlmError> {
2651    use tokio_stream::StreamExt as _;
2652
2653    let mut stream = stream;
2654    let mut collected = CollectedStream::default();
2655    while let Some(chunk) = stream.next().await {
2656        match chunk? {
2657            crate::provider::StreamChunk::Content(c) => {
2658                if collected.content.len() + c.len() > CASCADE_STREAM_MAX_BYTES {
2659                    return Err(LlmError::Other(
2660                        "cascade: stream response exceeds 1 MiB buffer limit".into(),
2661                    ));
2662                }
2663                collected.content.push_str(&c);
2664            }
2665            crate::provider::StreamChunk::Thinking(t) => {
2666                collected.thinking.push(t);
2667            }
2668            crate::provider::StreamChunk::ToolUse(tools) => {
2669                collected.tool_calls.extend(tools);
2670            }
2671            crate::provider::StreamChunk::Compaction(c) => {
2672                collected.compaction = Some(c);
2673            }
2674        }
2675    }
2676    Ok(collected)
2677}
2678
2679#[cfg(test)]
2680mod tests {
2681    use super::*;
2682    use crate::provider::Role;
2683
2684    #[test]
2685    fn empty_router_name() {
2686        let r = RouterProvider::new(vec![]);
2687        assert_eq!(r.name(), "router");
2688    }
2689
2690    #[test]
2691    fn empty_router_supports_nothing() {
2692        let r = RouterProvider::new(vec![]);
2693        assert!(!r.supports_streaming());
2694        assert!(!r.supports_embeddings());
2695        assert!(!r.supports_tool_use());
2696    }
2697
2698    #[test]
2699    fn empty_router_context_window_none() {
2700        let r = RouterProvider::new(vec![]);
2701        assert!(r.context_window().is_none());
2702    }
2703
2704    #[tokio::test]
2705    async fn empty_router_chat_returns_no_providers() {
2706        let r = RouterProvider::new(vec![]);
2707        let msgs = vec![Message::from_legacy(Role::User, "hello")];
2708        let err = r.chat(&msgs).await.unwrap_err();
2709        assert!(matches!(err, LlmError::NoProviders));
2710    }
2711
2712    #[tokio::test]
2713    async fn empty_router_chat_stream_returns_no_providers() {
2714        let r = RouterProvider::new(vec![]);
2715        let msgs = vec![Message::from_legacy(Role::User, "hello")];
2716        let result = r.chat_stream(&msgs).await;
2717        assert!(matches!(result, Err(LlmError::NoProviders)));
2718    }
2719
2720    #[tokio::test]
2721    async fn empty_router_embed_returns_no_providers() {
2722        let r = RouterProvider::new(vec![]);
2723        let err = r.embed("test").await.unwrap_err();
2724        assert!(matches!(err, LlmError::NoProviders));
2725    }
2726
2727    #[tokio::test]
2728    async fn empty_router_chat_with_tools_returns_no_providers() {
2729        let r = RouterProvider::new(vec![]);
2730        let msgs = vec![Message::from_legacy(Role::User, "hello")];
2731        let err = r.chat_with_tools(&msgs, &[]).await.unwrap_err();
2732        assert!(matches!(err, LlmError::NoProviders));
2733    }
2734
2735    #[tokio::test]
2736    async fn router_falls_back_on_unreachable() {
2737        use crate::ollama::OllamaProvider;
2738
2739        let p1 = AnyProvider::Ollama(OllamaProvider::new(
2740            "http://127.0.0.1:1",
2741            "m".into(),
2742            "e".into(),
2743        ));
2744        let p2 = AnyProvider::Ollama(OllamaProvider::new(
2745            "http://127.0.0.1:2",
2746            "m".into(),
2747            "e".into(),
2748        ));
2749        let r = RouterProvider::new(vec![p1, p2]);
2750        let msgs = vec![Message::from_legacy(Role::User, "hello")];
2751        let err = r.chat(&msgs).await.unwrap_err();
2752        assert!(matches!(err, LlmError::NoProviders));
2753    }
2754
2755    #[test]
2756    fn router_with_streaming_provider() {
2757        use crate::ollama::OllamaProvider;
2758
2759        let p = AnyProvider::Ollama(OllamaProvider::new(
2760            "http://127.0.0.1:1",
2761            "m".into(),
2762            "e".into(),
2763        ));
2764        let r = RouterProvider::new(vec![p]);
2765        assert!(r.supports_streaming());
2766        assert!(r.supports_embeddings());
2767    }
2768
2769    #[test]
2770    fn clone_preserves_providers() {
2771        use crate::ollama::OllamaProvider;
2772
2773        let p = AnyProvider::Ollama(OllamaProvider::new(
2774            "http://127.0.0.1:1",
2775            "m".into(),
2776            "e".into(),
2777        ));
2778        let r = RouterProvider::new(vec![p]);
2779        let c = r.clone();
2780        assert_eq!(c.state.providers.len(), 1);
2781        assert_eq!(c.name(), "router");
2782    }
2783
2784    #[test]
2785    fn last_cache_usage_returns_none() {
2786        let r = RouterProvider::new(vec![]);
2787        assert!(r.last_cache_usage().is_none());
2788    }
2789
2790    #[test]
2791    fn thompson_strategy_is_set() {
2792        let r = RouterProvider::new(vec![]).with_thompson(None);
2793        assert_eq!(r.strategy, RouterStrategy::Thompson);
2794        assert!(r.thompson.is_some());
2795    }
2796
2797    #[tokio::test]
2798    async fn save_thompson_state_noop_without_thompson() {
2799        let r = RouterProvider::new(vec![]);
2800        r.save_thompson_state().await; // should not panic
2801    }
2802
2803    #[test]
2804    fn thompson_ordered_providers_empty() {
2805        let r = RouterProvider::new(vec![]).with_thompson(None);
2806        let ordered = r.ordered_providers();
2807        assert!(ordered.is_empty());
2808    }
2809
2810    #[test]
2811    fn concurrent_record_outcome_does_not_deadlock() {
2812        use std::sync::Arc;
2813        let r = Arc::new(RouterProvider::new(vec![]).with_thompson(None));
2814        let handles: Vec<_> = (0..8)
2815            .map(|i| {
2816                let router = Arc::clone(&r);
2817                std::thread::spawn(move || {
2818                    router.record_availability(&format!("p{i}"), i % 2 == 0, 10);
2819                })
2820            })
2821            .collect();
2822        for h in handles {
2823            h.join().expect("thread panicked");
2824        }
2825        // If we reach here, no deadlock occurred.
2826        let stats = r.thompson_stats();
2827        assert_eq!(stats.len(), 8);
2828    }
2829
2830    // ── Cascade tests ──────────────────────────────────────────────────────────
2831
2832    #[test]
2833    fn cascade_strategy_is_set() {
2834        let r = RouterProvider::new(vec![]).with_cascade(CascadeRouterConfig::default());
2835        assert_eq!(r.strategy, RouterStrategy::Cascade);
2836        assert!(r.cascade_state.is_some());
2837        assert!(r.cascade_config.is_some());
2838    }
2839
2840    #[test]
2841    fn cascade_ordered_providers_preserves_chain_order() {
2842        use crate::ollama::OllamaProvider;
2843        let p1 = AnyProvider::Ollama(OllamaProvider::new(
2844            "http://127.0.0.1:1",
2845            "a".into(),
2846            String::new(),
2847        ));
2848        let p2 = AnyProvider::Ollama(OllamaProvider::new(
2849            "http://127.0.0.1:2",
2850            "b".into(),
2851            String::new(),
2852        ));
2853        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
2854        let ordered = r.ordered_providers();
2855        assert_eq!(ordered.len(), 2);
2856    }
2857
2858    #[tokio::test]
2859    async fn cascade_empty_router_returns_no_providers() {
2860        let r = RouterProvider::new(vec![]).with_cascade(CascadeRouterConfig::default());
2861        let msgs = vec![Message::from_legacy(Role::User, "hello")];
2862        let err = r.chat(&msgs).await.unwrap_err();
2863        assert!(matches!(err, LlmError::NoProviders));
2864    }
2865
2866    #[tokio::test]
2867    async fn cascade_returns_best_seen_when_all_fail_after_good_response() {
2868        use crate::mock::MockProvider;
2869
2870        // Provider 1: returns low-quality response (short "ok", triggers escalation at 0.9 threshold)
2871        let cheap =
2872            AnyProvider::Mock(MockProvider::with_responses(vec!["ok".to_owned()]).with_delay(0));
2873        // Provider 2: fails with availability error
2874        let expensive = AnyProvider::Mock(MockProvider::failing());
2875
2876        let r = RouterProvider::new(vec![cheap, expensive]).with_cascade(CascadeRouterConfig {
2877            quality_threshold: 0.9, // high threshold ensures "ok" fails quality check
2878            max_escalations: 2,
2879            ..CascadeRouterConfig::default()
2880        });
2881        let msgs = vec![Message::from_legacy(Role::User, "hello")];
2882        // Should return "ok" from cheap provider (best-seen), not NoProviders.
2883        let result = r.chat(&msgs).await.unwrap();
2884        assert_eq!(result, "ok");
2885    }
2886
2887    #[tokio::test]
2888    async fn cascade_accepts_good_quality_response() {
2889        use crate::mock::MockProvider;
2890
2891        let good_response = "This is a comprehensive, well-structured response that provides \
2892            detailed information about the topic. It covers multiple aspects and explains \
2893            the reasoning clearly with proper sentence structure.";
2894
2895        let cheap = AnyProvider::Mock(
2896            MockProvider::with_responses(vec![good_response.to_owned()]).with_delay(0),
2897        );
2898        // second provider should never be called
2899        let expensive = AnyProvider::Mock(MockProvider::failing());
2900
2901        let r = RouterProvider::new(vec![cheap, expensive]).with_cascade(CascadeRouterConfig {
2902            quality_threshold: 0.5,
2903            max_escalations: 1,
2904            ..CascadeRouterConfig::default()
2905        });
2906        let msgs = vec![Message::from_legacy(Role::User, "explain something")];
2907        let result = r.chat(&msgs).await.unwrap();
2908        assert_eq!(result, good_response);
2909    }
2910
2911    #[tokio::test]
2912    async fn cascade_max_escalations_budget_exhausted_returns_last_attempted() {
2913        use crate::mock::MockProvider;
2914
2915        // All three providers return degenerate response "x" but budget limits to 1 escalation.
2916        // p1 -> escalation budget 1 -> p2 -> budget=0 -> accept p2's response (not p3).
2917        let p1 =
2918            AnyProvider::Mock(MockProvider::with_responses(vec!["x".to_owned()]).with_delay(0));
2919        let p2 =
2920            AnyProvider::Mock(MockProvider::with_responses(vec!["x".to_owned()]).with_delay(0));
2921        let p3 = AnyProvider::Mock(MockProvider::failing()); // should never be reached
2922
2923        let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
2924            quality_threshold: 0.9,
2925            max_escalations: 1, // only 1 escalation allowed
2926            ..CascadeRouterConfig::default()
2927        });
2928        let msgs = vec![Message::from_legacy(Role::User, "test")];
2929        let result = r.chat(&msgs).await.unwrap();
2930        assert_eq!(result, "x");
2931    }
2932
2933    #[tokio::test]
2934    async fn cascade_token_budget_stops_escalation() {
2935        use crate::mock::MockProvider;
2936
2937        let p1 =
2938            AnyProvider::Mock(MockProvider::with_responses(vec!["x".to_owned()]).with_delay(0));
2939        let p2 = AnyProvider::Mock(MockProvider::failing()); // should not be reached
2940
2941        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2942            quality_threshold: 0.9, // "x" will fail quality
2943            max_escalations: 5,
2944            max_cascade_tokens: Some(1), // 1 token budget — exhausted after first response (~4 chars / 4 = 0 + 1 min)
2945            ..CascadeRouterConfig::default()
2946        });
2947        let msgs = vec![Message::from_legacy(Role::User, "test")];
2948        let result = r.chat(&msgs).await.unwrap();
2949        assert_eq!(result, "x"); // returned despite low quality due to token budget
2950    }
2951
2952    #[tokio::test]
2953    async fn cascade_budget_returns_best_seen_not_current() {
2954        use crate::mock::MockProvider;
2955
2956        // p1 returns a decent response, p2 returns a worse one but exhausts the budget.
2957        // With budget_exhausted, we should get the best-seen (p1) not the current (p2).
2958        let good_response = "This is a reasonable response with enough content to score well.";
2959        let bad_response = "x"; // degenerate, score << good_response
2960
2961        let p1 = AnyProvider::Mock(
2962            MockProvider::with_responses(vec![good_response.to_owned()]).with_delay(0),
2963        );
2964        let p2 = AnyProvider::Mock(
2965            MockProvider::with_responses(vec![bad_response.to_owned()]).with_delay(0),
2966        );
2967
2968        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2969            quality_threshold: 0.95, // both fail quality check but good > bad
2970            max_escalations: 5,
2971            max_cascade_tokens: Some(1), // budget exhausted after p1 (1 token min)
2972            ..CascadeRouterConfig::default()
2973        });
2974        let msgs = vec![Message::from_legacy(Role::User, "test")];
2975        // p1 exhausts the budget; should return p1's response (better), not p2's (worse).
2976        // Note: p2 is reached since budget check happens AFTER p1's response is processed
2977        // and p1 fails quality. Budget exhausted at p2 → return best-seen (p1).
2978        let result = r.chat(&msgs).await.unwrap();
2979        // The result must not be the degenerate "x" response.
2980        assert_ne!(result, bad_response, "should return best-seen, not current");
2981    }
2982
2983    #[tokio::test]
2984    async fn cascade_escalations_exhausted_returns_best_seen_not_current() {
2985        use crate::mock::MockProvider;
2986
2987        // p1: decent response, fails quality at 0.95 → escalates (escalations_remaining: 1 → 0)
2988        // p2: degenerate "x", fails quality → escalations_remaining == 0 → blocked → best_seen wins
2989        let good_response = "This is a reasonable response with enough content to score well.";
2990        let bad_response = "x";
2991
2992        let p1 = AnyProvider::Mock(
2993            MockProvider::with_responses(vec![good_response.to_owned()]).with_delay(0),
2994        );
2995        let p2 = AnyProvider::Mock(
2996            MockProvider::with_responses(vec![bad_response.to_owned()]).with_delay(0),
2997        );
2998        let p3 = AnyProvider::Mock(MockProvider::failing()); // should not be reached
2999
3000        let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
3001            quality_threshold: 0.95, // both fail quality; p1 score > p2 score
3002            max_escalations: 1,      // p1 escalates (budget: 1→0), p2 is blocked
3003            ..CascadeRouterConfig::default()
3004        });
3005        let msgs = vec![Message::from_legacy(Role::User, "test")];
3006        let result = r.chat(&msgs).await.unwrap();
3007        assert_eq!(
3008            result, good_response,
3009            "should return best-seen (p1), not the degenerate current response (p2)"
3010        );
3011        assert_ne!(
3012            result, bad_response,
3013            "must not return degenerate p2 response"
3014        );
3015    }
3016
3017    #[tokio::test]
3018    async fn cascade_stream_escalations_exhausted_returns_best_seen_not_current() {
3019        use crate::mock::MockProvider;
3020
3021        // Same scenario as above but for cascade_chat_stream.
3022        // p1: decent response, fails quality at 0.95 → escalates (escalations_remaining: 1 → 0)
3023        // p2: degenerate "x", fails quality → escalations_remaining == 0 → return best_seen
3024        let good_response = "This is a reasonable response with enough content to score well.";
3025        let bad_response = "x";
3026
3027        let p1 = AnyProvider::Mock(
3028            MockProvider::with_responses(vec![good_response.to_owned()])
3029                .with_delay(0)
3030                .with_streaming(),
3031        );
3032        let p2 = AnyProvider::Mock(
3033            MockProvider::with_responses(vec![bad_response.to_owned()])
3034                .with_delay(0)
3035                .with_streaming(),
3036        );
3037        let p3 = AnyProvider::Mock(MockProvider::failing()); // last provider, should not be reached
3038
3039        let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
3040            quality_threshold: 0.95, // both fail quality; p1 score > p2 score
3041            max_escalations: 1,      // p1 escalates (budget: 1→0), p2 is blocked
3042            ..CascadeRouterConfig::default()
3043        });
3044        let msgs = vec![Message::from_legacy(Role::User, "test")];
3045        let stream = r.chat_stream(&msgs).await.unwrap();
3046        let collected = collect_stream(stream).await.unwrap();
3047        assert_eq!(
3048            collected.content, good_response,
3049            "should return best-seen (p1), not the degenerate current response (p2)"
3050        );
3051        assert_ne!(
3052            collected.content, bad_response,
3053            "must not return degenerate p2 response"
3054        );
3055    }
3056
3057    #[tokio::test]
3058    async fn cascade_all_providers_fail_returns_no_providers() {
3059        use crate::mock::MockProvider;
3060
3061        let p1 = AnyProvider::Mock(MockProvider::failing());
3062        let p2 = AnyProvider::Mock(MockProvider::failing());
3063
3064        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
3065        let msgs = vec![Message::from_legacy(Role::User, "test")];
3066        let err = r.chat(&msgs).await.unwrap_err();
3067        assert!(matches!(err, LlmError::NoProviders));
3068    }
3069
3070    #[tokio::test]
3071    async fn cascade_stream_good_quality_no_escalation() {
3072        use crate::mock::MockProvider;
3073
3074        let good = "This is a well-formed response with sufficient length and coherent structure.";
3075        let p1 = AnyProvider::Mock(
3076            MockProvider::with_responses(vec![good.to_owned()])
3077                .with_delay(0)
3078                .with_streaming(),
3079        );
3080        let p2 = AnyProvider::Mock(MockProvider::failing());
3081
3082        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3083            quality_threshold: 0.5,
3084            max_escalations: 1,
3085            ..CascadeRouterConfig::default()
3086        });
3087        let msgs = vec![Message::from_legacy(Role::User, "q")];
3088        let stream = r.chat_stream(&msgs).await.unwrap();
3089        let collected = collect_stream(stream).await.unwrap();
3090        assert_eq!(collected.content, good);
3091    }
3092
3093    #[tokio::test]
3094    async fn cascade_stream_escalates_to_last_provider() {
3095        use crate::mock::MockProvider;
3096
3097        let bad = "x"; // low quality, should escalate
3098        let good = "This is the expensive model's comprehensive response.";
3099        let p1 = AnyProvider::Mock(
3100            MockProvider::with_responses(vec![bad.to_owned()])
3101                .with_delay(0)
3102                .with_streaming(),
3103        );
3104        let p2 = AnyProvider::Mock(
3105            MockProvider::with_responses(vec![good.to_owned()])
3106                .with_delay(0)
3107                .with_streaming(),
3108        );
3109
3110        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3111            quality_threshold: 0.9, // "x" fails quality
3112            max_escalations: 1,
3113            ..CascadeRouterConfig::default()
3114        });
3115        let msgs = vec![Message::from_legacy(Role::User, "q")];
3116        let stream = r.chat_stream(&msgs).await.unwrap();
3117        let collected = collect_stream(stream).await.unwrap();
3118        assert_eq!(collected.content, good);
3119    }
3120
3121    #[tokio::test]
3122    async fn cascade_stream_budget_returns_best_seen() {
3123        use crate::mock::MockProvider;
3124
3125        // Three providers: early=[p1, p2], last=p3.
3126        // p1 returns a decent response (fails quality threshold at 0.95, triggers escalation).
3127        // Budget is set to 1 token, so it is exhausted immediately after p1 processes.
3128        // best_seen = p1's response; budget_exhausted + should_escalate → return best_seen.
3129        let good_response = "This is a reasonable response with enough content to score well.";
3130        let bad_response = "x"; // degenerate, score << good_response
3131
3132        let p1 = AnyProvider::Mock(
3133            MockProvider::with_responses(vec![good_response.to_owned()])
3134                .with_delay(0)
3135                .with_streaming(),
3136        );
3137        let p2 = AnyProvider::Mock(
3138            MockProvider::with_responses(vec![bad_response.to_owned()])
3139                .with_delay(0)
3140                .with_streaming(),
3141        );
3142        let p3 = AnyProvider::Mock(MockProvider::failing()); // last provider, not reached
3143
3144        let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
3145            quality_threshold: 0.95, // p1 fails quality check → triggers escalation path
3146            max_escalations: 5,
3147            max_cascade_tokens: Some(1), // budget exhausted after p1 (1 token min)
3148            ..CascadeRouterConfig::default()
3149        });
3150        let msgs = vec![Message::from_legacy(Role::User, "test")];
3151        let stream = r.chat_stream(&msgs).await.unwrap();
3152        let collected = collect_stream(stream).await.unwrap();
3153        // Must return best-seen (p1's good response).
3154        assert_eq!(
3155            collected.content, good_response,
3156            "should return best-seen p1 response when budget exhausted"
3157        );
3158    }
3159
3160    #[tokio::test]
3161    async fn cascade_stream_budget_returns_best_seen_not_current() {
3162        use crate::mock::MockProvider;
3163
3164        // Four providers: early=[p1, p2, p3], last=p4.
3165        // p1 returns a good response, fails quality at 0.95 (score ~0.6), escalates; budget not yet exhausted.
3166        // p2 returns a degenerate response "x", fails quality, exhausts the budget.
3167        // At budget exhaustion: best_seen = p1 (higher score), current = p2's "x".
3168        // Must return best_seen (p1), not current (p2).
3169        let good_response = "This is a reasonable response with enough content to score well.";
3170        let bad_response = "x"; // 1 char → estimated_tokens = max(1/4, 1) = 1
3171
3172        let p1 = AnyProvider::Mock(
3173            MockProvider::with_responses(vec![good_response.to_owned()])
3174                .with_delay(0)
3175                .with_streaming(),
3176        );
3177        let p2 = AnyProvider::Mock(
3178            MockProvider::with_responses(vec![bad_response.to_owned()])
3179                .with_delay(0)
3180                .with_streaming(),
3181        );
3182        let p3 = AnyProvider::Mock(MockProvider::failing()); // last provider, not reached
3183        let p4 = AnyProvider::Mock(MockProvider::failing()); // last provider, not reached
3184
3185        // Budget = 20: p1 uses ~16 tokens (65 chars / 4), p2 uses 1 → total 17 ≥ 20? No.
3186        // Use budget = 17 so p2 exhausts it.
3187        let r = RouterProvider::new(vec![p1, p2, p3, p4]).with_cascade(CascadeRouterConfig {
3188            quality_threshold: 0.95, // both fail; p1 score > p2 score
3189            max_escalations: 5,
3190            max_cascade_tokens: Some(17), // p1 uses 16, p2 uses 1 → total 17 ≥ 17 after p2
3191            ..CascadeRouterConfig::default()
3192        });
3193        let msgs = vec![Message::from_legacy(Role::User, "test")];
3194        let stream = r.chat_stream(&msgs).await.unwrap();
3195        let collected = collect_stream(stream).await.unwrap();
3196        // Must return p1 (best_seen), not p2 (current at time of budget exhaustion).
3197        assert_eq!(
3198            collected.content, good_response,
3199            "should return best-seen (p1), not current degenerate (p2)"
3200        );
3201        assert_ne!(
3202            collected.content, bad_response,
3203            "must not return the degenerate p2 response"
3204        );
3205    }
3206
3207    #[tokio::test]
3208    async fn cascade_stream_last_fails_returns_best_seen() {
3209        use crate::mock::MockProvider;
3210
3211        // Two providers: early=[p1], last=p2.
3212        // p1 returns a low-quality response that triggers escalation.
3213        // p2 (last) fails with an error.
3214        // Should return p1's response (best-seen) instead of propagating the error.
3215        let low_quality = "ok"; // short, triggers escalation at 0.9 threshold
3216        let p1 = AnyProvider::Mock(
3217            MockProvider::with_responses(vec![low_quality.to_owned()])
3218                .with_delay(0)
3219                .with_streaming(),
3220        );
3221        let p2 = AnyProvider::Mock(MockProvider::failing()); // last provider fails
3222
3223        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3224            quality_threshold: 0.9, // "ok" fails quality, triggers escalation
3225            max_escalations: 2,
3226            ..CascadeRouterConfig::default()
3227        });
3228        let msgs = vec![Message::from_legacy(Role::User, "hello")];
3229        let stream = r.chat_stream(&msgs).await.unwrap();
3230        let collected = collect_stream(stream).await.unwrap();
3231        assert_eq!(collected.content, low_quality);
3232    }
3233
3234    #[tokio::test]
3235    async fn cascade_stream_all_fail_returns_error() {
3236        use crate::mock::MockProvider;
3237
3238        // Two providers, both fail. No best_seen accumulated.
3239        // p1 is early (errors → continue), p2 is last (errors → propagated).
3240        // The last provider's error must be propagated, not swallowed.
3241        let p1 = AnyProvider::Mock(MockProvider::failing());
3242        let p2 = AnyProvider::Mock(MockProvider::failing());
3243
3244        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
3245        let msgs = vec![Message::from_legacy(Role::User, "test")];
3246        let result = r.chat_stream(&msgs).await;
3247        assert!(
3248            result.is_err(),
3249            "expected error when all providers fail with no best_seen"
3250        );
3251    }
3252
3253    #[test]
3254    fn cascade_config_default_values() {
3255        let cfg = CascadeRouterConfig::default();
3256        assert!((cfg.quality_threshold - 0.5).abs() < f64::EPSILON);
3257        assert_eq!(cfg.max_escalations, 2);
3258        assert_eq!(cfg.window_size, 50);
3259        assert!(cfg.max_cascade_tokens.is_none());
3260        assert_eq!(cfg.classifier_mode, cascade::ClassifierMode::Heuristic);
3261    }
3262
3263    #[test]
3264    fn evaluate_heuristic_empty_should_escalate_above_threshold() {
3265        let verdict = RouterProvider::evaluate_heuristic("", 0.05);
3266        // score = 0.0, threshold = 0.05 → should_escalate = true
3267        assert!(verdict.should_escalate);
3268    }
3269
3270    #[test]
3271    fn evaluate_heuristic_good_response_does_not_escalate() {
3272        let text = "The answer to your question is straightforward. Consider the options and pick the best one.";
3273        let verdict = RouterProvider::evaluate_heuristic(text, 0.5);
3274        assert!(!verdict.should_escalate, "score={}", verdict.score);
3275    }
3276
3277    /// Empty string from the only provider must not be stored as `best_seen`.
3278    /// When all providers fail or return empty, the caller should get an error,
3279    /// not a silent empty response.
3280    #[tokio::test]
3281    async fn cascade_empty_response_not_stored_as_best_seen() {
3282        use crate::mock::MockProvider;
3283
3284        // Single provider returns empty string (score=0.0, should_escalate may be true/false).
3285        // With quality_threshold=0.0 it won't escalate, so we can check the return value.
3286        let p = AnyProvider::Mock(MockProvider::with_responses(vec![String::new()]));
3287        let cfg = CascadeRouterConfig {
3288            quality_threshold: 0.0,
3289            ..Default::default()
3290        };
3291        let r = RouterProvider::new(vec![p]).with_cascade(cfg);
3292        let msgs = vec![Message::from_legacy(Role::User, "hi")];
3293        // The provider returns "" — cascade must return it as-is (no best_seen involved
3294        // with a single provider), but this test confirms "" is not stored when escalating.
3295        let result = r.chat(&msgs).await;
3296        assert!(result.is_ok());
3297        assert_eq!(result.unwrap(), "");
3298    }
3299
3300    /// When provider 1 returns empty and provider 2 fails, `best_seen` must not hold
3301    /// the empty string — the caller must get an error, not a silent empty response.
3302    #[tokio::test]
3303    async fn cascade_empty_best_seen_not_returned_on_all_fail() {
3304        use crate::mock::MockProvider;
3305
3306        // p1: returns empty string (causes escalation with default threshold)
3307        // p2: hard error
3308        let p1 = AnyProvider::Mock(MockProvider::with_responses(vec![String::new()]));
3309        let p2 = AnyProvider::Mock(MockProvider::failing());
3310
3311        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
3312        let msgs = vec![Message::from_legacy(Role::User, "hi")];
3313        let result = r.chat(&msgs).await;
3314        // best_seen must NOT be the empty string; error must propagate.
3315        assert!(
3316            result.is_err(),
3317            "expected error, not silent empty string; got: {result:?}"
3318        );
3319    }
3320
3321    /// Stream variant: empty string from early provider must not be stored as `best_seen`.
3322    #[tokio::test]
3323    async fn cascade_stream_empty_response_not_stored_as_best_seen() {
3324        use crate::mock::MockProvider;
3325
3326        // p1 (early): returns "" — should NOT be stored as best_seen.
3327        // p2 (last): returns a real response.
3328        let p1 = AnyProvider::Mock(MockProvider::with_responses(vec![String::new()]));
3329        let p2 = AnyProvider::Mock(
3330            MockProvider::with_responses(vec!["real answer".to_owned()]).with_streaming(),
3331        );
3332
3333        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
3334        let msgs = vec![Message::from_legacy(Role::User, "hi")];
3335        let stream = r.chat_stream(&msgs).await.expect("should not error");
3336        let collected = collect_stream(stream).await.expect("stream should succeed");
3337        assert_eq!(collected.content, "real answer");
3338    }
3339
3340    // ── Arc<[AnyProvider]> + cost_tiers tests ──────────────────────────────────
3341
3342    #[test]
3343    fn arc_providers_clone_shares_allocation() {
3344        use crate::mock::MockProvider;
3345        let p = AnyProvider::Mock(MockProvider::default());
3346        let r = RouterProvider::new(vec![p]);
3347        let c = r.clone();
3348        // Both RouterProvider instances must share the same Arc allocation.
3349        assert!(Arc::ptr_eq(&r.state.providers, &c.state.providers));
3350    }
3351
3352    #[test]
3353    fn cost_tiers_reorders_providers_at_construction() {
3354        use crate::mock::MockProvider;
3355        let p1 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
3356        let p2 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
3357        let p3 = AnyProvider::Mock(MockProvider::default().with_name("openai"));
3358        let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
3359            cost_tiers: Some(vec!["ollama".into(), "claude".into()]),
3360            ..CascadeRouterConfig::default()
3361        });
3362        let names: Vec<&str> = r.state.providers.iter().map(LlmProvider::name).collect();
3363        // ollama first (tier 0), claude second (tier 1), openai last (unlisted, original idx 2)
3364        assert_eq!(names, vec!["ollama", "claude", "openai"]);
3365    }
3366
3367    #[test]
3368    fn cost_tiers_none_preserves_chain_order() {
3369        use crate::mock::MockProvider;
3370        let p1 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
3371        let p2 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
3372        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3373            cost_tiers: None,
3374            ..CascadeRouterConfig::default()
3375        });
3376        let names: Vec<&str> = r.state.providers.iter().map(LlmProvider::name).collect();
3377        assert_eq!(names, vec!["claude", "ollama"]);
3378    }
3379
3380    #[test]
3381    fn cost_tiers_empty_vec_preserves_chain_order() {
3382        use crate::mock::MockProvider;
3383        let p1 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
3384        let p2 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
3385        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3386            cost_tiers: Some(vec![]),
3387            ..CascadeRouterConfig::default()
3388        });
3389        let names: Vec<&str> = r.state.providers.iter().map(LlmProvider::name).collect();
3390        assert_eq!(names, vec!["claude", "ollama"]);
3391    }
3392
3393    #[test]
3394    fn cost_tiers_unknown_name_ignored() {
3395        use crate::mock::MockProvider;
3396        let p1 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
3397        let p2 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
3398        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3399            cost_tiers: Some(vec!["nonexistent".into(), "ollama".into()]),
3400            ..CascadeRouterConfig::default()
3401        });
3402        let names: Vec<&str> = r.state.providers.iter().map(LlmProvider::name).collect();
3403        // "nonexistent" ignored; "ollama" is tier 1 → first; "claude" unlisted → second
3404        assert_eq!(names, vec!["ollama", "claude"]);
3405    }
3406
3407    #[test]
3408    fn cost_tiers_all_providers_listed() {
3409        use crate::mock::MockProvider;
3410        let p1 = AnyProvider::Mock(MockProvider::default().with_name("c"));
3411        let p2 = AnyProvider::Mock(MockProvider::default().with_name("b"));
3412        let p3 = AnyProvider::Mock(MockProvider::default().with_name("a"));
3413        let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
3414            cost_tiers: Some(vec!["a".into(), "b".into(), "c".into()]),
3415            ..CascadeRouterConfig::default()
3416        });
3417        let names: Vec<&str> = r.state.providers.iter().map(LlmProvider::name).collect();
3418        assert_eq!(names, vec!["a", "b", "c"]);
3419    }
3420
3421    #[test]
3422    fn cost_tiers_duplicate_name_uses_last_position() {
3423        use crate::mock::MockProvider;
3424        let p1 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
3425        let p2 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
3426        // "ollama" appears twice in tiers: HashMap overwrites → position 2.
3427        // claude=tier 0, ollama=tier 2 → claude before ollama.
3428        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3429            cost_tiers: Some(vec!["claude".into(), "ollama".into(), "ollama".into()]),
3430            ..CascadeRouterConfig::default()
3431        });
3432        let names: Vec<&str> = r.state.providers.iter().map(LlmProvider::name).collect();
3433        assert_eq!(names, vec!["claude", "ollama"]);
3434    }
3435
3436    #[test]
3437    fn cost_tiers_empty_router_does_not_panic() {
3438        let r = RouterProvider::new(vec![]).with_cascade(CascadeRouterConfig {
3439            cost_tiers: Some(vec!["foo".into()]),
3440            ..CascadeRouterConfig::default()
3441        });
3442        assert_eq!(r.state.providers.len(), 0);
3443    }
3444
3445    #[test]
3446    fn set_status_tx_works_with_arc() {
3447        use crate::mock::MockProvider;
3448        let p = AnyProvider::Mock(MockProvider::default());
3449        let mut r = RouterProvider::new(vec![p]);
3450        let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
3451        r.set_status_tx(tx); // must not panic
3452    }
3453
3454    #[tokio::test]
3455    async fn cascade_chat_with_tools_unaffected_by_cost_tiers() {
3456        use crate::mock::MockProvider;
3457        // chat_with_tools skips cascade entirely (HIGH-04). Verify that cost_tiers
3458        // ordering does not accidentally affect the non-cascade tool fallback path.
3459        let p1 = AnyProvider::Mock(MockProvider::failing().with_name("cheap"));
3460        let p2 = AnyProvider::Mock(MockProvider::failing().with_name("expensive"));
3461        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
3462            cost_tiers: Some(vec!["cheap".into()]),
3463            ..CascadeRouterConfig::default()
3464        });
3465        let msgs = vec![Message::from_legacy(Role::User, "hi")];
3466        // Both providers fail → NoProviders, not a cascade-specific error.
3467        let err = r.chat_with_tools(&msgs, &[]).await.unwrap_err();
3468        assert!(matches!(err, LlmError::NoProviders));
3469    }
3470
3471    // ── Embed retry / rate-limit tests ────────────────────────────────────────
3472
3473    /// Provider returns `RateLimited` twice then succeeds on the third attempt.
3474    /// The router must retry and return the successful embedding.
3475    #[tokio::test]
3476    async fn embed_retries_on_rate_limited_then_succeeds() {
3477        use crate::mock::MockProvider;
3478
3479        let p = AnyProvider::Mock({
3480            let mut m = MockProvider::default()
3481                .with_errors(vec![LlmError::RateLimited, LlmError::RateLimited])
3482                .with_name("p1");
3483            m.supports_embeddings = true;
3484            m.embedding = vec![0.1, 0.2];
3485            m
3486        });
3487        let r = RouterProvider::new(vec![p]);
3488        let result = r.embed("text").await.unwrap();
3489        assert_eq!(result, vec![0.1, 0.2]);
3490    }
3491
3492    /// When all retries (3) are exhausted on the first provider, the router falls
3493    /// back to the second provider and returns its embedding.
3494    #[tokio::test]
3495    async fn embed_falls_back_after_all_retries_exhausted() {
3496        use crate::mock::MockProvider;
3497
3498        // p1: 4 RateLimited errors (attempt 0..=3 all fail)
3499        let p1 = AnyProvider::Mock({
3500            let mut m = MockProvider::default()
3501                .with_errors(vec![
3502                    LlmError::RateLimited,
3503                    LlmError::RateLimited,
3504                    LlmError::RateLimited,
3505                    LlmError::RateLimited,
3506                ])
3507                .with_name("p1");
3508            m.supports_embeddings = true;
3509            m
3510        });
3511        let p2 = AnyProvider::Mock({
3512            let mut m = MockProvider::default().with_name("p2");
3513            m.supports_embeddings = true;
3514            m.embedding = vec![9.0, 8.0];
3515            m
3516        });
3517        let r = RouterProvider::new(vec![p1, p2]);
3518        let result = r.embed("text").await.unwrap();
3519        assert_eq!(result, vec![9.0, 8.0]);
3520    }
3521
3522    /// Provider returns `RateLimited` twice then succeeds via `embed_batch`.
3523    #[tokio::test]
3524    async fn embed_batch_retries_on_rate_limited_then_succeeds() {
3525        use crate::mock::MockProvider;
3526
3527        let p = AnyProvider::Mock({
3528            let mut m = MockProvider::default()
3529                .with_errors(vec![LlmError::RateLimited, LlmError::RateLimited])
3530                .with_name("p1");
3531            m.supports_embeddings = true;
3532            m.embedding = vec![0.5, 0.6];
3533            m
3534        });
3535        let r = RouterProvider::new(vec![p]);
3536        let result = r.embed_batch(&["a", "b"]).await.unwrap();
3537        assert_eq!(result, vec![vec![0.5, 0.6], vec![0.5, 0.6]]);
3538    }
3539
3540    /// When all `embed_batch` retries are exhausted on the first provider, falls back
3541    /// to the second provider.
3542    #[tokio::test]
3543    async fn embed_batch_falls_back_after_all_retries_exhausted() {
3544        use crate::mock::MockProvider;
3545
3546        // p1 needs 4 errors per embed call * 1 text = 4 total (attempt 0..=3)
3547        let p1 = AnyProvider::Mock({
3548            let mut m = MockProvider::default()
3549                .with_errors(vec![
3550                    LlmError::RateLimited,
3551                    LlmError::RateLimited,
3552                    LlmError::RateLimited,
3553                    LlmError::RateLimited,
3554                ])
3555                .with_name("p1");
3556            m.supports_embeddings = true;
3557            m
3558        });
3559        let p2 = AnyProvider::Mock({
3560            let mut m = MockProvider::default().with_name("p2");
3561            m.supports_embeddings = true;
3562            m.embedding = vec![7.0, 8.0];
3563            m
3564        });
3565        let r = RouterProvider::new(vec![p1, p2]);
3566        let result = r.embed_batch(&["x"]).await.unwrap();
3567        assert_eq!(result, vec![vec![7.0, 8.0]]);
3568    }
3569
3570    // ── InvalidInput embed break tests ────────────────────────────────────────
3571
3572    /// When a provider returns `InvalidInput` from `embed()`, the router must break
3573    /// the fallback loop immediately and return `InvalidInput` — not `NoProviders`.
3574    #[tokio::test]
3575    async fn embed_invalid_input_breaks_loop_and_returns_invalid_input() {
3576        use crate::mock::MockProvider;
3577
3578        let p = AnyProvider::Mock(MockProvider::default().with_embed_invalid_input());
3579        let r = RouterProvider::new(vec![p]).with_thompson(None);
3580        let err = r.embed("some text").await.unwrap_err();
3581        assert!(
3582            matches!(err, LlmError::InvalidInput { .. }),
3583            "expected InvalidInput, got {err:?}"
3584        );
3585    }
3586
3587    /// When a provider returns `InvalidInput`, the router must NOT fall through to
3588    /// the next provider — a second embed-capable provider must never be called.
3589    #[tokio::test]
3590    async fn embed_invalid_input_does_not_fall_through_to_second_provider() {
3591        use crate::mock::MockProvider;
3592
3593        // p1 returns InvalidInput; p2 is a functioning embed provider.
3594        // If the loop falls through, p2 returns Ok — which would mean the error was
3595        // swallowed instead of breaking immediately.
3596        let p1 = AnyProvider::Mock(
3597            MockProvider::default()
3598                .with_embed_invalid_input()
3599                .with_name("p1"),
3600        );
3601        let p2 = AnyProvider::Mock({
3602            let mut m = MockProvider::default();
3603            m.supports_embeddings = true;
3604            m.name_override = Some("p2".into());
3605            m
3606        });
3607
3608        let r = RouterProvider::new(vec![p1, p2]);
3609        let err = r.embed("test").await.unwrap_err();
3610
3611        // The error must carry p1's name, proving p2 was never reached.
3612        assert!(
3613            matches!(&err, LlmError::InvalidInput { provider, .. } if provider == "p1"),
3614            "expected InvalidInput from p1, got {err:?}"
3615        );
3616    }
3617
3618    // ── InvalidInput chat_with_tools break tests ───────────────────────────────
3619
3620    /// When a provider returns `InvalidInput` from `chat_with_tools()`, the router must break
3621    /// the fallback loop immediately and return `InvalidInput` — not `NoProviders`.
3622    #[tokio::test]
3623    async fn chat_with_tools_invalid_input_breaks_loop_and_returns_invalid_input() {
3624        use crate::mock::MockProvider;
3625        use crate::provider::ToolDefinition;
3626
3627        let p = AnyProvider::Mock(MockProvider::default().with_tool_chat_invalid_input());
3628        let r = RouterProvider::new(vec![p]).with_thompson(None);
3629        let err = r
3630            .chat_with_tools(&[], &[] as &[ToolDefinition])
3631            .await
3632            .unwrap_err();
3633        assert!(
3634            matches!(err, LlmError::InvalidInput { .. }),
3635            "expected InvalidInput, got {err:?}"
3636        );
3637    }
3638
3639    /// When a provider returns `InvalidInput` from `chat_with_tools()`, the router must NOT
3640    /// fall through to the next provider.
3641    #[tokio::test]
3642    async fn chat_with_tools_invalid_input_does_not_fall_through_to_second_provider() {
3643        use crate::mock::MockProvider;
3644        use crate::provider::ToolDefinition;
3645
3646        let p1 = AnyProvider::Mock(
3647            MockProvider::default()
3648                .with_tool_chat_invalid_input()
3649                .with_name("p1"),
3650        );
3651        let p2 = AnyProvider::Mock(MockProvider::default().with_name("p2"));
3652
3653        let r = RouterProvider::new(vec![p1, p2]);
3654        let err = r
3655            .chat_with_tools(&[], &[] as &[ToolDefinition])
3656            .await
3657            .unwrap_err();
3658
3659        assert!(
3660            matches!(&err, LlmError::InvalidInput { provider, .. } if provider == "p1"),
3661            "expected InvalidInput from p1, got {err:?}"
3662        );
3663    }
3664
3665    /// The router skips providers that do not support embeddings and continues to
3666    /// the next one, returning a successful result from the first capable provider.
3667    #[tokio::test]
3668    async fn embed_skips_non_embedding_providers_and_falls_through() {
3669        use crate::mock::MockProvider;
3670
3671        // p1 does not support embeddings — skipped by the loop guard.
3672        // p2 supports embeddings and returns successfully.
3673        let p1 = AnyProvider::Mock({
3674            let mut m = MockProvider::default().with_name("p1");
3675            m.supports_embeddings = false;
3676            m
3677        });
3678        let p2 = AnyProvider::Mock({
3679            let mut m = MockProvider::default().with_name("p2");
3680            m.supports_embeddings = true;
3681            m.embedding = vec![1.0, 2.0, 3.0];
3682            m
3683        });
3684
3685        let r = RouterProvider::new(vec![p1, p2]);
3686        let result = r.embed("hello").await.unwrap();
3687        assert_eq!(result, vec![1.0, 2.0, 3.0]);
3688    }
3689
3690    /// `InvalidInput` from embed does not call `record_availability` (no reputation penalty).
3691    /// We verify this indirectly: `thompson_stats` must show no entry for the provider
3692    /// after an `InvalidInput` embed, whereas a normal embed failure increments it.
3693    #[tokio::test]
3694    async fn embed_invalid_input_does_not_record_availability() {
3695        use crate::mock::MockProvider;
3696
3697        let p = AnyProvider::Mock(
3698            MockProvider::default()
3699                .with_embed_invalid_input()
3700                .with_name("test-provider"),
3701        );
3702        let r = RouterProvider::new(vec![p]).with_thompson(None);
3703        let _ = r.embed("text").await;
3704
3705        // record_availability is only called on success or generic error,
3706        // not on InvalidInput. So thompson_stats must have no entry for "test-provider".
3707        let stats = r.thompson_stats();
3708        let provider_in_stats = stats.iter().any(|(name, ..)| name == "test-provider");
3709        assert!(
3710            !provider_in_stats,
3711            "InvalidInput must not update provider reputation; stats: {stats:?}"
3712        );
3713    }
3714
3715    // ── embed timeout tests ───────────────────────────────────────────────────
3716
3717    /// When the only provider's `embed()` exceeds `embed_timeout_ms`, the router
3718    /// exhausts the fallback list and returns `LlmError::NoProviders`.
3719    #[tokio::test]
3720    async fn embed_timeout_single_provider_returns_no_providers() {
3721        use crate::mock::MockProvider;
3722
3723        let p = AnyProvider::Mock(
3724            MockProvider::default()
3725                .with_embed_delay(200)
3726                .with_name("slow"),
3727        );
3728        let r = RouterProvider::new(vec![p]).with_embed_timeout(10);
3729        let err = r.embed("hello").await.unwrap_err();
3730        assert!(
3731            matches!(err, LlmError::NoProviders),
3732            "expected NoProviders after timeout, got {err:?}"
3733        );
3734    }
3735
3736    /// After a timeout on the first provider, the router falls back to the next
3737    /// embed-capable provider and returns its successful result.
3738    #[tokio::test]
3739    async fn embed_timeout_falls_back_to_next_provider() {
3740        use crate::mock::MockProvider;
3741
3742        let p1 = AnyProvider::Mock(
3743            MockProvider::default()
3744                .with_embed_delay(200)
3745                .with_name("slow"),
3746        );
3747        let p2 = AnyProvider::Mock({
3748            let mut m = MockProvider::default().with_name("fast");
3749            m.supports_embeddings = true;
3750            m.embedding = vec![1.0, 2.0, 3.0];
3751            m
3752        });
3753        let r = RouterProvider::new(vec![p1, p2]).with_embed_timeout(10);
3754        let result = r.embed("hello").await.unwrap();
3755        assert_eq!(result, vec![1.0, 2.0, 3.0]);
3756    }
3757
3758    // ── quality_gate tests ────────────────────────────────────────────────────
3759
3760    /// `with_quality_gate()` happy path: when cosine similarity >= threshold the
3761    /// response is returned directly without falling back.
3762    #[tokio::test]
3763    async fn quality_gate_passes_when_similarity_above_threshold() {
3764        use crate::mock::MockProvider;
3765
3766        // p1 returns a response; embed returns a unit vector so cosine similarity
3767        // with itself is 1.0 (>= any reasonable threshold).
3768        let p1 = AnyProvider::Mock({
3769            let mut m = MockProvider::with_responses(vec!["answer".to_owned()]).with_name("p1");
3770            m.supports_embeddings = true;
3771            m.embedding = vec![1.0, 0.0];
3772            m
3773        });
3774        let r = RouterProvider::new(vec![p1])
3775            .with_thompson(None)
3776            .with_quality_gate(0.5);
3777        let msgs = vec![Message::from_legacy(Role::User, "question")];
3778        let result = r.chat(&msgs).await.unwrap();
3779        assert_eq!(result, "answer");
3780    }
3781
3782    /// `with_quality_gate()` exhaustion: when all providers fail the gate the router
3783    /// returns the best-seen response (highest similarity) rather than an error.
3784    #[tokio::test]
3785    async fn quality_gate_exhaustion_returns_best_seen() {
3786        use crate::mock::MockProvider;
3787
3788        // p1 returns a response but embedding similarity is 0.0 (orthogonal vectors)
3789        // so it fails the gate (0.0 < 0.9). p2 fails entirely.
3790        // Expected: best_seen from p1 is returned.
3791        let p1 = AnyProvider::Mock({
3792            let mut m =
3793                MockProvider::with_responses(vec!["best_so_far".to_owned()]).with_name("p1");
3794            m.supports_embeddings = true;
3795            // query embed = [1,0], response embed = [0,1] → similarity = 0.0
3796            m.embedding = vec![0.0, 1.0];
3797            m
3798        });
3799        let p2 = AnyProvider::Mock(MockProvider::failing().with_name("p2"));
3800        let r = RouterProvider::new(vec![p1, p2])
3801            .with_thompson(None)
3802            .with_quality_gate(0.9);
3803        let msgs = vec![Message::from_legacy(Role::User, "question")];
3804        let result = r.chat(&msgs).await.unwrap();
3805        assert_eq!(result, "best_so_far");
3806    }
3807
3808    // ── apply_routing_signals guard logic tests ───────────────────────────────
3809
3810    /// `quality_gate = 5.0` (> 1.0) must be silently ignored — the field is left
3811    /// as `None` and no panic occurs.
3812    #[test]
3813    fn routing_signals_quality_gate_above_one_is_ignored() {
3814        // Build a RouterProvider directly and check that with_quality_gate is only
3815        // called for in-range values by replicating the guard from provider.rs.
3816        let threshold: f32 = 5.0;
3817        let mut router = RouterProvider::new(vec![]);
3818        if threshold.is_finite() && threshold > 0.0 && threshold <= 1.0 {
3819            router = router.with_quality_gate(threshold);
3820        }
3821        assert!(
3822            router.quality_gate.is_none(),
3823            "out-of-range quality_gate must not be wired; got {:?}",
3824            router.quality_gate
3825        );
3826    }
3827
3828    /// `quality_gate = 0.8` (valid) must be wired into the router.
3829    #[test]
3830    fn routing_signals_quality_gate_valid_is_wired() {
3831        let threshold: f32 = 0.8;
3832        let mut router = RouterProvider::new(vec![]);
3833        if threshold.is_finite() && threshold > 0.0 && threshold <= 1.0 {
3834            router = router.with_quality_gate(threshold);
3835        }
3836        assert_eq!(
3837            router.quality_gate,
3838            Some(0.8),
3839            "valid quality_gate must be wired"
3840        );
3841    }
3842
3843    // --- ASI debounce tests ---
3844
3845    #[test]
3846    fn asi_debounce_same_turn_fires_once() {
3847        let router = RouterProvider::new(vec![]);
3848        let turn_id = 42u64;
3849
3850        // First call: prev == u64::MAX (initial) → not equal to turn_id → proceeds (returns false)
3851        let prev1 = router.state.asi_last_turn.swap(turn_id, Ordering::AcqRel);
3852        let first_dropped = prev1 == turn_id;
3853
3854        // Second call same turn: prev == turn_id → dropped
3855        let prev2 = router.state.asi_last_turn.swap(turn_id, Ordering::AcqRel);
3856        let second_dropped = prev2 == turn_id;
3857
3858        assert!(!first_dropped, "first call in turn must not be dropped");
3859        assert!(second_dropped, "second call in same turn must be dropped");
3860    }
3861
3862    #[test]
3863    fn asi_debounce_next_turn_fires_again() {
3864        let router = RouterProvider::new(vec![]);
3865
3866        // Simulate turn 1
3867        let prev1 = router.state.asi_last_turn.swap(1u64, Ordering::AcqRel);
3868        assert_ne!(prev1, 1u64, "turn 1: initial value != 1, should proceed");
3869
3870        // Simulate turn 2 — different turn_id
3871        let prev2 = router.state.asi_last_turn.swap(2u64, Ordering::AcqRel);
3872        let dropped = prev2 == 2u64;
3873        assert!(!dropped, "turn 2 must not be dropped (different turn_id)");
3874    }
3875
3876    #[test]
3877    fn turn_counter_increments_across_clones() {
3878        let router = RouterProvider::new(vec![]);
3879        let clone = router.clone();
3880
3881        let t0 = router.state.turn_counter.fetch_add(1, Ordering::Relaxed);
3882        let t1 = clone.state.turn_counter.fetch_add(1, Ordering::Relaxed);
3883
3884        // Both clones share the same Arc<AtomicU64>
3885        assert_eq!(t1, t0 + 1, "cloned router shares turn_counter");
3886    }
3887
3888    #[test]
3889    fn with_embed_concurrency_zero_means_no_semaphore() {
3890        let r = RouterProvider::new(vec![]).with_embed_concurrency(0);
3891        assert!(
3892            r.state.embed_semaphore.is_none(),
3893            "0 should disable semaphore"
3894        );
3895    }
3896
3897    #[test]
3898    fn with_embed_concurrency_positive_creates_semaphore() {
3899        let r = RouterProvider::new(vec![]).with_embed_concurrency(4);
3900        let sem = r
3901            .state
3902            .embed_semaphore
3903            .as_ref()
3904            .expect("semaphore should exist");
3905        assert_eq!(sem.available_permits(), 4);
3906    }
3907
3908    #[tokio::test]
3909    async fn embed_semaphore_limits_concurrency() {
3910        use std::sync::Arc as StdArc;
3911        use std::sync::atomic::{AtomicUsize, Ordering as AO};
3912
3913        // Use a semaphore with 2 permits. Verify that at most 2 concurrent
3914        // tasks can hold the permit at the same time.
3915        let sem = Arc::new(tokio::sync::Semaphore::new(2));
3916        let concurrent_peak = StdArc::new(AtomicUsize::new(0));
3917        let active = StdArc::new(AtomicUsize::new(0));
3918
3919        let mut handles = vec![];
3920        for _ in 0..6 {
3921            let sem_clone = sem.clone();
3922            let peak = concurrent_peak.clone();
3923            let active = active.clone();
3924            handles.push(tokio::spawn(async move {
3925                let _permit = sem_clone.acquire().await.unwrap();
3926                let cur = active.fetch_add(1, AO::SeqCst) + 1;
3927                // Track peak concurrent usage.
3928                let mut p = peak.load(AO::SeqCst);
3929                while p < cur {
3930                    match peak.compare_exchange(p, cur, AO::SeqCst, AO::SeqCst) {
3931                        Ok(_) => break,
3932                        Err(new) => p = new,
3933                    }
3934                }
3935                tokio::time::sleep(std::time::Duration::from_millis(5)).await;
3936                active.fetch_sub(1, AO::SeqCst);
3937            }));
3938        }
3939        for h in handles {
3940            h.await.unwrap();
3941        }
3942        assert!(
3943            concurrent_peak.load(AO::SeqCst) <= 2,
3944            "peak concurrency should not exceed semaphore limit"
3945        );
3946    }
3947
3948    // ── TurnEmbedCache tests (#2819) ──────────────────────────────────────────
3949
3950    /// T2: A second `embed_cached` call with the same text hits the cache instead of
3951    /// calling the underlying provider, and `embed_cache_hits` increments to 1.
3952    #[tokio::test]
3953    async fn turn_embed_cache_hit_increments_counter() {
3954        use crate::mock::MockProvider;
3955
3956        let mut m = MockProvider::default();
3957        m.supports_embeddings = true;
3958        m.embedding = vec![0.5, 0.5];
3959        let provider_embed_calls = Arc::clone(&m.embed_call_count);
3960
3961        let r = RouterProvider::new(vec![AnyProvider::Mock(m)]);
3962        let cache = Mutex::new(TurnEmbedCache::default());
3963
3964        // First call — cache miss → calls provider.
3965        let emb1 = r.embed_cached("hello", &cache).await.unwrap();
3966        // Second call — same text → cache hit, no provider call.
3967        let emb2 = r.embed_cached("hello", &cache).await.unwrap();
3968
3969        assert_eq!(emb1, emb2, "cached embedding must match original");
3970        assert_eq!(
3971            provider_embed_calls.load(Ordering::Relaxed),
3972            1,
3973            "provider embed() must be called exactly once (second call hits cache)"
3974        );
3975        let (total, hits) = r.embed_cache_metrics();
3976        assert_eq!(
3977            total, 2,
3978            "embed_call_count must be 2 (two embed_cached calls)"
3979        );
3980        assert_eq!(hits, 1, "embed_cache_hits must be 1 (one cache hit)");
3981    }
3982
3983    /// T3: Passing `Some(precomputed_embedding)` to `spawn_asi_update` does not trigger
3984    /// an `embed()` call on the provider; the ASI window is updated with the provided embedding.
3985    #[tokio::test]
3986    async fn spawn_asi_update_with_precomputed_skips_embed() {
3987        use crate::mock::MockProvider;
3988
3989        let mut m = MockProvider::with_responses(vec!["ok".to_owned()]);
3990        m.supports_embeddings = true;
3991        m.embedding = vec![1.0, 0.0];
3992        let provider_embed_calls = Arc::clone(&m.embed_call_count);
3993
3994        let r =
3995            RouterProvider::new(vec![AnyProvider::Mock(m)]).with_asi(AsiRouterConfig::default());
3996
3997        let precomputed = vec![0.9_f32, 0.1];
3998        let turn_id = 42u64;
3999
4000        // Inject a different turn id into asi_last_turn so the debounce doesn't fire.
4001        r.state.asi_last_turn.store(u64::MAX, Ordering::SeqCst);
4002
4003        r.spawn_asi_update(
4004            "p1",
4005            "response".to_owned(),
4006            turn_id,
4007            Some(precomputed.clone()),
4008        );
4009
4010        // Give the spawned task time to complete.
4011        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
4012
4013        // Provider embed() must not have been called.
4014        assert_eq!(
4015            provider_embed_calls.load(Ordering::Relaxed),
4016            0,
4017            "embed() must not be called when precomputed_embedding is Some"
4018        );
4019
4020        // The ASI window must have received the precomputed embedding.
4021        let asi = r.asi.as_ref().unwrap().lock();
4022        let coherence = asi.coherence("p1");
4023        // coherence_score returns None when the window has < 2 entries; after one push it's None.
4024        // We only verify the ASI has the provider in its window (score will be None with 1 entry).
4025        let _ = coherence; // just verifying no panic
4026    }
4027
4028    /// Regression for #4296: `blocking_load` must not panic on a `current_thread` runtime
4029    /// and must actually call the closure, returning its result.
4030    #[tokio::test]
4031    async fn blocking_load_runs_closure_on_current_thread_runtime() {
4032        let result = super::blocking_load(|| 42_u32);
4033        assert_eq!(result, 42, "blocking_load must return the closure result");
4034    }
4035
4036    // ── spawn_asi_update JoinSet reap regression (#4644) ─────────────────────
4037
4038    /// Regression for #4644: completed tasks in `asi_tasks` must be reaped before the cap
4039    /// check so that a full-but-finished `JoinSet` does not permanently block new spawns.
4040    #[tokio::test]
4041    async fn spawn_asi_update_reaped_after_cap_full() {
4042        use crate::mock::MockProvider;
4043        use std::sync::atomic::Ordering;
4044
4045        let mut m = MockProvider::with_responses(vec!["ok".to_owned()]);
4046        m.supports_embeddings = true;
4047        m.embedding = vec![1.0, 0.0];
4048        let embed_calls = Arc::clone(&m.embed_call_count);
4049
4050        let r =
4051            RouterProvider::new(vec![AnyProvider::Mock(m)]).with_asi(AsiRouterConfig::default());
4052        r.state.asi_last_turn.store(u64::MAX, Ordering::SeqCst);
4053
4054        // Spawn exactly MAX_ASI_TASKS tasks and wait for all to complete.
4055        for i in 0..super::MAX_ASI_TASKS {
4056            r.spawn_asi_update("p1", format!("resp{i}"), i as u64, Some(vec![0.5, 0.5]));
4057        }
4058        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
4059
4060        // After all tasks finish the JoinSet is full of completed handles.
4061        // Without the drain fix this next call would be silently skipped.
4062        r.spawn_asi_update(
4063            "p1",
4064            "extra".to_owned(),
4065            super::MAX_ASI_TASKS as u64,
4066            Some(vec![0.9, 0.1]),
4067        );
4068
4069        // Give the newly spawned task time to complete.
4070        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
4071
4072        // embed() must never have been called — all calls used a precomputed embedding.
4073        assert_eq!(
4074            embed_calls.load(Ordering::Relaxed),
4075            0,
4076            "embed() must not be called when precomputed_embedding is Some"
4077        );
4078
4079        // Trigger one more spawn — the drain inside will reap the last completed task,
4080        // proving the extra spawn was not permanently blocked by the full JoinSet.
4081        r.spawn_asi_update(
4082            "p1",
4083            "probe".to_owned(),
4084            (super::MAX_ASI_TASKS + 1) as u64,
4085            Some(vec![0.1, 0.9]),
4086        );
4087
4088        // All previously finished tasks must have been reaped; only the probe may remain.
4089        let remaining = r.asi_tasks.lock().len();
4090        assert!(
4091            remaining <= 1,
4092            "completed tasks must be reaped; at most 1 in-flight task expected, got {remaining}"
4093        );
4094    }
4095
4096    // ── spawn_asi_update timeout regression (#4566) ───────────────────────────
4097
4098    /// Regression for #4566: when `embed()` inside `spawn_asi_update` exceeds `embed_timeout_ms`,
4099    /// the ASI coherence window must NOT be updated (task returns early without pushing embedding).
4100    #[tokio::test]
4101    async fn spawn_asi_update_embed_timeout_does_not_update_asi() {
4102        use crate::mock::MockProvider;
4103        use std::sync::atomic::Ordering;
4104
4105        // Provider that takes 200 ms to embed — well above the 10 ms timeout.
4106        let mut m = MockProvider::with_responses(vec!["ok".to_owned()]);
4107        m.supports_embeddings = true;
4108        m.embedding = vec![1.0, 0.0];
4109        m.embed_delay_ms = 200;
4110        let provider_embed_calls = Arc::clone(&m.embed_call_count);
4111
4112        let r = RouterProvider::new(vec![AnyProvider::Mock(m)])
4113            .with_asi(AsiRouterConfig::default())
4114            .with_embed_timeout(10);
4115
4116        // Inject a sentinel turn id so the debounce does not fire.
4117        r.state.asi_last_turn.store(u64::MAX, Ordering::SeqCst);
4118
4119        // No precomputed embedding → router will attempt to call embed().
4120        r.spawn_asi_update("p1", "response".to_owned(), 1u64, None);
4121
4122        // Wait long enough for the spawned task to reach its timeout and return.
4123        tokio::time::sleep(std::time::Duration::from_millis(150)).await;
4124
4125        // embed() was called (the call was made before the timeout fired).
4126        assert!(
4127            provider_embed_calls.load(Ordering::Relaxed) >= 1,
4128            "embed() must have been attempted"
4129        );
4130
4131        // ASI window must be empty — timeout fired before push_embedding could run.
4132        let asi = r.asi.as_ref().unwrap().lock();
4133        let coherence = asi.coherence("p1");
4134        // coherence() returns 1.0 when the provider is unknown (no entries in the window).
4135        assert!(
4136            (coherence - 1.0).abs() < f32::EPSILON,
4137            "ASI window must be empty after embed timeout; coherence={coherence}"
4138        );
4139    }
4140}