Skip to main content

rig_core/pipeline/
op.rs

1use crate::wasm_compat::*;
2#[allow(unused_imports)] // Needed since this is used in a macro rule
3use futures::join;
4use futures::stream;
5use std::future::Future;
6
7// ================================================================
8// Core Op trait
9// ================================================================
10pub trait Op: WasmCompatSend + WasmCompatSync {
11    type Input: WasmCompatSend + WasmCompatSync;
12    type Output: WasmCompatSend + WasmCompatSync;
13
14    fn call(&self, input: Self::Input) -> impl Future<Output = Self::Output> + WasmCompatSend;
15
16    /// Execute the current pipeline with the given inputs. `n` is the number of concurrent
17    /// inputs that will be processed concurrently.
18    fn batch_call<I>(
19        &self,
20        n: usize,
21        input: I,
22    ) -> impl Future<Output = Vec<Self::Output>> + WasmCompatSend
23    where
24        I: IntoIterator<Item = Self::Input> + WasmCompatSend,
25        I::IntoIter: WasmCompatSend,
26        Self: Sized,
27    {
28        use futures::stream::StreamExt;
29
30        async move {
31            stream::iter(input)
32                .map(|input| self.call(input))
33                .buffered(n)
34                .collect()
35                .await
36        }
37    }
38
39    /// Chain a function `f` to the current op.
40    ///
41    /// # Example
42    /// ```no_run
43    /// use rig_core::pipeline::{self, Op};
44    ///
45    /// # async fn run() {
46    /// let chain = pipeline::new()
47    ///    .map(|(x, y)| x + y)
48    ///    .map(|z| format!("Result: {z}!"));
49    ///
50    /// let result = chain.call((1, 2)).await;
51    /// assert_eq!(result, "Result: 3!");
52    /// # }
53    /// ```
54    fn map<F, Input>(self, f: F) -> Sequential<Self, Map<F, Self::Output>>
55    where
56        F: Fn(Self::Output) -> Input + WasmCompatSend + WasmCompatSync,
57        Input: WasmCompatSend + WasmCompatSync,
58        Self: Sized,
59    {
60        Sequential::new(self, Map::new(f))
61    }
62
63    /// Same as `map` but for asynchronous functions
64    ///
65    /// # Example
66    /// ```no_run
67    /// use rig_core::pipeline::{self, Op};
68    ///
69    /// # async fn run() {
70    /// let chain = pipeline::new()
71    ///     .then(|email: String| async move {
72    ///         email.split('@').next().unwrap().to_string()
73    ///     })
74    ///     .then(|username: String| async move {
75    ///         format!("Hello, {}!", username)
76    ///     });
77    ///
78    /// let result = chain.call("bob@gmail.com".to_string()).await;
79    /// assert_eq!(result, "Hello, bob!");
80    /// # }
81    /// ```
82    fn then<F, Fut>(self, f: F) -> Sequential<Self, Then<F, Fut::Output>>
83    where
84        F: Fn(Self::Output) -> Fut + Send + WasmCompatSync,
85        Fut: Future + WasmCompatSend + WasmCompatSync,
86        Fut::Output: WasmCompatSend + WasmCompatSync,
87        Self: Sized,
88    {
89        Sequential::new(self, Then::new(f))
90    }
91
92    /// Chain an arbitrary operation to the current op.
93    ///
94    /// # Example
95    /// ```no_run
96    /// use rig_core::pipeline::{self, Op};
97    ///
98    /// # async fn run() {
99    /// struct AddOne;
100    ///
101    /// impl Op for AddOne {
102    ///     type Input = i32;
103    ///     type Output = i32;
104    ///
105    ///     async fn call(&self, input: Self::Input) -> Self::Output {
106    ///         input + 1
107    ///     }
108    /// }
109    ///
110    /// let chain = pipeline::new()
111    ///    .chain(AddOne);
112    ///
113    /// let result = chain.call(1).await;
114    /// assert_eq!(result, 2);
115    /// # }
116    /// ```
117    fn chain<T>(self, op: T) -> Sequential<Self, T>
118    where
119        T: Op<Input = Self::Output>,
120        Self: Sized,
121    {
122        Sequential::new(self, op)
123    }
124
125    /// Chain a lookup operation to the current chain. The lookup operation expects the
126    /// current chain to output a query string. The lookup operation will use the query to
127    /// retrieve the top `n` documents from the index and return them with the query string.
128    ///
129    /// # Example
130    /// ```ignore
131    /// use rig_core::chain::{self, Chain};
132    ///
133    /// let chain = chain::new()
134    ///     .lookup(index, 2)
135    ///     .chain(|(query, docs): (_, Vec<String>)| async move {
136    ///         format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
137    ///     });
138    ///
139    /// let result = chain.call("What is a flurbo?".to_string()).await;
140    /// ```
141    fn lookup<I, Input>(
142        self,
143        index: I,
144        n: usize,
145    ) -> Sequential<Self, Lookup<I, Self::Output, Input>>
146    where
147        I: vector_store::VectorStoreIndex,
148        Input: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>,
149        Self::Output: Into<String>,
150        Self: Sized,
151    {
152        Sequential::new(self, Lookup::new(index, n))
153    }
154
155    /// Chain a prompt operation to the current chain. The prompt operation expects the
156    /// current chain to output a string. The prompt operation will use the string to prompt
157    /// the given agent (or any other type that implements the `Prompt` trait) and return
158    /// the response.
159    ///
160    /// # Example
161    /// ```ignore
162    /// use rig_core::chain::{self, Chain};
163    ///
164    /// let agent = &openai_client.agent("gpt-4").build();
165    ///
166    /// let chain = chain::new()
167    ///    .map(|name| format!("Find funny nicknames for the following name: {name}!"))
168    ///    .prompt(agent);
169    ///
170    /// let result = chain.call("Alice".to_string()).await;
171    /// ```
172    fn prompt<P>(self, prompt: P) -> Sequential<Self, Prompt<P, Self::Output>>
173    where
174        P: completion::Prompt,
175        Self::Output: Into<String>,
176        Self: Sized,
177    {
178        Sequential::new(self, Prompt::new(prompt))
179    }
180}
181
182impl<T: Op> Op for &T {
183    type Input = T::Input;
184    type Output = T::Output;
185
186    #[inline]
187    async fn call(&self, input: Self::Input) -> Self::Output {
188        (*self).call(input).await
189    }
190}
191
192// ================================================================
193// Op combinators
194// ================================================================
195pub struct Sequential<Op1, Op2> {
196    prev: Op1,
197    op: Op2,
198}
199
200impl<Op1, Op2> Sequential<Op1, Op2> {
201    pub(crate) fn new(prev: Op1, op: Op2) -> Self {
202        Self { prev, op }
203    }
204}
205
206impl<Op1, Op2> Op for Sequential<Op1, Op2>
207where
208    Op1: Op,
209    Op2: Op<Input = Op1::Output>,
210{
211    type Input = Op1::Input;
212    type Output = Op2::Output;
213
214    #[inline]
215    async fn call(&self, input: Self::Input) -> Self::Output {
216        let prev = self.prev.call(input).await;
217        self.op.call(prev).await
218    }
219}
220
221use super::agent_ops::{Lookup, Prompt};
222use crate::{completion, vector_store};
223
224// ================================================================
225// Core Op implementations
226// ================================================================
227pub struct Map<F, Input> {
228    f: F,
229    _t: std::marker::PhantomData<Input>,
230}
231
232impl<F, Input> Map<F, Input> {
233    pub(crate) fn new(f: F) -> Self {
234        Self {
235            f,
236            _t: std::marker::PhantomData,
237        }
238    }
239}
240
241impl<F, Input, Output> Op for Map<F, Input>
242where
243    F: Fn(Input) -> Output + WasmCompatSend + WasmCompatSync,
244    Input: WasmCompatSend + WasmCompatSync,
245    Output: WasmCompatSend + WasmCompatSync,
246{
247    type Input = Input;
248    type Output = Output;
249
250    #[inline]
251    async fn call(&self, input: Self::Input) -> Self::Output {
252        (self.f)(input)
253    }
254}
255
256pub fn map<F, Input, Output>(f: F) -> Map<F, Input>
257where
258    F: Fn(Input) -> Output + WasmCompatSend + WasmCompatSync,
259    Input: WasmCompatSend + WasmCompatSync,
260    Output: WasmCompatSend + WasmCompatSync,
261{
262    Map::new(f)
263}
264
265pub struct Passthrough<T> {
266    _t: std::marker::PhantomData<T>,
267}
268
269impl<T> Passthrough<T> {
270    pub(crate) fn new() -> Self {
271        Self {
272            _t: std::marker::PhantomData,
273        }
274    }
275}
276
277impl<T> Op for Passthrough<T>
278where
279    T: WasmCompatSend + WasmCompatSync,
280{
281    type Input = T;
282    type Output = T;
283
284    async fn call(&self, input: Self::Input) -> Self::Output {
285        input
286    }
287}
288
289pub fn passthrough<T>() -> Passthrough<T>
290where
291    T: WasmCompatSend + WasmCompatSync,
292{
293    Passthrough::new()
294}
295
296pub struct Then<F, Input> {
297    f: F,
298    _t: std::marker::PhantomData<Input>,
299}
300
301impl<F, Input> Then<F, Input> {
302    pub(crate) fn new(f: F) -> Self {
303        Self {
304            f,
305            _t: std::marker::PhantomData,
306        }
307    }
308}
309
310impl<F, Input, Fut> Op for Then<F, Input>
311where
312    F: Fn(Input) -> Fut + WasmCompatSend + WasmCompatSync,
313    Input: WasmCompatSend + WasmCompatSync,
314    Fut: Future + WasmCompatSend,
315    Fut::Output: WasmCompatSend + WasmCompatSync,
316{
317    type Input = Input;
318    type Output = Fut::Output;
319
320    #[inline]
321    async fn call(&self, input: Self::Input) -> Self::Output {
322        (self.f)(input).await
323    }
324}
325
326pub fn then<F, Input, Fut>(f: F) -> Then<F, Input>
327where
328    F: Fn(Input) -> Fut + WasmCompatSend + WasmCompatSync,
329    Input: WasmCompatSend + WasmCompatSync,
330    Fut: Future + WasmCompatSend,
331    Fut::Output: WasmCompatSend + WasmCompatSync,
332{
333    Then::new(f)
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[tokio::test]
341    async fn test_sequential_constructor() {
342        let op1 = map(|x: i32| x + 1);
343        let op2 = map(|x: i32| x * 2);
344        let op3 = map(|x: i32| x * 3);
345
346        let pipeline = Sequential::new(Sequential::new(op1, op2), op3);
347
348        let result = pipeline.call(1).await;
349        assert_eq!(result, 12);
350    }
351
352    #[tokio::test]
353    async fn test_sequential_chain() {
354        let pipeline = map(|x: i32| x + 1)
355            .map(|x| x * 2)
356            .then(|x| async move { x * 3 });
357
358        let result = pipeline.call(1).await;
359        assert_eq!(result, 12);
360    }
361
362    // #[tokio::test]
363    // async fn test_flatten() {
364    //     let op = Parallel::new(
365    //         Parallel::new(
366    //             map(|x: i32| x + 1),
367    //             map(|x: i32| x * 2),
368    //         ),
369    //         map(|x: i32| x * 3),
370    //     );
371
372    //     let pipeline = flatten::<_, (_, _, _)>(op);
373
374    //     let result = pipeline.call(1).await;
375    //     assert_eq!(result, (2, 2, 3));
376    // }
377
378    // #[tokio::test]
379    // async fn test_parallel_macro() {
380    //     let op1 = map(|x: i32| x + 1);
381    //     let op2 = map(|x: i32| x * 3);
382    //     let op3 = map(|x: i32| format!("{} is the number!", x));
383    //     let op4 = map(|x: i32| x - 1);
384
385    //     let pipeline = parallel!(op1, op2, op3, op4);
386
387    //     let result = pipeline.call(1).await;
388    //     assert_eq!(result, (2, 3, "1 is the number!".to_string(), 0));
389    // }
390
391    // #[tokio::test]
392    // async fn test_parallel_join() {
393    //     let op3 = map(|x: i32| format!("{} is the number!", x));
394
395    //     let pipeline = Sequential::new(
396    //         map(|x: i32| x + 1),
397    //         then(|x| {
398    //             // let op1 = map(|x: i32| x * 2);
399    //             // let op2 = map(|x: i32| x * 3);
400    //             let op3 = &op3;
401
402    //             async move {
403    //             join!(
404    //                 (&map(|x: i32| x * 2)).call(x),
405    //                 {
406    //                     let op = map(|x: i32| x * 3);
407    //                     op.call(x)
408    //                 },
409    //                 op3.call(x),
410    //             )
411    //         }}),
412    //     );
413
414    //     let result = pipeline.call(1).await;
415    //     assert_eq!(result, (2, 3, "1 is the number!".to_string()));
416    // }
417
418    // #[test]
419    // fn test_flatten() {
420    //     let x = (1, (2, (3, 4)));
421    //     let result = flatten!(0, 1, 1, 1, 1);
422    //     assert_eq!(result, (1, 2, 3, 4));
423    // }
424}