Skip to main content

cognis_core/wrappers/
batch_processor.rs

1//! Explicit batch processor — extends the default [`Runnable::batch`]
2//! with knobs that the trait method intentionally keeps simple:
3//!
4//! - **`max_concurrency`** — separate from `RunnableConfig`'s default,
5//!   so the same runnable can be batched at different concurrencies in
6//!   different call sites without per-config gymnastics.
7//! - **`return_exceptions`** — when `true`, the processor returns
8//!   `Vec<Result<O, CognisError>>` so partial failures surface
9//!   per-input. The default `Runnable::batch` short-circuits on the
10//!   first error.
11//! - **`wave_delay`** — optional pause between scheduling waves, useful
12//!   when the underlying service has rate-limit windows that aren't
13//!   well-captured by per-call backoff.
14//!
15//! Customization: implement [`crate::Runnable`] for an entirely custom
16//! batch strategy (e.g. server-side batching). Otherwise, use
17//! `BatchProcessor::process` for the standard concurrency-bounded path.
18
19use std::marker::PhantomData;
20use std::sync::Arc;
21use std::time::Duration;
22
23use futures::stream::{self, StreamExt};
24
25use crate::runnable::{Runnable, RunnableConfig};
26use crate::{CognisError, Result};
27
28/// Settings for [`BatchProcessor`].
29#[derive(Debug, Clone, Default)]
30pub struct BatchOptions {
31    /// Maximum in-flight tasks. `0` falls back to the config default.
32    pub max_concurrency: usize,
33    /// When `true`, errors are collected per-input rather than
34    /// short-circuiting the whole batch.
35    pub return_exceptions: bool,
36    /// Optional pause between waves of scheduled tasks. `None` schedules
37    /// continuously up to `max_concurrency`.
38    pub wave_delay: Option<Duration>,
39}
40
41impl BatchOptions {
42    /// Override max_concurrency (`0` ⇒ use config's value).
43    pub fn with_max_concurrency(mut self, n: usize) -> Self {
44        self.max_concurrency = n;
45        self
46    }
47
48    /// Toggle return_exceptions.
49    pub fn with_return_exceptions(mut self, on: bool) -> Self {
50        self.return_exceptions = on;
51        self
52    }
53
54    /// Set a wave delay between scheduled tasks.
55    pub fn with_wave_delay(mut self, d: Duration) -> Self {
56        self.wave_delay = Some(d);
57        self
58    }
59}
60
61/// Explicit batch processor for a single inner runnable.
62///
63/// Differences from `inner.batch(...)`:
64/// - Honors per-instance [`BatchOptions`].
65/// - Has [`process_with_results`] returning a `Vec<Result<O>>` so
66///   partial failures don't lose data.
67pub struct BatchProcessor<R, I, O> {
68    inner: R,
69    options: BatchOptions,
70    _phantom: PhantomData<fn(I) -> O>,
71}
72
73impl<R, I, O> BatchProcessor<R, I, O>
74where
75    R: Runnable<I, O>,
76    I: Send + 'static,
77    O: Send + 'static,
78{
79    /// Wrap with default options.
80    pub fn new(inner: R) -> Self {
81        Self {
82            inner,
83            options: BatchOptions::default(),
84            _phantom: PhantomData,
85        }
86    }
87
88    /// Wrap with explicit options.
89    pub fn with_options(inner: R, options: BatchOptions) -> Self {
90        Self {
91            inner,
92            options,
93            _phantom: PhantomData,
94        }
95    }
96
97    /// Borrow the inner runnable.
98    pub fn inner(&self) -> &R {
99        &self.inner
100    }
101
102    /// Mutably borrow options (e.g. to flip `return_exceptions` mid-life).
103    pub fn options_mut(&mut self) -> &mut BatchOptions {
104        &mut self.options
105    }
106
107    fn effective_concurrency(&self, config: &RunnableConfig) -> usize {
108        if self.options.max_concurrency == 0 {
109            config.max_concurrency.max(1)
110        } else {
111            self.options.max_concurrency
112        }
113    }
114}
115
116impl<R, I, O> BatchProcessor<R, I, O>
117where
118    R: Runnable<I, O> + Sync,
119    I: Send + 'static,
120    O: Send + 'static,
121{
122    /// Run a batch and short-circuit on the first error. Equivalent to
123    /// `inner.batch(...)` but honors `BatchOptions::max_concurrency` and
124    /// `wave_delay`.
125    pub async fn process(&self, inputs: Vec<I>, config: RunnableConfig) -> Result<Vec<O>> {
126        if self.options.return_exceptions {
127            let results = self.process_with_results(inputs, config).await;
128            results.into_iter().collect()
129        } else {
130            let concurrency = self.effective_concurrency(&config);
131            let wave = self.options.wave_delay;
132            let cfg = Arc::new(config);
133            let inner = &self.inner;
134            stream::iter(inputs)
135                .map(|input| {
136                    let cfg = cfg.clone();
137                    async move {
138                        if let Some(d) = wave {
139                            tokio::time::sleep(d).await;
140                        }
141                        inner
142                            .invoke(input, RunnableConfig::clone_for_subcall(&cfg))
143                            .await
144                    }
145                })
146                .buffer_unordered(concurrency)
147                .collect::<Vec<_>>()
148                .await
149                .into_iter()
150                .collect()
151        }
152    }
153
154    /// Run a batch and **always** return per-input results — failures
155    /// surface as `Err(CognisError)` entries, successful invocations as
156    /// `Ok(O)`. Order matches the input order.
157    pub async fn process_with_results(
158        &self,
159        inputs: Vec<I>,
160        config: RunnableConfig,
161    ) -> Vec<Result<O>> {
162        let concurrency = self.effective_concurrency(&config);
163        let wave = self.options.wave_delay;
164        let cfg = Arc::new(config);
165        let inner = &self.inner;
166        // Tag each input with its original index so we can restore order
167        // after `buffer_unordered` finishes.
168        let tagged: Vec<(usize, I)> = inputs.into_iter().enumerate().collect();
169        let total = tagged.len();
170        let mut results: Vec<Option<Result<O>>> = (0..total).map(|_| None).collect::<Vec<_>>();
171
172        let mut stream = stream::iter(tagged)
173            .map(|(i, input)| {
174                let cfg = cfg.clone();
175                async move {
176                    if let Some(d) = wave {
177                        tokio::time::sleep(d).await;
178                    }
179                    let res = inner
180                        .invoke(input, RunnableConfig::clone_for_subcall(&cfg))
181                        .await;
182                    (i, res)
183                }
184            })
185            .buffer_unordered(concurrency);
186
187        while let Some((i, res)) = stream.next().await {
188            results[i] = Some(res);
189        }
190        results
191            .into_iter()
192            .map(|r| {
193                r.unwrap_or_else(|| {
194                    Err(CognisError::Internal("batch slot was never filled".into()))
195                })
196            })
197            .collect()
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use async_trait::async_trait;
205    use std::sync::atomic::{AtomicUsize, Ordering};
206
207    struct Counter {
208        seen: Arc<AtomicUsize>,
209    }
210    #[async_trait]
211    impl Runnable<usize, usize> for Counter {
212        async fn invoke(&self, input: usize, _: RunnableConfig) -> Result<usize> {
213            self.seen.fetch_add(1, Ordering::SeqCst);
214            Ok(input * 2)
215        }
216    }
217
218    struct Sometimes;
219    #[async_trait]
220    impl Runnable<usize, usize> for Sometimes {
221        async fn invoke(&self, input: usize, _: RunnableConfig) -> Result<usize> {
222            if input.is_multiple_of(2) {
223                Ok(input)
224            } else {
225                Err(CognisError::Internal(format!("odd input {input}")))
226            }
227        }
228    }
229
230    #[tokio::test]
231    async fn process_returns_results_in_order() {
232        let r = Counter {
233            seen: Arc::new(AtomicUsize::new(0)),
234        };
235        let bp: BatchProcessor<Counter, usize, usize> = BatchProcessor::new(r);
236        let out = bp
237            .process(vec![1, 2, 3], RunnableConfig::default())
238            .await
239            .unwrap();
240        // process() preserves order via the inner Runnable::batch path —
241        // here we use the same shape but rely on the result order matching
242        // input order (first-error short-circuit branch).
243        assert_eq!(out.len(), 3);
244        assert!(out.contains(&2));
245        assert!(out.contains(&4));
246        assert!(out.contains(&6));
247    }
248
249    #[tokio::test]
250    async fn process_short_circuits_on_first_error() {
251        let bp: BatchProcessor<Sometimes, usize, usize> = BatchProcessor::new(Sometimes);
252        let res = bp.process(vec![1, 2, 3], RunnableConfig::default()).await;
253        assert!(res.is_err());
254    }
255
256    #[tokio::test]
257    async fn process_with_results_collects_partials() {
258        let bp = BatchProcessor::with_options(
259            Sometimes,
260            BatchOptions::default()
261                .with_return_exceptions(true)
262                .with_max_concurrency(4),
263        );
264        let res = bp
265            .process_with_results(vec![1, 2, 3, 4], RunnableConfig::default())
266            .await;
267        // Order matches input.
268        assert_eq!(res.len(), 4);
269        assert!(res[0].is_err());
270        assert!(res[1].is_ok());
271        assert!(res[2].is_err());
272        assert!(res[3].is_ok());
273    }
274
275    #[tokio::test]
276    async fn return_exceptions_via_process_short_circuits_only_when_all_fail() {
277        let bp = BatchProcessor::with_options(
278            Sometimes,
279            BatchOptions::default().with_return_exceptions(true),
280        );
281        // process() under return_exceptions returns the first Err it
282        // encounters when collecting via `.collect()`. That mirrors V1
283        // semantics: if the user wants the whole vec of partials, they
284        // should call process_with_results directly.
285        let res = bp.process(vec![2, 4], RunnableConfig::default()).await;
286        assert_eq!(res.unwrap(), vec![2, 4]);
287    }
288
289    #[tokio::test]
290    async fn explicit_max_concurrency_overrides_config() {
291        let counter = Arc::new(AtomicUsize::new(0));
292        let r = Counter {
293            seen: counter.clone(),
294        };
295        let bp: BatchProcessor<Counter, usize, usize> =
296            BatchProcessor::with_options(r, BatchOptions::default().with_max_concurrency(2));
297        let _ = bp
298            .process(vec![1, 2, 3, 4, 5], RunnableConfig::default())
299            .await
300            .unwrap();
301        assert_eq!(counter.load(Ordering::SeqCst), 5);
302    }
303}