inference_lab/simulation/
simulator.rs

1use crate::compute::ComputeEngine;
2use crate::config::Config;
3use crate::kv_cache::KVCacheManager;
4use crate::metrics::MetricsCollector;
5use crate::request::RequestGenerator;
6use crate::scheduler::Scheduler;
7
8#[derive(Debug, Clone)]
9pub struct TimeSeriesPoint {
10    pub time: f64,
11    pub arrivals: u64,
12    pub running: usize,
13    pub waiting: usize,
14    pub kv_cache_util: f64,
15    pub num_prefilling: usize,
16    pub num_decoding: usize,
17    pub prefill_tokens: u32,
18    pub decode_tokens: u32,
19    pub input_throughput: f64,   // Input tokens per second (windowed)
20    pub output_throughput: f64,  // Output tokens per second (windowed)
21    pub ttft_p50: f64,           // TTFT p50 in recent window (ms)
22    pub tpot_p50: f64,           // TPOT p50 in recent window (ms)
23}
24
25pub struct ProgressInfo<'a> {
26    pub current_time: f64,
27    pub completed_requests: u64,
28    pub total_requests: u64,
29    pub running: usize,
30    pub waiting: usize,
31    pub kv_cache_util: f64,
32    pub time_series: Option<&'a [TimeSeriesPoint]>,
33    pub metrics: Option<crate::metrics::MetricsSummary>,
34    pub latency_samples: Option<((&'a [f64], &'a [f64]), (&'a [f64], &'a [f64]), (&'a [f64], &'a [f64]))>,
35    pub distribution_samples: Option<(&'a [u32], &'a [u32])>,  // (input_lengths, output_lengths)
36}
37
38pub struct Simulator {
39    scheduler: Scheduler,
40    compute_engine: ComputeEngine,
41    request_generator: RequestGenerator,
42    metrics: MetricsCollector,
43    time_series_data: Vec<TimeSeriesPoint>,
44    sample_interval: f64,
45    next_sample_time: f64,
46    prev_sample_time: f64,
47
48    current_time: f64,
49    iteration: u64,
50    log_interval: u64,
51
52    // Track last sent sample counts for streaming deltas (one per metric type)
53    last_sent_ttft_count: usize,
54    last_sent_e2e_count: usize,
55    last_sent_tpot_count: usize,
56    last_sent_input_count: usize,
57    last_sent_output_count: usize,
58}
59
60impl Simulator {
61    pub fn new(config: Config) -> Result<Self, String> {
62        let kv_cache_manager = KVCacheManager::new(
63            config.hardware.kv_cache_capacity,
64            config.scheduler.block_size,
65            config.model.kv_cache_bytes_per_token,
66            true, // enable_prefix_caching
67        );
68
69        let scheduler = Scheduler::new(
70            config.scheduler.clone(),
71            config.hardware.clone(),
72            config.model.clone(),
73            kv_cache_manager,
74        )?;
75
76        let compute_engine = ComputeEngine::new(config.hardware, config.model);
77        let request_generator = RequestGenerator::new(config.workload);
78        let metrics = MetricsCollector::new(0.0);
79
80        Ok(Self {
81            scheduler,
82            compute_engine,
83            request_generator,
84            metrics,
85            time_series_data: Vec::new(),
86            sample_interval: 0.1,
87            next_sample_time: 0.0,
88            prev_sample_time: 0.0,
89            current_time: 0.0,
90            iteration: 0,
91            log_interval: config.simulation.log_interval,
92            last_sent_ttft_count: 0,
93            last_sent_e2e_count: 0,
94            last_sent_tpot_count: 0,
95            last_sent_input_count: 0,
96            last_sent_output_count: 0,
97        })
98    }
99
100    /// Run the simulation
101    pub fn run(&mut self) {
102        #[cfg(not(target_arch = "wasm32"))]
103        println!("Starting simulation...\n");
104
105        loop {
106            self.iteration += 1;
107
108            // 1. Generate new arrivals up to current_time
109            while let Some(request) = self.request_generator.next_if_before(self.current_time) {
110                self.scheduler.add_request(request);
111                self.metrics.total_requests += 1;
112            }
113
114            // 2. Run scheduler
115            let decision = self.scheduler.schedule(self.current_time);
116
117            // 3. Calculate iteration time
118            let iteration_time = if decision.num_scheduled() > 0 {
119                // Build batch of scheduled requests
120                let running = self.scheduler.running_mut();
121                let batch_requests: Vec<&_> = decision
122                    .scheduled_new
123                    .iter()
124                    .chain(decision.scheduled_running.iter())
125                    .filter_map(|&idx| running.get(idx))
126                    .collect();
127
128                let tokens_per_req: Vec<u32> = batch_requests
129                    .iter()
130                    .enumerate()
131                    .map(|(i, _req)| {
132                        let idx = if i < decision.scheduled_new.len() {
133                            decision.scheduled_new[i]
134                        } else {
135                            decision.scheduled_running[i - decision.scheduled_new.len()]
136                        };
137                        *decision.tokens_per_request.get(&idx).unwrap_or(&0)
138                    })
139                    .collect();
140
141                self.compute_engine
142                    .calculate_iteration_time(&batch_requests, &tokens_per_req)
143            } else {
144                0.001 // Small time step when idle
145            };
146
147            // 4. Advance time
148            self.current_time += iteration_time;
149
150            // 5. Determine which requests were prefilling vs decoding BEFORE updating state
151            let mut prefilling_reqs = std::collections::HashSet::new();
152            for (&idx, &_tokens) in &decision.tokens_per_request {
153                if let Some(request) = self.scheduler.running().get(idx) {
154                    if request.is_prefill() {
155                        prefilling_reqs.insert(idx);
156                    }
157                }
158            }
159
160            // 6. Update request states
161            for (&idx, &tokens) in &decision.tokens_per_request {
162                if let Some(request) = self.scheduler.running_mut().get_mut(idx) {
163                    request.record_generated_tokens(tokens, self.current_time);
164                }
165            }
166
167            // 7. Record iteration metrics (before moving completed requests)
168            let kv_util = self.scheduler.kv_cache_manager().utilization();
169
170            // Calculate bandwidth and flops utilization
171            let (bandwidth_util, flops_util) = if decision.num_scheduled() > 0 {
172                let running = self.scheduler.running_mut();
173                let batch_requests: Vec<&_> = decision
174                    .scheduled_new
175                    .iter()
176                    .chain(decision.scheduled_running.iter())
177                    .filter_map(|&idx| running.get(idx))
178                    .collect();
179
180                let tokens_per_req: Vec<u32> = batch_requests
181                    .iter()
182                    .enumerate()
183                    .map(|(i, _req)| {
184                        let idx = if i < decision.scheduled_new.len() {
185                            decision.scheduled_new[i]
186                        } else {
187                            decision.scheduled_running[i - decision.scheduled_new.len()]
188                        };
189                        *decision.tokens_per_request.get(&idx).unwrap_or(&0)
190                    })
191                    .collect();
192
193                let bytes_transferred = self
194                    .compute_engine
195                    .calculate_bytes_transferred(&batch_requests, &tokens_per_req);
196                let bandwidth_util = self
197                    .compute_engine
198                    .calculate_bandwidth_utilization(bytes_transferred, iteration_time);
199
200                let flops_util = self
201                    .compute_engine
202                    .calculate_flops_utilization(&batch_requests, &tokens_per_req, iteration_time);
203
204                (bandwidth_util, flops_util)
205            } else {
206                (0.0, 0.0)
207            };
208
209            self.metrics
210                .record_iteration_metrics(kv_util, flops_util, bandwidth_util);
211
212            // 7. Record time-series data (BEFORE handling completed requests)
213            if self.current_time >= self.next_sample_time {
214                // Calculate prefill vs decode breakdown
215                let running = self.scheduler.running();
216                let mut num_prefilling = 0;
217                let mut num_decoding = 0;
218                let mut prefill_tokens = 0;
219                let mut decode_tokens = 0;
220
221                for req in running {
222                    if req.is_prefill() {
223                        num_prefilling += 1;
224                    } else {
225                        num_decoding += 1;
226                    }
227                }
228
229                // Count tokens scheduled in this iteration
230                // Use the prefilling_reqs set we captured before updating state
231                for (&idx, &tokens) in &decision.tokens_per_request {
232                    if prefilling_reqs.contains(&idx) {
233                        prefill_tokens += tokens;
234                    } else {
235                        decode_tokens += tokens;
236                    }
237                }
238
239                // Record throughput sample
240                self.metrics.record_throughput_sample(self.current_time);
241
242                // Calculate windowed throughput (tokens per second)
243                let input_throughput = prefill_tokens as f64 / self.sample_interval;
244                let output_throughput = decode_tokens as f64 / self.sample_interval;
245
246                // Get latency mean for events since last sample
247                let (ttft_mean, tpot_mean) = self.metrics.get_interval_latencies(
248                    self.prev_sample_time,
249                    self.current_time
250                );
251
252                self.time_series_data.push(TimeSeriesPoint {
253                    time: self.current_time,
254                    arrivals: self.metrics.total_requests,
255                    running: self.scheduler.num_running(),
256                    waiting: self.scheduler.num_waiting(),
257                    kv_cache_util: kv_util,
258                    num_prefilling,
259                    num_decoding,
260                    prefill_tokens,
261                    decode_tokens,
262                    input_throughput,
263                    output_throughput,
264                    ttft_p50: ttft_mean,
265                    tpot_p50: tpot_mean,
266                });
267                self.prev_sample_time = self.current_time;
268                self.next_sample_time = self.current_time + self.sample_interval;
269            }
270
271            // 8. Handle completed requests
272            for request in decision.completed {
273                // Free KV cache blocks
274                self.scheduler
275                    .kv_cache_manager_mut()
276                    .free_blocks(&request.kv_blocks);
277
278                self.metrics.record_request_completion(&request);
279
280                // For closed-loop workloads, generate a new request when one completes
281                self.request_generator.on_request_complete(self.current_time);
282            }
283
284            // 9. Periodic logging
285            if self.iteration % self.log_interval == 0 {
286                self.log_progress();
287            }
288
289            // 10. Check termination conditions
290            if self.should_terminate() {
291                break;
292            }
293        }
294
295        #[cfg(not(target_arch = "wasm32"))]
296        println!("\nSimulation complete!");
297        #[cfg(not(target_arch = "wasm32"))]
298        self.print_final_metrics();
299    }
300
301    /// Run the simulation with progress callbacks
302    pub fn run_with_callback<F>(&mut self, mut callback: F) -> Result<(), String>
303    where
304        F: FnMut(ProgressInfo),
305    {
306        let mut last_callback_time = 0.0;
307        let callback_interval = 1.0; // Call callback every 1.0 seconds
308
309        loop {
310            self.iteration += 1;
311
312            // 1. Generate new arrivals up to current_time
313            while let Some(request) = self.request_generator.next_if_before(self.current_time) {
314                self.scheduler.add_request(request);
315                self.metrics.total_requests += 1;
316            }
317
318            // 2. Run scheduler
319            let decision = self.scheduler.schedule(self.current_time);
320
321            // 3. Calculate iteration time
322            let iteration_time = if decision.num_scheduled() > 0 {
323                // Build batch of scheduled requests
324                let running = self.scheduler.running_mut();
325                let batch_requests: Vec<&_> = decision
326                    .scheduled_new
327                    .iter()
328                    .chain(decision.scheduled_running.iter())
329                    .filter_map(|&idx| running.get(idx))
330                    .collect();
331
332                let tokens_per_req: Vec<u32> = batch_requests
333                    .iter()
334                    .enumerate()
335                    .map(|(i, _req)| {
336                        let idx = if i < decision.scheduled_new.len() {
337                            decision.scheduled_new[i]
338                        } else {
339                            decision.scheduled_running[i - decision.scheduled_new.len()]
340                        };
341                        *decision.tokens_per_request.get(&idx).unwrap_or(&0)
342                    })
343                    .collect();
344
345                self.compute_engine
346                    .calculate_iteration_time(&batch_requests, &tokens_per_req)
347            } else {
348                0.001 // Small time step when idle
349            };
350
351            // 4. Advance time
352            self.current_time += iteration_time;
353
354            // 5. Determine which requests were prefilling vs decoding BEFORE updating state
355            let mut prefilling_reqs = std::collections::HashSet::new();
356            for (&idx, &_tokens) in &decision.tokens_per_request {
357                if let Some(request) = self.scheduler.running().get(idx) {
358                    if request.is_prefill() {
359                        prefilling_reqs.insert(idx);
360                    }
361                }
362            }
363
364            // 6. Update request states
365            for (&idx, &tokens) in &decision.tokens_per_request {
366                if let Some(request) = self.scheduler.running_mut().get_mut(idx) {
367                    request.record_generated_tokens(tokens, self.current_time);
368                }
369            }
370
371            // 7. Record iteration metrics (before moving completed requests)
372            let kv_util = self.scheduler.kv_cache_manager().utilization();
373
374            // Calculate bandwidth and flops utilization
375            let (bandwidth_util, flops_util) = if decision.num_scheduled() > 0 {
376                let running = self.scheduler.running_mut();
377                let batch_requests: Vec<&_> = decision
378                    .scheduled_new
379                    .iter()
380                    .chain(decision.scheduled_running.iter())
381                    .filter_map(|&idx| running.get(idx))
382                    .collect();
383
384                let tokens_per_req: Vec<u32> = batch_requests
385                    .iter()
386                    .enumerate()
387                    .map(|(i, _req)| {
388                        let idx = if i < decision.scheduled_new.len() {
389                            decision.scheduled_new[i]
390                        } else {
391                            decision.scheduled_running[i - decision.scheduled_new.len()]
392                        };
393                        *decision.tokens_per_request.get(&idx).unwrap_or(&0)
394                    })
395                    .collect();
396
397                let bytes_transferred = self
398                    .compute_engine
399                    .calculate_bytes_transferred(&batch_requests, &tokens_per_req);
400                let bandwidth_util = self
401                    .compute_engine
402                    .calculate_bandwidth_utilization(bytes_transferred, iteration_time);
403
404                let flops_util = self
405                    .compute_engine
406                    .calculate_flops_utilization(&batch_requests, &tokens_per_req, iteration_time);
407
408                (bandwidth_util, flops_util)
409            } else {
410                (0.0, 0.0)
411            };
412
413            self.metrics
414                .record_iteration_metrics(kv_util, flops_util, bandwidth_util);
415
416            // 8. Record time-series data (BEFORE handling completed requests)
417            if self.current_time >= self.next_sample_time {
418                // Calculate prefill vs decode breakdown
419                let running = self.scheduler.running();
420                let mut num_prefilling = 0;
421                let mut num_decoding = 0;
422                let mut prefill_tokens = 0;
423                let mut decode_tokens = 0;
424
425                for req in running {
426                    if req.is_prefill() {
427                        num_prefilling += 1;
428                    } else {
429                        num_decoding += 1;
430                    }
431                }
432
433                // Count tokens scheduled in this iteration
434                for (&idx, &tokens) in &decision.tokens_per_request {
435                    if prefilling_reqs.contains(&idx) {
436                        prefill_tokens += tokens;
437                    } else {
438                        decode_tokens += tokens;
439                    }
440                }
441
442                // Record throughput sample
443                self.metrics.record_throughput_sample(self.current_time);
444
445                // Calculate windowed throughput (tokens per second)
446                let input_throughput = prefill_tokens as f64 / self.sample_interval;
447                let output_throughput = decode_tokens as f64 / self.sample_interval;
448
449                // Get latency mean for events since last sample
450                let (ttft_mean, tpot_mean) = self.metrics.get_interval_latencies(
451                    self.prev_sample_time,
452                    self.current_time
453                );
454
455                self.time_series_data.push(TimeSeriesPoint {
456                    time: self.current_time,
457                    arrivals: self.metrics.total_requests,
458                    running: self.scheduler.num_running(),
459                    waiting: self.scheduler.num_waiting(),
460                    kv_cache_util: kv_util,
461                    num_prefilling,
462                    num_decoding,
463                    prefill_tokens,
464                    decode_tokens,
465                    input_throughput,
466                    output_throughput,
467                    ttft_p50: ttft_mean,
468                    tpot_p50: tpot_mean,
469                });
470                self.prev_sample_time = self.current_time;
471                self.next_sample_time = self.current_time + self.sample_interval;
472            }
473
474            // 9. Handle completed requests
475            for request in decision.completed {
476                // Free KV cache blocks
477                self.scheduler
478                    .kv_cache_manager_mut()
479                    .free_blocks(&request.kv_blocks);
480
481                self.metrics.record_request_completion(&request);
482
483                // For closed-loop workloads, generate a new request when one completes
484                self.request_generator.on_request_complete(self.current_time);
485            }
486
487            // 10. Send progress update if enough time has passed
488            if self.current_time - last_callback_time >= callback_interval {
489                let latency_samples = self.metrics.get_latency_samples();
490                let input_lengths = self.metrics.get_input_lengths();
491                let output_lengths = self.metrics.get_output_lengths();
492
493                // Only send new samples since last callback (delta) - track each metric separately
494                let ttft_delta = &latency_samples.0.0[self.last_sent_ttft_count..];
495                let ttft_timestamps_delta = &latency_samples.0.1[self.last_sent_ttft_count..];
496                let e2e_delta = &latency_samples.1.0[self.last_sent_e2e_count..];
497                let e2e_timestamps_delta = &latency_samples.1.1[self.last_sent_e2e_count..];
498                let tpot_delta = &latency_samples.2.0[self.last_sent_tpot_count..];
499                let tpot_timestamps_delta = &latency_samples.2.1[self.last_sent_tpot_count..];
500                let input_delta = &input_lengths[self.last_sent_input_count..];
501                let output_delta = &output_lengths[self.last_sent_output_count..];
502
503                let progress = ProgressInfo {
504                    current_time: self.current_time,
505                    completed_requests: self.metrics.completed_requests,
506                    total_requests: self.metrics.total_requests,
507                    running: self.scheduler.num_running(),
508                    waiting: self.scheduler.num_waiting(),
509                    kv_cache_util: kv_util,
510                    time_series: Some(&self.time_series_data),
511                    metrics: Some(self.metrics.compute_summary(self.current_time)),
512                    latency_samples: Some(((ttft_delta, ttft_timestamps_delta), (e2e_delta, e2e_timestamps_delta), (tpot_delta, tpot_timestamps_delta))),
513                    distribution_samples: Some((input_delta, output_delta)),
514                };
515                callback(progress);
516
517                // Update last sent sample counts for each metric
518                self.last_sent_ttft_count = latency_samples.0.0.len();
519                self.last_sent_e2e_count = latency_samples.1.0.len();
520                self.last_sent_tpot_count = latency_samples.2.0.len();
521                self.last_sent_input_count = input_lengths.len();
522                self.last_sent_output_count = output_lengths.len();
523                last_callback_time = self.current_time;
524            }
525
526            // Note: Periodic logging removed for callback mode - callback replaces it
527
528            // 11. Check termination conditions
529            if self.should_terminate() {
530                // Send final progress update with any remaining samples
531                let latency_samples = self.metrics.get_latency_samples();
532                let input_lengths = self.metrics.get_input_lengths();
533                let output_lengths = self.metrics.get_output_lengths();
534
535                // Only send new samples since last callback (delta) - track each metric separately
536                let ttft_delta = &latency_samples.0.0[self.last_sent_ttft_count..];
537                let ttft_timestamps_delta = &latency_samples.0.1[self.last_sent_ttft_count..];
538                let e2e_delta = &latency_samples.1.0[self.last_sent_e2e_count..];
539                let e2e_timestamps_delta = &latency_samples.1.1[self.last_sent_e2e_count..];
540                let tpot_delta = &latency_samples.2.0[self.last_sent_tpot_count..];
541                let tpot_timestamps_delta = &latency_samples.2.1[self.last_sent_tpot_count..];
542                let input_delta = &input_lengths[self.last_sent_input_count..];
543                let output_delta = &output_lengths[self.last_sent_output_count..];
544
545                let progress = ProgressInfo {
546                    current_time: self.current_time,
547                    completed_requests: self.metrics.completed_requests,
548                    total_requests: self.metrics.total_requests,
549                    running: self.scheduler.num_running(),
550                    waiting: self.scheduler.num_waiting(),
551                    kv_cache_util: kv_util,
552                    time_series: Some(&self.time_series_data),
553                    metrics: Some(self.metrics.compute_summary(self.current_time)),
554                    latency_samples: Some(((ttft_delta, ttft_timestamps_delta), (e2e_delta, e2e_timestamps_delta), (tpot_delta, tpot_timestamps_delta))),
555                    distribution_samples: Some((input_delta, output_delta)),
556                };
557                callback(progress);
558                break;
559            }
560        }
561
562        Ok(())
563    }
564
565    pub fn get_metrics_summary(&self) -> crate::metrics::MetricsSummary {
566        self.metrics.compute_summary(self.current_time)
567    }
568
569    pub fn get_time_series_data(&self) -> &[TimeSeriesPoint] {
570        &self.time_series_data
571    }
572
573    pub fn get_input_lengths(&self) -> &[u32] {
574        self.metrics.get_input_lengths()
575    }
576
577    pub fn get_output_lengths(&self) -> &[u32] {
578        self.metrics.get_output_lengths()
579    }
580
581    pub fn get_current_time(&self) -> f64 {
582        self.current_time
583    }
584
585    pub fn get_latency_samples(&self) -> (
586        (&[f64], &[f64]), // (ttft_samples, ttft_timestamps)
587        (&[f64], &[f64]), // (e2e_samples, e2e_timestamps)
588        (&[f64], &[f64]), // (tpot_samples, tpot_timestamps)
589    ) {
590        self.metrics.get_latency_samples()
591    }
592
593    fn log_progress(&self) {
594        #[cfg(not(target_arch = "wasm32"))]
595        println!(
596            "[{:.2}s] Iteration {}: {} running, {} waiting, {:.1}% KV cache used",
597            self.current_time,
598            self.iteration,
599            self.scheduler.num_running(),
600            self.scheduler.num_waiting(),
601            self.scheduler.kv_cache_manager().utilization() * 100.0,
602        );
603    }
604
605    fn should_terminate(&self) -> bool {
606        // Check if we've generated all requests and completed them all
607        self.request_generator.is_finished()
608            && self.scheduler.num_running() == 0
609            && self.scheduler.num_waiting() == 0
610    }
611
612    fn print_final_metrics(&self) {
613        let summary = self.metrics.compute_summary(self.current_time);
614        summary.print();
615    }
616}