use async_trait::async_trait;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Notify;
#[derive(Debug, Clone)]
pub struct PolicyContext {
pub target_model: String,
pub active_model: Option<String>,
pub target_queue_depth: usize,
pub oldest_waiting: Duration,
pub active_in_flight: usize,
}
pub struct SwitchContext {
pub from_model: Option<String>,
pub to_model: String,
in_flight_drained: Arc<Notify>,
get_in_flight: Box<dyn Fn() -> usize + Send + Sync>,
}
impl SwitchContext {
pub fn new(
from_model: Option<String>,
to_model: String,
in_flight_drained: Arc<Notify>,
get_in_flight: Box<dyn Fn() -> usize + Send + Sync>,
) -> Self {
Self {
from_model,
to_model,
in_flight_drained,
get_in_flight,
}
}
pub async fn wait_for_in_flight(&self) {
while (self.get_in_flight)() > 0 {
self.in_flight_drained.notified().await;
}
}
pub fn in_flight_count(&self) -> usize {
(self.get_in_flight)()
}
}
pub enum PolicyDecision {
SwitchNow,
Defer(Pin<Box<dyn Future<Output = ()> + Send + 'static>>),
}
#[async_trait]
pub trait SwitchPolicy: Send + Sync {
async fn on_pending_request(&self, ctx: &PolicyContext) -> PolicyDecision;
async fn prepare_switch(&self, ctx: &mut SwitchContext);
fn sleep_level(&self) -> u8;
fn request_timeout(&self) -> Duration;
fn min_active_duration(&self) -> Duration;
}
pub struct FifoPolicy {
sleep_level: u8,
request_timeout: Duration,
drain_before_switch: bool,
min_active_duration: Duration,
}
impl FifoPolicy {
pub fn new(
sleep_level: u8,
request_timeout: Duration,
drain_before_switch: bool,
min_active_duration: Duration,
) -> Self {
Self {
sleep_level,
request_timeout,
drain_before_switch,
min_active_duration,
}
}
}
impl Default for FifoPolicy {
fn default() -> Self {
Self::new(1, Duration::from_secs(60), true, Duration::from_secs(5))
}
}
#[async_trait]
impl SwitchPolicy for FifoPolicy {
async fn on_pending_request(&self, _ctx: &PolicyContext) -> PolicyDecision {
PolicyDecision::SwitchNow
}
async fn prepare_switch(&self, ctx: &mut SwitchContext) {
if self.drain_before_switch {
ctx.wait_for_in_flight().await;
}
}
fn sleep_level(&self) -> u8 {
self.sleep_level
}
fn request_timeout(&self) -> Duration {
self.request_timeout
}
fn min_active_duration(&self) -> Duration {
self.min_active_duration
}
}