use crate::priority::Priority;
use kapsl_engine_api::{BinaryTensorPacket, EngineError, EngineMetrics, InferenceRequest};
use parking_lot::RwLock;
use std::cmp::Reverse;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PoolStrategy {
RoundRobin,
LeastLoaded,
Sticky,
}
#[derive(Debug, Clone)]
pub struct ReplicaStats {
pub replica_id: u32,
pub requests_total: u64,
pub queue_depth: (usize, usize),
pub healthy: bool,
}
pub struct ReplicaPool<T> {
replicas: Arc<RwLock<Vec<PooledReplica<T>>>>,
strategy: PoolStrategy,
round_robin_counter: AtomicUsize,
}
struct PooledReplica<T> {
replica_id: u32,
scheduler: Arc<T>,
requests_total: AtomicUsize,
}
impl<T> ReplicaPool<T>
where
T: ReplicaScheduler + Send + Sync + 'static,
{
fn paged_kv_routing_key(metrics: &EngineMetrics) -> Option<(u8, usize, usize)> {
let total_blocks = metrics.kv_cache_blocks_total;
if total_blocks == 0 {
return None;
}
let free_blocks = metrics.kv_cache_blocks_free.min(total_blocks);
let used_blocks = total_blocks.saturating_sub(free_blocks);
let utilization_permille = used_blocks.saturating_mul(1000) / total_blocks.max(1);
let pressure_tier = match utilization_permille {
0..=699 => 0,
700..=849 => 1,
850..=949 => 2,
_ => 3,
};
let free_bytes = metrics
.kv_cache_bytes_capacity
.saturating_sub(metrics.kv_cache_bytes_used);
Some((pressure_tier, free_blocks, free_bytes))
}
fn memory_routing_enabled(replicas: &[PooledReplica<T>]) -> bool {
replicas
.iter()
.filter(|replica| replica.scheduler.is_healthy())
.any(|replica| Self::paged_kv_routing_key(&replica.scheduler.get_metrics()).is_some())
}
fn routing_key(
&self,
replica: &PooledReplica<T>,
memory_aware: bool,
) -> (u8, usize, Reverse<usize>, Reverse<usize>, usize) {
let (high, low) = replica.scheduler.get_queue_depth();
let total_depth = high + low;
if memory_aware {
if let Some((pressure_tier, free_blocks, free_bytes)) =
Self::paged_kv_routing_key(&replica.scheduler.get_metrics())
{
return (
pressure_tier,
total_depth,
Reverse(free_blocks),
Reverse(free_bytes),
replica.requests_total.load(Ordering::Relaxed),
);
}
return (
u8::MAX,
total_depth,
Reverse(0usize),
Reverse(0usize),
replica.requests_total.load(Ordering::Relaxed),
);
}
(
0,
total_depth,
Reverse(0usize),
Reverse(0usize),
replica.requests_total.load(Ordering::Relaxed),
)
}
pub fn new(strategy: PoolStrategy) -> Self {
Self {
replicas: Arc::new(RwLock::new(Vec::new())),
strategy,
round_robin_counter: AtomicUsize::new(0),
}
}
pub fn add_replica(&self, replica_id: u32, scheduler: Arc<T>) {
let mut replicas = self.replicas.write();
replicas.push(PooledReplica {
replica_id,
scheduler,
requests_total: AtomicUsize::new(0),
});
}
pub fn remove_replica(&self, replica_id: u32) -> bool {
let mut replicas = self.replicas.write();
if let Some(pos) = replicas.iter().position(|r| r.replica_id == replica_id) {
replicas.remove(pos);
true
} else {
false
}
}
pub fn size(&self) -> usize {
self.replicas.read().len()
}
pub fn get_replica_stats(&self, replica_id: u32) -> Option<ReplicaStats> {
let replicas = self.replicas.read();
replicas
.iter()
.find(|r| r.replica_id == replica_id)
.map(|r| ReplicaStats {
replica_id: r.replica_id,
requests_total: r.requests_total.load(Ordering::Relaxed) as u64,
queue_depth: r.scheduler.get_queue_depth(),
healthy: r.scheduler.is_healthy(),
})
}
pub fn get_replica_count(&self) -> usize {
self.replicas.read().len()
}
pub fn get_healthy_replica_count(&self) -> usize {
self.replicas
.read()
.iter()
.filter(|replica| replica.scheduler.is_healthy())
.count()
}
pub fn stats(&self) -> Vec<ReplicaStats> {
let replicas = self.replicas.read();
let mut stats = Vec::new();
for replica in replicas.iter() {
stats.push(ReplicaStats {
replica_id: replica.replica_id,
requests_total: replica.requests_total.load(Ordering::Relaxed) as u64,
queue_depth: replica.scheduler.get_queue_depth(),
healthy: replica.scheduler.is_healthy(),
});
}
stats
}
pub async fn execute(
&self,
request: InferenceRequest,
priority: Priority,
force_cpu: bool,
) -> Result<BinaryTensorPacket, EngineError> {
let (selected_replica_id, selected_scheduler, fallback_schedulers) = {
let replicas = self.replicas.read();
let memory_aware = Self::memory_routing_enabled(&replicas);
if replicas.is_empty() {
return Err(EngineError::overloaded(
"No replicas available in pool".to_string(),
));
}
let selected_idx = match self.strategy {
PoolStrategy::RoundRobin => self.select_round_robin(&replicas),
PoolStrategy::LeastLoaded => self.select_least_loaded(&replicas),
PoolStrategy::Sticky => self.select_sticky(&replicas, &request),
};
let selected = &replicas[selected_idx];
selected.requests_total.fetch_add(1, Ordering::Relaxed);
let mut fallbacks = Vec::new();
if replicas.len() > 1 {
for (idx, other) in replicas.iter().enumerate() {
if idx == selected_idx {
continue;
}
if other.scheduler.is_healthy() {
fallbacks.push((
self.routing_key(other, memory_aware),
other.replica_id,
other.scheduler.clone(),
));
}
}
fallbacks.sort_by(|a, b| a.0.cmp(&b.0));
}
(
selected.replica_id,
selected.scheduler.clone(),
fallbacks
.into_iter()
.map(|(_, replica_id, scheduler)| (replica_id, scheduler))
.collect::<Vec<(u32, Arc<T>)>>(),
)
};
let result = selected_scheduler
.infer(&request, priority, force_cpu)
.await;
if result.is_err() && !fallback_schedulers.is_empty() {
log::warn!(
"Request failed on replica {}, attempting failover",
selected_replica_id
);
for (replica_id, scheduler) in fallback_schedulers {
if !scheduler.is_healthy() {
continue;
}
log::info!("Failing over to replica {}", replica_id);
if let Ok(response) = scheduler.infer(&request, priority, force_cpu).await {
if let Some(replicas) = self.replicas.try_read() {
if let Some(replica) = replicas.iter().find(|r| r.replica_id == replica_id)
{
replica.requests_total.fetch_add(1, Ordering::Relaxed);
}
}
return Ok(response);
}
}
}
result
}
fn select_round_robin(&self, replicas: &[PooledReplica<T>]) -> usize {
let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed);
counter % replicas.len()
}
fn select_least_loaded(&self, replicas: &[PooledReplica<T>]) -> usize {
let memory_aware = Self::memory_routing_enabled(replicas);
let mut best_idx = 0;
let mut best_key: Option<(u8, usize, Reverse<usize>, Reverse<usize>, usize)> = None;
for (idx, replica) in replicas.iter().enumerate() {
if !replica.scheduler.is_healthy() {
continue;
}
let key = self.routing_key(replica, memory_aware);
if best_key.as_ref().map(|best| key < *best).unwrap_or(true) {
best_key = Some(key);
best_idx = idx;
}
}
best_idx
}
fn select_sticky(&self, replicas: &[PooledReplica<T>], request: &InferenceRequest) -> usize {
if let Some(ref session_id) = request.session_id {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
session_id.hash(&mut hasher);
(hasher.finish() as usize) % replicas.len()
} else {
self.select_round_robin(replicas)
}
}
}
#[async_trait::async_trait]
impl<T> ReplicaScheduler for ReplicaPool<T>
where
T: ReplicaScheduler + Send + Sync + 'static,
{
fn get_queue_depth(&self) -> (usize, usize) {
if let Some(replicas) = self.replicas.try_read() {
let mut total_high = 0;
let mut total_low = 0;
for replica in replicas.iter() {
let (h, l) = replica.scheduler.get_queue_depth();
total_high += h;
total_low += l;
}
(total_high, total_low)
} else {
(0, 0)
}
}
fn is_healthy(&self) -> bool {
if let Some(replicas) = self.replicas.try_read() {
if replicas.is_empty() {
return true;
}
replicas.iter().any(|r| r.scheduler.is_healthy())
} else {
true
}
}
fn get_metrics(&self) -> kapsl_engine_api::EngineMetrics {
let mut total_memory = 0;
let mut total_gpu_util = 0.0;
let mut total_throughput = 0.0;
let mut total_kv_bytes_used = 0;
let mut total_kv_bytes_capacity = 0;
let mut total_kv_blocks_total = 0;
let mut total_kv_blocks_free = 0;
let mut total_kv_sequences = 0;
let mut total_kv_evicted_blocks = 0;
let mut total_kv_evicted_sequences = 0;
let mut total_kv_packed_layers = 0;
let mut cpu_q = 0;
let mut gpu_q = 0;
let mut count = 0;
if let Some(replicas) = self.replicas.try_read() {
count = replicas.len();
for replica in replicas.iter() {
let m = replica.scheduler.get_metrics();
let (cq, gq) = replica.scheduler.get_queue_depth();
total_memory += m.memory_usage;
total_gpu_util += m.gpu_utilization;
total_throughput += m.throughput;
total_kv_bytes_used += m.kv_cache_bytes_used;
total_kv_bytes_capacity += m.kv_cache_bytes_capacity;
total_kv_blocks_total += m.kv_cache_blocks_total;
total_kv_blocks_free += m.kv_cache_blocks_free;
total_kv_sequences += m.kv_cache_sequences;
total_kv_evicted_blocks += m.kv_cache_evicted_blocks;
total_kv_evicted_sequences += m.kv_cache_evicted_sequences;
total_kv_packed_layers += m.kv_cache_packed_layers;
cpu_q += cq;
gpu_q += gq;
}
}
kapsl_engine_api::EngineMetrics {
memory_usage: total_memory,
gpu_utilization: if count > 0 {
total_gpu_util / count as f64
} else {
0.0
},
throughput: total_throughput,
queue_depth: cpu_q + gpu_q,
kv_cache_bytes_used: total_kv_bytes_used,
kv_cache_bytes_capacity: total_kv_bytes_capacity,
kv_cache_blocks_total: total_kv_blocks_total,
kv_cache_blocks_free: total_kv_blocks_free,
kv_cache_sequences: total_kv_sequences,
kv_cache_evicted_blocks: total_kv_evicted_blocks,
kv_cache_evicted_sequences: total_kv_evicted_sequences,
kv_cache_packed_layers: total_kv_packed_layers,
..kapsl_engine_api::EngineMetrics::default()
}
}
fn model_info(&self) -> Option<kapsl_engine_api::EngineModelInfo> {
if let Some(replicas) = self.replicas.try_read() {
for replica in replicas.iter() {
if let Some(model_info) = replica.scheduler.model_info() {
return Some(model_info);
}
}
}
None
}
async fn infer(
&self,
request: &InferenceRequest,
priority: Priority,
force_cpu: bool,
) -> Result<BinaryTensorPacket, EngineError> {
self.execute(request.clone(), priority, force_cpu).await
}
async fn infer_stream(
&self,
request: InferenceRequest,
priority: Priority,
force_cpu: bool,
) -> Result<
std::pin::Pin<
Box<dyn futures::Stream<Item = Result<BinaryTensorPacket, EngineError>> + Send>,
>,
EngineError,
> {
let (selected_replica_id, selected_scheduler, fallback_schedulers) = {
let replicas = self.replicas.read();
let memory_aware = Self::memory_routing_enabled(&replicas);
if replicas.is_empty() {
return Err(EngineError::overloaded(
"No replicas available in pool".to_string(),
));
}
let selected_idx = match self.strategy {
PoolStrategy::RoundRobin => self.select_round_robin(&replicas),
PoolStrategy::LeastLoaded => self.select_least_loaded(&replicas),
PoolStrategy::Sticky => self.select_sticky(&replicas, &request),
};
let selected = &replicas[selected_idx];
selected.requests_total.fetch_add(1, Ordering::Relaxed);
let mut fallbacks = Vec::new();
if replicas.len() > 1 {
for (idx, other) in replicas.iter().enumerate() {
if idx == selected_idx {
continue;
}
if other.scheduler.is_healthy() {
fallbacks.push((
self.routing_key(other, memory_aware),
other.replica_id,
other.scheduler.clone(),
));
}
}
fallbacks.sort_by(|a, b| a.0.cmp(&b.0));
}
(
selected.replica_id,
selected.scheduler.clone(),
fallbacks
.into_iter()
.map(|(_, replica_id, scheduler)| (replica_id, scheduler))
.collect::<Vec<(u32, Arc<T>)>>(),
)
};
if selected_scheduler.is_healthy() {
match selected_scheduler
.infer_stream(request.clone(), priority, force_cpu)
.await
{
Ok(stream) => return Ok(stream),
Err(e) => {
log::warn!(
"Streaming request failed on replica {}: {}, attempting failover",
selected_replica_id,
e
);
}
}
}
if !fallback_schedulers.is_empty() {
for (replica_id, scheduler) in fallback_schedulers {
if !scheduler.is_healthy() {
continue;
}
log::info!("Failing over streaming request to replica {}", replica_id);
match scheduler
.infer_stream(request.clone(), priority, force_cpu)
.await
{
Ok(stream) => {
if let Some(replicas) = self.replicas.try_read() {
if let Some(replica) =
replicas.iter().find(|r| r.replica_id == replica_id)
{
replica.requests_total.fetch_add(1, Ordering::Relaxed);
}
}
return Ok(stream);
}
Err(_) => continue,
}
}
}
Err(EngineError::overloaded(
"All replicas failed or overloaded".to_string(),
))
}
}
#[async_trait::async_trait]
pub trait ReplicaScheduler: Send + Sync {
fn get_queue_depth(&self) -> (usize, usize);
fn is_healthy(&self) -> bool;
fn get_metrics(&self) -> kapsl_engine_api::EngineMetrics;
fn model_info(&self) -> Option<kapsl_engine_api::EngineModelInfo> {
None
}
async fn infer(
&self,
request: &InferenceRequest,
priority: Priority,
force_cpu: bool,
) -> Result<BinaryTensorPacket, EngineError>;
async fn infer_stream(
&self,
request: InferenceRequest,
priority: Priority,
force_cpu: bool,
) -> Result<
std::pin::Pin<
Box<dyn futures::Stream<Item = Result<BinaryTensorPacket, EngineError>> + Send>,
>,
EngineError,
>;
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
use kapsl_engine_api::BinaryTensorPacket;
use kapsl_engine_api::TensorDtype;
struct MockScheduler {
queue_depth: (usize, usize),
healthy: bool,
metrics: kapsl_engine_api::EngineMetrics,
}
impl MockScheduler {
fn new(queue_depth: (usize, usize), healthy: bool) -> Self {
Self {
queue_depth,
healthy,
metrics: kapsl_engine_api::EngineMetrics {
queue_depth: queue_depth.0 + queue_depth.1,
..kapsl_engine_api::EngineMetrics::default()
},
}
}
fn with_metrics(
queue_depth: (usize, usize),
healthy: bool,
metrics: kapsl_engine_api::EngineMetrics,
) -> Self {
Self {
queue_depth,
healthy,
metrics,
}
}
}
#[async_trait::async_trait]
impl ReplicaScheduler for MockScheduler {
fn get_queue_depth(&self) -> (usize, usize) {
self.queue_depth
}
fn is_healthy(&self) -> bool {
self.healthy
}
async fn infer(
&self,
request: &InferenceRequest,
_priority: Priority,
_force_cpu: bool,
) -> Result<BinaryTensorPacket, EngineError> {
if !self.healthy {
return Err(EngineError::InferenceError {
reason: "Unhealthy replica".to_string(),
source: None,
});
}
Ok(request.input.clone())
}
fn get_metrics(&self) -> kapsl_engine_api::EngineMetrics {
self.metrics.clone()
}
async fn infer_stream(
&self,
request: InferenceRequest,
_priority: Priority,
_force_cpu: bool,
) -> Result<
std::pin::Pin<
Box<dyn futures::Stream<Item = Result<BinaryTensorPacket, EngineError>> + Send>,
>,
EngineError,
> {
if !self.healthy {
return Err(EngineError::InferenceError {
reason: "Unhealthy replica".to_string(),
source: None,
});
}
let result = Ok(request.input.clone());
Ok(Box::pin(futures::stream::once(async move { result })))
}
}
#[tokio::test]
async fn test_round_robin_distribution() {
let pool = ReplicaPool::new(PoolStrategy::RoundRobin);
for i in 0..3 {
pool.add_replica(i, Arc::new(MockScheduler::new((0, 0), true)));
}
let request = InferenceRequest::new(BinaryTensorPacket {
shape: vec![1, 1],
dtype: TensorDtype::Float32,
data: vec![0, 0, 0, 0],
});
for _ in 0..9 {
let _ = pool
.execute(request.clone(), Priority::Throughput, false)
.await;
}
let stats = pool.stats();
for stat in stats {
assert_eq!(stat.requests_total, 3);
}
}
#[tokio::test]
async fn test_least_loaded_selection() {
let pool = ReplicaPool::new(PoolStrategy::LeastLoaded);
pool.add_replica(0, Arc::new(MockScheduler::new((10, 5), true)));
pool.add_replica(1, Arc::new(MockScheduler::new((2, 1), true)));
pool.add_replica(2, Arc::new(MockScheduler::new((5, 3), true)));
let request = InferenceRequest::new(BinaryTensorPacket {
shape: vec![1, 1],
dtype: TensorDtype::Float32,
data: vec![0, 0, 0, 0],
});
let _ = pool.execute(request, Priority::Throughput, false).await;
let stats = pool.stats();
assert_eq!(stats[1].requests_total, 1);
assert_eq!(stats[0].requests_total, 0);
assert_eq!(stats[2].requests_total, 0);
}
#[tokio::test]
async fn test_sticky_routing() {
let pool = ReplicaPool::new(PoolStrategy::Sticky);
for i in 0..3 {
pool.add_replica(i, Arc::new(MockScheduler::new((0, 0), true)));
}
let request = InferenceRequest::new(BinaryTensorPacket {
shape: vec![1, 1],
dtype: TensorDtype::Float32,
data: vec![0, 0, 0, 0],
})
.with_session_id("session123");
for _ in 0..5 {
let _ = pool
.execute(request.clone(), Priority::Throughput, false)
.await;
}
let stats = pool.stats();
let total_requests: u64 = stats.iter().map(|s| s.requests_total).sum();
assert_eq!(total_requests, 5);
assert!(stats.iter().any(|s| s.requests_total == 5));
}
#[tokio::test]
async fn test_failover() {
let pool = ReplicaPool::new(PoolStrategy::RoundRobin);
pool.add_replica(0, Arc::new(MockScheduler::new((0, 0), false)));
pool.add_replica(1, Arc::new(MockScheduler::new((0, 0), true)));
let request = InferenceRequest::new(BinaryTensorPacket {
shape: vec![1, 1],
dtype: TensorDtype::Float32,
data: vec![0, 0, 0, 0],
});
let result = pool.execute(request, Priority::Throughput, false).await;
assert!(result.is_ok());
let stats = pool.stats();
assert_eq!(stats[1].requests_total, 1);
}
#[tokio::test]
async fn test_streaming_failover() {
let pool = ReplicaPool::new(PoolStrategy::RoundRobin);
pool.add_replica(0, Arc::new(MockScheduler::new((0, 0), false)));
pool.add_replica(1, Arc::new(MockScheduler::new((0, 0), true)));
let request = InferenceRequest::new(BinaryTensorPacket {
shape: vec![1, 1],
dtype: TensorDtype::Float32,
data: vec![0, 0, 0, 0],
});
let result = pool
.infer_stream(request.clone(), Priority::LatencyCritical, false)
.await;
assert!(
result.is_ok(),
"Streaming request should succeed via failover"
);
let mut stream = result.unwrap();
let item = stream.next().await;
assert!(item.is_some());
assert!(item.unwrap().is_ok());
let stats = pool.stats();
assert!(stats[1].requests_total >= 1);
}
#[tokio::test]
async fn test_queue_depth_aggregation() {
let pool = ReplicaPool::new(PoolStrategy::RoundRobin);
pool.add_replica(0, Arc::new(MockScheduler::new((10, 5), true)));
pool.add_replica(1, Arc::new(MockScheduler::new((2, 1), true)));
pool.add_replica(2, Arc::new(MockScheduler::new((5, 3), true)));
let (high, low) = pool.get_queue_depth();
assert_eq!(high, 17);
assert_eq!(low, 9);
}
#[tokio::test]
async fn test_least_loaded_prefers_lower_kv_pressure_when_paged_metrics_exist() {
let pool = ReplicaPool::new(PoolStrategy::LeastLoaded);
let high_pressure_metrics = kapsl_engine_api::EngineMetrics {
kv_cache_blocks_total: 100,
kv_cache_blocks_free: 4,
kv_cache_bytes_capacity: 1_000,
kv_cache_bytes_used: 960,
..kapsl_engine_api::EngineMetrics::default()
};
let low_pressure_metrics = kapsl_engine_api::EngineMetrics {
kv_cache_blocks_total: 100,
kv_cache_blocks_free: 32,
kv_cache_bytes_capacity: 1_000,
kv_cache_bytes_used: 680,
..kapsl_engine_api::EngineMetrics::default()
};
pool.add_replica(
0,
Arc::new(MockScheduler::with_metrics(
(0, 0),
true,
high_pressure_metrics,
)),
);
pool.add_replica(
1,
Arc::new(MockScheduler::with_metrics(
(1, 0),
true,
low_pressure_metrics,
)),
);
let request = InferenceRequest::new(BinaryTensorPacket {
shape: vec![1, 1],
dtype: TensorDtype::Float32,
data: vec![0, 0, 0, 0],
});
let _ = pool.execute(request, Priority::Throughput, false).await;
let stats = pool.stats();
assert_eq!(stats[0].requests_total, 0);
assert_eq!(stats[1].requests_total, 1);
}
#[tokio::test]
async fn test_failover_prefers_more_kv_headroom() {
let pool = ReplicaPool::new(PoolStrategy::RoundRobin);
let failing_metrics = kapsl_engine_api::EngineMetrics::default();
let worse_headroom = kapsl_engine_api::EngineMetrics {
kv_cache_blocks_total: 100,
kv_cache_blocks_free: 8,
kv_cache_bytes_capacity: 1_000,
kv_cache_bytes_used: 920,
..kapsl_engine_api::EngineMetrics::default()
};
let better_headroom = kapsl_engine_api::EngineMetrics {
kv_cache_blocks_total: 100,
kv_cache_blocks_free: 40,
kv_cache_bytes_capacity: 1_000,
kv_cache_bytes_used: 600,
..kapsl_engine_api::EngineMetrics::default()
};
pool.add_replica(
0,
Arc::new(MockScheduler::with_metrics((0, 0), false, failing_metrics)),
);
pool.add_replica(
1,
Arc::new(MockScheduler::with_metrics((0, 0), true, worse_headroom)),
);
pool.add_replica(
2,
Arc::new(MockScheduler::with_metrics((0, 0), true, better_headroom)),
);
let request = InferenceRequest::new(BinaryTensorPacket {
shape: vec![1, 1],
dtype: TensorDtype::Float32,
data: vec![0, 0, 0, 0],
});
let result = pool.execute(request, Priority::Throughput, false).await;
assert!(result.is_ok());
let stats = pool.stats();
assert_eq!(stats[1].requests_total, 0);
assert_eq!(stats[2].requests_total, 1);
}
}