#![allow(dead_code)]
use crate::operations::queue::priority_queue::RequestMetadata;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AssignmentResult {
pub request_id: String,
pub assigned_worker_id: u32,
pub estimated_duration_ms: u32,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
pub enum BackpressureStatus {
Healthy,
Elevated,
Critical,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
pub enum AssignmentStrategy {
LeastLoaded,
EarliestCompletion,
RoundRobin,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestGroup {
pub requests: Vec<String>, pub model_id: String,
pub total_tokens: u32,
pub batch_size: usize,
pub priority: u8,
}
#[derive(Debug)]
pub struct LoadBalancer {
strategy: AssignmentStrategy,
worker_load: HashMap<u32, u32>,
worker_eta: HashMap<u32, u64>,
worker_gpu_memory: HashMap<u32, u32>,
max_queue_depth: usize,
min_gpu_memory_free_mb: u32,
batch_grouping_window_ms: u32,
max_batch_size: usize,
pending_groups: VecDeque<RequestGroup>,
}
impl LoadBalancer {
pub fn new(strategy: AssignmentStrategy) -> Self {
Self {
strategy,
worker_load: HashMap::new(),
worker_eta: HashMap::new(),
worker_gpu_memory: HashMap::new(),
max_queue_depth: 10_000,
min_gpu_memory_free_mb: 512,
batch_grouping_window_ms: 50,
max_batch_size: 32,
pending_groups: VecDeque::new(),
}
}
pub fn with_max_queue_depth(mut self, depth: usize) -> Self {
self.max_queue_depth = depth;
self
}
pub fn with_min_gpu_memory_mb(mut self, memory: u32) -> Self {
self.min_gpu_memory_free_mb = memory;
self
}
pub fn with_batch_grouping_window_ms(mut self, window: u32) -> Self {
self.batch_grouping_window_ms = window;
self
}
pub fn with_max_batch_size(mut self, size: usize) -> Self {
self.max_batch_size = size;
self
}
pub fn register_worker(&mut self, worker_id: u32) {
self.worker_load.insert(worker_id, 0);
self.worker_eta.insert(worker_id, 0);
self.worker_gpu_memory.insert(worker_id, 0);
}
pub fn unregister_worker(&mut self, worker_id: u32) {
self.worker_load.remove(&worker_id);
self.worker_eta.remove(&worker_id);
self.worker_gpu_memory.remove(&worker_id);
}
pub fn update_worker_metrics(
&mut self,
worker_id: u32,
active_requests: u32,
estimated_completion_ms: u64,
gpu_memory_mb: u32,
) {
self.worker_load.insert(worker_id, active_requests);
self.worker_eta.insert(worker_id, estimated_completion_ms);
self.worker_gpu_memory.insert(worker_id, gpu_memory_mb);
}
pub fn assign_request(
&self,
request: &RequestMetadata,
available_gpu_memory_mb: u32,
) -> Option<AssignmentResult> {
if available_gpu_memory_mb < self.min_gpu_memory_free_mb {
return None; }
if self.worker_load.is_empty() {
return None; }
let assigned_worker = match self.strategy {
AssignmentStrategy::LeastLoaded => self.find_least_loaded_worker(),
AssignmentStrategy::EarliestCompletion => self.find_earliest_completion_worker(),
AssignmentStrategy::RoundRobin => self.find_next_worker(),
}?;
let tokens = request.estimated_tokens;
let estimated_tokens_per_sec = 50; let estimated_duration_ms = (tokens / estimated_tokens_per_sec) * 1000;
Some(AssignmentResult {
request_id: request.request_id.clone(),
assigned_worker_id: assigned_worker,
estimated_duration_ms,
})
}
fn find_least_loaded_worker(&self) -> Option<u32> {
self.worker_load
.iter()
.min_by_key(|(_, load)| **load)
.map(|(id, _)| *id)
}
fn find_earliest_completion_worker(&self) -> Option<u32> {
self.worker_eta
.iter()
.min_by_key(|(_, eta)| **eta)
.map(|(id, _)| *id)
}
fn find_next_worker(&self) -> Option<u32> {
self.worker_load.keys().next().copied()
}
pub fn group_requests(
&mut self,
requests: Vec<RequestMetadata>,
model_id: &str,
) -> Vec<RequestGroup> {
let mut groups: HashMap<u8, Vec<String>> = HashMap::new();
for request in requests {
groups
.entry(request.priority as u8)
.or_default()
.push(request.request_id);
}
let mut result = Vec::new();
for (priority, mut request_ids) in groups {
while !request_ids.is_empty() {
let chunk_size = request_ids.len().min(self.max_batch_size);
let chunk: Vec<String> = request_ids.drain(0..chunk_size).collect();
let total_tokens = chunk.len() as u32 * 256;
result.push(RequestGroup {
requests: chunk.clone(),
model_id: model_id.to_string(),
total_tokens,
batch_size: chunk.len(),
priority,
});
}
}
result
}
pub fn check_backpressure(
&self,
current_queue_depth: usize,
available_gpu_memory_mb: u32,
) -> BackpressureStatus {
let queue_utilization = current_queue_depth as f32 / self.max_queue_depth as f32;
match (
queue_utilization,
available_gpu_memory_mb < self.min_gpu_memory_free_mb,
) {
(util, true) if util >= 0.8 => BackpressureStatus::Critical,
(_, true) => BackpressureStatus::Elevated,
(util, _) if util >= 0.9 => BackpressureStatus::Critical,
(util, _) if util >= 0.7 => BackpressureStatus::Elevated,
_ => BackpressureStatus::Healthy,
}
}
pub fn load_stats(&self) -> LoadStats {
let total_load: u32 = self.worker_load.values().sum();
let worker_count = self.worker_load.len();
let avg_load = if worker_count > 0 {
total_load as f32 / worker_count as f32
} else {
0.0
};
let total_gpu_memory: u32 = self.worker_gpu_memory.values().sum();
LoadStats {
total_load,
worker_count,
avg_load_per_worker: avg_load,
total_gpu_memory_used_mb: total_gpu_memory,
}
}
}
impl Default for LoadBalancer {
fn default() -> Self {
Self::new(AssignmentStrategy::LeastLoaded)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoadStats {
pub total_load: u32,
pub worker_count: usize,
pub avg_load_per_worker: f32,
pub total_gpu_memory_used_mb: u32,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::operations::queue::Priority;
#[test]
fn test_load_balancer_creation() {
let lb = LoadBalancer::new(AssignmentStrategy::LeastLoaded);
assert_eq!(lb.max_queue_depth, 10_000);
}
#[test]
fn test_worker_registration() {
let mut lb = LoadBalancer::new(AssignmentStrategy::LeastLoaded);
lb.register_worker(1);
lb.register_worker(2);
lb.register_worker(3);
assert!(lb.worker_load.contains_key(&1));
assert!(lb.worker_load.contains_key(&2));
assert!(lb.worker_load.contains_key(&3));
}
#[test]
fn test_least_loaded_assignment() {
let mut lb = LoadBalancer::new(AssignmentStrategy::LeastLoaded);
lb.register_worker(1);
lb.register_worker(2);
lb.register_worker(3);
lb.update_worker_metrics(1, 10, 1000, 4096);
lb.update_worker_metrics(2, 5, 500, 4096);
lb.update_worker_metrics(3, 15, 2000, 4096);
let least_loaded = lb.find_least_loaded_worker();
assert_eq!(least_loaded, Some(2)); }
#[test]
fn test_earliest_completion_assignment() {
let mut lb = LoadBalancer::new(AssignmentStrategy::EarliestCompletion);
lb.register_worker(1);
lb.register_worker(2);
lb.register_worker(3);
lb.update_worker_metrics(1, 5, 2000, 4096);
lb.update_worker_metrics(2, 10, 500, 4096);
lb.update_worker_metrics(3, 8, 1500, 4096);
let earliest = lb.find_earliest_completion_worker();
assert_eq!(earliest, Some(2)); }
#[test]
fn test_backpressure_detection() {
let lb = LoadBalancer::new(AssignmentStrategy::LeastLoaded);
let status = lb.check_backpressure(100, 1024);
assert_eq!(status, BackpressureStatus::Healthy);
let status = lb.check_backpressure(7000, 1024);
assert_eq!(status, BackpressureStatus::Elevated);
let status = lb.check_backpressure(9500, 100);
assert_eq!(status, BackpressureStatus::Critical);
}
#[test]
fn test_request_grouping() {
let mut lb = LoadBalancer::new(AssignmentStrategy::LeastLoaded);
let mut requests = Vec::new();
for i in 0..5 {
requests.push(RequestMetadata::new(
format!("req_{}", i),
"user".to_string(),
Priority::Normal,
"model".to_string(),
));
}
let groups = lb.group_requests(requests, "model");
assert!(!groups.is_empty());
assert_eq!(groups[0].model_id, "model");
}
#[test]
fn test_load_stats() {
let mut lb = LoadBalancer::new(AssignmentStrategy::LeastLoaded);
lb.register_worker(1);
lb.register_worker(2);
lb.update_worker_metrics(1, 5, 1000, 4096);
lb.update_worker_metrics(2, 10, 2000, 4096);
let stats = lb.load_stats();
assert_eq!(stats.total_load, 15);
assert_eq!(stats.worker_count, 2);
assert_eq!(stats.avg_load_per_worker, 7.5);
}
}