kizzasi_inference/
batch.rs

1//! Continuous batching for high-throughput inference
2//!
3//! This module implements continuous batching (also known as iteration-level scheduling),
4//! which allows dynamic batching of inference requests to maximize GPU/CPU utilization
5//! while maintaining low latency for individual requests.
6//!
7//! # Key Features
8//!
9//! - **Dynamic batch formation**: Requests are grouped on-the-fly
10//! - **Early exit**: Completed sequences leave the batch immediately
11//! - **Variable-length sequences**: Different requests can have different lengths
12//! - **Priority scheduling**: High-priority requests can skip the queue
13//!
14//! # References
15//!
16//! - Orca paper: https://www.usenix.org/system/files/osdi22-yu.pdf
17//! - vLLM: https://arxiv.org/abs/2309.06180
18
19use crate::engine::{EngineConfig, InferenceEngine};
20use crate::error::InferenceResult;
21use scirs2_core::ndarray::Array1;
22use std::collections::VecDeque;
23use std::time::{Duration, Instant};
24
25/// Configuration for continuous batching
26#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
27pub struct BatchConfig {
28    /// Maximum batch size
29    pub max_batch_size: usize,
30    /// Maximum waiting time before forming a batch (milliseconds)
31    pub max_wait_ms: u64,
32    /// Minimum batch size before processing (0 = process immediately)
33    pub min_batch_size: usize,
34    /// Enable priority-based scheduling
35    pub enable_priority: bool,
36    /// Maximum sequence length
37    pub max_seq_len: usize,
38}
39
40impl Default for BatchConfig {
41    fn default() -> Self {
42        Self {
43            max_batch_size: 32,
44            max_wait_ms: 10,
45            min_batch_size: 1,
46            enable_priority: false,
47            max_seq_len: 2048,
48        }
49    }
50}
51
52impl BatchConfig {
53    /// Create a new batch configuration
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    /// Set maximum batch size
59    pub fn max_batch_size(mut self, size: usize) -> Self {
60        self.max_batch_size = size;
61        self
62    }
63
64    /// Set maximum wait time
65    pub fn max_wait_ms(mut self, ms: u64) -> Self {
66        self.max_wait_ms = ms;
67        self
68    }
69
70    /// Set minimum batch size
71    pub fn min_batch_size(mut self, size: usize) -> Self {
72        self.min_batch_size = size;
73        self
74    }
75
76    /// Enable priority scheduling
77    pub fn with_priority(mut self) -> Self {
78        self.enable_priority = true;
79        self
80    }
81}
82
83/// Priority level for inference requests
84#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
85pub enum Priority {
86    Low = 0,
87    Normal = 1,
88    High = 2,
89    Critical = 3,
90}
91
92/// A single inference request in the batch
93#[derive(Debug, Clone)]
94pub struct BatchRequest {
95    /// Unique request ID
96    pub id: u64,
97    /// Input data
98    pub input: Array1<f32>,
99    /// Maximum number of steps to generate
100    pub max_steps: usize,
101    /// Priority level
102    pub priority: Priority,
103    /// Timestamp when request was received
104    pub received_at: Instant,
105    /// Current step number
106    pub current_step: usize,
107}
108
109impl BatchRequest {
110    /// Create a new batch request
111    pub fn new(id: u64, input: Array1<f32>, max_steps: usize) -> Self {
112        Self {
113            id,
114            input,
115            max_steps,
116            priority: Priority::Normal,
117            received_at: Instant::now(),
118            current_step: 0,
119        }
120    }
121
122    /// Set priority
123    pub fn with_priority(mut self, priority: Priority) -> Self {
124        self.priority = priority;
125        self
126    }
127
128    /// Check if request is complete
129    pub fn is_complete(&self) -> bool {
130        self.current_step >= self.max_steps
131    }
132
133    /// Get waiting time in milliseconds
134    pub fn wait_time_ms(&self) -> u64 {
135        self.received_at.elapsed().as_millis() as u64
136    }
137}
138
139/// Response from a batch inference request
140#[derive(Debug, Clone)]
141pub struct BatchResponse {
142    /// Request ID
143    pub request_id: u64,
144    /// Generated outputs (one per step)
145    pub outputs: Vec<Array1<f32>>,
146    /// Number of steps completed
147    pub steps_completed: usize,
148    /// Whether the request is complete
149    pub is_complete: bool,
150    /// Total inference time in microseconds
151    pub inference_time_us: u64,
152}
153
154/// Continuous batching scheduler
155pub struct BatchScheduler {
156    config: BatchConfig,
157    engine: InferenceEngine,
158    /// Queue of pending requests
159    pending: VecDeque<BatchRequest>,
160    /// Currently processing requests
161    active: Vec<BatchRequest>,
162    /// Completed responses
163    completed: Vec<BatchResponse>,
164    /// Next request ID
165    next_id: u64,
166    /// Last batch formation time
167    last_batch_time: Instant,
168}
169
170impl BatchScheduler {
171    /// Create a new batch scheduler
172    pub fn new(config: BatchConfig, engine_config: EngineConfig) -> InferenceResult<Self> {
173        let engine = InferenceEngine::new(engine_config);
174
175        Ok(Self {
176            config,
177            engine,
178            pending: VecDeque::new(),
179            active: Vec::new(),
180            completed: Vec::new(),
181            next_id: 0,
182            last_batch_time: Instant::now(),
183        })
184    }
185
186    /// Submit a new inference request
187    pub fn submit(&mut self, input: Array1<f32>, max_steps: usize) -> u64 {
188        let id = self.next_id;
189        self.next_id += 1;
190
191        let request = BatchRequest::new(id, input, max_steps);
192        self.pending.push_back(request);
193
194        id
195    }
196
197    /// Submit a request with priority
198    pub fn submit_with_priority(
199        &mut self,
200        input: Array1<f32>,
201        max_steps: usize,
202        priority: Priority,
203    ) -> u64 {
204        let id = self.next_id;
205        self.next_id += 1;
206
207        let request = BatchRequest::new(id, input, max_steps).with_priority(priority);
208
209        // Insert based on priority if enabled
210        if self.config.enable_priority {
211            let insert_pos = self
212                .pending
213                .iter()
214                .position(|r| r.priority < priority)
215                .unwrap_or(self.pending.len());
216            self.pending.insert(insert_pos, request);
217        } else {
218            self.pending.push_back(request);
219        }
220
221        id
222    }
223
224    /// Check if it's time to form a new batch
225    fn should_form_batch(&self) -> bool {
226        if self.pending.is_empty() {
227            return false;
228        }
229
230        // Check if we have enough requests
231        if self.pending.len() >= self.config.min_batch_size {
232            return true;
233        }
234
235        // Check if we've waited long enough
236        let wait_time = self.last_batch_time.elapsed();
237        wait_time >= Duration::from_millis(self.config.max_wait_ms)
238    }
239
240    /// Form a new batch from pending requests
241    fn form_batch(&mut self) {
242        let batch_size = self
243            .config
244            .max_batch_size
245            .min(self.pending.len())
246            .min(self.config.max_batch_size - self.active.len());
247
248        for _ in 0..batch_size {
249            if let Some(request) = self.pending.pop_front() {
250                self.active.push(request);
251            }
252        }
253
254        self.last_batch_time = Instant::now();
255    }
256
257    /// Process one step for all active requests
258    pub fn step(&mut self) -> InferenceResult<Vec<BatchResponse>> {
259        // Form batch if needed
260        if self.should_form_batch() {
261            self.form_batch();
262        }
263
264        if self.active.is_empty() {
265            return Ok(Vec::new());
266        }
267
268        let start = Instant::now();
269        let mut responses = Vec::new();
270
271        // Process each active request
272        let mut i = 0;
273        while i < self.active.len() {
274            let request = &mut self.active[i];
275
276            // Run one inference step
277            let output = self.engine.step(&request.input)?;
278
279            request.current_step += 1;
280            request.input = output.clone(); // Use output as next input
281
282            // Check if complete
283            if request.is_complete() {
284                let completed_request = self.active.remove(i);
285                let inference_time = start.elapsed().as_micros() as u64;
286
287                responses.push(BatchResponse {
288                    request_id: completed_request.id,
289                    outputs: vec![output],
290                    steps_completed: completed_request.current_step,
291                    is_complete: true,
292                    inference_time_us: inference_time,
293                });
294            } else {
295                i += 1;
296            }
297        }
298
299        Ok(responses)
300    }
301
302    /// Process all active and pending requests until completion
303    pub fn process_all(&mut self) -> InferenceResult<Vec<BatchResponse>> {
304        let mut all_responses = Vec::new();
305
306        while !self.pending.is_empty() || !self.active.is_empty() {
307            let responses = self.step()?;
308            all_responses.extend(responses);
309        }
310
311        Ok(all_responses)
312    }
313
314    /// Get statistics about the scheduler
315    pub fn stats(&self) -> SchedulerStats {
316        SchedulerStats {
317            pending_requests: self.pending.len(),
318            active_requests: self.active.len(),
319            completed_requests: self.completed.len(),
320            total_submitted: self.next_id,
321        }
322    }
323
324    /// Reset the scheduler
325    pub fn reset(&mut self) {
326        self.pending.clear();
327        self.active.clear();
328        self.completed.clear();
329        self.engine.reset();
330    }
331}
332
333/// Statistics about the batch scheduler
334#[derive(Debug, Clone)]
335pub struct SchedulerStats {
336    pub pending_requests: usize,
337    pub active_requests: usize,
338    pub completed_requests: usize,
339    pub total_submitted: u64,
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn test_batch_config() {
348        let config = BatchConfig::new()
349            .max_batch_size(16)
350            .max_wait_ms(5)
351            .min_batch_size(4)
352            .with_priority();
353
354        assert_eq!(config.max_batch_size, 16);
355        assert_eq!(config.max_wait_ms, 5);
356        assert_eq!(config.min_batch_size, 4);
357        assert!(config.enable_priority);
358    }
359
360    #[test]
361    fn test_batch_request() {
362        let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
363        let request = BatchRequest::new(1, input, 10);
364
365        assert_eq!(request.id, 1);
366        assert_eq!(request.max_steps, 10);
367        assert_eq!(request.current_step, 0);
368        assert!(!request.is_complete());
369    }
370
371    #[test]
372    fn test_priority_ordering() {
373        assert!(Priority::Critical > Priority::High);
374        assert!(Priority::High > Priority::Normal);
375        assert!(Priority::Normal > Priority::Low);
376    }
377
378    #[test]
379    fn test_scheduler_creation() {
380        let batch_config = BatchConfig::new();
381        let engine_config = EngineConfig::new(3, 3);
382
383        let scheduler = BatchScheduler::new(batch_config, engine_config);
384        assert!(scheduler.is_ok());
385    }
386
387    #[test]
388    fn test_scheduler_submit() {
389        let batch_config = BatchConfig::new();
390        let engine_config = EngineConfig::new(3, 3);
391        let mut scheduler = BatchScheduler::new(batch_config, engine_config).unwrap();
392
393        let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
394        let id = scheduler.submit(input, 5);
395
396        assert_eq!(id, 0);
397        assert_eq!(scheduler.stats().pending_requests, 1);
398    }
399
400    #[test]
401    fn test_scheduler_priority() {
402        let batch_config = BatchConfig::new().with_priority();
403        let engine_config = EngineConfig::new(3, 3);
404        let mut scheduler = BatchScheduler::new(batch_config, engine_config).unwrap();
405
406        let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
407
408        // Submit with different priorities
409        let _id1 = scheduler.submit_with_priority(input.clone(), 5, Priority::Low);
410        let _id2 = scheduler.submit_with_priority(input.clone(), 5, Priority::High);
411        let _id3 = scheduler.submit_with_priority(input.clone(), 5, Priority::Normal);
412
413        // High priority should be first
414        assert_eq!(scheduler.pending[0].priority, Priority::High);
415        assert_eq!(scheduler.stats().pending_requests, 3);
416    }
417
418    #[test]
419    fn test_scheduler_stats() {
420        let batch_config = BatchConfig::new();
421        let engine_config = EngineConfig::new(3, 3);
422        let mut scheduler = BatchScheduler::new(batch_config, engine_config).unwrap();
423
424        let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
425        scheduler.submit(input.clone(), 5);
426        scheduler.submit(input.clone(), 5);
427
428        let stats = scheduler.stats();
429        assert_eq!(stats.pending_requests, 2);
430        assert_eq!(stats.total_submitted, 2);
431    }
432}