1use crate::engine::{EngineConfig, InferenceEngine};
20use crate::error::InferenceResult;
21use scirs2_core::ndarray::Array1;
22use std::collections::VecDeque;
23use std::time::{Duration, Instant};
24
25#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
27pub struct BatchConfig {
28 pub max_batch_size: usize,
30 pub max_wait_ms: u64,
32 pub min_batch_size: usize,
34 pub enable_priority: bool,
36 pub max_seq_len: usize,
38}
39
40impl Default for BatchConfig {
41 fn default() -> Self {
42 Self {
43 max_batch_size: 32,
44 max_wait_ms: 10,
45 min_batch_size: 1,
46 enable_priority: false,
47 max_seq_len: 2048,
48 }
49 }
50}
51
52impl BatchConfig {
53 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn max_batch_size(mut self, size: usize) -> Self {
60 self.max_batch_size = size;
61 self
62 }
63
64 pub fn max_wait_ms(mut self, ms: u64) -> Self {
66 self.max_wait_ms = ms;
67 self
68 }
69
70 pub fn min_batch_size(mut self, size: usize) -> Self {
72 self.min_batch_size = size;
73 self
74 }
75
76 pub fn with_priority(mut self) -> Self {
78 self.enable_priority = true;
79 self
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
85pub enum Priority {
86 Low = 0,
87 Normal = 1,
88 High = 2,
89 Critical = 3,
90}
91
92#[derive(Debug, Clone)]
94pub struct BatchRequest {
95 pub id: u64,
97 pub input: Array1<f32>,
99 pub max_steps: usize,
101 pub priority: Priority,
103 pub received_at: Instant,
105 pub current_step: usize,
107}
108
109impl BatchRequest {
110 pub fn new(id: u64, input: Array1<f32>, max_steps: usize) -> Self {
112 Self {
113 id,
114 input,
115 max_steps,
116 priority: Priority::Normal,
117 received_at: Instant::now(),
118 current_step: 0,
119 }
120 }
121
122 pub fn with_priority(mut self, priority: Priority) -> Self {
124 self.priority = priority;
125 self
126 }
127
128 pub fn is_complete(&self) -> bool {
130 self.current_step >= self.max_steps
131 }
132
133 pub fn wait_time_ms(&self) -> u64 {
135 self.received_at.elapsed().as_millis() as u64
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct BatchResponse {
142 pub request_id: u64,
144 pub outputs: Vec<Array1<f32>>,
146 pub steps_completed: usize,
148 pub is_complete: bool,
150 pub inference_time_us: u64,
152}
153
154pub struct BatchScheduler {
156 config: BatchConfig,
157 engine: InferenceEngine,
158 pending: VecDeque<BatchRequest>,
160 active: Vec<BatchRequest>,
162 completed: Vec<BatchResponse>,
164 next_id: u64,
166 last_batch_time: Instant,
168}
169
170impl BatchScheduler {
171 pub fn new(config: BatchConfig, engine_config: EngineConfig) -> InferenceResult<Self> {
173 let engine = InferenceEngine::new(engine_config);
174
175 Ok(Self {
176 config,
177 engine,
178 pending: VecDeque::new(),
179 active: Vec::new(),
180 completed: Vec::new(),
181 next_id: 0,
182 last_batch_time: Instant::now(),
183 })
184 }
185
186 pub fn submit(&mut self, input: Array1<f32>, max_steps: usize) -> u64 {
188 let id = self.next_id;
189 self.next_id += 1;
190
191 let request = BatchRequest::new(id, input, max_steps);
192 self.pending.push_back(request);
193
194 id
195 }
196
197 pub fn submit_with_priority(
199 &mut self,
200 input: Array1<f32>,
201 max_steps: usize,
202 priority: Priority,
203 ) -> u64 {
204 let id = self.next_id;
205 self.next_id += 1;
206
207 let request = BatchRequest::new(id, input, max_steps).with_priority(priority);
208
209 if self.config.enable_priority {
211 let insert_pos = self
212 .pending
213 .iter()
214 .position(|r| r.priority < priority)
215 .unwrap_or(self.pending.len());
216 self.pending.insert(insert_pos, request);
217 } else {
218 self.pending.push_back(request);
219 }
220
221 id
222 }
223
224 fn should_form_batch(&self) -> bool {
226 if self.pending.is_empty() {
227 return false;
228 }
229
230 if self.pending.len() >= self.config.min_batch_size {
232 return true;
233 }
234
235 let wait_time = self.last_batch_time.elapsed();
237 wait_time >= Duration::from_millis(self.config.max_wait_ms)
238 }
239
240 fn form_batch(&mut self) {
242 let batch_size = self
243 .config
244 .max_batch_size
245 .min(self.pending.len())
246 .min(self.config.max_batch_size - self.active.len());
247
248 for _ in 0..batch_size {
249 if let Some(request) = self.pending.pop_front() {
250 self.active.push(request);
251 }
252 }
253
254 self.last_batch_time = Instant::now();
255 }
256
257 pub fn step(&mut self) -> InferenceResult<Vec<BatchResponse>> {
259 if self.should_form_batch() {
261 self.form_batch();
262 }
263
264 if self.active.is_empty() {
265 return Ok(Vec::new());
266 }
267
268 let start = Instant::now();
269 let mut responses = Vec::new();
270
271 let mut i = 0;
273 while i < self.active.len() {
274 let request = &mut self.active[i];
275
276 let output = self.engine.step(&request.input)?;
278
279 request.current_step += 1;
280 request.input = output.clone(); if request.is_complete() {
284 let completed_request = self.active.remove(i);
285 let inference_time = start.elapsed().as_micros() as u64;
286
287 responses.push(BatchResponse {
288 request_id: completed_request.id,
289 outputs: vec![output],
290 steps_completed: completed_request.current_step,
291 is_complete: true,
292 inference_time_us: inference_time,
293 });
294 } else {
295 i += 1;
296 }
297 }
298
299 Ok(responses)
300 }
301
302 pub fn process_all(&mut self) -> InferenceResult<Vec<BatchResponse>> {
304 let mut all_responses = Vec::new();
305
306 while !self.pending.is_empty() || !self.active.is_empty() {
307 let responses = self.step()?;
308 all_responses.extend(responses);
309 }
310
311 Ok(all_responses)
312 }
313
314 pub fn stats(&self) -> SchedulerStats {
316 SchedulerStats {
317 pending_requests: self.pending.len(),
318 active_requests: self.active.len(),
319 completed_requests: self.completed.len(),
320 total_submitted: self.next_id,
321 }
322 }
323
324 pub fn reset(&mut self) {
326 self.pending.clear();
327 self.active.clear();
328 self.completed.clear();
329 self.engine.reset();
330 }
331}
332
333#[derive(Debug, Clone)]
335pub struct SchedulerStats {
336 pub pending_requests: usize,
337 pub active_requests: usize,
338 pub completed_requests: usize,
339 pub total_submitted: u64,
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_batch_config() {
348 let config = BatchConfig::new()
349 .max_batch_size(16)
350 .max_wait_ms(5)
351 .min_batch_size(4)
352 .with_priority();
353
354 assert_eq!(config.max_batch_size, 16);
355 assert_eq!(config.max_wait_ms, 5);
356 assert_eq!(config.min_batch_size, 4);
357 assert!(config.enable_priority);
358 }
359
360 #[test]
361 fn test_batch_request() {
362 let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
363 let request = BatchRequest::new(1, input, 10);
364
365 assert_eq!(request.id, 1);
366 assert_eq!(request.max_steps, 10);
367 assert_eq!(request.current_step, 0);
368 assert!(!request.is_complete());
369 }
370
371 #[test]
372 fn test_priority_ordering() {
373 assert!(Priority::Critical > Priority::High);
374 assert!(Priority::High > Priority::Normal);
375 assert!(Priority::Normal > Priority::Low);
376 }
377
378 #[test]
379 fn test_scheduler_creation() {
380 let batch_config = BatchConfig::new();
381 let engine_config = EngineConfig::new(3, 3);
382
383 let scheduler = BatchScheduler::new(batch_config, engine_config);
384 assert!(scheduler.is_ok());
385 }
386
387 #[test]
388 fn test_scheduler_submit() {
389 let batch_config = BatchConfig::new();
390 let engine_config = EngineConfig::new(3, 3);
391 let mut scheduler = BatchScheduler::new(batch_config, engine_config).unwrap();
392
393 let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
394 let id = scheduler.submit(input, 5);
395
396 assert_eq!(id, 0);
397 assert_eq!(scheduler.stats().pending_requests, 1);
398 }
399
400 #[test]
401 fn test_scheduler_priority() {
402 let batch_config = BatchConfig::new().with_priority();
403 let engine_config = EngineConfig::new(3, 3);
404 let mut scheduler = BatchScheduler::new(batch_config, engine_config).unwrap();
405
406 let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
407
408 let _id1 = scheduler.submit_with_priority(input.clone(), 5, Priority::Low);
410 let _id2 = scheduler.submit_with_priority(input.clone(), 5, Priority::High);
411 let _id3 = scheduler.submit_with_priority(input.clone(), 5, Priority::Normal);
412
413 assert_eq!(scheduler.pending[0].priority, Priority::High);
415 assert_eq!(scheduler.stats().pending_requests, 3);
416 }
417
418 #[test]
419 fn test_scheduler_stats() {
420 let batch_config = BatchConfig::new();
421 let engine_config = EngineConfig::new(3, 3);
422 let mut scheduler = BatchScheduler::new(batch_config, engine_config).unwrap();
423
424 let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
425 scheduler.submit(input.clone(), 5);
426 scheduler.submit(input.clone(), 5);
427
428 let stats = scheduler.stats();
429 assert_eq!(stats.pending_requests, 2);
430 assert_eq!(stats.total_submitted, 2);
431 }
432}