1use crate::kv_router::protocols::{ForwardPassMetrics, KvCacheEventData, KvStats, WorkerStats};
32use crate::mocker::evictor::LRUEvictor;
33use crate::mocker::kv_manager::KvManager;
34use crate::mocker::protocols::{DirectRequest, MockEngineArgs, MoveBlockResponse};
35use crate::mocker::protocols::{MoveBlock, OutputSignal, PrefillCost, block_response_to_kv_event};
36use crate::mocker::sequence::ActiveSequence;
37use crate::tokens::BlockHash;
38use crate::tokens::blocks::UniqueBlock;
39use std::collections::HashMap;
40use std::collections::VecDeque;
41use std::sync::Arc;
42use tokio::sync::{Mutex, mpsc};
43use tokio::time::Duration;
44use tokio_util::sync::CancellationToken;
45use uuid::Uuid;
46
47pub enum Request {
49 Direct(DirectRequest),
50 Active(ActiveSequence),
51}
52
53#[derive(Default)]
54struct SchedulerState {
55 waiting: VecDeque<Uuid>,
56 prefill: VecDeque<Uuid>,
57 decode: LRUEvictor<Uuid>,
58 requests: HashMap<Uuid, Request>,
59 prefill_costs: HashMap<Uuid, PrefillCost>,
60 max_num_batched_tokens: Option<usize>,
61 active_tokens: usize,
62 waiting_tokens: usize,
63}
64
65impl SchedulerState {
66 fn new(max_num_batched_tokens: Option<usize>) -> Self {
67 SchedulerState {
68 max_num_batched_tokens,
69 ..Default::default()
70 }
71 }
72
73 fn is_empty(&self) -> bool {
74 self.requests.is_empty()
75 }
76
77 fn receive(&mut self, request: DirectRequest) -> Uuid {
79 let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
81 self.requests.insert(uuid, Request::Direct(request));
82 self.waiting.push_back(uuid);
83 uuid
84 }
85
86 fn next(&mut self) -> Option<(Uuid, Request)> {
88 let uuid = self.waiting.pop_front()?;
89 let request = self
90 .requests
91 .remove(&uuid)
92 .expect("Request does not exist.");
93 Some((uuid, request))
94 }
95
96 fn first_in_line(&mut self, uuid: Uuid, request: Request) {
98 self.requests.insert(uuid, request);
99 self.waiting.push_front(uuid);
100 }
101
102 fn move_to_prefill(&mut self, uuid: Uuid, active_seq: ActiveSequence, cost: PrefillCost) {
104 self.waiting_tokens += cost.new_tokens;
105 self.requests.insert(uuid, Request::Active(active_seq));
106 self.prefill.push_back(uuid);
107 self.prefill_costs.insert(uuid, cost);
108 }
109
110 fn try_prefill(&mut self) -> Option<(f64, Option<MoveBlock>, Vec<BlockHash>, bool)> {
118 let uuid = self.prefill.pop_front()?;
119
120 let mut prefill_cost = self
122 .prefill_costs
123 .remove(&uuid)
124 .expect("Expects valid prefill cost.");
125
126 let new_tokens = prefill_cost.new_tokens;
127
128 let maybe_prefill_tokens = self.max_num_batched_tokens.and_then(|max_tokens| {
129 let remaining_tokens = max_tokens - self.active_tokens;
130 if prefill_cost.new_tokens > remaining_tokens {
131 Some(remaining_tokens)
132 } else {
133 None
134 }
135 });
136
137 let (prefill_compute, is_full_prefill) = if let Some(prefill_tokens) = maybe_prefill_tokens
138 {
139 let prefill_compute = prefill_cost.predict_prefill_compute(Some(prefill_tokens));
140 prefill_cost.new_tokens -= prefill_tokens;
141 assert!(
142 (prefill_cost.new_tokens > 0) && (prefill_compute > 0.0),
143 "Encountered negative prefill tokens or prefill compute cost."
144 );
145
146 self.prefill.push_front(uuid);
147 self.prefill_costs.insert(uuid, prefill_cost);
148
149 self.active_tokens = self.max_num_batched_tokens.unwrap();
150 self.waiting_tokens -= prefill_tokens;
151
152 (prefill_compute, false)
153 } else {
154 self.decode.insert(uuid);
156
157 self.active_tokens += new_tokens;
158 self.waiting_tokens -= new_tokens;
159
160 (prefill_cost.predict_prefill_compute(None), true)
161 };
162
163 let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
166 panic!("Request does not exist.");
167 };
168
169 Some((
170 prefill_compute,
171 sequence.take_creation_signal(),
172 sequence.block_hashes(),
173 is_full_prefill,
174 ))
175 }
176
177 fn reset_active_tokens(&mut self) {
179 self.active_tokens = self.decode.len();
180 }
181
182 fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
183 if !self.decode.contains(&uuid) {
184 return None;
185 }
186 let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
187 panic!("Request does not exist.");
188 };
189 Some(sequence)
190 }
191
192 fn num_active_requests(&self) -> usize {
193 self.prefill.len() + self.decode.len()
194 }
195
196 fn complete(&mut self, uuid: &Uuid) {
198 tracing::trace!("Request {uuid} will complete");
199 self.decode.remove(uuid);
200 self.requests.remove(uuid);
201 self.prefill_costs.remove(uuid);
202 self.active_tokens -= 1;
203 }
204
205 fn preempt(&mut self) -> Vec<MoveBlock> {
209 let uuid = self
211 .decode
212 .evict()
213 .expect("Nothing to evict for preemption.");
214 let request = self
215 .requests
216 .remove(&uuid)
217 .expect("Request does not exist.");
218 self.prefill_costs.remove(&uuid);
219 self.active_tokens -= 1;
220 tracing::warn!("Request {uuid} will be preempted");
221
222 let Request::Active(mut active_sequence) = request else {
225 panic!("Expected ActiveSequence in running queue")
226 };
227 let signals = active_sequence.reset_with_signal();
228
229 self.first_in_line(uuid, Request::Active(active_sequence));
233
234 signals
235 }
236}
237
238#[derive(Clone)]
240pub struct Scheduler {
241 state: Arc<Mutex<SchedulerState>>,
242 kv_manager: Arc<Mutex<KvManager>>,
243 request_tx: mpsc::UnboundedSender<DirectRequest>,
244 metrics_rx: tokio::sync::watch::Receiver<ForwardPassMetrics>,
245}
246
247impl Scheduler {
248 pub fn new(
250 args: MockEngineArgs,
251 dp_rank: Option<u32>,
252 output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
253 kv_events_tx: Option<mpsc::UnboundedSender<KvCacheEventData>>,
254 cancellation_token: Option<CancellationToken>,
255 ) -> Self {
256 let state = Arc::new(Mutex::new(SchedulerState::new(args.max_num_batched_tokens)));
257
258 let (block_resp_tx, mut block_resp_rx) = if kv_events_tx.is_some() {
260 let (tx, rx) = mpsc::unbounded_channel::<MoveBlockResponse>();
261 (Some(tx), Some(rx))
262 } else {
263 (None, None)
264 };
265
266 let kv_manager = Arc::new(Mutex::new(KvManager::new_with_sender(
267 args.num_gpu_blocks,
268 args.block_size,
269 block_resp_tx,
270 )));
271 let hit_rates = Arc::new(Mutex::new(VecDeque::with_capacity(1000)));
272
273 assert!(
275 args.speedup_ratio > 0.0,
276 "speedup_ratio must be greater than 0, got: {}",
277 args.speedup_ratio
278 );
279
280 let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
282 let mut initial_metrics = ForwardPassMetrics::default();
283 initial_metrics.worker_stats.data_parallel_rank = dp_rank;
284 let (metrics_tx, metrics_rx) =
285 tokio::sync::watch::channel::<ForwardPassMetrics>(initial_metrics);
286
287 let state_clone = state.clone();
289 let kv_manager_clone = kv_manager.clone();
290 let output_tx_clone = output_tx.clone();
291 let cancel_token_clone = cancellation_token.unwrap_or_default().clone();
292
293 tokio::spawn(async move {
295 let mut should_schedule = true;
296
297 loop {
298 {
299 let state_guard = state_clone.lock().await;
300
301 if state_guard.is_empty() {
304 drop(state_guard);
305 let Some(request) = request_rx.recv().await else {
306 tracing::warn!("request sender is dropped");
307 break;
308 };
309 let mut state_guard = state_clone.lock().await;
310 state_guard.receive(request);
311 }
312 }
313
314 tokio::select! {
315 biased;
316
317 Some(request) = request_rx.recv() => {
319 let mut state = state_clone.lock().await;
320 state.receive(request);
321 }
322
323 _ = tokio::task::yield_now() => {
325 if !should_schedule {
327 continue;
328 }
329
330 let mut state_guard = state_clone.lock().await;
331 let kv_manager_guard = kv_manager_clone.lock().await;
332
333 let mut current_blocks = kv_manager_guard.num_active_blocks();
336 let mut current_tokens = state_guard.active_tokens + state_guard.waiting_tokens;
337 let mut current_seqs = state_guard.num_active_requests();
338
339 while let Some((uuid, request)) = state_guard.next() {
340 let active_sequence = get_active_sequence(request, args.block_size, args.enable_prefix_caching);
341
342 let prefill_cost = kv_manager_guard.get_prefill_cost(&active_sequence);
344 let total_tokens = active_sequence.len();
345 let new_blocks = (total_tokens as u32).div_ceil(args.block_size as u32) as usize;
347 let new_tokens = prefill_cost.new_tokens;
348
349 current_blocks += new_blocks;
350 current_tokens += new_tokens;
351 current_seqs += 1;
352
353 let under_block_budget = current_blocks as f64 <= (1. - args.watermark) * kv_manager_guard.max_capacity() as f64;
355 let comparison_tokens = if args.enable_chunked_prefill {current_tokens - new_tokens} else {current_tokens};
357 let under_token_budget = args.max_num_batched_tokens.is_none_or(|limit| comparison_tokens <= limit);
358 let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit);
359
360 if !(under_block_budget && under_token_budget && under_seq_budget) {
362 state_guard.first_in_line(uuid, Request::Active(active_sequence));
363 break;
364 }
365
366 let hit_rate = if !active_sequence.is_empty() { 1.0 - (new_tokens as f32 / active_sequence.len() as f32) } else { 0.0 };
368 {
369 let mut hit_rates_guard = hit_rates.lock().await;
370 hit_rates_guard.push_back(hit_rate);
371 if hit_rates_guard.len() > 1000 {
372 hit_rates_guard.pop_front();
373 }
374 }
375
376 state_guard.move_to_prefill(uuid, active_sequence, prefill_cost);
377 should_schedule = false;
378 }
379 }
380
381 _ = cancel_token_clone.cancelled() => {
383 break;
384 }
385 }
386
387 let mut state_guard = state_clone.lock().await;
389 let mut kv_manager_guard = kv_manager_clone.lock().await;
390
391 let active_perc = kv_manager_guard.get_active_perc();
393 let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44;
394 let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);
395
396 while let Some((
398 prefill_compute,
399 maybe_creation_signal,
400 block_hashes,
401 is_full_prefill,
402 )) = state_guard.try_prefill()
403 {
404 total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
407
408 if let Some(creation_signal) = maybe_creation_signal {
409 if !process_signals(
410 &mut kv_manager_guard,
411 std::slice::from_ref(&creation_signal),
412 ) {
413 panic!("Block allocation for prefilling cannot fail.");
414 }
415
416 if let (Some(relay_tx), Some(rx)) = (&kv_events_tx, &mut block_resp_rx) {
418 while let Ok(event) = rx.try_recv() {
419 let _ =
420 relay_tx.send(block_response_to_kv_event(event, &block_hashes));
421 }
422 }
423 };
424
425 if !is_full_prefill {
427 break;
428 }
429 }
430
431 state_guard.reset_active_tokens();
432
433 {
434 let hit_rates_guard = hit_rates.lock().await;
435 let metrics = get_fwd_pass_metrics(
436 &state_guard,
437 &kv_manager_guard,
438 &hit_rates_guard,
439 dp_rank,
440 );
441 let _ = metrics_tx.send(metrics);
442 }
443
444 let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect();
446 if !uuids.is_empty() {
447 should_schedule = true
448 };
449 for uuid in uuids {
450 let Some(sequence) = state_guard.run(uuid) else {
451 continue;
452 };
453 let signals = sequence.generate();
454
455 if !process_signals(&mut kv_manager_guard, &signals) {
458 sequence.pop(); for signal in state_guard.preempt() {
460 kv_manager_guard.process(&signal);
461 }
462 continue;
463 }
464
465 if let (Some(relay_tx), Some(rx)) = (&kv_events_tx, &mut block_resp_rx) {
467 while let Ok(event) = rx.try_recv() {
468 let _ = relay_tx
469 .send(block_response_to_kv_event(event, &sequence.block_hashes()));
470 }
471 }
472
473 let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
475 let should_output =
476 sequence.generated_tokens() > sequence.already_generated_tokens();
477
478 let mut send_failed = false;
479 if should_output {
480 send_failed = output_tx_clone.as_ref().is_some_and(|tx| {
481 tx.send(OutputSignal {
482 uuid,
483 completed: is_complete,
484 })
485 .is_err()
486 });
487 }
488
489 if send_failed {
490 for signal in &sequence.free_signal() {
491 kv_manager_guard.process(signal);
492 }
493 }
494
495 {
496 let hit_rates_guard = hit_rates.lock().await;
497 let metrics = get_fwd_pass_metrics(
498 &state_guard,
499 &kv_manager_guard,
500 &hit_rates_guard,
501 dp_rank,
502 );
503 let _ = metrics_tx.send(metrics);
504 }
505
506 if send_failed || is_complete {
507 state_guard.complete(&uuid);
508 continue;
509 }
510 }
511
512 drop(kv_manager_guard);
514 drop(state_guard);
515 let adjusted_time =
516 Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
517 if adjusted_time.as_millis() > 0 {
518 tokio::time::sleep(adjusted_time).await;
519 }
520 }
521 });
522
523 Self {
524 state,
525 kv_manager,
526 request_tx,
527 metrics_rx,
528 }
529 }
530
531 pub async fn receive(&self, request: DirectRequest) {
533 let _ = self.request_tx.send(request);
534 }
535
536 pub fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
537 self.request_tx.clone()
538 }
539
540 pub async fn waiting_count(&self) -> usize {
541 let state = self.state.lock().await;
542 state.waiting.len()
543 }
544
545 pub async fn running_count(&self) -> usize {
546 let state = self.state.lock().await;
547 state.decode.len()
548 }
549
550 pub async fn waiting_tokens(&self) -> usize {
551 let state = self.state.lock().await;
552 state.waiting_tokens
553 }
554
555 pub async fn active_tokens(&self) -> usize {
556 let state = self.state.lock().await;
557 state.active_tokens
558 }
559
560 pub async fn kv_usage_perc(&self) -> f64 {
561 let kv_manager = self.kv_manager.lock().await;
562 kv_manager.current_capacity_perc()
563 }
564
565 pub fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<ForwardPassMetrics> {
567 self.metrics_rx.clone()
568 }
569}
570
571fn get_fwd_pass_metrics(
573 state: &SchedulerState,
574 kv_manager: &KvManager,
575 hit_rates: &VecDeque<f32>,
576 dp_rank: Option<u32>,
577) -> ForwardPassMetrics {
578 let request_active_slots = state.decode.len() as u64;
580 let num_requests_waiting = state.waiting.len() as u64;
581
582 let active_blocks_count = kv_manager.active_blocks().len() as u64;
584 let total_capacity = kv_manager.max_capacity() as u64;
585 let gpu_cache_usage_perc = if total_capacity > 0 {
586 active_blocks_count as f32 / total_capacity as f32
587 } else {
588 0.0
589 };
590
591 let gpu_prefix_cache_hit_rate = if hit_rates.is_empty() {
593 0.0
594 } else {
595 let sum: f32 = hit_rates.iter().sum();
596 sum / hit_rates.len() as f32
597 };
598
599 let worker_stats = WorkerStats {
600 data_parallel_rank: dp_rank,
601 request_active_slots,
602 request_total_slots: 1024, num_requests_waiting,
604 };
605
606 let kv_stats = KvStats {
607 kv_active_blocks: active_blocks_count,
608 kv_total_blocks: total_capacity,
609 gpu_cache_usage_perc,
610 gpu_prefix_cache_hit_rate,
611 };
612
613 let spec_decode_stats = None;
614
615 ForwardPassMetrics {
616 worker_stats,
617 kv_stats,
618 spec_decode_stats,
619 }
620}
621
622fn get_active_sequence(
624 request: Request,
625 block_size: usize,
626 enable_prefix_caching: bool,
627) -> ActiveSequence {
628 if let Request::Active(active_seq) = request {
629 return active_seq;
630 }
631
632 let Request::Direct(direct_request) = request else {
633 unreachable!("Request must be either Direct or Active");
634 };
635
636 ActiveSequence::new(
637 direct_request.tokens,
638 direct_request.max_output_tokens,
639 Some(block_size),
640 enable_prefix_caching,
641 )
642}
643
644fn process_signals(
652 kv_manager_guard: &mut tokio::sync::MutexGuard<'_, KvManager>,
653 signals: &[MoveBlock],
654) -> bool {
655 for signal in signals {
656 if kv_manager_guard.process(signal) {
657 continue;
658 }
659
660 let MoveBlock::Use(blocks) = signal else {
662 panic!(
663 "Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}"
664 );
665 };
666
667 let num_blocks = blocks.len();
669 let num_active_blocks = kv_manager_guard.num_active_blocks();
670 if num_blocks != 1 {
671 panic!(
672 "Failed signal is Invalid. Tried to create (prefill) {num_blocks} blocks on top of {num_active_blocks} active blocks."
673 );
674 }
675
676 if !matches!(blocks[0], UniqueBlock::PartialBlock(_)) {
678 panic!("Failed signal is Invalid. Generation block has to be partial.");
679 }
680
681 return false;
682 }
683
684 true
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690 use rstest::rstest;
691 use std::time::Duration;
692 use tokio::time::interval;
693
694 #[rstest]
695 #[case::case_1(false, false, false)]
696 #[case::case_2(false, true, false)]
697 #[case::case_3(true, false, false)]
698 #[case::case_4(true, true, false)]
699 #[case::case_5(false, false, true)]
700 #[case::case_6(false, true, true)]
701 #[case::case_7(true, false, true)]
702 #[case::case_8(true, true, true)]
703 #[tokio::test]
704 async fn test_scheduler_token_generation_patterns(
705 #[case] use_shared_tokens: bool,
706 #[case] enable_prefix_caching: bool,
707 #[case] enable_chunked_prefill: bool,
708 ) {
709 unsafe { std::env::set_var("RUST_LOG", "debug") };
710
711 let kv_capacity: usize = 500;
712 let block_size: usize = 64;
713 let num_requests: usize = 200;
714 let input_len: usize = 1000;
715 let max_output_tokens: usize = 100;
716
717 let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
719
720 let args = MockEngineArgs::builder()
722 .num_gpu_blocks(kv_capacity)
723 .block_size(block_size)
724 .speedup_ratio(10.0)
725 .enable_prefix_caching(enable_prefix_caching)
726 .enable_chunked_prefill(enable_chunked_prefill)
727 .build()
728 .unwrap();
729
730 let scheduler = Scheduler::new(args, None, Some(output_tx), None, None);
732
733 let shared_tokens = if use_shared_tokens {
735 Some(
736 (0..input_len / 2)
737 .map(|_| rand::random::<u32>() % 50000)
738 .collect::<Vec<_>>(),
739 )
740 } else {
741 None
742 };
743
744 for _ in 0..num_requests {
746 let input_tokens = if let Some(ref shared) = shared_tokens {
747 let mut tokens = shared.clone();
749 tokens.extend((0..input_len / 2).map(|_| rand::random::<u32>() % 50000));
750 tokens
751 } else {
752 (0..input_len)
754 .map(|_| rand::random::<u32>() % 50000)
755 .collect::<Vec<_>>()
756 };
757
758 let request = DirectRequest {
759 tokens: input_tokens,
760 max_output_tokens,
761 uuid: None,
762 dp_rank: None,
763 };
764 scheduler.receive(request).await;
765 }
766
767 let start_time = std::time::Instant::now();
768
769 let expected_tokens = num_requests * max_output_tokens;
771 let mut received_tokens = 0;
772
773 let timeout = tokio::time::sleep(Duration::from_secs(2));
775 tokio::pin!(timeout);
776
777 let metrics_rx = scheduler.metrics_receiver();
779
780 let mut debug_interval = interval(Duration::from_millis(500));
782
783 loop {
784 tokio::select! {
785 biased;
786
787 _ = debug_interval.tick() => {
789 let _metrics = metrics_rx.borrow().clone();
790 tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
791 }
792
793 Some(_) = output_rx.recv() => {
794 received_tokens += 1;
795 timeout.set(tokio::time::sleep(Duration::from_secs(2)));
797 }
798
799 _ = &mut timeout => {
800 break;
802 }
803 }
804 }
805
806 let elapsed = start_time.elapsed();
808 println!(
809 "Test completed in: {elapsed:?} for {} case with prefix_caching={enable_prefix_caching} and chunked_prefill={enable_chunked_prefill}",
810 if use_shared_tokens {
811 "caching"
812 } else {
813 "random"
814 }
815 );
816
817 assert!(
819 received_tokens == expected_tokens,
820 "Received {received_tokens} tokens but expected exactly {expected_tokens}"
821 );
822
823 let active_tokens = scheduler.active_tokens().await;
824 assert!(
825 active_tokens == 0,
826 "Scheduler still have {active_tokens} active tokens but expected 0"
827 );
828
829 let waiting_tokens = scheduler.waiting_tokens().await;
830 assert!(
831 waiting_tokens == 0,
832 "Scheduler still have {waiting_tokens} waiting tokens but expected 0"
833 );
834 }
835
836 #[tokio::test]
837 async fn test_cache_hit_rate_with_identical_requests() {
838 let block_size: usize = 64;
839 let max_output_tokens: usize = 10;
840 let speedup_ratio = 10.0;
841 let num_requests = 10;
842 let token_length = 65;
843
844 let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
846
847 let args = MockEngineArgs::builder()
849 .num_gpu_blocks(100) .block_size(block_size)
851 .speedup_ratio(speedup_ratio)
852 .build()
853 .unwrap();
854
855 let scheduler = Scheduler::new(args, None, Some(output_tx), None, None);
857
858 let identical_tokens: Vec<u32> = (0..token_length).map(|i| i as u32).collect();
860
861 for _ in 0..num_requests {
863 let request = DirectRequest {
864 tokens: identical_tokens.clone(),
865 max_output_tokens,
866 uuid: None,
867 dp_rank: None,
868 };
869 scheduler.receive(request).await;
870 tokio::time::sleep(Duration::from_millis(100)).await;
872 }
873
874 let mut received_tokens = 0;
876
877 let timeout = tokio::time::sleep(Duration::from_millis(500));
879 tokio::pin!(timeout);
880
881 let metrics_rx = scheduler.metrics_receiver();
883
884 let mut debug_interval = interval(Duration::from_millis(500));
886
887 loop {
888 tokio::select! {
889 biased;
890
891 _ = debug_interval.tick() => {
893 let _metrics = metrics_rx.borrow().clone();
894 tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
895 }
896
897 Some(_signal) = output_rx.recv() => {
898 received_tokens += 1;
899 timeout.set(tokio::time::sleep(Duration::from_millis(500)));
901 }
902
903 _ = &mut timeout => {
904 break;
906 }
907 }
908 }
909
910 tokio::time::sleep(Duration::from_millis(100)).await;
912
913 let metrics = metrics_rx.borrow().clone();
915
916 assert_eq!(
917 metrics.worker_stats.num_requests_waiting, 0,
918 "Expected no waiting requests, got {}",
919 metrics.worker_stats.num_requests_waiting
920 );
921
922 assert!(
923 metrics.kv_stats.gpu_prefix_cache_hit_rate > 0.8,
924 "Expected cache hit rate > 0.8, got {}",
925 metrics.kv_stats.gpu_prefix_cache_hit_rate
926 );
927
928 println!(
929 "Test passed! Cache hit rate: {:.3}",
930 metrics.kv_stats.gpu_prefix_cache_hit_rate
931 );
932 println!("Received {received_tokens} tokens");
933 }
934
935 #[tokio::test]
936 async fn test_receiver_drop_cleans_up_resources() {
937 let block_size: usize = 64;
938 let input_tokens = 256;
939 let max_output_tokens = 200; let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
943
944 let args = MockEngineArgs::builder()
946 .num_gpu_blocks(10) .block_size(block_size)
948 .speedup_ratio(100.0) .build()
950 .unwrap();
951
952 let scheduler = Scheduler::new(args, None, Some(output_tx), None, None);
954
955 let tokens: Vec<u32> = (0..input_tokens).map(|i| i as u32).collect();
957 let request = DirectRequest {
958 tokens,
959 max_output_tokens,
960 uuid: None,
961 dp_rank: None,
962 };
963
964 scheduler.receive(request).await;
965
966 let mut received_count = 0;
968 while received_count < 129 {
969 if let Some(_signal) = output_rx.recv().await {
970 received_count += 1;
971 } else {
972 panic!("Channel closed before receiving 129 tokens");
973 }
974 }
975
976 drop(output_rx);
978
979 tokio::time::sleep(Duration::from_secs(1)).await;
981
982 let metrics_rx = scheduler.metrics_receiver();
984 let metrics = metrics_rx.borrow().clone();
985
986 assert_eq!(
987 metrics.kv_stats.gpu_cache_usage_perc,
988 0.0,
989 "Expected GPU cache usage to be 0%, got {}%",
990 metrics.kv_stats.gpu_cache_usage_perc * 100.0
991 );
992
993 assert_eq!(
994 metrics.kv_stats.kv_active_blocks, 0,
995 "Expected 0 active blocks, got {}",
996 metrics.kv_stats.kv_active_blocks
997 );
998 }
999}