use async_trait::async_trait;
use ferrum_types::{
BatchId, InferenceRequest, InferenceResponse, Priority, RequestId, RequestState, Result,
SchedulerConfig as TypesSchedulerConfig, SchedulerStats,
};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, time::Duration};
#[async_trait]
pub trait Scheduler: Send + Sync {
async fn submit(&self, request: InferenceRequest) -> Result<RequestId>;
async fn next_batch(&self, hint: BatchHint) -> Option<BatchPlan>;
async fn complete(&self, request_id: RequestId, response: &InferenceResponse) -> Result<()>;
async fn cancel(&self, request_id: RequestId) -> Result<bool>;
async fn update_priority(&self, request_id: RequestId, priority: Priority) -> Result<()>;
fn metrics(&self) -> SchedulerMetrics;
fn config(&self) -> &TypesSchedulerConfig;
fn request_state(&self, request_id: &RequestId) -> Option<RequestState> {
let _ = request_id;
None
}
async fn preempt(&self, _request_id: RequestId) -> Result<PreemptionResult> {
Err(ferrum_types::FerrumError::unsupported(
"Preemption not supported",
))
}
async fn resume(&self, _request_id: RequestId) -> Result<()> {
Err(ferrum_types::FerrumError::unsupported(
"Resumption not supported",
))
}
}
#[derive(Debug, Clone)]
pub struct BatchHint {
pub max_batch_size: usize,
pub max_tokens: usize,
pub target_latency_ms: Option<u64>,
pub available_memory: Option<u64>,
pub resource_constraints: ResourceConstraints,
}
impl BatchHint {
pub fn simple(max_batch_size: usize) -> Self {
Self {
max_batch_size,
max_tokens: max_batch_size * 2048, target_latency_ms: None,
available_memory: None,
resource_constraints: ResourceConstraints::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceConstraints {
pub max_gpu_memory: Option<u64>,
pub max_cpu_memory: Option<u64>,
pub max_compute_units: Option<usize>,
pub required_devices: Vec<ferrum_types::Device>,
}
impl Default for ResourceConstraints {
fn default() -> Self {
Self {
max_gpu_memory: None,
max_cpu_memory: None,
max_compute_units: None,
required_devices: vec![],
}
}
}
#[derive(Debug, Clone)]
pub struct BatchPlan {
pub batch_id: BatchId,
pub requests: Vec<ScheduledRequest>,
pub max_sequence_length: usize,
pub estimated_time_ms: Option<u64>,
pub resource_requirements: BatchResourceRequirements,
pub created_at: chrono::DateTime<chrono::Utc>,
}
impl BatchPlan {
pub fn total_tokens(&self) -> usize {
self.requests
.iter()
.map(|req| req.request.sampling_params.max_tokens)
.sum()
}
pub fn size(&self) -> usize {
self.requests.len()
}
pub fn is_empty(&self) -> bool {
self.requests.is_empty()
}
pub fn max_priority(&self) -> Priority {
self.requests
.iter()
.map(|req| req.request.priority)
.max()
.unwrap_or(Priority::Low)
}
}
#[derive(Debug, Clone)]
pub struct ScheduledRequest {
pub request: InferenceRequest,
pub state: RequestState,
pub queue_position: Option<usize>,
pub estimated_wait_time: Option<Duration>,
pub tokens_processed: usize,
pub allocated_resources: AllocatedResources,
pub submitted_at: chrono::DateTime<chrono::Utc>,
pub started_at: Option<chrono::DateTime<chrono::Utc>>,
}
impl ScheduledRequest {
pub fn new(request: InferenceRequest) -> Self {
Self {
request,
state: RequestState::Waiting,
queue_position: None,
estimated_wait_time: None,
tokens_processed: 0,
allocated_resources: AllocatedResources::default(),
submitted_at: chrono::Utc::now(),
started_at: None,
}
}
pub fn age(&self) -> Duration {
(chrono::Utc::now() - self.submitted_at)
.to_std()
.unwrap_or_default()
}
pub fn processing_time(&self) -> Option<Duration> {
self.started_at
.map(|start| (chrono::Utc::now() - start).to_std().unwrap_or_default())
}
}
#[derive(Debug, Clone, Default)]
pub struct AllocatedResources {
pub kv_cache_blocks: Vec<ferrum_types::BlockId>,
pub gpu_memory: u64,
pub cpu_memory: u64,
pub compute_units: usize,
}
#[derive(Debug, Clone)]
pub struct BatchResourceRequirements {
pub gpu_memory: u64,
pub cpu_memory: u64,
pub kv_cache_blocks: usize,
pub compute_units: usize,
}
#[derive(Debug, Clone)]
pub struct PreemptionResult {
pub success: bool,
pub saved_state: Option<PreemptionState>,
pub freed_resources: AllocatedResources,
}
#[derive(Debug, Clone)]
pub struct PreemptionState {
pub kv_cache_checkpoint: Vec<u8>,
pub tokens_processed: usize,
pub generation_state: HashMap<String, serde_json::Value>,
}
pub type SchedulerConfig = TypesSchedulerConfig;
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum SchedulingPolicy {
FCFS,
Priority,
FairShare,
SJF,
ResourceAware,
SlaAware,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FairShareConfig {
pub client_shares: HashMap<String, f32>,
pub default_share: f32,
pub enforcement_strictness: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SlaConfig {
pub enabled: bool,
pub default_sla: SlaRequirements,
pub client_slas: HashMap<String, SlaRequirements>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SlaRequirements {
pub max_latency_p95_ms: u64,
pub max_latency_p99_ms: u64,
pub min_throughput_rps: f32,
pub availability_percent: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceLimits {
pub max_gpu_memory: Option<u64>,
pub max_cpu_memory: Option<u64>,
pub max_kv_cache_blocks: Option<usize>,
pub per_client_limits: HashMap<String, ClientResourceLimits>,
}
impl Default for ResourceLimits {
fn default() -> Self {
Self {
max_gpu_memory: None,
max_cpu_memory: None,
max_kv_cache_blocks: None,
per_client_limits: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientResourceLimits {
pub max_concurrent_requests: usize,
pub max_gpu_memory: Option<u64>,
pub max_requests_per_minute: Option<u32>,
}
pub type SchedulerMetrics = SchedulerStats;
#[async_trait]
pub trait AdvancedScheduler: Scheduler {
async fn enable_resource_awareness(&mut self, config: ResourceAwarenessConfig) -> Result<()>;
async fn set_admission_policy(&mut self, policy: Box<dyn AdmissionPolicy>) -> Result<()>;
async fn configure_dynamic_batching(&mut self, config: DynamicBatchingConfig) -> Result<()>;
fn queue_analysis(&self) -> QueueAnalysis;
async fn simulate_load(
&self,
workload: &SimulatedWorkload,
) -> Result<SchedulingSimulationResult>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceAwarenessConfig {
pub enable_memory_awareness: bool,
pub enable_compute_awareness: bool,
pub prediction_horizon_ms: u64,
pub safety_margin: f32,
}
pub trait AdmissionPolicy: Send + Sync {
fn should_admit(
&self,
request: &InferenceRequest,
current_metrics: &SchedulerMetrics,
) -> AdmissionDecision;
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub enum AdmissionDecision {
Accept,
Reject(String),
AcceptWithDelay(Duration),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DynamicBatchingConfig {
pub min_batch_size: usize,
pub max_batch_size: usize,
pub batch_timeout_ms: u64,
pub enable_adaptive_sizing: bool,
pub target_utilization: f32,
}
#[derive(Debug, Clone)]
pub struct QueueAnalysis {
pub queue_depth_history: Vec<(chrono::DateTime<chrono::Utc>, usize)>,
pub wait_time_distribution: WaitTimeDistribution,
pub request_patterns: RequestPatternAnalysis,
pub bottlenecks: Vec<BottleneckAnalysis>,
}
#[derive(Debug, Clone)]
pub struct WaitTimeDistribution {
pub p50_ms: f64,
pub p95_ms: f64,
pub p99_ms: f64,
pub max_ms: f64,
pub mean_ms: f64,
}
#[derive(Debug, Clone)]
pub struct RequestPatternAnalysis {
pub peak_times: Vec<chrono::DateTime<chrono::Utc>>,
pub rate_trend: RateTrend,
pub seasonality: SeasonalityPattern,
}
#[derive(Debug, Clone, Copy)]
pub enum RateTrend {
Increasing,
Decreasing,
Stable,
Volatile,
}
#[derive(Debug, Clone)]
pub struct SeasonalityPattern {
pub hourly_pattern: Vec<f32>,
pub daily_pattern: Vec<f32>,
pub weekly_pattern: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct BottleneckAnalysis {
pub bottleneck_type: BottleneckType,
pub severity: f32,
pub description: String,
pub mitigation: String,
}
#[derive(Debug, Clone, Copy)]
pub enum BottleneckType {
Memory,
Compute,
IO,
Scheduling,
Network,
}
#[derive(Debug, Clone)]
pub struct SimulatedWorkload {
pub arrival_pattern: ArrivalPattern,
pub size_distribution: SizeDistribution,
pub duration_seconds: u64,
}
#[derive(Debug, Clone)]
pub enum ArrivalPattern {
Constant { rate_rps: f32 },
Poisson { lambda: f32 },
Bursty {
burst_rate: f32,
quiet_rate: f32,
burst_duration_s: f32,
},
Seasonal {
base_rate: f32,
peaks: Vec<(f32, f32)>,
}, }
#[derive(Debug, Clone)]
pub enum SizeDistribution {
Fixed { tokens: usize },
Uniform {
min_tokens: usize,
max_tokens: usize,
},
Normal { mean: f32, std_dev: f32 },
LogNormal { mu: f32, sigma: f32 },
}
#[derive(Debug, Clone)]
pub struct SchedulingSimulationResult {
pub total_requests: u64,
pub successful_requests: u64,
pub failed_requests: u64,
pub avg_latency_ms: f64,
pub p95_latency_ms: f64,
pub p99_latency_ms: f64,
pub throughput_rps: f32,
pub resource_utilization: Option<ResourceStats>,
pub bottlenecks: Vec<BottleneckAnalysis>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ResourceStats {
pub gpu_memory_bytes: Option<u64>,
pub cpu_memory_bytes: Option<u64>,
pub compute_utilization: Option<f32>,
}