1use crate::cost::SwitchCostTracker;
8use crate::hooks::HookRunner;
9use crate::policy::{PolicyContext, PolicyDecision, ScheduleContext, SwitchContext, SwitchPolicy};
10use crate::types::{SwitchError, SwitcherState};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
14use std::time::{Duration, Instant};
15use tokio::sync::{Mutex, Notify, RwLock, mpsc, oneshot};
16use tracing::{debug, error, info, trace, warn};
17
18pub(crate) struct ReadySignal {
26 settle_tx: mpsc::Sender<()>,
27}
28
29impl ReadySignal {
30 pub(crate) async fn settle(self) {
31 let _ = self.settle_tx.send(()).await;
32 }
33}
34
35struct PendingRequest {
37 #[allow(dead_code)]
38 model: String,
39 queued_at: Instant,
40 ready_tx: oneshot::Sender<Result<ReadySignal, SwitchError>>,
41}
42
43struct ModelState {
45 in_flight: AtomicUsize,
46 pending: Mutex<Vec<PendingRequest>>,
47 in_flight_changed: Arc<Notify>,
48 draining: AtomicBool,
52}
53
54impl Default for ModelState {
55 fn default() -> Self {
56 Self {
57 in_flight: AtomicUsize::new(0),
58 pending: Mutex::new(Vec::new()),
59 in_flight_changed: Arc::new(Notify::new()),
60 draining: AtomicBool::new(false),
61 }
62 }
63}
64
65struct SwitcherInner {
66 hooks: Arc<HookRunner>,
67 policy: Box<dyn SwitchPolicy>,
68 state: RwLock<SwitcherState>,
69 model_states: HashMap<String, Arc<ModelState>>,
70 switch_lock: Mutex<()>,
71 activated_at: RwLock<Option<Instant>>,
73 last_switch_failure: RwLock<Option<Instant>>,
75 cost_tracker: SwitchCostTracker,
77}
78
79pub struct ModelSwitcher {
81 inner: Arc<SwitcherInner>,
82}
83
84impl Clone for ModelSwitcher {
85 fn clone(&self) -> Self {
86 Self {
87 inner: Arc::clone(&self.inner),
88 }
89 }
90}
91
92impl ModelSwitcher {
93 pub fn new(hooks: Arc<HookRunner>, policy: Box<dyn SwitchPolicy>) -> Self {
94 let model_states: HashMap<String, Arc<ModelState>> = hooks
95 .registered_models()
96 .into_iter()
97 .map(|model| (model, Arc::new(ModelState::default())))
98 .collect();
99
100 Self {
101 inner: Arc::new(SwitcherInner {
102 hooks,
103 policy,
104 state: RwLock::new(SwitcherState::Idle),
105 model_states,
106 switch_lock: Mutex::new(()),
107 activated_at: RwLock::new(None),
108 last_switch_failure: RwLock::new(None),
109 cost_tracker: SwitchCostTracker::new(0.3),
110 }),
111 }
112 }
113
114 pub async fn state(&self) -> SwitcherState {
115 self.inner.state.read().await.clone()
116 }
117
118 pub async fn active_model(&self) -> Option<String> {
119 match &*self.inner.state.read().await {
120 SwitcherState::Active { model } => Some(model.clone()),
121 _ => None,
122 }
123 }
124
125 pub fn registered_models(&self) -> Vec<String> {
126 self.inner.model_states.keys().cloned().collect()
127 }
128
129 pub fn hooks(&self) -> &Arc<HookRunner> {
130 &self.inner.hooks
131 }
132
133 pub fn is_registered(&self, model: &str) -> bool {
134 self.inner.model_states.contains_key(model)
135 }
136
137 pub fn model_port(&self, model: &str) -> Option<u16> {
138 self.inner.hooks.model_port(model)
139 }
140
141 pub fn in_flight_count(&self, model: &str) -> usize {
142 self.inner
143 .model_states
144 .get(model)
145 .map(|s| s.in_flight.load(Ordering::SeqCst))
146 .unwrap_or(0)
147 }
148
149 pub fn estimated_switch_cost(&self, from: Option<&str>, to: &str) -> Option<Duration> {
152 self.inner.cost_tracker.estimate(from, to)
153 }
154
155 pub async fn force_switch(&self, model: &str) -> Result<(), SwitchError> {
157 if !self.is_registered(model) {
158 return Err(SwitchError::ModelNotFound(model.to_string()));
159 }
160
161 {
163 let state = self.inner.state.read().await;
164 if let SwitcherState::Active { model: active } = &*state
165 && active == model
166 {
167 return Ok(());
168 }
169 }
170
171 self.do_switch(model).await;
172
173 let state = self.inner.state.read().await;
174 match &*state {
175 SwitcherState::Active { model: active } if active == model => Ok(()),
176 _ => Err(SwitchError::NotReady(model.to_string())),
177 }
178 }
179
180 pub(crate) async fn ensure_model_ready(
189 &self,
190 model: &str,
191 ) -> Result<Option<ReadySignal>, SwitchError> {
192 let model_state = self
193 .inner
194 .model_states
195 .get(model)
196 .ok_or_else(|| SwitchError::ModelNotFound(model.to_string()))?;
197
198 {
200 let state = self.inner.state.read().await;
201 if let SwitcherState::Active { model: active } = &*state
202 && active == model
203 {
204 trace!(model = %model, "Model already active");
205 return Ok(None);
206 }
207 }
208
209 let (ready_tx, ready_rx) = oneshot::channel();
211 let pending = PendingRequest {
212 model: model.to_string(),
213 queued_at: Instant::now(),
214 ready_tx,
215 };
216
217 {
218 let mut queue = model_state.pending.lock().await;
219 queue.push(pending);
220 let depth = queue.len();
221 debug!(model = %model, queue_depth = depth, "Request queued");
222 metrics::gauge!("llmux_request_queue_depth", "model" => model.to_string())
223 .set(depth as f64);
224 }
225
226 self.maybe_trigger_switch(model).await;
227
228 match self.inner.policy.request_timeout() {
229 Some(timeout) => match tokio::time::timeout(timeout, ready_rx).await {
230 Ok(Ok(result)) => result.map(Some),
231 Ok(Err(_)) => Err(SwitchError::Internal("channel closed".to_string())),
232 Err(_) => {
233 warn!(
234 event = "request_timeout",
235 model = %model,
236 timeout_secs = timeout.as_secs_f64(),
237 "Request timed out waiting for model"
238 );
239 Err(SwitchError::Timeout)
240 }
241 },
242 None => match ready_rx.await {
243 Ok(result) => result.map(Some),
244 Err(_) => Err(SwitchError::Internal("channel closed".to_string())),
245 },
246 }
247 }
248
249 pub fn acquire_in_flight(&self, model: &str) -> Option<InFlightGuard> {
254 let model_state = self.inner.model_states.get(model)?;
255
256 let new_count = model_state.in_flight.fetch_add(1, Ordering::SeqCst) + 1;
257
258 if model_state.draining.load(Ordering::SeqCst) {
259 model_state.in_flight.fetch_sub(1, Ordering::SeqCst);
260 model_state.in_flight_changed.notify_waiters();
261 return None;
262 }
263
264 metrics::gauge!("llmux_model_in_flight", "model" => model.to_string())
265 .set(new_count as f64);
266
267 Some(InFlightGuard {
268 model_state: Arc::clone(model_state),
269 model: model.to_string(),
270 })
271 }
272
273 pub async fn queue_depths(&self) -> HashMap<String, usize> {
275 let mut depths = HashMap::new();
276 for (model, state) in &self.inner.model_states {
277 let queue = state.pending.lock().await;
278 depths.insert(model.clone(), queue.len());
279 }
280 depths
281 }
282
283 pub fn spawn_scheduler(self) -> Option<tokio::task::JoinHandle<()>> {
285 let interval = self.inner.policy.scheduler_interval()?;
286
287 info!(
288 interval_ms = interval.as_millis(),
289 "Spawning background scheduler"
290 );
291
292 Some(tokio::spawn(async move {
293 let mut tick = tokio::time::interval(interval);
294 tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
295
296 loop {
297 tick.tick().await;
298 let ctx = self.build_schedule_context().await;
299 if let Some(target) = self.inner.policy.schedule_tick(&ctx) {
300 debug!(target = %target, "Scheduler: triggering switch");
301 self.do_switch(&target).await;
302 }
303 }
304 }))
305 }
306
307 async fn maybe_trigger_switch(&self, target_model: &str) {
312 let model_state = match self.inner.model_states.get(target_model) {
313 Some(s) => s,
314 None => return,
315 };
316
317 let ctx = {
318 let state = self.inner.state.read().await;
319 let queue = model_state.pending.lock().await;
320
321 let oldest_waiting = queue
322 .first()
323 .map(|p| p.queued_at.elapsed())
324 .unwrap_or(Duration::ZERO);
325
326 let (active_model, active_in_flight) = match &*state {
327 SwitcherState::Active { model } => {
328 (Some(model.clone()), self.in_flight_count(model))
329 }
330 _ => (None, 0),
331 };
332
333 let active_duration = self
334 .inner
335 .activated_at
336 .read()
337 .await
338 .map(|t| t.elapsed())
339 .unwrap_or(Duration::ZERO);
340
341 let estimated_switch_cost = self
342 .inner
343 .cost_tracker
344 .estimate(active_model.as_deref(), target_model);
345
346 PolicyContext {
347 target_model: target_model.to_string(),
348 active_model,
349 target_queue_depth: queue.len(),
350 oldest_waiting,
351 active_in_flight,
352 active_duration,
353 estimated_switch_cost,
354 }
355 };
356
357 {
359 let state = self.inner.state.read().await;
360 if let SwitcherState::Switching { to, .. } = &*state
361 && to == target_model
362 {
363 return;
364 }
365 }
366
367 let decision = self.inner.policy.on_pending_request(&ctx).await;
368
369 match decision {
370 PolicyDecision::SwitchNow => {
371 debug!(model = %target_model, "Policy: switch now");
372 let switcher = self.clone();
373 let target = target_model.to_string();
374 tokio::spawn(async move {
375 switcher.do_switch(&target).await;
376 });
377 }
378 PolicyDecision::Defer(future) => {
379 debug!(model = %target_model, "Policy: defer");
380 let switcher = self.clone();
381 let target = target_model.to_string();
382 tokio::spawn(async move {
383 future.await;
384 switcher.do_switch(&target).await;
385 });
386 }
387 PolicyDecision::Skip => {
388 trace!(model = %target_model, "Policy: skip");
389 }
390 }
391 }
392
393 async fn do_switch(&self, target_model: &str) {
394 let _guard = self.inner.switch_lock.lock().await;
395 let switch_start = Instant::now();
396
397 {
399 let last_failure = self.inner.last_switch_failure.read().await;
400 if let Some(failed_at) = *last_failure {
401 let backoff = Duration::from_secs(2);
402 let elapsed = failed_at.elapsed();
403 if elapsed < backoff {
404 let remaining = backoff - elapsed;
405 info!(remaining = ?remaining, "Backing off after recent switch failure");
406 drop(last_failure);
407 tokio::time::sleep(remaining).await;
408 }
409 }
410 }
411
412 {
414 let state = self.inner.state.read().await;
415 match &*state {
416 SwitcherState::Active { model } if model == target_model => {
417 self.notify_pending(target_model, Ok(())).await;
418 return;
419 }
420 SwitcherState::Switching { to, .. } if to == target_model => {
421 return;
422 }
423 _ => {}
424 }
425 }
426
427 if let Some(target_state) = self.inner.model_states.get(target_model) {
432 let queue = target_state.pending.lock().await;
433 if queue.is_empty() {
434 debug!(
435 model = %target_model,
436 "No pending requests, skipping stale switch"
437 );
438 return;
439 }
440 }
441
442 let from_model = {
443 let state = self.inner.state.read().await;
444 match &*state {
445 SwitcherState::Active { model } => Some(model.clone()),
446 _ => None,
447 }
448 };
449
450 let from_label = from_model.as_deref().unwrap_or("idle").to_string();
451
452 if from_model.is_some() {
454 let active_dur = self
455 .inner
456 .activated_at
457 .read()
458 .await
459 .map(|t| t.elapsed())
460 .unwrap_or(Duration::ZERO);
461 metrics::histogram!(
462 "llmux_model_active_duration_seconds",
463 "model" => from_label.clone()
464 )
465 .record(active_dur.as_secs_f64());
466 }
467
468 {
470 let mut state = self.inner.state.write().await;
471 *state = SwitcherState::Switching {
472 from: from_model.clone(),
473 to: target_model.to_string(),
474 };
475 }
476
477 info!(
478 event = "switch_started",
479 from = %from_label,
480 to = %target_model,
481 "Starting model switch"
482 );
483
484 if from_model.is_some() {
486 let min_active = self.inner.policy.min_active_duration();
487 let activated_at = *self.inner.activated_at.read().await;
488 if let Some(activated) = activated_at {
489 let elapsed = activated.elapsed();
490 if elapsed < min_active {
491 let remaining = min_active - elapsed;
492 info!(remaining = ?remaining, "Waiting for cooldown");
493 tokio::time::sleep(remaining).await;
494 }
495 }
496 }
497
498 let drain_start = Instant::now();
500
501 if let Some(ref from) = from_model
502 && let Some(from_state) = self.inner.model_states.get(from)
503 {
504 from_state.draining.store(true, Ordering::SeqCst);
505 }
506
507 if let Some(ref from) = from_model
508 && let Some(from_state) = self.inner.model_states.get(from)
509 {
510 let in_flight_changed = Arc::clone(&from_state.in_flight_changed);
511 let from_state_clone = Arc::clone(from_state);
512
513 let mut switch_ctx = SwitchContext::new(
514 from_model.clone(),
515 target_model.to_string(),
516 in_flight_changed,
517 Box::new(move || from_state_clone.in_flight.load(Ordering::SeqCst)),
518 );
519
520 self.inner.policy.prepare_switch(&mut switch_ctx).await;
521 }
522
523 if from_model.is_some() {
524 metrics::histogram!(
525 "llmux_switch_drain_duration_seconds",
526 "from" => from_label.clone(),
527 "to" => target_model.to_string()
528 )
529 .record(drain_start.elapsed().as_secs_f64());
530 }
531
532 if let Some(ref from) = from_model {
534 debug!(model = %from, "Running sleep hook");
535 if let Err(e) = self.inner.hooks.run_sleep(from).await {
536 error!(
537 event = "sleep_hook_failed",
538 model = %from,
539 to = %target_model,
540 error = %e,
541 "Sleep hook failed, continuing with wake (idempotent)"
542 );
543 }
544 }
545
546 if let Some(ref from) = from_model
548 && let Some(from_state) = self.inner.model_states.get(from)
549 {
550 from_state.draining.store(false, Ordering::SeqCst);
551 }
552
553 debug!(model = %target_model, "Running wake hook");
555 match self.inner.hooks.run_wake(target_model).await {
556 Ok(()) => {
557 let total_dur = switch_start.elapsed();
558
559 {
560 let mut state = self.inner.state.write().await;
561 *state = SwitcherState::Active {
562 model: target_model.to_string(),
563 };
564 }
565 *self.inner.activated_at.write().await = Some(Instant::now());
566 *self.inner.last_switch_failure.write().await = None;
567
568 self.inner
569 .cost_tracker
570 .record(from_model.as_deref(), target_model, total_dur);
571
572 info!(
574 event = "model_activated",
575 model = %target_model,
576 from = %from_label,
577 duration_secs = total_dur.as_secs_f64(),
578 "Model is now active"
579 );
580
581 metrics::counter!(
583 "llmux_switch_total",
584 "from" => from_label.clone(),
585 "to" => target_model.to_string(),
586 "result" => "success"
587 )
588 .increment(1);
589 metrics::histogram!(
590 "llmux_switch_duration_seconds",
591 "from" => from_label.clone(),
592 "to" => target_model.to_string()
593 )
594 .record(total_dur.as_secs_f64());
595
596 if let Some(ema) = self
597 .inner
598 .cost_tracker
599 .estimate(from_model.as_deref(), target_model)
600 {
601 metrics::gauge!(
602 "llmux_switch_cost_ema_seconds",
603 "from" => from_label,
604 "to" => target_model.to_string()
605 )
606 .set(ema.as_secs_f64());
607 }
608
609 let from_str = from_model.as_deref().unwrap_or("");
610 self.inner
611 .policy
612 .on_switch_complete(from_str, target_model, total_dur);
613
614 self.notify_pending(target_model, Ok(())).await;
615 }
616 Err(e) => {
617 info!(
619 event = "switch_failed",
620 model = %target_model,
621 from = %from_label,
622 error = %e,
623 "Switch failed, returning to idle"
624 );
625
626 metrics::counter!(
627 "llmux_switch_total",
628 "from" => from_label,
629 "to" => target_model.to_string(),
630 "result" => "failure"
631 )
632 .increment(1);
633
634 let _ = self.inner.hooks.run_sleep(target_model).await;
636
637 *self.inner.last_switch_failure.write().await = Some(Instant::now());
638 {
639 let mut state = self.inner.state.write().await;
640 *state = SwitcherState::Idle;
641 }
642
643 self.notify_pending(
644 target_model,
645 Err(SwitchError::HookFailed {
646 model: target_model.to_string(),
647 detail: e.to_string(),
648 }),
649 )
650 .await;
651 }
652 }
653 }
654
655 async fn build_schedule_context(&self) -> ScheduleContext {
656 let (active_model, active_in_flight) = match &*self.inner.state.read().await {
657 SwitcherState::Active { model } => (Some(model.clone()), self.in_flight_count(model)),
658 _ => (None, 0),
659 };
660
661 let active_duration = self
662 .inner
663 .activated_at
664 .read()
665 .await
666 .map(|t| t.elapsed())
667 .unwrap_or(Duration::ZERO);
668
669 let queue_depths = self.queue_depths().await;
670
671 let switch_costs = self
672 .inner
673 .cost_tracker
674 .estimates_from(active_model.as_deref());
675
676 ScheduleContext {
677 active_model,
678 active_duration,
679 queue_depths,
680 active_in_flight,
681 switch_costs,
682 }
683 }
684
685 async fn notify_pending(&self, model: &str, result: Result<(), SwitchError>) {
696 let Some(model_state) = self.inner.model_states.get(model) else {
697 return;
698 };
699
700 let mut queue = model_state.pending.lock().await;
701 let count = queue.len();
702 if count == 0 {
703 return;
704 }
705
706 let mut delivered = 0;
707
708 let settle_tx = if result.is_ok() {
711 Some(mpsc::channel::<()>(count))
713 } else {
714 None
715 };
716
717 for pending in queue.drain(..) {
718 let r = match (&result, &settle_tx) {
719 (Ok(()), Some((tx, _))) => Ok(ReadySignal {
720 settle_tx: tx.clone(),
721 }),
722 (Err(e), _) => Err(SwitchError::Internal(e.to_string())),
723 _ => unreachable!(),
724 };
725 if pending.ready_tx.send(r).is_ok() {
726 delivered += 1;
727 }
728 }
729
730 metrics::gauge!("llmux_request_queue_depth", "model" => model.to_string()).set(0.0);
731 drop(queue); if count > 0 {
734 let expired = count - delivered;
735 if expired > 0 {
736 warn!(model = %model, count, delivered, expired,
737 "Notified pending requests ({expired} already timed out)");
738 } else {
739 debug!(model = %model, count, "Notified pending requests");
740 }
741 }
742
743 if let Some((tx, mut rx)) = settle_tx {
750 drop(tx);
753
754 if delivered > 0 {
755 let settle_wait = async {
756 for _ in 0..delivered {
757 if rx.recv().await.is_none() {
758 break; }
760 }
761 };
762
763 if tokio::time::timeout(Duration::from_secs(5), settle_wait)
764 .await
765 .is_err()
766 {
767 warn!(
768 model = %model,
769 delivered,
770 "Settle timeout — proceeding with switch lock release"
771 );
772 }
773 }
774 }
775 }
776}
777
778pub struct InFlightGuard {
781 model_state: Arc<ModelState>,
782 model: String,
783}
784
785impl Drop for InFlightGuard {
786 fn drop(&mut self) {
787 let prev = self.model_state.in_flight.fetch_sub(1, Ordering::SeqCst);
788 metrics::gauge!("llmux_model_in_flight", "model" => self.model.clone())
789 .set((prev - 1) as f64);
790 if prev == 1 {
791 self.model_state.in_flight_changed.notify_waiters();
792 }
793 }
794}
795
796#[cfg(test)]
797mod tests {
798 use super::*;
799 use crate::config::ModelConfig;
800 use crate::policy::FifoPolicy;
801
802 fn make_test_hooks() -> Arc<HookRunner> {
803 let mut configs = HashMap::new();
804 configs.insert(
805 "model-a".to_string(),
806 ModelConfig {
807 port: 8001,
808 wake: "true".to_string(),
809 sleep: "true".to_string(),
810 alive: "true".to_string(),
811 },
812 );
813 configs.insert(
814 "model-b".to_string(),
815 ModelConfig {
816 port: 8002,
817 wake: "true".to_string(),
818 sleep: "true".to_string(),
819 alive: "true".to_string(),
820 },
821 );
822 Arc::new(HookRunner::new(configs))
823 }
824
825 #[test]
826 fn test_switcher_creation() {
827 let hooks = make_test_hooks();
828 let policy = Box::new(FifoPolicy::default());
829 let switcher = ModelSwitcher::new(hooks, policy);
830
831 assert!(switcher.is_registered("model-a"));
832 assert!(switcher.is_registered("model-b"));
833 assert!(!switcher.is_registered("model-c"));
834 }
835
836 #[tokio::test]
837 async fn test_in_flight_tracking() {
838 let hooks = make_test_hooks();
839 let policy = Box::new(FifoPolicy::default());
840 let switcher = ModelSwitcher::new(hooks, policy);
841
842 assert_eq!(switcher.in_flight_count("model-a"), 0);
843
844 {
845 let _guard = switcher.acquire_in_flight("model-a");
846 assert_eq!(switcher.in_flight_count("model-a"), 1);
847 }
848
849 assert_eq!(switcher.in_flight_count("model-a"), 0);
850 }
851
852 #[test]
853 fn test_acquire_in_flight_rejected_while_draining() {
854 let hooks = make_test_hooks();
855 let policy = Box::new(FifoPolicy::default());
856 let switcher = ModelSwitcher::new(hooks, policy);
857
858 let guard = switcher.acquire_in_flight("model-a");
859 assert!(guard.is_some());
860 assert_eq!(switcher.in_flight_count("model-a"), 1);
861 drop(guard);
862
863 let model_state = switcher.inner.model_states.get("model-a").unwrap();
865 model_state.draining.store(true, Ordering::SeqCst);
866
867 let guard = switcher.acquire_in_flight("model-a");
868 assert!(guard.is_none());
869 assert_eq!(switcher.in_flight_count("model-a"), 0);
870
871 model_state.draining.store(false, Ordering::SeqCst);
872
873 let guard = switcher.acquire_in_flight("model-a");
874 assert!(guard.is_some());
875 assert_eq!(switcher.in_flight_count("model-a"), 1);
876 drop(guard);
877 }
878
879 #[tokio::test]
880 async fn test_model_port() {
881 let hooks = make_test_hooks();
882 let policy = Box::new(FifoPolicy::default());
883 let switcher = ModelSwitcher::new(hooks, policy);
884
885 assert_eq!(switcher.model_port("model-a"), Some(8001));
886 assert_eq!(switcher.model_port("model-b"), Some(8002));
887 assert_eq!(switcher.model_port("model-c"), None);
888 }
889
890 #[tokio::test]
891 async fn test_force_switch_unknown_model() {
892 let hooks = make_test_hooks();
893 let policy = Box::new(FifoPolicy::default());
894 let switcher = ModelSwitcher::new(hooks, policy);
895
896 let result = switcher.force_switch("nonexistent").await;
897 assert!(matches!(result, Err(SwitchError::ModelNotFound(_))));
898 }
899
900 #[tokio::test]
901 async fn test_force_switch_already_active() {
902 let hooks = make_test_hooks();
903 let policy = Box::new(FifoPolicy::default());
904 let switcher = ModelSwitcher::new(hooks, policy);
905
906 {
907 let mut state = switcher.inner.state.write().await;
908 *state = SwitcherState::Active {
909 model: "model-a".to_string(),
910 };
911 }
912
913 let result = switcher.force_switch("model-a").await;
914 assert!(result.is_ok());
915 }
916}