Skip to main content

inference_lab/request/
generator.rs

1use super::Request;
2use crate::config::WorkloadConfig;
3use crate::dataset::{BatchTokenizerFn, DatasetEntry, UnparsedEntry};
4use rand::{rngs::StdRng, Rng, SeedableRng};
5use rand_distr::{Distribution, Exp};
6use std::collections::hash_map::DefaultHasher;
7use std::hash::{Hash, Hasher};
8use std::sync::mpsc::{sync_channel, Receiver};
9use std::thread;
10
11/// Generates requests based on workload configuration
12pub struct RequestGenerator {
13    workload: WorkloadConfig,
14    rng: StdRng,
15    next_arrival_time: f64,
16    requests_generated: usize,
17    next_request_id: u64,
18    /// For closed-loop: track pending requests to generate when completions occur
19    pending_closed_loop_requests: Vec<f64>,
20    /// Dataset receiver (if using dataset mode) - receives pre-loaded entries from background thread
21    dataset_receiver: Option<Receiver<Option<DatasetEntry>>>,
22    /// Track if dataset has been exhausted (received None from channel)
23    dataset_exhausted: bool,
24}
25
26impl RequestGenerator {
27    pub fn new(workload: WorkloadConfig) -> Self {
28        let mut rng = StdRng::seed_from_u64(workload.seed);
29        let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
30
31        // For closed-loop, initialize with N requests at time 0
32        let mut pending_closed_loop_requests = Vec::new();
33        if is_closed_loop {
34            if let Some(num_users) = workload.num_concurrent_users {
35                // Generate initial requests for all concurrent users
36                pending_closed_loop_requests = vec![0.0; num_users]
37            }
38        };
39
40        let next_arrival_time = if is_closed_loop && !pending_closed_loop_requests.is_empty() {
41            0.0 // Start immediately with the first batch
42        } else {
43            Self::sample_next_arrival(
44                0.0,
45                &workload.arrival_pattern,
46                workload.arrival_rate,
47                &mut rng,
48            )
49        };
50
51        Self {
52            workload,
53            rng,
54            next_arrival_time,
55            requests_generated: 0,
56            next_request_id: 0,
57            pending_closed_loop_requests,
58            dataset_receiver: None,
59            dataset_exhausted: false,
60        }
61    }
62
63    /// Create a new generator from a dataset iterator
64    /// Spawns a background thread to read, parse, and batch-tokenize entries in parallel
65    /// Buffer size controls memory usage (entries buffered ahead of simulation)
66    pub fn from_dataset<I>(
67        workload: WorkloadConfig,
68        dataset_iterator: I,
69        _total_entries: Option<usize>,
70        tokenizer: BatchTokenizerFn,
71    ) -> Self
72    where
73        I: Iterator<Item = Result<Option<UnparsedEntry>, Box<dyn std::error::Error>>>
74            + Send
75            + 'static,
76    {
77        let rng = StdRng::seed_from_u64(workload.seed);
78        let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
79
80        // For closed-loop, initialize with N requests at time 0
81        let mut pending_closed_loop_requests = Vec::new();
82        if is_closed_loop {
83            if let Some(num_users) = workload.num_concurrent_users {
84                pending_closed_loop_requests = vec![0.0; num_users]
85            }
86        };
87
88        let next_arrival_time = 0.0; // First dataset request arrives at t=0
89
90        // Spawn background thread to load and batch-tokenize entries
91        // Buffer size: 5000 entries (~10-50MB depending on token counts)
92        let (sender, receiver) = sync_channel::<Option<DatasetEntry>>(5000);
93
94        thread::spawn(move || {
95            let batch_size: usize = std::env::var("TOKENIZER_BATCH_SIZE")
96                .ok()
97                .and_then(|s| s.parse().ok())
98                .unwrap_or(32); // Default: 32 (optimal for latency/throughput balance)
99            let mut batch = Vec::with_capacity(batch_size);
100
101            for result in dataset_iterator {
102                match result {
103                    Ok(Some(unparsed)) => {
104                        batch.push(unparsed);
105
106                        // Process batch when full
107                        if batch.len() >= batch_size
108                            && Self::tokenize_and_send_batch(&mut batch, &tokenizer, &sender)
109                                .is_err()
110                        {
111                            // Receiver dropped, simulation ended early
112                            break;
113                        }
114                    }
115                    Ok(None) => {
116                        // End of dataset - flush remaining batch and send completion signal
117                        if !batch.is_empty() {
118                            let _ = Self::tokenize_and_send_batch(&mut batch, &tokenizer, &sender);
119                        }
120                        let _ = sender.send(None);
121                        break;
122                    }
123                    Err(e) => {
124                        eprintln!("Error loading dataset entry: {}", e);
125                        break;
126                    }
127                }
128            }
129        });
130
131        Self {
132            workload,
133            rng,
134            next_arrival_time,
135            requests_generated: 0,
136            next_request_id: 0,
137            pending_closed_loop_requests,
138            dataset_receiver: Some(receiver),
139            dataset_exhausted: false,
140        }
141    }
142
143    /// Check if using dataset mode
144    pub fn is_dataset_mode(&self) -> bool {
145        self.dataset_receiver.is_some()
146    }
147
148    /// Batch tokenize and send entries to the channel
149    /// Returns Err if the receiver dropped (simulation ended)
150    fn tokenize_and_send_batch(
151        batch: &mut Vec<UnparsedEntry>,
152        tokenizer: &BatchTokenizerFn,
153        sender: &std::sync::mpsc::SyncSender<Option<DatasetEntry>>,
154    ) -> Result<(), ()> {
155        if batch.is_empty() {
156            return Ok(());
157        }
158
159        // Batch tokenize all entries at once (much faster!)
160        let prompt_inputs: Vec<_> = batch.iter().map(|e| e.prompt_input.clone()).collect();
161        let all_tokens = match tokenizer(&prompt_inputs) {
162            Ok(tokens) => tokens,
163            Err(e) => {
164                eprintln!("Batch tokenization failed: {}", e);
165                return Err(());
166            }
167        };
168
169        // Send tokenized entries
170        for (unparsed, prompt_tokens) in batch.drain(..).zip(all_tokens.into_iter()) {
171            let entry = DatasetEntry {
172                request_id: unparsed.request_id,
173                prompt_tokens,
174                max_output_tokens: unparsed.max_output_tokens,
175            };
176
177            // Send to channel - returns Err if receiver dropped
178            if sender.send(Some(entry)).is_err() {
179                return Err(());
180            }
181        }
182        Ok(())
183    }
184
185    /// Get the next scheduled arrival time
186    pub fn peek_next_arrival_time(&self) -> f64 {
187        self.next_arrival_time
188    }
189
190    /// Compute block hashes from token IDs (for real prefix caching)
191    /// Uses incremental hashing: hash of block i includes all tokens up to block i
192    fn compute_block_hashes(tokens: &[u32], block_size: usize) -> Vec<u64> {
193        let num_blocks = tokens.len().div_ceil(block_size);
194        let mut hashes = Vec::with_capacity(num_blocks);
195
196        for block_idx in 0..num_blocks {
197            let end = ((block_idx + 1) * block_size).min(tokens.len());
198            let block_tokens = &tokens[..end]; // All tokens up to this block
199
200            // Hash all tokens cumulatively
201            let mut hasher = DefaultHasher::new();
202            block_tokens.hash(&mut hasher);
203            hashes.push(hasher.finish());
204        }
205
206        hashes
207    }
208
209    /// Get the next request if its arrival time is before the given time
210    /// Returns None if no request is ready or all requests have been generated
211    pub fn next_if_before(&mut self, current_time: f64) -> Option<Request> {
212        // Dataset mode
213        if self.is_dataset_mode() {
214            return self.next_from_dataset(current_time);
215        }
216
217        let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
218
219        // For closed-loop, check pending requests
220        if is_closed_loop {
221            // Check if we've generated all requests
222            if let Some(max_requests) = self.workload.num_requests {
223                if self.requests_generated >= max_requests {
224                    // Clear any remaining pending requests that won't be used
225                    self.pending_closed_loop_requests.clear();
226                    return None;
227                }
228            }
229
230            // Find the earliest pending request that has arrived
231            if let Some(pos) = self
232                .pending_closed_loop_requests
233                .iter()
234                .position(|&t| t <= current_time)
235            {
236                let arrival_time = self.pending_closed_loop_requests.remove(pos);
237
238                // Generate request
239                let request_id = format!("req-{}", self.next_request_id);
240                self.next_request_id += 1;
241
242                let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
243                let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
244
245                let mut request = Request::new(
246                    request_id,
247                    0, // Default priority
248                    arrival_time,
249                    num_prompt_tokens,
250                    max_output_tokens,
251                );
252
253                // Generate block hashes for the prompt
254                // For now, sample from a small range to get realistic cache hit rates
255                // First few blocks more likely to be shared (simulating common system prompts)
256                let num_blocks = num_prompt_tokens.div_ceil(16) as usize; // Assume 16-token blocks
257                request.prompt_block_hashes = (0..num_blocks)
258                    .map(|_| self.rng.gen_range(0..u64::MAX))
259                    .collect();
260
261                self.requests_generated += 1;
262                return Some(request);
263            }
264            return None;
265        }
266
267        // Original logic for non-closed-loop patterns
268        // Check if we've generated all requests
269        if let Some(max_requests) = self.workload.num_requests {
270            if self.requests_generated >= max_requests {
271                return None;
272            }
273        }
274
275        // Check if next request has arrived
276        if self.next_arrival_time > current_time {
277            return None;
278        }
279
280        // Generate request
281        let request_id = format!("req-{}", self.next_request_id);
282        self.next_request_id += 1;
283
284        let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
285        let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
286
287        let mut request = Request::new(
288            request_id,
289            0, // Default priority
290            self.next_arrival_time,
291            num_prompt_tokens,
292            max_output_tokens,
293        );
294
295        // Generate block hashes for the prompt
296        // For now, sample from a small range to get realistic cache hit rates
297        // First few blocks more likely to be shared (simulating common system prompts)
298        let num_blocks = num_prompt_tokens.div_ceil(16) as usize; // Assume 16-token blocks
299        request.prompt_block_hashes = (0..num_blocks)
300            .map(|_| self.rng.gen_range(0..u64::MAX))
301            .collect();
302
303        self.requests_generated += 1;
304
305        // Sample next arrival time
306        self.next_arrival_time = Self::sample_next_arrival(
307            self.next_arrival_time,
308            &self.workload.arrival_pattern,
309            self.workload.arrival_rate,
310            &mut self.rng,
311        );
312
313        Some(request)
314    }
315
316    /// Sample the next arrival time based on the arrival pattern
317    fn sample_next_arrival(current_time: f64, pattern: &str, rate: f64, rng: &mut StdRng) -> f64 {
318        match pattern.to_lowercase().as_str() {
319            "poisson" => {
320                // Poisson process: inter-arrival times are exponentially distributed
321                let exp = Exp::new(rate).unwrap();
322                let inter_arrival = exp.sample(rng);
323                current_time + inter_arrival
324            }
325            "uniform" => {
326                // Uniform: constant inter-arrival time
327                let inter_arrival = 1.0 / rate;
328                current_time + inter_arrival
329            }
330            "burst" => {
331                // Burst: requests arrive in bursts with gaps
332                // Simple implementation: alternate between fast and slow
333                if rng.gen_bool(0.2) {
334                    // 20% chance of burst
335                    current_time + rng.gen_range(0.001..0.01)
336                } else {
337                    current_time + rng.gen_range(0.5..2.0)
338                }
339            }
340            "fixed_rate" => {
341                // Fixed rate: exact inter-arrival time
342                current_time + 1.0 / rate
343            }
344            "batched" => {
345                // Batched: all requests arrive at time 0
346                0.0
347            }
348            _ => {
349                // Default to Poisson
350                let exp = Exp::new(rate).unwrap();
351                current_time + exp.sample(rng)
352            }
353        }
354    }
355
356    /// Get next request from dataset (receives from background thread)
357    fn next_from_dataset(&mut self, current_time: f64) -> Option<Request> {
358        // If already exhausted, no more requests
359        if self.dataset_exhausted {
360            return None;
361        }
362
363        let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
364
365        // For closed-loop, handle pending requests
366        if is_closed_loop {
367            // Check if we've generated all requests
368            if let Some(max_requests) = self.workload.num_requests {
369                if self.requests_generated >= max_requests {
370                    // Clear any remaining pending requests that won't be used
371                    self.pending_closed_loop_requests.clear();
372                    return None;
373                }
374            }
375
376            // Find the earliest pending request that has arrived
377            if let Some(pos) = self
378                .pending_closed_loop_requests
379                .iter()
380                .position(|&t| t <= current_time)
381            {
382                let arrival_time = self.pending_closed_loop_requests.remove(pos);
383
384                // Receive from channel
385                let entry = match self.dataset_receiver.as_ref()?.recv() {
386                    Ok(Some(e)) => e,
387                    Ok(None) => {
388                        // End of dataset signaled
389                        self.dataset_exhausted = true;
390                        return None;
391                    }
392                    Err(_) => {
393                        // Channel error (sender dropped)
394                        self.dataset_exhausted = true;
395                        return None;
396                    }
397                };
398
399                // Sample actual output length from distribution
400                let sampled_output_len = self.workload.output_len_dist.sample(&mut self.rng);
401
402                // If max_output_tokens is specified in the dataset, cap at that; otherwise use sampled value
403                let max_output_tokens = entry.max_output_tokens.unwrap_or(16384);
404                let target_output_tokens = sampled_output_len.min(max_output_tokens);
405
406                let mut request = Request::new_with_target(
407                    entry.request_id.clone(),
408                    0,
409                    arrival_time,
410                    entry.num_prompt_tokens(),
411                    max_output_tokens,
412                    target_output_tokens,
413                );
414
415                request.prompt_block_hashes = Self::compute_block_hashes(&entry.prompt_tokens, 16);
416                self.requests_generated += 1;
417
418                return Some(request);
419            }
420            return None;
421        }
422
423        // Non-closed-loop: original logic
424        // Check if it's time for next arrival
425        if self.next_arrival_time > current_time {
426            return None;
427        }
428
429        // Receive from channel
430        let entry = match self.dataset_receiver.as_ref()?.recv() {
431            Ok(Some(e)) => e,
432            Ok(None) => {
433                // End of dataset signaled
434                self.dataset_exhausted = true;
435                return None;
436            }
437            Err(_) => {
438                // Channel error (sender dropped)
439                self.dataset_exhausted = true;
440                return None;
441            }
442        };
443
444        let arrival_time = self.next_arrival_time;
445
446        // Sample actual output length from distribution
447        let sampled_output_len = self.workload.output_len_dist.sample(&mut self.rng);
448
449        // If max_output_tokens is specified in the dataset, cap at that; otherwise use sampled value
450        let max_output_tokens = entry.max_output_tokens.unwrap_or(16384);
451        let target_output_tokens = sampled_output_len.min(max_output_tokens);
452
453        let mut request = Request::new_with_target(
454            entry.request_id.clone(),
455            0,
456            arrival_time,
457            entry.num_prompt_tokens(),
458            max_output_tokens,
459            target_output_tokens,
460        );
461
462        request.prompt_block_hashes = Self::compute_block_hashes(&entry.prompt_tokens, 16);
463        self.requests_generated += 1;
464
465        // Sample next arrival time AFTER creating the request
466        // but ONLY if we haven't hit the request limit
467        // This prevents sampling a bogus future time for a request that won't exist
468        let should_sample_next = if let Some(max_requests) = self.workload.num_requests {
469            self.requests_generated < max_requests
470        } else {
471            // No limit set, keep sampling (will stop when channel sends None)
472            true
473        };
474
475        if should_sample_next {
476            self.next_arrival_time = Self::sample_next_arrival(
477                self.next_arrival_time,
478                &self.workload.arrival_pattern,
479                self.workload.arrival_rate,
480                &mut self.rng,
481            );
482        }
483
484        Some(request)
485    }
486
487    /// Check if all requests have been generated
488    pub fn is_finished(&self) -> bool {
489        let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
490
491        // Dataset mode: check if we've hit limit or dataset is exhausted
492        if self.is_dataset_mode() {
493            // If num_requests is set, check against that limit
494            if let Some(max_requests) = self.workload.num_requests {
495                if is_closed_loop {
496                    // For closed-loop with dataset, we're finished when we've generated max_requests
497                    // AND have no pending requests (or dataset is exhausted)
498                    return (self.requests_generated >= max_requests
499                        && self.pending_closed_loop_requests.is_empty())
500                        || self.dataset_exhausted;
501                } else {
502                    return self.requests_generated >= max_requests;
503                }
504            }
505            // Otherwise, check if dataset has been fully consumed
506            return self.dataset_exhausted;
507        }
508
509        // Synthetic workload mode
510        if let Some(max_requests) = self.workload.num_requests {
511            if is_closed_loop {
512                // For closed-loop, we're finished when we've generated max_requests
513                // AND have no pending requests
514                self.requests_generated >= max_requests
515                    && self.pending_closed_loop_requests.is_empty()
516            } else {
517                self.requests_generated >= max_requests
518            }
519        } else {
520            false
521        }
522    }
523
524    /// Called when a request completes (for closed-loop pattern)
525    /// Generates a new request for that "user slot" at the completion time
526    pub fn on_request_complete(&mut self, completion_time: f64) {
527        let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
528        if !is_closed_loop {
529            return; // Only applicable to closed-loop
530        }
531
532        // Check if we should generate more requests
533        if let Some(max_requests) = self.workload.num_requests {
534            if self.requests_generated >= max_requests {
535                return; // Already generated all requested requests
536            }
537        }
538
539        // Add a new pending request at the completion time
540        self.pending_closed_loop_requests.push(completion_time);
541    }
542
543    /// Get number of requests generated so far
544    pub fn num_generated(&self) -> usize {
545        self.requests_generated
546    }
547
548    /// Peek at the next arrival time without generating the request
549    pub fn peek_next_arrival(&self) -> f64 {
550        self.next_arrival_time
551    }
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557    use crate::config::LengthDistribution;
558
559    fn create_test_workload(pattern: &str, rate: f64, num_requests: usize) -> WorkloadConfig {
560        WorkloadConfig {
561            dataset_path: None,
562            arrival_pattern: pattern.to_string(),
563            arrival_rate: rate,
564            num_concurrent_users: None,
565            input_len_dist: LengthDistribution::Fixed { value: 100 },
566            output_len_dist: LengthDistribution::Fixed { value: 50 },
567            num_requests: Some(num_requests),
568            duration_secs: None,
569            seed: 42,
570        }
571    }
572
573    #[test]
574    fn test_generator_creation() {
575        let workload = create_test_workload("poisson", 1.0, 10);
576        let generator = RequestGenerator::new(workload);
577
578        assert_eq!(generator.num_generated(), 0);
579        assert!(!generator.is_finished());
580    }
581
582    #[test]
583    fn test_generate_requests() {
584        let workload = create_test_workload("poisson", 10.0, 5);
585        let mut generator = RequestGenerator::new(workload);
586
587        let mut requests = Vec::new();
588        let mut current_time = 0.0;
589
590        while !generator.is_finished() {
591            // Advance time significantly to ensure all requests arrive
592            current_time += 10.0;
593
594            while let Some(req) = generator.next_if_before(current_time) {
595                requests.push(req);
596            }
597        }
598
599        assert_eq!(requests.len(), 5);
600        assert!(generator.is_finished());
601    }
602
603    #[test]
604    fn test_arrival_ordering() {
605        let workload = create_test_workload("poisson", 5.0, 10);
606        let mut generator = RequestGenerator::new(workload);
607
608        let mut requests = Vec::new();
609        let mut current_time = 0.0;
610
611        while !generator.is_finished() {
612            current_time += 10.0;
613            while let Some(req) = generator.next_if_before(current_time) {
614                requests.push(req);
615            }
616        }
617
618        // Check that arrival times are monotonically increasing
619        for i in 1..requests.len() {
620            assert!(requests[i].arrival_time >= requests[i - 1].arrival_time);
621        }
622    }
623
624    #[test]
625    fn test_fixed_rate_arrival() {
626        let workload = create_test_workload("fixed_rate", 2.0, 4);
627        let mut generator = RequestGenerator::new(workload);
628
629        let mut requests = Vec::new();
630        let mut current_time = 0.0;
631
632        while !generator.is_finished() {
633            current_time += 10.0;
634            while let Some(req) = generator.next_if_before(current_time) {
635                requests.push(req);
636            }
637        }
638
639        assert_eq!(requests.len(), 4);
640
641        // Check that inter-arrival times are approximately 1/rate = 0.5 seconds
642        for i in 1..requests.len() {
643            let inter_arrival = requests[i].arrival_time - requests[i - 1].arrival_time;
644            assert!((inter_arrival - 0.5).abs() < 1e-6);
645        }
646    }
647
648    #[test]
649    fn test_request_properties() {
650        let workload = create_test_workload("poisson", 1.0, 1);
651        let mut generator = RequestGenerator::new(workload);
652
653        let req = generator.next_if_before(10.0).unwrap();
654
655        assert_eq!(req.num_prompt_tokens, 100);
656        assert_eq!(req.max_output_tokens, 50);
657        assert_eq!(req.priority, 0);
658        assert!(req.request_id.starts_with("req-"));
659    }
660
661    #[test]
662    fn test_peek_next_arrival() {
663        let workload = create_test_workload("poisson", 1.0, 10);
664        let mut generator = RequestGenerator::new(workload);
665
666        let next_arrival = generator.peek_next_arrival();
667        assert!(next_arrival > 0.0);
668
669        // Generate the request
670        let req = generator.next_if_before(next_arrival + 1.0).unwrap();
671
672        // Check that arrival time matches what we peeked
673        assert_eq!(req.arrival_time, next_arrival);
674    }
675}