Skip to main content

inference_lab/request/
generator.rs

1use super::Request;
2use crate::config::WorkloadConfig;
3use crate::dataset::{BatchTokenizerFn, DatasetEntry, UnparsedEntry};
4use rand::{rngs::StdRng, Rng, SeedableRng};
5use rand_distr::{Distribution, Exp};
6use std::collections::hash_map::DefaultHasher;
7use std::hash::{Hash, Hasher};
8use std::sync::mpsc::{sync_channel, Receiver};
9use std::thread;
10
11/// Generates requests based on workload configuration
12pub struct RequestGenerator {
13    workload: WorkloadConfig,
14    rng: StdRng,
15    next_arrival_time: f64,
16    requests_generated: usize,
17    next_request_id: u64,
18    /// For closed-loop: track pending requests to generate when completions occur
19    pending_closed_loop_requests: Vec<f64>,
20    /// Dataset receiver (if using dataset mode) - receives pre-loaded entries from background thread
21    dataset_receiver: Option<Receiver<Option<DatasetEntry>>>,
22    /// Track if dataset has been exhausted (received None from channel)
23    dataset_exhausted: bool,
24}
25
26impl RequestGenerator {
27    pub fn new(workload: WorkloadConfig) -> Self {
28        let mut rng = StdRng::seed_from_u64(workload.seed);
29        let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
30
31        // For closed-loop, initialize with N requests at time 0
32        let mut pending_closed_loop_requests = Vec::new();
33        if is_closed_loop {
34            if let Some(num_users) = workload.num_concurrent_users {
35                // Generate initial requests for all concurrent users
36                pending_closed_loop_requests = vec![0.0; num_users]
37            }
38        };
39
40        let next_arrival_time = if is_closed_loop && !pending_closed_loop_requests.is_empty() {
41            0.0 // Start immediately with the first batch
42        } else {
43            Self::sample_next_arrival(
44                0.0,
45                &workload.arrival_pattern,
46                workload.arrival_rate,
47                &mut rng,
48            )
49        };
50
51        Self {
52            workload,
53            rng,
54            next_arrival_time,
55            requests_generated: 0,
56            next_request_id: 0,
57            pending_closed_loop_requests,
58            dataset_receiver: None,
59            dataset_exhausted: false,
60        }
61    }
62
63    /// Create a new generator from a dataset iterator
64    /// Spawns a background thread to read, parse, and batch-tokenize entries in parallel
65    /// Buffer size controls memory usage (entries buffered ahead of simulation)
66    pub fn from_dataset<I>(
67        workload: WorkloadConfig,
68        dataset_iterator: I,
69        _total_entries: Option<usize>,
70        tokenizer: BatchTokenizerFn,
71    ) -> Self
72    where
73        I: Iterator<Item = Result<Option<UnparsedEntry>, Box<dyn std::error::Error>>>
74            + Send
75            + 'static,
76    {
77        let rng = StdRng::seed_from_u64(workload.seed);
78        let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
79
80        // For closed-loop, initialize with N requests at time 0
81        let mut pending_closed_loop_requests = Vec::new();
82        if is_closed_loop {
83            if let Some(num_users) = workload.num_concurrent_users {
84                pending_closed_loop_requests = vec![0.0; num_users]
85            }
86        };
87
88        let next_arrival_time = if is_closed_loop && !pending_closed_loop_requests.is_empty() {
89            0.0
90        } else {
91            0.0 // First request arrives at t=0
92        };
93
94        // Spawn background thread to load and batch-tokenize entries
95        // Buffer size: 5000 entries (~10-50MB depending on token counts)
96        let (sender, receiver) = sync_channel::<Option<DatasetEntry>>(5000);
97
98        thread::spawn(move || {
99            let batch_size: usize = std::env::var("TOKENIZER_BATCH_SIZE")
100                .ok()
101                .and_then(|s| s.parse().ok())
102                .unwrap_or(32); // Default: 32 (optimal for latency/throughput balance)
103            let mut batch = Vec::with_capacity(batch_size);
104
105            for result in dataset_iterator {
106                match result {
107                    Ok(Some(unparsed)) => {
108                        batch.push(unparsed);
109
110                        // Process batch when full
111                        if batch.len() >= batch_size {
112                            if let Err(_) =
113                                Self::tokenize_and_send_batch(&mut batch, &tokenizer, &sender)
114                            {
115                                // Receiver dropped, simulation ended early
116                                break;
117                            }
118                        }
119                    }
120                    Ok(None) => {
121                        // End of dataset - flush remaining batch and send completion signal
122                        if !batch.is_empty() {
123                            let _ = Self::tokenize_and_send_batch(&mut batch, &tokenizer, &sender);
124                        }
125                        let _ = sender.send(None);
126                        break;
127                    }
128                    Err(e) => {
129                        eprintln!("Error loading dataset entry: {}", e);
130                        break;
131                    }
132                }
133            }
134        });
135
136        Self {
137            workload,
138            rng,
139            next_arrival_time,
140            requests_generated: 0,
141            next_request_id: 0,
142            pending_closed_loop_requests,
143            dataset_receiver: Some(receiver),
144            dataset_exhausted: false,
145        }
146    }
147
148    /// Check if using dataset mode
149    pub fn is_dataset_mode(&self) -> bool {
150        self.dataset_receiver.is_some()
151    }
152
153    /// Batch tokenize and send entries to the channel
154    /// Returns Err if the receiver dropped (simulation ended)
155    fn tokenize_and_send_batch(
156        batch: &mut Vec<UnparsedEntry>,
157        tokenizer: &BatchTokenizerFn,
158        sender: &std::sync::mpsc::SyncSender<Option<DatasetEntry>>,
159    ) -> Result<(), ()> {
160        if batch.is_empty() {
161            return Ok(());
162        }
163
164        // Collect all message arrays to tokenize together
165        let message_arrays: Vec<&[_]> = batch.iter().map(|e| e.messages.as_slice()).collect();
166
167        // Batch tokenize all entries at once (much faster!)
168        let all_tokens = match tokenizer(&message_arrays) {
169            Ok(tokens) => tokens,
170            Err(e) => {
171                eprintln!("Batch tokenization failed: {}", e);
172                return Err(());
173            }
174        };
175
176        // Send tokenized entries
177        for (unparsed, prompt_tokens) in batch.drain(..).zip(all_tokens.into_iter()) {
178            let entry = DatasetEntry {
179                request_id: unparsed.request_id,
180                prompt_tokens,
181                max_output_tokens: unparsed.max_output_tokens,
182            };
183
184            // Send to channel - returns Err if receiver dropped
185            if sender.send(Some(entry)).is_err() {
186                return Err(());
187            }
188        }
189        Ok(())
190    }
191
192    /// Get the next scheduled arrival time
193    pub fn peek_next_arrival_time(&self) -> f64 {
194        self.next_arrival_time
195    }
196
197    /// Compute block hashes from token IDs (for real prefix caching)
198    /// Uses incremental hashing: hash of block i includes all tokens up to block i
199    fn compute_block_hashes(tokens: &[u32], block_size: usize) -> Vec<u64> {
200        let num_blocks = tokens.len().div_ceil(block_size);
201        let mut hashes = Vec::with_capacity(num_blocks);
202
203        for block_idx in 0..num_blocks {
204            let end = ((block_idx + 1) * block_size).min(tokens.len());
205            let block_tokens = &tokens[..end]; // All tokens up to this block
206
207            // Hash all tokens cumulatively
208            let mut hasher = DefaultHasher::new();
209            block_tokens.hash(&mut hasher);
210            hashes.push(hasher.finish());
211        }
212
213        hashes
214    }
215
216    /// Get the next request if its arrival time is before the given time
217    /// Returns None if no request is ready or all requests have been generated
218    pub fn next_if_before(&mut self, current_time: f64) -> Option<Request> {
219        // Dataset mode
220        if self.is_dataset_mode() {
221            return self.next_from_dataset(current_time);
222        }
223
224        let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
225
226        // For closed-loop, check pending requests
227        if is_closed_loop {
228            // Check if we've generated all requests
229            if let Some(max_requests) = self.workload.num_requests {
230                if self.requests_generated >= max_requests {
231                    // Clear any remaining pending requests that won't be used
232                    self.pending_closed_loop_requests.clear();
233                    return None;
234                }
235            }
236
237            // Find the earliest pending request that has arrived
238            if let Some(pos) = self
239                .pending_closed_loop_requests
240                .iter()
241                .position(|&t| t <= current_time)
242            {
243                let arrival_time = self.pending_closed_loop_requests.remove(pos);
244
245                // Generate request
246                let request_id = format!("req-{}", self.next_request_id);
247                self.next_request_id += 1;
248
249                let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
250                let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
251
252                let mut request = Request::new(
253                    request_id,
254                    0, // Default priority
255                    arrival_time,
256                    num_prompt_tokens,
257                    max_output_tokens,
258                );
259
260                // Generate block hashes for the prompt
261                // For now, sample from a small range to get realistic cache hit rates
262                // First few blocks more likely to be shared (simulating common system prompts)
263                let num_blocks = num_prompt_tokens.div_ceil(16) as usize; // Assume 16-token blocks
264                request.prompt_block_hashes = (0..num_blocks)
265                    .map(|_| self.rng.gen_range(0..u64::MAX))
266                    .collect();
267
268                self.requests_generated += 1;
269                return Some(request);
270            }
271            return None;
272        }
273
274        // Original logic for non-closed-loop patterns
275        // Check if we've generated all requests
276        if let Some(max_requests) = self.workload.num_requests {
277            if self.requests_generated >= max_requests {
278                return None;
279            }
280        }
281
282        // Check if next request has arrived
283        if self.next_arrival_time > current_time {
284            return None;
285        }
286
287        // Generate request
288        let request_id = format!("req-{}", self.next_request_id);
289        self.next_request_id += 1;
290
291        let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
292        let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
293
294        let mut request = Request::new(
295            request_id,
296            0, // Default priority
297            self.next_arrival_time,
298            num_prompt_tokens,
299            max_output_tokens,
300        );
301
302        // Generate block hashes for the prompt
303        // For now, sample from a small range to get realistic cache hit rates
304        // First few blocks more likely to be shared (simulating common system prompts)
305        let num_blocks = num_prompt_tokens.div_ceil(16) as usize; // Assume 16-token blocks
306        request.prompt_block_hashes = (0..num_blocks)
307            .map(|_| self.rng.gen_range(0..u64::MAX))
308            .collect();
309
310        self.requests_generated += 1;
311
312        // Sample next arrival time
313        self.next_arrival_time = Self::sample_next_arrival(
314            self.next_arrival_time,
315            &self.workload.arrival_pattern,
316            self.workload.arrival_rate,
317            &mut self.rng,
318        );
319
320        Some(request)
321    }
322
323    /// Sample the next arrival time based on the arrival pattern
324    fn sample_next_arrival(current_time: f64, pattern: &str, rate: f64, rng: &mut StdRng) -> f64 {
325        match pattern.to_lowercase().as_str() {
326            "poisson" => {
327                // Poisson process: inter-arrival times are exponentially distributed
328                let exp = Exp::new(rate).unwrap();
329                let inter_arrival = exp.sample(rng);
330                current_time + inter_arrival
331            }
332            "uniform" => {
333                // Uniform: constant inter-arrival time
334                let inter_arrival = 1.0 / rate;
335                current_time + inter_arrival
336            }
337            "burst" => {
338                // Burst: requests arrive in bursts with gaps
339                // Simple implementation: alternate between fast and slow
340                if rng.gen_bool(0.2) {
341                    // 20% chance of burst
342                    current_time + rng.gen_range(0.001..0.01)
343                } else {
344                    current_time + rng.gen_range(0.5..2.0)
345                }
346            }
347            "fixed_rate" => {
348                // Fixed rate: exact inter-arrival time
349                current_time + 1.0 / rate
350            }
351            "batched" => {
352                // Batched: all requests arrive at time 0
353                0.0
354            }
355            _ => {
356                // Default to Poisson
357                let exp = Exp::new(rate).unwrap();
358                current_time + exp.sample(rng)
359            }
360        }
361    }
362
363    /// Get next request from dataset (receives from background thread)
364    fn next_from_dataset(&mut self, current_time: f64) -> Option<Request> {
365        // If already exhausted, no more requests
366        if self.dataset_exhausted {
367            return None;
368        }
369
370        let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
371
372        // For closed-loop, handle pending requests
373        if is_closed_loop {
374            // Check if we've generated all requests
375            if let Some(max_requests) = self.workload.num_requests {
376                if self.requests_generated >= max_requests {
377                    // Clear any remaining pending requests that won't be used
378                    self.pending_closed_loop_requests.clear();
379                    return None;
380                }
381            }
382
383            // Find the earliest pending request that has arrived
384            if let Some(pos) = self
385                .pending_closed_loop_requests
386                .iter()
387                .position(|&t| t <= current_time)
388            {
389                let arrival_time = self.pending_closed_loop_requests.remove(pos);
390
391                // Receive from channel
392                let entry = match self.dataset_receiver.as_ref()?.recv() {
393                    Ok(Some(e)) => e,
394                    Ok(None) => {
395                        // End of dataset signaled
396                        self.dataset_exhausted = true;
397                        return None;
398                    }
399                    Err(_) => {
400                        // Channel error (sender dropped)
401                        self.dataset_exhausted = true;
402                        return None;
403                    }
404                };
405
406                // Sample actual output length from distribution
407                let sampled_output_len = self.workload.output_len_dist.sample(&mut self.rng);
408
409                // If max_output_tokens is specified in the dataset, cap at that; otherwise use sampled value
410                let max_output_tokens = entry.max_output_tokens.unwrap_or(16384);
411                let target_output_tokens = sampled_output_len.min(max_output_tokens);
412
413                let mut request = Request::new_with_target(
414                    entry.request_id.clone(),
415                    0,
416                    arrival_time,
417                    entry.num_prompt_tokens(),
418                    max_output_tokens,
419                    target_output_tokens,
420                );
421
422                request.prompt_block_hashes = Self::compute_block_hashes(&entry.prompt_tokens, 16);
423                self.requests_generated += 1;
424
425                return Some(request);
426            }
427            return None;
428        }
429
430        // Non-closed-loop: original logic
431        // Check if it's time for next arrival
432        if self.next_arrival_time > current_time {
433            return None;
434        }
435
436        // Receive from channel
437        let entry = match self.dataset_receiver.as_ref()?.recv() {
438            Ok(Some(e)) => e,
439            Ok(None) => {
440                // End of dataset signaled
441                self.dataset_exhausted = true;
442                return None;
443            }
444            Err(_) => {
445                // Channel error (sender dropped)
446                self.dataset_exhausted = true;
447                return None;
448            }
449        };
450
451        let arrival_time = self.next_arrival_time;
452
453        // Sample actual output length from distribution
454        let sampled_output_len = self.workload.output_len_dist.sample(&mut self.rng);
455
456        // If max_output_tokens is specified in the dataset, cap at that; otherwise use sampled value
457        let max_output_tokens = entry.max_output_tokens.unwrap_or(16384);
458        let target_output_tokens = sampled_output_len.min(max_output_tokens);
459
460        let mut request = Request::new_with_target(
461            entry.request_id.clone(),
462            0,
463            arrival_time,
464            entry.num_prompt_tokens(),
465            max_output_tokens,
466            target_output_tokens,
467        );
468
469        request.prompt_block_hashes = Self::compute_block_hashes(&entry.prompt_tokens, 16);
470        self.requests_generated += 1;
471
472        // Sample next arrival time AFTER creating the request
473        // but ONLY if we haven't hit the request limit
474        // This prevents sampling a bogus future time for a request that won't exist
475        let should_sample_next = if let Some(max_requests) = self.workload.num_requests {
476            self.requests_generated < max_requests
477        } else {
478            // No limit set, keep sampling (will stop when channel sends None)
479            true
480        };
481
482        if should_sample_next {
483            self.next_arrival_time = Self::sample_next_arrival(
484                self.next_arrival_time,
485                &self.workload.arrival_pattern,
486                self.workload.arrival_rate,
487                &mut self.rng,
488            );
489        }
490
491        Some(request)
492    }
493
494    /// Check if all requests have been generated
495    pub fn is_finished(&self) -> bool {
496        let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
497
498        // Dataset mode: check if we've hit limit or dataset is exhausted
499        if self.is_dataset_mode() {
500            // If num_requests is set, check against that limit
501            if let Some(max_requests) = self.workload.num_requests {
502                if is_closed_loop {
503                    // For closed-loop with dataset, we're finished when we've generated max_requests
504                    // AND have no pending requests (or dataset is exhausted)
505                    return (self.requests_generated >= max_requests
506                        && self.pending_closed_loop_requests.is_empty())
507                        || self.dataset_exhausted;
508                } else {
509                    return self.requests_generated >= max_requests;
510                }
511            }
512            // Otherwise, check if dataset has been fully consumed
513            return self.dataset_exhausted;
514        }
515
516        // Synthetic workload mode
517        if let Some(max_requests) = self.workload.num_requests {
518            if is_closed_loop {
519                // For closed-loop, we're finished when we've generated max_requests
520                // AND have no pending requests
521                self.requests_generated >= max_requests
522                    && self.pending_closed_loop_requests.is_empty()
523            } else {
524                self.requests_generated >= max_requests
525            }
526        } else {
527            false
528        }
529    }
530
531    /// Called when a request completes (for closed-loop pattern)
532    /// Generates a new request for that "user slot" at the completion time
533    pub fn on_request_complete(&mut self, completion_time: f64) {
534        let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
535        if !is_closed_loop {
536            return; // Only applicable to closed-loop
537        }
538
539        // Check if we should generate more requests
540        if let Some(max_requests) = self.workload.num_requests {
541            if self.requests_generated >= max_requests {
542                return; // Already generated all requested requests
543            }
544        }
545
546        // Add a new pending request at the completion time
547        self.pending_closed_loop_requests.push(completion_time);
548    }
549
550    /// Get number of requests generated so far
551    pub fn num_generated(&self) -> usize {
552        self.requests_generated
553    }
554
555    /// Peek at the next arrival time without generating the request
556    pub fn peek_next_arrival(&self) -> f64 {
557        self.next_arrival_time
558    }
559}
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564    use crate::config::LengthDistribution;
565
566    fn create_test_workload(pattern: &str, rate: f64, num_requests: usize) -> WorkloadConfig {
567        WorkloadConfig {
568            dataset_path: None,
569            arrival_pattern: pattern.to_string(),
570            arrival_rate: rate,
571            num_concurrent_users: None,
572            input_len_dist: LengthDistribution::Fixed { value: 100 },
573            output_len_dist: LengthDistribution::Fixed { value: 50 },
574            num_requests: Some(num_requests),
575            duration_secs: None,
576            seed: 42,
577        }
578    }
579
580    #[test]
581    fn test_generator_creation() {
582        let workload = create_test_workload("poisson", 1.0, 10);
583        let generator = RequestGenerator::new(workload);
584
585        assert_eq!(generator.num_generated(), 0);
586        assert!(!generator.is_finished());
587    }
588
589    #[test]
590    fn test_generate_requests() {
591        let workload = create_test_workload("poisson", 10.0, 5);
592        let mut generator = RequestGenerator::new(workload);
593
594        let mut requests = Vec::new();
595        let mut current_time = 0.0;
596
597        while !generator.is_finished() {
598            // Advance time significantly to ensure all requests arrive
599            current_time += 10.0;
600
601            while let Some(req) = generator.next_if_before(current_time) {
602                requests.push(req);
603            }
604        }
605
606        assert_eq!(requests.len(), 5);
607        assert!(generator.is_finished());
608    }
609
610    #[test]
611    fn test_arrival_ordering() {
612        let workload = create_test_workload("poisson", 5.0, 10);
613        let mut generator = RequestGenerator::new(workload);
614
615        let mut requests = Vec::new();
616        let mut current_time = 0.0;
617
618        while !generator.is_finished() {
619            current_time += 10.0;
620            while let Some(req) = generator.next_if_before(current_time) {
621                requests.push(req);
622            }
623        }
624
625        // Check that arrival times are monotonically increasing
626        for i in 1..requests.len() {
627            assert!(requests[i].arrival_time >= requests[i - 1].arrival_time);
628        }
629    }
630
631    #[test]
632    fn test_fixed_rate_arrival() {
633        let workload = create_test_workload("fixed_rate", 2.0, 4);
634        let mut generator = RequestGenerator::new(workload);
635
636        let mut requests = Vec::new();
637        let mut current_time = 0.0;
638
639        while !generator.is_finished() {
640            current_time += 10.0;
641            while let Some(req) = generator.next_if_before(current_time) {
642                requests.push(req);
643            }
644        }
645
646        assert_eq!(requests.len(), 4);
647
648        // Check that inter-arrival times are approximately 1/rate = 0.5 seconds
649        for i in 1..requests.len() {
650            let inter_arrival = requests[i].arrival_time - requests[i - 1].arrival_time;
651            assert!((inter_arrival - 0.5).abs() < 1e-6);
652        }
653    }
654
655    #[test]
656    fn test_request_properties() {
657        let workload = create_test_workload("poisson", 1.0, 1);
658        let mut generator = RequestGenerator::new(workload);
659
660        let req = generator.next_if_before(10.0).unwrap();
661
662        assert_eq!(req.num_prompt_tokens, 100);
663        assert_eq!(req.max_output_tokens, 50);
664        assert_eq!(req.priority, 0);
665        assert!(req.request_id.starts_with("req-"));
666    }
667
668    #[test]
669    fn test_peek_next_arrival() {
670        let workload = create_test_workload("poisson", 1.0, 10);
671        let mut generator = RequestGenerator::new(workload);
672
673        let next_arrival = generator.peek_next_arrival();
674        assert!(next_arrival > 0.0);
675
676        // Generate the request
677        let req = generator.next_if_before(next_arrival + 1.0).unwrap();
678
679        // Check that arrival time matches what we peeked
680        assert_eq!(req.arrival_time, next_arrival);
681    }
682}