ultralytics_inference/batch.rs
1// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
3//! Batch processing module for YOLO inference.
4//!
5//! This module provides the [`BatchProcessor`] struct, which abstracts the logic for
6//! buffering images and running batch inference. It handles:
7//!
8//! - **Buffering**: Collects images until the batch size is reached
9//! - **Batch inference**: Runs inference on the full batch
10//! - **Automatic fallback**: Falls back to single-image inference if batch fails
11//! - **Callback invocation**: Invokes a user-provided callback with results
12//!
13//! # Usage
14//!
15//! ```no_run
16//! use ultralytics_inference::{YOLOModel, batch::BatchProcessor};
17//!
18//! let mut model = YOLOModel::load("yolo26n.onnx")?;
19//! let mut processor = BatchProcessor::new(&mut model, 4, |results, images, paths, metas| {
20//! for (idx, result_vec) in results.iter().enumerate() {
21//! println!("Image {}: {} detections", paths[idx], result_vec.len());
22//! }
23//! });
24//!
25//! // Add images as they become available
26//! // processor.add(image, path, meta);
27//!
28//! // Don't forget to flush remaining images
29//! processor.flush();
30//! # Ok::<(), Box<dyn std::error::Error>>(())
31//! ```
32
33use crate::{Results, YOLOModel, source::SourceMeta};
34use image::DynamicImage;
35
36/// A processor for handling batch inference.
37///
38/// This struct manages collecting images into batches, running inference (with fallback),
39/// and invoking a callback with the results.
40///
41/// # Example
42///
43/// ```no_run
44/// use ultralytics_inference::{YOLOModel, batch::BatchProcessor};
45///
46/// fn main() -> Result<(), Box<dyn std::error::Error>> {
47/// let mut model = YOLOModel::load("yolo26n.onnx")?;
48/// let batch_size = 4;
49///
50/// let mut processor = BatchProcessor::new(&mut model, batch_size, |results, images, paths, metas| {
51/// println!("Processed batch of {} images", results.len());
52/// });
53///
54/// // Add images...
55/// // processor.add(image, path, meta);
56///
57/// processor.flush();
58/// Ok(())
59/// }
60/// ```
61pub struct BatchProcessor<'a, F>
62where
63 F: FnMut(Vec<Vec<Results>>, &[DynamicImage], &[String], &[SourceMeta]),
64{
65 model: &'a mut YOLOModel,
66 batch_size: usize,
67 images: Vec<DynamicImage>,
68 paths: Vec<String>,
69 metas: Vec<SourceMeta>,
70 callback: F,
71}
72
73impl<'a, F> BatchProcessor<'a, F>
74where
75 F: FnMut(Vec<Vec<Results>>, &[DynamicImage], &[String], &[SourceMeta]),
76{
77 /// Create a new `BatchProcessor`.
78 ///
79 /// # Arguments
80 ///
81 /// * `model` - Mutable reference to the [`YOLOModel`] for inference.
82 /// * `batch_size` - Maximum number of images to collect before processing.
83 /// * `callback` - Closure invoked with batch results. Receives:
84 /// - `Vec<Vec<Results>>` - Results for each image in the batch
85 /// - `&[DynamicImage]` - The batch images
86 /// - `&[String]` - Paths for each image
87 /// - `&[SourceMeta]` - Metadata for each image
88 ///
89 /// # Returns
90 ///
91 /// A new `BatchProcessor` instance.
92 pub fn new(model: &'a mut YOLOModel, batch_size: usize, callback: F) -> Self {
93 Self {
94 model,
95 batch_size,
96 images: Vec::with_capacity(batch_size),
97 paths: Vec::with_capacity(batch_size),
98 metas: Vec::with_capacity(batch_size),
99 callback,
100 }
101 }
102
103 /// Add an image to the batch.
104 ///
105 /// If the batch becomes full (reaches `batch_size`), it is automatically processed
106 /// and the callback is invoked.
107 ///
108 /// # Arguments
109 ///
110 /// * `image` - The image to add.
111 /// * `path` - Path or identifier for this image.
112 /// * `meta` - Source metadata for this image.
113 pub fn add(&mut self, image: DynamicImage, path: String, meta: SourceMeta) {
114 self.images.push(image);
115 self.paths.push(path);
116 self.metas.push(meta);
117
118 if self.images.len() >= self.batch_size {
119 self.process();
120 }
121 }
122
123 /// Process any remaining images in the batch.
124 ///
125 /// This should be called after all images have been added to ensure
126 /// the last partial batch is processed. Has no effect if the batch is empty.
127 pub fn flush(&mut self) {
128 self.process();
129 }
130
131 fn process(&mut self) {
132 if self.images.is_empty() {
133 return;
134 }
135
136 let batch_results = self.run_inference();
137 (self.callback)(batch_results, &self.images, &self.paths, &self.metas);
138
139 self.images.clear();
140 self.paths.clear();
141 self.metas.clear();
142 }
143
144 fn run_inference(&mut self) -> Vec<Vec<Results>> {
145 if let Ok(batch_results) = self.model.predict_batch(&self.images, &self.paths) {
146 return batch_results;
147 }
148
149 eprintln!("WARNING ⚠️ Batch inference failed. Falling back to single-image inference...");
150
151 let mut fallback_results = Vec::with_capacity(self.images.len());
152 for (idx, img) in self.images.iter().enumerate() {
153 let path = &self.paths[idx];
154 match self.model.predict_image(img, path.clone()) {
155 Ok(results) => fallback_results.push(results),
156 Err(e) => {
157 eprintln!("Error processing {path}: {e}");
158 fallback_results.push(Vec::new());
159 }
160 }
161 }
162 fallback_results
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use serial_test::serial;
170 use std::cell::RefCell;
171 use std::rc::Rc;
172
173 /// Helper to load a test image from assets.
174 fn load_test_image() -> DynamicImage {
175 // Use bus.jpg which should exist in assets/
176 image::open("assets/bus.jpg")
177 .or_else(|_| image::open("assets/zidane.jpg"))
178 .unwrap_or_else(|_| DynamicImage::new_rgb8(640, 640))
179 }
180
181 /// Test that `BatchProcessor` correctly buffers images and invokes callback.
182 ///
183 /// Uses `batch_size=1` since the default yolo26n.onnx model only supports batch=1.
184 /// The model is auto-downloaded if not present.
185 #[test]
186 #[serial]
187 fn test_batch_processor_with_model() {
188 let mut model = YOLOModel::load("yolo26n.onnx").expect("Model should load");
189
190 let callback_count = Rc::new(RefCell::new(0));
191 let callback_count_clone = Rc::clone(&callback_count);
192
193 // Use batch_size=1 since default model only supports batch=1
194 let mut processor =
195 BatchProcessor::new(&mut model, 1, move |_results, _images, _paths, _metas| {
196 *callback_count_clone.borrow_mut() += 1;
197 });
198
199 // Load real test images
200 let img1 = load_test_image();
201 let img2 = load_test_image();
202
203 let meta = SourceMeta {
204 path: "test.jpg".to_string(),
205 frame_idx: 0,
206 total_frames: Some(1),
207 fps: None,
208 };
209
210 // Add first image - should trigger callback immediately (batch_size=1)
211 processor.add(img1, "img1.jpg".to_string(), meta.clone());
212 assert_eq!(*callback_count.borrow(), 1);
213
214 // Add second image - should trigger another callback
215 processor.add(img2, "img2.jpg".to_string(), meta);
216 assert_eq!(*callback_count.borrow(), 2);
217
218 // Flush should not trigger callback (batch is empty)
219 processor.flush();
220 assert_eq!(*callback_count.borrow(), 2);
221 }
222
223 /// Test that flush on empty processor does nothing.
224 #[test]
225 #[serial]
226 fn test_batch_processor_empty_flush() {
227 let mut model = YOLOModel::load("yolo26n.onnx").expect("Model should load");
228
229 let callback_count = Rc::new(RefCell::new(0));
230 let callback_count_clone = Rc::clone(&callback_count);
231
232 let mut processor =
233 BatchProcessor::new(&mut model, 1, move |_results, _images, _paths, _metas| {
234 *callback_count_clone.borrow_mut() += 1;
235 });
236
237 // Flush without adding anything should not call callback
238 processor.flush();
239 assert_eq!(*callback_count.borrow(), 0);
240 }
241
242 /// Test that callback is invoked correct number of times with results.
243 #[test]
244 #[serial]
245 fn test_batch_processor_callback_count() {
246 let mut model = YOLOModel::load("yolo26n.onnx").expect("Model should load");
247
248 let count = Rc::new(RefCell::new(0));
249 let count_clone = Rc::clone(&count);
250
251 // Use `batch_size=1` to work with default model (which only supports batch=1)
252 let mut processor =
253 BatchProcessor::new(&mut model, 1, move |_results, _images, _paths, _metas| {
254 *count_clone.borrow_mut() += 1;
255 });
256
257 let meta = SourceMeta {
258 path: "test.jpg".to_string(),
259 frame_idx: 0,
260 total_frames: Some(1),
261 fps: None,
262 };
263
264 // Add 3 images with batch_size=1
265 for i in 0..3 {
266 let img = load_test_image();
267 processor.add(img, format!("img{i}.jpg"), meta.clone());
268 }
269 processor.flush();
270
271 // Should have 3 callbacks (one per image since batch_size=1)
272 assert_eq!(*count.borrow(), 3);
273 }
274}