1use super::Request;
2use crate::config::WorkloadConfig;
3use crate::dataset::{BatchTokenizerFn, DatasetEntry, UnparsedEntry};
4use rand::{rngs::StdRng, Rng, SeedableRng};
5use rand_distr::{Distribution, Exp};
6use std::collections::hash_map::DefaultHasher;
7use std::hash::{Hash, Hasher};
8use std::sync::mpsc::{sync_channel, Receiver};
9use std::thread;
10
11pub struct RequestGenerator {
13 workload: WorkloadConfig,
14 rng: StdRng,
15 next_arrival_time: f64,
16 requests_generated: usize,
17 next_request_id: u64,
18 pending_closed_loop_requests: Vec<f64>,
20 dataset_receiver: Option<Receiver<Option<DatasetEntry>>>,
22 dataset_exhausted: bool,
24}
25
26impl RequestGenerator {
27 pub fn new(workload: WorkloadConfig) -> Self {
28 let mut rng = StdRng::seed_from_u64(workload.seed);
29 let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
30
31 let mut pending_closed_loop_requests = Vec::new();
33 if is_closed_loop {
34 if let Some(num_users) = workload.num_concurrent_users {
35 pending_closed_loop_requests = vec![0.0; num_users]
37 }
38 };
39
40 let next_arrival_time = if is_closed_loop && !pending_closed_loop_requests.is_empty() {
41 0.0 } else {
43 Self::sample_next_arrival(
44 0.0,
45 &workload.arrival_pattern,
46 workload.arrival_rate,
47 &mut rng,
48 )
49 };
50
51 Self {
52 workload,
53 rng,
54 next_arrival_time,
55 requests_generated: 0,
56 next_request_id: 0,
57 pending_closed_loop_requests,
58 dataset_receiver: None,
59 dataset_exhausted: false,
60 }
61 }
62
63 pub fn from_dataset<I>(
67 workload: WorkloadConfig,
68 dataset_iterator: I,
69 _total_entries: Option<usize>,
70 tokenizer: BatchTokenizerFn,
71 ) -> Self
72 where
73 I: Iterator<Item = Result<Option<UnparsedEntry>, Box<dyn std::error::Error>>>
74 + Send
75 + 'static,
76 {
77 let rng = StdRng::seed_from_u64(workload.seed);
78 let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
79
80 let mut pending_closed_loop_requests = Vec::new();
82 if is_closed_loop {
83 if let Some(num_users) = workload.num_concurrent_users {
84 pending_closed_loop_requests = vec![0.0; num_users]
85 }
86 };
87
88 let next_arrival_time = 0.0; let (sender, receiver) = sync_channel::<Option<DatasetEntry>>(5000);
93
94 thread::spawn(move || {
95 let batch_size: usize = std::env::var("TOKENIZER_BATCH_SIZE")
96 .ok()
97 .and_then(|s| s.parse().ok())
98 .unwrap_or(32); let mut batch = Vec::with_capacity(batch_size);
100
101 for result in dataset_iterator {
102 match result {
103 Ok(Some(unparsed)) => {
104 batch.push(unparsed);
105
106 if batch.len() >= batch_size
108 && Self::tokenize_and_send_batch(&mut batch, &tokenizer, &sender)
109 .is_err()
110 {
111 break;
113 }
114 }
115 Ok(None) => {
116 if !batch.is_empty() {
118 let _ = Self::tokenize_and_send_batch(&mut batch, &tokenizer, &sender);
119 }
120 let _ = sender.send(None);
121 break;
122 }
123 Err(e) => {
124 eprintln!("Error loading dataset entry: {}", e);
125 break;
126 }
127 }
128 }
129 });
130
131 Self {
132 workload,
133 rng,
134 next_arrival_time,
135 requests_generated: 0,
136 next_request_id: 0,
137 pending_closed_loop_requests,
138 dataset_receiver: Some(receiver),
139 dataset_exhausted: false,
140 }
141 }
142
143 pub fn is_dataset_mode(&self) -> bool {
145 self.dataset_receiver.is_some()
146 }
147
148 fn tokenize_and_send_batch(
151 batch: &mut Vec<UnparsedEntry>,
152 tokenizer: &BatchTokenizerFn,
153 sender: &std::sync::mpsc::SyncSender<Option<DatasetEntry>>,
154 ) -> Result<(), ()> {
155 if batch.is_empty() {
156 return Ok(());
157 }
158
159 let prompt_inputs: Vec<_> = batch.iter().map(|e| e.prompt_input.clone()).collect();
161 let all_tokens = match tokenizer(&prompt_inputs) {
162 Ok(tokens) => tokens,
163 Err(e) => {
164 eprintln!("Batch tokenization failed: {}", e);
165 return Err(());
166 }
167 };
168
169 for (unparsed, prompt_tokens) in batch.drain(..).zip(all_tokens.into_iter()) {
171 let entry = DatasetEntry {
172 request_id: unparsed.request_id,
173 prompt_tokens,
174 max_output_tokens: unparsed.max_output_tokens,
175 };
176
177 if sender.send(Some(entry)).is_err() {
179 return Err(());
180 }
181 }
182 Ok(())
183 }
184
185 pub fn peek_next_arrival_time(&self) -> f64 {
187 self.next_arrival_time
188 }
189
190 fn compute_block_hashes(tokens: &[u32], block_size: usize) -> Vec<u64> {
193 let num_blocks = tokens.len().div_ceil(block_size);
194 let mut hashes = Vec::with_capacity(num_blocks);
195
196 for block_idx in 0..num_blocks {
197 let end = ((block_idx + 1) * block_size).min(tokens.len());
198 let block_tokens = &tokens[..end]; let mut hasher = DefaultHasher::new();
202 block_tokens.hash(&mut hasher);
203 hashes.push(hasher.finish());
204 }
205
206 hashes
207 }
208
209 pub fn next_if_before(&mut self, current_time: f64) -> Option<Request> {
212 if self.is_dataset_mode() {
214 return self.next_from_dataset(current_time);
215 }
216
217 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
218
219 if is_closed_loop {
221 if let Some(max_requests) = self.workload.num_requests {
223 if self.requests_generated >= max_requests {
224 self.pending_closed_loop_requests.clear();
226 return None;
227 }
228 }
229
230 if let Some(pos) = self
232 .pending_closed_loop_requests
233 .iter()
234 .position(|&t| t <= current_time)
235 {
236 let arrival_time = self.pending_closed_loop_requests.remove(pos);
237
238 let request_id = format!("req-{}", self.next_request_id);
240 self.next_request_id += 1;
241
242 let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
243 let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
244
245 let mut request = Request::new(
246 request_id,
247 0, arrival_time,
249 num_prompt_tokens,
250 max_output_tokens,
251 );
252
253 let num_blocks = num_prompt_tokens.div_ceil(16) as usize; request.prompt_block_hashes = (0..num_blocks)
258 .map(|_| self.rng.gen_range(0..u64::MAX))
259 .collect();
260
261 self.requests_generated += 1;
262 return Some(request);
263 }
264 return None;
265 }
266
267 if let Some(max_requests) = self.workload.num_requests {
270 if self.requests_generated >= max_requests {
271 return None;
272 }
273 }
274
275 if self.next_arrival_time > current_time {
277 return None;
278 }
279
280 let request_id = format!("req-{}", self.next_request_id);
282 self.next_request_id += 1;
283
284 let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
285 let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
286
287 let mut request = Request::new(
288 request_id,
289 0, self.next_arrival_time,
291 num_prompt_tokens,
292 max_output_tokens,
293 );
294
295 let num_blocks = num_prompt_tokens.div_ceil(16) as usize; request.prompt_block_hashes = (0..num_blocks)
300 .map(|_| self.rng.gen_range(0..u64::MAX))
301 .collect();
302
303 self.requests_generated += 1;
304
305 self.next_arrival_time = Self::sample_next_arrival(
307 self.next_arrival_time,
308 &self.workload.arrival_pattern,
309 self.workload.arrival_rate,
310 &mut self.rng,
311 );
312
313 Some(request)
314 }
315
316 fn sample_next_arrival(current_time: f64, pattern: &str, rate: f64, rng: &mut StdRng) -> f64 {
318 match pattern.to_lowercase().as_str() {
319 "poisson" => {
320 let exp = Exp::new(rate).unwrap();
322 let inter_arrival = exp.sample(rng);
323 current_time + inter_arrival
324 }
325 "uniform" => {
326 let inter_arrival = 1.0 / rate;
328 current_time + inter_arrival
329 }
330 "burst" => {
331 if rng.gen_bool(0.2) {
334 current_time + rng.gen_range(0.001..0.01)
336 } else {
337 current_time + rng.gen_range(0.5..2.0)
338 }
339 }
340 "fixed_rate" => {
341 current_time + 1.0 / rate
343 }
344 "batched" => {
345 0.0
347 }
348 _ => {
349 let exp = Exp::new(rate).unwrap();
351 current_time + exp.sample(rng)
352 }
353 }
354 }
355
356 fn next_from_dataset(&mut self, current_time: f64) -> Option<Request> {
358 if self.dataset_exhausted {
360 return None;
361 }
362
363 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
364
365 if is_closed_loop {
367 if let Some(max_requests) = self.workload.num_requests {
369 if self.requests_generated >= max_requests {
370 self.pending_closed_loop_requests.clear();
372 return None;
373 }
374 }
375
376 if let Some(pos) = self
378 .pending_closed_loop_requests
379 .iter()
380 .position(|&t| t <= current_time)
381 {
382 let arrival_time = self.pending_closed_loop_requests.remove(pos);
383
384 let entry = match self.dataset_receiver.as_ref()?.recv() {
386 Ok(Some(e)) => e,
387 Ok(None) => {
388 self.dataset_exhausted = true;
390 return None;
391 }
392 Err(_) => {
393 self.dataset_exhausted = true;
395 return None;
396 }
397 };
398
399 let sampled_output_len = self.workload.output_len_dist.sample(&mut self.rng);
401
402 let max_output_tokens = entry.max_output_tokens.unwrap_or(16384);
404 let target_output_tokens = sampled_output_len.min(max_output_tokens);
405
406 let mut request = Request::new_with_target(
407 entry.request_id.clone(),
408 0,
409 arrival_time,
410 entry.num_prompt_tokens(),
411 max_output_tokens,
412 target_output_tokens,
413 );
414
415 request.prompt_block_hashes = Self::compute_block_hashes(&entry.prompt_tokens, 16);
416 self.requests_generated += 1;
417
418 return Some(request);
419 }
420 return None;
421 }
422
423 if self.next_arrival_time > current_time {
426 return None;
427 }
428
429 let entry = match self.dataset_receiver.as_ref()?.recv() {
431 Ok(Some(e)) => e,
432 Ok(None) => {
433 self.dataset_exhausted = true;
435 return None;
436 }
437 Err(_) => {
438 self.dataset_exhausted = true;
440 return None;
441 }
442 };
443
444 let arrival_time = self.next_arrival_time;
445
446 let sampled_output_len = self.workload.output_len_dist.sample(&mut self.rng);
448
449 let max_output_tokens = entry.max_output_tokens.unwrap_or(16384);
451 let target_output_tokens = sampled_output_len.min(max_output_tokens);
452
453 let mut request = Request::new_with_target(
454 entry.request_id.clone(),
455 0,
456 arrival_time,
457 entry.num_prompt_tokens(),
458 max_output_tokens,
459 target_output_tokens,
460 );
461
462 request.prompt_block_hashes = Self::compute_block_hashes(&entry.prompt_tokens, 16);
463 self.requests_generated += 1;
464
465 let should_sample_next = if let Some(max_requests) = self.workload.num_requests {
469 self.requests_generated < max_requests
470 } else {
471 true
473 };
474
475 if should_sample_next {
476 self.next_arrival_time = Self::sample_next_arrival(
477 self.next_arrival_time,
478 &self.workload.arrival_pattern,
479 self.workload.arrival_rate,
480 &mut self.rng,
481 );
482 }
483
484 Some(request)
485 }
486
487 pub fn is_finished(&self) -> bool {
489 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
490
491 if self.is_dataset_mode() {
493 if let Some(max_requests) = self.workload.num_requests {
495 if is_closed_loop {
496 return (self.requests_generated >= max_requests
499 && self.pending_closed_loop_requests.is_empty())
500 || self.dataset_exhausted;
501 } else {
502 return self.requests_generated >= max_requests;
503 }
504 }
505 return self.dataset_exhausted;
507 }
508
509 if let Some(max_requests) = self.workload.num_requests {
511 if is_closed_loop {
512 self.requests_generated >= max_requests
515 && self.pending_closed_loop_requests.is_empty()
516 } else {
517 self.requests_generated >= max_requests
518 }
519 } else {
520 false
521 }
522 }
523
524 pub fn on_request_complete(&mut self, completion_time: f64) {
527 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
528 if !is_closed_loop {
529 return; }
531
532 if let Some(max_requests) = self.workload.num_requests {
534 if self.requests_generated >= max_requests {
535 return; }
537 }
538
539 self.pending_closed_loop_requests.push(completion_time);
541 }
542
543 pub fn num_generated(&self) -> usize {
545 self.requests_generated
546 }
547
548 pub fn peek_next_arrival(&self) -> f64 {
550 self.next_arrival_time
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557 use crate::config::LengthDistribution;
558
559 fn create_test_workload(pattern: &str, rate: f64, num_requests: usize) -> WorkloadConfig {
560 WorkloadConfig {
561 dataset_path: None,
562 arrival_pattern: pattern.to_string(),
563 arrival_rate: rate,
564 num_concurrent_users: None,
565 input_len_dist: LengthDistribution::Fixed { value: 100 },
566 output_len_dist: LengthDistribution::Fixed { value: 50 },
567 num_requests: Some(num_requests),
568 duration_secs: None,
569 seed: 42,
570 }
571 }
572
573 #[test]
574 fn test_generator_creation() {
575 let workload = create_test_workload("poisson", 1.0, 10);
576 let generator = RequestGenerator::new(workload);
577
578 assert_eq!(generator.num_generated(), 0);
579 assert!(!generator.is_finished());
580 }
581
582 #[test]
583 fn test_generate_requests() {
584 let workload = create_test_workload("poisson", 10.0, 5);
585 let mut generator = RequestGenerator::new(workload);
586
587 let mut requests = Vec::new();
588 let mut current_time = 0.0;
589
590 while !generator.is_finished() {
591 current_time += 10.0;
593
594 while let Some(req) = generator.next_if_before(current_time) {
595 requests.push(req);
596 }
597 }
598
599 assert_eq!(requests.len(), 5);
600 assert!(generator.is_finished());
601 }
602
603 #[test]
604 fn test_arrival_ordering() {
605 let workload = create_test_workload("poisson", 5.0, 10);
606 let mut generator = RequestGenerator::new(workload);
607
608 let mut requests = Vec::new();
609 let mut current_time = 0.0;
610
611 while !generator.is_finished() {
612 current_time += 10.0;
613 while let Some(req) = generator.next_if_before(current_time) {
614 requests.push(req);
615 }
616 }
617
618 for i in 1..requests.len() {
620 assert!(requests[i].arrival_time >= requests[i - 1].arrival_time);
621 }
622 }
623
624 #[test]
625 fn test_fixed_rate_arrival() {
626 let workload = create_test_workload("fixed_rate", 2.0, 4);
627 let mut generator = RequestGenerator::new(workload);
628
629 let mut requests = Vec::new();
630 let mut current_time = 0.0;
631
632 while !generator.is_finished() {
633 current_time += 10.0;
634 while let Some(req) = generator.next_if_before(current_time) {
635 requests.push(req);
636 }
637 }
638
639 assert_eq!(requests.len(), 4);
640
641 for i in 1..requests.len() {
643 let inter_arrival = requests[i].arrival_time - requests[i - 1].arrival_time;
644 assert!((inter_arrival - 0.5).abs() < 1e-6);
645 }
646 }
647
648 #[test]
649 fn test_request_properties() {
650 let workload = create_test_workload("poisson", 1.0, 1);
651 let mut generator = RequestGenerator::new(workload);
652
653 let req = generator.next_if_before(10.0).unwrap();
654
655 assert_eq!(req.num_prompt_tokens, 100);
656 assert_eq!(req.max_output_tokens, 50);
657 assert_eq!(req.priority, 0);
658 assert!(req.request_id.starts_with("req-"));
659 }
660
661 #[test]
662 fn test_peek_next_arrival() {
663 let workload = create_test_workload("poisson", 1.0, 10);
664 let mut generator = RequestGenerator::new(workload);
665
666 let next_arrival = generator.peek_next_arrival();
667 assert!(next_arrival > 0.0);
668
669 let req = generator.next_if_before(next_arrival + 1.0).unwrap();
671
672 assert_eq!(req.arrival_time, next_arrival);
674 }
675}