cbtop/continuous_batcher/
request.rs1use std::time::Instant;
4
5use crate::paged_kv::SeqId;
6
7pub type Token = u32;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
12pub struct Priority(pub u8);
13
14impl Default for Priority {
15 fn default() -> Self {
16 Priority(128) }
18}
19
20#[derive(Debug, Clone)]
22pub struct InferenceRequest {
23 pub id: SeqId,
25 pub input_tokens: Vec<Token>,
27 pub max_new_tokens: usize,
29 pub priority: Priority,
31 pub arrival_time: Instant,
33 pub estimated_tokens: usize,
35}
36
37impl InferenceRequest {
38 pub fn new(id: SeqId, input_tokens: Vec<Token>, max_new_tokens: usize) -> Self {
40 let estimated_tokens = input_tokens.len() + max_new_tokens;
41 Self {
42 id,
43 input_tokens,
44 max_new_tokens,
45 priority: Priority::default(),
46 arrival_time: Instant::now(),
47 estimated_tokens,
48 }
49 }
50
51 pub fn with_priority(mut self, priority: Priority) -> Self {
53 self.priority = priority;
54 self
55 }
56
57 pub fn input_len(&self) -> usize {
59 self.input_tokens.len()
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct SequenceGroup {
66 pub request: InferenceRequest,
68 pub output_tokens: Vec<Token>,
70 pub is_finished: bool,
72 pub last_access: Instant,
74 pub num_steps: usize,
76}
77
78impl SequenceGroup {
79 pub fn new(request: InferenceRequest) -> Self {
81 Self {
82 request,
83 output_tokens: Vec::new(),
84 is_finished: false,
85 last_access: Instant::now(),
86 num_steps: 0,
87 }
88 }
89
90 pub fn total_tokens(&self) -> usize {
92 self.request.input_tokens.len() + self.output_tokens.len()
93 }
94
95 pub fn remaining_tokens(&self) -> usize {
97 self.request
98 .max_new_tokens
99 .saturating_sub(self.output_tokens.len())
100 }
101
102 pub fn add_token(&mut self, token: Token) {
104 self.output_tokens.push(token);
105 self.last_access = Instant::now();
106 self.num_steps += 1;
107
108 if self.output_tokens.len() >= self.request.max_new_tokens {
110 self.is_finished = true;
111 }
112 }
113
114 pub fn finish(&mut self) {
116 self.is_finished = true;
117 }
118}