qubit_batch/process/impls/parallel_batch_processor.rs
1/*******************************************************************************
2 *
3 * Copyright (c) 2025 - 2026 Haixing Hu.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 *
7 * Licensed under the Apache License, Version 2.0.
8 *
9 ******************************************************************************/
10use std::{
11 num::NonZeroUsize,
12 sync::Arc,
13 thread,
14 time::Duration,
15};
16
17use qubit_function::{
18 ArcConsumer,
19 Consumer,
20};
21use qubit_progress::{
22 Progress,
23 reporter::ProgressReporter,
24};
25
26use crate::process::{
27 BatchProcessError,
28 BatchProcessResult,
29 BatchProcessState,
30 BatchProcessor,
31};
32use crate::utils::run_scoped_parallel;
33
34use super::parallel_batch_processor_builder::ParallelBatchProcessorBuilder;
35
36/// Processes batch items with sequential fallback and scoped standard threads.
37///
38/// The processor stores the supplied consumer as an [`ArcConsumer`] so every
39/// worker can share it safely. By default, small batches run sequentially to
40/// avoid thread setup overhead. Larger batches use scoped worker threads for
41/// each [`BatchProcessor::process`] call, therefore input items may borrow data
42/// from the caller as long as they are [`Send`]. Running progress is reported
43/// between items on the sequential path and from a scoped reporter thread on
44/// the parallel path.
45///
46/// # Type Parameters
47///
48/// * `Item` - Item type consumed by the stored consumer.
49///
50/// ```rust
51/// use std::{
52/// sync::{
53/// Arc,
54/// atomic::{
55/// AtomicUsize,
56/// Ordering,
57/// },
58/// },
59/// };
60///
61/// use qubit_batch::{
62/// BatchProcessor,
63/// ParallelBatchProcessor,
64/// };
65///
66/// let total = Arc::new(AtomicUsize::new(0));
67/// let total_for_consumer = Arc::clone(&total);
68/// let mut processor = ParallelBatchProcessor::builder(move |item: &usize| {
69/// total_for_consumer.fetch_add(*item, Ordering::Relaxed);
70/// })
71/// .thread_count(2)
72/// .sequential_threshold(0)
73/// .build()
74/// .expect("parallel processor configuration should be valid");
75///
76/// let result = processor
77/// .process([1, 2, 3])
78/// .expect("array length should be exact");
79///
80/// assert!(result.is_success());
81/// assert_eq!(total.load(Ordering::Relaxed), 6);
82/// ```
83pub struct ParallelBatchProcessor<Item> {
84 /// Consumer shared by all scoped workers.
85 pub(crate) consumer: ArcConsumer<Item>,
86 /// Fixed worker-thread count used by each processing call.
87 pub(crate) thread_count: NonZeroUsize,
88 /// Maximum batch size that still uses sequential processing.
89 pub(crate) sequential_threshold: usize,
90 /// Minimum interval between progress callbacks.
91 pub(crate) report_interval: Duration,
92 /// Reporter receiving batch lifecycle callbacks.
93 pub(crate) reporter: Arc<dyn ProgressReporter>,
94}
95
96impl<Item> ParallelBatchProcessor<Item> {
97 /// Default interval between progress callbacks.
98 pub const DEFAULT_REPORT_INTERVAL: Duration = Duration::from_secs(5);
99
100 /// Default maximum batch size that still uses sequential processing.
101 pub const DEFAULT_SEQUENTIAL_THRESHOLD: usize = 100;
102
103 /// Creates a parallel consumer-backed batch processor.
104 ///
105 /// # Parameters
106 ///
107 /// * `consumer` - Thread-safe consumer invoked once for each accepted item.
108 ///
109 /// # Returns
110 ///
111 /// A processor storing `consumer` as an [`ArcConsumer`] and using
112 /// [`Self::default_thread_count`] workers.
113 #[inline]
114 pub fn new<C>(consumer: C) -> Self
115 where
116 C: Consumer<Item> + Send + Sync + 'static,
117 {
118 Self::builder(consumer)
119 .build()
120 .expect("default parallel batch processor should build")
121 }
122
123 /// Creates a builder for configuring a parallel consumer-backed processor.
124 ///
125 /// # Parameters
126 ///
127 /// * `consumer` - Thread-safe consumer invoked once for each accepted item.
128 ///
129 /// # Returns
130 ///
131 /// A builder initialized with default settings.
132 #[inline]
133 pub fn builder<C>(consumer: C) -> ParallelBatchProcessorBuilder<Item>
134 where
135 C: Consumer<Item> + Send + Sync + 'static,
136 {
137 ParallelBatchProcessorBuilder::new(consumer)
138 }
139
140 /// Returns the default worker-thread count.
141 ///
142 /// # Returns
143 ///
144 /// The available CPU parallelism, or `1` if it cannot be detected.
145 #[inline]
146 pub fn default_thread_count() -> usize {
147 thread::available_parallelism()
148 .map(usize::from)
149 .unwrap_or(1)
150 }
151
152 /// Returns the configured worker-thread count.
153 ///
154 /// # Returns
155 ///
156 /// The maximum number of scoped worker threads used for one batch.
157 #[inline]
158 pub const fn thread_count(&self) -> usize {
159 self.thread_count.get()
160 }
161
162 /// Returns the configured sequential fallback threshold.
163 ///
164 /// # Returns
165 ///
166 /// The maximum item count that still runs sequentially.
167 #[inline]
168 pub const fn sequential_threshold(&self) -> usize {
169 self.sequential_threshold
170 }
171
172 /// Returns the configured progress-report interval.
173 ///
174 /// # Returns
175 ///
176 /// The minimum time between due-based running progress callbacks.
177 #[inline]
178 pub const fn report_interval(&self) -> Duration {
179 self.report_interval
180 }
181
182 /// Returns the configured progress reporter.
183 ///
184 /// # Returns
185 ///
186 /// A shared reference to the configured progress reporter.
187 #[inline]
188 pub fn reporter(&self) -> &Arc<dyn ProgressReporter> {
189 &self.reporter
190 }
191
192 /// Returns the stored consumer.
193 ///
194 /// # Returns
195 ///
196 /// A shared reference to the arc-backed consumer.
197 #[inline]
198 pub const fn consumer(&self) -> &ArcConsumer<Item> {
199 &self.consumer
200 }
201
202 /// Consumes this processor and returns the stored consumer.
203 ///
204 /// # Returns
205 ///
206 /// The arc-backed consumer used by this processor.
207 #[inline]
208 pub fn into_consumer(self) -> ArcConsumer<Item> {
209 self.consumer
210 }
211}
212
213impl<Item> BatchProcessor<Item> for ParallelBatchProcessor<Item>
214where
215 Item: Send,
216{
217 type Error = BatchProcessError;
218
219 /// Processes items sequentially for small batches or on scoped workers.
220 ///
221 /// # Parameters
222 ///
223 /// * `items` - Item source for the batch.
224 /// * `count` - Declared number of items expected from `items`.
225 ///
226 /// # Returns
227 ///
228 /// A result with completed and processed counts equal to the number of
229 /// consumer calls when the input source yields exactly `count` items.
230 ///
231 /// # Errors
232 ///
233 /// Returns [`BatchProcessError::CountShortfall`] when the source ends before
234 /// `count`, or [`BatchProcessError::CountExceeded`] when the source yields an
235 /// extra item. Extra items are observed but not passed to the consumer.
236 ///
237 /// # Panics
238 ///
239 /// Propagates any panic raised by the stored consumer from the caller thread
240 /// or a worker thread, or by the configured progress reporter.
241 fn process_with_count<I>(
242 &mut self,
243 items: I,
244 count: usize,
245 ) -> Result<BatchProcessResult, Self::Error>
246 where
247 I: IntoIterator<Item = Item>,
248 {
249 let state = Arc::new(BatchProcessState::new(count));
250 let mut progress = Progress::new(self.reporter.as_ref(), self.report_interval);
251 progress.report_started(state.progress_counters());
252
253 if count > 0 {
254 if count <= self.sequential_threshold {
255 self.process_sequential(items, count, state.as_ref(), &mut progress);
256 } else {
257 self.process_parallel_non_empty(items, count, Arc::clone(&state), &progress);
258 }
259 } else if items.into_iter().next().is_some() {
260 state.record_item_observed();
261 }
262
263 if state.observed_count() < count {
264 let failed = progress.report_failed(state.progress_counters());
265 let result = state.to_direct_result(failed.elapsed());
266 Err(BatchProcessError::CountShortfall {
267 expected: count,
268 actual: state.observed_count(),
269 result,
270 })
271 } else if state.observed_count() > count {
272 let failed = progress.report_failed(state.progress_counters());
273 let result = state.to_direct_result(failed.elapsed());
274 Err(BatchProcessError::CountExceeded {
275 expected: count,
276 observed_at_least: state.observed_count(),
277 result,
278 })
279 } else {
280 let finished = progress.report_finished(state.progress_counters());
281 let result = state.to_direct_result(finished.elapsed());
282 Ok(result)
283 }
284 }
285}
286
287impl<Item> ParallelBatchProcessor<Item>
288where
289 Item: Send,
290{
291 /// Processes a declared batch on the caller thread.
292 ///
293 /// # Parameters
294 ///
295 /// * `items` - Item source for the batch.
296 /// * `count` - Declared item count.
297 /// * `state` - Processing state updated by this method.
298 /// * `progress` - Progress run used for between-item running callbacks.
299 ///
300 /// # Panics
301 ///
302 /// Propagates any panic raised while invoking the stored consumer.
303 fn process_sequential<I>(
304 &self,
305 items: I,
306 count: usize,
307 state: &BatchProcessState,
308 progress: &mut Progress<'_>,
309 ) where
310 I: IntoIterator<Item = Item>,
311 {
312 for item in items {
313 let observed_count = state.record_item_observed();
314 if observed_count > count {
315 break;
316 }
317 state.record_item_started();
318 self.consumer.accept(&item);
319 state.record_item_processed();
320 let _ = progress.report_running_if_due(state.progress_counters());
321 }
322 }
323
324 /// Processes a non-empty declared batch through scoped workers.
325 ///
326 /// # Parameters
327 ///
328 /// * `items` - Item source for the batch.
329 /// * `count` - Declared item count.
330 /// * `state` - Shared processing state updated by producer and workers.
331 /// * `progress` - Progress run used to spawn the running reporter.
332 ///
333 /// # Panics
334 ///
335 /// Propagates any worker panic raised while invoking the stored consumer.
336 fn process_parallel_non_empty<I>(
337 &self,
338 items: I,
339 count: usize,
340 state: Arc<BatchProcessState>,
341 progress: &Progress<'_>,
342 ) where
343 I: IntoIterator<Item = Item>,
344 {
345 thread::scope(|scope| {
346 let reporter_state = Arc::clone(&state);
347 let running_progress =
348 progress.spawn_running_reporter(scope, move || reporter_state.progress_counters());
349 let running_point_handle = running_progress.point_handle();
350
351 let worker_count = self.thread_count.get().min(count);
352 let observer_state = Arc::clone(&state);
353 let worker_state = Arc::clone(&state);
354 let consumer = self.consumer.clone();
355 run_scoped_parallel(
356 items,
357 count,
358 worker_count,
359 move || observer_state.record_item_observed(),
360 move |_index, item| {
361 worker_state.record_item_started();
362 consumer.accept(&item);
363 worker_state.record_item_processed();
364 running_point_handle.report();
365 },
366 );
367 running_progress.stop_and_join();
368 });
369 }
370}