dynamo_llm/mocker/
scheduler.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Asynchronous Scheduler for LLM Request Management
5//!
6//! This module implements an asynchronous scheduler that handles three main functions:
7//! 1. Receiving new requests and placing them in the waiting queue
8//! 2. Scheduling waiting requests against available KV cache resources
9//! 3. Simulating the execution of running requests with realistic timing
10//!
11//! ## Scheduling Process
12//! The scheduler uses a watermark-based approach to determine if there's sufficient
13//! KV cache space for new requests. It also enforces a batched tokens budget to prevent
14//! oversubscription of computational resources. Only requests that can be allocated
15//! these resources are moved from waiting to running state.
16//!
17//! ## Request Simulation
18//! The simulation models two key phases:
19//! - Prefill phase: Uses a quadratic cost function: (cached_tokens + new_tokens) * new_tokens
20//! - Decode phase: Uses a cost function proportional to active KV blocks (linear)
21//!
22//! ## Resource Management
23//! The scheduler communicates with the KvManager through MoveBlock signals at each
24//! stage of request processing. When resources become constrained, it employs an
25//! LRU-based preemption strategy where the oldest running request is evicted and
26//! placed at the back of the waiting queue to be rescheduled later.
27//!
28//! ## NOTE
29//! The current prefill and decoding time simulations are not scientific at all and are WIP
30
31use crate::kv_router::protocols::{ForwardPassMetrics, KvCacheEventData, KvStats, WorkerStats};
32use crate::mocker::evictor::LRUEvictor;
33use crate::mocker::kv_manager::KvManager;
34use crate::mocker::protocols::{DirectRequest, MockEngineArgs, MoveBlockResponse};
35use crate::mocker::protocols::{MoveBlock, OutputSignal, PrefillCost, block_response_to_kv_event};
36use crate::mocker::sequence::ActiveSequence;
37use crate::tokens::BlockHash;
38use crate::tokens::blocks::UniqueBlock;
39use std::collections::HashMap;
40use std::collections::VecDeque;
41use std::sync::Arc;
42use tokio::sync::{Mutex, mpsc};
43use tokio::time::Duration;
44use tokio_util::sync::CancellationToken;
45use uuid::Uuid;
46
47/// Enum representing either a direct request or an active sequence
48pub enum Request {
49    Direct(DirectRequest),
50    Active(ActiveSequence),
51}
52
53#[derive(Default)]
54struct SchedulerState {
55    waiting: VecDeque<Uuid>,
56    prefill: VecDeque<Uuid>,
57    decode: LRUEvictor<Uuid>,
58    requests: HashMap<Uuid, Request>,
59    prefill_costs: HashMap<Uuid, PrefillCost>,
60    max_num_batched_tokens: Option<usize>,
61    active_tokens: usize,
62    waiting_tokens: usize,
63}
64
65impl SchedulerState {
66    fn new(max_num_batched_tokens: Option<usize>) -> Self {
67        SchedulerState {
68            max_num_batched_tokens,
69            ..Default::default()
70        }
71    }
72
73    fn is_empty(&self) -> bool {
74        self.requests.is_empty()
75    }
76
77    /// Create a new UUID for a DirectRequest, add it to requests, and push the UUID to waiting.
78    fn receive(&mut self, request: DirectRequest) -> Uuid {
79        // Use the provided UUID if available, otherwise generate a new one
80        let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
81        self.requests.insert(uuid, Request::Direct(request));
82        self.waiting.push_back(uuid);
83        uuid
84    }
85
86    /// Get the next UUID from ready or waiting queue and its associated Request.
87    fn next(&mut self) -> Option<(Uuid, Request)> {
88        let uuid = self.waiting.pop_front()?;
89        let request = self
90            .requests
91            .remove(&uuid)
92            .expect("Request does not exist.");
93        Some((uuid, request))
94    }
95
96    /// Move a UUID and its Request to the waiting queue (front).
97    fn first_in_line(&mut self, uuid: Uuid, request: Request) {
98        self.requests.insert(uuid, request);
99        self.waiting.push_front(uuid);
100    }
101
102    /// Move a UUID and its Request to the ready queue.
103    fn move_to_prefill(&mut self, uuid: Uuid, active_seq: ActiveSequence, cost: PrefillCost) {
104        self.waiting_tokens += cost.new_tokens;
105        self.requests.insert(uuid, Request::Active(active_seq));
106        self.prefill.push_back(uuid);
107        self.prefill_costs.insert(uuid, cost);
108    }
109
110    /// Try (chunked) prefill and move to decode queue
111    ///
112    /// Returns `Some((prefill_compute, creation_signal, is_full_prefill))` where:
113    /// - `prefill_compute`: The compute time in milliseconds for this prefill operation
114    /// - `creation_signal`: Optional MoveBlock signal for KV cache block creation
115    /// - `block_hashes`: Block hashes of the sequence beign prefilled
116    /// - `is_full_prefill`: true if the entire sequence was prefilled, false if chunked
117    fn try_prefill(&mut self) -> Option<(f64, Option<MoveBlock>, Vec<BlockHash>, bool)> {
118        let uuid = self.prefill.pop_front()?;
119
120        // Remove and extract prefill_compute from prefill_costs
121        let mut prefill_cost = self
122            .prefill_costs
123            .remove(&uuid)
124            .expect("Expects valid prefill cost.");
125
126        let new_tokens = prefill_cost.new_tokens;
127
128        let maybe_prefill_tokens = self.max_num_batched_tokens.and_then(|max_tokens| {
129            let remaining_tokens = max_tokens - self.active_tokens;
130            if prefill_cost.new_tokens > remaining_tokens {
131                Some(remaining_tokens)
132            } else {
133                None
134            }
135        });
136
137        let (prefill_compute, is_full_prefill) = if let Some(prefill_tokens) = maybe_prefill_tokens
138        {
139            let prefill_compute = prefill_cost.predict_prefill_compute(Some(prefill_tokens));
140            prefill_cost.new_tokens -= prefill_tokens;
141            assert!(
142                (prefill_cost.new_tokens > 0) && (prefill_compute > 0.0),
143                "Encountered negative prefill tokens or prefill compute cost."
144            );
145
146            self.prefill.push_front(uuid);
147            self.prefill_costs.insert(uuid, prefill_cost);
148
149            self.active_tokens = self.max_num_batched_tokens.unwrap();
150            self.waiting_tokens -= prefill_tokens;
151
152            (prefill_compute, false)
153        } else {
154            // Assume possible to complete prefilling the sequence, transfer to decode
155            self.decode.insert(uuid);
156
157            self.active_tokens += new_tokens;
158            self.waiting_tokens -= new_tokens;
159
160            (prefill_cost.predict_prefill_compute(None), true)
161        };
162
163        // NOTE: the current behavior allocates the KV blocks for the entire sequence,
164        // even if only a chunk is prefilled
165        let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
166            panic!("Request does not exist.");
167        };
168
169        Some((
170            prefill_compute,
171            sequence.take_creation_signal(),
172            sequence.block_hashes(),
173            is_full_prefill,
174        ))
175    }
176
177    // assume (chunked) prefills are completed, then active tokens would be 1 per decoding sequence
178    fn reset_active_tokens(&mut self) {
179        self.active_tokens = self.decode.len();
180    }
181
182    fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
183        if !self.decode.contains(&uuid) {
184            return None;
185        }
186        let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
187            panic!("Request does not exist.");
188        };
189        Some(sequence)
190    }
191
192    fn num_active_requests(&self) -> usize {
193        self.prefill.len() + self.decode.len()
194    }
195
196    /// Remove a UUID and its associated Request from collections.
197    fn complete(&mut self, uuid: &Uuid) {
198        tracing::trace!("Request {uuid} will complete");
199        self.decode.remove(uuid);
200        self.requests.remove(uuid);
201        self.prefill_costs.remove(uuid);
202        self.active_tokens -= 1;
203    }
204
205    /// Preempt the oldest running request by evicting it from running, resetting the sequence,
206    /// and adding it back to the waiting queue.
207    /// Returns the signal from reset_with_signal or None if no requests are running.
208    fn preempt(&mut self) -> Vec<MoveBlock> {
209        // Evict the oldest UUID from running
210        let uuid = self
211            .decode
212            .evict()
213            .expect("Nothing to evict for preemption.");
214        let request = self
215            .requests
216            .remove(&uuid)
217            .expect("Request does not exist.");
218        self.prefill_costs.remove(&uuid);
219        self.active_tokens -= 1;
220        tracing::warn!("Request {uuid} will be preempted");
221
222        // Reset the sequence and get the new sequence and signal
223        // Insert the new sequence back into the requests map and add to waiting queue
224        let Request::Active(mut active_sequence) = request else {
225            panic!("Expected ActiveSequence in running queue")
226        };
227        let signals = active_sequence.reset_with_signal();
228
229        // Note: For preemption, we don't compute hit rate since we don't have access to new_tokens
230        // and the sequence is being reset anyway. Hit rate tracking is primarily for new scheduling attempts.
231
232        self.first_in_line(uuid, Request::Active(active_sequence));
233
234        signals
235    }
236}
237
238/// Manages scheduling of requests using KvManager resources
239#[derive(Clone)]
240pub struct Scheduler {
241    state: Arc<Mutex<SchedulerState>>,
242    kv_manager: Arc<Mutex<KvManager>>,
243    request_tx: mpsc::UnboundedSender<DirectRequest>,
244    metrics_rx: tokio::sync::watch::Receiver<ForwardPassMetrics>,
245}
246
247impl Scheduler {
248    /// Create a new Scheduler with the given parameters
249    pub fn new(
250        args: MockEngineArgs,
251        dp_rank: Option<u32>,
252        output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
253        kv_events_tx: Option<mpsc::UnboundedSender<KvCacheEventData>>,
254        cancellation_token: Option<CancellationToken>,
255    ) -> Self {
256        let state = Arc::new(Mutex::new(SchedulerState::new(args.max_num_batched_tokens)));
257
258        // Create internal channel for KV events only if needed
259        let (block_resp_tx, mut block_resp_rx) = if kv_events_tx.is_some() {
260            let (tx, rx) = mpsc::unbounded_channel::<MoveBlockResponse>();
261            (Some(tx), Some(rx))
262        } else {
263            (None, None)
264        };
265
266        let kv_manager = Arc::new(Mutex::new(KvManager::new_with_sender(
267            args.num_gpu_blocks,
268            args.block_size,
269            block_resp_tx,
270        )));
271        let hit_rates = Arc::new(Mutex::new(VecDeque::with_capacity(1000)));
272
273        // Assert speedup_ratio is greater than 0
274        assert!(
275            args.speedup_ratio > 0.0,
276            "speedup_ratio must be greater than 0, got: {}",
277            args.speedup_ratio
278        );
279
280        // Create channel for request handling
281        let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
282        let mut initial_metrics = ForwardPassMetrics::default();
283        initial_metrics.worker_stats.data_parallel_rank = dp_rank;
284        let (metrics_tx, metrics_rx) =
285            tokio::sync::watch::channel::<ForwardPassMetrics>(initial_metrics);
286
287        // Create a clone for the background task
288        let state_clone = state.clone();
289        let kv_manager_clone = kv_manager.clone();
290        let output_tx_clone = output_tx.clone();
291        let cancel_token_clone = cancellation_token.unwrap_or_default().clone();
292
293        // Spawn main background task with cancellation token
294        tokio::spawn(async move {
295            let mut should_schedule = true;
296
297            loop {
298                {
299                    let state_guard = state_clone.lock().await;
300
301                    // Enqueue new request, blocks until at least one is received, so no redundant work is done
302                    // TODO: clean this up? double lock acquisition is ugly, but needed to not hold the lock forever
303                    if state_guard.is_empty() {
304                        drop(state_guard);
305                        let Some(request) = request_rx.recv().await else {
306                            tracing::warn!("request sender is dropped");
307                            break;
308                        };
309                        let mut state_guard = state_clone.lock().await;
310                        state_guard.receive(request);
311                    }
312                }
313
314                tokio::select! {
315                    biased;
316
317                    // Enqueue new request
318                    Some(request) = request_rx.recv() => {
319                        let mut state = state_clone.lock().await;
320                        state.receive(request);
321                    }
322
323                    // Try Scheduling Requests - runs on normal interval or after simulation
324                    _ = tokio::task::yield_now() => {
325                        // Skip if we just ran scheduling after simulation to prevent consecutive runs
326                        if !should_schedule {
327                            continue;
328                        }
329
330                        let mut state_guard = state_clone.lock().await;
331                        let kv_manager_guard = kv_manager_clone.lock().await;
332
333                        // Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't
334                        // schedule anymore.
335                        let mut current_blocks = kv_manager_guard.num_active_blocks();
336                        let mut current_tokens = state_guard.active_tokens + state_guard.waiting_tokens;
337                        let mut current_seqs = state_guard.num_active_requests();
338
339                        while let Some((uuid, request)) = state_guard.next() {
340                            let active_sequence = get_active_sequence(request, args.block_size, args.enable_prefix_caching);
341
342                            // Update predictive budgets
343                            let prefill_cost = kv_manager_guard.get_prefill_cost(&active_sequence);
344                            let total_tokens = active_sequence.len();
345                            // this is conservative, assumes no cache hit so never over-schedules
346                            let new_blocks = (total_tokens as u32).div_ceil(args.block_size as u32) as usize;
347                            let new_tokens = prefill_cost.new_tokens;
348
349                            current_blocks += new_blocks;
350                            current_tokens += new_tokens;
351                            current_seqs += 1;
352
353                            // Check various budgets to see if possible to schedule
354                            let under_block_budget = current_blocks as f64 <= (1. - args.watermark) * kv_manager_guard.max_capacity() as f64;
355                            // If chunked prefill is enabled, we can be under token budget when scheduling
356                            let comparison_tokens = if args.enable_chunked_prefill {current_tokens - new_tokens} else {current_tokens};
357                            let under_token_budget = args.max_num_batched_tokens.is_none_or(|limit| comparison_tokens <= limit);
358                            let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit);
359
360                            // Cannot schedule, put first in line instead
361                            if !(under_block_budget && under_token_budget && under_seq_budget) {
362                                state_guard.first_in_line(uuid, Request::Active(active_sequence));
363                                break;
364                            }
365
366                            // Compute and store hit rate
367                            let hit_rate = if !active_sequence.is_empty() { 1.0 - (new_tokens as f32 / active_sequence.len() as f32) } else { 0.0 };
368                            {
369                                let mut hit_rates_guard = hit_rates.lock().await;
370                                hit_rates_guard.push_back(hit_rate);
371                                if hit_rates_guard.len() > 1000 {
372                                    hit_rates_guard.pop_front();
373                                }
374                            }
375
376                            state_guard.move_to_prefill(uuid, active_sequence, prefill_cost);
377                            should_schedule = false;
378                        }
379                    }
380
381                    // Check for cancellation
382                    _ = cancel_token_clone.cancelled() => {
383                        break;
384                    }
385                }
386
387                // Simulates prefill + decode
388                let mut state_guard = state_clone.lock().await;
389                let mut kv_manager_guard = kv_manager_clone.lock().await;
390
391                // Base time needed for decoding using active percentage and quadratic formula
392                let active_perc = kv_manager_guard.get_active_perc();
393                let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44;
394                let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);
395
396                // Process prefilling
397                while let Some((
398                    prefill_compute,
399                    maybe_creation_signal,
400                    block_hashes,
401                    is_full_prefill,
402                )) = state_guard.try_prefill()
403                {
404                    // NOTE: Prefill cost/time is always incremented for new blocks, even if they
405                    // could be cached by other requests in the same batch. This matches vLLM behavior.
406                    total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
407
408                    if let Some(creation_signal) = maybe_creation_signal {
409                        if !process_signals(
410                            &mut kv_manager_guard,
411                            std::slice::from_ref(&creation_signal),
412                        ) {
413                            panic!("Block allocation for prefilling cannot fail.");
414                        }
415
416                        // Drain KV events and forward to relay after prefill signal processing
417                        if let (Some(relay_tx), Some(rx)) = (&kv_events_tx, &mut block_resp_rx) {
418                            while let Ok(event) = rx.try_recv() {
419                                let _ =
420                                    relay_tx.send(block_response_to_kv_event(event, &block_hashes));
421                            }
422                        }
423                    };
424
425                    // Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
426                    if !is_full_prefill {
427                        break;
428                    }
429                }
430
431                state_guard.reset_active_tokens();
432
433                {
434                    let hit_rates_guard = hit_rates.lock().await;
435                    let metrics = get_fwd_pass_metrics(
436                        &state_guard,
437                        &kv_manager_guard,
438                        &hit_rates_guard,
439                        dp_rank,
440                    );
441                    let _ = metrics_tx.send(metrics);
442                }
443
444                // Process decoding
445                let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect();
446                if !uuids.is_empty() {
447                    should_schedule = true
448                };
449                for uuid in uuids {
450                    let Some(sequence) = state_guard.run(uuid) else {
451                        continue;
452                    };
453                    let signals = sequence.generate();
454
455                    // Process all signals with the KvManager
456                    // Handling of preemption on failure
457                    if !process_signals(&mut kv_manager_guard, &signals) {
458                        sequence.pop(); // revert the failed generation op
459                        for signal in state_guard.preempt() {
460                            kv_manager_guard.process(&signal);
461                        }
462                        continue;
463                    }
464
465                    // Drain KV events and forward to relay after decode signal processing
466                    if let (Some(relay_tx), Some(rx)) = (&kv_events_tx, &mut block_resp_rx) {
467                        while let Ok(event) = rx.try_recv() {
468                            let _ = relay_tx
469                                .send(block_response_to_kv_event(event, &sequence.block_hashes()));
470                        }
471                    }
472
473                    // Check completion and send notification
474                    let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
475                    let should_output =
476                        sequence.generated_tokens() > sequence.already_generated_tokens();
477
478                    let mut send_failed = false;
479                    if should_output {
480                        send_failed = output_tx_clone.as_ref().is_some_and(|tx| {
481                            tx.send(OutputSignal {
482                                uuid,
483                                completed: is_complete,
484                            })
485                            .is_err()
486                        });
487                    }
488
489                    if send_failed {
490                        for signal in &sequence.free_signal() {
491                            kv_manager_guard.process(signal);
492                        }
493                    }
494
495                    {
496                        let hit_rates_guard = hit_rates.lock().await;
497                        let metrics = get_fwd_pass_metrics(
498                            &state_guard,
499                            &kv_manager_guard,
500                            &hit_rates_guard,
501                            dp_rank,
502                        );
503                        let _ = metrics_tx.send(metrics);
504                    }
505
506                    if send_failed || is_complete {
507                        state_guard.complete(&uuid);
508                        continue;
509                    }
510                }
511
512                // Sleep once for the adjusted duration
513                drop(kv_manager_guard);
514                drop(state_guard);
515                let adjusted_time =
516                    Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
517                if adjusted_time.as_millis() > 0 {
518                    tokio::time::sleep(adjusted_time).await;
519                }
520            }
521        });
522
523        Self {
524            state,
525            kv_manager,
526            request_tx,
527            metrics_rx,
528        }
529    }
530
531    /// Add a new request to the waiting queue
532    pub async fn receive(&self, request: DirectRequest) {
533        let _ = self.request_tx.send(request);
534    }
535
536    pub fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
537        self.request_tx.clone()
538    }
539
540    pub async fn waiting_count(&self) -> usize {
541        let state = self.state.lock().await;
542        state.waiting.len()
543    }
544
545    pub async fn running_count(&self) -> usize {
546        let state = self.state.lock().await;
547        state.decode.len()
548    }
549
550    pub async fn waiting_tokens(&self) -> usize {
551        let state = self.state.lock().await;
552        state.waiting_tokens
553    }
554
555    pub async fn active_tokens(&self) -> usize {
556        let state = self.state.lock().await;
557        state.active_tokens
558    }
559
560    pub async fn kv_usage_perc(&self) -> f64 {
561        let kv_manager = self.kv_manager.lock().await;
562        kv_manager.current_capacity_perc()
563    }
564
565    /// Get a watch receiver for forward pass metrics
566    pub fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<ForwardPassMetrics> {
567        self.metrics_rx.clone()
568    }
569}
570
571/// Calculate forward pass metrics from current state
572fn get_fwd_pass_metrics(
573    state: &SchedulerState,
574    kv_manager: &KvManager,
575    hit_rates: &VecDeque<f32>,
576    dp_rank: Option<u32>,
577) -> ForwardPassMetrics {
578    // Get state metrics
579    let request_active_slots = state.decode.len() as u64;
580    let num_requests_waiting = state.waiting.len() as u64;
581
582    // Get KV manager metrics
583    let active_blocks_count = kv_manager.active_blocks().len() as u64;
584    let total_capacity = kv_manager.max_capacity() as u64;
585    let gpu_cache_usage_perc = if total_capacity > 0 {
586        active_blocks_count as f32 / total_capacity as f32
587    } else {
588        0.0
589    };
590
591    // Get hit rate metrics
592    let gpu_prefix_cache_hit_rate = if hit_rates.is_empty() {
593        0.0
594    } else {
595        let sum: f32 = hit_rates.iter().sum();
596        sum / hit_rates.len() as f32
597    };
598
599    let worker_stats = WorkerStats {
600        data_parallel_rank: dp_rank,
601        request_active_slots,
602        request_total_slots: 1024, // vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128
603        num_requests_waiting,
604    };
605
606    let kv_stats = KvStats {
607        kv_active_blocks: active_blocks_count,
608        kv_total_blocks: total_capacity,
609        gpu_cache_usage_perc,
610        gpu_prefix_cache_hit_rate,
611    };
612
613    let spec_decode_stats = None;
614
615    ForwardPassMetrics {
616        worker_stats,
617        kv_stats,
618        spec_decode_stats,
619    }
620}
621
622/// Convert a Request to an ActiveSequence
623fn get_active_sequence(
624    request: Request,
625    block_size: usize,
626    enable_prefix_caching: bool,
627) -> ActiveSequence {
628    if let Request::Active(active_seq) = request {
629        return active_seq;
630    }
631
632    let Request::Direct(direct_request) = request else {
633        unreachable!("Request must be either Direct or Active");
634    };
635
636    ActiveSequence::new(
637        direct_request.tokens,
638        direct_request.max_output_tokens,
639        Some(block_size),
640        enable_prefix_caching,
641    )
642}
643
644/// Processes MoveBlock signals with the KvManager.
645///
646/// When a signal fails, this function verifies that the failure is for an expected case:
647/// specifically a single signal attempting to create a single partial (generation) block.
648/// This validation is important because in normal operation, the only legitimate failure
649/// case should be when trying to acquire a new generation block - any other failures would
650/// indicate an unexpected state in the system.
651fn process_signals(
652    kv_manager_guard: &mut tokio::sync::MutexGuard<'_, KvManager>,
653    signals: &[MoveBlock],
654) -> bool {
655    for signal in signals {
656        if kv_manager_guard.process(signal) {
657            continue;
658        }
659
660        // Check we have a Use signal with blocks
661        let MoveBlock::Use(blocks) = signal else {
662            panic!(
663                "Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}"
664            );
665        };
666
667        // Verify the signal contains exactly one block
668        let num_blocks = blocks.len();
669        let num_active_blocks = kv_manager_guard.num_active_blocks();
670        if num_blocks != 1 {
671            panic!(
672                "Failed signal is Invalid. Tried to create (prefill) {num_blocks} blocks on top of {num_active_blocks} active blocks."
673            );
674        }
675
676        // Verify the block is a PartialBlock (generation block)
677        if !matches!(blocks[0], UniqueBlock::PartialBlock(_)) {
678            panic!("Failed signal is Invalid. Generation block has to be partial.");
679        }
680
681        return false;
682    }
683
684    true
685}
686
687#[cfg(test)]
688mod tests {
689    use super::*;
690    use rstest::rstest;
691    use std::time::Duration;
692    use tokio::time::interval;
693
694    #[rstest]
695    #[case::case_1(false, false, false)]
696    #[case::case_2(false, true, false)]
697    #[case::case_3(true, false, false)]
698    #[case::case_4(true, true, false)]
699    #[case::case_5(false, false, true)]
700    #[case::case_6(false, true, true)]
701    #[case::case_7(true, false, true)]
702    #[case::case_8(true, true, true)]
703    #[tokio::test]
704    async fn test_scheduler_token_generation_patterns(
705        #[case] use_shared_tokens: bool,
706        #[case] enable_prefix_caching: bool,
707        #[case] enable_chunked_prefill: bool,
708    ) {
709        unsafe { std::env::set_var("RUST_LOG", "debug") };
710
711        let kv_capacity: usize = 500;
712        let block_size: usize = 64;
713        let num_requests: usize = 200;
714        let input_len: usize = 1000;
715        let max_output_tokens: usize = 100;
716
717        // Create channel for token output
718        let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
719
720        // Create scheduler args using builder - now including enable_prefix_caching
721        let args = MockEngineArgs::builder()
722            .num_gpu_blocks(kv_capacity)
723            .block_size(block_size)
724            .speedup_ratio(10.0)
725            .enable_prefix_caching(enable_prefix_caching)
726            .enable_chunked_prefill(enable_chunked_prefill)
727            .build()
728            .unwrap();
729
730        // Create scheduler with new args struct
731        let scheduler = Scheduler::new(args, None, Some(output_tx), None, None);
732
733        // Create shared tokens for caching case
734        let shared_tokens = if use_shared_tokens {
735            Some(
736                (0..input_len / 2)
737                    .map(|_| rand::random::<u32>() % 50000)
738                    .collect::<Vec<_>>(),
739            )
740        } else {
741            None
742        };
743
744        // Create test requests
745        for _ in 0..num_requests {
746            let input_tokens = if let Some(ref shared) = shared_tokens {
747                // For caching case: use shared tokens for first half, random for second half
748                let mut tokens = shared.clone();
749                tokens.extend((0..input_len / 2).map(|_| rand::random::<u32>() % 50000));
750                tokens
751            } else {
752                // For random case: create unique random token vector for each request
753                (0..input_len)
754                    .map(|_| rand::random::<u32>() % 50000)
755                    .collect::<Vec<_>>()
756            };
757
758            let request = DirectRequest {
759                tokens: input_tokens,
760                max_output_tokens,
761                uuid: None,
762                dp_rank: None,
763            };
764            scheduler.receive(request).await;
765        }
766
767        let start_time = std::time::Instant::now();
768
769        // Collect all generated tokens (should be num_requests * max_output_tokens)
770        let expected_tokens = num_requests * max_output_tokens;
771        let mut received_tokens = 0;
772
773        // Set up a timeout that causes the test to panic if no tokens are received for 2 seconds
774        let timeout = tokio::time::sleep(Duration::from_secs(2));
775        tokio::pin!(timeout);
776
777        // Get metrics receiver
778        let metrics_rx = scheduler.metrics_receiver();
779
780        // Set up debug ticker interval
781        let mut debug_interval = interval(Duration::from_millis(500));
782
783        loop {
784            tokio::select! {
785                biased;
786
787                // Manual debug ticker that prints forward pass metrics
788                _ = debug_interval.tick() => {
789                    let _metrics = metrics_rx.borrow().clone();
790                    tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
791                }
792
793                Some(_) = output_rx.recv() => {
794                    received_tokens += 1;
795                    // Reset timeout whenever we receive a token
796                    timeout.set(tokio::time::sleep(Duration::from_secs(2)));
797                }
798
799                _ = &mut timeout => {
800                    // Break instead of panicking when timeout occurs
801                    break;
802                }
803            }
804        }
805
806        // Calculate and print elapsed time
807        let elapsed = start_time.elapsed();
808        println!(
809            "Test completed in: {elapsed:?} for {} case with prefix_caching={enable_prefix_caching} and chunked_prefill={enable_chunked_prefill}",
810            if use_shared_tokens {
811                "caching"
812            } else {
813                "random"
814            }
815        );
816
817        // Assert that we received the expected number of tokens
818        assert!(
819            received_tokens == expected_tokens,
820            "Received {received_tokens} tokens but expected exactly {expected_tokens}"
821        );
822
823        let active_tokens = scheduler.active_tokens().await;
824        assert!(
825            active_tokens == 0,
826            "Scheduler still have {active_tokens} active tokens but expected 0"
827        );
828
829        let waiting_tokens = scheduler.waiting_tokens().await;
830        assert!(
831            waiting_tokens == 0,
832            "Scheduler still have {waiting_tokens} waiting tokens but expected 0"
833        );
834    }
835
836    #[tokio::test]
837    async fn test_cache_hit_rate_with_identical_requests() {
838        let block_size: usize = 64;
839        let max_output_tokens: usize = 10;
840        let speedup_ratio = 10.0;
841        let num_requests = 10;
842        let token_length = 65;
843
844        // Create channel for token output
845        let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
846
847        // Create scheduler args
848        let args = MockEngineArgs::builder()
849            .num_gpu_blocks(100) // Large enough to not be a constraint
850            .block_size(block_size)
851            .speedup_ratio(speedup_ratio)
852            .build()
853            .unwrap();
854
855        // Create scheduler
856        let scheduler = Scheduler::new(args, None, Some(output_tx), None, None);
857
858        // Create identical tokens for all requests
859        let identical_tokens: Vec<u32> = (0..token_length).map(|i| i as u32).collect();
860
861        // Send all requests with identical tokens
862        for _ in 0..num_requests {
863            let request = DirectRequest {
864                tokens: identical_tokens.clone(),
865                max_output_tokens,
866                uuid: None,
867                dp_rank: None,
868            };
869            scheduler.receive(request).await;
870            // Sleep for 0.1 second after each request
871            tokio::time::sleep(Duration::from_millis(100)).await;
872        }
873
874        // Collect all generated tokens
875        let mut received_tokens = 0;
876
877        // Set up a timeout that resets to 0.5 seconds on each received token
878        let timeout = tokio::time::sleep(Duration::from_millis(500));
879        tokio::pin!(timeout);
880
881        // Get metrics receiver
882        let metrics_rx = scheduler.metrics_receiver();
883
884        // Set up debug ticker interval
885        let mut debug_interval = interval(Duration::from_millis(500));
886
887        loop {
888            tokio::select! {
889                biased;
890
891                // Manual debug ticker that prints forward pass metrics
892                _ = debug_interval.tick() => {
893                    let _metrics = metrics_rx.borrow().clone();
894                    tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
895                }
896
897                Some(_signal) = output_rx.recv() => {
898                    received_tokens += 1;
899                    // Reset timeout whenever we receive a token
900                    timeout.set(tokio::time::sleep(Duration::from_millis(500)));
901                }
902
903                _ = &mut timeout => {
904                    // Break when timeout occurs (no more tokens for 0.5 seconds)
905                    break;
906                }
907            }
908        }
909
910        // Wait a bit for final metrics update
911        tokio::time::sleep(Duration::from_millis(100)).await;
912
913        // Verify forward pass metrics
914        let metrics = metrics_rx.borrow().clone();
915
916        assert_eq!(
917            metrics.worker_stats.num_requests_waiting, 0,
918            "Expected no waiting requests, got {}",
919            metrics.worker_stats.num_requests_waiting
920        );
921
922        assert!(
923            metrics.kv_stats.gpu_prefix_cache_hit_rate > 0.8,
924            "Expected cache hit rate > 0.8, got {}",
925            metrics.kv_stats.gpu_prefix_cache_hit_rate
926        );
927
928        println!(
929            "Test passed! Cache hit rate: {:.3}",
930            metrics.kv_stats.gpu_prefix_cache_hit_rate
931        );
932        println!("Received {received_tokens} tokens");
933    }
934
935    #[tokio::test]
936    async fn test_receiver_drop_cleans_up_resources() {
937        let block_size: usize = 64;
938        let input_tokens = 256;
939        let max_output_tokens = 200; // More than we'll receive
940
941        // Create channel for token output
942        let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
943
944        // Create scheduler args
945        let args = MockEngineArgs::builder()
946            .num_gpu_blocks(10) // Enough for 256 tokens (4 blocks)
947            .block_size(block_size)
948            .speedup_ratio(100.0) // Fast simulation
949            .build()
950            .unwrap();
951
952        // Create scheduler
953        let scheduler = Scheduler::new(args, None, Some(output_tx), None, None);
954
955        // Create request with 256 tokens
956        let tokens: Vec<u32> = (0..input_tokens).map(|i| i as u32).collect();
957        let request = DirectRequest {
958            tokens,
959            max_output_tokens,
960            uuid: None,
961            dp_rank: None,
962        };
963
964        scheduler.receive(request).await;
965
966        // Receive exactly 129 tokens
967        let mut received_count = 0;
968        while received_count < 129 {
969            if let Some(_signal) = output_rx.recv().await {
970                received_count += 1;
971            } else {
972                panic!("Channel closed before receiving 129 tokens");
973            }
974        }
975
976        // Drop the receiver immediately
977        drop(output_rx);
978
979        // Wait for 1 second to allow cleanup
980        tokio::time::sleep(Duration::from_secs(1)).await;
981
982        // Check forward pass metrics
983        let metrics_rx = scheduler.metrics_receiver();
984        let metrics = metrics_rx.borrow().clone();
985
986        assert_eq!(
987            metrics.kv_stats.gpu_cache_usage_perc,
988            0.0,
989            "Expected GPU cache usage to be 0%, got {}%",
990            metrics.kv_stats.gpu_cache_usage_perc * 100.0
991        );
992
993        assert_eq!(
994            metrics.kv_stats.kv_active_blocks, 0,
995            "Expected 0 active blocks, got {}",
996            metrics.kv_stats.kv_active_blocks
997        );
998    }
999}