rig/pipeline/
conditional.rs

1/// Creates an `Op` that conditionally dispatches to one of multiple sub-ops
2/// based on the variant of the input enum.
3///
4/// **Important Requirements**:
5/// 1. The enum must be defined as a single-type-parameter wrapper, e.g.
6///    ```rust
7///    enum MyEnum<T> {
8///        VariantA(T),
9///        VariantB(T),
10///    }
11///    ```
12///    This allows all variants to share the same inner type (`T`).
13/// 2. All sub-ops must have the same `Input` type (this `T`) and the same `Output`.
14///    That is, for each variant, the corresponding op must implement
15///    `Op<Input = T, Output = Out>`.
16///
17/// # Example
18/// ```rust
19/// use rig::pipeline::*;
20/// use rig::conditional;
21/// use tokio;
22///
23/// #[tokio::main]
24/// async fn main() {
25///     #[derive(Debug)]
26///     enum ExampleEnum<T> {
27///         Variant1(T),
28///         Variant2(T),
29///     }
30///
31///     // Creates a pipeline Op that adds 1 if it’s Variant1, or doubles if it’s Variant2
32///     let op1 = map(|x: i32| x + 1);
33///     let op2 = map(|x: i32| x * 2);
34///
35///     let conditional = conditional!(ExampleEnum,
36///         Variant1 => op1,
37///         Variant2 => op2,
38///     );
39///
40///     let result1 = conditional.call(ExampleEnum::Variant1(2)).await;
41///     assert_eq!(result1, 3);
42///
43///     let result2 = conditional.call(ExampleEnum::Variant2(3)).await;
44///     assert_eq!(result2, 6);
45/// }
46/// ```
47#[macro_export]
48macro_rules! conditional {
49    ($enum:ident, $( $variant:ident => $op:expr ),+ $(,)?) => {
50        {
51            #[allow(non_snake_case)]
52            struct ConditionalOp<$($variant),+> {
53                $(
54                    $variant: $variant,
55                )+
56            }
57
58            impl<Value, Out, $($variant),+> Op for ConditionalOp<$($variant),+>
59            where
60                $($variant: Op<Input=Value, Output=Out>),+,
61                Value: Send + Sync,
62                Out: Send + Sync,
63            {
64                type Input = $enum<Value>;
65                type Output = Out;
66
67                fn call(&self, input: Self::Input) -> impl std::future::Future<Output=Self::Output> + Send {
68                    async move {
69                        match input {
70                            $(
71                                $enum::$variant(val) => self.$variant.call(val).await
72                            ),+
73                        }
74                    }
75                }
76            }
77
78            ConditionalOp { $($variant: $op),+ }
79        }
80    };
81}
82
83/// Creates a `TryOp` that conditionally dispatches to one of multiple sub-ops
84/// based on the variant of the input enum, returning a `Result`.
85///
86/// **Important Requirements**:
87/// 1. The enum must be defined as a single-type-parameter wrapper, e.g.
88///    ```rust
89///    enum MyEnum<T> {
90///        VariantA(T),
91///        VariantB(T),
92///    }
93///    ```
94///    This allows all variants to share the same inner type (`T`).
95/// 2. All sub-ops must have the same `Input` type (this `T`) and the same `Output`.
96///    That is, for each variant, the corresponding op must implement
97///    `TryOp<Input = T, Output = Out, Error = E>`.
98///
99/// # Example
100/// ```rust
101/// use rig::pipeline::*;
102/// use rig::try_conditional;
103/// use tokio;
104///
105/// #[tokio::main]
106/// async fn main() {
107///     #[derive(Debug)]
108///     enum ExampleEnum<T> {
109///         Variant1(T),
110///         Variant2(T),
111///     }
112///
113///     // Creates a pipeline TryOp that adds 1 or doubles, returning Ok(...) or Err(...)
114///     let op1 = map(|x: i32| Ok::<_, String>(x + 1));
115///     let op2 = map(|x: i32| Ok::<_, String>(x * 2));
116///
117///     let try_conditional = try_conditional!(ExampleEnum,
118///         Variant1 => op1,
119///         Variant2 => op2,
120///     );
121///
122///     let result = try_conditional.try_call(ExampleEnum::Variant1(2)).await;
123///     assert_eq!(result, Ok(3));
124/// }
125/// ```
126#[macro_export]
127macro_rules! try_conditional {
128    ($enum:ident, $( $variant:ident => $op:expr ),+ $(,)?) => {
129        {
130            #[allow(non_snake_case)]
131            struct TryConditionalOp<$( $variant ),+> {
132                $( $variant: $variant ),+
133            }
134
135            impl<Value, Out, Err, $( $variant ),+> TryOp for TryConditionalOp<$( $variant ),+>
136            where
137                $( $variant: TryOp<Input=Value, Output=Out, Error=Err> ),+,
138                Value: Send + Sync,
139                Out: Send + Sync,
140                Err: Send + Sync,
141            {
142                type Input = $enum<Value>;
143                type Output = Out;
144                type Error = Err;
145
146                async fn try_call(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
147                    match input {
148                        $(
149                            $enum::$variant(val) => self.$variant.try_call(val).await
150                        ),+
151                    }
152                }
153            }
154
155            TryConditionalOp { $($variant: $op),+ }
156        }
157    };
158}
159
160#[cfg(test)]
161mod tests {
162    use crate::pipeline::*;
163
164    #[tokio::test]
165    async fn test_conditional_op() {
166        enum ExampleEnum<T> {
167            Variant1(T),
168            Variant2(T),
169        }
170
171        let op1 = map(|x: i32| x + 1);
172        let op2 = map(|x: i32| x * 2);
173
174        let conditional = conditional!(ExampleEnum,
175            Variant1 => op1,
176            Variant2 => op2
177        );
178
179        let result1 = conditional.call(ExampleEnum::Variant1(2)).await;
180        assert_eq!(result1, 3);
181
182        let result2 = conditional.call(ExampleEnum::Variant2(3)).await;
183        assert_eq!(result2, 6);
184    }
185
186    #[tokio::test]
187    async fn test_try_conditional_op() {
188        enum ExampleEnum<T> {
189            Variant1(T),
190            Variant2(T),
191        }
192
193        let op1 = map(|x: i32| Ok::<_, String>(x + 1));
194        let op2 = map(|x: i32| Ok::<_, String>(x * 2));
195
196        let try_conditional = try_conditional!(ExampleEnum,
197            Variant1 => op1,
198            Variant2 => op2
199        );
200
201        let result1 = try_conditional.try_call(ExampleEnum::Variant1(2)).await;
202        assert_eq!(result1, Ok(3));
203
204        let result2 = try_conditional.try_call(ExampleEnum::Variant2(3)).await;
205        assert_eq!(result2, Ok(6));
206    }
207}