rig/pipeline/
try_op.rs

1use std::future::Future;
2
3use futures::stream;
4#[allow(unused_imports)] // Needed since this is used in a macro rule
5use futures::try_join;
6
7use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
8
9use super::op::{self};
10
11// ================================================================
12// Core TryOp trait
13// ================================================================
14pub trait TryOp: WasmCompatSend + WasmCompatSync {
15    type Input: WasmCompatSend + WasmCompatSync;
16    type Output: WasmCompatSend + WasmCompatSync;
17    type Error: WasmCompatSend + WasmCompatSync;
18
19    /// Execute the current op with the given input.
20    fn try_call(
21        &self,
22        input: Self::Input,
23    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
24
25    /// Execute the current op with the given inputs. `n` is the number of concurrent
26    /// inputs that will be processed concurrently.
27    /// If the op fails for one of the inputs, the entire operation will fail and the error will
28    /// be returned.
29    ///
30    /// # Example
31    /// ```rust
32    /// use rig::pipeline::{self, TryOp};
33    ///
34    /// let op = pipeline::new()
35    ///    .map(|x: i32| if x % 2 == 0 { Ok(x + 1) } else { Err("x is odd") });
36    ///
37    /// // Execute the pipeline concurrently with 2 inputs
38    /// let result = op.try_batch_call(2, vec![2, 4]).await;
39    /// assert_eq!(result, Ok(vec![3, 5]));
40    /// ```
41    fn try_batch_call<I>(
42        &self,
43        n: usize,
44        input: I,
45    ) -> impl Future<Output = Result<Vec<Self::Output>, Self::Error>> + WasmCompatSend
46    where
47        I: IntoIterator<Item = Self::Input> + WasmCompatSend,
48        I::IntoIter: WasmCompatSend,
49        Self: Sized,
50    {
51        use stream::{StreamExt, TryStreamExt};
52
53        async move {
54            stream::iter(input)
55                .map(|input| self.try_call(input))
56                .buffered(n)
57                .try_collect()
58                .await
59        }
60    }
61
62    /// Map the success return value (i.e., `Ok`) of the current op to a different value
63    /// using the provided closure.
64    ///
65    /// # Example
66    /// ```rust
67    /// use rig::pipeline::{self, TryOp};
68    ///
69    /// let op = pipeline::new()
70    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
71    ///     .map_ok(|x| x * 2);
72    ///
73    /// let result = op.try_call(2).await;
74    /// assert_eq!(result, Ok(4));
75    /// ```
76    fn map_ok<F, Output>(self, f: F) -> MapOk<Self, op::Map<F, Self::Output>>
77    where
78        F: Fn(Self::Output) -> Output + WasmCompatSend + WasmCompatSync,
79        Output: WasmCompatSend + WasmCompatSync,
80        Self: Sized,
81    {
82        MapOk::new(self, op::Map::new(f))
83    }
84
85    /// Map the error return value (i.e., `Err`) of the current op to a different value
86    /// using the provided closure.
87    ///
88    /// # Example
89    /// ```rust
90    /// use rig::pipeline::{self, TryOp};
91    ///
92    /// let op = pipeline::new()
93    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
94    ///     .map_err(|err| format!("Error: {}", err));
95    ///
96    /// let result = op.try_call(1).await;
97    /// assert_eq!(result, Err("Error: x is odd".to_string()));
98    /// ```
99    fn map_err<F, E>(self, f: F) -> MapErr<Self, op::Map<F, Self::Error>>
100    where
101        F: Fn(Self::Error) -> E + WasmCompatSend + WasmCompatSync,
102        E: WasmCompatSend + WasmCompatSync,
103        Self: Sized,
104    {
105        MapErr::new(self, op::Map::new(f))
106    }
107
108    /// Chain a function to the current op. The function will only be called
109    /// if the current op returns `Ok`. The function must return a `Future` with value
110    /// `Result<T, E>` where `E` is the same type as the error type of the current.
111    ///
112    /// # Example
113    /// ```rust
114    /// use rig::pipeline::{self, TryOp};
115    ///
116    /// let op = pipeline::new()
117    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
118    ///     .and_then(|x| async move { Ok(x * 2) });
119    ///
120    /// let result = op.try_call(2).await;
121    /// assert_eq!(result, Ok(4));
122    /// ```
123    fn and_then<F, Fut, Output>(self, f: F) -> AndThen<Self, op::Then<F, Self::Output>>
124    where
125        F: Fn(Self::Output) -> Fut + WasmCompatSend + WasmCompatSync,
126        Fut: Future<Output = Result<Output, Self::Error>> + WasmCompatSend + WasmCompatSync,
127        Output: WasmCompatSend + WasmCompatSync,
128        Self: Sized,
129    {
130        AndThen::new(self, op::Then::new(f))
131    }
132
133    /// Chain a function `f` to the current op. The function `f` will only be called
134    /// if the current op returns `Err`. `f` must return a `Future` with value
135    /// `Result<T, E>` where `T` is the same type as the output type of the current op.
136    ///
137    /// # Example
138    /// ```rust
139    /// use rig::pipeline::{self, TryOp};
140    ///
141    /// let op = pipeline::new()
142    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
143    ///     .or_else(|err| async move { Err(format!("Error: {}", err)) });
144    ///
145    /// let result = op.try_call(1).await;
146    /// assert_eq!(result, Err("Error: x is odd".to_string()));
147    /// ```
148    fn or_else<F, Fut, E>(self, f: F) -> OrElse<Self, op::Then<F, Self::Error>>
149    where
150        F: Fn(Self::Error) -> Fut + WasmCompatSend + WasmCompatSync,
151        Fut: Future<Output = Result<Self::Output, E>> + WasmCompatSend + WasmCompatSync,
152        E: WasmCompatSend + WasmCompatSync,
153        Self: Sized,
154    {
155        OrElse::new(self, op::Then::new(f))
156    }
157
158    /// Chain a new op `op` to the current op. The new op will be called with the success
159    /// return value of the current op (i.e.: `Ok` value). The chained op can be any type that
160    /// implements the `Op` trait.
161    ///
162    /// # Example
163    /// ```rust
164    /// use rig::pipeline::{self, TryOp};
165    ///
166    /// struct AddOne;
167    ///
168    /// impl Op for AddOne {
169    ///     type Input = i32;
170    ///     type Output = i32;
171    ///
172    ///     async fn call(&self, input: Self::Input) -> Self::Output {
173    ///         input + 1
174    ///     }
175    /// }
176    ///
177    /// let op = pipeline::new()
178    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
179    ///     .chain_ok(MyOp);
180    ///
181    /// let result = op.try_call(2).await;
182    /// assert_eq!(result, Ok(3));
183    /// ```
184    fn chain_ok<T>(self, op: T) -> TrySequential<Self, T>
185    where
186        T: op::Op<Input = Self::Output>,
187        Self: Sized,
188    {
189        TrySequential::new(self, op)
190    }
191}
192
193impl<Op, T, E> TryOp for Op
194where
195    Op: super::Op<Output = Result<T, E>>,
196    T: WasmCompatSend + WasmCompatSync,
197    E: WasmCompatSend + WasmCompatSync,
198{
199    type Input = Op::Input;
200    type Output = T;
201    type Error = E;
202
203    async fn try_call(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
204        self.call(input).await
205    }
206}
207
208// ================================================================
209// TryOp combinators
210// ================================================================
211pub struct MapOk<Op1, Op2> {
212    prev: Op1,
213    op: Op2,
214}
215
216impl<Op1, Op2> MapOk<Op1, Op2> {
217    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
218        Self { prev, op }
219    }
220}
221
222impl<Op1, Op2> op::Op for MapOk<Op1, Op2>
223where
224    Op1: TryOp,
225    Op2: super::Op<Input = Op1::Output>,
226{
227    type Input = Op1::Input;
228    type Output = Result<Op2::Output, Op1::Error>;
229
230    #[inline]
231    async fn call(&self, input: Self::Input) -> Self::Output {
232        match self.prev.try_call(input).await {
233            Ok(output) => Ok(self.op.call(output).await),
234            Err(err) => Err(err),
235        }
236    }
237}
238
239pub struct MapErr<Op1, Op2> {
240    prev: Op1,
241    op: Op2,
242}
243
244impl<Op1, Op2> MapErr<Op1, Op2> {
245    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
246        Self { prev, op }
247    }
248}
249
250// Result<T, E1> -> Result<T, E2>
251impl<Op1, Op2> op::Op for MapErr<Op1, Op2>
252where
253    Op1: TryOp,
254    Op2: super::Op<Input = Op1::Error>,
255{
256    type Input = Op1::Input;
257    type Output = Result<Op1::Output, Op2::Output>;
258
259    #[inline]
260    async fn call(&self, input: Self::Input) -> Self::Output {
261        match self.prev.try_call(input).await {
262            Ok(output) => Ok(output),
263            Err(err) => Err(self.op.call(err).await),
264        }
265    }
266}
267
268pub struct AndThen<Op1, Op2> {
269    prev: Op1,
270    op: Op2,
271}
272
273impl<Op1, Op2> AndThen<Op1, Op2> {
274    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
275        Self { prev, op }
276    }
277}
278
279impl<Op1, Op2> op::Op for AndThen<Op1, Op2>
280where
281    Op1: TryOp,
282    Op2: TryOp<Input = Op1::Output, Error = Op1::Error>,
283{
284    type Input = Op1::Input;
285    type Output = Result<Op2::Output, Op1::Error>;
286
287    #[inline]
288    async fn call(&self, input: Self::Input) -> Self::Output {
289        let output = self.prev.try_call(input).await?;
290        self.op.try_call(output).await
291    }
292}
293
294pub struct OrElse<Op1, Op2> {
295    prev: Op1,
296    op: Op2,
297}
298
299impl<Op1, Op2> OrElse<Op1, Op2> {
300    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
301        Self { prev, op }
302    }
303}
304
305impl<Op1, Op2> op::Op for OrElse<Op1, Op2>
306where
307    Op1: TryOp,
308    Op2: TryOp<Input = Op1::Error, Output = Op1::Output>,
309{
310    type Input = Op1::Input;
311    type Output = Result<Op1::Output, Op2::Error>;
312
313    #[inline]
314    async fn call(&self, input: Self::Input) -> Self::Output {
315        match self.prev.try_call(input).await {
316            Ok(output) => Ok(output),
317            Err(err) => self.op.try_call(err).await,
318        }
319    }
320}
321
322pub struct TrySequential<Op1, Op2> {
323    prev: Op1,
324    op: Op2,
325}
326
327impl<Op1, Op2> TrySequential<Op1, Op2> {
328    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
329        Self { prev, op }
330    }
331}
332
333impl<Op1, Op2> op::Op for TrySequential<Op1, Op2>
334where
335    Op1: TryOp,
336    Op2: op::Op<Input = Op1::Output>,
337{
338    type Input = Op1::Input;
339    type Output = Result<Op2::Output, Op1::Error>;
340
341    #[inline]
342    async fn call(&self, input: Self::Input) -> Self::Output {
343        match self.prev.try_call(input).await {
344            Ok(output) => Ok(self.op.call(output).await),
345            Err(err) => Err(err),
346        }
347    }
348}
349
350// TODO: Implement TryParallel
351// pub struct TryParallel<Op1, Op2> {
352//     op1: Op1,
353//     op2: Op2,
354// }
355
356// impl<Op1, Op2> TryParallel<Op1, Op2> {
357//     pub fn new(op1: Op1, op2: Op2) -> Self {
358//         Self { op1, op2 }
359//     }
360// }
361
362// impl<Op1, Op2> TryOp for TryParallel<Op1, Op2>
363// where
364//     Op1: TryOp,
365//     Op2: TryOp<Input = Op1::Input, Output = Op1::Output, Error = Op1::Error>,
366// {
367//     type Input = Op1::Input;
368//     type Output = (Op1::Output, Op2::Output);
369//     type Error = Op1::Error;
370
371//     #[inline]
372//     async fn try_call(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
373//         let (output1, output2) = tokio::join!(self.op1.try_call(input.clone()), self.op2.try_call(input));
374//         Ok((output1?, output2?))
375//     }
376// }
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use crate::pipeline::op::{map, then};
382
383    #[tokio::test]
384    async fn test_try_op() {
385        let op = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
386        let result = op.try_call(2).await.unwrap();
387        assert_eq!(result, 2);
388    }
389
390    #[tokio::test]
391    async fn test_map_ok_constructor() {
392        let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
393        let op2 = then(|x: i32| async move { x * 2 });
394        let op3 = map(|x: i32| x - 1);
395
396        let pipeline = MapOk::new(MapOk::new(op1, op2), op3);
397
398        let result = pipeline.try_call(2).await.unwrap();
399        assert_eq!(result, 3);
400    }
401
402    #[tokio::test]
403    async fn test_map_ok_chain() {
404        let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
405            .map_ok(|x| x * 2)
406            .map_ok(|x| x - 1);
407
408        let result = pipeline.try_call(2).await.unwrap();
409        assert_eq!(result, 3);
410    }
411
412    #[tokio::test]
413    async fn test_map_err_constructor() {
414        let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
415        let op2 = then(|err: &str| async move { format!("Error: {err}") });
416        let op3 = map(|err: String| err.len());
417
418        let pipeline = MapErr::new(MapErr::new(op1, op2), op3);
419
420        let result = pipeline.try_call(1).await;
421        assert_eq!(result, Err(15));
422    }
423
424    #[tokio::test]
425    async fn test_map_err_chain() {
426        let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
427            .map_err(|err| format!("Error: {err}"))
428            .map_err(|err| err.len());
429
430        let result = pipeline.try_call(1).await;
431        assert_eq!(result, Err(15));
432    }
433
434    #[tokio::test]
435    async fn test_and_then_constructor() {
436        let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
437        let op2 = then(|x: i32| async move { Ok(x * 2) });
438        let op3 = map(|x: i32| Ok(x - 1));
439
440        let pipeline = AndThen::new(AndThen::new(op1, op2), op3);
441
442        let result = pipeline.try_call(2).await.unwrap();
443        assert_eq!(result, 3);
444    }
445
446    #[tokio::test]
447    async fn test_and_then_chain() {
448        let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
449            .and_then(|x| async move { Ok(x * 2) })
450            .and_then(|x| async move { Ok(x - 1) });
451
452        let result = pipeline.try_call(2).await.unwrap();
453        assert_eq!(result, 3);
454    }
455
456    #[tokio::test]
457    async fn test_or_else_constructor() {
458        let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
459        let op2 = then(|err: &str| async move { Err(format!("Error: {err}")) });
460        let op3 = map(|err: String| Ok::<i32, String>(err.len() as i32));
461
462        let pipeline = OrElse::new(OrElse::new(op1, op2), op3);
463
464        let result = pipeline.try_call(1).await.unwrap();
465        assert_eq!(result, 15);
466    }
467
468    #[tokio::test]
469    async fn test_or_else_chain() {
470        let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
471            .or_else(|err| async move { Err(format!("Error: {err}")) })
472            .or_else(|err| async move { Ok::<i32, String>(err.len() as i32) });
473
474        let result = pipeline.try_call(1).await.unwrap();
475        assert_eq!(result, 15);
476    }
477}