rig/pipeline/
mod.rs

1//! This module defines a flexible pipeline API for defining a sequence of operations that
2//! may or may not use AI components (e.g.: semantic search, LLMs prompting, etc).
3//!
4//! The pipeline API was inspired by general orchestration pipelines such as Airflow, Dagster and Prefect,
5//! but implemented with idiomatic Rust patterns and providing some AI-specific ops out-of-the-box along
6//! general combinators.
7//!
8//! Pipelines are made up of one or more operations, or "ops", each of which must implement the [Op] trait.
9//! The [Op] trait requires the implementation of only one method: `call`, which takes an input
10//! and returns an output. The trait provides a wide range of combinators for chaining operations together.
11//!
12//! One can think of a pipeline as a DAG (Directed Acyclic Graph) where each node is an operation and
13//! the edges represent the data flow between operations. When invoking the pipeline on some input,
14//! the input is passed to the root node of the DAG (i.e.: the first op defined in the pipeline).
15//! The output of each op is then passed to the next op in the pipeline until the output reaches the
16//! leaf node (i.e.: the last op defined in the pipeline). The output of the leaf node is then returned
17//! as the result of the pipeline.
18//!
19//! ## Basic Example
20//! For example, the pipeline below takes a tuple of two integers, adds them together and then formats
21//! the result as a string using the [map](Op::map) combinator method, which applies a simple function
22//! op to the output of the previous op:
23//! ```rust
24//! use rig::pipeline::{self, Op};
25//!
26//! let pipeline = pipeline::new()
27//!     // op1: add two numbers
28//!     .map(|(x, y)| x + y)
29//!     // op2: format result
30//!     .map(|z| format!("Result: {z}!"));
31//!
32//! let result = pipeline.call((1, 2)).await;
33//! assert_eq!(result, "Result: 3!");
34//! ```
35//!
36//! This pipeline can be visualized as the following DAG:
37//! ```text
38//!          ┌─────────┐   ┌─────────┐         
39//! Input───►│   op1   ├──►│   op2   ├──►Output
40//!          └─────────┘   └─────────┘         
41//! ```
42//!
43//! ## Parallel Operations
44//! The pipeline API also provides a [parallel!](crate::parallel!) and macro for running operations in parallel.
45//! The macro takes a list of ops and turns them into a single op that will duplicate the input
46//! and run each op in concurrently. The results of each op are then collected and returned as a tuple.
47//!
48//! For example, the pipeline below runs two operations concurrently:
49//! ```rust
50//! use rig::{pipeline::{self, Op, map}, parallel};
51//!
52//! let pipeline = pipeline::new()
53//!     .chain(parallel!(
54//!         // op1: add 1 to input
55//!         map(|x| x + 1),
56//!         // op2: subtract 1 from input
57//!         map(|x| x - 1),
58//!     ))
59//!     // op3: format results
60//!     .map(|(a, b)| format!("Results: {a}, {b}"));
61//!
62//! let result = pipeline.call(1).await;
63//! assert_eq!(result, "Result: 2, 0");
64//! ```
65//!
66//! Notes:
67//! - The [chain](Op::chain) method is similar to the [map](Op::map) method but it allows
68//!   for chaining arbitrary operations, as long as they implement the [Op] trait.
69//! - [map] is a function that initializes a standalone [Map](self::op::Map) op without an existing pipeline/op.
70//!
71//! The pipeline above can be visualized as the following DAG:
72//! ```text                 
73//!           Input            
74//!             │              
75//!      ┌──────┴──────┐       
76//!      ▼             ▼       
77//! ┌─────────┐   ┌─────────┐  
78//! │   op1   │   │   op2   │  
79//! └────┬────┘   └────┬────┘  
80//!      └──────┬──────┘       
81//!             ▼              
82//!        ┌─────────┐         
83//!        │   op3   │         
84//!        └────┬────┘         
85//!             │              
86//!             ▼              
87//!          Output           
88//! ```
89
90pub mod agent_ops;
91pub mod op;
92pub mod try_op;
93#[macro_use]
94pub mod parallel;
95#[macro_use]
96pub mod conditional;
97
98use std::future::Future;
99
100pub use op::{map, passthrough, then, Op};
101pub use try_op::TryOp;
102
103use crate::{completion, extractor::Extractor, vector_store};
104
105pub struct PipelineBuilder<E> {
106    _error: std::marker::PhantomData<E>,
107}
108
109impl<E> PipelineBuilder<E> {
110    /// Add a function to the current pipeline
111    ///
112    /// # Example
113    /// ```rust
114    /// use rig::pipeline::{self, Op};
115    ///
116    /// let pipeline = pipeline::new()
117    ///    .map(|(x, y)| x + y)
118    ///    .map(|z| format!("Result: {z}!"));
119    ///
120    /// let result = pipeline.call((1, 2)).await;
121    /// assert_eq!(result, "Result: 3!");
122    /// ```
123    pub fn map<F, Input, Output>(self, f: F) -> op::Map<F, Input>
124    where
125        F: Fn(Input) -> Output + Send + Sync,
126        Input: Send + Sync,
127        Output: Send + Sync,
128        Self: Sized,
129    {
130        op::Map::new(f)
131    }
132
133    /// Same as `map` but for asynchronous functions
134    ///
135    /// # Example
136    /// ```rust
137    /// use rig::pipeline::{self, Op};
138    ///
139    /// let pipeline = pipeline::new()
140    ///     .then(|email: String| async move {
141    ///         email.split('@').next().unwrap().to_string()
142    ///     })
143    ///     .then(|username: String| async move {
144    ///         format!("Hello, {}!", username)
145    ///     });
146    ///
147    /// let result = pipeline.call("bob@gmail.com".to_string()).await;
148    /// assert_eq!(result, "Hello, bob!");
149    /// ```
150    pub fn then<F, Input, Fut>(self, f: F) -> op::Then<F, Input>
151    where
152        F: Fn(Input) -> Fut + Send + Sync,
153        Input: Send + Sync,
154        Fut: Future + Send + Sync,
155        Fut::Output: Send + Sync,
156        Self: Sized,
157    {
158        op::Then::new(f)
159    }
160
161    /// Add an arbitrary operation to the current pipeline.
162    ///
163    /// # Example
164    /// ```rust
165    /// use rig::pipeline::{self, Op};
166    ///
167    /// struct MyOp;
168    ///
169    /// impl Op for MyOp {
170    ///     type Input = i32;
171    ///     type Output = i32;
172    ///
173    ///     async fn call(&self, input: Self::Input) -> Self::Output {
174    ///         input + 1
175    ///     }
176    /// }
177    ///
178    /// let pipeline = pipeline::new()
179    ///    .chain(MyOp);
180    ///
181    /// let result = pipeline.call(1).await;
182    /// assert_eq!(result, 2);
183    /// ```
184    pub fn chain<T>(self, op: T) -> T
185    where
186        T: Op,
187        Self: Sized,
188    {
189        op
190    }
191
192    /// Chain a lookup operation to the current chain. The lookup operation expects the
193    /// current chain to output a query string. The lookup operation will use the query to
194    /// retrieve the top `n` documents from the index and return them with the query string.
195    ///
196    /// # Example
197    /// ```rust
198    /// use rig::pipeline::{self, Op};
199    ///
200    /// let pipeline = pipeline::new()
201    ///     .lookup(index, 2)
202    ///     .pipeline(|(query, docs): (_, Vec<String>)| async move {
203    ///         format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
204    ///     });
205    ///
206    /// let result = pipeline.call("What is a flurbo?".to_string()).await;
207    /// ```
208    pub fn lookup<I, Input, Output>(self, index: I, n: usize) -> agent_ops::Lookup<I, Input, Output>
209    where
210        I: vector_store::VectorStoreIndex,
211        Output: Send + Sync + for<'a> serde::Deserialize<'a>,
212        Input: Into<String> + Send + Sync,
213        // E: From<vector_store::VectorStoreError> + Send + Sync,
214        Self: Sized,
215    {
216        agent_ops::Lookup::new(index, n)
217    }
218
219    /// Add a prompt operation to the current pipeline/op. The prompt operation expects the
220    /// current pipeline to output a string. The prompt operation will use the string to prompt
221    /// the given `agent`, which must implements the [Prompt](completion::Prompt) trait and return
222    /// the response.
223    ///
224    /// # Example
225    /// ```rust
226    /// use rig::pipeline::{self, Op};
227    ///
228    /// let agent = &openai_client.agent("gpt-4").build();
229    ///
230    /// let pipeline = pipeline::new()
231    ///    .map(|name| format!("Find funny nicknames for the following name: {name}!"))
232    ///    .prompt(agent);
233    ///
234    /// let result = pipeline.call("Alice".to_string()).await;
235    /// ```
236    pub fn prompt<P, Input>(self, agent: P) -> agent_ops::Prompt<P, Input>
237    where
238        P: completion::Prompt,
239        Input: Into<String> + Send + Sync,
240        // E: From<completion::PromptError> + Send + Sync,
241        Self: Sized,
242    {
243        agent_ops::Prompt::new(agent)
244    }
245
246    /// Add an extract operation to the current pipeline/op. The extract operation expects the
247    /// current pipeline to output a string. The extract operation will use the given `extractor`
248    /// to extract information from the string in the form of the type `T` and return it.
249    ///
250    /// # Example
251    /// ```rust
252    /// use rig::pipeline::{self, Op};
253    ///
254    /// #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
255    /// struct Sentiment {
256    ///     /// The sentiment score of the text (0.0 = negative, 1.0 = positive)
257    ///     score: f64,
258    /// }
259    ///
260    /// let extractor = &openai_client.extractor::<Sentiment>("gpt-4").build();
261    ///
262    /// let pipeline = pipeline::new()
263    ///     .map(|text| format!("Analyze the sentiment of the following text: {text}!"))
264    ///     .extract(extractor);
265    ///
266    /// let result: Sentiment = pipeline.call("I love ice cream!".to_string()).await?;
267    /// assert!(result.score > 0.5);
268    /// ```
269    pub fn extract<M, Input, Output>(
270        self,
271        extractor: Extractor<M, Output>,
272    ) -> agent_ops::Extract<M, Input, Output>
273    where
274        M: completion::CompletionModel,
275        Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
276        Input: Into<String> + Send + Sync,
277    {
278        agent_ops::Extract::new(extractor)
279    }
280}
281
282#[derive(Debug, thiserror::Error)]
283pub enum ChainError {
284    #[error("Failed to prompt agent: {0}")]
285    PromptError(#[from] completion::PromptError),
286
287    #[error("Failed to lookup documents: {0}")]
288    LookupError(#[from] vector_store::VectorStoreError),
289}
290
291pub fn new() -> PipelineBuilder<ChainError> {
292    PipelineBuilder {
293        _error: std::marker::PhantomData,
294    }
295}
296
297pub fn with_error<E>() -> PipelineBuilder<E> {
298    PipelineBuilder {
299        _error: std::marker::PhantomData,
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use agent_ops::tests::{Foo, MockIndex, MockModel};
307    use parallel::parallel;
308
309    #[tokio::test]
310    async fn test_prompt_pipeline() {
311        let model = MockModel;
312
313        let chain = super::new()
314            .map(|input| format!("User query: {}", input))
315            .prompt(model);
316
317        let result = chain
318            .call("What is a flurbo?")
319            .await
320            .expect("Failed to run chain");
321
322        assert_eq!(result, "Mock response: User query: What is a flurbo?");
323    }
324
325    #[tokio::test]
326    async fn test_prompt_pipeline_error() {
327        let model = MockModel;
328
329        let chain = super::with_error::<()>()
330            .map(|input| format!("User query: {}", input))
331            .prompt(model);
332
333        let result = chain
334            .try_call("What is a flurbo?")
335            .await
336            .expect("Failed to run chain");
337
338        assert_eq!(result, "Mock response: User query: What is a flurbo?");
339    }
340
341    #[tokio::test]
342    async fn test_lookup_pipeline() {
343        let index = MockIndex;
344
345        let chain = super::new()
346            .lookup::<_, _, Foo>(index, 1)
347            .map_ok(|docs| format!("Top documents:\n{}", docs[0].2.foo));
348
349        let result = chain
350            .try_call("What is a flurbo?")
351            .await
352            .expect("Failed to run chain");
353
354        assert_eq!(result, "Top documents:\nbar");
355    }
356
357    #[tokio::test]
358    async fn test_rag_pipeline() {
359        let index = MockIndex;
360
361        let chain = super::new()
362            .chain(parallel!(
363                passthrough(),
364                agent_ops::lookup::<_, _, Foo>(index, 1),
365            ))
366            .map(|(query, maybe_docs)| match maybe_docs {
367                Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].2.foo),
368                Err(err) => format!("Error: {}", err),
369            })
370            .prompt(MockModel);
371
372        let result = chain
373            .call("What is a flurbo?")
374            .await
375            .expect("Failed to run chain");
376
377        assert_eq!(
378            result,
379            "Mock response: User query: What is a flurbo?\n\nTop documents:\nbar"
380        );
381    }
382}