qubit_batch/execute/impls/parallel_batch_executor.rs
1/*******************************************************************************
2 *
3 * Copyright (c) 2025 - 2026 Haixing Hu.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 *
7 * Licensed under the Apache License, Version 2.0.
8 *
9 ******************************************************************************/
10use std::sync::Arc;
11use std::thread;
12use std::time::Duration;
13
14use qubit_function::Runnable;
15use qubit_progress::{
16 Progress,
17 reporter::ProgressReporter,
18};
19
20use crate::BatchExecutionError;
21use crate::BatchOutcome;
22use crate::execute::{
23 BatchExecutionState,
24 BatchExecutor,
25 SequentialBatchExecutor,
26};
27use crate::utils::run_scoped_parallel;
28
29use super::ParallelBatchExecutorBuildError;
30use super::ParallelBatchExecutorBuilder;
31use super::indexed_task::run_parallel_task;
32
33/// Fixed-width parallel batch executor backed by scoped standard threads.
34///
35/// The executor creates scoped worker threads for each parallel batch run and
36/// shuts them down before [`BatchExecutor::execute`] returns. Because the
37/// workers are scoped to the call, tasks may borrow data from the caller and do
38/// not need to be `'static`.
39///
40/// [`Default`] uses [`Self::DEFAULT_SEQUENTIAL_THRESHOLD`], so batches with at
41/// most 100 declared tasks run through [`SequentialBatchExecutor`] to avoid
42/// thread setup overhead. Configure `sequential_threshold(0)` through
43/// [`Self::builder`] when every non-empty batch should use parallel workers.
44///
45/// ```rust
46/// use qubit_batch::{
47/// BatchExecutor,
48/// ParallelBatchExecutor,
49/// };
50///
51/// let executor = ParallelBatchExecutor::builder()
52/// .thread_count(2)
53/// .sequential_threshold(0)
54/// .build()
55/// .expect("parallel executor configuration should be valid");
56///
57/// let outcome = executor
58/// .for_each(0..4, |value| {
59/// assert!(value < 4);
60/// Ok::<(), &'static str>(())
61/// })
62/// .expect("range length should be exact");
63///
64/// assert!(outcome.is_success());
65/// ```
66#[derive(Clone)]
67pub struct ParallelBatchExecutor {
68 /// Number of worker threads used for parallel executions.
69 pub(crate) thread_count: usize,
70 /// Maximum batch size that still uses sequential execution.
71 pub(crate) sequential_threshold: usize,
72 /// Minimum interval between progress callbacks.
73 pub(crate) report_interval: Duration,
74 /// Reporter receiving batch lifecycle callbacks.
75 pub(crate) reporter: Arc<dyn ProgressReporter>,
76}
77
78impl ParallelBatchExecutor {
79 /// Default interval between progress callbacks.
80 pub const DEFAULT_REPORT_INTERVAL: Duration = Duration::from_secs(5);
81
82 /// Default maximum batch size that still uses sequential execution.
83 pub const DEFAULT_SEQUENTIAL_THRESHOLD: usize = 100;
84
85 /// Returns the default worker-thread count.
86 ///
87 /// # Returns
88 ///
89 /// The available CPU parallelism, or `1` if it cannot be detected.
90 #[inline]
91 pub fn default_thread_count() -> usize {
92 thread::available_parallelism()
93 .map(usize::from)
94 .unwrap_or(1)
95 }
96
97 /// Creates a builder for configuring a parallel batch executor.
98 ///
99 /// # Returns
100 ///
101 /// A builder initialized with default settings.
102 #[inline]
103 pub fn builder() -> ParallelBatchExecutorBuilder {
104 ParallelBatchExecutorBuilder::default()
105 }
106
107 /// Creates a parallel batch executor with `thread_count` workers.
108 ///
109 /// # Parameters
110 ///
111 /// * `thread_count` - Number of scoped worker threads to use.
112 ///
113 /// # Returns
114 ///
115 /// A configured parallel batch executor.
116 ///
117 /// # Errors
118 ///
119 /// Returns [`ParallelBatchExecutorBuildError::ZeroThreadCount`] when
120 /// `thread_count` is zero.
121 #[inline]
122 pub fn new(thread_count: usize) -> Result<Self, ParallelBatchExecutorBuildError> {
123 Self::builder().thread_count(thread_count).build()
124 }
125
126 /// Returns the configured worker-thread count.
127 ///
128 /// # Returns
129 ///
130 /// The maximum number of scoped worker threads used for one batch.
131 #[inline]
132 pub const fn thread_count(&self) -> usize {
133 self.thread_count
134 }
135
136 /// Returns the configured sequential fallback threshold.
137 ///
138 /// # Returns
139 ///
140 /// The maximum task count that still runs sequentially.
141 #[inline]
142 pub const fn sequential_threshold(&self) -> usize {
143 self.sequential_threshold
144 }
145
146 /// Returns the configured progress-report interval.
147 ///
148 /// # Returns
149 ///
150 /// The minimum interval between due-based running progress callbacks.
151 #[inline]
152 pub const fn report_interval(&self) -> Duration {
153 self.report_interval
154 }
155
156 /// Returns the progress reporter used by this executor.
157 ///
158 /// # Returns
159 ///
160 /// A shared reference to the configured progress reporter.
161 #[inline]
162 pub fn reporter(&self) -> &Arc<dyn ProgressReporter> {
163 &self.reporter
164 }
165
166 /// Creates a sequential executor with matching progress configuration.
167 ///
168 /// # Returns
169 ///
170 /// A sequential executor used for small batches.
171 fn sequential_executor(&self) -> SequentialBatchExecutor {
172 SequentialBatchExecutor::builder()
173 .report_interval(self.report_interval)
174 .reporter_arc(Arc::clone(&self.reporter))
175 .build()
176 }
177}
178
179impl Default for ParallelBatchExecutor {
180 /// Creates a default parallel batch executor.
181 ///
182 /// # Returns
183 ///
184 /// A default-configured parallel batch executor.
185 ///
186 /// # Panics
187 ///
188 /// Panics if the default configuration fails validation.
189 fn default() -> Self {
190 Self::builder()
191 .build()
192 .expect("default parallel batch executor should build")
193 }
194}
195
196impl BatchExecutor for ParallelBatchExecutor {
197 /// Executes the batch on scoped standard threads when the batch is large
198 /// enough.
199 ///
200 /// # Parameters
201 ///
202 /// * `tasks` - Task source for the batch.
203 /// * `count` - Declared task count expected from `tasks`.
204 ///
205 /// # Returns
206 ///
207 /// A structured batch result when the declared task count matches, or a
208 /// batch-count mismatch error with the attached partial result.
209 ///
210 /// # Errors
211 ///
212 /// Returns [`BatchExecutionError`] when `tasks` yields fewer or more tasks
213 /// than `count`.
214 ///
215 /// # Panics
216 ///
217 /// Panics from tasks are captured in the result. Panics from the configured
218 /// progress reporter are propagated to the caller.
219 fn execute_with_count<T, E, I>(
220 &self,
221 tasks: I,
222 count: usize,
223 ) -> Result<BatchOutcome<E>, BatchExecutionError<E>>
224 where
225 I: IntoIterator<Item = T>,
226 T: Runnable<E> + Send,
227 E: Send,
228 {
229 if count <= self.sequential_threshold || self.thread_count <= 1 {
230 return self.sequential_executor().execute_with_count(tasks, count);
231 }
232
233 let state = Arc::new(BatchExecutionState::new(count));
234 let progress = Progress::new(self.reporter.as_ref(), self.report_interval);
235 progress.report_started(state.progress_counters());
236 let mut actual_count = 0usize;
237 let worker_count = self.thread_count.min(count);
238
239 thread::scope(|scope| {
240 let reporter_state = Arc::clone(&state);
241 let running_progress =
242 progress.spawn_running_reporter(scope, move || reporter_state.progress_counters());
243 let running_point_handle = running_progress.point_handle();
244
245 let observer_state = Arc::clone(&state);
246 let worker_state = Arc::clone(&state);
247 actual_count = run_scoped_parallel(
248 tasks,
249 count,
250 worker_count,
251 move || observer_state.record_task_observed(),
252 move |index, task| {
253 run_parallel_task(&worker_state, index, task);
254 running_point_handle.report();
255 },
256 );
257 running_progress.stop_and_join();
258 });
259
260 let state = Arc::into_inner(state)
261 .expect("parallel batch execution state should have a single owner");
262 if actual_count < count {
263 let failed = progress.report_failed(state.progress_counters());
264 let result = state.into_outcome(failed.elapsed());
265 Err(BatchExecutionError::CountShortfall {
266 expected: count,
267 actual: actual_count,
268 outcome: result,
269 })
270 } else if actual_count > count {
271 let failed = progress.report_failed(state.progress_counters());
272 let result = state.into_outcome(failed.elapsed());
273 Err(BatchExecutionError::CountExceeded {
274 expected: count,
275 observed_at_least: actual_count,
276 outcome: result,
277 })
278 } else {
279 let finished = progress.report_finished(state.progress_counters());
280 let result = state.into_outcome(finished.elapsed());
281 Ok(result)
282 }
283 }
284}