use crate::switcher::EvictionPolicy;
use async_trait::async_trait;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;
use tokio::sync::Notify;
use tracing::debug;
#[derive(Debug, Clone)]
pub struct PolicyContext {
pub target_model: String,
pub active_model: Option<String>,
pub target_queue_depth: usize,
pub oldest_waiting: Duration,
pub active_in_flight: usize,
pub active_duration: Duration,
}
#[derive(Debug, Clone)]
pub struct ScheduleContext {
pub active_model: Option<String>,
pub active_duration: Duration,
pub queue_depths: HashMap<String, usize>,
pub active_in_flight: usize,
}
pub struct SwitchContext {
pub from_model: Option<String>,
pub to_model: String,
in_flight_drained: Arc<Notify>,
get_in_flight: Box<dyn Fn() -> usize + Send + Sync>,
}
impl SwitchContext {
pub fn new(
from_model: Option<String>,
to_model: String,
in_flight_drained: Arc<Notify>,
get_in_flight: Box<dyn Fn() -> usize + Send + Sync>,
) -> Self {
Self {
from_model,
to_model,
in_flight_drained,
get_in_flight,
}
}
pub async fn wait_for_in_flight(&self) {
while (self.get_in_flight)() > 0 {
self.in_flight_drained.notified().await;
}
}
pub fn in_flight_count(&self) -> usize {
(self.get_in_flight)()
}
}
pub enum PolicyDecision {
SwitchNow,
Defer(Pin<Box<dyn Future<Output = ()> + Send + 'static>>),
Skip,
}
#[async_trait]
pub trait SwitchPolicy: Send + Sync {
async fn on_pending_request(&self, ctx: &PolicyContext) -> PolicyDecision;
async fn prepare_switch(&self, ctx: &mut SwitchContext);
fn on_switch_complete(&self, _from: &str, _to: &str, _duration: Duration) {}
fn eviction_policy(&self) -> EvictionPolicy;
fn request_timeout(&self) -> Duration;
fn min_active_duration(&self) -> Duration;
fn scheduler_interval(&self) -> Option<Duration> {
None
}
fn schedule_tick(&self, _ctx: &ScheduleContext) -> Option<String> {
None
}
}
pub struct FifoPolicy {
eviction: EvictionPolicy,
request_timeout: Duration,
drain_before_switch: bool,
min_active_duration: Duration,
}
impl FifoPolicy {
pub fn new(
eviction: EvictionPolicy,
request_timeout: Duration,
drain_before_switch: bool,
min_active_duration: Duration,
) -> Self {
Self {
eviction,
request_timeout,
drain_before_switch,
min_active_duration,
}
}
}
impl Default for FifoPolicy {
fn default() -> Self {
Self::new(
EvictionPolicy::from(1),
Duration::from_secs(300),
true,
Duration::from_secs(5),
)
}
}
#[async_trait]
impl SwitchPolicy for FifoPolicy {
async fn on_pending_request(&self, _ctx: &PolicyContext) -> PolicyDecision {
PolicyDecision::SwitchNow
}
async fn prepare_switch(&self, ctx: &mut SwitchContext) {
if self.drain_before_switch {
ctx.wait_for_in_flight().await;
}
}
fn eviction_policy(&self) -> EvictionPolicy {
self.eviction
}
fn request_timeout(&self) -> Duration {
self.request_timeout
}
fn min_active_duration(&self) -> Duration {
self.min_active_duration
}
}
struct SwitchTiming {
ema_ms: AtomicU64,
count: AtomicU64,
}
impl SwitchTiming {
fn new(initial_estimate_ms: u64) -> Self {
Self {
ema_ms: AtomicU64::new(initial_estimate_ms),
count: AtomicU64::new(0),
}
}
fn record(&self, duration: Duration) {
let ms = duration.as_millis() as u64;
let ms = ms.min(60_000);
let n = self.count.fetch_add(1, Ordering::Relaxed);
if n == 0 {
self.ema_ms.store(ms, Ordering::Relaxed);
} else {
let old = self.ema_ms.load(Ordering::Relaxed);
let new_ema = (3 * ms + 7 * old) / 10;
self.ema_ms.store(new_ema, Ordering::Relaxed);
}
}
fn estimated_ms(&self) -> u64 {
self.ema_ms.load(Ordering::Relaxed)
}
}
struct CoalesceState {
defer_pending: AtomicBool,
}
impl Default for CoalesceState {
fn default() -> Self {
Self {
defer_pending: AtomicBool::new(false),
}
}
}
pub struct CostAwarePolicy {
eviction: EvictionPolicy,
request_timeout: Duration,
min_active_duration: Duration,
coalesce_window: Duration,
amortization_factor: f64,
max_wait: Duration,
timings: HashMap<String, SwitchTiming>,
default_estimate_ms: u64,
coalesce_states: HashMap<String, CoalesceState>,
}
impl CostAwarePolicy {
pub fn new(
eviction: EvictionPolicy,
request_timeout: Duration,
min_active_duration: Duration,
coalesce_window: Duration,
amortization_factor: f64,
max_wait: Duration,
model_names: Vec<String>,
) -> Self {
let mut timings = HashMap::new();
let mut coalesce_states = HashMap::new();
for from in &model_names {
for to in &model_names {
if from != to {
let key = format!("{}→{}", from, to);
timings.insert(key, SwitchTiming::new(10_000));
}
}
coalesce_states.insert(from.clone(), CoalesceState::default());
}
Self {
eviction,
request_timeout,
min_active_duration,
coalesce_window,
amortization_factor,
max_wait,
timings,
default_estimate_ms: 10_000,
coalesce_states,
}
}
fn record_switch(&self, from: &str, to: &str, duration: Duration) {
let key = format!("{}→{}", from, to);
if let Some(timing) = self.timings.get(&key) {
timing.record(duration);
debug!(
direction = %key,
observed_ms = duration.as_millis(),
new_ema_ms = timing.estimated_ms(),
"Recorded switch timing"
);
}
}
fn estimated_switch_cost(&self, from: Option<&str>, to: &str) -> f64 {
match from {
None => {
self.default_estimate_ms as f64 / 1000.0
}
Some(from) => {
let key = format!("{}→{}", from, to);
self.timings
.get(&key)
.map(|t| t.estimated_ms() as f64 / 1000.0)
.unwrap_or(self.default_estimate_ms as f64 / 1000.0)
}
}
}
fn min_queue_depth(&self, switch_cost_secs: f64) -> usize {
let min = (self.amortization_factor * switch_cost_secs).ceil() as usize;
min.max(1)
}
}
#[async_trait]
impl SwitchPolicy for CostAwarePolicy {
async fn on_pending_request(&self, ctx: &PolicyContext) -> PolicyDecision {
let switch_cost =
self.estimated_switch_cost(ctx.active_model.as_deref(), &ctx.target_model);
let required_depth = self.min_queue_depth(switch_cost);
debug!(
target_model = %ctx.target_model,
queue_depth = ctx.target_queue_depth,
oldest_waiting_ms = ctx.oldest_waiting.as_millis(),
switch_cost_s = format!("{:.1}", switch_cost),
required_depth,
active_in_flight = ctx.active_in_flight,
"CostAware: evaluating switch"
);
if ctx.oldest_waiting >= self.max_wait {
debug!("CostAware: staleness override — switching now");
return PolicyDecision::SwitchNow;
}
if ctx.active_model.is_none() {
return PolicyDecision::SwitchNow;
}
let serving_window = Duration::from_secs_f64(switch_cost);
if ctx.active_duration < serving_window {
let remaining = serving_window - ctx.active_duration;
debug!(
active_duration_ms = ctx.active_duration.as_millis(),
serving_window_ms = serving_window.as_millis(),
remaining_ms = remaining.as_millis(),
"CostAware: within serving window, deferring"
);
let coalesce_state = self.coalesce_states.get(&ctx.target_model);
if let Some(state) = coalesce_state {
if state.defer_pending.load(Ordering::SeqCst) {
return PolicyDecision::Skip;
}
state.defer_pending.store(true, Ordering::SeqCst);
}
let target = ctx.target_model.clone();
return PolicyDecision::Defer(Box::pin(async move {
tokio::time::sleep(remaining).await;
debug!(model = %target, "CostAware: serving window expired");
}));
}
if ctx.target_queue_depth >= required_depth {
debug!(
"CostAware: queue depth {} >= required {}, switching now",
ctx.target_queue_depth, required_depth
);
return PolicyDecision::SwitchNow;
}
let coalesce_state = self.coalesce_states.get(&ctx.target_model);
if let Some(state) = coalesce_state
&& state.defer_pending.load(Ordering::SeqCst)
{
return PolicyDecision::Skip;
}
if let Some(state) = coalesce_state {
state.defer_pending.store(true, Ordering::SeqCst);
}
let window = self.coalesce_window;
let target = ctx.target_model.clone();
PolicyDecision::Defer(Box::pin(async move {
tokio::time::sleep(window).await;
debug!(model = %target, "CostAware: coalescing window expired");
}))
}
async fn prepare_switch(&self, ctx: &mut SwitchContext) {
if let Some(state) = self.coalesce_states.get(&ctx.to_model) {
state.defer_pending.store(false, Ordering::SeqCst);
}
ctx.wait_for_in_flight().await;
}
fn on_switch_complete(&self, from: &str, to: &str, duration: Duration) {
self.record_switch(from, to, duration);
}
fn eviction_policy(&self) -> EvictionPolicy {
self.eviction
}
fn request_timeout(&self) -> Duration {
self.request_timeout
}
fn min_active_duration(&self) -> Duration {
self.min_active_duration
}
}
pub struct TimeSlicePolicy {
eviction: EvictionPolicy,
request_timeout: Duration,
min_active_duration: Duration,
max_wait: Duration,
tick_interval: Duration,
}
impl TimeSlicePolicy {
pub fn new(
eviction: EvictionPolicy,
request_timeout: Duration,
min_active_duration: Duration,
max_wait: Duration,
_min_quantum: Duration,
tick_interval: Duration,
_model_names: Vec<String>,
) -> Self {
Self {
eviction,
request_timeout,
min_active_duration,
max_wait,
tick_interval,
}
}
}
#[async_trait]
impl SwitchPolicy for TimeSlicePolicy {
async fn on_pending_request(&self, ctx: &PolicyContext) -> PolicyDecision {
if ctx.oldest_waiting >= self.max_wait {
debug!(
target_model = %ctx.target_model,
oldest_waiting_ms = ctx.oldest_waiting.as_millis(),
"TimeSlice: staleness override — switching now"
);
return PolicyDecision::SwitchNow;
}
if ctx.active_model.is_none() {
return PolicyDecision::SwitchNow;
}
PolicyDecision::Skip
}
async fn prepare_switch(&self, ctx: &mut SwitchContext) {
ctx.wait_for_in_flight().await;
}
fn eviction_policy(&self) -> EvictionPolicy {
self.eviction
}
fn request_timeout(&self) -> Duration {
self.request_timeout
}
fn min_active_duration(&self) -> Duration {
self.min_active_duration
}
fn scheduler_interval(&self) -> Option<Duration> {
Some(self.tick_interval)
}
fn schedule_tick(&self, ctx: &ScheduleContext) -> Option<String> {
let active = ctx.active_model.as_deref()?;
if ctx.active_duration < self.min_active_duration {
return None;
}
let active_depth = ctx.queue_depths.get(active).copied().unwrap_or(0);
if active_depth > 0 || ctx.active_in_flight > 0 {
return None;
}
let best_target = ctx
.queue_depths
.iter()
.filter(|(model, _)| model.as_str() != active)
.filter(|(_, depth)| **depth > 0)
.max_by_key(|(_, depth)| **depth);
let (target_model, _) = best_target?;
debug!(
from = %active,
to = %target_model,
"TimeSlice: active model idle, switching to model with most demand"
);
Some(target_model.clone())
}
}