rig/pipeline/
op.rs

1use std::future::Future;
2
3#[allow(unused_imports)] // Needed since this is used in a macro rule
4use futures::join;
5use futures::stream;
6
7// ================================================================
8// Core Op trait
9// ================================================================
10pub trait Op: Send + Sync {
11    type Input: Send + Sync;
12    type Output: Send + Sync;
13
14    fn call(&self, input: Self::Input) -> impl Future<Output = Self::Output> + Send;
15
16    /// Execute the current pipeline with the given inputs. `n` is the number of concurrent
17    /// inputs that will be processed concurrently.
18    fn batch_call<I>(&self, n: usize, input: I) -> impl Future<Output = Vec<Self::Output>> + Send
19    where
20        I: IntoIterator<Item = Self::Input> + Send,
21        I::IntoIter: Send,
22        Self: Sized,
23    {
24        use futures::stream::StreamExt;
25
26        async move {
27            stream::iter(input)
28                .map(|input| self.call(input))
29                .buffered(n)
30                .collect()
31                .await
32        }
33    }
34
35    /// Chain a function `f` to the current op.
36    ///
37    /// # Example
38    /// ```rust
39    /// use rig::pipeline::{self, Op};
40    ///
41    /// let chain = pipeline::new()
42    ///    .map(|(x, y)| x + y)
43    ///    .map(|z| format!("Result: {z}!"));
44    ///
45    /// let result = chain.call((1, 2)).await;
46    /// assert_eq!(result, "Result: 3!");
47    /// ```
48    fn map<F, Input>(self, f: F) -> Sequential<Self, Map<F, Self::Output>>
49    where
50        F: Fn(Self::Output) -> Input + Send + Sync,
51        Input: Send + Sync,
52        Self: Sized,
53    {
54        Sequential::new(self, Map::new(f))
55    }
56
57    /// Same as `map` but for asynchronous functions
58    ///
59    /// # Example
60    /// ```rust
61    /// use rig::pipeline::{self, Op};
62    ///
63    /// let chain = pipeline::new()
64    ///     .then(|email: String| async move {
65    ///         email.split('@').next().unwrap().to_string()
66    ///     })
67    ///     .then(|username: String| async move {
68    ///         format!("Hello, {}!", username)
69    ///     });
70    ///
71    /// let result = chain.call("bob@gmail.com".to_string()).await;
72    /// assert_eq!(result, "Hello, bob!");
73    /// ```
74    fn then<F, Fut>(self, f: F) -> Sequential<Self, Then<F, Fut::Output>>
75    where
76        F: Fn(Self::Output) -> Fut + Send + Sync,
77        Fut: Future + Send + Sync,
78        Fut::Output: Send + Sync,
79        Self: Sized,
80    {
81        Sequential::new(self, Then::new(f))
82    }
83
84    /// Chain an arbitrary operation to the current op.
85    ///
86    /// # Example
87    /// ```rust
88    /// use rig::pipeline::{self, Op};
89    ///
90    /// struct AddOne;
91    ///
92    /// impl Op for AddOne {
93    ///     type Input = i32;
94    ///     type Output = i32;
95    ///
96    ///     async fn call(&self, input: Self::Input) -> Self::Output {
97    ///         input + 1
98    ///     }
99    /// }
100    ///
101    /// let chain = pipeline::new()
102    ///    .chain(AddOne);
103    ///
104    /// let result = chain.call(1).await;
105    /// assert_eq!(result, 2);
106    /// ```
107    fn chain<T>(self, op: T) -> Sequential<Self, T>
108    where
109        T: Op<Input = Self::Output>,
110        Self: Sized,
111    {
112        Sequential::new(self, op)
113    }
114
115    /// Chain a lookup operation to the current chain. The lookup operation expects the
116    /// current chain to output a query string. The lookup operation will use the query to
117    /// retrieve the top `n` documents from the index and return them with the query string.
118    ///
119    /// # Example
120    /// ```rust
121    /// use rig::chain::{self, Chain};
122    ///
123    /// let chain = chain::new()
124    ///     .lookup(index, 2)
125    ///     .chain(|(query, docs): (_, Vec<String>)| async move {
126    ///         format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
127    ///     });
128    ///
129    /// let result = chain.call("What is a flurbo?".to_string()).await;
130    /// ```
131    fn lookup<I, Input>(
132        self,
133        index: I,
134        n: usize,
135    ) -> Sequential<Self, Lookup<I, Self::Output, Input>>
136    where
137        I: vector_store::VectorStoreIndex,
138        Input: Send + Sync + for<'a> serde::Deserialize<'a>,
139        Self::Output: Into<String>,
140        Self: Sized,
141    {
142        Sequential::new(self, Lookup::new(index, n))
143    }
144
145    /// Chain a prompt operation to the current chain. The prompt operation expects the
146    /// current chain to output a string. The prompt operation will use the string to prompt
147    /// the given agent (or any other type that implements the `Prompt` trait) and return
148    /// the response.
149    ///
150    /// # Example
151    /// ```rust
152    /// use rig::chain::{self, Chain};
153    ///
154    /// let agent = &openai_client.agent("gpt-4").build();
155    ///
156    /// let chain = chain::new()
157    ///    .map(|name| format!("Find funny nicknames for the following name: {name}!"))
158    ///    .prompt(agent);
159    ///
160    /// let result = chain.call("Alice".to_string()).await;
161    /// ```
162    fn prompt<P>(self, prompt: P) -> Sequential<Self, Prompt<P, Self::Output>>
163    where
164        P: completion::Prompt,
165        Self::Output: Into<String>,
166        Self: Sized,
167    {
168        Sequential::new(self, Prompt::new(prompt))
169    }
170}
171
172impl<T: Op> Op for &T {
173    type Input = T::Input;
174    type Output = T::Output;
175
176    #[inline]
177    async fn call(&self, input: Self::Input) -> Self::Output {
178        (*self).call(input).await
179    }
180}
181
182// ================================================================
183// Op combinators
184// ================================================================
185pub struct Sequential<Op1, Op2> {
186    prev: Op1,
187    op: Op2,
188}
189
190impl<Op1, Op2> Sequential<Op1, Op2> {
191    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
192        Self { prev, op }
193    }
194}
195
196impl<Op1, Op2> Op for Sequential<Op1, Op2>
197where
198    Op1: Op,
199    Op2: Op<Input = Op1::Output>,
200{
201    type Input = Op1::Input;
202    type Output = Op2::Output;
203
204    #[inline]
205    async fn call(&self, input: Self::Input) -> Self::Output {
206        let prev = self.prev.call(input).await;
207        self.op.call(prev).await
208    }
209}
210
211use crate::{completion, vector_store};
212
213use super::agent_ops::{Lookup, Prompt};
214
215// ================================================================
216// Core Op implementations
217// ================================================================
218pub struct Map<F, Input> {
219    f: F,
220    _t: std::marker::PhantomData<Input>,
221}
222
223impl<F, Input> Map<F, Input> {
224    pub(crate) fn new(f: F) -> Self {
225        Self {
226            f,
227            _t: std::marker::PhantomData,
228        }
229    }
230}
231
232impl<F, Input, Output> Op for Map<F, Input>
233where
234    F: Fn(Input) -> Output + Send + Sync,
235    Input: Send + Sync,
236    Output: Send + Sync,
237{
238    type Input = Input;
239    type Output = Output;
240
241    #[inline]
242    async fn call(&self, input: Self::Input) -> Self::Output {
243        (self.f)(input)
244    }
245}
246
247pub fn map<F, Input, Output>(f: F) -> Map<F, Input>
248where
249    F: Fn(Input) -> Output + Send + Sync,
250    Input: Send + Sync,
251    Output: Send + Sync,
252{
253    Map::new(f)
254}
255
256pub struct Passthrough<T> {
257    _t: std::marker::PhantomData<T>,
258}
259
260impl<T> Passthrough<T> {
261    pub(crate) fn new() -> Self {
262        Self {
263            _t: std::marker::PhantomData,
264        }
265    }
266}
267
268impl<T> Op for Passthrough<T>
269where
270    T: Send + Sync,
271{
272    type Input = T;
273    type Output = T;
274
275    async fn call(&self, input: Self::Input) -> Self::Output {
276        input
277    }
278}
279
280pub fn passthrough<T>() -> Passthrough<T>
281where
282    T: Send + Sync,
283{
284    Passthrough::new()
285}
286
287pub struct Then<F, Input> {
288    f: F,
289    _t: std::marker::PhantomData<Input>,
290}
291
292impl<F, Input> Then<F, Input> {
293    pub(crate) fn new(f: F) -> Self {
294        Self {
295            f,
296            _t: std::marker::PhantomData,
297        }
298    }
299}
300
301impl<F, Input, Fut> Op for Then<F, Input>
302where
303    F: Fn(Input) -> Fut + Send + Sync,
304    Input: Send + Sync,
305    Fut: Future + Send,
306    Fut::Output: Send + Sync,
307{
308    type Input = Input;
309    type Output = Fut::Output;
310
311    #[inline]
312    async fn call(&self, input: Self::Input) -> Self::Output {
313        (self.f)(input).await
314    }
315}
316
317pub fn then<F, Input, Fut>(f: F) -> Then<F, Input>
318where
319    F: Fn(Input) -> Fut + Send + Sync,
320    Input: Send + Sync,
321    Fut: Future + Send,
322    Fut::Output: Send + Sync,
323{
324    Then::new(f)
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    #[tokio::test]
332    async fn test_sequential_constructor() {
333        let op1 = map(|x: i32| x + 1);
334        let op2 = map(|x: i32| x * 2);
335        let op3 = map(|x: i32| x * 3);
336
337        let pipeline = Sequential::new(Sequential::new(op1, op2), op3);
338
339        let result = pipeline.call(1).await;
340        assert_eq!(result, 12);
341    }
342
343    #[tokio::test]
344    async fn test_sequential_chain() {
345        let pipeline = map(|x: i32| x + 1)
346            .map(|x| x * 2)
347            .then(|x| async move { x * 3 });
348
349        let result = pipeline.call(1).await;
350        assert_eq!(result, 12);
351    }
352
353    // #[tokio::test]
354    // async fn test_flatten() {
355    //     let op = Parallel::new(
356    //         Parallel::new(
357    //             map(|x: i32| x + 1),
358    //             map(|x: i32| x * 2),
359    //         ),
360    //         map(|x: i32| x * 3),
361    //     );
362
363    //     let pipeline = flatten::<_, (_, _, _)>(op);
364
365    //     let result = pipeline.call(1).await;
366    //     assert_eq!(result, (2, 2, 3));
367    // }
368
369    // #[tokio::test]
370    // async fn test_parallel_macro() {
371    //     let op1 = map(|x: i32| x + 1);
372    //     let op2 = map(|x: i32| x * 3);
373    //     let op3 = map(|x: i32| format!("{} is the number!", x));
374    //     let op4 = map(|x: i32| x - 1);
375
376    //     let pipeline = parallel!(op1, op2, op3, op4);
377
378    //     let result = pipeline.call(1).await;
379    //     assert_eq!(result, (2, 3, "1 is the number!".to_string(), 0));
380    // }
381
382    // #[tokio::test]
383    // async fn test_parallel_join() {
384    //     let op3 = map(|x: i32| format!("{} is the number!", x));
385
386    //     let pipeline = Sequential::new(
387    //         map(|x: i32| x + 1),
388    //         then(|x| {
389    //             // let op1 = map(|x: i32| x * 2);
390    //             // let op2 = map(|x: i32| x * 3);
391    //             let op3 = &op3;
392
393    //             async move {
394    //             join!(
395    //                 (&map(|x: i32| x * 2)).call(x),
396    //                 {
397    //                     let op = map(|x: i32| x * 3);
398    //                     op.call(x)
399    //                 },
400    //                 op3.call(x),
401    //             )
402    //         }}),
403    //     );
404
405    //     let result = pipeline.call(1).await;
406    //     assert_eq!(result, (2, 3, "1 is the number!".to_string()));
407    // }
408
409    // #[test]
410    // fn test_flatten() {
411    //     let x = (1, (2, (3, 4)));
412    //     let result = flatten!(0, 1, 1, 1, 1);
413    //     assert_eq!(result, (1, 2, 3, 4));
414    // }
415}