Skip to main content

rig_core/pipeline/
try_op.rs

1use std::future::Future;
2
3use futures::stream;
4#[allow(unused_imports)] // Needed since this is used in a macro rule
5use futures::try_join;
6
7use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
8
9use super::op::{self};
10
11// ================================================================
12// Core TryOp trait
13// ================================================================
14pub trait TryOp: WasmCompatSend + WasmCompatSync {
15    type Input: WasmCompatSend + WasmCompatSync;
16    type Output: WasmCompatSend + WasmCompatSync;
17    type Error: WasmCompatSend + WasmCompatSync;
18
19    /// Execute the current op with the given input.
20    fn try_call(
21        &self,
22        input: Self::Input,
23    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
24
25    /// Execute the current op with the given inputs. `n` is the number of concurrent
26    /// inputs that will be processed concurrently.
27    /// If the op fails for one of the inputs, the entire operation will fail and the error will
28    /// be returned.
29    ///
30    /// # Example
31    /// ```no_run
32    /// use rig_core::pipeline::{self, TryOp};
33    ///
34    /// # async fn run() {
35    /// let op = pipeline::new()
36    ///    .map(|x: i32| if x % 2 == 0 { Ok(x + 1) } else { Err("x is odd") });
37    ///
38    /// // Execute the pipeline concurrently with 2 inputs
39    /// let result = op.try_batch_call(2, vec![2, 4]).await;
40    /// assert_eq!(result, Ok(vec![3, 5]));
41    /// # }
42    /// ```
43    fn try_batch_call<I>(
44        &self,
45        n: usize,
46        input: I,
47    ) -> impl Future<Output = Result<Vec<Self::Output>, Self::Error>> + WasmCompatSend
48    where
49        I: IntoIterator<Item = Self::Input> + WasmCompatSend,
50        I::IntoIter: WasmCompatSend,
51        Self: Sized,
52    {
53        use stream::{StreamExt, TryStreamExt};
54
55        async move {
56            stream::iter(input)
57                .map(|input| self.try_call(input))
58                .buffered(n)
59                .try_collect()
60                .await
61        }
62    }
63
64    /// Map the success return value (i.e., `Ok`) of the current op to a different value
65    /// using the provided closure.
66    ///
67    /// # Example
68    /// ```no_run
69    /// use rig_core::pipeline::{self, TryOp};
70    ///
71    /// # async fn run() {
72    /// let op = pipeline::new()
73    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
74    ///     .map_ok(|x| x * 2);
75    ///
76    /// let result = op.try_call(2).await;
77    /// assert_eq!(result, Ok(4));
78    /// # }
79    /// ```
80    fn map_ok<F, Output>(self, f: F) -> MapOk<Self, op::Map<F, Self::Output>>
81    where
82        F: Fn(Self::Output) -> Output + WasmCompatSend + WasmCompatSync,
83        Output: WasmCompatSend + WasmCompatSync,
84        Self: Sized,
85    {
86        MapOk::new(self, op::Map::new(f))
87    }
88
89    /// Map the error return value (i.e., `Err`) of the current op to a different value
90    /// using the provided closure.
91    ///
92    /// # Example
93    /// ```no_run
94    /// use rig_core::pipeline::{self, TryOp};
95    ///
96    /// # async fn run() {
97    /// let op = pipeline::new()
98    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
99    ///     .map_err(|err| format!("Error: {}", err));
100    ///
101    /// let result = op.try_call(1).await;
102    /// assert_eq!(result, Err("Error: x is odd".to_string()));
103    /// # }
104    /// ```
105    fn map_err<F, E>(self, f: F) -> MapErr<Self, op::Map<F, Self::Error>>
106    where
107        F: Fn(Self::Error) -> E + WasmCompatSend + WasmCompatSync,
108        E: WasmCompatSend + WasmCompatSync,
109        Self: Sized,
110    {
111        MapErr::new(self, op::Map::new(f))
112    }
113
114    /// Chain a function to the current op. The function will only be called
115    /// if the current op returns `Ok`. The function must return a `Future` with value
116    /// `Result<T, E>` where `E` is the same type as the error type of the current.
117    ///
118    /// # Example
119    /// ```no_run
120    /// use rig_core::pipeline::{self, TryOp};
121    ///
122    /// # async fn run() {
123    /// let op = pipeline::new()
124    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
125    ///     .and_then(|x| async move { Ok(x * 2) });
126    ///
127    /// let result = op.try_call(2).await;
128    /// assert_eq!(result, Ok(4));
129    /// # }
130    /// ```
131    fn and_then<F, Fut, Output>(self, f: F) -> AndThen<Self, op::Then<F, Self::Output>>
132    where
133        F: Fn(Self::Output) -> Fut + WasmCompatSend + WasmCompatSync,
134        Fut: Future<Output = Result<Output, Self::Error>> + WasmCompatSend + WasmCompatSync,
135        Output: WasmCompatSend + WasmCompatSync,
136        Self: Sized,
137    {
138        AndThen::new(self, op::Then::new(f))
139    }
140
141    /// Chain a function `f` to the current op. The function `f` will only be called
142    /// if the current op returns `Err`. `f` must return a `Future` with value
143    /// `Result<T, E>` where `T` is the same type as the output type of the current op.
144    ///
145    /// # Example
146    /// ```no_run
147    /// use rig_core::pipeline::{self, TryOp};
148    ///
149    /// # async fn run() {
150    /// let op = pipeline::new()
151    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
152    ///     .or_else(|err| async move { Err(format!("Error: {}", err)) });
153    ///
154    /// let result = op.try_call(1).await;
155    /// assert_eq!(result, Err("Error: x is odd".to_string()));
156    /// # }
157    /// ```
158    fn or_else<F, Fut, E>(self, f: F) -> OrElse<Self, op::Then<F, Self::Error>>
159    where
160        F: Fn(Self::Error) -> Fut + WasmCompatSend + WasmCompatSync,
161        Fut: Future<Output = Result<Self::Output, E>> + WasmCompatSend + WasmCompatSync,
162        E: WasmCompatSend + WasmCompatSync,
163        Self: Sized,
164    {
165        OrElse::new(self, op::Then::new(f))
166    }
167
168    /// Chain a new op `op` to the current op. The new op will be called with the success
169    /// return value of the current op (i.e.: `Ok` value). The chained op can be any type that
170    /// implements the `Op` trait.
171    ///
172    /// # Example
173    /// ```no_run
174    /// use rig_core::pipeline::{self, Op, TryOp};
175    ///
176    /// # async fn run() {
177    /// struct AddOne;
178    ///
179    /// impl Op for AddOne {
180    ///     type Input = i32;
181    ///     type Output = i32;
182    ///
183    ///     async fn call(&self, input: Self::Input) -> Self::Output {
184    ///         input + 1
185    ///     }
186    /// }
187    ///
188    /// let op = pipeline::new()
189    ///     .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
190    ///     .chain_ok(AddOne);
191    ///
192    /// let result = op.try_call(2).await;
193    /// assert_eq!(result, Ok(3));
194    /// # }
195    /// ```
196    fn chain_ok<T>(self, op: T) -> TrySequential<Self, T>
197    where
198        T: op::Op<Input = Self::Output>,
199        Self: Sized,
200    {
201        TrySequential::new(self, op)
202    }
203}
204
205impl<Op, T, E> TryOp for Op
206where
207    Op: super::Op<Output = Result<T, E>>,
208    T: WasmCompatSend + WasmCompatSync,
209    E: WasmCompatSend + WasmCompatSync,
210{
211    type Input = Op::Input;
212    type Output = T;
213    type Error = E;
214
215    async fn try_call(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
216        self.call(input).await
217    }
218}
219
220// ================================================================
221// TryOp combinators
222// ================================================================
223pub struct MapOk<Op1, Op2> {
224    prev: Op1,
225    op: Op2,
226}
227
228impl<Op1, Op2> MapOk<Op1, Op2> {
229    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
230        Self { prev, op }
231    }
232}
233
234impl<Op1, Op2> op::Op for MapOk<Op1, Op2>
235where
236    Op1: TryOp,
237    Op2: super::Op<Input = Op1::Output>,
238{
239    type Input = Op1::Input;
240    type Output = Result<Op2::Output, Op1::Error>;
241
242    #[inline]
243    async fn call(&self, input: Self::Input) -> Self::Output {
244        match self.prev.try_call(input).await {
245            Ok(output) => Ok(self.op.call(output).await),
246            Err(err) => Err(err),
247        }
248    }
249}
250
251pub struct MapErr<Op1, Op2> {
252    prev: Op1,
253    op: Op2,
254}
255
256impl<Op1, Op2> MapErr<Op1, Op2> {
257    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
258        Self { prev, op }
259    }
260}
261
262// Result<T, E1> -> Result<T, E2>
263impl<Op1, Op2> op::Op for MapErr<Op1, Op2>
264where
265    Op1: TryOp,
266    Op2: super::Op<Input = Op1::Error>,
267{
268    type Input = Op1::Input;
269    type Output = Result<Op1::Output, Op2::Output>;
270
271    #[inline]
272    async fn call(&self, input: Self::Input) -> Self::Output {
273        match self.prev.try_call(input).await {
274            Ok(output) => Ok(output),
275            Err(err) => Err(self.op.call(err).await),
276        }
277    }
278}
279
280pub struct AndThen<Op1, Op2> {
281    prev: Op1,
282    op: Op2,
283}
284
285impl<Op1, Op2> AndThen<Op1, Op2> {
286    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
287        Self { prev, op }
288    }
289}
290
291impl<Op1, Op2> op::Op for AndThen<Op1, Op2>
292where
293    Op1: TryOp,
294    Op2: TryOp<Input = Op1::Output, Error = Op1::Error>,
295{
296    type Input = Op1::Input;
297    type Output = Result<Op2::Output, Op1::Error>;
298
299    #[inline]
300    async fn call(&self, input: Self::Input) -> Self::Output {
301        let output = self.prev.try_call(input).await?;
302        self.op.try_call(output).await
303    }
304}
305
306pub struct OrElse<Op1, Op2> {
307    prev: Op1,
308    op: Op2,
309}
310
311impl<Op1, Op2> OrElse<Op1, Op2> {
312    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
313        Self { prev, op }
314    }
315}
316
317impl<Op1, Op2> op::Op for OrElse<Op1, Op2>
318where
319    Op1: TryOp,
320    Op2: TryOp<Input = Op1::Error, Output = Op1::Output>,
321{
322    type Input = Op1::Input;
323    type Output = Result<Op1::Output, Op2::Error>;
324
325    #[inline]
326    async fn call(&self, input: Self::Input) -> Self::Output {
327        match self.prev.try_call(input).await {
328            Ok(output) => Ok(output),
329            Err(err) => self.op.try_call(err).await,
330        }
331    }
332}
333
334pub struct TrySequential<Op1, Op2> {
335    prev: Op1,
336    op: Op2,
337}
338
339impl<Op1, Op2> TrySequential<Op1, Op2> {
340    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
341        Self { prev, op }
342    }
343}
344
345impl<Op1, Op2> op::Op for TrySequential<Op1, Op2>
346where
347    Op1: TryOp,
348    Op2: op::Op<Input = Op1::Output>,
349{
350    type Input = Op1::Input;
351    type Output = Result<Op2::Output, Op1::Error>;
352
353    #[inline]
354    async fn call(&self, input: Self::Input) -> Self::Output {
355        match self.prev.try_call(input).await {
356            Ok(output) => Ok(self.op.call(output).await),
357            Err(err) => Err(err),
358        }
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use crate::pipeline::op::{map, then};
366
367    #[tokio::test]
368    async fn test_try_op() {
369        let op = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
370        let result = op.try_call(2).await.unwrap();
371        assert_eq!(result, 2);
372    }
373
374    #[tokio::test]
375    async fn test_map_ok_constructor() {
376        let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
377        let op2 = then(|x: i32| async move { x * 2 });
378        let op3 = map(|x: i32| x - 1);
379
380        let pipeline = MapOk::new(MapOk::new(op1, op2), op3);
381
382        let result = pipeline.try_call(2).await.unwrap();
383        assert_eq!(result, 3);
384    }
385
386    #[tokio::test]
387    async fn test_map_ok_chain() {
388        let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
389            .map_ok(|x| x * 2)
390            .map_ok(|x| x - 1);
391
392        let result = pipeline.try_call(2).await.unwrap();
393        assert_eq!(result, 3);
394    }
395
396    #[tokio::test]
397    async fn test_map_err_constructor() {
398        let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
399        let op2 = then(|err: &str| async move { format!("Error: {err}") });
400        let op3 = map(|err: String| err.len());
401
402        let pipeline = MapErr::new(MapErr::new(op1, op2), op3);
403
404        let result = pipeline.try_call(1).await;
405        assert_eq!(result, Err(15));
406    }
407
408    #[tokio::test]
409    async fn test_map_err_chain() {
410        let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
411            .map_err(|err| format!("Error: {err}"))
412            .map_err(|err| err.len());
413
414        let result = pipeline.try_call(1).await;
415        assert_eq!(result, Err(15));
416    }
417
418    #[tokio::test]
419    async fn test_and_then_constructor() {
420        let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
421        let op2 = then(|x: i32| async move { Ok(x * 2) });
422        let op3 = map(|x: i32| Ok(x - 1));
423
424        let pipeline = AndThen::new(AndThen::new(op1, op2), op3);
425
426        let result = pipeline.try_call(2).await.unwrap();
427        assert_eq!(result, 3);
428    }
429
430    #[tokio::test]
431    async fn test_and_then_chain() {
432        let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
433            .and_then(|x| async move { Ok(x * 2) })
434            .and_then(|x| async move { Ok(x - 1) });
435
436        let result = pipeline.try_call(2).await.unwrap();
437        assert_eq!(result, 3);
438    }
439
440    #[tokio::test]
441    async fn test_or_else_constructor() {
442        let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") });
443        let op2 = then(|err: &str| async move { Err(format!("Error: {err}")) });
444        let op3 = map(|err: String| Ok::<i32, String>(err.len() as i32));
445
446        let pipeline = OrElse::new(OrElse::new(op1, op2), op3);
447
448        let result = pipeline.try_call(1).await.unwrap();
449        assert_eq!(result, 15);
450    }
451
452    #[tokio::test]
453    async fn test_or_else_chain() {
454        let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") })
455            .or_else(|err| async move { Err(format!("Error: {err}")) })
456            .or_else(|err| async move { Ok::<i32, String>(err.len() as i32) });
457
458        let result = pipeline.try_call(1).await.unwrap();
459        assert_eq!(result, 15);
460    }
461}