inference_lab/request/
request.rs

1use super::status::RequestStatus;
2use crate::config::ModelConfig;
3
4pub type BlockId = u32;
5
6/// Request represents a single inference request in the simulation
7#[derive(Debug, Clone)]
8pub struct Request {
9    /// Unique request ID
10    pub request_id: String,
11
12    /// Client priority (lower = higher priority)
13    pub priority: i32,
14
15    /// Arrival time (simulated time)
16    pub arrival_time: f64,
17
18    /// Request status
19    pub status: RequestStatus,
20
21    /// Number of input tokens
22    pub num_prompt_tokens: u32,
23
24    /// Maximum number of output tokens to generate
25    pub max_output_tokens: u32,
26
27    /// Actual number of output tokens to generate (sampled, may be less than max)
28    /// This simulates hitting an EOS token
29    pub target_output_tokens: u32,
30
31    /// Number of tokens computed so far
32    pub num_computed_tokens: u32,
33
34    /// Number of output tokens generated so far
35    pub num_output_tokens: u32,
36
37    /// Total tokens (prompt + output)
38    pub num_tokens: u32,
39
40    /// Number of prefix-cached tokens (set by cache manager)
41    pub num_cached_tokens: u32,
42
43    /// Synthetic block hashes for prefix caching modeling
44    /// In synthetic mode: pre-generated hashes (some shared, some unique)
45    /// In semantic mode: will be computed from actual token content
46    pub prompt_block_hashes: Vec<u64>,
47
48    /// KV cache blocks allocated to this request
49    pub kv_blocks: Vec<BlockId>,
50
51    /// Number of times this request has been preempted
52    pub num_preemptions: u32,
53
54    /// Time when first token was generated (TTFT tracking)
55    pub first_token_time: Option<f64>,
56
57    /// Time when request completed
58    pub completion_time: Option<f64>,
59
60    /// Per-token generation times
61    pub token_generation_times: Vec<f64>,
62
63    /// Time spent preempted (not running)
64    pub preempted_time: f64,
65
66    /// Last preemption start time
67    pub last_preempted_at: Option<f64>,
68}
69
70impl Request {
71    /// Create a new request with a target output length
72    pub fn new_with_target(
73        request_id: String,
74        priority: i32,
75        arrival_time: f64,
76        num_prompt_tokens: u32,
77        max_output_tokens: u32,
78        target_output_tokens: u32,
79    ) -> Self {
80        Self {
81            request_id,
82            priority,
83            arrival_time,
84            status: RequestStatus::Waiting,
85            num_prompt_tokens,
86            max_output_tokens,
87            target_output_tokens,
88            num_computed_tokens: 0,
89            num_output_tokens: 0,
90            num_tokens: num_prompt_tokens + target_output_tokens,
91            num_cached_tokens: 0,
92            prompt_block_hashes: Vec::new(),
93            kv_blocks: Vec::new(),
94            num_preemptions: 0,
95            first_token_time: None,
96            completion_time: None,
97            token_generation_times: Vec::new(),
98            preempted_time: 0.0,
99            last_preempted_at: None,
100        }
101    }
102
103    /// Create a new request (target = max)
104    pub fn new(
105        request_id: String,
106        priority: i32,
107        arrival_time: f64,
108        num_prompt_tokens: u32,
109        max_output_tokens: u32,
110    ) -> Self {
111        Self::new_with_target(
112            request_id,
113            priority,
114            arrival_time,
115            num_prompt_tokens,
116            max_output_tokens,
117            max_output_tokens, // Target = max (used for synthetic workloads)
118        )
119    }
120
121    /// Get block hashes for the prompt
122    /// These should be thought of as 'incremental hashes' - i.e. the hash of block n is the hash
123    /// of all the tokens up to that block (not just that block alone).
124    /// In synthetic mode: returns pre-generated hashes
125    /// In semantic mode: will compute from actual token content
126    pub fn get_prompt_block_hashes(&self) -> &[u64] {
127        &self.prompt_block_hashes
128    }
129
130    /// Check if this is in prefill phase
131    pub fn is_prefill(&self) -> bool {
132        self.num_computed_tokens < self.num_prompt_tokens
133    }
134
135    /// Get number of tokens needed to process
136    pub fn tokens_to_process(&self) -> u32 {
137        if self.is_finished() {
138            return 0; // Don't process more if we've generated all output
139        }
140        self.num_tokens - self.num_computed_tokens
141    }
142
143    /// Check if request is done
144    pub fn is_finished(&self) -> bool {
145        self.num_output_tokens >= self.target_output_tokens
146    }
147
148    /// Get total tokens (prompt + max output)
149    pub fn total_tokens(&self) -> u32 {
150        self.num_prompt_tokens + self.max_output_tokens
151    }
152
153    /// Get remaining tokens to process
154    pub fn remaining_tokens(&self) -> u32 {
155        self.num_prompt_tokens + self.max_output_tokens - self.num_computed_tokens
156    }
157
158    /// Calculate KV cache requirement for this request
159    pub fn kv_cache_size(&self, model: &ModelConfig) -> u64 {
160        model.kv_cache_size_for_sequence(self.num_tokens)
161    }
162
163    /// Record that tokens were generated (update output token count and total)
164    pub fn record_generated_tokens(&mut self, num_new_tokens: u32, current_time: f64) {
165        // Update computed tokens
166        self.num_computed_tokens += num_new_tokens;
167
168        // If we've crossed into decode phase, update output tokens
169        if self.num_computed_tokens > self.num_prompt_tokens {
170            let new_output_tokens =
171                (self.num_computed_tokens - self.num_prompt_tokens).min(self.max_output_tokens); // Cap at max
172
173            // Record first token time if this is the first output token
174            if self.first_token_time.is_none() && new_output_tokens > 0 {
175                self.first_token_time = Some(current_time);
176            }
177
178            self.num_output_tokens = new_output_tokens;
179            // Note: num_tokens stays fixed at num_prompt_tokens + max_output_tokens
180
181            // Record generation times for each decode token
182            self.token_generation_times.push(current_time);
183        }
184    }
185
186    /// Mark request as preempted
187    pub fn mark_preempted(&mut self, current_time: f64) {
188        self.status = RequestStatus::Preempted;
189        self.num_preemptions += 1;
190        self.last_preempted_at = Some(current_time);
191    }
192
193    /// Resume a preempted request
194    pub fn resume(&mut self, current_time: f64) {
195        if let Some(preempted_at) = self.last_preempted_at {
196            self.preempted_time += current_time - preempted_at;
197        }
198        self.status = RequestStatus::Running;
199        self.last_preempted_at = None;
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn test_request_creation() {
209        let req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
210
211        assert_eq!(req.request_id, "req-1");
212        assert_eq!(req.priority, 0);
213        assert_eq!(req.arrival_time, 0.0);
214        assert_eq!(req.status, RequestStatus::Waiting);
215        assert_eq!(req.num_prompt_tokens, 100);
216        assert_eq!(req.max_output_tokens, 50);
217        assert_eq!(req.num_computed_tokens, 0);
218        assert_eq!(req.num_output_tokens, 0);
219        assert_eq!(req.num_tokens, 150); // prompt + max_output
220    }
221
222    #[test]
223    fn test_is_prefill() {
224        let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
225
226        assert!(req.is_prefill());
227
228        req.num_computed_tokens = 50;
229        assert!(req.is_prefill());
230
231        req.num_computed_tokens = 100;
232        assert!(!req.is_prefill());
233    }
234
235    #[test]
236    fn test_tokens_to_process() {
237        let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
238
239        assert_eq!(req.tokens_to_process(), 150); // prompt(100) + max_output(50)
240
241        req.num_computed_tokens = 50;
242        assert_eq!(req.tokens_to_process(), 100); // 150 - 50
243
244        req.num_computed_tokens = 100;
245        assert_eq!(req.tokens_to_process(), 50); // 150 - 100 (in decode phase now)
246
247        req.num_computed_tokens = 150;
248        assert_eq!(req.tokens_to_process(), 0); // All tokens processed
249    }
250
251    #[test]
252    fn test_is_finished() {
253        let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
254
255        assert!(!req.is_finished());
256
257        req.num_output_tokens = 25;
258        assert!(!req.is_finished());
259
260        req.num_output_tokens = 50;
261        assert!(req.is_finished());
262
263        req.num_output_tokens = 60;
264        assert!(req.is_finished());
265    }
266
267    #[test]
268    fn test_record_generated_tokens() {
269        let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
270
271        // Prefill phase
272        req.record_generated_tokens(50, 1.0);
273        assert_eq!(req.num_computed_tokens, 50);
274        assert_eq!(req.num_output_tokens, 0);
275        assert_eq!(req.num_tokens, 150); // prompt(100) + max_output(50)
276        assert!(req.first_token_time.is_none());
277
278        // Complete prefill and start decode
279        req.record_generated_tokens(51, 2.0);
280        assert_eq!(req.num_computed_tokens, 101);
281        assert_eq!(req.num_output_tokens, 1);
282        assert_eq!(req.num_tokens, 150); // Stays fixed at prompt(100) + max_output(50)
283        assert_eq!(req.first_token_time, Some(2.0));
284        assert_eq!(req.token_generation_times.len(), 1);
285
286        // Continue decode
287        req.record_generated_tokens(1, 3.0);
288        assert_eq!(req.num_computed_tokens, 102);
289        assert_eq!(req.num_output_tokens, 2);
290        assert_eq!(req.num_tokens, 150); // Stays fixed
291        assert_eq!(req.first_token_time, Some(2.0)); // Doesn't change
292        assert_eq!(req.token_generation_times.len(), 2);
293    }
294
295    #[test]
296    fn test_preemption() {
297        let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
298        req.status = RequestStatus::Running;
299
300        // Preempt
301        req.mark_preempted(5.0);
302        assert_eq!(req.status, RequestStatus::Preempted);
303        assert_eq!(req.num_preemptions, 1);
304        assert_eq!(req.last_preempted_at, Some(5.0));
305
306        // Resume
307        req.resume(10.0);
308        assert_eq!(req.status, RequestStatus::Running);
309        assert_eq!(req.preempted_time, 5.0);
310        assert!(req.last_preempted_at.is_none());
311
312        // Preempt again
313        req.mark_preempted(15.0);
314        assert_eq!(req.num_preemptions, 2);
315
316        // Resume again
317        req.resume(20.0);
318        assert_eq!(req.preempted_time, 10.0); // 5.0 + 5.0
319    }
320}