1use crate::compute::ComputeEngine;
2use crate::config::Config;
3use crate::dataset::{BatchTokenizerFn, DatasetLoader};
4use crate::kv_cache::KVCacheManager;
5use crate::metrics::{LatencySampleTriplet, MetricsCollector};
6use crate::request::RequestGenerator;
7use crate::scheduler::Scheduler;
8
9#[derive(Debug, Clone)]
10pub struct TimeSeriesPoint {
11 pub time: f64,
12 pub arrivals: u64,
13 pub running: usize,
14 pub waiting: usize,
15 pub kv_cache_util: f64,
16 pub num_prefilling: usize,
17 pub num_decoding: usize,
18 pub prefill_tokens: u32,
19 pub decode_tokens: u32,
20 pub input_throughput: f64, pub output_throughput: f64, pub ttft_p50: f64, pub tpot_p50: f64, }
25
26pub struct ProgressInfo<'a> {
27 pub current_time: f64,
28 pub completed_requests: u64,
29 pub total_requests: u64,
30 pub running: usize,
31 pub waiting: usize,
32 pub kv_cache_util: f64,
33 pub time_series: Option<&'a [TimeSeriesPoint]>,
34 pub metrics: Option<crate::metrics::MetricsSummary>,
35 pub latency_samples: Option<LatencySampleTriplet<'a>>,
36 pub distribution_samples: Option<(&'a [u32], &'a [u32])>, }
38
39pub struct Simulator {
40 scheduler: Scheduler,
41 compute_engine: ComputeEngine,
42 request_generator: RequestGenerator,
43 metrics: MetricsCollector,
44 time_series_data: Vec<TimeSeriesPoint>,
45 sample_interval: f64,
46 next_sample_time: f64,
47 prev_sample_time: f64,
48
49 current_time: f64,
50 iteration: u64,
51
52 last_sent_ttft_count: usize,
54 last_sent_e2e_count: usize,
55 last_sent_tpot_count: usize,
56 last_sent_input_count: usize,
57 last_sent_output_count: usize,
58}
59
60impl Simulator {
61 pub fn new(config: Config) -> Result<Self, String> {
62 Self::new_with_tokenizer(config, None).map(|(sim, _)| sim)
63 }
64
65 pub fn new_with_tokenizer(
68 mut config: Config,
69 tokenizer: Option<BatchTokenizerFn>,
70 ) -> Result<(Self, Config), String> {
71 let kv_cache_manager = KVCacheManager::new(
72 config.hardware.kv_cache_capacity,
73 config.scheduler.block_size,
74 config.model.kv_cache_bytes_per_token,
75 true, );
77
78 let scheduler = Scheduler::new(
79 config.scheduler.clone(),
80 config.hardware.clone(),
81 config.model.clone(),
82 kv_cache_manager,
83 )?;
84
85 let compute_engine = ComputeEngine::new(config.hardware.clone(), config.model.clone());
86
87 let request_generator = if let Some(dataset_path) = &config.workload.dataset_path {
89 let tokenizer = tokenizer.ok_or_else(|| {
91 format!(
92 "Dataset path '{}' provided but no tokenizer function supplied",
93 dataset_path
94 )
95 })?;
96
97 if config.workload.num_requests.is_none() {
99 let total_entries = DatasetLoader::count_entries(dataset_path)
100 .map_err(|e| format!("Failed to count entries in '{}': {}", dataset_path, e))?;
101 config.workload.num_requests = Some(total_entries);
102 }
103
104 let dataset_iterator = DatasetLoader::from_file(dataset_path)
106 .map_err(|e| format!("Failed to load dataset from '{}': {}", dataset_path, e))?;
107
108 RequestGenerator::from_dataset(
110 config.workload.clone(),
111 dataset_iterator,
112 None,
113 tokenizer,
114 )
115 } else {
116 RequestGenerator::new(config.workload.clone())
118 };
119
120 let metrics = MetricsCollector::new(0.0);
121
122 let simulator = Self {
123 scheduler,
124 compute_engine,
125 request_generator,
126 metrics,
127 time_series_data: Vec::new(),
128 sample_interval: 0.1,
129 next_sample_time: 0.0,
130 prev_sample_time: 0.0,
131 current_time: 0.0,
132 iteration: 0,
133 last_sent_ttft_count: 0,
134 last_sent_e2e_count: 0,
135 last_sent_tpot_count: 0,
136 last_sent_input_count: 0,
137 last_sent_output_count: 0,
138 };
139
140 Ok((simulator, config))
141 }
142
143 pub fn run_with_callback<F>(&mut self, mut callback: F) -> Result<(), String>
145 where
146 F: FnMut(ProgressInfo),
147 {
148 let mut last_callback_time = 0.0;
149 let callback_interval = 1.0; loop {
152 self.iteration += 1;
153
154 while let Some(request) = self.request_generator.next_if_before(self.current_time) {
156 self.scheduler.add_request(request);
157 self.metrics.total_requests += 1;
158 }
159
160 let decision = self.scheduler.schedule(self.current_time);
162
163 let (iteration_time, bandwidth_util, flops_util) = if decision.num_scheduled() > 0 {
165 let running = self.scheduler.running_mut();
167 let mut batch_requests = Vec::new();
168 let mut tokens_per_req = Vec::new();
169
170 for (i, &idx) in decision.scheduled_new.iter().enumerate() {
171 if let Some(req) = running.get(idx) {
172 batch_requests.push(req);
173 tokens_per_req.push(decision.tokens_for_new[i]);
174 }
175 }
176
177 for (i, &idx) in decision.scheduled_running.iter().enumerate() {
178 if let Some(req) = running.get(idx) {
179 batch_requests.push(req);
180 tokens_per_req.push(decision.tokens_for_running[i]);
181 }
182 }
183
184 let iteration_time = self
185 .compute_engine
186 .calculate_iteration_time(&batch_requests, &tokens_per_req);
187
188 let bytes_transferred = self
189 .compute_engine
190 .calculate_bytes_transferred(&batch_requests, &tokens_per_req);
191 let bandwidth_util = self
192 .compute_engine
193 .calculate_bandwidth_utilization(bytes_transferred, iteration_time);
194
195 let flops_util = self.compute_engine.calculate_flops_utilization(
196 &batch_requests,
197 &tokens_per_req,
198 iteration_time,
199 );
200
201 (iteration_time, bandwidth_util, flops_util)
202 } else {
203 if !self.request_generator.is_finished() {
205 let next_arrival = self.request_generator.peek_next_arrival_time();
206 let time_until_next = (next_arrival - self.current_time).max(0.001);
207 (time_until_next, 0.0, 0.0)
208 } else {
209 (0.001, 0.0, 0.0)
211 }
212 };
213
214 self.current_time += iteration_time;
216
217 let mut prefilling_reqs = std::collections::HashSet::new();
219 for &idx in decision
220 .scheduled_new
221 .iter()
222 .chain(decision.scheduled_running.iter())
223 {
224 if let Some(request) = self.scheduler.running().get(idx) {
225 if request.is_prefill() {
226 prefilling_reqs.insert(idx);
227 }
228 }
229 }
230
231 for (i, &idx) in decision.scheduled_new.iter().enumerate() {
233 if let Some(request) = self.scheduler.running_mut().get_mut(idx) {
234 request.record_generated_tokens(decision.tokens_for_new[i], self.current_time);
235 }
236 }
237 for (i, &idx) in decision.scheduled_running.iter().enumerate() {
238 if let Some(request) = self.scheduler.running_mut().get_mut(idx) {
239 request
240 .record_generated_tokens(decision.tokens_for_running[i], self.current_time);
241 }
242 }
243
244 let kv_util = self.scheduler.kv_cache_manager().utilization();
246
247 self.metrics
248 .record_iteration_metrics(kv_util, flops_util, bandwidth_util);
249
250 if self.current_time >= self.next_sample_time {
252 let running = self.scheduler.running();
254 let mut num_prefilling = 0;
255 let mut num_decoding = 0;
256 let mut prefill_tokens = 0;
257 let mut decode_tokens = 0;
258
259 for req in running {
260 if req.is_prefill() {
261 num_prefilling += 1;
262 } else {
263 num_decoding += 1;
264 }
265 }
266
267 for (i, &idx) in decision.scheduled_new.iter().enumerate() {
269 let tokens = decision.tokens_for_new[i];
270 if prefilling_reqs.contains(&idx) {
271 prefill_tokens += tokens;
272 } else {
273 decode_tokens += tokens;
274 }
275 }
276 for (i, &idx) in decision.scheduled_running.iter().enumerate() {
277 let tokens = decision.tokens_for_running[i];
278 if prefilling_reqs.contains(&idx) {
279 prefill_tokens += tokens;
280 } else {
281 decode_tokens += tokens;
282 }
283 }
284
285 let input_throughput = prefill_tokens as f64 / self.sample_interval;
287 let output_throughput = decode_tokens as f64 / self.sample_interval;
288
289 let (ttft_mean, tpot_mean) = self.metrics.get_interval_latencies();
291
292 self.time_series_data.push(TimeSeriesPoint {
293 time: self.current_time,
294 arrivals: self.metrics.total_requests,
295 running: self.scheduler.num_running(),
296 waiting: self.scheduler.num_waiting(),
297 kv_cache_util: kv_util,
298 num_prefilling,
299 num_decoding,
300 prefill_tokens,
301 decode_tokens,
302 input_throughput,
303 output_throughput,
304 ttft_p50: ttft_mean,
305 tpot_p50: tpot_mean,
306 });
307 self.prev_sample_time = self.current_time;
308 self.next_sample_time = self.current_time + self.sample_interval;
309 }
310
311 for request in decision.completed {
313 self.scheduler
315 .kv_cache_manager_mut()
316 .free_blocks(&request.kv_blocks);
317
318 self.metrics.record_request_completion(&request);
319
320 self.request_generator
322 .on_request_complete(self.current_time);
323 }
324
325 if self.current_time - last_callback_time >= callback_interval {
327 let kv_manager = self.scheduler.kv_cache_manager();
329 let summary = self.metrics.compute_summary(
330 self.current_time,
331 kv_manager.num_prefix_cache_hits,
332 kv_manager.num_prefix_cache_misses,
333 kv_manager.hit_size_sum,
334 kv_manager.hit_size_count,
335 );
336
337 let latency_samples = self.metrics.get_latency_samples();
339 let input_lengths = self.metrics.get_input_lengths();
340 let output_lengths = self.metrics.get_output_lengths();
341
342 let ttft_delta = &latency_samples.0 .0[self.last_sent_ttft_count..];
344 let ttft_timestamps_delta = &latency_samples.0 .1[self.last_sent_ttft_count..];
345 let e2e_delta = &latency_samples.1 .0[self.last_sent_e2e_count..];
346 let e2e_timestamps_delta = &latency_samples.1 .1[self.last_sent_e2e_count..];
347 let tpot_delta = &latency_samples.2 .0[self.last_sent_tpot_count..];
348 let tpot_timestamps_delta = &latency_samples.2 .1[self.last_sent_tpot_count..];
349 let input_delta = &input_lengths[self.last_sent_input_count..];
350 let output_delta = &output_lengths[self.last_sent_output_count..];
351
352 let progress = ProgressInfo {
353 current_time: self.current_time,
354 completed_requests: self.metrics.completed_requests,
355 total_requests: self.metrics.total_requests,
356 running: self.scheduler.num_running(),
357 waiting: self.scheduler.num_waiting(),
358 kv_cache_util: kv_util,
359 time_series: Some(&self.time_series_data),
360 metrics: Some(summary),
361 latency_samples: Some((
362 (ttft_delta, ttft_timestamps_delta),
363 (e2e_delta, e2e_timestamps_delta),
364 (tpot_delta, tpot_timestamps_delta),
365 )),
366 distribution_samples: Some((input_delta, output_delta)),
367 };
368 callback(progress);
369
370 self.last_sent_ttft_count = latency_samples.0 .0.len();
372 self.last_sent_e2e_count = latency_samples.1 .0.len();
373 self.last_sent_tpot_count = latency_samples.2 .0.len();
374 self.last_sent_input_count = input_lengths.len();
375 self.last_sent_output_count = output_lengths.len();
376 last_callback_time = self.current_time;
377 }
378
379 if self.should_terminate() {
381 let kv_manager = self.scheduler.kv_cache_manager();
384 let summary = self.metrics.compute_summary(
385 self.current_time,
386 kv_manager.num_prefix_cache_hits,
387 kv_manager.num_prefix_cache_misses,
388 kv_manager.hit_size_sum,
389 kv_manager.hit_size_count,
390 );
391
392 let latency_samples = self.metrics.get_latency_samples();
394 let input_lengths = self.metrics.get_input_lengths();
395 let output_lengths = self.metrics.get_output_lengths();
396
397 let ttft_delta = &latency_samples.0 .0[self.last_sent_ttft_count..];
399 let ttft_timestamps_delta = &latency_samples.0 .1[self.last_sent_ttft_count..];
400 let e2e_delta = &latency_samples.1 .0[self.last_sent_e2e_count..];
401 let e2e_timestamps_delta = &latency_samples.1 .1[self.last_sent_e2e_count..];
402 let tpot_delta = &latency_samples.2 .0[self.last_sent_tpot_count..];
403 let tpot_timestamps_delta = &latency_samples.2 .1[self.last_sent_tpot_count..];
404 let input_delta = &input_lengths[self.last_sent_input_count..];
405 let output_delta = &output_lengths[self.last_sent_output_count..];
406
407 let progress = ProgressInfo {
408 current_time: self.current_time,
409 completed_requests: self.metrics.completed_requests,
410 total_requests: self.metrics.total_requests,
411 running: self.scheduler.num_running(),
412 waiting: self.scheduler.num_waiting(),
413 kv_cache_util: kv_util,
414 time_series: Some(&self.time_series_data),
415 metrics: Some(summary),
416 latency_samples: Some((
417 (ttft_delta, ttft_timestamps_delta),
418 (e2e_delta, e2e_timestamps_delta),
419 (tpot_delta, tpot_timestamps_delta),
420 )),
421 distribution_samples: Some((input_delta, output_delta)),
422 };
423 callback(progress);
424 break;
425 }
426 }
427
428 Ok(())
429 }
430
431 pub fn get_metrics_summary(&mut self) -> crate::metrics::MetricsSummary {
432 let kv_manager = self.scheduler.kv_cache_manager();
433 self.metrics.compute_summary(
434 self.current_time,
435 kv_manager.num_prefix_cache_hits,
436 kv_manager.num_prefix_cache_misses,
437 kv_manager.hit_size_sum,
438 kv_manager.hit_size_count,
439 )
440 }
441
442 pub fn get_time_series_data(&self) -> &[TimeSeriesPoint] {
443 &self.time_series_data
444 }
445
446 pub fn get_input_lengths(&self) -> &[u32] {
447 self.metrics.get_input_lengths()
448 }
449
450 pub fn get_output_lengths(&self) -> &[u32] {
451 self.metrics.get_output_lengths()
452 }
453
454 pub fn get_current_time(&self) -> f64 {
455 self.current_time
456 }
457
458 pub fn get_latency_samples(&self) -> LatencySampleTriplet<'_> {
459 self.metrics.get_latency_samples()
460 }
461
462 fn should_terminate(&self) -> bool {
463 self.request_generator.is_finished()
465 && self.scheduler.num_running() == 0
466 && self.scheduler.num_waiting() == 0
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 fn create_minimal_test_config() -> Config {
475 let mut config = Config::test_default();
476 config.workload.num_requests = Some(10); config.workload.arrival_rate = 10.0; config
479 }
480
481 #[test]
482 fn test_simulation_completes_all_requests() {
483 let config = create_minimal_test_config();
484 let mut simulator = Simulator::new(config).unwrap();
485
486 simulator.run_with_callback(|_| {}).unwrap();
487
488 let summary = simulator.get_metrics_summary();
489
490 assert_eq!(summary.completed_requests, summary.total_requests);
492 assert_eq!(summary.completed_requests, 10);
493 }
494
495 #[test]
496 fn test_simulation_time_progresses() {
497 let config = create_minimal_test_config();
498 let mut simulator = Simulator::new(config).unwrap();
499
500 let start_time = simulator.get_current_time();
501 simulator.run_with_callback(|_| {}).unwrap();
502 let end_time = simulator.get_current_time();
503
504 assert!(end_time > start_time);
506 }
507
508 #[test]
509 fn test_simulation_metrics_reasonable() {
510 let config = create_minimal_test_config();
511 let mut simulator = Simulator::new(config).unwrap();
512
513 simulator.run_with_callback(|_| {}).unwrap();
514
515 let summary = simulator.get_metrics_summary();
516
517 assert!(summary.ttft_mean > 0.0 && summary.ttft_mean.is_finite());
519 assert!(summary.e2e_mean > 0.0 && summary.e2e_mean.is_finite());
520 assert!(summary.per_token_mean > 0.0 && summary.per_token_mean.is_finite());
521
522 assert!(summary.ttft_min <= summary.ttft_p50);
524 assert!(summary.ttft_p50 <= summary.ttft_p90);
525 assert!(summary.ttft_p90 <= summary.ttft_p99);
526
527 assert!(summary.e2e_min <= summary.e2e_p50);
528 assert!(summary.e2e_p50 <= summary.e2e_p90);
529 assert!(summary.e2e_p90 <= summary.e2e_p99);
530
531 assert!(summary.avg_kv_cache_util >= 0.0 && summary.avg_kv_cache_util <= 1.0);
533 assert!(summary.avg_flops_util >= 0.0 && summary.avg_flops_util <= 1.0);
534 assert!(summary.avg_bandwidth_util >= 0.0 && summary.avg_bandwidth_util <= 1.0);
535
536 assert!(summary.input_tokens_per_sec > 0.0);
538 assert!(summary.output_tokens_per_sec > 0.0);
539 assert!(summary.requests_per_sec > 0.0);
540 }
541
542 #[test]
543 fn test_simulation_no_infinite_loop() {
544 let config = create_minimal_test_config();
545 let mut simulator = Simulator::new(config).unwrap();
546
547 let mut iteration_count = 0;
549 simulator
550 .run_with_callback(|_| {
551 iteration_count += 1;
552 })
553 .unwrap();
554
555 assert!(iteration_count < 1000);
558 assert!(iteration_count > 0);
559 }
560
561 #[test]
562 fn test_simulation_with_fcfs_policy() {
563 let mut config = create_minimal_test_config();
564 config.scheduler.policy = "fcfs".to_string();
565
566 let mut simulator = Simulator::new(config).unwrap();
567 simulator.run_with_callback(|_| {}).unwrap();
568
569 let summary = simulator.get_metrics_summary();
570 assert_eq!(summary.completed_requests, 10);
571 }
572
573 #[test]
574 fn test_simulation_with_sjf_policy() {
575 let mut config = create_minimal_test_config();
576 config.scheduler.policy = "sjf".to_string();
577
578 let mut simulator = Simulator::new(config).unwrap();
579 simulator.run_with_callback(|_| {}).unwrap();
580
581 let summary = simulator.get_metrics_summary();
582 assert_eq!(summary.completed_requests, 10);
583 }
584
585 #[test]
586 fn test_simulation_with_priority_policy() {
587 let mut config = create_minimal_test_config();
588 config.scheduler.policy = "priority".to_string();
589
590 let mut simulator = Simulator::new(config).unwrap();
591 simulator.run_with_callback(|_| {}).unwrap();
592
593 let summary = simulator.get_metrics_summary();
594 assert_eq!(summary.completed_requests, 10);
595 }
596
597 #[test]
598 fn test_simulation_different_policies_produce_results() {
599 let mut config_fcfs = create_minimal_test_config();
600 config_fcfs.scheduler.policy = "fcfs".to_string();
601 config_fcfs.workload.seed = 42;
602
603 let mut config_sjf = create_minimal_test_config();
604 config_sjf.scheduler.policy = "sjf".to_string();
605 config_sjf.workload.seed = 42; let mut sim_fcfs = Simulator::new(config_fcfs).unwrap();
608 let mut sim_sjf = Simulator::new(config_sjf).unwrap();
609
610 sim_fcfs.run_with_callback(|_| {}).unwrap();
611 sim_sjf.run_with_callback(|_| {}).unwrap();
612
613 let summary_fcfs = sim_fcfs.get_metrics_summary();
614 let summary_sjf = sim_sjf.get_metrics_summary();
615
616 assert_eq!(summary_fcfs.completed_requests, 10);
618 assert_eq!(summary_sjf.completed_requests, 10);
619
620 assert!(summary_fcfs.e2e_mean > 0.0);
623 assert!(summary_sjf.e2e_mean > 0.0);
624 }
625
626 #[test]
627 fn test_simulation_with_chunked_prefill() {
628 let mut config = create_minimal_test_config();
629 config.scheduler.enable_chunked_prefill = true;
630 config.scheduler.long_prefill_token_threshold = 512;
631
632 let mut simulator = Simulator::new(config).unwrap();
633 simulator.run_with_callback(|_| {}).unwrap();
634
635 let summary = simulator.get_metrics_summary();
636 assert_eq!(summary.completed_requests, 10);
637 }
638
639 #[test]
640 fn test_simulation_preemption_metrics() {
641 let config = create_minimal_test_config();
642 let mut simulator = Simulator::new(config).unwrap();
643
644 simulator.run_with_callback(|_| {}).unwrap();
645
646 let summary = simulator.get_metrics_summary();
647
648 assert!(summary.preemptions_per_request_mean >= 0.0);
650 }
651
652 #[test]
653 fn test_simulation_time_series_collected() {
654 let config = create_minimal_test_config();
655 let mut simulator = Simulator::new(config).unwrap();
656
657 simulator.run_with_callback(|_| {}).unwrap();
658
659 let time_series = simulator.get_time_series_data();
660
661 assert!(!time_series.is_empty());
663
664 for i in 1..time_series.len() {
666 assert!(time_series[i].time >= time_series[i - 1].time);
667 }
668 }
669
670 #[test]
671 fn test_simulation_latency_samples_collected() {
672 let config = create_minimal_test_config();
673 let mut simulator = Simulator::new(config).unwrap();
674
675 simulator.run_with_callback(|_| {}).unwrap();
676
677 let ((ttft, ttft_ts), (e2e, e2e_ts), (tpot, tpot_ts)) = simulator.get_latency_samples();
678
679 assert!(!ttft.is_empty());
681 assert!(!e2e.is_empty());
682
683 assert_eq!(ttft.len(), ttft_ts.len());
685 assert_eq!(e2e.len(), e2e_ts.len());
686 assert_eq!(tpot.len(), tpot_ts.len());
687 }
688}