Skip to main content

qubit_batch/process/impls/
parallel_batch_processor.rs

1/*******************************************************************************
2 *
3 *    Copyright (c) 2025 - 2026 Haixing Hu.
4 *
5 *    SPDX-License-Identifier: Apache-2.0
6 *
7 *    Licensed under the Apache License, Version 2.0.
8 *
9 ******************************************************************************/
10use std::{
11    num::NonZeroUsize,
12    sync::Arc,
13    thread,
14    time::Duration,
15};
16
17use qubit_function::{
18    ArcConsumer,
19    Consumer,
20};
21use qubit_progress::{
22    Progress,
23    reporter::ProgressReporter,
24};
25
26use crate::process::{
27    BatchProcessError,
28    BatchProcessResult,
29    BatchProcessState,
30    BatchProcessor,
31};
32use crate::utils::run_scoped_parallel;
33
34use super::parallel_batch_processor_builder::ParallelBatchProcessorBuilder;
35
36/// Processes batch items with sequential fallback and scoped standard threads.
37///
38/// The processor stores the supplied consumer as an [`ArcConsumer`] so every
39/// worker can share it safely. By default, small batches run sequentially to
40/// avoid thread setup overhead. Larger batches use scoped worker threads for
41/// each [`BatchProcessor::process`] call, therefore input items may borrow data
42/// from the caller as long as they are [`Send`]. Running progress is reported
43/// between items on the sequential path and from a scoped reporter thread on
44/// the parallel path.
45///
46/// # Type Parameters
47///
48/// * `Item` - Item type consumed by the stored consumer.
49///
50/// ```rust
51/// use std::{
52///     sync::{
53///         Arc,
54///         atomic::{
55///             AtomicUsize,
56///             Ordering,
57///         },
58///     },
59/// };
60///
61/// use qubit_batch::{
62///     BatchProcessor,
63///     ParallelBatchProcessor,
64/// };
65///
66/// let total = Arc::new(AtomicUsize::new(0));
67/// let total_for_consumer = Arc::clone(&total);
68/// let mut processor = ParallelBatchProcessor::builder(move |item: &usize| {
69///     total_for_consumer.fetch_add(*item, Ordering::Relaxed);
70/// })
71/// .thread_count(2)
72/// .sequential_threshold(0)
73/// .build()
74/// .expect("parallel processor configuration should be valid");
75///
76/// let result = processor
77///     .process([1, 2, 3])
78///     .expect("array length should be exact");
79///
80/// assert!(result.is_success());
81/// assert_eq!(total.load(Ordering::Relaxed), 6);
82/// ```
83pub struct ParallelBatchProcessor<Item> {
84    /// Consumer shared by all scoped workers.
85    pub(crate) consumer: ArcConsumer<Item>,
86    /// Fixed worker-thread count used by each processing call.
87    pub(crate) thread_count: NonZeroUsize,
88    /// Maximum batch size that still uses sequential processing.
89    pub(crate) sequential_threshold: usize,
90    /// Minimum interval between progress callbacks.
91    pub(crate) report_interval: Duration,
92    /// Reporter receiving batch lifecycle callbacks.
93    pub(crate) reporter: Arc<dyn ProgressReporter>,
94}
95
96impl<Item> ParallelBatchProcessor<Item> {
97    /// Default interval between progress callbacks.
98    pub const DEFAULT_REPORT_INTERVAL: Duration = Duration::from_secs(5);
99
100    /// Default maximum batch size that still uses sequential processing.
101    pub const DEFAULT_SEQUENTIAL_THRESHOLD: usize = 100;
102
103    /// Creates a parallel consumer-backed batch processor.
104    ///
105    /// # Parameters
106    ///
107    /// * `consumer` - Thread-safe consumer invoked once for each accepted item.
108    ///
109    /// # Returns
110    ///
111    /// A processor storing `consumer` as an [`ArcConsumer`] and using
112    /// [`Self::default_thread_count`] workers.
113    #[inline]
114    pub fn new<C>(consumer: C) -> Self
115    where
116        C: Consumer<Item> + Send + Sync + 'static,
117    {
118        Self::builder(consumer)
119            .build()
120            .expect("default parallel batch processor should build")
121    }
122
123    /// Creates a builder for configuring a parallel consumer-backed processor.
124    ///
125    /// # Parameters
126    ///
127    /// * `consumer` - Thread-safe consumer invoked once for each accepted item.
128    ///
129    /// # Returns
130    ///
131    /// A builder initialized with default settings.
132    #[inline]
133    pub fn builder<C>(consumer: C) -> ParallelBatchProcessorBuilder<Item>
134    where
135        C: Consumer<Item> + Send + Sync + 'static,
136    {
137        ParallelBatchProcessorBuilder::new(consumer)
138    }
139
140    /// Returns the default worker-thread count.
141    ///
142    /// # Returns
143    ///
144    /// The available CPU parallelism, or `1` if it cannot be detected.
145    #[inline]
146    pub fn default_thread_count() -> usize {
147        thread::available_parallelism()
148            .map(usize::from)
149            .unwrap_or(1)
150    }
151
152    /// Returns the configured worker-thread count.
153    ///
154    /// # Returns
155    ///
156    /// The maximum number of scoped worker threads used for one batch.
157    #[inline]
158    pub const fn thread_count(&self) -> usize {
159        self.thread_count.get()
160    }
161
162    /// Returns the configured sequential fallback threshold.
163    ///
164    /// # Returns
165    ///
166    /// The maximum item count that still runs sequentially.
167    #[inline]
168    pub const fn sequential_threshold(&self) -> usize {
169        self.sequential_threshold
170    }
171
172    /// Returns the configured progress-report interval.
173    ///
174    /// # Returns
175    ///
176    /// The minimum time between due-based running progress callbacks.
177    #[inline]
178    pub const fn report_interval(&self) -> Duration {
179        self.report_interval
180    }
181
182    /// Returns the configured progress reporter.
183    ///
184    /// # Returns
185    ///
186    /// A shared reference to the configured progress reporter.
187    #[inline]
188    pub fn reporter(&self) -> &Arc<dyn ProgressReporter> {
189        &self.reporter
190    }
191
192    /// Returns the stored consumer.
193    ///
194    /// # Returns
195    ///
196    /// A shared reference to the arc-backed consumer.
197    #[inline]
198    pub const fn consumer(&self) -> &ArcConsumer<Item> {
199        &self.consumer
200    }
201
202    /// Consumes this processor and returns the stored consumer.
203    ///
204    /// # Returns
205    ///
206    /// The arc-backed consumer used by this processor.
207    #[inline]
208    pub fn into_consumer(self) -> ArcConsumer<Item> {
209        self.consumer
210    }
211}
212
213impl<Item> BatchProcessor<Item> for ParallelBatchProcessor<Item>
214where
215    Item: Send,
216{
217    type Error = BatchProcessError;
218
219    /// Processes items sequentially for small batches or on scoped workers.
220    ///
221    /// # Parameters
222    ///
223    /// * `items` - Item source for the batch.
224    /// * `count` - Declared number of items expected from `items`.
225    ///
226    /// # Returns
227    ///
228    /// A result with completed and processed counts equal to the number of
229    /// consumer calls when the input source yields exactly `count` items.
230    ///
231    /// # Errors
232    ///
233    /// Returns [`BatchProcessError::CountShortfall`] when the source ends before
234    /// `count`, or [`BatchProcessError::CountExceeded`] when the source yields an
235    /// extra item. Extra items are observed but not passed to the consumer.
236    ///
237    /// # Panics
238    ///
239    /// Propagates any panic raised by the stored consumer from the caller thread
240    /// or a worker thread, or by the configured progress reporter.
241    fn process_with_count<I>(
242        &mut self,
243        items: I,
244        count: usize,
245    ) -> Result<BatchProcessResult, Self::Error>
246    where
247        I: IntoIterator<Item = Item>,
248    {
249        let state = Arc::new(BatchProcessState::new(count));
250        let mut progress = Progress::new(self.reporter.as_ref(), self.report_interval);
251        progress.report_started(state.progress_counters());
252
253        if count > 0 {
254            if count <= self.sequential_threshold {
255                self.process_sequential(items, count, state.as_ref(), &mut progress);
256            } else {
257                self.process_parallel_non_empty(items, count, Arc::clone(&state), &progress);
258            }
259        } else if items.into_iter().next().is_some() {
260            state.record_item_observed();
261        }
262
263        if state.observed_count() < count {
264            let failed = progress.report_failed(state.progress_counters());
265            let result = state.to_direct_result(failed.elapsed());
266            Err(BatchProcessError::CountShortfall {
267                expected: count,
268                actual: state.observed_count(),
269                result,
270            })
271        } else if state.observed_count() > count {
272            let failed = progress.report_failed(state.progress_counters());
273            let result = state.to_direct_result(failed.elapsed());
274            Err(BatchProcessError::CountExceeded {
275                expected: count,
276                observed_at_least: state.observed_count(),
277                result,
278            })
279        } else {
280            let finished = progress.report_finished(state.progress_counters());
281            let result = state.to_direct_result(finished.elapsed());
282            Ok(result)
283        }
284    }
285}
286
287impl<Item> ParallelBatchProcessor<Item>
288where
289    Item: Send,
290{
291    /// Processes a declared batch on the caller thread.
292    ///
293    /// # Parameters
294    ///
295    /// * `items` - Item source for the batch.
296    /// * `count` - Declared item count.
297    /// * `state` - Processing state updated by this method.
298    /// * `progress` - Progress run used for between-item running callbacks.
299    ///
300    /// # Panics
301    ///
302    /// Propagates any panic raised while invoking the stored consumer.
303    fn process_sequential<I>(
304        &self,
305        items: I,
306        count: usize,
307        state: &BatchProcessState,
308        progress: &mut Progress<'_>,
309    ) where
310        I: IntoIterator<Item = Item>,
311    {
312        for item in items {
313            let observed_count = state.record_item_observed();
314            if observed_count > count {
315                break;
316            }
317            state.record_item_started();
318            self.consumer.accept(&item);
319            state.record_item_processed();
320            let _ = progress.report_running_if_due(state.progress_counters());
321        }
322    }
323
324    /// Processes a non-empty declared batch through scoped workers.
325    ///
326    /// # Parameters
327    ///
328    /// * `items` - Item source for the batch.
329    /// * `count` - Declared item count.
330    /// * `state` - Shared processing state updated by producer and workers.
331    /// * `progress` - Progress run used to spawn the running reporter.
332    ///
333    /// # Panics
334    ///
335    /// Propagates any worker panic raised while invoking the stored consumer.
336    fn process_parallel_non_empty<I>(
337        &self,
338        items: I,
339        count: usize,
340        state: Arc<BatchProcessState>,
341        progress: &Progress<'_>,
342    ) where
343        I: IntoIterator<Item = Item>,
344    {
345        thread::scope(|scope| {
346            let reporter_state = Arc::clone(&state);
347            let running_progress =
348                progress.spawn_running_reporter(scope, move || reporter_state.progress_counters());
349            let running_point_handle = running_progress.point_handle();
350
351            let worker_count = self.thread_count.get().min(count);
352            let observer_state = Arc::clone(&state);
353            let worker_state = Arc::clone(&state);
354            let consumer = self.consumer.clone();
355            run_scoped_parallel(
356                items,
357                count,
358                worker_count,
359                move || observer_state.record_item_observed(),
360                move |_index, item| {
361                    worker_state.record_item_started();
362                    consumer.accept(&item);
363                    worker_state.record_item_processed();
364                    running_point_handle.report();
365                },
366            );
367            running_progress.stop_and_join();
368        });
369    }
370}