rig/pipeline/
op.rs

1use crate::wasm_compat::*;
2#[allow(unused_imports)] // Needed since this is used in a macro rule
3use futures::join;
4use futures::stream;
5use std::future::Future;
6
7// ================================================================
8// Core Op trait
9// ================================================================
10pub trait Op: WasmCompatSend + WasmCompatSync {
11    type Input: WasmCompatSend + WasmCompatSync;
12    type Output: WasmCompatSend + WasmCompatSync;
13
14    fn call(&self, input: Self::Input) -> impl Future<Output = Self::Output> + WasmCompatSend;
15
16    /// Execute the current pipeline with the given inputs. `n` is the number of concurrent
17    /// inputs that will be processed concurrently.
18    fn batch_call<I>(
19        &self,
20        n: usize,
21        input: I,
22    ) -> impl Future<Output = Vec<Self::Output>> + WasmCompatSend
23    where
24        I: IntoIterator<Item = Self::Input> + WasmCompatSend,
25        I::IntoIter: WasmCompatSend,
26        Self: Sized,
27    {
28        use futures::stream::StreamExt;
29
30        async move {
31            stream::iter(input)
32                .map(|input| self.call(input))
33                .buffered(n)
34                .collect()
35                .await
36        }
37    }
38
39    /// Chain a function `f` to the current op.
40    ///
41    /// # Example
42    /// ```rust
43    /// use rig::pipeline::{self, Op};
44    ///
45    /// let chain = pipeline::new()
46    ///    .map(|(x, y)| x + y)
47    ///    .map(|z| format!("Result: {z}!"));
48    ///
49    /// let result = chain.call((1, 2)).await;
50    /// assert_eq!(result, "Result: 3!");
51    /// ```
52    fn map<F, Input>(self, f: F) -> Sequential<Self, Map<F, Self::Output>>
53    where
54        F: Fn(Self::Output) -> Input + WasmCompatSend + WasmCompatSync,
55        Input: WasmCompatSend + WasmCompatSync,
56        Self: Sized,
57    {
58        Sequential::new(self, Map::new(f))
59    }
60
61    /// Same as `map` but for asynchronous functions
62    ///
63    /// # Example
64    /// ```rust
65    /// use rig::pipeline::{self, Op};
66    ///
67    /// let chain = pipeline::new()
68    ///     .then(|email: String| async move {
69    ///         email.split('@').next().unwrap().to_string()
70    ///     })
71    ///     .then(|username: String| async move {
72    ///         format!("Hello, {}!", username)
73    ///     });
74    ///
75    /// let result = chain.call("bob@gmail.com".to_string()).await;
76    /// assert_eq!(result, "Hello, bob!");
77    /// ```
78    fn then<F, Fut>(self, f: F) -> Sequential<Self, Then<F, Fut::Output>>
79    where
80        F: Fn(Self::Output) -> Fut + Send + WasmCompatSync,
81        Fut: Future + WasmCompatSend + WasmCompatSync,
82        Fut::Output: WasmCompatSend + WasmCompatSync,
83        Self: Sized,
84    {
85        Sequential::new(self, Then::new(f))
86    }
87
88    /// Chain an arbitrary operation to the current op.
89    ///
90    /// # Example
91    /// ```rust
92    /// use rig::pipeline::{self, Op};
93    ///
94    /// struct AddOne;
95    ///
96    /// impl Op for AddOne {
97    ///     type Input = i32;
98    ///     type Output = i32;
99    ///
100    ///     async fn call(&self, input: Self::Input) -> Self::Output {
101    ///         input + 1
102    ///     }
103    /// }
104    ///
105    /// let chain = pipeline::new()
106    ///    .chain(AddOne);
107    ///
108    /// let result = chain.call(1).await;
109    /// assert_eq!(result, 2);
110    /// ```
111    fn chain<T>(self, op: T) -> Sequential<Self, T>
112    where
113        T: Op<Input = Self::Output>,
114        Self: Sized,
115    {
116        Sequential::new(self, op)
117    }
118
119    /// Chain a lookup operation to the current chain. The lookup operation expects the
120    /// current chain to output a query string. The lookup operation will use the query to
121    /// retrieve the top `n` documents from the index and return them with the query string.
122    ///
123    /// # Example
124    /// ```rust
125    /// use rig::chain::{self, Chain};
126    ///
127    /// let chain = chain::new()
128    ///     .lookup(index, 2)
129    ///     .chain(|(query, docs): (_, Vec<String>)| async move {
130    ///         format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
131    ///     });
132    ///
133    /// let result = chain.call("What is a flurbo?".to_string()).await;
134    /// ```
135    fn lookup<I, Input>(
136        self,
137        index: I,
138        n: usize,
139    ) -> Sequential<Self, Lookup<I, Self::Output, Input>>
140    where
141        I: vector_store::VectorStoreIndex,
142        Input: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>,
143        Self::Output: Into<String>,
144        Self: Sized,
145    {
146        Sequential::new(self, Lookup::new(index, n))
147    }
148
149    /// Chain a prompt operation to the current chain. The prompt operation expects the
150    /// current chain to output a string. The prompt operation will use the string to prompt
151    /// the given agent (or any other type that implements the `Prompt` trait) and return
152    /// the response.
153    ///
154    /// # Example
155    /// ```rust
156    /// use rig::chain::{self, Chain};
157    ///
158    /// let agent = &openai_client.agent("gpt-4").build();
159    ///
160    /// let chain = chain::new()
161    ///    .map(|name| format!("Find funny nicknames for the following name: {name}!"))
162    ///    .prompt(agent);
163    ///
164    /// let result = chain.call("Alice".to_string()).await;
165    /// ```
166    fn prompt<P>(self, prompt: P) -> Sequential<Self, Prompt<P, Self::Output>>
167    where
168        P: completion::Prompt,
169        Self::Output: Into<String>,
170        Self: Sized,
171    {
172        Sequential::new(self, Prompt::new(prompt))
173    }
174}
175
176impl<T: Op> Op for &T {
177    type Input = T::Input;
178    type Output = T::Output;
179
180    #[inline]
181    async fn call(&self, input: Self::Input) -> Self::Output {
182        (*self).call(input).await
183    }
184}
185
186// ================================================================
187// Op combinators
188// ================================================================
189pub struct Sequential<Op1, Op2> {
190    prev: Op1,
191    op: Op2,
192}
193
194impl<Op1, Op2> Sequential<Op1, Op2> {
195    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
196        Self { prev, op }
197    }
198}
199
200impl<Op1, Op2> Op for Sequential<Op1, Op2>
201where
202    Op1: Op,
203    Op2: Op<Input = Op1::Output>,
204{
205    type Input = Op1::Input;
206    type Output = Op2::Output;
207
208    #[inline]
209    async fn call(&self, input: Self::Input) -> Self::Output {
210        let prev = self.prev.call(input).await;
211        self.op.call(prev).await
212    }
213}
214
215use super::agent_ops::{Lookup, Prompt};
216use crate::{completion, vector_store};
217
218// ================================================================
219// Core Op implementations
220// ================================================================
221pub struct Map<F, Input> {
222    f: F,
223    _t: std::marker::PhantomData<Input>,
224}
225
226impl<F, Input> Map<F, Input> {
227    pub(crate) fn new(f: F) -> Self {
228        Self {
229            f,
230            _t: std::marker::PhantomData,
231        }
232    }
233}
234
235impl<F, Input, Output> Op for Map<F, Input>
236where
237    F: Fn(Input) -> Output + WasmCompatSend + WasmCompatSync,
238    Input: WasmCompatSend + WasmCompatSync,
239    Output: WasmCompatSend + WasmCompatSync,
240{
241    type Input = Input;
242    type Output = Output;
243
244    #[inline]
245    async fn call(&self, input: Self::Input) -> Self::Output {
246        (self.f)(input)
247    }
248}
249
250pub fn map<F, Input, Output>(f: F) -> Map<F, Input>
251where
252    F: Fn(Input) -> Output + WasmCompatSend + WasmCompatSync,
253    Input: WasmCompatSend + WasmCompatSync,
254    Output: WasmCompatSend + WasmCompatSync,
255{
256    Map::new(f)
257}
258
259pub struct Passthrough<T> {
260    _t: std::marker::PhantomData<T>,
261}
262
263impl<T> Passthrough<T> {
264    pub(crate) fn new() -> Self {
265        Self {
266            _t: std::marker::PhantomData,
267        }
268    }
269}
270
271impl<T> Op for Passthrough<T>
272where
273    T: WasmCompatSend + WasmCompatSync,
274{
275    type Input = T;
276    type Output = T;
277
278    async fn call(&self, input: Self::Input) -> Self::Output {
279        input
280    }
281}
282
283pub fn passthrough<T>() -> Passthrough<T>
284where
285    T: WasmCompatSend + WasmCompatSync,
286{
287    Passthrough::new()
288}
289
290pub struct Then<F, Input> {
291    f: F,
292    _t: std::marker::PhantomData<Input>,
293}
294
295impl<F, Input> Then<F, Input> {
296    pub(crate) fn new(f: F) -> Self {
297        Self {
298            f,
299            _t: std::marker::PhantomData,
300        }
301    }
302}
303
304impl<F, Input, Fut> Op for Then<F, Input>
305where
306    F: Fn(Input) -> Fut + WasmCompatSend + WasmCompatSync,
307    Input: WasmCompatSend + WasmCompatSync,
308    Fut: Future + WasmCompatSend,
309    Fut::Output: WasmCompatSend + WasmCompatSync,
310{
311    type Input = Input;
312    type Output = Fut::Output;
313
314    #[inline]
315    async fn call(&self, input: Self::Input) -> Self::Output {
316        (self.f)(input).await
317    }
318}
319
320pub fn then<F, Input, Fut>(f: F) -> Then<F, Input>
321where
322    F: Fn(Input) -> Fut + WasmCompatSend + WasmCompatSync,
323    Input: WasmCompatSend + WasmCompatSync,
324    Fut: Future + WasmCompatSend,
325    Fut::Output: WasmCompatSend + WasmCompatSync,
326{
327    Then::new(f)
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[tokio::test]
335    async fn test_sequential_constructor() {
336        let op1 = map(|x: i32| x + 1);
337        let op2 = map(|x: i32| x * 2);
338        let op3 = map(|x: i32| x * 3);
339
340        let pipeline = Sequential::new(Sequential::new(op1, op2), op3);
341
342        let result = pipeline.call(1).await;
343        assert_eq!(result, 12);
344    }
345
346    #[tokio::test]
347    async fn test_sequential_chain() {
348        let pipeline = map(|x: i32| x + 1)
349            .map(|x| x * 2)
350            .then(|x| async move { x * 3 });
351
352        let result = pipeline.call(1).await;
353        assert_eq!(result, 12);
354    }
355
356    // #[tokio::test]
357    // async fn test_flatten() {
358    //     let op = Parallel::new(
359    //         Parallel::new(
360    //             map(|x: i32| x + 1),
361    //             map(|x: i32| x * 2),
362    //         ),
363    //         map(|x: i32| x * 3),
364    //     );
365
366    //     let pipeline = flatten::<_, (_, _, _)>(op);
367
368    //     let result = pipeline.call(1).await;
369    //     assert_eq!(result, (2, 2, 3));
370    // }
371
372    // #[tokio::test]
373    // async fn test_parallel_macro() {
374    //     let op1 = map(|x: i32| x + 1);
375    //     let op2 = map(|x: i32| x * 3);
376    //     let op3 = map(|x: i32| format!("{} is the number!", x));
377    //     let op4 = map(|x: i32| x - 1);
378
379    //     let pipeline = parallel!(op1, op2, op3, op4);
380
381    //     let result = pipeline.call(1).await;
382    //     assert_eq!(result, (2, 3, "1 is the number!".to_string(), 0));
383    // }
384
385    // #[tokio::test]
386    // async fn test_parallel_join() {
387    //     let op3 = map(|x: i32| format!("{} is the number!", x));
388
389    //     let pipeline = Sequential::new(
390    //         map(|x: i32| x + 1),
391    //         then(|x| {
392    //             // let op1 = map(|x: i32| x * 2);
393    //             // let op2 = map(|x: i32| x * 3);
394    //             let op3 = &op3;
395
396    //             async move {
397    //             join!(
398    //                 (&map(|x: i32| x * 2)).call(x),
399    //                 {
400    //                     let op = map(|x: i32| x * 3);
401    //                     op.call(x)
402    //                 },
403    //                 op3.call(x),
404    //             )
405    //         }}),
406    //     );
407
408    //     let result = pipeline.call(1).await;
409    //     assert_eq!(result, (2, 3, "1 is the number!".to_string()));
410    // }
411
412    // #[test]
413    // fn test_flatten() {
414    //     let x = (1, (2, (3, 4)));
415    //     let result = flatten!(0, 1, 1, 1, 1);
416    //     assert_eq!(result, (1, 2, 3, 4));
417    // }
418}