Skip to main content

cbtop/continuous_batcher/
request.rs

1//! Inference request and sequence types for continuous batching.
2
3use std::time::Instant;
4
5use crate::paged_kv::SeqId;
6
7/// Token type (simplified - u32 vocabulary index).
8pub type Token = u32;
9
10/// Request priority level.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
12pub struct Priority(pub u8);
13
14impl Default for Priority {
15    fn default() -> Self {
16        Priority(128) // Middle priority
17    }
18}
19
20/// Inference request.
21#[derive(Debug, Clone)]
22pub struct InferenceRequest {
23    /// Unique request ID
24    pub id: SeqId,
25    /// Input tokens (prompt)
26    pub input_tokens: Vec<Token>,
27    /// Maximum output tokens to generate
28    pub max_new_tokens: usize,
29    /// Request priority
30    pub priority: Priority,
31    /// Arrival timestamp
32    pub arrival_time: Instant,
33    /// Estimated total tokens (input + output)
34    pub estimated_tokens: usize,
35}
36
37impl InferenceRequest {
38    /// Create a new inference request.
39    pub fn new(id: SeqId, input_tokens: Vec<Token>, max_new_tokens: usize) -> Self {
40        let estimated_tokens = input_tokens.len() + max_new_tokens;
41        Self {
42            id,
43            input_tokens,
44            max_new_tokens,
45            priority: Priority::default(),
46            arrival_time: Instant::now(),
47            estimated_tokens,
48        }
49    }
50
51    /// Create request with priority.
52    pub fn with_priority(mut self, priority: Priority) -> Self {
53        self.priority = priority;
54        self
55    }
56
57    /// Input sequence length.
58    pub fn input_len(&self) -> usize {
59        self.input_tokens.len()
60    }
61}
62
63/// Sequence group (request + generation state).
64#[derive(Debug, Clone)]
65pub struct SequenceGroup {
66    /// Original request
67    pub request: InferenceRequest,
68    /// Generated tokens so far
69    pub output_tokens: Vec<Token>,
70    /// Is generation complete?
71    pub is_finished: bool,
72    /// Last access timestamp (for LRU)
73    pub last_access: Instant,
74    /// Number of decode steps
75    pub num_steps: usize,
76}
77
78impl SequenceGroup {
79    /// Create new sequence group from request.
80    pub fn new(request: InferenceRequest) -> Self {
81        Self {
82            request,
83            output_tokens: Vec::new(),
84            is_finished: false,
85            last_access: Instant::now(),
86            num_steps: 0,
87        }
88    }
89
90    /// Total tokens (input + output so far).
91    pub fn total_tokens(&self) -> usize {
92        self.request.input_tokens.len() + self.output_tokens.len()
93    }
94
95    /// Remaining tokens to generate.
96    pub fn remaining_tokens(&self) -> usize {
97        self.request
98            .max_new_tokens
99            .saturating_sub(self.output_tokens.len())
100    }
101
102    /// Add generated token.
103    pub fn add_token(&mut self, token: Token) {
104        self.output_tokens.push(token);
105        self.last_access = Instant::now();
106        self.num_steps += 1;
107
108        // Check if finished
109        if self.output_tokens.len() >= self.request.max_new_tokens {
110            self.is_finished = true;
111        }
112    }
113
114    /// Mark as finished.
115    pub fn finish(&mut self) {
116        self.is_finished = true;
117    }
118}