1use crate::compute::ComputeEngine;
2use crate::config::Config;
3use crate::kv_cache::KVCacheManager;
4use crate::metrics::MetricsCollector;
5use crate::request::RequestGenerator;
6use crate::scheduler::Scheduler;
7
8#[derive(Debug, Clone)]
9pub struct TimeSeriesPoint {
10 pub time: f64,
11 pub arrivals: u64,
12 pub running: usize,
13 pub waiting: usize,
14 pub kv_cache_util: f64,
15 pub num_prefilling: usize,
16 pub num_decoding: usize,
17 pub prefill_tokens: u32,
18 pub decode_tokens: u32,
19 pub input_throughput: f64, pub output_throughput: f64, pub ttft_p50: f64, pub tpot_p50: f64, }
24
25pub struct ProgressInfo<'a> {
26 pub current_time: f64,
27 pub completed_requests: u64,
28 pub total_requests: u64,
29 pub running: usize,
30 pub waiting: usize,
31 pub kv_cache_util: f64,
32 pub time_series: Option<&'a [TimeSeriesPoint]>,
33 pub metrics: Option<crate::metrics::MetricsSummary>,
34 pub latency_samples: Option<(
35 (&'a [f64], &'a [f64]),
36 (&'a [f64], &'a [f64]),
37 (&'a [f64], &'a [f64]),
38 )>,
39 pub distribution_samples: Option<(&'a [u32], &'a [u32])>, }
41
42pub struct Simulator {
43 scheduler: Scheduler,
44 compute_engine: ComputeEngine,
45 request_generator: RequestGenerator,
46 metrics: MetricsCollector,
47 time_series_data: Vec<TimeSeriesPoint>,
48 sample_interval: f64,
49 next_sample_time: f64,
50 prev_sample_time: f64,
51
52 current_time: f64,
53 iteration: u64,
54 log_interval: u64,
55
56 last_sent_ttft_count: usize,
58 last_sent_e2e_count: usize,
59 last_sent_tpot_count: usize,
60 last_sent_input_count: usize,
61 last_sent_output_count: usize,
62}
63
64impl Simulator {
65 pub fn new(config: Config) -> Result<Self, String> {
66 let kv_cache_manager = KVCacheManager::new(
67 config.hardware.kv_cache_capacity,
68 config.scheduler.block_size,
69 config.model.kv_cache_bytes_per_token,
70 true, );
72
73 let scheduler = Scheduler::new(
74 config.scheduler.clone(),
75 config.hardware.clone(),
76 config.model.clone(),
77 kv_cache_manager,
78 )?;
79
80 let compute_engine = ComputeEngine::new(config.hardware, config.model);
81 let request_generator = RequestGenerator::new(config.workload);
82 let metrics = MetricsCollector::new(0.0);
83
84 Ok(Self {
85 scheduler,
86 compute_engine,
87 request_generator,
88 metrics,
89 time_series_data: Vec::new(),
90 sample_interval: 0.1,
91 next_sample_time: 0.0,
92 prev_sample_time: 0.0,
93 current_time: 0.0,
94 iteration: 0,
95 log_interval: config.simulation.log_interval,
96 last_sent_ttft_count: 0,
97 last_sent_e2e_count: 0,
98 last_sent_tpot_count: 0,
99 last_sent_input_count: 0,
100 last_sent_output_count: 0,
101 })
102 }
103
104 pub fn run(&mut self) {
106 #[cfg(not(target_arch = "wasm32"))]
107 println!("Starting simulation...\n");
108
109 loop {
110 self.iteration += 1;
111
112 while let Some(request) = self.request_generator.next_if_before(self.current_time) {
114 self.scheduler.add_request(request);
115 self.metrics.total_requests += 1;
116 }
117
118 let decision = self.scheduler.schedule(self.current_time);
120
121 let iteration_time = if decision.num_scheduled() > 0 {
123 let running = self.scheduler.running_mut();
125 let batch_requests: Vec<&_> = decision
126 .scheduled_new
127 .iter()
128 .chain(decision.scheduled_running.iter())
129 .filter_map(|&idx| running.get(idx))
130 .collect();
131
132 let tokens_per_req: Vec<u32> = batch_requests
133 .iter()
134 .enumerate()
135 .map(|(i, _req)| {
136 let idx = if i < decision.scheduled_new.len() {
137 decision.scheduled_new[i]
138 } else {
139 decision.scheduled_running[i - decision.scheduled_new.len()]
140 };
141 *decision.tokens_per_request.get(&idx).unwrap_or(&0)
142 })
143 .collect();
144
145 self.compute_engine
146 .calculate_iteration_time(&batch_requests, &tokens_per_req)
147 } else {
148 0.001 };
150
151 self.current_time += iteration_time;
153
154 let mut prefilling_reqs = std::collections::HashSet::new();
156 for (&idx, &_tokens) in &decision.tokens_per_request {
157 if let Some(request) = self.scheduler.running().get(idx) {
158 if request.is_prefill() {
159 prefilling_reqs.insert(idx);
160 }
161 }
162 }
163
164 for (&idx, &tokens) in &decision.tokens_per_request {
166 if let Some(request) = self.scheduler.running_mut().get_mut(idx) {
167 request.record_generated_tokens(tokens, self.current_time);
168 }
169 }
170
171 let kv_util = self.scheduler.kv_cache_manager().utilization();
173
174 let (bandwidth_util, flops_util) = if decision.num_scheduled() > 0 {
176 let running = self.scheduler.running_mut();
177 let batch_requests: Vec<&_> = decision
178 .scheduled_new
179 .iter()
180 .chain(decision.scheduled_running.iter())
181 .filter_map(|&idx| running.get(idx))
182 .collect();
183
184 let tokens_per_req: Vec<u32> = batch_requests
185 .iter()
186 .enumerate()
187 .map(|(i, _req)| {
188 let idx = if i < decision.scheduled_new.len() {
189 decision.scheduled_new[i]
190 } else {
191 decision.scheduled_running[i - decision.scheduled_new.len()]
192 };
193 *decision.tokens_per_request.get(&idx).unwrap_or(&0)
194 })
195 .collect();
196
197 let bytes_transferred = self
198 .compute_engine
199 .calculate_bytes_transferred(&batch_requests, &tokens_per_req);
200 let bandwidth_util = self
201 .compute_engine
202 .calculate_bandwidth_utilization(bytes_transferred, iteration_time);
203
204 let flops_util = self.compute_engine.calculate_flops_utilization(
205 &batch_requests,
206 &tokens_per_req,
207 iteration_time,
208 );
209
210 (bandwidth_util, flops_util)
211 } else {
212 (0.0, 0.0)
213 };
214
215 self.metrics
216 .record_iteration_metrics(kv_util, flops_util, bandwidth_util);
217
218 if self.current_time >= self.next_sample_time {
220 let running = self.scheduler.running();
222 let mut num_prefilling = 0;
223 let mut num_decoding = 0;
224 let mut prefill_tokens = 0;
225 let mut decode_tokens = 0;
226
227 for req in running {
228 if req.is_prefill() {
229 num_prefilling += 1;
230 } else {
231 num_decoding += 1;
232 }
233 }
234
235 for (&idx, &tokens) in &decision.tokens_per_request {
238 if prefilling_reqs.contains(&idx) {
239 prefill_tokens += tokens;
240 } else {
241 decode_tokens += tokens;
242 }
243 }
244
245 self.metrics.record_throughput_sample(self.current_time);
247
248 let input_throughput = prefill_tokens as f64 / self.sample_interval;
250 let output_throughput = decode_tokens as f64 / self.sample_interval;
251
252 let (ttft_mean, tpot_mean) = self
254 .metrics
255 .get_interval_latencies(self.prev_sample_time, self.current_time);
256
257 self.time_series_data.push(TimeSeriesPoint {
258 time: self.current_time,
259 arrivals: self.metrics.total_requests,
260 running: self.scheduler.num_running(),
261 waiting: self.scheduler.num_waiting(),
262 kv_cache_util: kv_util,
263 num_prefilling,
264 num_decoding,
265 prefill_tokens,
266 decode_tokens,
267 input_throughput,
268 output_throughput,
269 ttft_p50: ttft_mean,
270 tpot_p50: tpot_mean,
271 });
272 self.prev_sample_time = self.current_time;
273 self.next_sample_time = self.current_time + self.sample_interval;
274 }
275
276 for request in decision.completed {
278 self.scheduler
280 .kv_cache_manager_mut()
281 .free_blocks(&request.kv_blocks);
282
283 self.metrics.record_request_completion(&request);
284
285 self.request_generator
287 .on_request_complete(self.current_time);
288 }
289
290 if self.iteration % self.log_interval == 0 {
292 self.log_progress();
293 }
294
295 if self.should_terminate() {
297 break;
298 }
299 }
300
301 #[cfg(not(target_arch = "wasm32"))]
302 println!("\nSimulation complete!");
303 #[cfg(not(target_arch = "wasm32"))]
304 self.print_final_metrics();
305 }
306
307 pub fn run_with_callback<F>(&mut self, mut callback: F) -> Result<(), String>
309 where
310 F: FnMut(ProgressInfo),
311 {
312 let mut last_callback_time = 0.0;
313 let callback_interval = 1.0; loop {
316 self.iteration += 1;
317
318 while let Some(request) = self.request_generator.next_if_before(self.current_time) {
320 self.scheduler.add_request(request);
321 self.metrics.total_requests += 1;
322 }
323
324 let decision = self.scheduler.schedule(self.current_time);
326
327 let iteration_time = if decision.num_scheduled() > 0 {
329 let running = self.scheduler.running_mut();
331 let batch_requests: Vec<&_> = decision
332 .scheduled_new
333 .iter()
334 .chain(decision.scheduled_running.iter())
335 .filter_map(|&idx| running.get(idx))
336 .collect();
337
338 let tokens_per_req: Vec<u32> = batch_requests
339 .iter()
340 .enumerate()
341 .map(|(i, _req)| {
342 let idx = if i < decision.scheduled_new.len() {
343 decision.scheduled_new[i]
344 } else {
345 decision.scheduled_running[i - decision.scheduled_new.len()]
346 };
347 *decision.tokens_per_request.get(&idx).unwrap_or(&0)
348 })
349 .collect();
350
351 self.compute_engine
352 .calculate_iteration_time(&batch_requests, &tokens_per_req)
353 } else {
354 0.001 };
356
357 self.current_time += iteration_time;
359
360 let mut prefilling_reqs = std::collections::HashSet::new();
362 for (&idx, &_tokens) in &decision.tokens_per_request {
363 if let Some(request) = self.scheduler.running().get(idx) {
364 if request.is_prefill() {
365 prefilling_reqs.insert(idx);
366 }
367 }
368 }
369
370 for (&idx, &tokens) in &decision.tokens_per_request {
372 if let Some(request) = self.scheduler.running_mut().get_mut(idx) {
373 request.record_generated_tokens(tokens, self.current_time);
374 }
375 }
376
377 let kv_util = self.scheduler.kv_cache_manager().utilization();
379
380 let (bandwidth_util, flops_util) = if decision.num_scheduled() > 0 {
382 let running = self.scheduler.running_mut();
383 let batch_requests: Vec<&_> = decision
384 .scheduled_new
385 .iter()
386 .chain(decision.scheduled_running.iter())
387 .filter_map(|&idx| running.get(idx))
388 .collect();
389
390 let tokens_per_req: Vec<u32> = batch_requests
391 .iter()
392 .enumerate()
393 .map(|(i, _req)| {
394 let idx = if i < decision.scheduled_new.len() {
395 decision.scheduled_new[i]
396 } else {
397 decision.scheduled_running[i - decision.scheduled_new.len()]
398 };
399 *decision.tokens_per_request.get(&idx).unwrap_or(&0)
400 })
401 .collect();
402
403 let bytes_transferred = self
404 .compute_engine
405 .calculate_bytes_transferred(&batch_requests, &tokens_per_req);
406 let bandwidth_util = self
407 .compute_engine
408 .calculate_bandwidth_utilization(bytes_transferred, iteration_time);
409
410 let flops_util = self.compute_engine.calculate_flops_utilization(
411 &batch_requests,
412 &tokens_per_req,
413 iteration_time,
414 );
415
416 (bandwidth_util, flops_util)
417 } else {
418 (0.0, 0.0)
419 };
420
421 self.metrics
422 .record_iteration_metrics(kv_util, flops_util, bandwidth_util);
423
424 if self.current_time >= self.next_sample_time {
426 let running = self.scheduler.running();
428 let mut num_prefilling = 0;
429 let mut num_decoding = 0;
430 let mut prefill_tokens = 0;
431 let mut decode_tokens = 0;
432
433 for req in running {
434 if req.is_prefill() {
435 num_prefilling += 1;
436 } else {
437 num_decoding += 1;
438 }
439 }
440
441 for (&idx, &tokens) in &decision.tokens_per_request {
443 if prefilling_reqs.contains(&idx) {
444 prefill_tokens += tokens;
445 } else {
446 decode_tokens += tokens;
447 }
448 }
449
450 self.metrics.record_throughput_sample(self.current_time);
452
453 let input_throughput = prefill_tokens as f64 / self.sample_interval;
455 let output_throughput = decode_tokens as f64 / self.sample_interval;
456
457 let (ttft_mean, tpot_mean) = self
459 .metrics
460 .get_interval_latencies(self.prev_sample_time, self.current_time);
461
462 self.time_series_data.push(TimeSeriesPoint {
463 time: self.current_time,
464 arrivals: self.metrics.total_requests,
465 running: self.scheduler.num_running(),
466 waiting: self.scheduler.num_waiting(),
467 kv_cache_util: kv_util,
468 num_prefilling,
469 num_decoding,
470 prefill_tokens,
471 decode_tokens,
472 input_throughput,
473 output_throughput,
474 ttft_p50: ttft_mean,
475 tpot_p50: tpot_mean,
476 });
477 self.prev_sample_time = self.current_time;
478 self.next_sample_time = self.current_time + self.sample_interval;
479 }
480
481 for request in decision.completed {
483 self.scheduler
485 .kv_cache_manager_mut()
486 .free_blocks(&request.kv_blocks);
487
488 self.metrics.record_request_completion(&request);
489
490 self.request_generator
492 .on_request_complete(self.current_time);
493 }
494
495 if self.current_time - last_callback_time >= callback_interval {
497 let latency_samples = self.metrics.get_latency_samples();
498 let input_lengths = self.metrics.get_input_lengths();
499 let output_lengths = self.metrics.get_output_lengths();
500
501 let ttft_delta = &latency_samples.0 .0[self.last_sent_ttft_count..];
503 let ttft_timestamps_delta = &latency_samples.0 .1[self.last_sent_ttft_count..];
504 let e2e_delta = &latency_samples.1 .0[self.last_sent_e2e_count..];
505 let e2e_timestamps_delta = &latency_samples.1 .1[self.last_sent_e2e_count..];
506 let tpot_delta = &latency_samples.2 .0[self.last_sent_tpot_count..];
507 let tpot_timestamps_delta = &latency_samples.2 .1[self.last_sent_tpot_count..];
508 let input_delta = &input_lengths[self.last_sent_input_count..];
509 let output_delta = &output_lengths[self.last_sent_output_count..];
510
511 let progress = ProgressInfo {
512 current_time: self.current_time,
513 completed_requests: self.metrics.completed_requests,
514 total_requests: self.metrics.total_requests,
515 running: self.scheduler.num_running(),
516 waiting: self.scheduler.num_waiting(),
517 kv_cache_util: kv_util,
518 time_series: Some(&self.time_series_data),
519 metrics: Some(self.metrics.compute_summary(self.current_time)),
520 latency_samples: Some((
521 (ttft_delta, ttft_timestamps_delta),
522 (e2e_delta, e2e_timestamps_delta),
523 (tpot_delta, tpot_timestamps_delta),
524 )),
525 distribution_samples: Some((input_delta, output_delta)),
526 };
527 callback(progress);
528
529 self.last_sent_ttft_count = latency_samples.0 .0.len();
531 self.last_sent_e2e_count = latency_samples.1 .0.len();
532 self.last_sent_tpot_count = latency_samples.2 .0.len();
533 self.last_sent_input_count = input_lengths.len();
534 self.last_sent_output_count = output_lengths.len();
535 last_callback_time = self.current_time;
536 }
537
538 if self.should_terminate() {
542 let latency_samples = self.metrics.get_latency_samples();
544 let input_lengths = self.metrics.get_input_lengths();
545 let output_lengths = self.metrics.get_output_lengths();
546
547 let ttft_delta = &latency_samples.0 .0[self.last_sent_ttft_count..];
549 let ttft_timestamps_delta = &latency_samples.0 .1[self.last_sent_ttft_count..];
550 let e2e_delta = &latency_samples.1 .0[self.last_sent_e2e_count..];
551 let e2e_timestamps_delta = &latency_samples.1 .1[self.last_sent_e2e_count..];
552 let tpot_delta = &latency_samples.2 .0[self.last_sent_tpot_count..];
553 let tpot_timestamps_delta = &latency_samples.2 .1[self.last_sent_tpot_count..];
554 let input_delta = &input_lengths[self.last_sent_input_count..];
555 let output_delta = &output_lengths[self.last_sent_output_count..];
556
557 let progress = ProgressInfo {
558 current_time: self.current_time,
559 completed_requests: self.metrics.completed_requests,
560 total_requests: self.metrics.total_requests,
561 running: self.scheduler.num_running(),
562 waiting: self.scheduler.num_waiting(),
563 kv_cache_util: kv_util,
564 time_series: Some(&self.time_series_data),
565 metrics: Some(self.metrics.compute_summary(self.current_time)),
566 latency_samples: Some((
567 (ttft_delta, ttft_timestamps_delta),
568 (e2e_delta, e2e_timestamps_delta),
569 (tpot_delta, tpot_timestamps_delta),
570 )),
571 distribution_samples: Some((input_delta, output_delta)),
572 };
573 callback(progress);
574 break;
575 }
576 }
577
578 Ok(())
579 }
580
581 pub fn get_metrics_summary(&self) -> crate::metrics::MetricsSummary {
582 self.metrics.compute_summary(self.current_time)
583 }
584
585 pub fn get_time_series_data(&self) -> &[TimeSeriesPoint] {
586 &self.time_series_data
587 }
588
589 pub fn get_input_lengths(&self) -> &[u32] {
590 self.metrics.get_input_lengths()
591 }
592
593 pub fn get_output_lengths(&self) -> &[u32] {
594 self.metrics.get_output_lengths()
595 }
596
597 pub fn get_current_time(&self) -> f64 {
598 self.current_time
599 }
600
601 pub fn get_latency_samples(
602 &self,
603 ) -> (
604 (&[f64], &[f64]), (&[f64], &[f64]), (&[f64], &[f64]), ) {
608 self.metrics.get_latency_samples()
609 }
610
611 fn log_progress(&self) {
612 #[cfg(not(target_arch = "wasm32"))]
613 println!(
614 "[{:.2}s] Iteration {}: {} running, {} waiting, {:.1}% KV cache used",
615 self.current_time,
616 self.iteration,
617 self.scheduler.num_running(),
618 self.scheduler.num_waiting(),
619 self.scheduler.kv_cache_manager().utilization() * 100.0,
620 );
621 }
622
623 fn should_terminate(&self) -> bool {
624 self.request_generator.is_finished()
626 && self.scheduler.num_running() == 0
627 && self.scheduler.num_waiting() == 0
628 }
629
630 fn print_final_metrics(&self) {
631 let summary = self.metrics.compute_summary(self.current_time);
632 summary.print();
633 }
634}