1use crate::compute::ComputeEngine;
2use crate::config::Config;
3use crate::kv_cache::KVCacheManager;
4use crate::metrics::MetricsCollector;
5use crate::request::RequestGenerator;
6use crate::scheduler::Scheduler;
7
8#[derive(Debug, Clone)]
9pub struct TimeSeriesPoint {
10 pub time: f64,
11 pub arrivals: u64,
12 pub running: usize,
13 pub waiting: usize,
14 pub kv_cache_util: f64,
15 pub num_prefilling: usize,
16 pub num_decoding: usize,
17 pub prefill_tokens: u32,
18 pub decode_tokens: u32,
19 pub input_throughput: f64, pub output_throughput: f64, pub ttft_p50: f64, pub tpot_p50: f64, }
24
25pub struct ProgressInfo<'a> {
26 pub current_time: f64,
27 pub completed_requests: u64,
28 pub total_requests: u64,
29 pub running: usize,
30 pub waiting: usize,
31 pub kv_cache_util: f64,
32 pub time_series: Option<&'a [TimeSeriesPoint]>,
33 pub metrics: Option<crate::metrics::MetricsSummary>,
34 pub latency_samples: Option<((&'a [f64], &'a [f64]), (&'a [f64], &'a [f64]), (&'a [f64], &'a [f64]))>,
35 pub distribution_samples: Option<(&'a [u32], &'a [u32])>, }
37
38pub struct Simulator {
39 scheduler: Scheduler,
40 compute_engine: ComputeEngine,
41 request_generator: RequestGenerator,
42 metrics: MetricsCollector,
43 time_series_data: Vec<TimeSeriesPoint>,
44 sample_interval: f64,
45 next_sample_time: f64,
46 prev_sample_time: f64,
47
48 current_time: f64,
49 iteration: u64,
50 log_interval: u64,
51
52 last_sent_ttft_count: usize,
54 last_sent_e2e_count: usize,
55 last_sent_tpot_count: usize,
56 last_sent_input_count: usize,
57 last_sent_output_count: usize,
58}
59
60impl Simulator {
61 pub fn new(config: Config) -> Result<Self, String> {
62 let kv_cache_manager = KVCacheManager::new(
63 config.hardware.kv_cache_capacity,
64 config.scheduler.block_size,
65 config.model.kv_cache_bytes_per_token,
66 true, );
68
69 let scheduler = Scheduler::new(
70 config.scheduler.clone(),
71 config.hardware.clone(),
72 config.model.clone(),
73 kv_cache_manager,
74 )?;
75
76 let compute_engine = ComputeEngine::new(config.hardware, config.model);
77 let request_generator = RequestGenerator::new(config.workload);
78 let metrics = MetricsCollector::new(0.0);
79
80 Ok(Self {
81 scheduler,
82 compute_engine,
83 request_generator,
84 metrics,
85 time_series_data: Vec::new(),
86 sample_interval: 0.1,
87 next_sample_time: 0.0,
88 prev_sample_time: 0.0,
89 current_time: 0.0,
90 iteration: 0,
91 log_interval: config.simulation.log_interval,
92 last_sent_ttft_count: 0,
93 last_sent_e2e_count: 0,
94 last_sent_tpot_count: 0,
95 last_sent_input_count: 0,
96 last_sent_output_count: 0,
97 })
98 }
99
100 pub fn run(&mut self) {
102 #[cfg(not(target_arch = "wasm32"))]
103 println!("Starting simulation...\n");
104
105 loop {
106 self.iteration += 1;
107
108 while let Some(request) = self.request_generator.next_if_before(self.current_time) {
110 self.scheduler.add_request(request);
111 self.metrics.total_requests += 1;
112 }
113
114 let decision = self.scheduler.schedule(self.current_time);
116
117 let iteration_time = if decision.num_scheduled() > 0 {
119 let running = self.scheduler.running_mut();
121 let batch_requests: Vec<&_> = decision
122 .scheduled_new
123 .iter()
124 .chain(decision.scheduled_running.iter())
125 .filter_map(|&idx| running.get(idx))
126 .collect();
127
128 let tokens_per_req: Vec<u32> = batch_requests
129 .iter()
130 .enumerate()
131 .map(|(i, _req)| {
132 let idx = if i < decision.scheduled_new.len() {
133 decision.scheduled_new[i]
134 } else {
135 decision.scheduled_running[i - decision.scheduled_new.len()]
136 };
137 *decision.tokens_per_request.get(&idx).unwrap_or(&0)
138 })
139 .collect();
140
141 self.compute_engine
142 .calculate_iteration_time(&batch_requests, &tokens_per_req)
143 } else {
144 0.001 };
146
147 self.current_time += iteration_time;
149
150 let mut prefilling_reqs = std::collections::HashSet::new();
152 for (&idx, &_tokens) in &decision.tokens_per_request {
153 if let Some(request) = self.scheduler.running().get(idx) {
154 if request.is_prefill() {
155 prefilling_reqs.insert(idx);
156 }
157 }
158 }
159
160 for (&idx, &tokens) in &decision.tokens_per_request {
162 if let Some(request) = self.scheduler.running_mut().get_mut(idx) {
163 request.record_generated_tokens(tokens, self.current_time);
164 }
165 }
166
167 let kv_util = self.scheduler.kv_cache_manager().utilization();
169
170 let (bandwidth_util, flops_util) = if decision.num_scheduled() > 0 {
172 let running = self.scheduler.running_mut();
173 let batch_requests: Vec<&_> = decision
174 .scheduled_new
175 .iter()
176 .chain(decision.scheduled_running.iter())
177 .filter_map(|&idx| running.get(idx))
178 .collect();
179
180 let tokens_per_req: Vec<u32> = batch_requests
181 .iter()
182 .enumerate()
183 .map(|(i, _req)| {
184 let idx = if i < decision.scheduled_new.len() {
185 decision.scheduled_new[i]
186 } else {
187 decision.scheduled_running[i - decision.scheduled_new.len()]
188 };
189 *decision.tokens_per_request.get(&idx).unwrap_or(&0)
190 })
191 .collect();
192
193 let bytes_transferred = self
194 .compute_engine
195 .calculate_bytes_transferred(&batch_requests, &tokens_per_req);
196 let bandwidth_util = self
197 .compute_engine
198 .calculate_bandwidth_utilization(bytes_transferred, iteration_time);
199
200 let flops_util = self
201 .compute_engine
202 .calculate_flops_utilization(&batch_requests, &tokens_per_req, iteration_time);
203
204 (bandwidth_util, flops_util)
205 } else {
206 (0.0, 0.0)
207 };
208
209 self.metrics
210 .record_iteration_metrics(kv_util, flops_util, bandwidth_util);
211
212 if self.current_time >= self.next_sample_time {
214 let running = self.scheduler.running();
216 let mut num_prefilling = 0;
217 let mut num_decoding = 0;
218 let mut prefill_tokens = 0;
219 let mut decode_tokens = 0;
220
221 for req in running {
222 if req.is_prefill() {
223 num_prefilling += 1;
224 } else {
225 num_decoding += 1;
226 }
227 }
228
229 for (&idx, &tokens) in &decision.tokens_per_request {
232 if prefilling_reqs.contains(&idx) {
233 prefill_tokens += tokens;
234 } else {
235 decode_tokens += tokens;
236 }
237 }
238
239 self.metrics.record_throughput_sample(self.current_time);
241
242 let input_throughput = prefill_tokens as f64 / self.sample_interval;
244 let output_throughput = decode_tokens as f64 / self.sample_interval;
245
246 let (ttft_mean, tpot_mean) = self.metrics.get_interval_latencies(
248 self.prev_sample_time,
249 self.current_time
250 );
251
252 self.time_series_data.push(TimeSeriesPoint {
253 time: self.current_time,
254 arrivals: self.metrics.total_requests,
255 running: self.scheduler.num_running(),
256 waiting: self.scheduler.num_waiting(),
257 kv_cache_util: kv_util,
258 num_prefilling,
259 num_decoding,
260 prefill_tokens,
261 decode_tokens,
262 input_throughput,
263 output_throughput,
264 ttft_p50: ttft_mean,
265 tpot_p50: tpot_mean,
266 });
267 self.prev_sample_time = self.current_time;
268 self.next_sample_time = self.current_time + self.sample_interval;
269 }
270
271 for request in decision.completed {
273 self.scheduler
275 .kv_cache_manager_mut()
276 .free_blocks(&request.kv_blocks);
277
278 self.metrics.record_request_completion(&request);
279
280 self.request_generator.on_request_complete(self.current_time);
282 }
283
284 if self.iteration % self.log_interval == 0 {
286 self.log_progress();
287 }
288
289 if self.should_terminate() {
291 break;
292 }
293 }
294
295 #[cfg(not(target_arch = "wasm32"))]
296 println!("\nSimulation complete!");
297 #[cfg(not(target_arch = "wasm32"))]
298 self.print_final_metrics();
299 }
300
301 pub fn run_with_callback<F>(&mut self, mut callback: F) -> Result<(), String>
303 where
304 F: FnMut(ProgressInfo),
305 {
306 let mut last_callback_time = 0.0;
307 let callback_interval = 1.0; loop {
310 self.iteration += 1;
311
312 while let Some(request) = self.request_generator.next_if_before(self.current_time) {
314 self.scheduler.add_request(request);
315 self.metrics.total_requests += 1;
316 }
317
318 let decision = self.scheduler.schedule(self.current_time);
320
321 let iteration_time = if decision.num_scheduled() > 0 {
323 let running = self.scheduler.running_mut();
325 let batch_requests: Vec<&_> = decision
326 .scheduled_new
327 .iter()
328 .chain(decision.scheduled_running.iter())
329 .filter_map(|&idx| running.get(idx))
330 .collect();
331
332 let tokens_per_req: Vec<u32> = batch_requests
333 .iter()
334 .enumerate()
335 .map(|(i, _req)| {
336 let idx = if i < decision.scheduled_new.len() {
337 decision.scheduled_new[i]
338 } else {
339 decision.scheduled_running[i - decision.scheduled_new.len()]
340 };
341 *decision.tokens_per_request.get(&idx).unwrap_or(&0)
342 })
343 .collect();
344
345 self.compute_engine
346 .calculate_iteration_time(&batch_requests, &tokens_per_req)
347 } else {
348 0.001 };
350
351 self.current_time += iteration_time;
353
354 let mut prefilling_reqs = std::collections::HashSet::new();
356 for (&idx, &_tokens) in &decision.tokens_per_request {
357 if let Some(request) = self.scheduler.running().get(idx) {
358 if request.is_prefill() {
359 prefilling_reqs.insert(idx);
360 }
361 }
362 }
363
364 for (&idx, &tokens) in &decision.tokens_per_request {
366 if let Some(request) = self.scheduler.running_mut().get_mut(idx) {
367 request.record_generated_tokens(tokens, self.current_time);
368 }
369 }
370
371 let kv_util = self.scheduler.kv_cache_manager().utilization();
373
374 let (bandwidth_util, flops_util) = if decision.num_scheduled() > 0 {
376 let running = self.scheduler.running_mut();
377 let batch_requests: Vec<&_> = decision
378 .scheduled_new
379 .iter()
380 .chain(decision.scheduled_running.iter())
381 .filter_map(|&idx| running.get(idx))
382 .collect();
383
384 let tokens_per_req: Vec<u32> = batch_requests
385 .iter()
386 .enumerate()
387 .map(|(i, _req)| {
388 let idx = if i < decision.scheduled_new.len() {
389 decision.scheduled_new[i]
390 } else {
391 decision.scheduled_running[i - decision.scheduled_new.len()]
392 };
393 *decision.tokens_per_request.get(&idx).unwrap_or(&0)
394 })
395 .collect();
396
397 let bytes_transferred = self
398 .compute_engine
399 .calculate_bytes_transferred(&batch_requests, &tokens_per_req);
400 let bandwidth_util = self
401 .compute_engine
402 .calculate_bandwidth_utilization(bytes_transferred, iteration_time);
403
404 let flops_util = self
405 .compute_engine
406 .calculate_flops_utilization(&batch_requests, &tokens_per_req, iteration_time);
407
408 (bandwidth_util, flops_util)
409 } else {
410 (0.0, 0.0)
411 };
412
413 self.metrics
414 .record_iteration_metrics(kv_util, flops_util, bandwidth_util);
415
416 if self.current_time >= self.next_sample_time {
418 let running = self.scheduler.running();
420 let mut num_prefilling = 0;
421 let mut num_decoding = 0;
422 let mut prefill_tokens = 0;
423 let mut decode_tokens = 0;
424
425 for req in running {
426 if req.is_prefill() {
427 num_prefilling += 1;
428 } else {
429 num_decoding += 1;
430 }
431 }
432
433 for (&idx, &tokens) in &decision.tokens_per_request {
435 if prefilling_reqs.contains(&idx) {
436 prefill_tokens += tokens;
437 } else {
438 decode_tokens += tokens;
439 }
440 }
441
442 self.metrics.record_throughput_sample(self.current_time);
444
445 let input_throughput = prefill_tokens as f64 / self.sample_interval;
447 let output_throughput = decode_tokens as f64 / self.sample_interval;
448
449 let (ttft_mean, tpot_mean) = self.metrics.get_interval_latencies(
451 self.prev_sample_time,
452 self.current_time
453 );
454
455 self.time_series_data.push(TimeSeriesPoint {
456 time: self.current_time,
457 arrivals: self.metrics.total_requests,
458 running: self.scheduler.num_running(),
459 waiting: self.scheduler.num_waiting(),
460 kv_cache_util: kv_util,
461 num_prefilling,
462 num_decoding,
463 prefill_tokens,
464 decode_tokens,
465 input_throughput,
466 output_throughput,
467 ttft_p50: ttft_mean,
468 tpot_p50: tpot_mean,
469 });
470 self.prev_sample_time = self.current_time;
471 self.next_sample_time = self.current_time + self.sample_interval;
472 }
473
474 for request in decision.completed {
476 self.scheduler
478 .kv_cache_manager_mut()
479 .free_blocks(&request.kv_blocks);
480
481 self.metrics.record_request_completion(&request);
482
483 self.request_generator.on_request_complete(self.current_time);
485 }
486
487 if self.current_time - last_callback_time >= callback_interval {
489 let latency_samples = self.metrics.get_latency_samples();
490 let input_lengths = self.metrics.get_input_lengths();
491 let output_lengths = self.metrics.get_output_lengths();
492
493 let ttft_delta = &latency_samples.0.0[self.last_sent_ttft_count..];
495 let ttft_timestamps_delta = &latency_samples.0.1[self.last_sent_ttft_count..];
496 let e2e_delta = &latency_samples.1.0[self.last_sent_e2e_count..];
497 let e2e_timestamps_delta = &latency_samples.1.1[self.last_sent_e2e_count..];
498 let tpot_delta = &latency_samples.2.0[self.last_sent_tpot_count..];
499 let tpot_timestamps_delta = &latency_samples.2.1[self.last_sent_tpot_count..];
500 let input_delta = &input_lengths[self.last_sent_input_count..];
501 let output_delta = &output_lengths[self.last_sent_output_count..];
502
503 let progress = ProgressInfo {
504 current_time: self.current_time,
505 completed_requests: self.metrics.completed_requests,
506 total_requests: self.metrics.total_requests,
507 running: self.scheduler.num_running(),
508 waiting: self.scheduler.num_waiting(),
509 kv_cache_util: kv_util,
510 time_series: Some(&self.time_series_data),
511 metrics: Some(self.metrics.compute_summary(self.current_time)),
512 latency_samples: Some(((ttft_delta, ttft_timestamps_delta), (e2e_delta, e2e_timestamps_delta), (tpot_delta, tpot_timestamps_delta))),
513 distribution_samples: Some((input_delta, output_delta)),
514 };
515 callback(progress);
516
517 self.last_sent_ttft_count = latency_samples.0.0.len();
519 self.last_sent_e2e_count = latency_samples.1.0.len();
520 self.last_sent_tpot_count = latency_samples.2.0.len();
521 self.last_sent_input_count = input_lengths.len();
522 self.last_sent_output_count = output_lengths.len();
523 last_callback_time = self.current_time;
524 }
525
526 if self.should_terminate() {
530 let latency_samples = self.metrics.get_latency_samples();
532 let input_lengths = self.metrics.get_input_lengths();
533 let output_lengths = self.metrics.get_output_lengths();
534
535 let ttft_delta = &latency_samples.0.0[self.last_sent_ttft_count..];
537 let ttft_timestamps_delta = &latency_samples.0.1[self.last_sent_ttft_count..];
538 let e2e_delta = &latency_samples.1.0[self.last_sent_e2e_count..];
539 let e2e_timestamps_delta = &latency_samples.1.1[self.last_sent_e2e_count..];
540 let tpot_delta = &latency_samples.2.0[self.last_sent_tpot_count..];
541 let tpot_timestamps_delta = &latency_samples.2.1[self.last_sent_tpot_count..];
542 let input_delta = &input_lengths[self.last_sent_input_count..];
543 let output_delta = &output_lengths[self.last_sent_output_count..];
544
545 let progress = ProgressInfo {
546 current_time: self.current_time,
547 completed_requests: self.metrics.completed_requests,
548 total_requests: self.metrics.total_requests,
549 running: self.scheduler.num_running(),
550 waiting: self.scheduler.num_waiting(),
551 kv_cache_util: kv_util,
552 time_series: Some(&self.time_series_data),
553 metrics: Some(self.metrics.compute_summary(self.current_time)),
554 latency_samples: Some(((ttft_delta, ttft_timestamps_delta), (e2e_delta, e2e_timestamps_delta), (tpot_delta, tpot_timestamps_delta))),
555 distribution_samples: Some((input_delta, output_delta)),
556 };
557 callback(progress);
558 break;
559 }
560 }
561
562 Ok(())
563 }
564
565 pub fn get_metrics_summary(&self) -> crate::metrics::MetricsSummary {
566 self.metrics.compute_summary(self.current_time)
567 }
568
569 pub fn get_time_series_data(&self) -> &[TimeSeriesPoint] {
570 &self.time_series_data
571 }
572
573 pub fn get_input_lengths(&self) -> &[u32] {
574 self.metrics.get_input_lengths()
575 }
576
577 pub fn get_output_lengths(&self) -> &[u32] {
578 self.metrics.get_output_lengths()
579 }
580
581 pub fn get_current_time(&self) -> f64 {
582 self.current_time
583 }
584
585 pub fn get_latency_samples(&self) -> (
586 (&[f64], &[f64]), (&[f64], &[f64]), (&[f64], &[f64]), ) {
590 self.metrics.get_latency_samples()
591 }
592
593 fn log_progress(&self) {
594 #[cfg(not(target_arch = "wasm32"))]
595 println!(
596 "[{:.2}s] Iteration {}: {} running, {} waiting, {:.1}% KV cache used",
597 self.current_time,
598 self.iteration,
599 self.scheduler.num_running(),
600 self.scheduler.num_waiting(),
601 self.scheduler.kv_cache_manager().utilization() * 100.0,
602 );
603 }
604
605 fn should_terminate(&self) -> bool {
606 self.request_generator.is_finished()
608 && self.scheduler.num_running() == 0
609 && self.scheduler.num_waiting() == 0
610 }
611
612 fn print_final_metrics(&self) {
613 let summary = self.metrics.compute_summary(self.current_time);
614 summary.print();
615 }
616}