inference_lab/simulation/
simulator.rs

1use crate::compute::ComputeEngine;
2use crate::config::Config;
3use crate::kv_cache::KVCacheManager;
4use crate::metrics::MetricsCollector;
5use crate::request::RequestGenerator;
6use crate::scheduler::Scheduler;
7
8#[derive(Debug, Clone)]
9pub struct TimeSeriesPoint {
10    pub time: f64,
11    pub arrivals: u64,
12    pub running: usize,
13    pub waiting: usize,
14    pub kv_cache_util: f64,
15    pub num_prefilling: usize,
16    pub num_decoding: usize,
17    pub prefill_tokens: u32,
18    pub decode_tokens: u32,
19    pub input_throughput: f64,  // Input tokens per second (windowed)
20    pub output_throughput: f64, // Output tokens per second (windowed)
21    pub ttft_p50: f64,          // TTFT p50 in recent window (ms)
22    pub tpot_p50: f64,          // TPOT p50 in recent window (ms)
23}
24
25pub struct ProgressInfo<'a> {
26    pub current_time: f64,
27    pub completed_requests: u64,
28    pub total_requests: u64,
29    pub running: usize,
30    pub waiting: usize,
31    pub kv_cache_util: f64,
32    pub time_series: Option<&'a [TimeSeriesPoint]>,
33    pub metrics: Option<crate::metrics::MetricsSummary>,
34    pub latency_samples: Option<(
35        (&'a [f64], &'a [f64]),
36        (&'a [f64], &'a [f64]),
37        (&'a [f64], &'a [f64]),
38    )>,
39    pub distribution_samples: Option<(&'a [u32], &'a [u32])>, // (input_lengths, output_lengths)
40}
41
42pub struct Simulator {
43    scheduler: Scheduler,
44    compute_engine: ComputeEngine,
45    request_generator: RequestGenerator,
46    metrics: MetricsCollector,
47    time_series_data: Vec<TimeSeriesPoint>,
48    sample_interval: f64,
49    next_sample_time: f64,
50    prev_sample_time: f64,
51
52    current_time: f64,
53    iteration: u64,
54    log_interval: u64,
55
56    // Track last sent sample counts for streaming deltas (one per metric type)
57    last_sent_ttft_count: usize,
58    last_sent_e2e_count: usize,
59    last_sent_tpot_count: usize,
60    last_sent_input_count: usize,
61    last_sent_output_count: usize,
62}
63
64impl Simulator {
65    pub fn new(config: Config) -> Result<Self, String> {
66        let kv_cache_manager = KVCacheManager::new(
67            config.hardware.kv_cache_capacity,
68            config.scheduler.block_size,
69            config.model.kv_cache_bytes_per_token,
70            true, // enable_prefix_caching
71        );
72
73        let scheduler = Scheduler::new(
74            config.scheduler.clone(),
75            config.hardware.clone(),
76            config.model.clone(),
77            kv_cache_manager,
78        )?;
79
80        let compute_engine = ComputeEngine::new(config.hardware, config.model);
81        let request_generator = RequestGenerator::new(config.workload);
82        let metrics = MetricsCollector::new(0.0);
83
84        Ok(Self {
85            scheduler,
86            compute_engine,
87            request_generator,
88            metrics,
89            time_series_data: Vec::new(),
90            sample_interval: 0.1,
91            next_sample_time: 0.0,
92            prev_sample_time: 0.0,
93            current_time: 0.0,
94            iteration: 0,
95            log_interval: config.simulation.log_interval,
96            last_sent_ttft_count: 0,
97            last_sent_e2e_count: 0,
98            last_sent_tpot_count: 0,
99            last_sent_input_count: 0,
100            last_sent_output_count: 0,
101        })
102    }
103
104    /// Run the simulation
105    pub fn run(&mut self) {
106        #[cfg(not(target_arch = "wasm32"))]
107        println!("Starting simulation...\n");
108
109        loop {
110            self.iteration += 1;
111
112            // 1. Generate new arrivals up to current_time
113            while let Some(request) = self.request_generator.next_if_before(self.current_time) {
114                self.scheduler.add_request(request);
115                self.metrics.total_requests += 1;
116            }
117
118            // 2. Run scheduler
119            let decision = self.scheduler.schedule(self.current_time);
120
121            // 3. Calculate iteration time
122            let iteration_time = if decision.num_scheduled() > 0 {
123                // Build batch of scheduled requests
124                let running = self.scheduler.running_mut();
125                let batch_requests: Vec<&_> = decision
126                    .scheduled_new
127                    .iter()
128                    .chain(decision.scheduled_running.iter())
129                    .filter_map(|&idx| running.get(idx))
130                    .collect();
131
132                let tokens_per_req: Vec<u32> = batch_requests
133                    .iter()
134                    .enumerate()
135                    .map(|(i, _req)| {
136                        let idx = if i < decision.scheduled_new.len() {
137                            decision.scheduled_new[i]
138                        } else {
139                            decision.scheduled_running[i - decision.scheduled_new.len()]
140                        };
141                        *decision.tokens_per_request.get(&idx).unwrap_or(&0)
142                    })
143                    .collect();
144
145                self.compute_engine
146                    .calculate_iteration_time(&batch_requests, &tokens_per_req)
147            } else {
148                0.001 // Small time step when idle
149            };
150
151            // 4. Advance time
152            self.current_time += iteration_time;
153
154            // 5. Determine which requests were prefilling vs decoding BEFORE updating state
155            let mut prefilling_reqs = std::collections::HashSet::new();
156            for (&idx, &_tokens) in &decision.tokens_per_request {
157                if let Some(request) = self.scheduler.running().get(idx) {
158                    if request.is_prefill() {
159                        prefilling_reqs.insert(idx);
160                    }
161                }
162            }
163
164            // 6. Update request states
165            for (&idx, &tokens) in &decision.tokens_per_request {
166                if let Some(request) = self.scheduler.running_mut().get_mut(idx) {
167                    request.record_generated_tokens(tokens, self.current_time);
168                }
169            }
170
171            // 7. Record iteration metrics (before moving completed requests)
172            let kv_util = self.scheduler.kv_cache_manager().utilization();
173
174            // Calculate bandwidth and flops utilization
175            let (bandwidth_util, flops_util) = if decision.num_scheduled() > 0 {
176                let running = self.scheduler.running_mut();
177                let batch_requests: Vec<&_> = decision
178                    .scheduled_new
179                    .iter()
180                    .chain(decision.scheduled_running.iter())
181                    .filter_map(|&idx| running.get(idx))
182                    .collect();
183
184                let tokens_per_req: Vec<u32> = batch_requests
185                    .iter()
186                    .enumerate()
187                    .map(|(i, _req)| {
188                        let idx = if i < decision.scheduled_new.len() {
189                            decision.scheduled_new[i]
190                        } else {
191                            decision.scheduled_running[i - decision.scheduled_new.len()]
192                        };
193                        *decision.tokens_per_request.get(&idx).unwrap_or(&0)
194                    })
195                    .collect();
196
197                let bytes_transferred = self
198                    .compute_engine
199                    .calculate_bytes_transferred(&batch_requests, &tokens_per_req);
200                let bandwidth_util = self
201                    .compute_engine
202                    .calculate_bandwidth_utilization(bytes_transferred, iteration_time);
203
204                let flops_util = self.compute_engine.calculate_flops_utilization(
205                    &batch_requests,
206                    &tokens_per_req,
207                    iteration_time,
208                );
209
210                (bandwidth_util, flops_util)
211            } else {
212                (0.0, 0.0)
213            };
214
215            self.metrics
216                .record_iteration_metrics(kv_util, flops_util, bandwidth_util);
217
218            // 7. Record time-series data (BEFORE handling completed requests)
219            if self.current_time >= self.next_sample_time {
220                // Calculate prefill vs decode breakdown
221                let running = self.scheduler.running();
222                let mut num_prefilling = 0;
223                let mut num_decoding = 0;
224                let mut prefill_tokens = 0;
225                let mut decode_tokens = 0;
226
227                for req in running {
228                    if req.is_prefill() {
229                        num_prefilling += 1;
230                    } else {
231                        num_decoding += 1;
232                    }
233                }
234
235                // Count tokens scheduled in this iteration
236                // Use the prefilling_reqs set we captured before updating state
237                for (&idx, &tokens) in &decision.tokens_per_request {
238                    if prefilling_reqs.contains(&idx) {
239                        prefill_tokens += tokens;
240                    } else {
241                        decode_tokens += tokens;
242                    }
243                }
244
245                // Record throughput sample
246                self.metrics.record_throughput_sample(self.current_time);
247
248                // Calculate windowed throughput (tokens per second)
249                let input_throughput = prefill_tokens as f64 / self.sample_interval;
250                let output_throughput = decode_tokens as f64 / self.sample_interval;
251
252                // Get latency mean for events since last sample
253                let (ttft_mean, tpot_mean) = self
254                    .metrics
255                    .get_interval_latencies(self.prev_sample_time, self.current_time);
256
257                self.time_series_data.push(TimeSeriesPoint {
258                    time: self.current_time,
259                    arrivals: self.metrics.total_requests,
260                    running: self.scheduler.num_running(),
261                    waiting: self.scheduler.num_waiting(),
262                    kv_cache_util: kv_util,
263                    num_prefilling,
264                    num_decoding,
265                    prefill_tokens,
266                    decode_tokens,
267                    input_throughput,
268                    output_throughput,
269                    ttft_p50: ttft_mean,
270                    tpot_p50: tpot_mean,
271                });
272                self.prev_sample_time = self.current_time;
273                self.next_sample_time = self.current_time + self.sample_interval;
274            }
275
276            // 8. Handle completed requests
277            for request in decision.completed {
278                // Free KV cache blocks
279                self.scheduler
280                    .kv_cache_manager_mut()
281                    .free_blocks(&request.kv_blocks);
282
283                self.metrics.record_request_completion(&request);
284
285                // For closed-loop workloads, generate a new request when one completes
286                self.request_generator
287                    .on_request_complete(self.current_time);
288            }
289
290            // 9. Periodic logging
291            if self.iteration % self.log_interval == 0 {
292                self.log_progress();
293            }
294
295            // 10. Check termination conditions
296            if self.should_terminate() {
297                break;
298            }
299        }
300
301        #[cfg(not(target_arch = "wasm32"))]
302        println!("\nSimulation complete!");
303        #[cfg(not(target_arch = "wasm32"))]
304        self.print_final_metrics();
305    }
306
307    /// Run the simulation with progress callbacks
308    pub fn run_with_callback<F>(&mut self, mut callback: F) -> Result<(), String>
309    where
310        F: FnMut(ProgressInfo),
311    {
312        let mut last_callback_time = 0.0;
313        let callback_interval = 1.0; // Call callback every 1.0 seconds
314
315        loop {
316            self.iteration += 1;
317
318            // 1. Generate new arrivals up to current_time
319            while let Some(request) = self.request_generator.next_if_before(self.current_time) {
320                self.scheduler.add_request(request);
321                self.metrics.total_requests += 1;
322            }
323
324            // 2. Run scheduler
325            let decision = self.scheduler.schedule(self.current_time);
326
327            // 3. Calculate iteration time
328            let iteration_time = if decision.num_scheduled() > 0 {
329                // Build batch of scheduled requests
330                let running = self.scheduler.running_mut();
331                let batch_requests: Vec<&_> = decision
332                    .scheduled_new
333                    .iter()
334                    .chain(decision.scheduled_running.iter())
335                    .filter_map(|&idx| running.get(idx))
336                    .collect();
337
338                let tokens_per_req: Vec<u32> = batch_requests
339                    .iter()
340                    .enumerate()
341                    .map(|(i, _req)| {
342                        let idx = if i < decision.scheduled_new.len() {
343                            decision.scheduled_new[i]
344                        } else {
345                            decision.scheduled_running[i - decision.scheduled_new.len()]
346                        };
347                        *decision.tokens_per_request.get(&idx).unwrap_or(&0)
348                    })
349                    .collect();
350
351                self.compute_engine
352                    .calculate_iteration_time(&batch_requests, &tokens_per_req)
353            } else {
354                0.001 // Small time step when idle
355            };
356
357            // 4. Advance time
358            self.current_time += iteration_time;
359
360            // 5. Determine which requests were prefilling vs decoding BEFORE updating state
361            let mut prefilling_reqs = std::collections::HashSet::new();
362            for (&idx, &_tokens) in &decision.tokens_per_request {
363                if let Some(request) = self.scheduler.running().get(idx) {
364                    if request.is_prefill() {
365                        prefilling_reqs.insert(idx);
366                    }
367                }
368            }
369
370            // 6. Update request states
371            for (&idx, &tokens) in &decision.tokens_per_request {
372                if let Some(request) = self.scheduler.running_mut().get_mut(idx) {
373                    request.record_generated_tokens(tokens, self.current_time);
374                }
375            }
376
377            // 7. Record iteration metrics (before moving completed requests)
378            let kv_util = self.scheduler.kv_cache_manager().utilization();
379
380            // Calculate bandwidth and flops utilization
381            let (bandwidth_util, flops_util) = if decision.num_scheduled() > 0 {
382                let running = self.scheduler.running_mut();
383                let batch_requests: Vec<&_> = decision
384                    .scheduled_new
385                    .iter()
386                    .chain(decision.scheduled_running.iter())
387                    .filter_map(|&idx| running.get(idx))
388                    .collect();
389
390                let tokens_per_req: Vec<u32> = batch_requests
391                    .iter()
392                    .enumerate()
393                    .map(|(i, _req)| {
394                        let idx = if i < decision.scheduled_new.len() {
395                            decision.scheduled_new[i]
396                        } else {
397                            decision.scheduled_running[i - decision.scheduled_new.len()]
398                        };
399                        *decision.tokens_per_request.get(&idx).unwrap_or(&0)
400                    })
401                    .collect();
402
403                let bytes_transferred = self
404                    .compute_engine
405                    .calculate_bytes_transferred(&batch_requests, &tokens_per_req);
406                let bandwidth_util = self
407                    .compute_engine
408                    .calculate_bandwidth_utilization(bytes_transferred, iteration_time);
409
410                let flops_util = self.compute_engine.calculate_flops_utilization(
411                    &batch_requests,
412                    &tokens_per_req,
413                    iteration_time,
414                );
415
416                (bandwidth_util, flops_util)
417            } else {
418                (0.0, 0.0)
419            };
420
421            self.metrics
422                .record_iteration_metrics(kv_util, flops_util, bandwidth_util);
423
424            // 8. Record time-series data (BEFORE handling completed requests)
425            if self.current_time >= self.next_sample_time {
426                // Calculate prefill vs decode breakdown
427                let running = self.scheduler.running();
428                let mut num_prefilling = 0;
429                let mut num_decoding = 0;
430                let mut prefill_tokens = 0;
431                let mut decode_tokens = 0;
432
433                for req in running {
434                    if req.is_prefill() {
435                        num_prefilling += 1;
436                    } else {
437                        num_decoding += 1;
438                    }
439                }
440
441                // Count tokens scheduled in this iteration
442                for (&idx, &tokens) in &decision.tokens_per_request {
443                    if prefilling_reqs.contains(&idx) {
444                        prefill_tokens += tokens;
445                    } else {
446                        decode_tokens += tokens;
447                    }
448                }
449
450                // Record throughput sample
451                self.metrics.record_throughput_sample(self.current_time);
452
453                // Calculate windowed throughput (tokens per second)
454                let input_throughput = prefill_tokens as f64 / self.sample_interval;
455                let output_throughput = decode_tokens as f64 / self.sample_interval;
456
457                // Get latency mean for events since last sample
458                let (ttft_mean, tpot_mean) = self
459                    .metrics
460                    .get_interval_latencies(self.prev_sample_time, self.current_time);
461
462                self.time_series_data.push(TimeSeriesPoint {
463                    time: self.current_time,
464                    arrivals: self.metrics.total_requests,
465                    running: self.scheduler.num_running(),
466                    waiting: self.scheduler.num_waiting(),
467                    kv_cache_util: kv_util,
468                    num_prefilling,
469                    num_decoding,
470                    prefill_tokens,
471                    decode_tokens,
472                    input_throughput,
473                    output_throughput,
474                    ttft_p50: ttft_mean,
475                    tpot_p50: tpot_mean,
476                });
477                self.prev_sample_time = self.current_time;
478                self.next_sample_time = self.current_time + self.sample_interval;
479            }
480
481            // 9. Handle completed requests
482            for request in decision.completed {
483                // Free KV cache blocks
484                self.scheduler
485                    .kv_cache_manager_mut()
486                    .free_blocks(&request.kv_blocks);
487
488                self.metrics.record_request_completion(&request);
489
490                // For closed-loop workloads, generate a new request when one completes
491                self.request_generator
492                    .on_request_complete(self.current_time);
493            }
494
495            // 10. Send progress update if enough time has passed
496            if self.current_time - last_callback_time >= callback_interval {
497                let latency_samples = self.metrics.get_latency_samples();
498                let input_lengths = self.metrics.get_input_lengths();
499                let output_lengths = self.metrics.get_output_lengths();
500
501                // Only send new samples since last callback (delta) - track each metric separately
502                let ttft_delta = &latency_samples.0 .0[self.last_sent_ttft_count..];
503                let ttft_timestamps_delta = &latency_samples.0 .1[self.last_sent_ttft_count..];
504                let e2e_delta = &latency_samples.1 .0[self.last_sent_e2e_count..];
505                let e2e_timestamps_delta = &latency_samples.1 .1[self.last_sent_e2e_count..];
506                let tpot_delta = &latency_samples.2 .0[self.last_sent_tpot_count..];
507                let tpot_timestamps_delta = &latency_samples.2 .1[self.last_sent_tpot_count..];
508                let input_delta = &input_lengths[self.last_sent_input_count..];
509                let output_delta = &output_lengths[self.last_sent_output_count..];
510
511                let progress = ProgressInfo {
512                    current_time: self.current_time,
513                    completed_requests: self.metrics.completed_requests,
514                    total_requests: self.metrics.total_requests,
515                    running: self.scheduler.num_running(),
516                    waiting: self.scheduler.num_waiting(),
517                    kv_cache_util: kv_util,
518                    time_series: Some(&self.time_series_data),
519                    metrics: Some(self.metrics.compute_summary(self.current_time)),
520                    latency_samples: Some((
521                        (ttft_delta, ttft_timestamps_delta),
522                        (e2e_delta, e2e_timestamps_delta),
523                        (tpot_delta, tpot_timestamps_delta),
524                    )),
525                    distribution_samples: Some((input_delta, output_delta)),
526                };
527                callback(progress);
528
529                // Update last sent sample counts for each metric
530                self.last_sent_ttft_count = latency_samples.0 .0.len();
531                self.last_sent_e2e_count = latency_samples.1 .0.len();
532                self.last_sent_tpot_count = latency_samples.2 .0.len();
533                self.last_sent_input_count = input_lengths.len();
534                self.last_sent_output_count = output_lengths.len();
535                last_callback_time = self.current_time;
536            }
537
538            // Note: Periodic logging removed for callback mode - callback replaces it
539
540            // 11. Check termination conditions
541            if self.should_terminate() {
542                // Send final progress update with any remaining samples
543                let latency_samples = self.metrics.get_latency_samples();
544                let input_lengths = self.metrics.get_input_lengths();
545                let output_lengths = self.metrics.get_output_lengths();
546
547                // Only send new samples since last callback (delta) - track each metric separately
548                let ttft_delta = &latency_samples.0 .0[self.last_sent_ttft_count..];
549                let ttft_timestamps_delta = &latency_samples.0 .1[self.last_sent_ttft_count..];
550                let e2e_delta = &latency_samples.1 .0[self.last_sent_e2e_count..];
551                let e2e_timestamps_delta = &latency_samples.1 .1[self.last_sent_e2e_count..];
552                let tpot_delta = &latency_samples.2 .0[self.last_sent_tpot_count..];
553                let tpot_timestamps_delta = &latency_samples.2 .1[self.last_sent_tpot_count..];
554                let input_delta = &input_lengths[self.last_sent_input_count..];
555                let output_delta = &output_lengths[self.last_sent_output_count..];
556
557                let progress = ProgressInfo {
558                    current_time: self.current_time,
559                    completed_requests: self.metrics.completed_requests,
560                    total_requests: self.metrics.total_requests,
561                    running: self.scheduler.num_running(),
562                    waiting: self.scheduler.num_waiting(),
563                    kv_cache_util: kv_util,
564                    time_series: Some(&self.time_series_data),
565                    metrics: Some(self.metrics.compute_summary(self.current_time)),
566                    latency_samples: Some((
567                        (ttft_delta, ttft_timestamps_delta),
568                        (e2e_delta, e2e_timestamps_delta),
569                        (tpot_delta, tpot_timestamps_delta),
570                    )),
571                    distribution_samples: Some((input_delta, output_delta)),
572                };
573                callback(progress);
574                break;
575            }
576        }
577
578        Ok(())
579    }
580
581    pub fn get_metrics_summary(&self) -> crate::metrics::MetricsSummary {
582        self.metrics.compute_summary(self.current_time)
583    }
584
585    pub fn get_time_series_data(&self) -> &[TimeSeriesPoint] {
586        &self.time_series_data
587    }
588
589    pub fn get_input_lengths(&self) -> &[u32] {
590        self.metrics.get_input_lengths()
591    }
592
593    pub fn get_output_lengths(&self) -> &[u32] {
594        self.metrics.get_output_lengths()
595    }
596
597    pub fn get_current_time(&self) -> f64 {
598        self.current_time
599    }
600
601    pub fn get_latency_samples(
602        &self,
603    ) -> (
604        (&[f64], &[f64]), // (ttft_samples, ttft_timestamps)
605        (&[f64], &[f64]), // (e2e_samples, e2e_timestamps)
606        (&[f64], &[f64]), // (tpot_samples, tpot_timestamps)
607    ) {
608        self.metrics.get_latency_samples()
609    }
610
611    fn log_progress(&self) {
612        #[cfg(not(target_arch = "wasm32"))]
613        println!(
614            "[{:.2}s] Iteration {}: {} running, {} waiting, {:.1}% KV cache used",
615            self.current_time,
616            self.iteration,
617            self.scheduler.num_running(),
618            self.scheduler.num_waiting(),
619            self.scheduler.kv_cache_manager().utilization() * 100.0,
620        );
621    }
622
623    fn should_terminate(&self) -> bool {
624        // Check if we've generated all requests and completed them all
625        self.request_generator.is_finished()
626            && self.scheduler.num_running() == 0
627            && self.scheduler.num_waiting() == 0
628    }
629
630    fn print_final_metrics(&self) {
631        let summary = self.metrics.compute_summary(self.current_time);
632        summary.print();
633    }
634}