qubit_batch/execute/batch_executor.rs
1/*******************************************************************************
2 *
3 * Copyright (c) 2025 - 2026 Haixing Hu.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 *
7 * Licensed under the Apache License, Version 2.0.
8 *
9 ******************************************************************************/
10use std::sync::Arc;
11
12use crossbeam_queue::SegQueue;
13use qubit_function::{
14 Callable,
15 Runnable,
16};
17
18use crate::{
19 BatchExecutionError,
20 BatchOutcome,
21};
22
23use super::{
24 BatchCallResult,
25 callable_task::CallableTask,
26 for_each_task::ForEachTask,
27};
28
29/// Executes batches of fallible tasks.
30///
31/// Implementations consume the supplied iterator once, execute every observed
32/// task unless an explicitly declared count is exceeded, and return a
33/// [`BatchOutcome`] containing task-level successes, failures, panics, and
34/// elapsed time.
35///
36/// ```rust
37/// use qubit_batch::{
38/// BatchExecutor,
39/// SequentialBatchExecutor,
40/// };
41///
42/// let outcome = SequentialBatchExecutor::new()
43/// .for_each([1, 2, 3], |value| {
44/// assert!(value > 0);
45/// Ok::<(), &'static str>(())
46/// })
47/// .expect("array length should be exact");
48///
49/// assert!(outcome.is_success());
50/// ```
51pub trait BatchExecutor: Send + Sync {
52 /// Executes a batch of runnable tasks whose iterator exposes an exact
53 /// length.
54 ///
55 /// # Parameters
56 ///
57 /// * `tasks` - Task source for the batch. Its iterator must report the
58 /// remaining task count exactly.
59 ///
60 /// # Returns
61 ///
62 /// The result returned by [`Self::execute_with_count`] after deriving the
63 /// declared count from the iterator length.
64 ///
65 /// # Errors
66 ///
67 /// Returns [`BatchExecutionError`] only if the iterator violates its exact
68 /// length contract while being consumed.
69 ///
70 /// # Panics
71 ///
72 /// Panics from individual tasks are captured in [`BatchOutcome`].
73 /// Panics from the configured
74 /// [`qubit_progress::reporter::ProgressReporter`] are propagated to the
75 /// caller.
76 fn execute<T, E, I>(&self, tasks: I) -> Result<BatchOutcome<E>, BatchExecutionError<E>>
77 where
78 I: IntoIterator<Item = T>,
79 I::IntoIter: ExactSizeIterator,
80 T: Runnable<E> + Send,
81 E: Send,
82 {
83 let tasks = tasks.into_iter();
84 let count = tasks.len();
85 self.execute_with_count(tasks, count)
86 }
87
88 /// Executes a batch of runnable tasks with an explicit declared count.
89 ///
90 /// # Parameters
91 ///
92 /// * `tasks` - Task source for the batch. It may be eager or lazy.
93 /// * `count` - Declared number of tasks expected from `tasks`.
94 ///
95 /// # Returns
96 ///
97 /// `Ok(BatchOutcome)` when the declared task count matches the source, or
98 /// `Err(BatchExecutionError)` when the source yields fewer or more tasks
99 /// than declared.
100 ///
101 /// # Errors
102 ///
103 /// Returns [`BatchExecutionError`] when the source task count does not
104 /// match `count`.
105 ///
106 /// # Panics
107 ///
108 /// Panics from individual tasks are captured in [`BatchOutcome`].
109 /// Panics from the configured
110 /// [`qubit_progress::reporter::ProgressReporter`] are propagated to the
111 /// caller.
112 fn execute_with_count<T, E, I>(
113 &self,
114 tasks: I,
115 count: usize,
116 ) -> Result<BatchOutcome<E>, BatchExecutionError<E>>
117 where
118 I: IntoIterator<Item = T>,
119 T: Runnable<E> + Send,
120 E: Send;
121
122 /// Executes callable tasks whose iterator exposes an exact length.
123 ///
124 /// # Parameters
125 ///
126 /// * `tasks` - Callable task source for the batch. Its iterator must report
127 /// the remaining callable count exactly.
128 ///
129 /// # Returns
130 ///
131 /// A [`BatchCallResult`] containing the normal execution summary plus
132 /// optional success values indexed by callable position.
133 ///
134 /// # Errors
135 ///
136 /// Returns [`BatchExecutionError`] only if the iterator violates its exact
137 /// length contract while being consumed.
138 ///
139 /// # Panics
140 ///
141 /// Panics from individual callables are captured in the execution result.
142 /// Panics from the configured
143 /// [`qubit_progress::reporter::ProgressReporter`] are propagated to the
144 /// caller.
145 fn call<C, R, E, I>(&self, tasks: I) -> Result<BatchCallResult<R, E>, BatchExecutionError<E>>
146 where
147 I: IntoIterator<Item = C>,
148 I::IntoIter: ExactSizeIterator,
149 C: Callable<R, E> + Send,
150 R: Send,
151 E: Send,
152 {
153 let tasks = tasks.into_iter();
154 let count = tasks.len();
155 self.call_with_count(tasks, count)
156 }
157
158 /// Executes callable tasks with an explicit declared count and collects
159 /// success values by index.
160 ///
161 /// # Parameters
162 ///
163 /// * `tasks` - Callable task source for the batch.
164 /// * `count` - Declared number of callables expected from `tasks`.
165 ///
166 /// # Returns
167 ///
168 /// A [`BatchCallResult`] containing the normal execution summary plus
169 /// optional success values indexed by callable position.
170 ///
171 /// # Errors
172 ///
173 /// Returns [`BatchExecutionError`] when the source callable count does not
174 /// match `count`.
175 ///
176 /// # Panics
177 ///
178 /// Panics from individual callables are captured in the execution result.
179 /// Panics from the configured
180 /// [`qubit_progress::reporter::ProgressReporter`] are propagated to the
181 /// caller.
182 fn call_with_count<C, R, E, I>(
183 &self,
184 tasks: I,
185 count: usize,
186 ) -> Result<BatchCallResult<R, E>, BatchExecutionError<E>>
187 where
188 I: IntoIterator<Item = C>,
189 C: Callable<R, E> + Send,
190 R: Send,
191 E: Send,
192 {
193 let outputs = Arc::new(SegQueue::new());
194 // This adapter is lazy: callables are wrapped as runnable tasks only
195 // when the executor consumes the iterator. The callables themselves are
196 // still executed later by `CallableTask::run`.
197 let runnable_tasks = tasks.into_iter().enumerate().map({
198 let outputs = Arc::clone(&outputs);
199 move |(index, callable)| CallableTask::new(callable, index, Arc::clone(&outputs))
200 });
201 let outcome = self.execute_with_count(runnable_tasks, count)?;
202 let values = collect_call_outputs(outputs, count);
203 Ok(BatchCallResult::new(outcome, values))
204 }
205
206 /// Applies `action` to every item whose iterator exposes an exact length.
207 ///
208 /// # Parameters
209 ///
210 /// * `items` - Item source to transform into runnable tasks.
211 /// * `action` - Fallible action applied to each item.
212 ///
213 /// # Returns
214 ///
215 /// The result returned by [`Self::for_each_with_count`] after deriving the
216 /// declared count from the iterator length.
217 ///
218 /// # Errors
219 ///
220 /// Returns [`BatchExecutionError`] only if the iterator violates its exact
221 /// length contract while being consumed.
222 fn for_each<Item, E, I, F>(
223 &self,
224 items: I,
225 action: F,
226 ) -> Result<BatchOutcome<E>, BatchExecutionError<E>>
227 where
228 I: IntoIterator<Item = Item>,
229 I::IntoIter: ExactSizeIterator,
230 Item: Send,
231 F: Fn(Item) -> Result<(), E> + Send + Sync,
232 E: Send,
233 {
234 let items = items.into_iter();
235 let count = items.len();
236 self.for_each_with_count(items, count, action)
237 }
238
239 /// Applies `action` to every item using an explicit declared count.
240 ///
241 /// # Parameters
242 ///
243 /// * `items` - Item source to transform into runnable tasks.
244 /// * `count` - Declared number of items expected from `items`.
245 /// * `action` - Fallible action applied to each item.
246 ///
247 /// # Returns
248 ///
249 /// The result returned by [`Self::execute_with_count`] for the derived task
250 /// batch.
251 ///
252 /// # Errors
253 ///
254 /// Returns [`BatchExecutionError`] when the source item count does not
255 /// match `count`.
256 fn for_each_with_count<Item, E, I, F>(
257 &self,
258 items: I,
259 count: usize,
260 action: F,
261 ) -> Result<BatchOutcome<E>, BatchExecutionError<E>>
262 where
263 I: IntoIterator<Item = Item>,
264 Item: Send,
265 F: Fn(Item) -> Result<(), E> + Send + Sync,
266 E: Send,
267 {
268 let action = Arc::new(action);
269 let tasks = items
270 .into_iter()
271 .map(move |item| ForEachTask::new(item, Arc::clone(&action)));
272 self.execute_with_count(tasks, count)
273 }
274}
275
276/// Consumes shared callable outputs into an indexed value vector.
277///
278/// # Parameters
279///
280/// * `outputs` - Shared output queue filled by callable wrappers.
281/// * `count` - Declared callable count used to size the result vector.
282///
283/// # Returns
284///
285/// Optional success values indexed by callable position.
286///
287/// # Panics
288///
289/// Panics if callable wrappers still hold references to `outputs`, or if a
290/// queued output index is outside the declared batch size.
291fn collect_call_outputs<R>(outputs: Arc<SegQueue<(usize, R)>>, count: usize) -> Vec<Option<R>> {
292 let outputs = match Arc::try_unwrap(outputs) {
293 Ok(outputs) => outputs,
294 Err(_) => panic!("callable output queue should have a single owner after execution"),
295 };
296 let mut values = Vec::with_capacity(count);
297 values.resize_with(count, || None);
298 while let Some((index, value)) = outputs.pop() {
299 let slot = values
300 .get_mut(index)
301 .expect("callable index must be within the declared count");
302 *slot = Some(value);
303 }
304 values
305}