Skip to main content

qubit_batch/execute/
batch_executor.rs

1/*******************************************************************************
2 *
3 *    Copyright (c) 2025 - 2026 Haixing Hu.
4 *
5 *    SPDX-License-Identifier: Apache-2.0
6 *
7 *    Licensed under the Apache License, Version 2.0.
8 *
9 ******************************************************************************/
10use std::sync::Arc;
11
12use crossbeam_queue::SegQueue;
13use qubit_function::{
14    Callable,
15    Runnable,
16};
17
18use crate::{
19    BatchExecutionError,
20    BatchOutcome,
21};
22
23use super::{
24    BatchCallResult,
25    callable_task::CallableTask,
26    for_each_task::ForEachTask,
27};
28
29/// Executes batches of fallible tasks.
30///
31/// Implementations consume the supplied iterator once, execute every observed
32/// task unless an explicitly declared count is exceeded, and return a
33/// [`BatchOutcome`] containing task-level successes, failures, panics, and
34/// elapsed time.
35///
36/// ```rust
37/// use qubit_batch::{
38///     BatchExecutor,
39///     SequentialBatchExecutor,
40/// };
41///
42/// let outcome = SequentialBatchExecutor::new()
43///     .for_each([1, 2, 3], |value| {
44///         assert!(value > 0);
45///         Ok::<(), &'static str>(())
46///     })
47///     .expect("array length should be exact");
48///
49/// assert!(outcome.is_success());
50/// ```
51pub trait BatchExecutor: Send + Sync {
52    /// Executes a batch of runnable tasks whose iterator exposes an exact
53    /// length.
54    ///
55    /// # Parameters
56    ///
57    /// * `tasks` - Task source for the batch. Its iterator must report the
58    ///   remaining task count exactly.
59    ///
60    /// # Returns
61    ///
62    /// The result returned by [`Self::execute_with_count`] after deriving the
63    /// declared count from the iterator length.
64    ///
65    /// # Errors
66    ///
67    /// Returns [`BatchExecutionError`] only if the iterator violates its exact
68    /// length contract while being consumed.
69    ///
70    /// # Panics
71    ///
72    /// Panics from individual tasks are captured in [`BatchOutcome`].
73    /// Panics from the configured
74    /// [`qubit_progress::reporter::ProgressReporter`] are propagated to the
75    /// caller.
76    fn execute<T, E, I>(&self, tasks: I) -> Result<BatchOutcome<E>, BatchExecutionError<E>>
77    where
78        I: IntoIterator<Item = T>,
79        I::IntoIter: ExactSizeIterator,
80        T: Runnable<E> + Send,
81        E: Send,
82    {
83        let tasks = tasks.into_iter();
84        let count = tasks.len();
85        self.execute_with_count(tasks, count)
86    }
87
88    /// Executes a batch of runnable tasks with an explicit declared count.
89    ///
90    /// # Parameters
91    ///
92    /// * `tasks` - Task source for the batch. It may be eager or lazy.
93    /// * `count` - Declared number of tasks expected from `tasks`.
94    ///
95    /// # Returns
96    ///
97    /// `Ok(BatchOutcome)` when the declared task count matches the source, or
98    /// `Err(BatchExecutionError)` when the source yields fewer or more tasks
99    /// than declared.
100    ///
101    /// # Errors
102    ///
103    /// Returns [`BatchExecutionError`] when the source task count does not
104    /// match `count`.
105    ///
106    /// # Panics
107    ///
108    /// Panics from individual tasks are captured in [`BatchOutcome`].
109    /// Panics from the configured
110    /// [`qubit_progress::reporter::ProgressReporter`] are propagated to the
111    /// caller.
112    fn execute_with_count<T, E, I>(
113        &self,
114        tasks: I,
115        count: usize,
116    ) -> Result<BatchOutcome<E>, BatchExecutionError<E>>
117    where
118        I: IntoIterator<Item = T>,
119        T: Runnable<E> + Send,
120        E: Send;
121
122    /// Executes callable tasks whose iterator exposes an exact length.
123    ///
124    /// # Parameters
125    ///
126    /// * `tasks` - Callable task source for the batch. Its iterator must report
127    ///   the remaining callable count exactly.
128    ///
129    /// # Returns
130    ///
131    /// A [`BatchCallResult`] containing the normal execution summary plus
132    /// optional success values indexed by callable position.
133    ///
134    /// # Errors
135    ///
136    /// Returns [`BatchExecutionError`] only if the iterator violates its exact
137    /// length contract while being consumed.
138    ///
139    /// # Panics
140    ///
141    /// Panics from individual callables are captured in the execution result.
142    /// Panics from the configured
143    /// [`qubit_progress::reporter::ProgressReporter`] are propagated to the
144    /// caller.
145    fn call<C, R, E, I>(&self, tasks: I) -> Result<BatchCallResult<R, E>, BatchExecutionError<E>>
146    where
147        I: IntoIterator<Item = C>,
148        I::IntoIter: ExactSizeIterator,
149        C: Callable<R, E> + Send,
150        R: Send,
151        E: Send,
152    {
153        let tasks = tasks.into_iter();
154        let count = tasks.len();
155        self.call_with_count(tasks, count)
156    }
157
158    /// Executes callable tasks with an explicit declared count and collects
159    /// success values by index.
160    ///
161    /// # Parameters
162    ///
163    /// * `tasks` - Callable task source for the batch.
164    /// * `count` - Declared number of callables expected from `tasks`.
165    ///
166    /// # Returns
167    ///
168    /// A [`BatchCallResult`] containing the normal execution summary plus
169    /// optional success values indexed by callable position.
170    ///
171    /// # Errors
172    ///
173    /// Returns [`BatchExecutionError`] when the source callable count does not
174    /// match `count`.
175    ///
176    /// # Panics
177    ///
178    /// Panics from individual callables are captured in the execution result.
179    /// Panics from the configured
180    /// [`qubit_progress::reporter::ProgressReporter`] are propagated to the
181    /// caller.
182    fn call_with_count<C, R, E, I>(
183        &self,
184        tasks: I,
185        count: usize,
186    ) -> Result<BatchCallResult<R, E>, BatchExecutionError<E>>
187    where
188        I: IntoIterator<Item = C>,
189        C: Callable<R, E> + Send,
190        R: Send,
191        E: Send,
192    {
193        let outputs = Arc::new(SegQueue::new());
194        // This adapter is lazy: callables are wrapped as runnable tasks only
195        // when the executor consumes the iterator. The callables themselves are
196        // still executed later by `CallableTask::run`.
197        let runnable_tasks = tasks.into_iter().enumerate().map({
198            let outputs = Arc::clone(&outputs);
199            move |(index, callable)| CallableTask::new(callable, index, Arc::clone(&outputs))
200        });
201        let outcome = self.execute_with_count(runnable_tasks, count)?;
202        let values = collect_call_outputs(outputs, count);
203        Ok(BatchCallResult::new(outcome, values))
204    }
205
206    /// Applies `action` to every item whose iterator exposes an exact length.
207    ///
208    /// # Parameters
209    ///
210    /// * `items` - Item source to transform into runnable tasks.
211    /// * `action` - Fallible action applied to each item.
212    ///
213    /// # Returns
214    ///
215    /// The result returned by [`Self::for_each_with_count`] after deriving the
216    /// declared count from the iterator length.
217    ///
218    /// # Errors
219    ///
220    /// Returns [`BatchExecutionError`] only if the iterator violates its exact
221    /// length contract while being consumed.
222    fn for_each<Item, E, I, F>(
223        &self,
224        items: I,
225        action: F,
226    ) -> Result<BatchOutcome<E>, BatchExecutionError<E>>
227    where
228        I: IntoIterator<Item = Item>,
229        I::IntoIter: ExactSizeIterator,
230        Item: Send,
231        F: Fn(Item) -> Result<(), E> + Send + Sync,
232        E: Send,
233    {
234        let items = items.into_iter();
235        let count = items.len();
236        self.for_each_with_count(items, count, action)
237    }
238
239    /// Applies `action` to every item using an explicit declared count.
240    ///
241    /// # Parameters
242    ///
243    /// * `items` - Item source to transform into runnable tasks.
244    /// * `count` - Declared number of items expected from `items`.
245    /// * `action` - Fallible action applied to each item.
246    ///
247    /// # Returns
248    ///
249    /// The result returned by [`Self::execute_with_count`] for the derived task
250    /// batch.
251    ///
252    /// # Errors
253    ///
254    /// Returns [`BatchExecutionError`] when the source item count does not
255    /// match `count`.
256    fn for_each_with_count<Item, E, I, F>(
257        &self,
258        items: I,
259        count: usize,
260        action: F,
261    ) -> Result<BatchOutcome<E>, BatchExecutionError<E>>
262    where
263        I: IntoIterator<Item = Item>,
264        Item: Send,
265        F: Fn(Item) -> Result<(), E> + Send + Sync,
266        E: Send,
267    {
268        let action = Arc::new(action);
269        let tasks = items
270            .into_iter()
271            .map(move |item| ForEachTask::new(item, Arc::clone(&action)));
272        self.execute_with_count(tasks, count)
273    }
274}
275
276/// Consumes shared callable outputs into an indexed value vector.
277///
278/// # Parameters
279///
280/// * `outputs` - Shared output queue filled by callable wrappers.
281/// * `count` - Declared callable count used to size the result vector.
282///
283/// # Returns
284///
285/// Optional success values indexed by callable position.
286///
287/// # Panics
288///
289/// Panics if callable wrappers still hold references to `outputs`, or if a
290/// queued output index is outside the declared batch size.
291fn collect_call_outputs<R>(outputs: Arc<SegQueue<(usize, R)>>, count: usize) -> Vec<Option<R>> {
292    let outputs = match Arc::try_unwrap(outputs) {
293        Ok(outputs) => outputs,
294        Err(_) => panic!("callable output queue should have a single owner after execution"),
295    };
296    let mut values = Vec::with_capacity(count);
297    values.resize_with(count, || None);
298    while let Some((index, value)) = outputs.pop() {
299        let slot = values
300            .get_mut(index)
301            .expect("callable index must be within the declared count");
302        *slot = Some(value);
303    }
304    values
305}