1use super::status::RequestStatus;
2use crate::config::ModelConfig;
3
4pub type BlockId = u32;
5
6#[derive(Debug, Clone)]
8pub struct Request {
9 pub request_id: String,
11
12 pub priority: i32,
14
15 pub arrival_time: f64,
17
18 pub status: RequestStatus,
20
21 pub num_prompt_tokens: u32,
23
24 pub max_output_tokens: u32,
26
27 pub target_output_tokens: u32,
30
31 pub num_computed_tokens: u32,
33
34 pub num_output_tokens: u32,
36
37 pub num_tokens: u32,
39
40 pub num_cached_tokens: u32,
42
43 pub prompt_block_hashes: Vec<u64>,
47
48 pub kv_blocks: Vec<BlockId>,
50
51 pub num_preemptions: u32,
53
54 pub first_token_time: Option<f64>,
56
57 pub completion_time: Option<f64>,
59
60 pub token_generation_times: Vec<f64>,
62
63 pub preempted_time: f64,
65
66 pub last_preempted_at: Option<f64>,
68}
69
70impl Request {
71 pub fn new_with_target(
73 request_id: String,
74 priority: i32,
75 arrival_time: f64,
76 num_prompt_tokens: u32,
77 max_output_tokens: u32,
78 target_output_tokens: u32,
79 ) -> Self {
80 Self {
81 request_id,
82 priority,
83 arrival_time,
84 status: RequestStatus::Waiting,
85 num_prompt_tokens,
86 max_output_tokens,
87 target_output_tokens,
88 num_computed_tokens: 0,
89 num_output_tokens: 0,
90 num_tokens: num_prompt_tokens + target_output_tokens,
91 num_cached_tokens: 0,
92 prompt_block_hashes: Vec::new(),
93 kv_blocks: Vec::new(),
94 num_preemptions: 0,
95 first_token_time: None,
96 completion_time: None,
97 token_generation_times: Vec::new(),
98 preempted_time: 0.0,
99 last_preempted_at: None,
100 }
101 }
102
103 pub fn new(
105 request_id: String,
106 priority: i32,
107 arrival_time: f64,
108 num_prompt_tokens: u32,
109 max_output_tokens: u32,
110 ) -> Self {
111 Self::new_with_target(
112 request_id,
113 priority,
114 arrival_time,
115 num_prompt_tokens,
116 max_output_tokens,
117 max_output_tokens, )
119 }
120
121 pub fn get_prompt_block_hashes(&self) -> &[u64] {
127 &self.prompt_block_hashes
128 }
129
130 pub fn is_prefill(&self) -> bool {
132 self.num_computed_tokens < self.num_prompt_tokens
133 }
134
135 pub fn tokens_to_process(&self) -> u32 {
137 if self.is_finished() {
138 return 0; }
140 self.num_tokens - self.num_computed_tokens
141 }
142
143 pub fn is_finished(&self) -> bool {
145 self.num_output_tokens >= self.target_output_tokens
146 }
147
148 pub fn total_tokens(&self) -> u32 {
150 self.num_prompt_tokens + self.max_output_tokens
151 }
152
153 pub fn remaining_tokens(&self) -> u32 {
155 self.num_prompt_tokens + self.max_output_tokens - self.num_computed_tokens
156 }
157
158 pub fn kv_cache_size(&self, model: &ModelConfig) -> u64 {
160 model.kv_cache_size_for_sequence(self.num_tokens)
161 }
162
163 pub fn record_generated_tokens(&mut self, num_new_tokens: u32, current_time: f64) {
165 self.num_computed_tokens += num_new_tokens;
167
168 if self.num_computed_tokens > self.num_prompt_tokens {
170 let new_output_tokens =
171 (self.num_computed_tokens - self.num_prompt_tokens).min(self.max_output_tokens); if self.first_token_time.is_none() && new_output_tokens > 0 {
175 self.first_token_time = Some(current_time);
176 }
177
178 self.num_output_tokens = new_output_tokens;
179 self.token_generation_times.push(current_time);
183 }
184 }
185
186 pub fn mark_preempted(&mut self, current_time: f64) {
188 self.status = RequestStatus::Preempted;
189 self.num_preemptions += 1;
190 self.last_preempted_at = Some(current_time);
191 }
192
193 pub fn resume(&mut self, current_time: f64) {
195 if let Some(preempted_at) = self.last_preempted_at {
196 self.preempted_time += current_time - preempted_at;
197 }
198 self.status = RequestStatus::Running;
199 self.last_preempted_at = None;
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_request_creation() {
209 let req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
210
211 assert_eq!(req.request_id, "req-1");
212 assert_eq!(req.priority, 0);
213 assert_eq!(req.arrival_time, 0.0);
214 assert_eq!(req.status, RequestStatus::Waiting);
215 assert_eq!(req.num_prompt_tokens, 100);
216 assert_eq!(req.max_output_tokens, 50);
217 assert_eq!(req.num_computed_tokens, 0);
218 assert_eq!(req.num_output_tokens, 0);
219 assert_eq!(req.num_tokens, 150); }
221
222 #[test]
223 fn test_is_prefill() {
224 let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
225
226 assert!(req.is_prefill());
227
228 req.num_computed_tokens = 50;
229 assert!(req.is_prefill());
230
231 req.num_computed_tokens = 100;
232 assert!(!req.is_prefill());
233 }
234
235 #[test]
236 fn test_tokens_to_process() {
237 let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
238
239 assert_eq!(req.tokens_to_process(), 150); req.num_computed_tokens = 50;
242 assert_eq!(req.tokens_to_process(), 100); req.num_computed_tokens = 100;
245 assert_eq!(req.tokens_to_process(), 50); req.num_computed_tokens = 150;
248 assert_eq!(req.tokens_to_process(), 0); }
250
251 #[test]
252 fn test_is_finished() {
253 let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
254
255 assert!(!req.is_finished());
256
257 req.num_output_tokens = 25;
258 assert!(!req.is_finished());
259
260 req.num_output_tokens = 50;
261 assert!(req.is_finished());
262
263 req.num_output_tokens = 60;
264 assert!(req.is_finished());
265 }
266
267 #[test]
268 fn test_record_generated_tokens() {
269 let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
270
271 req.record_generated_tokens(50, 1.0);
273 assert_eq!(req.num_computed_tokens, 50);
274 assert_eq!(req.num_output_tokens, 0);
275 assert_eq!(req.num_tokens, 150); assert!(req.first_token_time.is_none());
277
278 req.record_generated_tokens(51, 2.0);
280 assert_eq!(req.num_computed_tokens, 101);
281 assert_eq!(req.num_output_tokens, 1);
282 assert_eq!(req.num_tokens, 150); assert_eq!(req.first_token_time, Some(2.0));
284 assert_eq!(req.token_generation_times.len(), 1);
285
286 req.record_generated_tokens(1, 3.0);
288 assert_eq!(req.num_computed_tokens, 102);
289 assert_eq!(req.num_output_tokens, 2);
290 assert_eq!(req.num_tokens, 150); assert_eq!(req.first_token_time, Some(2.0)); assert_eq!(req.token_generation_times.len(), 2);
293 }
294
295 #[test]
296 fn test_preemption() {
297 let mut req = Request::new("req-1".to_string(), 0, 0.0, 100, 50);
298 req.status = RequestStatus::Running;
299
300 req.mark_preempted(5.0);
302 assert_eq!(req.status, RequestStatus::Preempted);
303 assert_eq!(req.num_preemptions, 1);
304 assert_eq!(req.last_preempted_at, Some(5.0));
305
306 req.resume(10.0);
308 assert_eq!(req.status, RequestStatus::Running);
309 assert_eq!(req.preempted_time, 5.0);
310 assert!(req.last_preempted_at.is_none());
311
312 req.mark_preempted(15.0);
314 assert_eq!(req.num_preemptions, 2);
315
316 req.resume(20.0);
318 assert_eq!(req.preempted_time, 10.0); }
320}