1use crate::compute::ComputeEngine;
2use crate::config::Config;
3use crate::dataset::{BatchTokenizerFn, DatasetLoader};
4use crate::kv_cache::KVCacheManager;
5use crate::metrics::MetricsCollector;
6use crate::request::RequestGenerator;
7use crate::scheduler::Scheduler;
8
9#[derive(Debug, Clone)]
10pub struct TimeSeriesPoint {
11 pub time: f64,
12 pub arrivals: u64,
13 pub running: usize,
14 pub waiting: usize,
15 pub kv_cache_util: f64,
16 pub num_prefilling: usize,
17 pub num_decoding: usize,
18 pub prefill_tokens: u32,
19 pub decode_tokens: u32,
20 pub input_throughput: f64, pub output_throughput: f64, pub ttft_p50: f64, pub tpot_p50: f64, }
25
26pub struct ProgressInfo<'a> {
27 pub current_time: f64,
28 pub completed_requests: u64,
29 pub total_requests: u64,
30 pub running: usize,
31 pub waiting: usize,
32 pub kv_cache_util: f64,
33 pub time_series: Option<&'a [TimeSeriesPoint]>,
34 pub metrics: Option<crate::metrics::MetricsSummary>,
35 pub latency_samples: Option<(
36 (&'a [f64], &'a [f64]),
37 (&'a [f64], &'a [f64]),
38 (&'a [f64], &'a [f64]),
39 )>,
40 pub distribution_samples: Option<(&'a [u32], &'a [u32])>, }
42
43pub struct Simulator {
44 scheduler: Scheduler,
45 compute_engine: ComputeEngine,
46 request_generator: RequestGenerator,
47 metrics: MetricsCollector,
48 time_series_data: Vec<TimeSeriesPoint>,
49 sample_interval: f64,
50 next_sample_time: f64,
51 prev_sample_time: f64,
52
53 current_time: f64,
54 iteration: u64,
55
56 last_sent_ttft_count: usize,
58 last_sent_e2e_count: usize,
59 last_sent_tpot_count: usize,
60 last_sent_input_count: usize,
61 last_sent_output_count: usize,
62}
63
64impl Simulator {
65 pub fn new(config: Config) -> Result<Self, String> {
66 Self::new_with_tokenizer(config, None).map(|(sim, _)| sim)
67 }
68
69 pub fn new_with_tokenizer(
72 mut config: Config,
73 tokenizer: Option<BatchTokenizerFn>,
74 ) -> Result<(Self, Config), String> {
75 let kv_cache_manager = KVCacheManager::new(
76 config.hardware.kv_cache_capacity,
77 config.scheduler.block_size,
78 config.model.kv_cache_bytes_per_token,
79 true, );
81
82 let scheduler = Scheduler::new(
83 config.scheduler.clone(),
84 config.hardware.clone(),
85 config.model.clone(),
86 kv_cache_manager,
87 )?;
88
89 let compute_engine = ComputeEngine::new(config.hardware.clone(), config.model.clone());
90
91 let request_generator = if let Some(dataset_path) = &config.workload.dataset_path {
93 let tokenizer = tokenizer.ok_or_else(|| {
95 format!(
96 "Dataset path '{}' provided but no tokenizer function supplied",
97 dataset_path
98 )
99 })?;
100
101 if config.workload.num_requests.is_none() {
103 let total_entries = DatasetLoader::count_entries(dataset_path)
104 .map_err(|e| format!("Failed to count entries in '{}': {}", dataset_path, e))?;
105 config.workload.num_requests = Some(total_entries);
106 }
107
108 let dataset_iterator = DatasetLoader::from_file(dataset_path)
110 .map_err(|e| format!("Failed to load dataset from '{}': {}", dataset_path, e))?;
111
112 RequestGenerator::from_dataset(
114 config.workload.clone(),
115 dataset_iterator,
116 None,
117 tokenizer,
118 )
119 } else {
120 RequestGenerator::new(config.workload.clone())
122 };
123
124 let metrics = MetricsCollector::new(0.0);
125
126 let simulator = Self {
127 scheduler,
128 compute_engine,
129 request_generator,
130 metrics,
131 time_series_data: Vec::new(),
132 sample_interval: 0.1,
133 next_sample_time: 0.0,
134 prev_sample_time: 0.0,
135 current_time: 0.0,
136 iteration: 0,
137 last_sent_ttft_count: 0,
138 last_sent_e2e_count: 0,
139 last_sent_tpot_count: 0,
140 last_sent_input_count: 0,
141 last_sent_output_count: 0,
142 };
143
144 Ok((simulator, config))
145 }
146
147 pub fn run_with_callback<F>(&mut self, mut callback: F) -> Result<(), String>
149 where
150 F: FnMut(ProgressInfo),
151 {
152 let mut last_callback_time = 0.0;
153 let callback_interval = 1.0; loop {
156 self.iteration += 1;
157
158 while let Some(request) = self.request_generator.next_if_before(self.current_time) {
160 self.scheduler.add_request(request);
161 self.metrics.total_requests += 1;
162 }
163
164 let decision = self.scheduler.schedule(self.current_time);
166
167 let (iteration_time, bandwidth_util, flops_util) = if decision.num_scheduled() > 0 {
169 let running = self.scheduler.running_mut();
171 let mut batch_requests = Vec::new();
172 let mut tokens_per_req = Vec::new();
173
174 for (i, &idx) in decision.scheduled_new.iter().enumerate() {
175 if let Some(req) = running.get(idx) {
176 batch_requests.push(req);
177 tokens_per_req.push(decision.tokens_for_new[i]);
178 }
179 }
180
181 for (i, &idx) in decision.scheduled_running.iter().enumerate() {
182 if let Some(req) = running.get(idx) {
183 batch_requests.push(req);
184 tokens_per_req.push(decision.tokens_for_running[i]);
185 }
186 }
187
188 let iteration_time = self
189 .compute_engine
190 .calculate_iteration_time(&batch_requests, &tokens_per_req);
191
192 let bytes_transferred = self
193 .compute_engine
194 .calculate_bytes_transferred(&batch_requests, &tokens_per_req);
195 let bandwidth_util = self
196 .compute_engine
197 .calculate_bandwidth_utilization(bytes_transferred, iteration_time);
198
199 let flops_util = self.compute_engine.calculate_flops_utilization(
200 &batch_requests,
201 &tokens_per_req,
202 iteration_time,
203 );
204
205 (iteration_time, bandwidth_util, flops_util)
206 } else {
207 if !self.request_generator.is_finished() {
209 let next_arrival = self.request_generator.peek_next_arrival_time();
210 let time_until_next = (next_arrival - self.current_time).max(0.001);
211 (time_until_next, 0.0, 0.0)
212 } else {
213 (0.001, 0.0, 0.0)
215 }
216 };
217
218 self.current_time += iteration_time;
220
221 let mut prefilling_reqs = std::collections::HashSet::new();
223 for &idx in decision
224 .scheduled_new
225 .iter()
226 .chain(decision.scheduled_running.iter())
227 {
228 if let Some(request) = self.scheduler.running().get(idx) {
229 if request.is_prefill() {
230 prefilling_reqs.insert(idx);
231 }
232 }
233 }
234
235 for (i, &idx) in decision.scheduled_new.iter().enumerate() {
237 if let Some(request) = self.scheduler.running_mut().get_mut(idx) {
238 request.record_generated_tokens(decision.tokens_for_new[i], self.current_time);
239 }
240 }
241 for (i, &idx) in decision.scheduled_running.iter().enumerate() {
242 if let Some(request) = self.scheduler.running_mut().get_mut(idx) {
243 request
244 .record_generated_tokens(decision.tokens_for_running[i], self.current_time);
245 }
246 }
247
248 let kv_util = self.scheduler.kv_cache_manager().utilization();
250
251 self.metrics
252 .record_iteration_metrics(kv_util, flops_util, bandwidth_util);
253
254 if self.current_time >= self.next_sample_time {
256 let running = self.scheduler.running();
258 let mut num_prefilling = 0;
259 let mut num_decoding = 0;
260 let mut prefill_tokens = 0;
261 let mut decode_tokens = 0;
262
263 for req in running {
264 if req.is_prefill() {
265 num_prefilling += 1;
266 } else {
267 num_decoding += 1;
268 }
269 }
270
271 for (i, &idx) in decision.scheduled_new.iter().enumerate() {
273 let tokens = decision.tokens_for_new[i];
274 if prefilling_reqs.contains(&idx) {
275 prefill_tokens += tokens;
276 } else {
277 decode_tokens += tokens;
278 }
279 }
280 for (i, &idx) in decision.scheduled_running.iter().enumerate() {
281 let tokens = decision.tokens_for_running[i];
282 if prefilling_reqs.contains(&idx) {
283 prefill_tokens += tokens;
284 } else {
285 decode_tokens += tokens;
286 }
287 }
288
289 let input_throughput = prefill_tokens as f64 / self.sample_interval;
291 let output_throughput = decode_tokens as f64 / self.sample_interval;
292
293 let (ttft_mean, tpot_mean) = self.metrics.get_interval_latencies();
295
296 self.time_series_data.push(TimeSeriesPoint {
297 time: self.current_time,
298 arrivals: self.metrics.total_requests,
299 running: self.scheduler.num_running(),
300 waiting: self.scheduler.num_waiting(),
301 kv_cache_util: kv_util,
302 num_prefilling,
303 num_decoding,
304 prefill_tokens,
305 decode_tokens,
306 input_throughput,
307 output_throughput,
308 ttft_p50: ttft_mean,
309 tpot_p50: tpot_mean,
310 });
311 self.prev_sample_time = self.current_time;
312 self.next_sample_time = self.current_time + self.sample_interval;
313 }
314
315 for request in decision.completed {
317 self.scheduler
319 .kv_cache_manager_mut()
320 .free_blocks(&request.kv_blocks);
321
322 self.metrics.record_request_completion(&request);
323
324 self.request_generator
326 .on_request_complete(self.current_time);
327 }
328
329 if self.current_time - last_callback_time >= callback_interval {
331 let kv_manager = self.scheduler.kv_cache_manager();
333 let summary = self.metrics.compute_summary(
334 self.current_time,
335 kv_manager.num_prefix_cache_hits,
336 kv_manager.num_prefix_cache_misses,
337 kv_manager.hit_size_sum,
338 kv_manager.hit_size_count,
339 );
340
341 let latency_samples = self.metrics.get_latency_samples();
343 let input_lengths = self.metrics.get_input_lengths();
344 let output_lengths = self.metrics.get_output_lengths();
345
346 let ttft_delta = &latency_samples.0 .0[self.last_sent_ttft_count..];
348 let ttft_timestamps_delta = &latency_samples.0 .1[self.last_sent_ttft_count..];
349 let e2e_delta = &latency_samples.1 .0[self.last_sent_e2e_count..];
350 let e2e_timestamps_delta = &latency_samples.1 .1[self.last_sent_e2e_count..];
351 let tpot_delta = &latency_samples.2 .0[self.last_sent_tpot_count..];
352 let tpot_timestamps_delta = &latency_samples.2 .1[self.last_sent_tpot_count..];
353 let input_delta = &input_lengths[self.last_sent_input_count..];
354 let output_delta = &output_lengths[self.last_sent_output_count..];
355
356 let progress = ProgressInfo {
357 current_time: self.current_time,
358 completed_requests: self.metrics.completed_requests,
359 total_requests: self.metrics.total_requests,
360 running: self.scheduler.num_running(),
361 waiting: self.scheduler.num_waiting(),
362 kv_cache_util: kv_util,
363 time_series: Some(&self.time_series_data),
364 metrics: Some(summary),
365 latency_samples: Some((
366 (ttft_delta, ttft_timestamps_delta),
367 (e2e_delta, e2e_timestamps_delta),
368 (tpot_delta, tpot_timestamps_delta),
369 )),
370 distribution_samples: Some((input_delta, output_delta)),
371 };
372 callback(progress);
373
374 self.last_sent_ttft_count = latency_samples.0 .0.len();
376 self.last_sent_e2e_count = latency_samples.1 .0.len();
377 self.last_sent_tpot_count = latency_samples.2 .0.len();
378 self.last_sent_input_count = input_lengths.len();
379 self.last_sent_output_count = output_lengths.len();
380 last_callback_time = self.current_time;
381 }
382
383 if self.should_terminate() {
385 let kv_manager = self.scheduler.kv_cache_manager();
388 let summary = self.metrics.compute_summary(
389 self.current_time,
390 kv_manager.num_prefix_cache_hits,
391 kv_manager.num_prefix_cache_misses,
392 kv_manager.hit_size_sum,
393 kv_manager.hit_size_count,
394 );
395
396 let latency_samples = self.metrics.get_latency_samples();
398 let input_lengths = self.metrics.get_input_lengths();
399 let output_lengths = self.metrics.get_output_lengths();
400
401 let ttft_delta = &latency_samples.0 .0[self.last_sent_ttft_count..];
403 let ttft_timestamps_delta = &latency_samples.0 .1[self.last_sent_ttft_count..];
404 let e2e_delta = &latency_samples.1 .0[self.last_sent_e2e_count..];
405 let e2e_timestamps_delta = &latency_samples.1 .1[self.last_sent_e2e_count..];
406 let tpot_delta = &latency_samples.2 .0[self.last_sent_tpot_count..];
407 let tpot_timestamps_delta = &latency_samples.2 .1[self.last_sent_tpot_count..];
408 let input_delta = &input_lengths[self.last_sent_input_count..];
409 let output_delta = &output_lengths[self.last_sent_output_count..];
410
411 let progress = ProgressInfo {
412 current_time: self.current_time,
413 completed_requests: self.metrics.completed_requests,
414 total_requests: self.metrics.total_requests,
415 running: self.scheduler.num_running(),
416 waiting: self.scheduler.num_waiting(),
417 kv_cache_util: kv_util,
418 time_series: Some(&self.time_series_data),
419 metrics: Some(summary),
420 latency_samples: Some((
421 (ttft_delta, ttft_timestamps_delta),
422 (e2e_delta, e2e_timestamps_delta),
423 (tpot_delta, tpot_timestamps_delta),
424 )),
425 distribution_samples: Some((input_delta, output_delta)),
426 };
427 callback(progress);
428 break;
429 }
430 }
431
432 Ok(())
433 }
434
435 pub fn get_metrics_summary(&mut self) -> crate::metrics::MetricsSummary {
436 let kv_manager = self.scheduler.kv_cache_manager();
437 self.metrics.compute_summary(
438 self.current_time,
439 kv_manager.num_prefix_cache_hits,
440 kv_manager.num_prefix_cache_misses,
441 kv_manager.hit_size_sum,
442 kv_manager.hit_size_count,
443 )
444 }
445
446 pub fn get_time_series_data(&self) -> &[TimeSeriesPoint] {
447 &self.time_series_data
448 }
449
450 pub fn get_input_lengths(&self) -> &[u32] {
451 self.metrics.get_input_lengths()
452 }
453
454 pub fn get_output_lengths(&self) -> &[u32] {
455 self.metrics.get_output_lengths()
456 }
457
458 pub fn get_current_time(&self) -> f64 {
459 self.current_time
460 }
461
462 pub fn get_latency_samples(
463 &self,
464 ) -> (
465 (&[f64], &[f64]), (&[f64], &[f64]), (&[f64], &[f64]), ) {
469 self.metrics.get_latency_samples()
470 }
471
472 fn should_terminate(&self) -> bool {
473 self.request_generator.is_finished()
475 && self.scheduler.num_running() == 0
476 && self.scheduler.num_waiting() == 0
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483
484 fn create_minimal_test_config() -> Config {
485 let mut config = Config::test_default();
486 config.workload.num_requests = Some(10); config.workload.arrival_rate = 10.0; config
489 }
490
491 #[test]
492 fn test_simulation_completes_all_requests() {
493 let config = create_minimal_test_config();
494 let mut simulator = Simulator::new(config).unwrap();
495
496 simulator.run_with_callback(|_| {}).unwrap();
497
498 let summary = simulator.get_metrics_summary();
499
500 assert_eq!(summary.completed_requests, summary.total_requests);
502 assert_eq!(summary.completed_requests, 10);
503 }
504
505 #[test]
506 fn test_simulation_time_progresses() {
507 let config = create_minimal_test_config();
508 let mut simulator = Simulator::new(config).unwrap();
509
510 let start_time = simulator.get_current_time();
511 simulator.run_with_callback(|_| {}).unwrap();
512 let end_time = simulator.get_current_time();
513
514 assert!(end_time > start_time);
516 }
517
518 #[test]
519 fn test_simulation_metrics_reasonable() {
520 let config = create_minimal_test_config();
521 let mut simulator = Simulator::new(config).unwrap();
522
523 simulator.run_with_callback(|_| {}).unwrap();
524
525 let summary = simulator.get_metrics_summary();
526
527 assert!(summary.ttft_mean > 0.0 && summary.ttft_mean.is_finite());
529 assert!(summary.e2e_mean > 0.0 && summary.e2e_mean.is_finite());
530 assert!(summary.per_token_mean > 0.0 && summary.per_token_mean.is_finite());
531
532 assert!(summary.ttft_min <= summary.ttft_p50);
534 assert!(summary.ttft_p50 <= summary.ttft_p90);
535 assert!(summary.ttft_p90 <= summary.ttft_p99);
536
537 assert!(summary.e2e_min <= summary.e2e_p50);
538 assert!(summary.e2e_p50 <= summary.e2e_p90);
539 assert!(summary.e2e_p90 <= summary.e2e_p99);
540
541 assert!(summary.avg_kv_cache_util >= 0.0 && summary.avg_kv_cache_util <= 1.0);
543 assert!(summary.avg_flops_util >= 0.0 && summary.avg_flops_util <= 1.0);
544 assert!(summary.avg_bandwidth_util >= 0.0 && summary.avg_bandwidth_util <= 1.0);
545
546 assert!(summary.input_tokens_per_sec > 0.0);
548 assert!(summary.output_tokens_per_sec > 0.0);
549 assert!(summary.requests_per_sec > 0.0);
550 }
551
552 #[test]
553 fn test_simulation_no_infinite_loop() {
554 let config = create_minimal_test_config();
555 let mut simulator = Simulator::new(config).unwrap();
556
557 let mut iteration_count = 0;
559 simulator
560 .run_with_callback(|_| {
561 iteration_count += 1;
562 })
563 .unwrap();
564
565 assert!(iteration_count < 1000);
568 assert!(iteration_count > 0);
569 }
570
571 #[test]
572 fn test_simulation_with_fcfs_policy() {
573 let mut config = create_minimal_test_config();
574 config.scheduler.policy = "fcfs".to_string();
575
576 let mut simulator = Simulator::new(config).unwrap();
577 simulator.run_with_callback(|_| {}).unwrap();
578
579 let summary = simulator.get_metrics_summary();
580 assert_eq!(summary.completed_requests, 10);
581 }
582
583 #[test]
584 fn test_simulation_with_sjf_policy() {
585 let mut config = create_minimal_test_config();
586 config.scheduler.policy = "sjf".to_string();
587
588 let mut simulator = Simulator::new(config).unwrap();
589 simulator.run_with_callback(|_| {}).unwrap();
590
591 let summary = simulator.get_metrics_summary();
592 assert_eq!(summary.completed_requests, 10);
593 }
594
595 #[test]
596 fn test_simulation_with_priority_policy() {
597 let mut config = create_minimal_test_config();
598 config.scheduler.policy = "priority".to_string();
599
600 let mut simulator = Simulator::new(config).unwrap();
601 simulator.run_with_callback(|_| {}).unwrap();
602
603 let summary = simulator.get_metrics_summary();
604 assert_eq!(summary.completed_requests, 10);
605 }
606
607 #[test]
608 fn test_simulation_different_policies_produce_results() {
609 let mut config_fcfs = create_minimal_test_config();
610 config_fcfs.scheduler.policy = "fcfs".to_string();
611 config_fcfs.workload.seed = 42;
612
613 let mut config_sjf = create_minimal_test_config();
614 config_sjf.scheduler.policy = "sjf".to_string();
615 config_sjf.workload.seed = 42; let mut sim_fcfs = Simulator::new(config_fcfs).unwrap();
618 let mut sim_sjf = Simulator::new(config_sjf).unwrap();
619
620 sim_fcfs.run_with_callback(|_| {}).unwrap();
621 sim_sjf.run_with_callback(|_| {}).unwrap();
622
623 let summary_fcfs = sim_fcfs.get_metrics_summary();
624 let summary_sjf = sim_sjf.get_metrics_summary();
625
626 assert_eq!(summary_fcfs.completed_requests, 10);
628 assert_eq!(summary_sjf.completed_requests, 10);
629
630 assert!(summary_fcfs.e2e_mean > 0.0);
633 assert!(summary_sjf.e2e_mean > 0.0);
634 }
635
636 #[test]
637 fn test_simulation_with_chunked_prefill() {
638 let mut config = create_minimal_test_config();
639 config.scheduler.enable_chunked_prefill = true;
640 config.scheduler.long_prefill_token_threshold = 512;
641
642 let mut simulator = Simulator::new(config).unwrap();
643 simulator.run_with_callback(|_| {}).unwrap();
644
645 let summary = simulator.get_metrics_summary();
646 assert_eq!(summary.completed_requests, 10);
647 }
648
649 #[test]
650 fn test_simulation_preemption_metrics() {
651 let config = create_minimal_test_config();
652 let mut simulator = Simulator::new(config).unwrap();
653
654 simulator.run_with_callback(|_| {}).unwrap();
655
656 let summary = simulator.get_metrics_summary();
657
658 assert!(summary.preemptions_per_request_mean >= 0.0);
660 }
661
662 #[test]
663 fn test_simulation_time_series_collected() {
664 let config = create_minimal_test_config();
665 let mut simulator = Simulator::new(config).unwrap();
666
667 simulator.run_with_callback(|_| {}).unwrap();
668
669 let time_series = simulator.get_time_series_data();
670
671 assert!(!time_series.is_empty());
673
674 for i in 1..time_series.len() {
676 assert!(time_series[i].time >= time_series[i - 1].time);
677 }
678 }
679
680 #[test]
681 fn test_simulation_latency_samples_collected() {
682 let config = create_minimal_test_config();
683 let mut simulator = Simulator::new(config).unwrap();
684
685 simulator.run_with_callback(|_| {}).unwrap();
686
687 let ((ttft, ttft_ts), (e2e, e2e_ts), (tpot, tpot_ts)) = simulator.get_latency_samples();
688
689 assert!(!ttft.is_empty());
691 assert!(!e2e.is_empty());
692
693 assert_eq!(ttft.len(), ttft_ts.len());
695 assert_eq!(e2e.len(), e2e_ts.len());
696 assert_eq!(tpot.len(), tpot_ts.len());
697 }
698}