1use super::Request;
2use crate::config::WorkloadConfig;
3use crate::dataset::{BatchTokenizerFn, DatasetEntry, UnparsedEntry};
4use rand::{rngs::StdRng, Rng, SeedableRng};
5use rand_distr::{Distribution, Exp};
6use std::collections::hash_map::DefaultHasher;
7use std::hash::{Hash, Hasher};
8use std::sync::mpsc::{sync_channel, Receiver};
9use std::thread;
10
11pub struct RequestGenerator {
13 workload: WorkloadConfig,
14 rng: StdRng,
15 next_arrival_time: f64,
16 requests_generated: usize,
17 next_request_id: u64,
18 pending_closed_loop_requests: Vec<f64>,
20 dataset_receiver: Option<Receiver<Option<DatasetEntry>>>,
22 dataset_exhausted: bool,
24}
25
26impl RequestGenerator {
27 pub fn new(workload: WorkloadConfig) -> Self {
28 let mut rng = StdRng::seed_from_u64(workload.seed);
29 let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
30
31 let mut pending_closed_loop_requests = Vec::new();
33 if is_closed_loop {
34 if let Some(num_users) = workload.num_concurrent_users {
35 pending_closed_loop_requests = vec![0.0; num_users]
37 }
38 };
39
40 let next_arrival_time = if is_closed_loop && !pending_closed_loop_requests.is_empty() {
41 0.0 } else {
43 Self::sample_next_arrival(
44 0.0,
45 &workload.arrival_pattern,
46 workload.arrival_rate,
47 &mut rng,
48 )
49 };
50
51 Self {
52 workload,
53 rng,
54 next_arrival_time,
55 requests_generated: 0,
56 next_request_id: 0,
57 pending_closed_loop_requests,
58 dataset_receiver: None,
59 dataset_exhausted: false,
60 }
61 }
62
63 pub fn from_dataset<I>(
67 workload: WorkloadConfig,
68 dataset_iterator: I,
69 _total_entries: Option<usize>,
70 tokenizer: BatchTokenizerFn,
71 ) -> Self
72 where
73 I: Iterator<Item = Result<Option<UnparsedEntry>, Box<dyn std::error::Error>>>
74 + Send
75 + 'static,
76 {
77 let rng = StdRng::seed_from_u64(workload.seed);
78 let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
79
80 let mut pending_closed_loop_requests = Vec::new();
82 if is_closed_loop {
83 if let Some(num_users) = workload.num_concurrent_users {
84 pending_closed_loop_requests = vec![0.0; num_users]
85 }
86 };
87
88 let next_arrival_time = if is_closed_loop && !pending_closed_loop_requests.is_empty() {
89 0.0
90 } else {
91 0.0 };
93
94 let (sender, receiver) = sync_channel::<Option<DatasetEntry>>(5000);
97
98 thread::spawn(move || {
99 let batch_size: usize = std::env::var("TOKENIZER_BATCH_SIZE")
100 .ok()
101 .and_then(|s| s.parse().ok())
102 .unwrap_or(32); let mut batch = Vec::with_capacity(batch_size);
104
105 for result in dataset_iterator {
106 match result {
107 Ok(Some(unparsed)) => {
108 batch.push(unparsed);
109
110 if batch.len() >= batch_size {
112 if let Err(_) =
113 Self::tokenize_and_send_batch(&mut batch, &tokenizer, &sender)
114 {
115 break;
117 }
118 }
119 }
120 Ok(None) => {
121 if !batch.is_empty() {
123 let _ = Self::tokenize_and_send_batch(&mut batch, &tokenizer, &sender);
124 }
125 let _ = sender.send(None);
126 break;
127 }
128 Err(e) => {
129 eprintln!("Error loading dataset entry: {}", e);
130 break;
131 }
132 }
133 }
134 });
135
136 Self {
137 workload,
138 rng,
139 next_arrival_time,
140 requests_generated: 0,
141 next_request_id: 0,
142 pending_closed_loop_requests,
143 dataset_receiver: Some(receiver),
144 dataset_exhausted: false,
145 }
146 }
147
148 pub fn is_dataset_mode(&self) -> bool {
150 self.dataset_receiver.is_some()
151 }
152
153 fn tokenize_and_send_batch(
156 batch: &mut Vec<UnparsedEntry>,
157 tokenizer: &BatchTokenizerFn,
158 sender: &std::sync::mpsc::SyncSender<Option<DatasetEntry>>,
159 ) -> Result<(), ()> {
160 if batch.is_empty() {
161 return Ok(());
162 }
163
164 let message_arrays: Vec<&[_]> = batch.iter().map(|e| e.messages.as_slice()).collect();
166
167 let all_tokens = match tokenizer(&message_arrays) {
169 Ok(tokens) => tokens,
170 Err(e) => {
171 eprintln!("Batch tokenization failed: {}", e);
172 return Err(());
173 }
174 };
175
176 for (unparsed, prompt_tokens) in batch.drain(..).zip(all_tokens.into_iter()) {
178 let entry = DatasetEntry {
179 request_id: unparsed.request_id,
180 prompt_tokens,
181 max_output_tokens: unparsed.max_output_tokens,
182 };
183
184 if sender.send(Some(entry)).is_err() {
186 return Err(());
187 }
188 }
189 Ok(())
190 }
191
192 pub fn peek_next_arrival_time(&self) -> f64 {
194 self.next_arrival_time
195 }
196
197 fn compute_block_hashes(tokens: &[u32], block_size: usize) -> Vec<u64> {
200 let num_blocks = tokens.len().div_ceil(block_size);
201 let mut hashes = Vec::with_capacity(num_blocks);
202
203 for block_idx in 0..num_blocks {
204 let end = ((block_idx + 1) * block_size).min(tokens.len());
205 let block_tokens = &tokens[..end]; let mut hasher = DefaultHasher::new();
209 block_tokens.hash(&mut hasher);
210 hashes.push(hasher.finish());
211 }
212
213 hashes
214 }
215
216 pub fn next_if_before(&mut self, current_time: f64) -> Option<Request> {
219 if self.is_dataset_mode() {
221 return self.next_from_dataset(current_time);
222 }
223
224 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
225
226 if is_closed_loop {
228 if let Some(max_requests) = self.workload.num_requests {
230 if self.requests_generated >= max_requests {
231 self.pending_closed_loop_requests.clear();
233 return None;
234 }
235 }
236
237 if let Some(pos) = self
239 .pending_closed_loop_requests
240 .iter()
241 .position(|&t| t <= current_time)
242 {
243 let arrival_time = self.pending_closed_loop_requests.remove(pos);
244
245 let request_id = format!("req-{}", self.next_request_id);
247 self.next_request_id += 1;
248
249 let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
250 let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
251
252 let mut request = Request::new(
253 request_id,
254 0, arrival_time,
256 num_prompt_tokens,
257 max_output_tokens,
258 );
259
260 let num_blocks = num_prompt_tokens.div_ceil(16) as usize; request.prompt_block_hashes = (0..num_blocks)
265 .map(|_| self.rng.gen_range(0..u64::MAX))
266 .collect();
267
268 self.requests_generated += 1;
269 return Some(request);
270 }
271 return None;
272 }
273
274 if let Some(max_requests) = self.workload.num_requests {
277 if self.requests_generated >= max_requests {
278 return None;
279 }
280 }
281
282 if self.next_arrival_time > current_time {
284 return None;
285 }
286
287 let request_id = format!("req-{}", self.next_request_id);
289 self.next_request_id += 1;
290
291 let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
292 let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
293
294 let mut request = Request::new(
295 request_id,
296 0, self.next_arrival_time,
298 num_prompt_tokens,
299 max_output_tokens,
300 );
301
302 let num_blocks = num_prompt_tokens.div_ceil(16) as usize; request.prompt_block_hashes = (0..num_blocks)
307 .map(|_| self.rng.gen_range(0..u64::MAX))
308 .collect();
309
310 self.requests_generated += 1;
311
312 self.next_arrival_time = Self::sample_next_arrival(
314 self.next_arrival_time,
315 &self.workload.arrival_pattern,
316 self.workload.arrival_rate,
317 &mut self.rng,
318 );
319
320 Some(request)
321 }
322
323 fn sample_next_arrival(current_time: f64, pattern: &str, rate: f64, rng: &mut StdRng) -> f64 {
325 match pattern.to_lowercase().as_str() {
326 "poisson" => {
327 let exp = Exp::new(rate).unwrap();
329 let inter_arrival = exp.sample(rng);
330 current_time + inter_arrival
331 }
332 "uniform" => {
333 let inter_arrival = 1.0 / rate;
335 current_time + inter_arrival
336 }
337 "burst" => {
338 if rng.gen_bool(0.2) {
341 current_time + rng.gen_range(0.001..0.01)
343 } else {
344 current_time + rng.gen_range(0.5..2.0)
345 }
346 }
347 "fixed_rate" => {
348 current_time + 1.0 / rate
350 }
351 "batched" => {
352 0.0
354 }
355 _ => {
356 let exp = Exp::new(rate).unwrap();
358 current_time + exp.sample(rng)
359 }
360 }
361 }
362
363 fn next_from_dataset(&mut self, current_time: f64) -> Option<Request> {
365 if self.dataset_exhausted {
367 return None;
368 }
369
370 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
371
372 if is_closed_loop {
374 if let Some(max_requests) = self.workload.num_requests {
376 if self.requests_generated >= max_requests {
377 self.pending_closed_loop_requests.clear();
379 return None;
380 }
381 }
382
383 if let Some(pos) = self
385 .pending_closed_loop_requests
386 .iter()
387 .position(|&t| t <= current_time)
388 {
389 let arrival_time = self.pending_closed_loop_requests.remove(pos);
390
391 let entry = match self.dataset_receiver.as_ref()?.recv() {
393 Ok(Some(e)) => e,
394 Ok(None) => {
395 self.dataset_exhausted = true;
397 return None;
398 }
399 Err(_) => {
400 self.dataset_exhausted = true;
402 return None;
403 }
404 };
405
406 let sampled_output_len = self.workload.output_len_dist.sample(&mut self.rng);
408
409 let max_output_tokens = entry.max_output_tokens.unwrap_or(16384);
411 let target_output_tokens = sampled_output_len.min(max_output_tokens);
412
413 let mut request = Request::new_with_target(
414 entry.request_id.clone(),
415 0,
416 arrival_time,
417 entry.num_prompt_tokens(),
418 max_output_tokens,
419 target_output_tokens,
420 );
421
422 request.prompt_block_hashes = Self::compute_block_hashes(&entry.prompt_tokens, 16);
423 self.requests_generated += 1;
424
425 return Some(request);
426 }
427 return None;
428 }
429
430 if self.next_arrival_time > current_time {
433 return None;
434 }
435
436 let entry = match self.dataset_receiver.as_ref()?.recv() {
438 Ok(Some(e)) => e,
439 Ok(None) => {
440 self.dataset_exhausted = true;
442 return None;
443 }
444 Err(_) => {
445 self.dataset_exhausted = true;
447 return None;
448 }
449 };
450
451 let arrival_time = self.next_arrival_time;
452
453 let sampled_output_len = self.workload.output_len_dist.sample(&mut self.rng);
455
456 let max_output_tokens = entry.max_output_tokens.unwrap_or(16384);
458 let target_output_tokens = sampled_output_len.min(max_output_tokens);
459
460 let mut request = Request::new_with_target(
461 entry.request_id.clone(),
462 0,
463 arrival_time,
464 entry.num_prompt_tokens(),
465 max_output_tokens,
466 target_output_tokens,
467 );
468
469 request.prompt_block_hashes = Self::compute_block_hashes(&entry.prompt_tokens, 16);
470 self.requests_generated += 1;
471
472 let should_sample_next = if let Some(max_requests) = self.workload.num_requests {
476 self.requests_generated < max_requests
477 } else {
478 true
480 };
481
482 if should_sample_next {
483 self.next_arrival_time = Self::sample_next_arrival(
484 self.next_arrival_time,
485 &self.workload.arrival_pattern,
486 self.workload.arrival_rate,
487 &mut self.rng,
488 );
489 }
490
491 Some(request)
492 }
493
494 pub fn is_finished(&self) -> bool {
496 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
497
498 if self.is_dataset_mode() {
500 if let Some(max_requests) = self.workload.num_requests {
502 if is_closed_loop {
503 return (self.requests_generated >= max_requests
506 && self.pending_closed_loop_requests.is_empty())
507 || self.dataset_exhausted;
508 } else {
509 return self.requests_generated >= max_requests;
510 }
511 }
512 return self.dataset_exhausted;
514 }
515
516 if let Some(max_requests) = self.workload.num_requests {
518 if is_closed_loop {
519 self.requests_generated >= max_requests
522 && self.pending_closed_loop_requests.is_empty()
523 } else {
524 self.requests_generated >= max_requests
525 }
526 } else {
527 false
528 }
529 }
530
531 pub fn on_request_complete(&mut self, completion_time: f64) {
534 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
535 if !is_closed_loop {
536 return; }
538
539 if let Some(max_requests) = self.workload.num_requests {
541 if self.requests_generated >= max_requests {
542 return; }
544 }
545
546 self.pending_closed_loop_requests.push(completion_time);
548 }
549
550 pub fn num_generated(&self) -> usize {
552 self.requests_generated
553 }
554
555 pub fn peek_next_arrival(&self) -> f64 {
557 self.next_arrival_time
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use crate::config::LengthDistribution;
565
566 fn create_test_workload(pattern: &str, rate: f64, num_requests: usize) -> WorkloadConfig {
567 WorkloadConfig {
568 dataset_path: None,
569 arrival_pattern: pattern.to_string(),
570 arrival_rate: rate,
571 num_concurrent_users: None,
572 input_len_dist: LengthDistribution::Fixed { value: 100 },
573 output_len_dist: LengthDistribution::Fixed { value: 50 },
574 num_requests: Some(num_requests),
575 duration_secs: None,
576 seed: 42,
577 }
578 }
579
580 #[test]
581 fn test_generator_creation() {
582 let workload = create_test_workload("poisson", 1.0, 10);
583 let generator = RequestGenerator::new(workload);
584
585 assert_eq!(generator.num_generated(), 0);
586 assert!(!generator.is_finished());
587 }
588
589 #[test]
590 fn test_generate_requests() {
591 let workload = create_test_workload("poisson", 10.0, 5);
592 let mut generator = RequestGenerator::new(workload);
593
594 let mut requests = Vec::new();
595 let mut current_time = 0.0;
596
597 while !generator.is_finished() {
598 current_time += 10.0;
600
601 while let Some(req) = generator.next_if_before(current_time) {
602 requests.push(req);
603 }
604 }
605
606 assert_eq!(requests.len(), 5);
607 assert!(generator.is_finished());
608 }
609
610 #[test]
611 fn test_arrival_ordering() {
612 let workload = create_test_workload("poisson", 5.0, 10);
613 let mut generator = RequestGenerator::new(workload);
614
615 let mut requests = Vec::new();
616 let mut current_time = 0.0;
617
618 while !generator.is_finished() {
619 current_time += 10.0;
620 while let Some(req) = generator.next_if_before(current_time) {
621 requests.push(req);
622 }
623 }
624
625 for i in 1..requests.len() {
627 assert!(requests[i].arrival_time >= requests[i - 1].arrival_time);
628 }
629 }
630
631 #[test]
632 fn test_fixed_rate_arrival() {
633 let workload = create_test_workload("fixed_rate", 2.0, 4);
634 let mut generator = RequestGenerator::new(workload);
635
636 let mut requests = Vec::new();
637 let mut current_time = 0.0;
638
639 while !generator.is_finished() {
640 current_time += 10.0;
641 while let Some(req) = generator.next_if_before(current_time) {
642 requests.push(req);
643 }
644 }
645
646 assert_eq!(requests.len(), 4);
647
648 for i in 1..requests.len() {
650 let inter_arrival = requests[i].arrival_time - requests[i - 1].arrival_time;
651 assert!((inter_arrival - 0.5).abs() < 1e-6);
652 }
653 }
654
655 #[test]
656 fn test_request_properties() {
657 let workload = create_test_workload("poisson", 1.0, 1);
658 let mut generator = RequestGenerator::new(workload);
659
660 let req = generator.next_if_before(10.0).unwrap();
661
662 assert_eq!(req.num_prompt_tokens, 100);
663 assert_eq!(req.max_output_tokens, 50);
664 assert_eq!(req.priority, 0);
665 assert!(req.request_id.starts_with("req-"));
666 }
667
668 #[test]
669 fn test_peek_next_arrival() {
670 let workload = create_test_workload("poisson", 1.0, 10);
671 let mut generator = RequestGenerator::new(workload);
672
673 let next_arrival = generator.peek_next_arrival();
674 assert!(next_arrival > 0.0);
675
676 let req = generator.next_if_before(next_arrival + 1.0).unwrap();
678
679 assert_eq!(req.arrival_time, next_arrival);
681 }
682}