inference_lab/request/
generator.rs

1use super::Request;
2use crate::config::WorkloadConfig;
3use rand::{rngs::StdRng, Rng, SeedableRng};
4use rand_distr::{Distribution, Exp};
5
6/// Generates requests based on workload configuration
7pub struct RequestGenerator {
8    workload: WorkloadConfig,
9    rng: StdRng,
10    next_arrival_time: f64,
11    requests_generated: usize,
12    next_request_id: u64,
13    /// For closed-loop: track pending requests to generate when completions occur
14    pending_closed_loop_requests: Vec<f64>,
15}
16
17impl RequestGenerator {
18    pub fn new(workload: WorkloadConfig) -> Self {
19        let mut rng = StdRng::seed_from_u64(workload.seed);
20        let is_closed_loop = workload.arrival_pattern.to_lowercase() == "closed_loop";
21
22        // For closed-loop, initialize with N requests at time 0
23        let mut pending_closed_loop_requests = Vec::new();
24        if is_closed_loop {
25            if let Some(num_users) = workload.num_concurrent_users {
26                // Generate initial requests for all concurrent users
27                for _ in 0..num_users {
28                    pending_closed_loop_requests.push(0.0);
29                }
30            }
31        }
32
33        let next_arrival_time = if is_closed_loop && !pending_closed_loop_requests.is_empty() {
34            0.0 // Start immediately with the first batch
35        } else {
36            Self::sample_next_arrival(
37                0.0,
38                &workload.arrival_pattern,
39                workload.arrival_rate,
40                &mut rng,
41            )
42        };
43
44        Self {
45            workload,
46            rng,
47            next_arrival_time,
48            requests_generated: 0,
49            next_request_id: 0,
50            pending_closed_loop_requests,
51        }
52    }
53
54    /// Get the next request if its arrival time is before the given time
55    /// Returns None if no request is ready or all requests have been generated
56    pub fn next_if_before(&mut self, current_time: f64) -> Option<Request> {
57        let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
58
59        // For closed-loop, check pending requests
60        if is_closed_loop {
61            // Check if we've generated all requests
62            if let Some(max_requests) = self.workload.num_requests {
63                if self.requests_generated >= max_requests {
64                    // Clear any remaining pending requests that won't be used
65                    self.pending_closed_loop_requests.clear();
66                    return None;
67                }
68            }
69
70            // Find the earliest pending request that has arrived
71            if let Some(pos) = self.pending_closed_loop_requests.iter().position(|&t| t <= current_time) {
72                let arrival_time = self.pending_closed_loop_requests.remove(pos);
73
74                // Generate request
75                let request_id = format!("req-{}", self.next_request_id);
76                self.next_request_id += 1;
77
78                let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
79                let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
80
81                let request = Request::new(
82                    request_id,
83                    0, // Default priority
84                    arrival_time,
85                    num_prompt_tokens,
86                    max_output_tokens,
87                );
88
89                self.requests_generated += 1;
90                return Some(request);
91            }
92            return None;
93        }
94
95        // Original logic for non-closed-loop patterns
96        // Check if we've generated all requests
97        if let Some(max_requests) = self.workload.num_requests {
98            if self.requests_generated >= max_requests {
99                return None;
100            }
101        }
102
103        // Check if next request has arrived
104        if self.next_arrival_time > current_time {
105            return None;
106        }
107
108        // Generate request
109        let request_id = format!("req-{}", self.next_request_id);
110        self.next_request_id += 1;
111
112        let num_prompt_tokens = self.workload.input_len_dist.sample(&mut self.rng);
113        let max_output_tokens = self.workload.output_len_dist.sample(&mut self.rng);
114
115        let request = Request::new(
116            request_id,
117            0, // Default priority
118            self.next_arrival_time,
119            num_prompt_tokens,
120            max_output_tokens,
121        );
122
123        self.requests_generated += 1;
124
125        // Sample next arrival time
126        self.next_arrival_time = Self::sample_next_arrival(
127            self.next_arrival_time,
128            &self.workload.arrival_pattern,
129            self.workload.arrival_rate,
130            &mut self.rng,
131        );
132
133        Some(request)
134    }
135
136    /// Sample the next arrival time based on the arrival pattern
137    fn sample_next_arrival(
138        current_time: f64,
139        pattern: &str,
140        rate: f64,
141        rng: &mut StdRng,
142    ) -> f64 {
143        match pattern.to_lowercase().as_str() {
144            "poisson" => {
145                // Poisson process: inter-arrival times are exponentially distributed
146                let exp = Exp::new(rate).unwrap();
147                let inter_arrival = exp.sample(rng);
148                current_time + inter_arrival
149            }
150            "uniform" => {
151                // Uniform: constant inter-arrival time
152                let inter_arrival = 1.0 / rate;
153                current_time + inter_arrival
154            }
155            "burst" => {
156                // Burst: requests arrive in bursts with gaps
157                // Simple implementation: alternate between fast and slow
158                if rng.gen_bool(0.2) {
159                    // 20% chance of burst
160                    current_time + rng.gen_range(0.001..0.01)
161                } else {
162                    current_time + rng.gen_range(0.5..2.0)
163                }
164            }
165            "fixed_rate" => {
166                // Fixed rate: exact inter-arrival time
167                current_time + 1.0 / rate
168            }
169            "batched" => {
170                // Batched: all requests arrive at time 0
171                0.0
172            }
173            _ => {
174                // Default to Poisson
175                let exp = Exp::new(rate).unwrap();
176                current_time + exp.sample(rng)
177            }
178        }
179    }
180
181    /// Check if all requests have been generated
182    pub fn is_finished(&self) -> bool {
183        if let Some(max_requests) = self.workload.num_requests {
184            let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
185            if is_closed_loop {
186                // For closed-loop, we're finished when we've generated max_requests
187                // AND have no pending requests
188                self.requests_generated >= max_requests && self.pending_closed_loop_requests.is_empty()
189            } else {
190                self.requests_generated >= max_requests
191            }
192        } else {
193            false
194        }
195    }
196
197    /// Called when a request completes (for closed-loop pattern)
198    /// Generates a new request for that "user slot" at the completion time
199    pub fn on_request_complete(&mut self, completion_time: f64) {
200        let is_closed_loop = self.workload.arrival_pattern.to_lowercase() == "closed_loop";
201        if !is_closed_loop {
202            return; // Only applicable to closed-loop
203        }
204
205        // Check if we should generate more requests
206        if let Some(max_requests) = self.workload.num_requests {
207            if self.requests_generated >= max_requests {
208                return; // Already generated all requested requests
209            }
210        }
211
212        // Add a new pending request at the completion time
213        self.pending_closed_loop_requests.push(completion_time);
214    }
215
216    /// Get number of requests generated so far
217    pub fn num_generated(&self) -> usize {
218        self.requests_generated
219    }
220
221    /// Peek at the next arrival time without generating the request
222    pub fn peek_next_arrival(&self) -> f64 {
223        self.next_arrival_time
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::config::LengthDistribution;
231
232    fn create_test_workload(pattern: &str, rate: f64, num_requests: usize) -> WorkloadConfig {
233        WorkloadConfig {
234            arrival_pattern: pattern.to_string(),
235            arrival_rate: rate,
236            num_concurrent_users: None,
237            input_len_dist: LengthDistribution::Fixed { value: 100 },
238            output_len_dist: LengthDistribution::Fixed { value: 50 },
239            num_requests: Some(num_requests),
240            duration_secs: None,
241            seed: 42,
242        }
243    }
244
245    #[test]
246    fn test_generator_creation() {
247        let workload = create_test_workload("poisson", 1.0, 10);
248        let generator = RequestGenerator::new(workload);
249
250        assert_eq!(generator.num_generated(), 0);
251        assert!(!generator.is_finished());
252    }
253
254    #[test]
255    fn test_generate_requests() {
256        let workload = create_test_workload("poisson", 10.0, 5);
257        let mut generator = RequestGenerator::new(workload);
258
259        let mut requests = Vec::new();
260        let mut current_time = 0.0;
261
262        while !generator.is_finished() {
263            // Advance time significantly to ensure all requests arrive
264            current_time += 10.0;
265
266            while let Some(req) = generator.next_if_before(current_time) {
267                requests.push(req);
268            }
269        }
270
271        assert_eq!(requests.len(), 5);
272        assert!(generator.is_finished());
273    }
274
275    #[test]
276    fn test_arrival_ordering() {
277        let workload = create_test_workload("poisson", 5.0, 10);
278        let mut generator = RequestGenerator::new(workload);
279
280        let mut requests = Vec::new();
281        let mut current_time = 0.0;
282
283        while !generator.is_finished() {
284            current_time += 10.0;
285            while let Some(req) = generator.next_if_before(current_time) {
286                requests.push(req);
287            }
288        }
289
290        // Check that arrival times are monotonically increasing
291        for i in 1..requests.len() {
292            assert!(requests[i].arrival_time >= requests[i - 1].arrival_time);
293        }
294    }
295
296    #[test]
297    fn test_fixed_rate_arrival() {
298        let workload = create_test_workload("fixed_rate", 2.0, 4);
299        let mut generator = RequestGenerator::new(workload);
300
301        let mut requests = Vec::new();
302        let mut current_time = 0.0;
303
304        while !generator.is_finished() {
305            current_time += 10.0;
306            while let Some(req) = generator.next_if_before(current_time) {
307                requests.push(req);
308            }
309        }
310
311        assert_eq!(requests.len(), 4);
312
313        // Check that inter-arrival times are approximately 1/rate = 0.5 seconds
314        for i in 1..requests.len() {
315            let inter_arrival = requests[i].arrival_time - requests[i - 1].arrival_time;
316            assert!((inter_arrival - 0.5).abs() < 1e-6);
317        }
318    }
319
320    #[test]
321    fn test_request_properties() {
322        let workload = create_test_workload("poisson", 1.0, 1);
323        let mut generator = RequestGenerator::new(workload);
324
325        let req = generator.next_if_before(10.0).unwrap();
326
327        assert_eq!(req.num_prompt_tokens, 100);
328        assert_eq!(req.max_output_tokens, 50);
329        assert_eq!(req.priority, 0);
330        assert!(req.request_id.starts_with("req-"));
331    }
332
333    #[test]
334    fn test_peek_next_arrival() {
335        let workload = create_test_workload("poisson", 1.0, 10);
336        let mut generator = RequestGenerator::new(workload);
337
338        let next_arrival = generator.peek_next_arrival();
339        assert!(next_arrival > 0.0);
340
341        // Generate the request
342        let req = generator.next_if_before(next_arrival + 1.0).unwrap();
343
344        // Check that arrival time matches what we peeked
345        assert_eq!(req.arrival_time, next_arrival);
346    }
347}