use crate::common::evictor::LRUEvictor;
use crate::common::perf_model::PerfModel;
use crate::common::protocols::{
DirectRequest, KvCacheEventSink, MockEngineArgs, MoveBlock, OutputSignal, PrefillCost,
WorkerType,
};
use crate::common::running_mean::RunningMean;
use crate::common::sequence::ActiveSequence;
use crate::common::utils::sleep_until_precise;
use crate::kv_manager::KvManager;
use dynamo_kv_router::protocols::DpRank;
use dynamo_tokens::blocks::UniqueBlock;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use validator::Validate;
#[derive(Clone, Default, Debug)]
pub struct MockerMetrics {
pub dp_rank: DpRank,
pub active_decode_blocks: u64,
}
pub enum Request {
Direct(DirectRequest),
Active(ActiveSequence),
}
#[derive(Default)]
struct SchedulerState {
waiting: VecDeque<Uuid>,
prefill: VecDeque<Uuid>,
decode: LRUEvictor<Uuid>,
requests: HashMap<Uuid, Request>,
prefill_costs: HashMap<Uuid, PrefillCost>,
max_num_batched_tokens: Option<usize>,
active_tokens: usize,
waiting_tokens: usize,
}
impl SchedulerState {
fn new(max_num_batched_tokens: Option<usize>) -> Self {
SchedulerState {
max_num_batched_tokens,
..Default::default()
}
}
fn is_empty(&self) -> bool {
self.requests.is_empty()
}
fn receive(&mut self, request: DirectRequest) -> Uuid {
let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
self.requests.insert(uuid, Request::Direct(request));
self.waiting.push_back(uuid);
uuid
}
fn next(&mut self) -> Option<(Uuid, Request)> {
let uuid = self.waiting.pop_front()?;
let request = self
.requests
.remove(&uuid)
.expect("Request does not exist.");
Some((uuid, request))
}
fn first_in_line(&mut self, uuid: Uuid, request: Request) {
self.requests.insert(uuid, request);
self.waiting.push_front(uuid);
}
fn move_to_prefill(&mut self, uuid: Uuid, active_seq: ActiveSequence, cost: PrefillCost) {
self.waiting_tokens += cost.new_tokens;
self.requests.insert(uuid, Request::Active(active_seq));
self.prefill.push_back(uuid);
self.prefill_costs.insert(uuid, cost);
}
fn try_prefill(&mut self, perf_model: &PerfModel) -> Option<(f64, Option<MoveBlock>, bool)> {
let uuid = self.prefill.pop_front()?;
let mut prefill_cost = self
.prefill_costs
.remove(&uuid)
.expect("Expects valid prefill cost.");
let new_tokens = prefill_cost.new_tokens;
let maybe_prefill_tokens = self.max_num_batched_tokens.and_then(|max_tokens| {
let remaining_tokens = max_tokens - self.active_tokens;
if prefill_cost.new_tokens > remaining_tokens {
Some(remaining_tokens)
} else {
None
}
});
let (prefill_compute, is_full_prefill) = if let Some(prefill_tokens) = maybe_prefill_tokens
{
let prefill_compute =
prefill_cost.predict_prefill_compute(Some(prefill_tokens), perf_model);
prefill_cost.new_tokens -= prefill_tokens;
assert!(
prefill_cost.new_tokens > 0,
"Encountered negative prefill tokens."
);
self.prefill.push_front(uuid);
self.prefill_costs.insert(uuid, prefill_cost);
self.active_tokens = self.max_num_batched_tokens.unwrap();
self.waiting_tokens -= prefill_tokens;
(prefill_compute, false)
} else {
self.decode.insert(uuid);
self.active_tokens += new_tokens;
self.waiting_tokens -= new_tokens;
(prefill_cost.predict_prefill_compute(None, perf_model), true)
};
let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
panic!("Request does not exist.");
};
Some((
prefill_compute,
sequence.take_creation_signal(),
is_full_prefill,
))
}
fn reset_active_tokens(&mut self) {
self.active_tokens = self.decode.len();
}
fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
if !self.decode.contains(&uuid) {
return None;
}
let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
panic!("Request does not exist.");
};
Some(sequence)
}
fn num_active_requests(&self) -> usize {
self.prefill.len() + self.decode.len()
}
fn complete(&mut self, uuid: &Uuid) {
tracing::trace!("Request {uuid} will complete");
self.decode.remove(uuid);
self.requests.remove(uuid);
self.prefill_costs.remove(uuid);
self.active_tokens -= 1;
}
fn preempt(&mut self) -> Vec<MoveBlock> {
let uuid = self
.decode
.evict()
.expect("Nothing to evict for preemption.");
let request = self
.requests
.remove(&uuid)
.expect("Request does not exist.");
self.prefill_costs.remove(&uuid);
self.active_tokens -= 1;
tracing::warn!("Request {uuid} will be preempted");
let Request::Active(mut active_sequence) = request else {
panic!("Expected ActiveSequence in running queue")
};
let signals = active_sequence.reset_with_signal();
self.first_in_line(uuid, Request::Active(active_sequence));
signals
}
}
#[derive(Clone)]
pub struct Scheduler {
request_tx: mpsc::UnboundedSender<DirectRequest>,
metrics_rx: tokio::sync::watch::Receiver<MockerMetrics>,
}
impl Scheduler {
pub fn new(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
cancellation_token: Option<CancellationToken>,
) -> Self {
args.validate().expect("invalid MockEngineArgs");
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let initial_metrics = MockerMetrics {
dp_rank,
active_decode_blocks: 0,
};
let (metrics_tx, metrics_rx) =
tokio::sync::watch::channel::<MockerMetrics>(initial_metrics);
let cancel_token_clone = cancellation_token.unwrap_or_default().clone();
tokio::spawn(async move {
let mut state = SchedulerState::new(args.max_num_batched_tokens);
let mut kv_manager = KvManager::new_with_event_sink(
args.num_gpu_blocks,
args.block_size,
kv_event_sink,
dp_rank,
);
let mut hit_rates = RunningMean::new(1000);
loop {
if receive_requests(&mut state, &mut request_rx, &cancel_token_clone)
.await
.is_none()
{
break;
}
try_schedule(&mut state, &kv_manager, &mut hit_rates, &args);
simulate_prefill(
&mut state,
&mut kv_manager,
&args.perf_model,
args.worker_type,
args.speedup_ratio,
)
.await;
simulate_decode(
&mut state,
&mut kv_manager,
&output_tx,
&args.perf_model,
args.block_size,
args.speedup_ratio,
)
.await;
let _ = metrics_tx.send(MockerMetrics {
dp_rank,
active_decode_blocks: kv_manager.num_active_blocks() as u64,
});
}
});
Self {
request_tx,
metrics_rx,
}
}
pub async fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request);
}
pub fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone()
}
pub fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
self.metrics_rx.clone()
}
}
async fn receive_requests(
state: &mut SchedulerState,
request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
cancel_token: &CancellationToken,
) -> Option<()> {
if cancel_token.is_cancelled() {
return None;
}
if state.is_empty() {
tokio::select! {
biased;
_ = cancel_token.cancelled() => {
return None;
}
Some(request) = request_rx.recv() => {
state.receive(request);
return Some(());
}
}
}
while let Ok(request) = request_rx.try_recv() {
state.receive(request);
}
Some(())
}
async fn simulate_prefill(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
perf_model: &PerfModel,
worker_type: WorkerType,
speedup_ratio: f64,
) -> Duration {
let start_time = Instant::now();
let mut total_time = Duration::ZERO;
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) =
state.try_prefill(perf_model)
{
if worker_type != WorkerType::Decode {
total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
}
if let Some(creation_signal) = maybe_creation_signal
&& !process_signals(kv_manager, std::slice::from_ref(&creation_signal))
{
panic!("Block allocation for prefilling cannot fail.");
}
if !is_full_prefill {
break;
}
}
if speedup_ratio > 0.0 && total_time > Duration::ZERO {
let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / speedup_ratio);
let deadline = start_time + sleep_duration;
sleep_until_precise(deadline).await;
}
total_time
}
async fn simulate_decode(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
perf_model: &PerfModel,
block_size: usize,
speedup_ratio: f64,
) -> Duration {
let start_time = Instant::now();
let active_kv_tokens = kv_manager.num_active_blocks() * block_size;
let total_length: usize = state
.decode
.keys()
.map(|uuid| {
if let Request::Active(seq) = state.requests.get(uuid).unwrap() {
seq.len()
} else {
0
}
})
.sum();
let count = state.decode.len();
let context_length = if count > 0 { total_length / count } else { 0 };
let decoding_time = perf_model.predict_decode_time(active_kv_tokens, context_length);
let total_time = Duration::from_secs_f64(decoding_time / 1000.0);
state.reset_active_tokens();
let uuids: Vec<Uuid> = state.decode.keys().cloned().collect();
for uuid in uuids {
let Some(sequence) = state.run(uuid) else {
continue;
};
let signals = sequence.generate();
if !process_signals(kv_manager, &signals) {
sequence.pop(); for signal in state.preempt() {
kv_manager.process(&signal);
}
continue;
}
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
let send_failed = output_tx.as_ref().is_some_and(|tx| {
tx.send(OutputSignal {
uuid,
completed: is_complete,
})
.is_err()
});
if send_failed {
for signal in &sequence.free_signal() {
kv_manager.process(signal);
}
}
if send_failed || is_complete {
state.complete(&uuid);
}
}
if speedup_ratio > 0.0 && total_time > Duration::ZERO {
let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / speedup_ratio);
let deadline = start_time + sleep_duration;
sleep_until_precise(deadline).await;
}
total_time
}
fn try_schedule(
state: &mut SchedulerState,
kv_manager: &KvManager,
hit_rates: &mut RunningMean<f32>,
args: &MockEngineArgs,
) -> usize {
let mut scheduled_count = 0;
let mut current_blocks = kv_manager.num_active_blocks();
let mut current_tokens = state.active_tokens + state.waiting_tokens;
let mut current_seqs = state.num_active_requests();
while let Some((uuid, request)) = state.next() {
let active_sequence = match request {
Request::Active(active_seq) => active_seq,
Request::Direct(direct_request) => ActiveSequence::new(
direct_request.tokens,
direct_request.max_output_tokens,
Some(args.block_size),
args.enable_prefix_caching,
args.zmq_kv_events_port.is_some(),
),
};
let prefill_cost = kv_manager.get_prefill_cost(&active_sequence);
let total_tokens = active_sequence.len();
let new_blocks = (total_tokens as u32).div_ceil(args.block_size as u32) as usize;
let new_tokens = prefill_cost.new_tokens;
current_blocks += new_blocks;
current_tokens += new_tokens;
current_seqs += 1;
let under_block_budget =
current_blocks as f64 <= (1. - args.watermark) * kv_manager.max_capacity() as f64;
let comparison_tokens = if args.enable_chunked_prefill {
current_tokens - new_tokens
} else {
current_tokens
};
let under_token_budget = args
.max_num_batched_tokens
.is_none_or(|limit| comparison_tokens <= limit);
let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit);
if !(under_block_budget && under_token_budget && under_seq_budget) {
state.first_in_line(uuid, Request::Active(active_sequence));
break;
}
let hit_rate = if !active_sequence.is_empty() {
1.0 - (new_tokens as f32 / active_sequence.len() as f32)
} else {
0.0
};
hit_rates.push(hit_rate);
state.move_to_prefill(uuid, active_sequence, prefill_cost);
scheduled_count += 1;
}
scheduled_count
}
fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool {
for signal in signals {
if kv_manager.process(signal) {
continue;
}
let MoveBlock::Use(blocks, _hashes, ..) = signal else {
panic!(
"Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}"
);
};
let num_blocks = blocks.len();
let num_active_blocks = kv_manager.num_active_blocks();
if num_blocks != 1 {
panic!(
"Failed signal is Invalid. Tried to create (prefill) {num_blocks} blocks on top of {num_active_blocks} active blocks."
);
}
if !matches!(blocks[0], UniqueBlock::PartialBlock(_)) {
panic!("Failed signal is Invalid. Generation block has to be partial.");
}
return false;
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use std::time::Duration;
use tokio::time::interval;
fn assert_scheduler_idle(metrics: &MockerMetrics) {
assert_eq!(
metrics.active_decode_blocks, 0,
"Expected 0 active blocks, got {}",
metrics.active_decode_blocks
);
}
#[rstest]
#[case::case_1(false, false, false)]
#[case::case_2(false, true, false)]
#[case::case_3(true, false, false)]
#[case::case_4(true, true, false)]
#[case::case_5(false, false, true)]
#[case::case_6(false, true, true)]
#[case::case_7(true, false, true)]
#[case::case_8(true, true, true)]
#[tokio::test]
async fn test_scheduler_token_generation_patterns(
#[case] use_shared_tokens: bool,
#[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) {
unsafe { std::env::set_var("RUST_LOG", "debug") };
let kv_capacity: usize = 500;
let block_size: usize = 64;
let num_requests: usize = 200;
let input_len: usize = 1000;
let max_output_tokens: usize = 100;
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(kv_capacity)
.block_size(block_size)
.speedup_ratio(10.0)
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.build()
.unwrap();
let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
let shared_tokens = if use_shared_tokens {
Some(
(0..input_len / 2)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>(),
)
} else {
None
};
for _ in 0..num_requests {
let input_tokens = if let Some(ref shared) = shared_tokens {
let mut tokens = shared.clone();
tokens.extend((0..input_len / 2).map(|_| rand::random::<u32>() % 50000));
tokens
} else {
(0..input_len)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>()
};
let request = DirectRequest {
tokens: input_tokens,
max_output_tokens,
uuid: None,
dp_rank: 0,
};
scheduler.receive(request).await;
}
let start_time = std::time::Instant::now();
let expected_tokens = num_requests * max_output_tokens;
let mut received_tokens = 0;
let timeout = tokio::time::sleep(Duration::from_secs(2));
tokio::pin!(timeout);
let metrics_rx = scheduler.metrics_receiver();
let mut debug_interval = interval(Duration::from_millis(500));
loop {
tokio::select! {
biased;
_ = debug_interval.tick() => {
let _metrics = metrics_rx.borrow().clone();
tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
}
Some(_) = output_rx.recv() => {
received_tokens += 1;
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => {
break;
}
}
}
let elapsed = start_time.elapsed();
println!(
"Test completed in: {elapsed:?} for {} case with prefix_caching={enable_prefix_caching} and chunked_prefill={enable_chunked_prefill}",
if use_shared_tokens {
"caching"
} else {
"random"
}
);
assert!(
received_tokens == expected_tokens,
"Received {received_tokens} tokens but expected exactly {expected_tokens}"
);
tokio::time::sleep(Duration::from_millis(100)).await;
let metrics = scheduler.metrics_receiver().borrow().clone();
assert_scheduler_idle(&metrics);
}
#[tokio::test]
async fn test_cache_hit_rate_with_identical_requests() {
let block_size: usize = 64;
let max_output_tokens: usize = 10;
let speedup_ratio = 10.0;
let num_requests = 10;
let token_length = 65;
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(100) .block_size(block_size)
.speedup_ratio(speedup_ratio)
.build()
.unwrap();
let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
let identical_tokens: Vec<u32> = (0..token_length).map(|i| i as u32).collect();
for _ in 0..num_requests {
let request = DirectRequest {
tokens: identical_tokens.clone(),
max_output_tokens,
uuid: None,
dp_rank: 0,
};
scheduler.receive(request).await;
tokio::time::sleep(Duration::from_millis(100)).await;
}
let mut received_tokens = 0;
let timeout = tokio::time::sleep(Duration::from_millis(500));
tokio::pin!(timeout);
let metrics_rx = scheduler.metrics_receiver();
let mut debug_interval = interval(Duration::from_millis(500));
loop {
tokio::select! {
biased;
_ = debug_interval.tick() => {
let _metrics = metrics_rx.borrow().clone();
tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
}
Some(_signal) = output_rx.recv() => {
received_tokens += 1;
timeout.set(tokio::time::sleep(Duration::from_millis(500)));
}
_ = &mut timeout => {
break;
}
}
}
tokio::time::sleep(Duration::from_millis(100)).await;
let metrics = metrics_rx.borrow().clone();
assert_scheduler_idle(&metrics);
println!("Test passed! Received {received_tokens} tokens");
}
#[tokio::test]
async fn test_receiver_drop_cleans_up_resources() {
let block_size: usize = 64;
let input_tokens = 256;
let max_output_tokens = 200;
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(10) .block_size(block_size)
.speedup_ratio(100.0) .build()
.unwrap();
let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
let tokens: Vec<u32> = (0..input_tokens).map(|i| i as u32).collect();
let request = DirectRequest {
tokens,
max_output_tokens,
uuid: None,
dp_rank: 0,
};
scheduler.receive(request).await;
let mut received_count = 0;
while received_count < 129 {
if let Some(_signal) = output_rx.recv().await {
received_count += 1;
} else {
panic!("Channel closed before receiving 129 tokens");
}
}
drop(output_rx);
tokio::time::sleep(Duration::from_secs(1)).await;
let metrics_rx = scheduler.metrics_receiver();
let metrics = metrics_rx.borrow().clone();
assert_scheduler_idle(&metrics);
}
}