inference_lab/request/
generator.rs1use super::Request;
2use crate::config::WorkloadConfig;
3use rand::{rngs::StdRng, Rng, SeedableRng};
4use rand_distr::{Distribution, Exp};
5
6pub struct RequestGenerator {
8 workload: WorkloadConfig,
9 rng: StdRng,
10 next_arrival_time: f64,
11 requests_generated: usize,
12 next_request_id: u64,
13 pending_closed_loop_requests: Vec<f64>,
15}
16
17impl RequestGenerator {
18 pub fn new(workload: WorkloadConfig) -> Self {
19 let mut rng = StdRng::seed_from_u64(workload.seed);
20 let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
21
22 let mut pending_closed_loop_requests = Vec::new();
24 if is_closed_loop {
25 if let Some(num_users) = workload.num_concurrent_users {
26 for _ in 0..num_users {
28 pending_closed_loop_requests.push(0.0);
29 }
30 }
31 }
32
33 let next_arrival_time = if is_closed_loop && !pending_closed_loop_requests.is_empty() {
34 0.0 } else {
36 Self::sample_next_arrival(
37 0.0,
38 &workload.arrival_pattern,
39 workload.arrival_rate,
40 &mut rng,
41 )
42 };
43
44 Self {
45 workload,
46 rng,
47 next_arrival_time,
48 requests_generated: 0,
49 next_request_id: 0,
50 pending_closed_loop_requests,
51 }
52 }
53
54 pub fn next_if_before(&mut self, current_time: f64) -> Option<Request> {
57 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
58
59 if is_closed_loop {
61 if let Some(max_requests) = self.workload.num_requests {
63 if self.requests_generated >= max_requests {
64 self.pending_closed_loop_requests.clear();
66 return None;
67 }
68 }
69
70 if let Some(pos) = self
72 .pending_closed_loop_requests
73 .iter()
74 .position(|&t| t <= current_time)
75 {
76 let arrival_time = self.pending_closed_loop_requests.remove(pos);
77
78 let request_id = format!("req-{}", self.next_request_id);
80 self.next_request_id += 1;
81
82 let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
83 let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
84
85 let request = Request::new(
86 request_id,
87 0, arrival_time,
89 num_prompt_tokens,
90 max_output_tokens,
91 );
92
93 self.requests_generated += 1;
94 return Some(request);
95 }
96 return None;
97 }
98
99 if let Some(max_requests) = self.workload.num_requests {
102 if self.requests_generated >= max_requests {
103 return None;
104 }
105 }
106
107 if self.next_arrival_time > current_time {
109 return None;
110 }
111
112 let request_id = format!("req-{}", self.next_request_id);
114 self.next_request_id += 1;
115
116 let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
117 let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
118
119 let request = Request::new(
120 request_id,
121 0, self.next_arrival_time,
123 num_prompt_tokens,
124 max_output_tokens,
125 );
126
127 self.requests_generated += 1;
128
129 self.next_arrival_time = Self::sample_next_arrival(
131 self.next_arrival_time,
132 &self.workload.arrival_pattern,
133 self.workload.arrival_rate,
134 &mut self.rng,
135 );
136
137 Some(request)
138 }
139
140 fn sample_next_arrival(current_time: f64, pattern: &str, rate: f64, rng: &mut StdRng) -> f64 {
142 match pattern.to_lowercase().as_str() {
143 "poisson" => {
144 let exp = Exp::new(rate).unwrap();
146 let inter_arrival = exp.sample(rng);
147 current_time + inter_arrival
148 }
149 "uniform" => {
150 let inter_arrival = 1.0 / rate;
152 current_time + inter_arrival
153 }
154 "burst" => {
155 if rng.gen_bool(0.2) {
158 current_time + rng.gen_range(0.001..0.01)
160 } else {
161 current_time + rng.gen_range(0.5..2.0)
162 }
163 }
164 "fixed_rate" => {
165 current_time + 1.0 / rate
167 }
168 "batched" => {
169 0.0
171 }
172 _ => {
173 let exp = Exp::new(rate).unwrap();
175 current_time + exp.sample(rng)
176 }
177 }
178 }
179
180 pub fn is_finished(&self) -> bool {
182 if let Some(max_requests) = self.workload.num_requests {
183 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
184 if is_closed_loop {
185 self.requests_generated >= max_requests
188 && self.pending_closed_loop_requests.is_empty()
189 } else {
190 self.requests_generated >= max_requests
191 }
192 } else {
193 false
194 }
195 }
196
197 pub fn on_request_complete(&mut self, completion_time: f64) {
200 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
201 if !is_closed_loop {
202 return; }
204
205 if let Some(max_requests) = self.workload.num_requests {
207 if self.requests_generated >= max_requests {
208 return; }
210 }
211
212 self.pending_closed_loop_requests.push(completion_time);
214 }
215
216 pub fn num_generated(&self) -> usize {
218 self.requests_generated
219 }
220
221 pub fn peek_next_arrival(&self) -> f64 {
223 self.next_arrival_time
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use crate::config::LengthDistribution;
231
232 fn create_test_workload(pattern: &str, rate: f64, num_requests: usize) -> WorkloadConfig {
233 WorkloadConfig {
234 arrival_pattern: pattern.to_string(),
235 arrival_rate: rate,
236 num_concurrent_users: None,
237 input_len_dist: LengthDistribution::Fixed { value: 100 },
238 output_len_dist: LengthDistribution::Fixed { value: 50 },
239 num_requests: Some(num_requests),
240 duration_secs: None,
241 seed: 42,
242 }
243 }
244
245 #[test]
246 fn test_generator_creation() {
247 let workload = create_test_workload("poisson", 1.0, 10);
248 let generator = RequestGenerator::new(workload);
249
250 assert_eq!(generator.num_generated(), 0);
251 assert!(!generator.is_finished());
252 }
253
254 #[test]
255 fn test_generate_requests() {
256 let workload = create_test_workload("poisson", 10.0, 5);
257 let mut generator = RequestGenerator::new(workload);
258
259 let mut requests = Vec::new();
260 let mut current_time = 0.0;
261
262 while !generator.is_finished() {
263 current_time += 10.0;
265
266 while let Some(req) = generator.next_if_before(current_time) {
267 requests.push(req);
268 }
269 }
270
271 assert_eq!(requests.len(), 5);
272 assert!(generator.is_finished());
273 }
274
275 #[test]
276 fn test_arrival_ordering() {
277 let workload = create_test_workload("poisson", 5.0, 10);
278 let mut generator = RequestGenerator::new(workload);
279
280 let mut requests = Vec::new();
281 let mut current_time = 0.0;
282
283 while !generator.is_finished() {
284 current_time += 10.0;
285 while let Some(req) = generator.next_if_before(current_time) {
286 requests.push(req);
287 }
288 }
289
290 for i in 1..requests.len() {
292 assert!(requests[i].arrival_time >= requests[i - 1].arrival_time);
293 }
294 }
295
296 #[test]
297 fn test_fixed_rate_arrival() {
298 let workload = create_test_workload("fixed_rate", 2.0, 4);
299 let mut generator = RequestGenerator::new(workload);
300
301 let mut requests = Vec::new();
302 let mut current_time = 0.0;
303
304 while !generator.is_finished() {
305 current_time += 10.0;
306 while let Some(req) = generator.next_if_before(current_time) {
307 requests.push(req);
308 }
309 }
310
311 assert_eq!(requests.len(), 4);
312
313 for i in 1..requests.len() {
315 let inter_arrival = requests[i].arrival_time - requests[i - 1].arrival_time;
316 assert!((inter_arrival - 0.5).abs() < 1e-6);
317 }
318 }
319
320 #[test]
321 fn test_request_properties() {
322 let workload = create_test_workload("poisson", 1.0, 1);
323 let mut generator = RequestGenerator::new(workload);
324
325 let req = generator.next_if_before(10.0).unwrap();
326
327 assert_eq!(req.num_prompt_tokens, 100);
328 assert_eq!(req.max_output_tokens, 50);
329 assert_eq!(req.priority, 0);
330 assert!(req.request_id.starts_with("req-"));
331 }
332
333 #[test]
334 fn test_peek_next_arrival() {
335 let workload = create_test_workload("poisson", 1.0, 10);
336 let mut generator = RequestGenerator::new(workload);
337
338 let next_arrival = generator.peek_next_arrival();
339 assert!(next_arrival > 0.0);
340
341 let req = generator.next_if_before(next_arrival + 1.0).unwrap();
343
344 assert_eq!(req.arrival_time, next_arrival);
346 }
347}