Skip to main content

inference_lab/request/
generator.rs

1use super::Request;
2use crate::config::WorkloadConfig;
3use crate::dataset::{BatchTokenizerFn, DatasetEntry, UnparsedEntry};
4use rand::{rngs::StdRng, Rng, SeedableRng};
5use rand_distr::{Distribution, Exp};
6use std::collections::hash_map::DefaultHasher;
7use std::hash::{Hash, Hasher};
8use std::sync::mpsc::{sync_channel, Receiver};
9use std::thread;
10
11/// Generates requests based on workload configuration
12pub struct RequestGenerator {
13    workload: WorkloadConfig,
14    rng: StdRng,
15    next_arrival_time: f64,
16    requests_generated: usize,
17    next_request_id: u64,
18    /// For closed-loop: track pending requests to generate when completions occur
19    pending_closed_loop_requests: Vec<f64>,
20    /// Dataset receiver (if using dataset mode) - receives pre-loaded entries from background thread
21    dataset_receiver: Option<Receiver<Option<DatasetEntry>>>,
22    /// Track if dataset has been exhausted (received None from channel)
23    dataset_exhausted: bool,
24}
25
26impl RequestGenerator {
27    pub fn new(workload: WorkloadConfig) -> Self {
28        let mut rng = StdRng::seed_from_u64(workload.seed);
29        let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
30
31        // For closed-loop, initialize with N requests at time 0
32        let mut pending_closed_loop_requests = Vec::new();
33        if is_closed_loop {
34            if let Some(num_users) = workload.num_concurrent_users {
35                // Generate initial requests for all concurrent users
36                pending_closed_loop_requests = vec![0.0; num_users]
37            }
38        };
39
40        let next_arrival_time = if is_closed_loop && !pending_closed_loop_requests.is_empty() {
41            0.0 // Start immediately with the first batch
42        } else {
43            Self::sample_next_arrival(
44                0.0,
45                &workload.arrival_pattern,
46                workload.arrival_rate,
47                &mut rng,
48            )
49        };
50
51        Self {
52            workload,
53            rng,
54            next_arrival_time,
55            requests_generated: 0,
56            next_request_id: 0,
57            pending_closed_loop_requests,
58            dataset_receiver: None,
59            dataset_exhausted: false,
60        }
61    }
62
63    /// Create a new generator from a dataset iterator
64    /// Spawns a background thread to read, parse, and batch-tokenize entries in parallel
65    /// Buffer size controls memory usage (entries buffered ahead of simulation)
66    pub fn from_dataset<I>(
67        workload: WorkloadConfig,
68        dataset_iterator: I,
69        _total_entries: Option<usize>,
70        tokenizer: BatchTokenizerFn,
71    ) -> Self
72    where
73        I: Iterator<Item = Result<Option<UnparsedEntry>, Box<dyn std::error::Error>>>
74            + Send
75            + 'static,
76    {
77        let rng = StdRng::seed_from_u64(workload.seed);
78        let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
79
80        // For closed-loop, initialize with N requests at time 0
81        let mut pending_closed_loop_requests = Vec::new();
82        if is_closed_loop {
83            if let Some(num_users) = workload.num_concurrent_users {
84                pending_closed_loop_requests = vec![0.0; num_users]
85            }
86        };
87
88        let next_arrival_time = if is_closed_loop && !pending_closed_loop_requests.is_empty() {
89            0.0
90        } else {
91            0.0 // First request arrives at t=0
92        };
93
94        // Spawn background thread to load and batch-tokenize entries
95        // Buffer size: 5000 entries (~10-50MB depending on token counts)
96        let (sender, receiver) = sync_channel::<Option<DatasetEntry>>(5000);
97
98        thread::spawn(move || {
99            let batch_size: usize = std::env::var("TOKENIZER_BATCH_SIZE")
100                .ok()
101                .and_then(|s| s.parse().ok())
102                .unwrap_or(32); // Default: 32 (optimal for latency/throughput balance)
103            let mut batch = Vec::with_capacity(batch_size);
104
105            for result in dataset_iterator {
106                match result {
107                    Ok(Some(unparsed)) => {
108                        batch.push(unparsed);
109
110                        // Process batch when full
111                        if batch.len() >= batch_size {
112                            if let Err(_) =
113                                Self::tokenize_and_send_batch(&mut batch, &tokenizer, &sender)
114                            {
115                                // Receiver dropped, simulation ended early
116                                break;
117                            }
118                        }
119                    }
120                    Ok(None) => {
121                        // End of dataset - flush remaining batch and send completion signal
122                        if !batch.is_empty() {
123                            let _ = Self::tokenize_and_send_batch(&mut batch, &tokenizer, &sender);
124                        }
125                        let _ = sender.send(None);
126                        break;
127                    }
128                    Err(e) => {
129                        eprintln!("Error loading dataset entry: {}", e);
130                        break;
131                    }
132                }
133            }
134        });
135
136        Self {
137            workload,
138            rng,
139            next_arrival_time,
140            requests_generated: 0,
141            next_request_id: 0,
142            pending_closed_loop_requests,
143            dataset_receiver: Some(receiver),
144            dataset_exhausted: false,
145        }
146    }
147
148    /// Check if using dataset mode
149    pub fn is_dataset_mode(&self) -> bool {
150        self.dataset_receiver.is_some()
151    }
152
153    /// Batch tokenize and send entries to the channel
154    /// Returns Err if the receiver dropped (simulation ended)
155    fn tokenize_and_send_batch(
156        batch: &mut Vec<UnparsedEntry>,
157        tokenizer: &BatchTokenizerFn,
158        sender: &std::sync::mpsc::SyncSender<Option<DatasetEntry>>,
159    ) -> Result<(), ()> {
160        if batch.is_empty() {
161            return Ok(());
162        }
163
164        // Batch tokenize all entries at once (much faster!)
165        let prompt_inputs: Vec<_> = batch.iter().map(|e| e.prompt_input.clone()).collect();
166        let all_tokens = match tokenizer(&prompt_inputs) {
167            Ok(tokens) => tokens,
168            Err(e) => {
169                eprintln!("Batch tokenization failed: {}", e);
170                return Err(());
171            }
172        };
173
174        // Send tokenized entries
175        for (unparsed, prompt_tokens) in batch.drain(..).zip(all_tokens.into_iter()) {
176            let entry = DatasetEntry {
177                request_id: unparsed.request_id,
178                prompt_tokens,
179                max_output_tokens: unparsed.max_output_tokens,
180            };
181
182            // Send to channel - returns Err if receiver dropped
183            if sender.send(Some(entry)).is_err() {
184                return Err(());
185            }
186        }
187        Ok(())
188    }
189
190    /// Get the next scheduled arrival time
191    pub fn peek_next_arrival_time(&self) -> f64 {
192        self.next_arrival_time
193    }
194
195    /// Compute block hashes from token IDs (for real prefix caching)
196    /// Uses incremental hashing: hash of block i includes all tokens up to block i
197    fn compute_block_hashes(tokens: &[u32], block_size: usize) -> Vec<u64> {
198        let num_blocks = tokens.len().div_ceil(block_size);
199        let mut hashes = Vec::with_capacity(num_blocks);
200
201        for block_idx in 0..num_blocks {
202            let end = ((block_idx + 1) * block_size).min(tokens.len());
203            let block_tokens = &tokens[..end]; // All tokens up to this block
204
205            // Hash all tokens cumulatively
206            let mut hasher = DefaultHasher::new();
207            block_tokens.hash(&mut hasher);
208            hashes.push(hasher.finish());
209        }
210
211        hashes
212    }
213
214    /// Get the next request if its arrival time is before the given time
215    /// Returns None if no request is ready or all requests have been generated
216    pub fn next_if_before(&mut self, current_time: f64) -> Option<Request> {
217        // Dataset mode
218        if self.is_dataset_mode() {
219            return self.next_from_dataset(current_time);
220        }
221
222        let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
223
224        // For closed-loop, check pending requests
225        if is_closed_loop {
226            // Check if we've generated all requests
227            if let Some(max_requests) = self.workload.num_requests {
228                if self.requests_generated >= max_requests {
229                    // Clear any remaining pending requests that won't be used
230                    self.pending_closed_loop_requests.clear();
231                    return None;
232                }
233            }
234
235            // Find the earliest pending request that has arrived
236            if let Some(pos) = self
237                .pending_closed_loop_requests
238                .iter()
239                .position(|&t| t <= current_time)
240            {
241                let arrival_time = self.pending_closed_loop_requests.remove(pos);
242
243                // Generate request
244                let request_id = format!("req-{}", self.next_request_id);
245                self.next_request_id += 1;
246
247                let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
248                let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
249
250                let mut request = Request::new(
251                    request_id,
252                    0, // Default priority
253                    arrival_time,
254                    num_prompt_tokens,
255                    max_output_tokens,
256                );
257
258                // Generate block hashes for the prompt
259                // For now, sample from a small range to get realistic cache hit rates
260                // First few blocks more likely to be shared (simulating common system prompts)
261                let num_blocks = num_prompt_tokens.div_ceil(16) as usize; // Assume 16-token blocks
262                request.prompt_block_hashes = (0..num_blocks)
263                    .map(|_| self.rng.gen_range(0..u64::MAX))
264                    .collect();
265
266                self.requests_generated += 1;
267                return Some(request);
268            }
269            return None;
270        }
271
272        // Original logic for non-closed-loop patterns
273        // Check if we've generated all requests
274        if let Some(max_requests) = self.workload.num_requests {
275            if self.requests_generated >= max_requests {
276                return None;
277            }
278        }
279
280        // Check if next request has arrived
281        if self.next_arrival_time > current_time {
282            return None;
283        }
284
285        // Generate request
286        let request_id = format!("req-{}", self.next_request_id);
287        self.next_request_id += 1;
288
289        let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
290        let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
291
292        let mut request = Request::new(
293            request_id,
294            0, // Default priority
295            self.next_arrival_time,
296            num_prompt_tokens,
297            max_output_tokens,
298        );
299
300        // Generate block hashes for the prompt
301        // For now, sample from a small range to get realistic cache hit rates
302        // First few blocks more likely to be shared (simulating common system prompts)
303        let num_blocks = num_prompt_tokens.div_ceil(16) as usize; // Assume 16-token blocks
304        request.prompt_block_hashes = (0..num_blocks)
305            .map(|_| self.rng.gen_range(0..u64::MAX))
306            .collect();
307
308        self.requests_generated += 1;
309
310        // Sample next arrival time
311        self.next_arrival_time = Self::sample_next_arrival(
312            self.next_arrival_time,
313            &self.workload.arrival_pattern,
314            self.workload.arrival_rate,
315            &mut self.rng,
316        );
317
318        Some(request)
319    }
320
321    /// Sample the next arrival time based on the arrival pattern
322    fn sample_next_arrival(current_time: f64, pattern: &str, rate: f64, rng: &mut StdRng) -> f64 {
323        match pattern.to_lowercase().as_str() {
324            "poisson" => {
325                // Poisson process: inter-arrival times are exponentially distributed
326                let exp = Exp::new(rate).unwrap();
327                let inter_arrival = exp.sample(rng);
328                current_time + inter_arrival
329            }
330            "uniform" => {
331                // Uniform: constant inter-arrival time
332                let inter_arrival = 1.0 / rate;
333                current_time + inter_arrival
334            }
335            "burst" => {
336                // Burst: requests arrive in bursts with gaps
337                // Simple implementation: alternate between fast and slow
338                if rng.gen_bool(0.2) {
339                    // 20% chance of burst
340                    current_time + rng.gen_range(0.001..0.01)
341                } else {
342                    current_time + rng.gen_range(0.5..2.0)
343                }
344            }
345            "fixed_rate" => {
346                // Fixed rate: exact inter-arrival time
347                current_time + 1.0 / rate
348            }
349            "batched" => {
350                // Batched: all requests arrive at time 0
351                0.0
352            }
353            _ => {
354                // Default to Poisson
355                let exp = Exp::new(rate).unwrap();
356                current_time + exp.sample(rng)
357            }
358        }
359    }
360
361    /// Get next request from dataset (receives from background thread)
362    fn next_from_dataset(&mut self, current_time: f64) -> Option<Request> {
363        // If already exhausted, no more requests
364        if self.dataset_exhausted {
365            return None;
366        }
367
368        let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
369
370        // For closed-loop, handle pending requests
371        if is_closed_loop {
372            // Check if we've generated all requests
373            if let Some(max_requests) = self.workload.num_requests {
374                if self.requests_generated >= max_requests {
375                    // Clear any remaining pending requests that won't be used
376                    self.pending_closed_loop_requests.clear();
377                    return None;
378                }
379            }
380
381            // Find the earliest pending request that has arrived
382            if let Some(pos) = self
383                .pending_closed_loop_requests
384                .iter()
385                .position(|&t| t <= current_time)
386            {
387                let arrival_time = self.pending_closed_loop_requests.remove(pos);
388
389                // Receive from channel
390                let entry = match self.dataset_receiver.as_ref()?.recv() {
391                    Ok(Some(e)) => e,
392                    Ok(None) => {
393                        // End of dataset signaled
394                        self.dataset_exhausted = true;
395                        return None;
396                    }
397                    Err(_) => {
398                        // Channel error (sender dropped)
399                        self.dataset_exhausted = true;
400                        return None;
401                    }
402                };
403
404                // Sample actual output length from distribution
405                let sampled_output_len = self.workload.output_len_dist.sample(&mut self.rng);
406
407                // If max_output_tokens is specified in the dataset, cap at that; otherwise use sampled value
408                let max_output_tokens = entry.max_output_tokens.unwrap_or(16384);
409                let target_output_tokens = sampled_output_len.min(max_output_tokens);
410
411                let mut request = Request::new_with_target(
412                    entry.request_id.clone(),
413                    0,
414                    arrival_time,
415                    entry.num_prompt_tokens(),
416                    max_output_tokens,
417                    target_output_tokens,
418                );
419
420                request.prompt_block_hashes = Self::compute_block_hashes(&entry.prompt_tokens, 16);
421                self.requests_generated += 1;
422
423                return Some(request);
424            }
425            return None;
426        }
427
428        // Non-closed-loop: original logic
429        // Check if it's time for next arrival
430        if self.next_arrival_time > current_time {
431            return None;
432        }
433
434        // Receive from channel
435        let entry = match self.dataset_receiver.as_ref()?.recv() {
436            Ok(Some(e)) => e,
437            Ok(None) => {
438                // End of dataset signaled
439                self.dataset_exhausted = true;
440                return None;
441            }
442            Err(_) => {
443                // Channel error (sender dropped)
444                self.dataset_exhausted = true;
445                return None;
446            }
447        };
448
449        let arrival_time = self.next_arrival_time;
450
451        // Sample actual output length from distribution
452        let sampled_output_len = self.workload.output_len_dist.sample(&mut self.rng);
453
454        // If max_output_tokens is specified in the dataset, cap at that; otherwise use sampled value
455        let max_output_tokens = entry.max_output_tokens.unwrap_or(16384);
456        let target_output_tokens = sampled_output_len.min(max_output_tokens);
457
458        let mut request = Request::new_with_target(
459            entry.request_id.clone(),
460            0,
461            arrival_time,
462            entry.num_prompt_tokens(),
463            max_output_tokens,
464            target_output_tokens,
465        );
466
467        request.prompt_block_hashes = Self::compute_block_hashes(&entry.prompt_tokens, 16);
468        self.requests_generated += 1;
469
470        // Sample next arrival time AFTER creating the request
471        // but ONLY if we haven't hit the request limit
472        // This prevents sampling a bogus future time for a request that won't exist
473        let should_sample_next = if let Some(max_requests) = self.workload.num_requests {
474            self.requests_generated < max_requests
475        } else {
476            // No limit set, keep sampling (will stop when channel sends None)
477            true
478        };
479
480        if should_sample_next {
481            self.next_arrival_time = Self::sample_next_arrival(
482                self.next_arrival_time,
483                &self.workload.arrival_pattern,
484                self.workload.arrival_rate,
485                &mut self.rng,
486            );
487        }
488
489        Some(request)
490    }
491
492    /// Check if all requests have been generated
493    pub fn is_finished(&self) -> bool {
494        let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
495
496        // Dataset mode: check if we've hit limit or dataset is exhausted
497        if self.is_dataset_mode() {
498            // If num_requests is set, check against that limit
499            if let Some(max_requests) = self.workload.num_requests {
500                if is_closed_loop {
501                    // For closed-loop with dataset, we're finished when we've generated max_requests
502                    // AND have no pending requests (or dataset is exhausted)
503                    return (self.requests_generated >= max_requests
504                        && self.pending_closed_loop_requests.is_empty())
505                        || self.dataset_exhausted;
506                } else {
507                    return self.requests_generated >= max_requests;
508                }
509            }
510            // Otherwise, check if dataset has been fully consumed
511            return self.dataset_exhausted;
512        }
513
514        // Synthetic workload mode
515        if let Some(max_requests) = self.workload.num_requests {
516            if is_closed_loop {
517                // For closed-loop, we're finished when we've generated max_requests
518                // AND have no pending requests
519                self.requests_generated >= max_requests
520                    && self.pending_closed_loop_requests.is_empty()
521            } else {
522                self.requests_generated >= max_requests
523            }
524        } else {
525            false
526        }
527    }
528
529    /// Called when a request completes (for closed-loop pattern)
530    /// Generates a new request for that "user slot" at the completion time
531    pub fn on_request_complete(&mut self, completion_time: f64) {
532        let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
533        if !is_closed_loop {
534            return; // Only applicable to closed-loop
535        }
536
537        // Check if we should generate more requests
538        if let Some(max_requests) = self.workload.num_requests {
539            if self.requests_generated >= max_requests {
540                return; // Already generated all requested requests
541            }
542        }
543
544        // Add a new pending request at the completion time
545        self.pending_closed_loop_requests.push(completion_time);
546    }
547
548    /// Get number of requests generated so far
549    pub fn num_generated(&self) -> usize {
550        self.requests_generated
551    }
552
553    /// Peek at the next arrival time without generating the request
554    pub fn peek_next_arrival(&self) -> f64 {
555        self.next_arrival_time
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562    use crate::config::LengthDistribution;
563
564    fn create_test_workload(pattern: &str, rate: f64, num_requests: usize) -> WorkloadConfig {
565        WorkloadConfig {
566            dataset_path: None,
567            arrival_pattern: pattern.to_string(),
568            arrival_rate: rate,
569            num_concurrent_users: None,
570            input_len_dist: LengthDistribution::Fixed { value: 100 },
571            output_len_dist: LengthDistribution::Fixed { value: 50 },
572            num_requests: Some(num_requests),
573            duration_secs: None,
574            seed: 42,
575        }
576    }
577
578    #[test]
579    fn test_generator_creation() {
580        let workload = create_test_workload("poisson", 1.0, 10);
581        let generator = RequestGenerator::new(workload);
582
583        assert_eq!(generator.num_generated(), 0);
584        assert!(!generator.is_finished());
585    }
586
587    #[test]
588    fn test_generate_requests() {
589        let workload = create_test_workload("poisson", 10.0, 5);
590        let mut generator = RequestGenerator::new(workload);
591
592        let mut requests = Vec::new();
593        let mut current_time = 0.0;
594
595        while !generator.is_finished() {
596            // Advance time significantly to ensure all requests arrive
597            current_time += 10.0;
598
599            while let Some(req) = generator.next_if_before(current_time) {
600                requests.push(req);
601            }
602        }
603
604        assert_eq!(requests.len(), 5);
605        assert!(generator.is_finished());
606    }
607
608    #[test]
609    fn test_arrival_ordering() {
610        let workload = create_test_workload("poisson", 5.0, 10);
611        let mut generator = RequestGenerator::new(workload);
612
613        let mut requests = Vec::new();
614        let mut current_time = 0.0;
615
616        while !generator.is_finished() {
617            current_time += 10.0;
618            while let Some(req) = generator.next_if_before(current_time) {
619                requests.push(req);
620            }
621        }
622
623        // Check that arrival times are monotonically increasing
624        for i in 1..requests.len() {
625            assert!(requests[i].arrival_time >= requests[i - 1].arrival_time);
626        }
627    }
628
629    #[test]
630    fn test_fixed_rate_arrival() {
631        let workload = create_test_workload("fixed_rate", 2.0, 4);
632        let mut generator = RequestGenerator::new(workload);
633
634        let mut requests = Vec::new();
635        let mut current_time = 0.0;
636
637        while !generator.is_finished() {
638            current_time += 10.0;
639            while let Some(req) = generator.next_if_before(current_time) {
640                requests.push(req);
641            }
642        }
643
644        assert_eq!(requests.len(), 4);
645
646        // Check that inter-arrival times are approximately 1/rate = 0.5 seconds
647        for i in 1..requests.len() {
648            let inter_arrival = requests[i].arrival_time - requests[i - 1].arrival_time;
649            assert!((inter_arrival - 0.5).abs() < 1e-6);
650        }
651    }
652
653    #[test]
654    fn test_request_properties() {
655        let workload = create_test_workload("poisson", 1.0, 1);
656        let mut generator = RequestGenerator::new(workload);
657
658        let req = generator.next_if_before(10.0).unwrap();
659
660        assert_eq!(req.num_prompt_tokens, 100);
661        assert_eq!(req.max_output_tokens, 50);
662        assert_eq!(req.priority, 0);
663        assert!(req.request_id.starts_with("req-"));
664    }
665
666    #[test]
667    fn test_peek_next_arrival() {
668        let workload = create_test_workload("poisson", 1.0, 10);
669        let mut generator = RequestGenerator::new(workload);
670
671        let next_arrival = generator.peek_next_arrival();
672        assert!(next_arrival > 0.0);
673
674        // Generate the request
675        let req = generator.next_if_before(next_arrival + 1.0).unwrap();
676
677        // Check that arrival time matches what we peeked
678        assert_eq!(req.arrival_time, next_arrival);
679    }
680}