Skip to main content

llmux/
switcher.rs

1//! Model Switcher — coordinates wake/sleep transitions between models.
2//!
3//! The switcher tracks which model is active, manages in-flight request
4//! counting, drains requests before switching, and delegates all lifecycle
5//! operations to the [`HookRunner`].
6
7use crate::cost::SwitchCostTracker;
8use crate::hooks::HookRunner;
9use crate::policy::{PolicyContext, PolicyDecision, ScheduleContext, SwitchContext, SwitchPolicy};
10use crate::types::{SwitchError, SwitcherState};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
14use std::time::{Duration, Instant};
15use tokio::sync::{Mutex, Notify, RwLock, mpsc, oneshot};
16use tracing::{debug, error, info, trace, warn};
17
18/// Signal sent through the oneshot channel when a model becomes active.
19///
20/// The receiver must call [`ReadySignal::settle`] after acquiring its
21/// in-flight guard. `notify_pending` blocks on these settle signals so
22/// that the switch lock is held until all notified requests are actively
23/// being processed — preventing the next switch from draining the model
24/// before any request has started.
25pub(crate) struct ReadySignal {
26    settle_tx: mpsc::Sender<()>,
27}
28
29impl ReadySignal {
30    pub(crate) async fn settle(self) {
31        let _ = self.settle_tx.send(()).await;
32    }
33}
34
35/// A pending request waiting for a model to become active
36struct PendingRequest {
37    #[allow(dead_code)]
38    model: String,
39    queued_at: Instant,
40    ready_tx: oneshot::Sender<Result<ReadySignal, SwitchError>>,
41}
42
43/// Per-model state tracking
44struct ModelState {
45    in_flight: AtomicUsize,
46    pending: Mutex<Vec<PendingRequest>>,
47    in_flight_changed: Arc<Notify>,
48    /// Set to `true` while draining in-flight requests before sleep.
49    /// When true, `acquire_in_flight` will refuse new guards so that
50    /// no requests sneak in between drain completion and the actual sleep call.
51    draining: AtomicBool,
52}
53
54impl Default for ModelState {
55    fn default() -> Self {
56        Self {
57            in_flight: AtomicUsize::new(0),
58            pending: Mutex::new(Vec::new()),
59            in_flight_changed: Arc::new(Notify::new()),
60            draining: AtomicBool::new(false),
61        }
62    }
63}
64
65struct SwitcherInner {
66    hooks: Arc<HookRunner>,
67    policy: Box<dyn SwitchPolicy>,
68    state: RwLock<SwitcherState>,
69    model_states: HashMap<String, Arc<ModelState>>,
70    switch_lock: Mutex<()>,
71    /// When the currently active model was activated (for cooldown enforcement)
72    activated_at: RwLock<Option<Instant>>,
73    /// When the last switch failure occurred (for backoff)
74    last_switch_failure: RwLock<Option<Instant>>,
75    /// Empirical switch cost tracking (EMA of observed durations)
76    cost_tracker: SwitchCostTracker,
77}
78
79/// The model switcher coordinates wake/sleep transitions.
80pub struct ModelSwitcher {
81    inner: Arc<SwitcherInner>,
82}
83
84impl Clone for ModelSwitcher {
85    fn clone(&self) -> Self {
86        Self {
87            inner: Arc::clone(&self.inner),
88        }
89    }
90}
91
92impl ModelSwitcher {
93    pub fn new(hooks: Arc<HookRunner>, policy: Box<dyn SwitchPolicy>) -> Self {
94        let model_states: HashMap<String, Arc<ModelState>> = hooks
95            .registered_models()
96            .into_iter()
97            .map(|model| (model, Arc::new(ModelState::default())))
98            .collect();
99
100        Self {
101            inner: Arc::new(SwitcherInner {
102                hooks,
103                policy,
104                state: RwLock::new(SwitcherState::Idle),
105                model_states,
106                switch_lock: Mutex::new(()),
107                activated_at: RwLock::new(None),
108                last_switch_failure: RwLock::new(None),
109                cost_tracker: SwitchCostTracker::new(0.3),
110            }),
111        }
112    }
113
114    pub async fn state(&self) -> SwitcherState {
115        self.inner.state.read().await.clone()
116    }
117
118    pub async fn active_model(&self) -> Option<String> {
119        match &*self.inner.state.read().await {
120            SwitcherState::Active { model } => Some(model.clone()),
121            _ => None,
122        }
123    }
124
125    pub fn registered_models(&self) -> Vec<String> {
126        self.inner.model_states.keys().cloned().collect()
127    }
128
129    pub fn hooks(&self) -> &Arc<HookRunner> {
130        &self.inner.hooks
131    }
132
133    pub fn is_registered(&self, model: &str) -> bool {
134        self.inner.model_states.contains_key(model)
135    }
136
137    pub fn model_port(&self, model: &str) -> Option<u16> {
138        self.inner.hooks.model_port(model)
139    }
140
141    pub fn in_flight_count(&self, model: &str) -> usize {
142        self.inner
143            .model_states
144            .get(model)
145            .map(|s| s.in_flight.load(Ordering::SeqCst))
146            .unwrap_or(0)
147    }
148
149    /// Estimated cost of switching between two models, based on observed durations.
150    /// `from` is `None` for cold starts from Idle.
151    pub fn estimated_switch_cost(&self, from: Option<&str>, to: &str) -> Option<Duration> {
152        self.inner.cost_tracker.estimate(from, to)
153    }
154
155    /// Force a switch to the given model, bypassing policy.
156    pub async fn force_switch(&self, model: &str) -> Result<(), SwitchError> {
157        if !self.is_registered(model) {
158            return Err(SwitchError::ModelNotFound(model.to_string()));
159        }
160
161        // If already active, nothing to do
162        {
163            let state = self.inner.state.read().await;
164            if let SwitcherState::Active { model: active } = &*state
165                && active == model
166            {
167                return Ok(());
168            }
169        }
170
171        self.do_switch(model).await;
172
173        let state = self.inner.state.read().await;
174        match &*state {
175            SwitcherState::Active { model: active } if active == model => Ok(()),
176            _ => Err(SwitchError::NotReady(model.to_string())),
177        }
178    }
179
180    /// Ensure a model is ready for requests.
181    ///
182    /// Returns `Ok(None)` immediately if the model is already active (fast
183    /// path). Otherwise queues the request, triggers a switch if needed, and
184    /// waits up to the policy timeout. Returns `Ok(Some(signal))` — the
185    /// caller **must** call [`ReadySignal::settle`] after acquiring its
186    /// in-flight guard so that `notify_pending` knows the request is actively
187    /// being processed.
188    pub(crate) async fn ensure_model_ready(
189        &self,
190        model: &str,
191    ) -> Result<Option<ReadySignal>, SwitchError> {
192        let model_state = self
193            .inner
194            .model_states
195            .get(model)
196            .ok_or_else(|| SwitchError::ModelNotFound(model.to_string()))?;
197
198        // Fast path: model is already active
199        {
200            let state = self.inner.state.read().await;
201            if let SwitcherState::Active { model: active } = &*state
202                && active == model
203            {
204                trace!(model = %model, "Model already active");
205                return Ok(None);
206            }
207        }
208
209        // Queue the request
210        let (ready_tx, ready_rx) = oneshot::channel();
211        let pending = PendingRequest {
212            model: model.to_string(),
213            queued_at: Instant::now(),
214            ready_tx,
215        };
216
217        {
218            let mut queue = model_state.pending.lock().await;
219            queue.push(pending);
220            let depth = queue.len();
221            debug!(model = %model, queue_depth = depth, "Request queued");
222            metrics::gauge!("llmux_request_queue_depth", "model" => model.to_string())
223                .set(depth as f64);
224        }
225
226        self.maybe_trigger_switch(model).await;
227
228        match self.inner.policy.request_timeout() {
229            Some(timeout) => match tokio::time::timeout(timeout, ready_rx).await {
230                Ok(Ok(result)) => result.map(Some),
231                Ok(Err(_)) => Err(SwitchError::Internal("channel closed".to_string())),
232                Err(_) => {
233                    warn!(
234                        event = "request_timeout",
235                        model = %model,
236                        timeout_secs = timeout.as_secs_f64(),
237                        "Request timed out waiting for model"
238                    );
239                    Err(SwitchError::Timeout)
240                }
241            },
242            None => match ready_rx.await {
243                Ok(result) => result.map(Some),
244                Err(_) => Err(SwitchError::Internal("channel closed".to_string())),
245            },
246        }
247    }
248
249    /// Acquire an in-flight guard.
250    ///
251    /// Returns `None` if the model is not registered or if it is currently
252    /// draining. Uses increment-then-check to close the TOCTOU window.
253    pub fn acquire_in_flight(&self, model: &str) -> Option<InFlightGuard> {
254        let model_state = self.inner.model_states.get(model)?;
255
256        let new_count = model_state.in_flight.fetch_add(1, Ordering::SeqCst) + 1;
257
258        if model_state.draining.load(Ordering::SeqCst) {
259            model_state.in_flight.fetch_sub(1, Ordering::SeqCst);
260            model_state.in_flight_changed.notify_waiters();
261            return None;
262        }
263
264        metrics::gauge!("llmux_model_in_flight", "model" => model.to_string())
265            .set(new_count as f64);
266
267        Some(InFlightGuard {
268            model_state: Arc::clone(model_state),
269            model: model.to_string(),
270        })
271    }
272
273    /// Get queue depths for every registered model.
274    pub async fn queue_depths(&self) -> HashMap<String, usize> {
275        let mut depths = HashMap::new();
276        for (model, state) in &self.inner.model_states {
277            let queue = state.pending.lock().await;
278            depths.insert(model.clone(), queue.len());
279        }
280        depths
281    }
282
283    /// Spawn a background scheduler task if the policy requests one.
284    pub fn spawn_scheduler(self) -> Option<tokio::task::JoinHandle<()>> {
285        let interval = self.inner.policy.scheduler_interval()?;
286
287        info!(
288            interval_ms = interval.as_millis(),
289            "Spawning background scheduler"
290        );
291
292        Some(tokio::spawn(async move {
293            let mut tick = tokio::time::interval(interval);
294            tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
295
296            loop {
297                tick.tick().await;
298                let ctx = self.build_schedule_context().await;
299                if let Some(target) = self.inner.policy.schedule_tick(&ctx) {
300                    debug!(target = %target, "Scheduler: triggering switch");
301                    self.do_switch(&target).await;
302                }
303            }
304        }))
305    }
306
307    // -----------------------------------------------------------------------
308    // Private
309    // -----------------------------------------------------------------------
310
311    async fn maybe_trigger_switch(&self, target_model: &str) {
312        let model_state = match self.inner.model_states.get(target_model) {
313            Some(s) => s,
314            None => return,
315        };
316
317        let ctx = {
318            let state = self.inner.state.read().await;
319            let queue = model_state.pending.lock().await;
320
321            let oldest_waiting = queue
322                .first()
323                .map(|p| p.queued_at.elapsed())
324                .unwrap_or(Duration::ZERO);
325
326            let (active_model, active_in_flight) = match &*state {
327                SwitcherState::Active { model } => {
328                    (Some(model.clone()), self.in_flight_count(model))
329                }
330                _ => (None, 0),
331            };
332
333            let active_duration = self
334                .inner
335                .activated_at
336                .read()
337                .await
338                .map(|t| t.elapsed())
339                .unwrap_or(Duration::ZERO);
340
341            let estimated_switch_cost = self
342                .inner
343                .cost_tracker
344                .estimate(active_model.as_deref(), target_model);
345
346            PolicyContext {
347                target_model: target_model.to_string(),
348                active_model,
349                target_queue_depth: queue.len(),
350                oldest_waiting,
351                active_in_flight,
352                active_duration,
353                estimated_switch_cost,
354            }
355        };
356
357        // Already switching to this model?
358        {
359            let state = self.inner.state.read().await;
360            if let SwitcherState::Switching { to, .. } = &*state
361                && to == target_model
362            {
363                return;
364            }
365        }
366
367        let decision = self.inner.policy.on_pending_request(&ctx).await;
368
369        match decision {
370            PolicyDecision::SwitchNow => {
371                debug!(model = %target_model, "Policy: switch now");
372                let switcher = self.clone();
373                let target = target_model.to_string();
374                tokio::spawn(async move {
375                    switcher.do_switch(&target).await;
376                });
377            }
378            PolicyDecision::Defer(future) => {
379                debug!(model = %target_model, "Policy: defer");
380                let switcher = self.clone();
381                let target = target_model.to_string();
382                tokio::spawn(async move {
383                    future.await;
384                    switcher.do_switch(&target).await;
385                });
386            }
387            PolicyDecision::Skip => {
388                trace!(model = %target_model, "Policy: skip");
389            }
390        }
391    }
392
393    async fn do_switch(&self, target_model: &str) {
394        let _guard = self.inner.switch_lock.lock().await;
395        let switch_start = Instant::now();
396
397        // Backoff after recent failure
398        {
399            let last_failure = self.inner.last_switch_failure.read().await;
400            if let Some(failed_at) = *last_failure {
401                let backoff = Duration::from_secs(2);
402                let elapsed = failed_at.elapsed();
403                if elapsed < backoff {
404                    let remaining = backoff - elapsed;
405                    info!(remaining = ?remaining, "Backing off after recent switch failure");
406                    drop(last_failure);
407                    tokio::time::sleep(remaining).await;
408                }
409            }
410        }
411
412        // Double-check state
413        {
414            let state = self.inner.state.read().await;
415            match &*state {
416                SwitcherState::Active { model } if model == target_model => {
417                    self.notify_pending(target_model, Ok(())).await;
418                    return;
419                }
420                SwitcherState::Switching { to, .. } if to == target_model => {
421                    return;
422                }
423                _ => {}
424            }
425        }
426
427        // Skip if there are no pending requests for the target model.
428        // This happens when a stale do_switch task grabs the lock after
429        // the target model's requests were already served by an earlier
430        // activation.
431        if let Some(target_state) = self.inner.model_states.get(target_model) {
432            let queue = target_state.pending.lock().await;
433            if queue.is_empty() {
434                debug!(
435                    model = %target_model,
436                    "No pending requests, skipping stale switch"
437                );
438                return;
439            }
440        }
441
442        let from_model = {
443            let state = self.inner.state.read().await;
444            match &*state {
445                SwitcherState::Active { model } => Some(model.clone()),
446                _ => None,
447            }
448        };
449
450        let from_label = from_model.as_deref().unwrap_or("idle").to_string();
451
452        // Record how long the outgoing model was active
453        if from_model.is_some() {
454            let active_dur = self
455                .inner
456                .activated_at
457                .read()
458                .await
459                .map(|t| t.elapsed())
460                .unwrap_or(Duration::ZERO);
461            metrics::histogram!(
462                "llmux_model_active_duration_seconds",
463                "model" => from_label.clone()
464            )
465            .record(active_dur.as_secs_f64());
466        }
467
468        // Update state to Switching
469        {
470            let mut state = self.inner.state.write().await;
471            *state = SwitcherState::Switching {
472                from: from_model.clone(),
473                to: target_model.to_string(),
474            };
475        }
476
477        info!(
478            event = "switch_started",
479            from = %from_label,
480            to = %target_model,
481            "Starting model switch"
482        );
483
484        // Phase 1: Cooldown — ensure the old model has been active long enough
485        if from_model.is_some() {
486            let min_active = self.inner.policy.min_active_duration();
487            let activated_at = *self.inner.activated_at.read().await;
488            if let Some(activated) = activated_at {
489                let elapsed = activated.elapsed();
490                if elapsed < min_active {
491                    let remaining = min_active - elapsed;
492                    info!(remaining = ?remaining, "Waiting for cooldown");
493                    tokio::time::sleep(remaining).await;
494                }
495            }
496        }
497
498        // Phase 2: Drain — set draining flag and wait for in-flight to complete
499        let drain_start = Instant::now();
500
501        if let Some(ref from) = from_model
502            && let Some(from_state) = self.inner.model_states.get(from)
503        {
504            from_state.draining.store(true, Ordering::SeqCst);
505        }
506
507        if let Some(ref from) = from_model
508            && let Some(from_state) = self.inner.model_states.get(from)
509        {
510            let in_flight_changed = Arc::clone(&from_state.in_flight_changed);
511            let from_state_clone = Arc::clone(from_state);
512
513            let mut switch_ctx = SwitchContext::new(
514                from_model.clone(),
515                target_model.to_string(),
516                in_flight_changed,
517                Box::new(move || from_state_clone.in_flight.load(Ordering::SeqCst)),
518            );
519
520            self.inner.policy.prepare_switch(&mut switch_ctx).await;
521        }
522
523        if from_model.is_some() {
524            metrics::histogram!(
525                "llmux_switch_drain_duration_seconds",
526                "from" => from_label.clone(),
527                "to" => target_model.to_string()
528            )
529            .record(drain_start.elapsed().as_secs_f64());
530        }
531
532        // Phase 3: Sleep old model via hook
533        if let Some(ref from) = from_model {
534            debug!(model = %from, "Running sleep hook");
535            if let Err(e) = self.inner.hooks.run_sleep(from).await {
536                error!(
537                    event = "sleep_hook_failed",
538                    model = %from,
539                    to = %target_model,
540                    error = %e,
541                    "Sleep hook failed, continuing with wake (idempotent)"
542                );
543            }
544        }
545
546        // Clear draining flag
547        if let Some(ref from) = from_model
548            && let Some(from_state) = self.inner.model_states.get(from)
549        {
550            from_state.draining.store(false, Ordering::SeqCst);
551        }
552
553        // Phase 4: Wake new model via hook
554        debug!(model = %target_model, "Running wake hook");
555        match self.inner.hooks.run_wake(target_model).await {
556            Ok(()) => {
557                let total_dur = switch_start.elapsed();
558
559                {
560                    let mut state = self.inner.state.write().await;
561                    *state = SwitcherState::Active {
562                        model: target_model.to_string(),
563                    };
564                }
565                *self.inner.activated_at.write().await = Some(Instant::now());
566                *self.inner.last_switch_failure.write().await = None;
567
568                self.inner
569                    .cost_tracker
570                    .record(from_model.as_deref(), target_model, total_dur);
571
572                // Structured log event for timeline reconstruction
573                info!(
574                    event = "model_activated",
575                    model = %target_model,
576                    from = %from_label,
577                    duration_secs = total_dur.as_secs_f64(),
578                    "Model is now active"
579                );
580
581                // Metrics
582                metrics::counter!(
583                    "llmux_switch_total",
584                    "from" => from_label.clone(),
585                    "to" => target_model.to_string(),
586                    "result" => "success"
587                )
588                .increment(1);
589                metrics::histogram!(
590                    "llmux_switch_duration_seconds",
591                    "from" => from_label.clone(),
592                    "to" => target_model.to_string()
593                )
594                .record(total_dur.as_secs_f64());
595
596                if let Some(ema) = self
597                    .inner
598                    .cost_tracker
599                    .estimate(from_model.as_deref(), target_model)
600                {
601                    metrics::gauge!(
602                        "llmux_switch_cost_ema_seconds",
603                        "from" => from_label,
604                        "to" => target_model.to_string()
605                    )
606                    .set(ema.as_secs_f64());
607                }
608
609                let from_str = from_model.as_deref().unwrap_or("");
610                self.inner
611                    .policy
612                    .on_switch_complete(from_str, target_model, total_dur);
613
614                self.notify_pending(target_model, Ok(())).await;
615            }
616            Err(e) => {
617                // Structured log event for timeline reconstruction
618                info!(
619                    event = "switch_failed",
620                    model = %target_model,
621                    from = %from_label,
622                    error = %e,
623                    "Switch failed, returning to idle"
624                );
625
626                metrics::counter!(
627                    "llmux_switch_total",
628                    "from" => from_label,
629                    "to" => target_model.to_string(),
630                    "result" => "failure"
631                )
632                .increment(1);
633
634                // Try to clean up the partially-woken model
635                let _ = self.inner.hooks.run_sleep(target_model).await;
636
637                *self.inner.last_switch_failure.write().await = Some(Instant::now());
638                {
639                    let mut state = self.inner.state.write().await;
640                    *state = SwitcherState::Idle;
641                }
642
643                self.notify_pending(
644                    target_model,
645                    Err(SwitchError::HookFailed {
646                        model: target_model.to_string(),
647                        detail: e.to_string(),
648                    }),
649                )
650                .await;
651            }
652        }
653    }
654
655    async fn build_schedule_context(&self) -> ScheduleContext {
656        let (active_model, active_in_flight) = match &*self.inner.state.read().await {
657            SwitcherState::Active { model } => (Some(model.clone()), self.in_flight_count(model)),
658            _ => (None, 0),
659        };
660
661        let active_duration = self
662            .inner
663            .activated_at
664            .read()
665            .await
666            .map(|t| t.elapsed())
667            .unwrap_or(Duration::ZERO);
668
669        let queue_depths = self.queue_depths().await;
670
671        let switch_costs = self
672            .inner
673            .cost_tracker
674            .estimates_from(active_model.as_deref());
675
676        ScheduleContext {
677            active_model,
678            active_duration,
679            queue_depths,
680            active_in_flight,
681            switch_costs,
682        }
683    }
684
685    /// Notify pending requests and — on success — wait for them to settle.
686    ///
687    /// When the result is `Ok`, each notified request receives a
688    /// [`ReadySignal`]. The request must call `settle()` after acquiring
689    /// its in-flight guard. This method blocks until all delivered signals
690    /// have settled (or their receivers are dropped), ensuring the switch
691    /// lock is held while requests transition from "notified" to "in-flight."
692    ///
693    /// When the result is `Err`, requests are notified of the failure and
694    /// no settle wait is needed.
695    async fn notify_pending(&self, model: &str, result: Result<(), SwitchError>) {
696        let Some(model_state) = self.inner.model_states.get(model) else {
697            return;
698        };
699
700        let mut queue = model_state.pending.lock().await;
701        let count = queue.len();
702        if count == 0 {
703            return;
704        }
705
706        let mut delivered = 0;
707
708        // For successful activations, create a settle channel so we can
709        // wait for notified requests to acquire their in-flight guards.
710        let settle_tx = if result.is_ok() {
711            // +1 capacity: one per request, non-blocking sends
712            Some(mpsc::channel::<()>(count))
713        } else {
714            None
715        };
716
717        for pending in queue.drain(..) {
718            let r = match (&result, &settle_tx) {
719                (Ok(()), Some((tx, _))) => Ok(ReadySignal {
720                    settle_tx: tx.clone(),
721                }),
722                (Err(e), _) => Err(SwitchError::Internal(e.to_string())),
723                _ => unreachable!(),
724            };
725            if pending.ready_tx.send(r).is_ok() {
726                delivered += 1;
727            }
728        }
729
730        metrics::gauge!("llmux_request_queue_depth", "model" => model.to_string()).set(0.0);
731        drop(queue); // release pending lock before the settle wait
732
733        if count > 0 {
734            let expired = count - delivered;
735            if expired > 0 {
736                warn!(model = %model, count, delivered, expired,
737                    "Notified pending requests ({expired} already timed out)");
738            } else {
739                debug!(model = %model, count, "Notified pending requests");
740            }
741        }
742
743        // Wait for all delivered requests to acquire in-flight guards.
744        // Each request calls ReadySignal::settle() after acquiring its
745        // guard, which sends () on the channel. If a request drops its
746        // signal without settling (e.g. cancelled), its sender clone is
747        // dropped; once all senders are gone the channel closes and recv
748        // returns None.
749        if let Some((tx, mut rx)) = settle_tx {
750            // Drop the original sender — only the clones sent to requests
751            // should keep the channel alive.
752            drop(tx);
753
754            if delivered > 0 {
755                let settle_wait = async {
756                    for _ in 0..delivered {
757                        if rx.recv().await.is_none() {
758                            break; // all senders dropped
759                        }
760                    }
761                };
762
763                if tokio::time::timeout(Duration::from_secs(5), settle_wait)
764                    .await
765                    .is_err()
766                {
767                    warn!(
768                        model = %model,
769                        delivered,
770                        "Settle timeout — proceeding with switch lock release"
771                    );
772                }
773            }
774        }
775    }
776}
777
778/// Guard that tracks in-flight requests. When dropped, decrements the count
779/// and notifies the drain waiter.
780pub struct InFlightGuard {
781    model_state: Arc<ModelState>,
782    model: String,
783}
784
785impl Drop for InFlightGuard {
786    fn drop(&mut self) {
787        let prev = self.model_state.in_flight.fetch_sub(1, Ordering::SeqCst);
788        metrics::gauge!("llmux_model_in_flight", "model" => self.model.clone())
789            .set((prev - 1) as f64);
790        if prev == 1 {
791            self.model_state.in_flight_changed.notify_waiters();
792        }
793    }
794}
795
796#[cfg(test)]
797mod tests {
798    use super::*;
799    use crate::config::ModelConfig;
800    use crate::policy::FifoPolicy;
801
802    fn make_test_hooks() -> Arc<HookRunner> {
803        let mut configs = HashMap::new();
804        configs.insert(
805            "model-a".to_string(),
806            ModelConfig {
807                port: 8001,
808                wake: "true".to_string(),
809                sleep: "true".to_string(),
810                alive: "true".to_string(),
811            },
812        );
813        configs.insert(
814            "model-b".to_string(),
815            ModelConfig {
816                port: 8002,
817                wake: "true".to_string(),
818                sleep: "true".to_string(),
819                alive: "true".to_string(),
820            },
821        );
822        Arc::new(HookRunner::new(configs))
823    }
824
825    #[test]
826    fn test_switcher_creation() {
827        let hooks = make_test_hooks();
828        let policy = Box::new(FifoPolicy::default());
829        let switcher = ModelSwitcher::new(hooks, policy);
830
831        assert!(switcher.is_registered("model-a"));
832        assert!(switcher.is_registered("model-b"));
833        assert!(!switcher.is_registered("model-c"));
834    }
835
836    #[tokio::test]
837    async fn test_in_flight_tracking() {
838        let hooks = make_test_hooks();
839        let policy = Box::new(FifoPolicy::default());
840        let switcher = ModelSwitcher::new(hooks, policy);
841
842        assert_eq!(switcher.in_flight_count("model-a"), 0);
843
844        {
845            let _guard = switcher.acquire_in_flight("model-a");
846            assert_eq!(switcher.in_flight_count("model-a"), 1);
847        }
848
849        assert_eq!(switcher.in_flight_count("model-a"), 0);
850    }
851
852    #[test]
853    fn test_acquire_in_flight_rejected_while_draining() {
854        let hooks = make_test_hooks();
855        let policy = Box::new(FifoPolicy::default());
856        let switcher = ModelSwitcher::new(hooks, policy);
857
858        let guard = switcher.acquire_in_flight("model-a");
859        assert!(guard.is_some());
860        assert_eq!(switcher.in_flight_count("model-a"), 1);
861        drop(guard);
862
863        // Set draining flag
864        let model_state = switcher.inner.model_states.get("model-a").unwrap();
865        model_state.draining.store(true, Ordering::SeqCst);
866
867        let guard = switcher.acquire_in_flight("model-a");
868        assert!(guard.is_none());
869        assert_eq!(switcher.in_flight_count("model-a"), 0);
870
871        model_state.draining.store(false, Ordering::SeqCst);
872
873        let guard = switcher.acquire_in_flight("model-a");
874        assert!(guard.is_some());
875        assert_eq!(switcher.in_flight_count("model-a"), 1);
876        drop(guard);
877    }
878
879    #[tokio::test]
880    async fn test_model_port() {
881        let hooks = make_test_hooks();
882        let policy = Box::new(FifoPolicy::default());
883        let switcher = ModelSwitcher::new(hooks, policy);
884
885        assert_eq!(switcher.model_port("model-a"), Some(8001));
886        assert_eq!(switcher.model_port("model-b"), Some(8002));
887        assert_eq!(switcher.model_port("model-c"), None);
888    }
889
890    #[tokio::test]
891    async fn test_force_switch_unknown_model() {
892        let hooks = make_test_hooks();
893        let policy = Box::new(FifoPolicy::default());
894        let switcher = ModelSwitcher::new(hooks, policy);
895
896        let result = switcher.force_switch("nonexistent").await;
897        assert!(matches!(result, Err(SwitchError::ModelNotFound(_))));
898    }
899
900    #[tokio::test]
901    async fn test_force_switch_already_active() {
902        let hooks = make_test_hooks();
903        let policy = Box::new(FifoPolicy::default());
904        let switcher = ModelSwitcher::new(hooks, policy);
905
906        {
907            let mut state = switcher.inner.state.write().await;
908            *state = SwitcherState::Active {
909                model: "model-a".to_string(),
910            };
911        }
912
913        let result = switcher.force_switch("model-a").await;
914        assert!(result.is_ok());
915    }
916}