1use std::marker::PhantomData;
20use std::sync::Arc;
21use std::time::Duration;
22
23use futures::stream::{self, StreamExt};
24
25use crate::runnable::{Runnable, RunnableConfig};
26use crate::{CognisError, Result};
27
28#[derive(Debug, Clone, Default)]
30pub struct BatchOptions {
31 pub max_concurrency: usize,
33 pub return_exceptions: bool,
36 pub wave_delay: Option<Duration>,
39}
40
41impl BatchOptions {
42 pub fn with_max_concurrency(mut self, n: usize) -> Self {
44 self.max_concurrency = n;
45 self
46 }
47
48 pub fn with_return_exceptions(mut self, on: bool) -> Self {
50 self.return_exceptions = on;
51 self
52 }
53
54 pub fn with_wave_delay(mut self, d: Duration) -> Self {
56 self.wave_delay = Some(d);
57 self
58 }
59}
60
61pub struct BatchProcessor<R, I, O> {
68 inner: R,
69 options: BatchOptions,
70 _phantom: PhantomData<fn(I) -> O>,
71}
72
73impl<R, I, O> BatchProcessor<R, I, O>
74where
75 R: Runnable<I, O>,
76 I: Send + 'static,
77 O: Send + 'static,
78{
79 pub fn new(inner: R) -> Self {
81 Self {
82 inner,
83 options: BatchOptions::default(),
84 _phantom: PhantomData,
85 }
86 }
87
88 pub fn with_options(inner: R, options: BatchOptions) -> Self {
90 Self {
91 inner,
92 options,
93 _phantom: PhantomData,
94 }
95 }
96
97 pub fn inner(&self) -> &R {
99 &self.inner
100 }
101
102 pub fn options_mut(&mut self) -> &mut BatchOptions {
104 &mut self.options
105 }
106
107 fn effective_concurrency(&self, config: &RunnableConfig) -> usize {
108 if self.options.max_concurrency == 0 {
109 config.max_concurrency.max(1)
110 } else {
111 self.options.max_concurrency
112 }
113 }
114}
115
116impl<R, I, O> BatchProcessor<R, I, O>
117where
118 R: Runnable<I, O> + Sync,
119 I: Send + 'static,
120 O: Send + 'static,
121{
122 pub async fn process(&self, inputs: Vec<I>, config: RunnableConfig) -> Result<Vec<O>> {
126 if self.options.return_exceptions {
127 let results = self.process_with_results(inputs, config).await;
128 results.into_iter().collect()
129 } else {
130 let concurrency = self.effective_concurrency(&config);
131 let wave = self.options.wave_delay;
132 let cfg = Arc::new(config);
133 let inner = &self.inner;
134 stream::iter(inputs)
135 .map(|input| {
136 let cfg = cfg.clone();
137 async move {
138 if let Some(d) = wave {
139 tokio::time::sleep(d).await;
140 }
141 inner
142 .invoke(input, RunnableConfig::clone_for_subcall(&cfg))
143 .await
144 }
145 })
146 .buffer_unordered(concurrency)
147 .collect::<Vec<_>>()
148 .await
149 .into_iter()
150 .collect()
151 }
152 }
153
154 pub async fn process_with_results(
158 &self,
159 inputs: Vec<I>,
160 config: RunnableConfig,
161 ) -> Vec<Result<O>> {
162 let concurrency = self.effective_concurrency(&config);
163 let wave = self.options.wave_delay;
164 let cfg = Arc::new(config);
165 let inner = &self.inner;
166 let tagged: Vec<(usize, I)> = inputs.into_iter().enumerate().collect();
169 let total = tagged.len();
170 let mut results: Vec<Option<Result<O>>> = (0..total).map(|_| None).collect::<Vec<_>>();
171
172 let mut stream = stream::iter(tagged)
173 .map(|(i, input)| {
174 let cfg = cfg.clone();
175 async move {
176 if let Some(d) = wave {
177 tokio::time::sleep(d).await;
178 }
179 let res = inner
180 .invoke(input, RunnableConfig::clone_for_subcall(&cfg))
181 .await;
182 (i, res)
183 }
184 })
185 .buffer_unordered(concurrency);
186
187 while let Some((i, res)) = stream.next().await {
188 results[i] = Some(res);
189 }
190 results
191 .into_iter()
192 .map(|r| {
193 r.unwrap_or_else(|| {
194 Err(CognisError::Internal("batch slot was never filled".into()))
195 })
196 })
197 .collect()
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use async_trait::async_trait;
205 use std::sync::atomic::{AtomicUsize, Ordering};
206
207 struct Counter {
208 seen: Arc<AtomicUsize>,
209 }
210 #[async_trait]
211 impl Runnable<usize, usize> for Counter {
212 async fn invoke(&self, input: usize, _: RunnableConfig) -> Result<usize> {
213 self.seen.fetch_add(1, Ordering::SeqCst);
214 Ok(input * 2)
215 }
216 }
217
218 struct Sometimes;
219 #[async_trait]
220 impl Runnable<usize, usize> for Sometimes {
221 async fn invoke(&self, input: usize, _: RunnableConfig) -> Result<usize> {
222 if input.is_multiple_of(2) {
223 Ok(input)
224 } else {
225 Err(CognisError::Internal(format!("odd input {input}")))
226 }
227 }
228 }
229
230 #[tokio::test]
231 async fn process_returns_results_in_order() {
232 let r = Counter {
233 seen: Arc::new(AtomicUsize::new(0)),
234 };
235 let bp: BatchProcessor<Counter, usize, usize> = BatchProcessor::new(r);
236 let out = bp
237 .process(vec![1, 2, 3], RunnableConfig::default())
238 .await
239 .unwrap();
240 assert_eq!(out.len(), 3);
244 assert!(out.contains(&2));
245 assert!(out.contains(&4));
246 assert!(out.contains(&6));
247 }
248
249 #[tokio::test]
250 async fn process_short_circuits_on_first_error() {
251 let bp: BatchProcessor<Sometimes, usize, usize> = BatchProcessor::new(Sometimes);
252 let res = bp.process(vec![1, 2, 3], RunnableConfig::default()).await;
253 assert!(res.is_err());
254 }
255
256 #[tokio::test]
257 async fn process_with_results_collects_partials() {
258 let bp = BatchProcessor::with_options(
259 Sometimes,
260 BatchOptions::default()
261 .with_return_exceptions(true)
262 .with_max_concurrency(4),
263 );
264 let res = bp
265 .process_with_results(vec![1, 2, 3, 4], RunnableConfig::default())
266 .await;
267 assert_eq!(res.len(), 4);
269 assert!(res[0].is_err());
270 assert!(res[1].is_ok());
271 assert!(res[2].is_err());
272 assert!(res[3].is_ok());
273 }
274
275 #[tokio::test]
276 async fn return_exceptions_via_process_short_circuits_only_when_all_fail() {
277 let bp = BatchProcessor::with_options(
278 Sometimes,
279 BatchOptions::default().with_return_exceptions(true),
280 );
281 let res = bp.process(vec![2, 4], RunnableConfig::default()).await;
286 assert_eq!(res.unwrap(), vec![2, 4]);
287 }
288
289 #[tokio::test]
290 async fn explicit_max_concurrency_overrides_config() {
291 let counter = Arc::new(AtomicUsize::new(0));
292 let r = Counter {
293 seen: counter.clone(),
294 };
295 let bp: BatchProcessor<Counter, usize, usize> =
296 BatchProcessor::with_options(r, BatchOptions::default().with_max_concurrency(2));
297 let _ = bp
298 .process(vec![1, 2, 3, 4, 5], RunnableConfig::default())
299 .await
300 .unwrap();
301 assert_eq!(counter.load(Ordering::SeqCst), 5);
302 }
303}