rig/pipeline/conditional.rs
1/// Creates an `Op` that conditionally dispatches to one of multiple sub-ops
2/// based on the variant of the input enum.
3///
4/// **Important Requirements**:
5/// 1. The enum must be defined as a single-type-parameter wrapper, e.g.
6/// ```rust
7/// enum MyEnum<T> {
8/// VariantA(T),
9/// VariantB(T),
10/// }
11/// ```
12/// This allows all variants to share the same inner type (`T`).
13/// 2. All sub-ops must have the same `Input` type (this `T`) and the same `Output`.
14/// That is, for each variant, the corresponding op must implement
15/// `Op<Input = T, Output = Out>`.
16///
17/// # Example
18/// ```rust
19/// use rig::pipeline::*;
20/// use rig::conditional;
21/// use tokio;
22///
23/// #[tokio::main]
24/// async fn main() {
25/// #[derive(Debug)]
26/// enum ExampleEnum<T> {
27/// Variant1(T),
28/// Variant2(T),
29/// }
30///
31/// // Creates a pipeline Op that adds 1 if it’s Variant1, or doubles if it’s Variant2
32/// let op1 = map(|x: i32| x + 1);
33/// let op2 = map(|x: i32| x * 2);
34///
35/// let conditional = conditional!(ExampleEnum,
36/// Variant1 => op1,
37/// Variant2 => op2,
38/// );
39///
40/// let result1 = conditional.call(ExampleEnum::Variant1(2)).await;
41/// assert_eq!(result1, 3);
42///
43/// let result2 = conditional.call(ExampleEnum::Variant2(3)).await;
44/// assert_eq!(result2, 6);
45/// }
46/// ```
47#[macro_export]
48macro_rules! conditional {
49 ($enum:ident, $( $variant:ident => $op:expr ),+ $(,)?) => {
50 {
51 #[allow(non_snake_case)]
52 struct ConditionalOp<$($variant),+> {
53 $(
54 $variant: $variant,
55 )+
56 }
57
58 impl<Value, Out, $($variant),+> Op for ConditionalOp<$($variant),+>
59 where
60 $($variant: Op<Input=Value, Output=Out>),+,
61 Value: Send + Sync,
62 Out: Send + Sync,
63 {
64 type Input = $enum<Value>;
65 type Output = Out;
66
67 fn call(&self, input: Self::Input) -> impl std::future::Future<Output=Self::Output> + Send {
68 async move {
69 match input {
70 $(
71 $enum::$variant(val) => self.$variant.call(val).await
72 ),+
73 }
74 }
75 }
76 }
77
78 ConditionalOp { $($variant: $op),+ }
79 }
80 };
81}
82
83/// Creates a `TryOp` that conditionally dispatches to one of multiple sub-ops
84/// based on the variant of the input enum, returning a `Result`.
85///
86/// **Important Requirements**:
87/// 1. The enum must be defined as a single-type-parameter wrapper, e.g.
88/// ```rust
89/// enum MyEnum<T> {
90/// VariantA(T),
91/// VariantB(T),
92/// }
93/// ```
94/// This allows all variants to share the same inner type (`T`).
95/// 2. All sub-ops must have the same `Input` type (this `T`) and the same `Output`.
96/// That is, for each variant, the corresponding op must implement
97/// `TryOp<Input = T, Output = Out, Error = E>`.
98///
99/// # Example
100/// ```rust
101/// use rig::pipeline::*;
102/// use rig::try_conditional;
103/// use tokio;
104///
105/// #[tokio::main]
106/// async fn main() {
107/// #[derive(Debug)]
108/// enum ExampleEnum<T> {
109/// Variant1(T),
110/// Variant2(T),
111/// }
112///
113/// // Creates a pipeline TryOp that adds 1 or doubles, returning Ok(...) or Err(...)
114/// let op1 = map(|x: i32| Ok::<_, String>(x + 1));
115/// let op2 = map(|x: i32| Ok::<_, String>(x * 2));
116///
117/// let try_conditional = try_conditional!(ExampleEnum,
118/// Variant1 => op1,
119/// Variant2 => op2,
120/// );
121///
122/// let result = try_conditional.try_call(ExampleEnum::Variant1(2)).await;
123/// assert_eq!(result, Ok(3));
124/// }
125/// ```
126#[macro_export]
127macro_rules! try_conditional {
128 ($enum:ident, $( $variant:ident => $op:expr ),+ $(,)?) => {
129 {
130 #[allow(non_snake_case)]
131 struct TryConditionalOp<$( $variant ),+> {
132 $( $variant: $variant ),+
133 }
134
135 impl<Value, Out, Err, $( $variant ),+> TryOp for TryConditionalOp<$( $variant ),+>
136 where
137 $( $variant: TryOp<Input=Value, Output=Out, Error=Err> ),+,
138 Value: Send + Sync,
139 Out: Send + Sync,
140 Err: Send + Sync,
141 {
142 type Input = $enum<Value>;
143 type Output = Out;
144 type Error = Err;
145
146 async fn try_call(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
147 match input {
148 $(
149 $enum::$variant(val) => self.$variant.try_call(val).await
150 ),+
151 }
152 }
153 }
154
155 TryConditionalOp { $($variant: $op),+ }
156 }
157 };
158}
159
160#[cfg(test)]
161mod tests {
162 use crate::pipeline::*;
163
164 #[tokio::test]
165 async fn test_conditional_op() {
166 enum ExampleEnum<T> {
167 Variant1(T),
168 Variant2(T),
169 }
170
171 let op1 = map(|x: i32| x + 1);
172 let op2 = map(|x: i32| x * 2);
173
174 let conditional = conditional!(ExampleEnum,
175 Variant1 => op1,
176 Variant2 => op2
177 );
178
179 let result1 = conditional.call(ExampleEnum::Variant1(2)).await;
180 assert_eq!(result1, 3);
181
182 let result2 = conditional.call(ExampleEnum::Variant2(3)).await;
183 assert_eq!(result2, 6);
184 }
185
186 #[tokio::test]
187 async fn test_try_conditional_op() {
188 enum ExampleEnum<T> {
189 Variant1(T),
190 Variant2(T),
191 }
192
193 let op1 = map(|x: i32| Ok::<_, String>(x + 1));
194 let op2 = map(|x: i32| Ok::<_, String>(x * 2));
195
196 let try_conditional = try_conditional!(ExampleEnum,
197 Variant1 => op1,
198 Variant2 => op2
199 );
200
201 let result1 = try_conditional.try_call(ExampleEnum::Variant1(2)).await;
202 assert_eq!(result1, Ok(3));
203
204 let result2 = try_conditional.try_call(ExampleEnum::Variant2(3)).await;
205 assert_eq!(result2, Ok(6));
206 }
207}