inference_lab/request/
request.rs

1use super::status::RequestStatus;
2use crate::config::ModelConfig;
3
4pub type BlockId = u32;
5
6/// Request represents a single inference request in the simulation
7#[derive(Debug, Clone)]
8pub struct Request {
9    /// Unique request ID
10    pub request_id: String,
11
12    /// Client priority (lower = higher priority)
13    pub priority: i32,
14
15    /// Arrival time (simulated time)
16    pub arrival_time: f64,
17
18    /// Request status
19    pub status: RequestStatus,
20
21    /// Number of input tokens
22    pub num_prompt_tokens: u32,
23
24    /// Maximum number of output tokens to generate
25    pub max_output_tokens: u32,
26
27    /// Number of tokens computed so far
28    pub num_computed_tokens: u32,
29
30    /// Number of output tokens generated so far
31    pub num_output_tokens: u32,
32
33    /// Total tokens (prompt + output)
34    pub num_tokens: u32,
35
36    /// Number of prefix-cached tokens
37    pub num_cached_tokens: u32,
38
39    /// KV cache blocks allocated to this request
40    pub kv_blocks: Vec<BlockId>,
41
42    /// Number of times this request has been preempted
43    pub num_preemptions: u32,
44
45    /// Time when first token was generated (TTFT tracking)
46    pub first_token_time: Option<f64>,
47
48    /// Time when request completed
49    pub completion_time: Option<f64>,
50
51    /// Per-token generation times
52    pub token_generation_times: Vec<f64>,
53
54    /// Time spent preempted (not running)
55    pub preempted_time: f64,
56
57    /// Last preemption start time
58    pub last_preempted_at: Option<f64>,
59}
60
61impl Request {
62    /// Create a new request
63    pub fn new(
64        request_id: String,
65        priority: i32,
66        arrival_time: f64,
67        num_prompt_tokens: u32,
68        max_output_tokens: u32,
69    ) -> Self {
70        Self {
71            request_id,
72            priority,
73            arrival_time,
74            status: RequestStatus::Waiting,
75            num_prompt_tokens,
76            max_output_tokens,
77            num_computed_tokens: 0,
78            num_output_tokens: 0,
79            num_tokens: num_prompt_tokens + max_output_tokens, // Total tokens to process
80            num_cached_tokens: 0,
81            kv_blocks: Vec::new(),
82            num_preemptions: 0,
83            first_token_time: None,
84            completion_time: None,
85            token_generation_times: Vec::new(),
86            preempted_time: 0.0,
87            last_preempted_at: None,
88        }
89    }
90
91    /// Check if this is in prefill phase
92    pub fn is_prefill(&self) -> bool {
93        self.num_computed_tokens < self.num_prompt_tokens
94    }
95
96    /// Get number of tokens needed to process
97    pub fn tokens_to_process(&self) -> u32 {
98        if self.is_finished() {
99            return 0; // Don't process more if we've generated all output
100        }
101        self.num_tokens - self.num_computed_tokens
102    }
103
104    /// Check if request is done
105    pub fn is_finished(&self) -> bool {
106        self.num_output_tokens >= self.max_output_tokens
107    }
108
109    /// Calculate KV cache requirement for this request
110    pub fn kv_cache_size(&self, model: &ModelConfig) -> u64 {
111        model.kv_cache_size_for_sequence(self.num_tokens)
112    }
113
114    /// Record that tokens were generated (update output token count and total)
115    pub fn record_generated_tokens(&mut self, num_new_tokens: u32, current_time: f64) {
116        // Update computed tokens
117        self.num_computed_tokens += num_new_tokens;
118
119        // If we've crossed into decode phase, update output tokens
120        if self.num_computed_tokens > self.num_prompt_tokens {
121            let new_output_tokens = (self.num_computed_tokens - self.num_prompt_tokens)
122                .min(self.max_output_tokens); // Cap at max
123
124            // Record first token time if this is the first output token
125            if self.first_token_time.is_none() && new_output_tokens > 0 {
126                self.first_token_time = Some(current_time);
127            }
128
129            self.num_output_tokens = new_output_tokens;
130            // Note: num_tokens stays fixed at num_prompt_tokens + max_output_tokens
131
132            // Record generation times for each decode token
133            self.token_generation_times.push(current_time);
134        }
135    }
136
137    /// Mark request as preempted
138    pub fn mark_preempted(&mut self, current_time: f64) {
139        self.status = RequestStatus::Preempted;
140        self.num_preemptions += 1;
141        self.last_preempted_at = Some(current_time);
142    }
143
144    /// Resume a preempted request
145    pub fn resume(&mut self, current_time: f64) {
146        if let Some(preempted_at) = self.last_preempted_at {
147            self.preempted_time += current_time - preempted_at;
148        }
149        self.status = RequestStatus::Running;
150        self.last_preempted_at = None;
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn test_request_creation() {
160        let req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
161
162        assert_eq!(req.request_id, "req-1");
163        assert_eq!(req.priority, 0);
164        assert_eq!(req.arrival_time, 0.0);
165        assert_eq!(req.status, RequestStatus::Waiting);
166        assert_eq!(req.num_prompt_tokens, 100);
167        assert_eq!(req.max_output_tokens, 50);
168        assert_eq!(req.num_computed_tokens, 0);
169        assert_eq!(req.num_output_tokens, 0);
170        assert_eq!(req.num_tokens, 100);
171    }
172
173    #[test]
174    fn test_is_prefill() {
175        let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
176
177        assert!(req.is_prefill());
178
179        req.num_computed_tokens = 50;
180        assert!(req.is_prefill());
181
182        req.num_computed_tokens = 100;
183        assert!(!req.is_prefill());
184    }
185
186    #[test]
187    fn test_tokens_to_process() {
188        let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
189
190        assert_eq!(req.tokens_to_process(), 100);
191
192        req.num_computed_tokens = 50;
193        assert_eq!(req.tokens_to_process(), 50);
194
195        req.num_computed_tokens = 100;
196        assert_eq!(req.tokens_to_process(), 0);
197    }
198
199    #[test]
200    fn test_is_finished() {
201        let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
202
203        assert!(!req.is_finished());
204
205        req.num_output_tokens = 25;
206        assert!(!req.is_finished());
207
208        req.num_output_tokens = 50;
209        assert!(req.is_finished());
210
211        req.num_output_tokens = 60;
212        assert!(req.is_finished());
213    }
214
215    #[test]
216    fn test_record_generated_tokens() {
217        let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
218
219        // Prefill phase
220        req.record_generated_tokens(50, 1.0);
221        assert_eq!(req.num_computed_tokens, 50);
222        assert_eq!(req.num_output_tokens, 0);
223        assert_eq!(req.num_tokens, 100);
224        assert!(req.first_token_time.is_none());
225
226        // Complete prefill and start decode
227        req.record_generated_tokens(51, 2.0);
228        assert_eq!(req.num_computed_tokens, 101);
229        assert_eq!(req.num_output_tokens, 1);
230        assert_eq!(req.num_tokens, 101);
231        assert_eq!(req.first_token_time, Some(2.0));
232        assert_eq!(req.token_generation_times.len(), 1);
233
234        // Continue decode
235        req.record_generated_tokens(1, 3.0);
236        assert_eq!(req.num_computed_tokens, 102);
237        assert_eq!(req.num_output_tokens, 2);
238        assert_eq!(req.num_tokens, 102);
239        assert_eq!(req.first_token_time, Some(2.0)); // Doesn't change
240        assert_eq!(req.token_generation_times.len(), 2);
241    }
242
243    #[test]
244    fn test_preemption() {
245        let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
246        req.status = RequestStatus::Running;
247
248        // Preempt
249        req.mark_preempted(5.0);
250        assert_eq!(req.status, RequestStatus::Preempted);
251        assert_eq!(req.num_preemptions, 1);
252        assert_eq!(req.last_preempted_at, Some(5.0));
253
254        // Resume
255        req.resume(10.0);
256        assert_eq!(req.status, RequestStatus::Running);
257        assert_eq!(req.preempted_time, 5.0);
258        assert!(req.last_preempted_at.is_none());
259
260        // Preempt again
261        req.mark_preempted(15.0);
262        assert_eq!(req.num_preemptions, 2);
263
264        // Resume again
265        req.resume(20.0);
266        assert_eq!(req.preempted_time, 10.0); // 5.0 + 5.0
267    }
268}