rig/pipeline/
try_op.rs

1use std::future::Future;
2
3use futures::stream;
4#[allow(unused_imports)] // Needed since this is used in a macro rule
5use futures::try_join;
6
7use super::op::{self};
8
9// ================================================================
10// Core TryOp trait
11// ================================================================
12pub trait TryOp: Send + Sync {
13    type Input: Send + Sync;
14    type Output: Send + Sync;
15    type Error: Send + Sync;
16
17    /// Execute the current op with the given input.
18    fn try_call(
19        &self,
20        input: Self::Input,
21    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send;
22
23    /// Execute the current op with the given inputs. `n` is the number of concurrent
24    /// inputs that will be processed concurrently.
25    /// If the op fails for one of the inputs, the entire operation will fail and the error will
26    /// be returned.
27    ///
28    /// # Example
29    /// ```rust
30    /// use rig::pipeline::{self, TryOp};
31    ///
32    /// let op = pipeline::new()
33    ///    .map(|x: i32| if x % 2 == 0 { Ok(x + 1) } else { Err("x is odd") });
34    ///
35    /// // Execute the pipeline concurrently with 2 inputs
36    /// let result = op.try_batch_call(2, vec![2, 4]).await;
37    /// assert_eq!(result, Ok(vec![3, 5]));
38    /// ```
39    fn try_batch_call<I>(
40        &self,
41        n: usize,
42        input: I,
43    ) -> impl Future<Output = Result<Vec<Self::Output>, Self::Error>> + Send
44    where
45        I: IntoIterator<Item = Self::Input> + Send,
46        I::IntoIter: Send,
47        Self: Sized,
48    {
49        use stream::{StreamExt, TryStreamExt};
50
51        async move {
52            stream::iter(input)
53                .map(|input| self.try_call(input))
54                .buffered(n)
55                .try_collect()
56                .await
57        }
58    }
59
60    /// Map the success return value (i.e., `Ok`) of the current op to a different value
61    /// using the provided closure.
62    ///
63    /// # Example
64    /// ```rust
65    /// use rig::pipeline::{self, TryOp};
66    ///
67    /// let op = pipeline::new()
68    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
69    ///     .map_ok(|x| x * 2);
70    ///
71    /// let result = op.try_call(2).await;
72    /// assert_eq!(result, Ok(4));
73    /// ```
74    fn map_ok<F, Output>(self, f: F) -> MapOk<Self, op::Map<F, Self::Output>>
75    where
76        F: Fn(Self::Output) -> Output + Send + Sync,
77        Output: Send + Sync,
78        Self: Sized,
79    {
80        MapOk::new(self, op::Map::new(f))
81    }
82
83    /// Map the error return value (i.e., `Err`) of the current op to a different value
84    /// using the provided closure.
85    ///
86    /// # Example
87    /// ```rust
88    /// use rig::pipeline::{self, TryOp};
89    ///
90    /// let op = pipeline::new()
91    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
92    ///     .map_err(|err| format!("Error: {}", err));
93    ///
94    /// let result = op.try_call(1).await;
95    /// assert_eq!(result, Err("Error: x is odd".to_string()));
96    /// ```
97    fn map_err<F, E>(self, f: F) -> MapErr<Self, op::Map<F, Self::Error>>
98    where
99        F: Fn(Self::Error) -> E + Send + Sync,
100        E: Send + Sync,
101        Self: Sized,
102    {
103        MapErr::new(self, op::Map::new(f))
104    }
105
106    /// Chain a function to the current op. The function will only be called
107    /// if the current op returns `Ok`. The function must return a `Future` with value
108    /// `Result<T, E>` where `E` is the same type as the error type of the current.
109    ///
110    /// # Example
111    /// ```rust
112    /// use rig::pipeline::{self, TryOp};
113    ///
114    /// let op = pipeline::new()
115    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
116    ///     .and_then(|x| async move { Ok(x * 2) });
117    ///
118    /// let result = op.try_call(2).await;
119    /// assert_eq!(result, Ok(4));
120    /// ```
121    fn and_then<F, Fut, Output>(self, f: F) -> AndThen<Self, op::Then<F, Self::Output>>
122    where
123        F: Fn(Self::Output) -> Fut + Send + Sync,
124        Fut: Future<Output = Result<Output, Self::Error>> + Send + Sync,
125        Output: Send + Sync,
126        Self: Sized,
127    {
128        AndThen::new(self, op::Then::new(f))
129    }
130
131    /// Chain a function `f` to the current op. The function `f` will only be called
132    /// if the current op returns `Err`. `f` must return a `Future` with value
133    /// `Result<T, E>` where `T` is the same type as the output type of the current op.
134    ///
135    /// # Example
136    /// ```rust
137    /// use rig::pipeline::{self, TryOp};
138    ///
139    /// let op = pipeline::new()
140    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
141    ///     .or_else(|err| async move { Err(format!("Error: {}", err)) });
142    ///
143    /// let result = op.try_call(1).await;
144    /// assert_eq!(result, Err("Error: x is odd".to_string()));
145    /// ```
146    fn or_else<F, Fut, E>(self, f: F) -> OrElse<Self, op::Then<F, Self::Error>>
147    where
148        F: Fn(Self::Error) -> Fut + Send + Sync,
149        Fut: Future<Output = Result<Self::Output, E>> + Send + Sync,
150        E: Send + Sync,
151        Self: Sized,
152    {
153        OrElse::new(self, op::Then::new(f))
154    }
155
156    /// Chain a new op `op` to the current op. The new op will be called with the success
157    /// return value of the current op (i.e.: `Ok` value). The chained op can be any type that
158    /// implements the `Op` trait.
159    ///
160    /// # Example
161    /// ```rust
162    /// use rig::pipeline::{self, TryOp};
163    ///
164    /// struct AddOne;
165    ///
166    /// impl Op for AddOne {
167    ///     type Input = i32;
168    ///     type Output = i32;
169    ///
170    ///     async fn call(&self, input: Self::Input) -> Self::Output {
171    ///         input + 1
172    ///     }
173    /// }
174    ///
175    /// let op = pipeline::new()
176    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
177    ///     .chain_ok(MyOp);
178    ///
179    /// let result = op.try_call(2).await;
180    /// assert_eq!(result, Ok(3));
181    /// ```
182    fn chain_ok<T>(self, op: T) -> TrySequential<Self, T>
183    where
184        T: op::Op<Input = Self::Output>,
185        Self: Sized,
186    {
187        TrySequential::new(self, op)
188    }
189}
190
191impl<Op, T, E> TryOp for Op
192where
193    Op: super::Op<Output = Result<T, E>>,
194    T: Send + Sync,
195    E: Send + Sync,
196{
197    type Input = Op::Input;
198    type Output = T;
199    type Error = E;
200
201    async fn try_call(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
202        self.call(input).await
203    }
204}
205
206// ================================================================
207// TryOp combinators
208// ================================================================
209pub struct MapOk<Op1, Op2> {
210    prev: Op1,
211    op: Op2,
212}
213
214impl<Op1, Op2> MapOk<Op1, Op2> {
215    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
216        Self { prev, op }
217    }
218}
219
220impl<Op1, Op2> op::Op for MapOk<Op1, Op2>
221where
222    Op1: TryOp,
223    Op2: super::Op<Input = Op1::Output>,
224{
225    type Input = Op1::Input;
226    type Output = Result<Op2::Output, Op1::Error>;
227
228    #[inline]
229    async fn call(&self, input: Self::Input) -> Self::Output {
230        match self.prev.try_call(input).await {
231            Ok(output) => Ok(self.op.call(output).await),
232            Err(err) => Err(err),
233        }
234    }
235}
236
237pub struct MapErr<Op1, Op2> {
238    prev: Op1,
239    op: Op2,
240}
241
242impl<Op1, Op2> MapErr<Op1, Op2> {
243    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
244        Self { prev, op }
245    }
246}
247
248// Result<T, E1> -> Result<T, E2>
249impl<Op1, Op2> op::Op for MapErr<Op1, Op2>
250where
251    Op1: TryOp,
252    Op2: super::Op<Input = Op1::Error>,
253{
254    type Input = Op1::Input;
255    type Output = Result<Op1::Output, Op2::Output>;
256
257    #[inline]
258    async fn call(&self, input: Self::Input) -> Self::Output {
259        match self.prev.try_call(input).await {
260            Ok(output) => Ok(output),
261            Err(err) => Err(self.op.call(err).await),
262        }
263    }
264}
265
266pub struct AndThen<Op1, Op2> {
267    prev: Op1,
268    op: Op2,
269}
270
271impl<Op1, Op2> AndThen<Op1, Op2> {
272    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
273        Self { prev, op }
274    }
275}
276
277impl<Op1, Op2> op::Op for AndThen<Op1, Op2>
278where
279    Op1: TryOp,
280    Op2: TryOp<Input = Op1::Output, Error = Op1::Error>,
281{
282    type Input = Op1::Input;
283    type Output = Result<Op2::Output, Op1::Error>;
284
285    #[inline]
286    async fn call(&self, input: Self::Input) -> Self::Output {
287        let output = self.prev.try_call(input).await?;
288        self.op.try_call(output).await
289    }
290}
291
292pub struct OrElse<Op1, Op2> {
293    prev: Op1,
294    op: Op2,
295}
296
297impl<Op1, Op2> OrElse<Op1, Op2> {
298    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
299        Self { prev, op }
300    }
301}
302
303impl<Op1, Op2> op::Op for OrElse<Op1, Op2>
304where
305    Op1: TryOp,
306    Op2: TryOp<Input = Op1::Error, Output = Op1::Output>,
307{
308    type Input = Op1::Input;
309    type Output = Result<Op1::Output, Op2::Error>;
310
311    #[inline]
312    async fn call(&self, input: Self::Input) -> Self::Output {
313        match self.prev.try_call(input).await {
314            Ok(output) => Ok(output),
315            Err(err) => self.op.try_call(err).await,
316        }
317    }
318}
319
320pub struct TrySequential<Op1, Op2> {
321    prev: Op1,
322    op: Op2,
323}
324
325impl<Op1, Op2> TrySequential<Op1, Op2> {
326    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
327        Self { prev, op }
328    }
329}
330
331impl<Op1, Op2> op::Op for TrySequential<Op1, Op2>
332where
333    Op1: TryOp,
334    Op2: op::Op<Input = Op1::Output>,
335{
336    type Input = Op1::Input;
337    type Output = Result<Op2::Output, Op1::Error>;
338
339    #[inline]
340    async fn call(&self, input: Self::Input) -> Self::Output {
341        match self.prev.try_call(input).await {
342            Ok(output) => Ok(self.op.call(output).await),
343            Err(err) => Err(err),
344        }
345    }
346}
347
348// TODO: Implement TryParallel
349// pub struct TryParallel<Op1, Op2> {
350//     op1: Op1,
351//     op2: Op2,
352// }
353
354// impl<Op1, Op2> TryParallel<Op1, Op2> {
355//     pub fn new(op1: Op1, op2: Op2) -> Self {
356//         Self { op1, op2 }
357//     }
358// }
359
360// impl<Op1, Op2> TryOp for TryParallel<Op1, Op2>
361// where
362//     Op1: TryOp,
363//     Op2: TryOp<Input = Op1::Input, Output = Op1::Output, Error = Op1::Error>,
364// {
365//     type Input = Op1::Input;
366//     type Output = (Op1::Output, Op2::Output);
367//     type Error = Op1::Error;
368
369//     #[inline]
370//     async fn try_call(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
371//         let (output1, output2) = tokio::join!(self.op1.try_call(input.clone()), self.op2.try_call(input));
372//         Ok((output1?, output2?))
373//     }
374// }
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use crate::pipeline::op::{map, then};
380
381    #[tokio::test]
382    async fn test_try_op() {
383        let op = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
384        let result = op.try_call(2).await.unwrap();
385        assert_eq!(result, 2);
386    }
387
388    #[tokio::test]
389    async fn test_map_ok_constructor() {
390        let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
391        let op2 = then(|x: i32| async move { x * 2 });
392        let op3 = map(|x: i32| x - 1);
393
394        let pipeline = MapOk::new(MapOk::new(op1, op2), op3);
395
396        let result = pipeline.try_call(2).await.unwrap();
397        assert_eq!(result, 3);
398    }
399
400    #[tokio::test]
401    async fn test_map_ok_chain() {
402        let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
403            .map_ok(|x| x * 2)
404            .map_ok(|x| x - 1);
405
406        let result = pipeline.try_call(2).await.unwrap();
407        assert_eq!(result, 3);
408    }
409
410    #[tokio::test]
411    async fn test_map_err_constructor() {
412        let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
413        let op2 = then(|err: &str| async move { format!("Error: {}", err) });
414        let op3 = map(|err: String| err.len());
415
416        let pipeline = MapErr::new(MapErr::new(op1, op2), op3);
417
418        let result = pipeline.try_call(1).await;
419        assert_eq!(result, Err(15));
420    }
421
422    #[tokio::test]
423    async fn test_map_err_chain() {
424        let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
425            .map_err(|err| format!("Error: {}", err))
426            .map_err(|err| err.len());
427
428        let result = pipeline.try_call(1).await;
429        assert_eq!(result, Err(15));
430    }
431
432    #[tokio::test]
433    async fn test_and_then_constructor() {
434        let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
435        let op2 = then(|x: i32| async move { Ok(x * 2) });
436        let op3 = map(|x: i32| Ok(x - 1));
437
438        let pipeline = AndThen::new(AndThen::new(op1, op2), op3);
439
440        let result = pipeline.try_call(2).await.unwrap();
441        assert_eq!(result, 3);
442    }
443
444    #[tokio::test]
445    async fn test_and_then_chain() {
446        let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
447            .and_then(|x| async move { Ok(x * 2) })
448            .and_then(|x| async move { Ok(x - 1) });
449
450        let result = pipeline.try_call(2).await.unwrap();
451        assert_eq!(result, 3);
452    }
453
454    #[tokio::test]
455    async fn test_or_else_constructor() {
456        let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
457        let op2 = then(|err: &str| async move { Err(format!("Error: {}", err)) });
458        let op3 = map(|err: String| Ok::<i32, String>(err.len() as i32));
459
460        let pipeline = OrElse::new(OrElse::new(op1, op2), op3);
461
462        let result = pipeline.try_call(1).await.unwrap();
463        assert_eq!(result, 15);
464    }
465
466    #[tokio::test]
467    async fn test_or_else_chain() {
468        let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
469            .or_else(|err| async move { Err(format!("Error: {}", err)) })
470            .or_else(|err| async move { Ok::<i32, String>(err.len() as i32) });
471
472        let result = pipeline.try_call(1).await.unwrap();
473        assert_eq!(result, 15);
474    }
475}