use crate::{Results, YOLOModel, source::SourceMeta};
use image::DynamicImage;
pub struct BatchProcessor<'a, F>
where
F: FnMut(Vec<Vec<Results>>, &[DynamicImage], &[String], &[SourceMeta]),
{
model: &'a mut YOLOModel,
batch_size: usize,
images: Vec<DynamicImage>,
paths: Vec<String>,
metas: Vec<SourceMeta>,
callback: F,
}
impl<'a, F> BatchProcessor<'a, F>
where
F: FnMut(Vec<Vec<Results>>, &[DynamicImage], &[String], &[SourceMeta]),
{
pub fn new(model: &'a mut YOLOModel, batch_size: usize, callback: F) -> Self {
Self {
model,
batch_size,
images: Vec::with_capacity(batch_size),
paths: Vec::with_capacity(batch_size),
metas: Vec::with_capacity(batch_size),
callback,
}
}
pub fn add(&mut self, image: DynamicImage, path: String, meta: SourceMeta) {
self.images.push(image);
self.paths.push(path);
self.metas.push(meta);
if self.images.len() >= self.batch_size {
self.process();
}
}
pub fn flush(&mut self) {
self.process();
}
fn process(&mut self) {
if self.images.is_empty() {
return;
}
let batch_results = self.run_inference();
(self.callback)(batch_results, &self.images, &self.paths, &self.metas);
self.images.clear();
self.paths.clear();
self.metas.clear();
}
fn run_inference(&mut self) -> Vec<Vec<Results>> {
if let Ok(batch_results) = self.model.predict_batch(&self.images, &self.paths) {
return batch_results;
}
eprintln!("WARNING ⚠️ Batch inference failed. Falling back to single-image inference...");
let mut fallback_results = Vec::with_capacity(self.images.len());
for (idx, img) in self.images.iter().enumerate() {
let path = &self.paths[idx];
match self.model.predict_image(img, path.clone()) {
Ok(results) => fallback_results.push(results),
Err(e) => {
eprintln!("Error processing {path}: {e}");
fallback_results.push(Vec::new());
}
}
}
fallback_results
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use std::cell::RefCell;
use std::rc::Rc;
fn load_test_image() -> DynamicImage {
image::open("assets/bus.jpg")
.or_else(|_| image::open("assets/zidane.jpg"))
.unwrap_or_else(|_| DynamicImage::new_rgb8(640, 640))
}
#[test]
#[serial]
fn test_batch_processor_with_model() {
let mut model = YOLOModel::load("yolo26n.onnx").expect("Model should load");
let callback_count = Rc::new(RefCell::new(0));
let callback_count_clone = Rc::clone(&callback_count);
let mut processor =
BatchProcessor::new(&mut model, 1, move |_results, _images, _paths, _metas| {
*callback_count_clone.borrow_mut() += 1;
});
let img1 = load_test_image();
let img2 = load_test_image();
let meta = SourceMeta {
path: "test.jpg".to_string(),
frame_idx: 0,
total_frames: Some(1),
fps: None,
};
processor.add(img1, "img1.jpg".to_string(), meta.clone());
assert_eq!(*callback_count.borrow(), 1);
processor.add(img2, "img2.jpg".to_string(), meta);
assert_eq!(*callback_count.borrow(), 2);
processor.flush();
assert_eq!(*callback_count.borrow(), 2);
}
#[test]
#[serial]
fn test_batch_processor_empty_flush() {
let mut model = YOLOModel::load("yolo26n.onnx").expect("Model should load");
let callback_count = Rc::new(RefCell::new(0));
let callback_count_clone = Rc::clone(&callback_count);
let mut processor =
BatchProcessor::new(&mut model, 1, move |_results, _images, _paths, _metas| {
*callback_count_clone.borrow_mut() += 1;
});
processor.flush();
assert_eq!(*callback_count.borrow(), 0);
}
#[test]
#[serial]
fn test_batch_processor_callback_count() {
let mut model = YOLOModel::load("yolo26n.onnx").expect("Model should load");
let count = Rc::new(RefCell::new(0));
let count_clone = Rc::clone(&count);
let mut processor =
BatchProcessor::new(&mut model, 1, move |_results, _images, _paths, _metas| {
*count_clone.borrow_mut() += 1;
});
let meta = SourceMeta {
path: "test.jpg".to_string(),
frame_idx: 0,
total_frames: Some(1),
fps: None,
};
for i in 0..3 {
let img = load_test_image();
processor.add(img, format!("img{i}.jpg"), meta.clone());
}
processor.flush();
assert_eq!(*count.borrow(), 3);
}
}