1use super::Request;
2use crate::config::WorkloadConfig;
3use crate::dataset::{BatchTokenizerFn, DatasetEntry, UnparsedEntry};
4use rand::{rngs::StdRng, Rng, SeedableRng};
5use rand_distr::{Distribution, Exp};
6use std::collections::hash_map::DefaultHasher;
7use std::hash::{Hash, Hasher};
8use std::sync::mpsc::{sync_channel, Receiver};
9use std::thread;
10
11pub struct RequestGenerator {
13 workload: WorkloadConfig,
14 rng: StdRng,
15 next_arrival_time: f64,
16 requests_generated: usize,
17 next_request_id: u64,
18 pending_closed_loop_requests: Vec<f64>,
20 dataset_receiver: Option<Receiver<Option<DatasetEntry>>>,
22 dataset_exhausted: bool,
24}
25
26impl RequestGenerator {
27 pub fn new(workload: WorkloadConfig) -> Self {
28 let mut rng = StdRng::seed_from_u64(workload.seed);
29 let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
30
31 let mut pending_closed_loop_requests = Vec::new();
33 if is_closed_loop {
34 if let Some(num_users) = workload.num_concurrent_users {
35 pending_closed_loop_requests = vec![0.0; num_users]
37 }
38 };
39
40 let next_arrival_time = if is_closed_loop && !pending_closed_loop_requests.is_empty() {
41 0.0 } else {
43 Self::sample_next_arrival(
44 0.0,
45 &workload.arrival_pattern,
46 workload.arrival_rate,
47 &mut rng,
48 )
49 };
50
51 Self {
52 workload,
53 rng,
54 next_arrival_time,
55 requests_generated: 0,
56 next_request_id: 0,
57 pending_closed_loop_requests,
58 dataset_receiver: None,
59 dataset_exhausted: false,
60 }
61 }
62
63 pub fn from_dataset<I>(
67 workload: WorkloadConfig,
68 dataset_iterator: I,
69 _total_entries: Option<usize>,
70 tokenizer: BatchTokenizerFn,
71 ) -> Self
72 where
73 I: Iterator<Item = Result<Option<UnparsedEntry>, Box<dyn std::error::Error>>>
74 + Send
75 + 'static,
76 {
77 let rng = StdRng::seed_from_u64(workload.seed);
78 let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
79
80 let mut pending_closed_loop_requests = Vec::new();
82 if is_closed_loop {
83 if let Some(num_users) = workload.num_concurrent_users {
84 pending_closed_loop_requests = vec![0.0; num_users]
85 }
86 };
87
88 let next_arrival_time = if is_closed_loop && !pending_closed_loop_requests.is_empty() {
89 0.0
90 } else {
91 0.0 };
93
94 let (sender, receiver) = sync_channel::<Option<DatasetEntry>>(5000);
97
98 thread::spawn(move || {
99 let batch_size: usize = std::env::var("TOKENIZER_BATCH_SIZE")
100 .ok()
101 .and_then(|s| s.parse().ok())
102 .unwrap_or(32); let mut batch = Vec::with_capacity(batch_size);
104
105 for result in dataset_iterator {
106 match result {
107 Ok(Some(unparsed)) => {
108 batch.push(unparsed);
109
110 if batch.len() >= batch_size {
112 if let Err(_) =
113 Self::tokenize_and_send_batch(&mut batch, &tokenizer, &sender)
114 {
115 break;
117 }
118 }
119 }
120 Ok(None) => {
121 if !batch.is_empty() {
123 let _ = Self::tokenize_and_send_batch(&mut batch, &tokenizer, &sender);
124 }
125 let _ = sender.send(None);
126 break;
127 }
128 Err(e) => {
129 eprintln!("Error loading dataset entry: {}", e);
130 break;
131 }
132 }
133 }
134 });
135
136 Self {
137 workload,
138 rng,
139 next_arrival_time,
140 requests_generated: 0,
141 next_request_id: 0,
142 pending_closed_loop_requests,
143 dataset_receiver: Some(receiver),
144 dataset_exhausted: false,
145 }
146 }
147
148 pub fn is_dataset_mode(&self) -> bool {
150 self.dataset_receiver.is_some()
151 }
152
153 fn tokenize_and_send_batch(
156 batch: &mut Vec<UnparsedEntry>,
157 tokenizer: &BatchTokenizerFn,
158 sender: &std::sync::mpsc::SyncSender<Option<DatasetEntry>>,
159 ) -> Result<(), ()> {
160 if batch.is_empty() {
161 return Ok(());
162 }
163
164 let prompt_inputs: Vec<_> = batch.iter().map(|e| e.prompt_input.clone()).collect();
166 let all_tokens = match tokenizer(&prompt_inputs) {
167 Ok(tokens) => tokens,
168 Err(e) => {
169 eprintln!("Batch tokenization failed: {}", e);
170 return Err(());
171 }
172 };
173
174 for (unparsed, prompt_tokens) in batch.drain(..).zip(all_tokens.into_iter()) {
176 let entry = DatasetEntry {
177 request_id: unparsed.request_id,
178 prompt_tokens,
179 max_output_tokens: unparsed.max_output_tokens,
180 };
181
182 if sender.send(Some(entry)).is_err() {
184 return Err(());
185 }
186 }
187 Ok(())
188 }
189
190 pub fn peek_next_arrival_time(&self) -> f64 {
192 self.next_arrival_time
193 }
194
195 fn compute_block_hashes(tokens: &[u32], block_size: usize) -> Vec<u64> {
198 let num_blocks = tokens.len().div_ceil(block_size);
199 let mut hashes = Vec::with_capacity(num_blocks);
200
201 for block_idx in 0..num_blocks {
202 let end = ((block_idx + 1) * block_size).min(tokens.len());
203 let block_tokens = &tokens[..end]; let mut hasher = DefaultHasher::new();
207 block_tokens.hash(&mut hasher);
208 hashes.push(hasher.finish());
209 }
210
211 hashes
212 }
213
214 pub fn next_if_before(&mut self, current_time: f64) -> Option<Request> {
217 if self.is_dataset_mode() {
219 return self.next_from_dataset(current_time);
220 }
221
222 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
223
224 if is_closed_loop {
226 if let Some(max_requests) = self.workload.num_requests {
228 if self.requests_generated >= max_requests {
229 self.pending_closed_loop_requests.clear();
231 return None;
232 }
233 }
234
235 if let Some(pos) = self
237 .pending_closed_loop_requests
238 .iter()
239 .position(|&t| t <= current_time)
240 {
241 let arrival_time = self.pending_closed_loop_requests.remove(pos);
242
243 let request_id = format!("req-{}", self.next_request_id);
245 self.next_request_id += 1;
246
247 let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
248 let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
249
250 let mut request = Request::new(
251 request_id,
252 0, arrival_time,
254 num_prompt_tokens,
255 max_output_tokens,
256 );
257
258 let num_blocks = num_prompt_tokens.div_ceil(16) as usize; request.prompt_block_hashes = (0..num_blocks)
263 .map(|_| self.rng.gen_range(0..u64::MAX))
264 .collect();
265
266 self.requests_generated += 1;
267 return Some(request);
268 }
269 return None;
270 }
271
272 if let Some(max_requests) = self.workload.num_requests {
275 if self.requests_generated >= max_requests {
276 return None;
277 }
278 }
279
280 if self.next_arrival_time > current_time {
282 return None;
283 }
284
285 let request_id = format!("req-{}", self.next_request_id);
287 self.next_request_id += 1;
288
289 let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
290 let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
291
292 let mut request = Request::new(
293 request_id,
294 0, self.next_arrival_time,
296 num_prompt_tokens,
297 max_output_tokens,
298 );
299
300 let num_blocks = num_prompt_tokens.div_ceil(16) as usize; request.prompt_block_hashes = (0..num_blocks)
305 .map(|_| self.rng.gen_range(0..u64::MAX))
306 .collect();
307
308 self.requests_generated += 1;
309
310 self.next_arrival_time = Self::sample_next_arrival(
312 self.next_arrival_time,
313 &self.workload.arrival_pattern,
314 self.workload.arrival_rate,
315 &mut self.rng,
316 );
317
318 Some(request)
319 }
320
321 fn sample_next_arrival(current_time: f64, pattern: &str, rate: f64, rng: &mut StdRng) -> f64 {
323 match pattern.to_lowercase().as_str() {
324 "poisson" => {
325 let exp = Exp::new(rate).unwrap();
327 let inter_arrival = exp.sample(rng);
328 current_time + inter_arrival
329 }
330 "uniform" => {
331 let inter_arrival = 1.0 / rate;
333 current_time + inter_arrival
334 }
335 "burst" => {
336 if rng.gen_bool(0.2) {
339 current_time + rng.gen_range(0.001..0.01)
341 } else {
342 current_time + rng.gen_range(0.5..2.0)
343 }
344 }
345 "fixed_rate" => {
346 current_time + 1.0 / rate
348 }
349 "batched" => {
350 0.0
352 }
353 _ => {
354 let exp = Exp::new(rate).unwrap();
356 current_time + exp.sample(rng)
357 }
358 }
359 }
360
361 fn next_from_dataset(&mut self, current_time: f64) -> Option<Request> {
363 if self.dataset_exhausted {
365 return None;
366 }
367
368 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
369
370 if is_closed_loop {
372 if let Some(max_requests) = self.workload.num_requests {
374 if self.requests_generated >= max_requests {
375 self.pending_closed_loop_requests.clear();
377 return None;
378 }
379 }
380
381 if let Some(pos) = self
383 .pending_closed_loop_requests
384 .iter()
385 .position(|&t| t <= current_time)
386 {
387 let arrival_time = self.pending_closed_loop_requests.remove(pos);
388
389 let entry = match self.dataset_receiver.as_ref()?.recv() {
391 Ok(Some(e)) => e,
392 Ok(None) => {
393 self.dataset_exhausted = true;
395 return None;
396 }
397 Err(_) => {
398 self.dataset_exhausted = true;
400 return None;
401 }
402 };
403
404 let sampled_output_len = self.workload.output_len_dist.sample(&mut self.rng);
406
407 let max_output_tokens = entry.max_output_tokens.unwrap_or(16384);
409 let target_output_tokens = sampled_output_len.min(max_output_tokens);
410
411 let mut request = Request::new_with_target(
412 entry.request_id.clone(),
413 0,
414 arrival_time,
415 entry.num_prompt_tokens(),
416 max_output_tokens,
417 target_output_tokens,
418 );
419
420 request.prompt_block_hashes = Self::compute_block_hashes(&entry.prompt_tokens, 16);
421 self.requests_generated += 1;
422
423 return Some(request);
424 }
425 return None;
426 }
427
428 if self.next_arrival_time > current_time {
431 return None;
432 }
433
434 let entry = match self.dataset_receiver.as_ref()?.recv() {
436 Ok(Some(e)) => e,
437 Ok(None) => {
438 self.dataset_exhausted = true;
440 return None;
441 }
442 Err(_) => {
443 self.dataset_exhausted = true;
445 return None;
446 }
447 };
448
449 let arrival_time = self.next_arrival_time;
450
451 let sampled_output_len = self.workload.output_len_dist.sample(&mut self.rng);
453
454 let max_output_tokens = entry.max_output_tokens.unwrap_or(16384);
456 let target_output_tokens = sampled_output_len.min(max_output_tokens);
457
458 let mut request = Request::new_with_target(
459 entry.request_id.clone(),
460 0,
461 arrival_time,
462 entry.num_prompt_tokens(),
463 max_output_tokens,
464 target_output_tokens,
465 );
466
467 request.prompt_block_hashes = Self::compute_block_hashes(&entry.prompt_tokens, 16);
468 self.requests_generated += 1;
469
470 let should_sample_next = if let Some(max_requests) = self.workload.num_requests {
474 self.requests_generated < max_requests
475 } else {
476 true
478 };
479
480 if should_sample_next {
481 self.next_arrival_time = Self::sample_next_arrival(
482 self.next_arrival_time,
483 &self.workload.arrival_pattern,
484 self.workload.arrival_rate,
485 &mut self.rng,
486 );
487 }
488
489 Some(request)
490 }
491
492 pub fn is_finished(&self) -> bool {
494 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
495
496 if self.is_dataset_mode() {
498 if let Some(max_requests) = self.workload.num_requests {
500 if is_closed_loop {
501 return (self.requests_generated >= max_requests
504 && self.pending_closed_loop_requests.is_empty())
505 || self.dataset_exhausted;
506 } else {
507 return self.requests_generated >= max_requests;
508 }
509 }
510 return self.dataset_exhausted;
512 }
513
514 if let Some(max_requests) = self.workload.num_requests {
516 if is_closed_loop {
517 self.requests_generated >= max_requests
520 && self.pending_closed_loop_requests.is_empty()
521 } else {
522 self.requests_generated >= max_requests
523 }
524 } else {
525 false
526 }
527 }
528
529 pub fn on_request_complete(&mut self, completion_time: f64) {
532 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
533 if !is_closed_loop {
534 return; }
536
537 if let Some(max_requests) = self.workload.num_requests {
539 if self.requests_generated >= max_requests {
540 return; }
542 }
543
544 self.pending_closed_loop_requests.push(completion_time);
546 }
547
548 pub fn num_generated(&self) -> usize {
550 self.requests_generated
551 }
552
553 pub fn peek_next_arrival(&self) -> f64 {
555 self.next_arrival_time
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use crate::config::LengthDistribution;
563
564 fn create_test_workload(pattern: &str, rate: f64, num_requests: usize) -> WorkloadConfig {
565 WorkloadConfig {
566 dataset_path: None,
567 arrival_pattern: pattern.to_string(),
568 arrival_rate: rate,
569 num_concurrent_users: None,
570 input_len_dist: LengthDistribution::Fixed { value: 100 },
571 output_len_dist: LengthDistribution::Fixed { value: 50 },
572 num_requests: Some(num_requests),
573 duration_secs: None,
574 seed: 42,
575 }
576 }
577
578 #[test]
579 fn test_generator_creation() {
580 let workload = create_test_workload("poisson", 1.0, 10);
581 let generator = RequestGenerator::new(workload);
582
583 assert_eq!(generator.num_generated(), 0);
584 assert!(!generator.is_finished());
585 }
586
587 #[test]
588 fn test_generate_requests() {
589 let workload = create_test_workload("poisson", 10.0, 5);
590 let mut generator = RequestGenerator::new(workload);
591
592 let mut requests = Vec::new();
593 let mut current_time = 0.0;
594
595 while !generator.is_finished() {
596 current_time += 10.0;
598
599 while let Some(req) = generator.next_if_before(current_time) {
600 requests.push(req);
601 }
602 }
603
604 assert_eq!(requests.len(), 5);
605 assert!(generator.is_finished());
606 }
607
608 #[test]
609 fn test_arrival_ordering() {
610 let workload = create_test_workload("poisson", 5.0, 10);
611 let mut generator = RequestGenerator::new(workload);
612
613 let mut requests = Vec::new();
614 let mut current_time = 0.0;
615
616 while !generator.is_finished() {
617 current_time += 10.0;
618 while let Some(req) = generator.next_if_before(current_time) {
619 requests.push(req);
620 }
621 }
622
623 for i in 1..requests.len() {
625 assert!(requests[i].arrival_time >= requests[i - 1].arrival_time);
626 }
627 }
628
629 #[test]
630 fn test_fixed_rate_arrival() {
631 let workload = create_test_workload("fixed_rate", 2.0, 4);
632 let mut generator = RequestGenerator::new(workload);
633
634 let mut requests = Vec::new();
635 let mut current_time = 0.0;
636
637 while !generator.is_finished() {
638 current_time += 10.0;
639 while let Some(req) = generator.next_if_before(current_time) {
640 requests.push(req);
641 }
642 }
643
644 assert_eq!(requests.len(), 4);
645
646 for i in 1..requests.len() {
648 let inter_arrival = requests[i].arrival_time - requests[i - 1].arrival_time;
649 assert!((inter_arrival - 0.5).abs() < 1e-6);
650 }
651 }
652
653 #[test]
654 fn test_request_properties() {
655 let workload = create_test_workload("poisson", 1.0, 1);
656 let mut generator = RequestGenerator::new(workload);
657
658 let req = generator.next_if_before(10.0).unwrap();
659
660 assert_eq!(req.num_prompt_tokens, 100);
661 assert_eq!(req.max_output_tokens, 50);
662 assert_eq!(req.priority, 0);
663 assert!(req.request_id.starts_with("req-"));
664 }
665
666 #[test]
667 fn test_peek_next_arrival() {
668 let workload = create_test_workload("poisson", 1.0, 10);
669 let mut generator = RequestGenerator::new(workload);
670
671 let next_arrival = generator.peek_next_arrival();
672 assert!(next_arrival > 0.0);
673
674 let req = generator.next_if_before(next_arrival + 1.0).unwrap();
676
677 assert_eq!(req.arrival_time, next_arrival);
679 }
680}