inference_lab/request/
generator.rs1use super::Request;
2use crate::config::WorkloadConfig;
3use rand::{rngs::StdRng, Rng, SeedableRng};
4use rand_distr::{Distribution, Exp};
5
6pub struct RequestGenerator {
8 workload: WorkloadConfig,
9 rng: StdRng,
10 next_arrival_time: f64,
11 requests_generated: usize,
12 next_request_id: u64,
13 pending_closed_loop_requests: Vec<f64>,
15}
16
17impl RequestGenerator {
18 pub fn new(workload: WorkloadConfig) -> Self {
19 let mut rng = StdRng::seed_from_u64(workload.seed);
20 let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
21
22 let mut pending_closed_loop_requests = Vec::new();
24 if is_closed_loop {
25 if let Some(num_users) = workload.num_concurrent_users {
26 for _ in 0..num_users {
28 pending_closed_loop_requests.push(0.0);
29 }
30 }
31 }
32
33 let next_arrival_time = if is_closed_loop && !pending_closed_loop_requests.is_empty() {
34 0.0 } else {
36 Self::sample_next_arrival(
37 0.0,
38 &workload.arrival_pattern,
39 workload.arrival_rate,
40 &mut rng,
41 )
42 };
43
44 Self {
45 workload,
46 rng,
47 next_arrival_time,
48 requests_generated: 0,
49 next_request_id: 0,
50 pending_closed_loop_requests,
51 }
52 }
53
54 pub fn next_if_before(&mut self, current_time: f64) -> Option<Request> {
57 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
58
59 if is_closed_loop {
61 if let Some(max_requests) = self.workload.num_requests {
63 if self.requests_generated >= max_requests {
64 self.pending_closed_loop_requests.clear();
66 return None;
67 }
68 }
69
70 if let Some(pos) = self.pending_closed_loop_requests.iter().position(|&t| t <= current_time) {
72 let arrival_time = self.pending_closed_loop_requests.remove(pos);
73
74 let request_id = format!("req-{}", self.next_request_id);
76 self.next_request_id += 1;
77
78 let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
79 let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
80
81 let request = Request::new(
82 request_id,
83 0, arrival_time,
85 num_prompt_tokens,
86 max_output_tokens,
87 );
88
89 self.requests_generated += 1;
90 return Some(request);
91 }
92 return None;
93 }
94
95 if let Some(max_requests) = self.workload.num_requests {
98 if self.requests_generated >= max_requests {
99 return None;
100 }
101 }
102
103 if self.next_arrival_time > current_time {
105 return None;
106 }
107
108 let request_id = format!("req-{}", self.next_request_id);
110 self.next_request_id += 1;
111
112 let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
113 let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
114
115 let request = Request::new(
116 request_id,
117 0, self.next_arrival_time,
119 num_prompt_tokens,
120 max_output_tokens,
121 );
122
123 self.requests_generated += 1;
124
125 self.next_arrival_time = Self::sample_next_arrival(
127 self.next_arrival_time,
128 &self.workload.arrival_pattern,
129 self.workload.arrival_rate,
130 &mut self.rng,
131 );
132
133 Some(request)
134 }
135
136 fn sample_next_arrival(
138 current_time: f64,
139 pattern: &str,
140 rate: f64,
141 rng: &mut StdRng,
142 ) -> f64 {
143 match pattern.to_lowercase().as_str() {
144 "poisson" => {
145 let exp = Exp::new(rate).unwrap();
147 let inter_arrival = exp.sample(rng);
148 current_time + inter_arrival
149 }
150 "uniform" => {
151 let inter_arrival = 1.0 / rate;
153 current_time + inter_arrival
154 }
155 "burst" => {
156 if rng.gen_bool(0.2) {
159 current_time + rng.gen_range(0.001..0.01)
161 } else {
162 current_time + rng.gen_range(0.5..2.0)
163 }
164 }
165 "fixed_rate" => {
166 current_time + 1.0 / rate
168 }
169 "batched" => {
170 0.0
172 }
173 _ => {
174 let exp = Exp::new(rate).unwrap();
176 current_time + exp.sample(rng)
177 }
178 }
179 }
180
181 pub fn is_finished(&self) -> bool {
183 if let Some(max_requests) = self.workload.num_requests {
184 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
185 if is_closed_loop {
186 self.requests_generated >= max_requests && self.pending_closed_loop_requests.is_empty()
189 } else {
190 self.requests_generated >= max_requests
191 }
192 } else {
193 false
194 }
195 }
196
197 pub fn on_request_complete(&mut self, completion_time: f64) {
200 let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
201 if !is_closed_loop {
202 return; }
204
205 if let Some(max_requests) = self.workload.num_requests {
207 if self.requests_generated >= max_requests {
208 return; }
210 }
211
212 self.pending_closed_loop_requests.push(completion_time);
214 }
215
216 pub fn num_generated(&self) -> usize {
218 self.requests_generated
219 }
220
221 pub fn peek_next_arrival(&self) -> f64 {
223 self.next_arrival_time
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use crate::config::LengthDistribution;
231
232 fn create_test_workload(pattern: &str, rate: f64, num_requests: usize) -> WorkloadConfig {
233 WorkloadConfig {
234 arrival_pattern: pattern.to_string(),
235 arrival_rate: rate,
236 num_concurrent_users: None,
237 input_len_dist: LengthDistribution::Fixed { value: 100 },
238 output_len_dist: LengthDistribution::Fixed { value: 50 },
239 num_requests: Some(num_requests),
240 duration_secs: None,
241 seed: 42,
242 }
243 }
244
245 #[test]
246 fn test_generator_creation() {
247 let workload = create_test_workload("poisson", 1.0, 10);
248 let generator = RequestGenerator::new(workload);
249
250 assert_eq!(generator.num_generated(), 0);
251 assert!(!generator.is_finished());
252 }
253
254 #[test]
255 fn test_generate_requests() {
256 let workload = create_test_workload("poisson", 10.0, 5);
257 let mut generator = RequestGenerator::new(workload);
258
259 let mut requests = Vec::new();
260 let mut current_time = 0.0;
261
262 while !generator.is_finished() {
263 current_time += 10.0;
265
266 while let Some(req) = generator.next_if_before(current_time) {
267 requests.push(req);
268 }
269 }
270
271 assert_eq!(requests.len(), 5);
272 assert!(generator.is_finished());
273 }
274
275 #[test]
276 fn test_arrival_ordering() {
277 let workload = create_test_workload("poisson", 5.0, 10);
278 let mut generator = RequestGenerator::new(workload);
279
280 let mut requests = Vec::new();
281 let mut current_time = 0.0;
282
283 while !generator.is_finished() {
284 current_time += 10.0;
285 while let Some(req) = generator.next_if_before(current_time) {
286 requests.push(req);
287 }
288 }
289
290 for i in 1..requests.len() {
292 assert!(requests[i].arrival_time >= requests[i - 1].arrival_time);
293 }
294 }
295
296 #[test]
297 fn test_fixed_rate_arrival() {
298 let workload = create_test_workload("fixed_rate", 2.0, 4);
299 let mut generator = RequestGenerator::new(workload);
300
301 let mut requests = Vec::new();
302 let mut current_time = 0.0;
303
304 while !generator.is_finished() {
305 current_time += 10.0;
306 while let Some(req) = generator.next_if_before(current_time) {
307 requests.push(req);
308 }
309 }
310
311 assert_eq!(requests.len(), 4);
312
313 for i in 1..requests.len() {
315 let inter_arrival = requests[i].arrival_time - requests[i - 1].arrival_time;
316 assert!((inter_arrival - 0.5).abs() < 1e-6);
317 }
318 }
319
320 #[test]
321 fn test_request_properties() {
322 let workload = create_test_workload("poisson", 1.0, 1);
323 let mut generator = RequestGenerator::new(workload);
324
325 let req = generator.next_if_before(10.0).unwrap();
326
327 assert_eq!(req.num_prompt_tokens, 100);
328 assert_eq!(req.max_output_tokens, 50);
329 assert_eq!(req.priority, 0);
330 assert!(req.request_id.starts_with("req-"));
331 }
332
333 #[test]
334 fn test_peek_next_arrival() {
335 let workload = create_test_workload("poisson", 1.0, 10);
336 let mut generator = RequestGenerator::new(workload);
337
338 let next_arrival = generator.peek_next_arrival();
339 assert!(next_arrival > 0.0);
340
341 let req = generator.next_if_before(next_arrival + 1.0).unwrap();
343
344 assert_eq!(req.arrival_time, next_arrival);
346 }
347}