pub mod batch;
pub mod engine;
pub mod kv_cache_manager;
pub mod request;
pub mod scheduler;
pub use batch::{
BatchStats, BatchedRequest, DecodeTask, IterationPlan, PrefillTask, ScheduledBatch, TokenBudget,
};
pub use engine::{GenerationResult, ServingEngine, ServingEngineConfig, ServingMetrics};
pub use kv_cache_manager::{
KvCacheAllocation, KvCacheManager, KvCacheManagerStats, KvCachePoolConfig,
};
pub use request::{
CompletedRequest, FinishReason, InferenceRequest, Priority, RequestId, RequestState,
RunningRequest, TokenOutput,
};
pub use scheduler::{
ContinuousBatchScheduler, IterationScheduler, PreemptionMode, PriorityPolicy, RequestQueue,
SchedulerConfig, SchedulerStats,
};
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::{GenerateParams, NoopBackend};
use std::sync::Arc;
#[test]
fn test_full_serving_flow() {
let backend = Arc::new(NoopBackend);
let config = ServingEngineConfig {
kv_cache: KvCachePoolConfig {
num_slots: 8,
max_seq_len: 256,
block_size: 16,
total_blocks: 128,
num_kv_heads: 2,
head_dim: 64,
num_layers: 4,
},
max_concurrent_requests: 8,
..Default::default()
};
let engine = ServingEngine::new(backend, config);
let mut request_ids = Vec::new();
for i in 0..3 {
let params = GenerateParams::default().with_max_tokens(5);
let prompt: Vec<u32> = (0..10).map(|j| (i * 10 + j) as u32).collect();
let request = InferenceRequest::new(prompt, params);
let id = engine.submit(request).unwrap();
request_ids.push(id);
}
let mut iterations = 0;
let max_iterations = 100;
while iterations < max_iterations {
let outputs = engine.run_iteration().unwrap();
iterations += 1;
let all_complete = request_ids.iter().all(|id| engine.is_complete(*id));
if all_complete {
break;
}
}
for id in &request_ids {
let _ = engine.get_result(*id);
}
let metrics = engine.metrics();
assert!(metrics.total_requests_processed > 0);
}
#[test]
fn test_scheduler_continuous_batching() {
let scheduler_config = SchedulerConfig::default();
let kv_config = KvCachePoolConfig {
num_slots: 4,
max_seq_len: 128,
block_size: 16,
total_blocks: 32,
num_kv_heads: 2,
head_dim: 64,
num_layers: 4,
};
let mut scheduler = ContinuousBatchScheduler::new(scheduler_config, kv_config);
let mut queue = RequestQueue::new();
let params = GenerateParams::default().with_max_tokens(10);
let request1 = InferenceRequest::new(vec![1, 2, 3], params.clone());
queue.add(request1);
let batch1 = scheduler.schedule(&mut queue);
assert!(batch1.has_prefill);
assert_eq!(queue.running_count(), 1);
let request2 = InferenceRequest::new(vec![4, 5, 6], params);
queue.add(request2);
let batch2 = scheduler.schedule(&mut queue);
assert!(batch2.len() >= 1);
}
#[test]
fn test_priority_scheduling() {
let scheduler_config = SchedulerConfig {
priority_policy: PriorityPolicy::PriorityBased,
..Default::default()
};
let kv_config = KvCachePoolConfig::default();
let mut scheduler = ContinuousBatchScheduler::new(scheduler_config, kv_config);
let mut queue = RequestQueue::new();
let low =
InferenceRequest::new(vec![1], GenerateParams::default()).with_priority(Priority::Low);
queue.add(low);
let high =
InferenceRequest::new(vec![2], GenerateParams::default()).with_priority(Priority::High);
queue.add(high);
let batch = scheduler.schedule(&mut queue);
assert!(!batch.is_empty());
}
#[test]
fn test_kv_cache_allocation() {
let config = KvCachePoolConfig {
num_slots: 4,
max_seq_len: 128,
block_size: 16,
total_blocks: 32,
num_kv_heads: 2,
head_dim: 64,
num_layers: 4,
};
let mut manager = KvCacheManager::new(config);
let id1 = RequestId::new();
let slot1 = manager.allocate(id1, 64).unwrap();
let id2 = RequestId::new();
let slot2 = manager.allocate(id2, 64).unwrap();
assert_ne!(slot1, slot2);
manager.extend(id1, 32).unwrap();
let allocation = manager.get_allocation(id1).unwrap();
assert_eq!(allocation.current_length, 32);
manager.free(id1);
assert!(manager.get_allocation(id1).is_none());
let stats = manager.stats();
assert_eq!(stats.active_allocations, 1);
}
#[test]
fn test_iteration_plan() {
let plan = IterationPlan {
prefill_tasks: vec![PrefillTask {
request_id: RequestId::new(),
tokens: vec![1, 2, 3, 4, 5],
start_position: 0,
kv_cache_slot: 0,
block_table: vec![0],
}],
decode_tasks: vec![DecodeTask {
request_id: RequestId::new(),
input_token: 10,
position: 5,
kv_cache_slot: 1,
block_table: vec![1],
context_len: 5,
}],
evicted_requests: vec![],
swap_out_requests: vec![],
swap_in_requests: vec![],
};
assert!(plan.has_work());
assert_eq!(plan.total_requests(), 2);
assert_eq!(plan.total_tokens(), 6);
let batch = plan.to_scheduled_batch(1);
assert_eq!(batch.batch_id, 1);
assert!(batch.has_prefill);
assert!(batch.has_decode);
}
}