Skip to main content

rig_core/pipeline/
mod.rs

1//! This module defines a flexible pipeline API for defining a sequence of operations that
2//! may or may not use AI components (e.g.: semantic search, LLMs prompting, etc).
3//!
4//! The pipeline API was inspired by general orchestration pipelines such as Airflow, Dagster and Prefect,
5//! but implemented with idiomatic Rust patterns and providing some AI-specific ops out-of-the-box along
6//! general combinators.
7//!
8//! Pipelines are made up of one or more operations, or "ops", each of which must implement the [Op] trait.
9//! The [Op] trait requires the implementation of only one method: `call`, which takes an input
10//! and returns an output. The trait provides a wide range of combinators for chaining operations together.
11//!
12//! One can think of a pipeline as a DAG (Directed Acyclic Graph) where each node is an operation and
13//! the edges represent the data flow between operations. When invoking the pipeline on some input,
14//! the input is passed to the root node of the DAG (i.e.: the first op defined in the pipeline).
15//! The output of each op is then passed to the next op in the pipeline until the output reaches the
16//! leaf node (i.e.: the last op defined in the pipeline). The output of the leaf node is then returned
17//! as the result of the pipeline.
18//!
19//! ## Basic Example
20//! For example, the pipeline below takes a tuple of two integers, adds them together and then formats
21//! the result as a string using the [map](Op::map) combinator method, which applies a simple function
22//! op to the output of the previous op:
23//! ```no_run
24//! use rig_core::pipeline::{self, Op};
25//!
26//! # async fn run() {
27//! let pipeline = pipeline::new()
28//!     // op1: add two numbers
29//!     .map(|(x, y)| x + y)
30//!     // op2: format result
31//!     .map(|z| format!("Result: {z}!"));
32//!
33//! let result = pipeline.call((1, 2)).await;
34//! assert_eq!(result, "Result: 3!");
35//! # }
36//! ```
37//!
38//! This pipeline can be visualized as the following DAG:
39//! ```text
40//!          ┌─────────┐   ┌─────────┐         
41//! Input───►│   op1   ├──►│   op2   ├──►Output
42//!          └─────────┘   └─────────┘         
43//! ```
44//!
45//! ## Parallel Operations
46//! The pipeline API also provides the [`parallel!`](crate::parallel!) macro for running operations in parallel.
47//! The macro takes a list of ops and turns them into a single op that will duplicate the input
48//! and run each op in concurrently. The results of each op are then collected and returned as a tuple.
49//!
50//! For example, the pipeline below runs two operations concurrently:
51//! ```no_run
52//! use rig_core::{pipeline::{self, Op, map}, parallel};
53//!
54//! # async fn run() {
55//! let pipeline = pipeline::new()
56//!     .chain(parallel!(
57//!         // op1: add 1 to input
58//!         map(|x| x + 1),
59//!         // op2: subtract 1 from input
60//!         map(|x| x - 1),
61//!     ))
62//!     // op3: format results
63//!     .map(|(a, b)| format!("Results: {a}, {b}"));
64//!
65//! let result = pipeline.call(1).await;
66//! assert_eq!(result, "Results: 2, 0");
67//! # }
68//! ```
69//!
70//! Notes:
71//! - The [chain](Op::chain) method is similar to the [map](Op::map) method but it allows
72//!   for chaining arbitrary operations, as long as they implement the [Op] trait.
73//! - [map] is a function that initializes a standalone [Map](self::op::Map) op without an existing pipeline/op.
74//!
75//! The pipeline above can be visualized as the following DAG:
76//! ```text                 
77//!           Input            
78//!             │              
79//!      ┌──────┴──────┐       
80//!      ▼             ▼       
81//! ┌─────────┐   ┌─────────┐  
82//! │   op1   │   │   op2   │  
83//! └────┬────┘   └────┬────┘  
84//!      └──────┬──────┘       
85//!             ▼              
86//!        ┌─────────┐         
87//!        │   op3   │         
88//!        └────┬────┘         
89//!             │              
90//!             ▼              
91//!          Output           
92//! ```
93
94pub mod agent_ops;
95pub mod op;
96pub mod try_op;
97#[macro_use]
98pub mod parallel;
99#[macro_use]
100pub mod conditional;
101
102use std::future::Future;
103
104pub use op::{Op, map, passthrough, then};
105pub use try_op::TryOp;
106
107use crate::{completion, extractor::Extractor, vector_store};
108
109pub struct PipelineBuilder<E> {
110    _error: std::marker::PhantomData<E>,
111}
112
113impl<E> PipelineBuilder<E> {
114    /// Add a function to the current pipeline
115    ///
116    /// # Example
117    /// ```no_run
118    /// use rig_core::pipeline::{self, Op};
119    ///
120    /// # async fn run() {
121    /// let pipeline = pipeline::new()
122    ///    .map(|(x, y)| x + y)
123    ///    .map(|z| format!("Result: {z}!"));
124    ///
125    /// let result = pipeline.call((1, 2)).await;
126    /// assert_eq!(result, "Result: 3!");
127    /// # }
128    /// ```
129    pub fn map<F, Input, Output>(self, f: F) -> op::Map<F, Input>
130    where
131        F: Fn(Input) -> Output + Send + Sync,
132        Input: Send + Sync,
133        Output: Send + Sync,
134        Self: Sized,
135    {
136        op::Map::new(f)
137    }
138
139    /// Same as `map` but for asynchronous functions
140    ///
141    /// # Example
142    /// ```no_run
143    /// use rig_core::pipeline::{self, Op};
144    ///
145    /// # async fn run() {
146    /// let pipeline = pipeline::new()
147    ///     .then(|email: String| async move {
148    ///         email.split('@').next().unwrap_or_default().to_string()
149    ///     })
150    ///     .then(|username: String| async move {
151    ///         format!("Hello, {}!", username)
152    ///     });
153    ///
154    /// let result = pipeline.call("bob@gmail.com".to_string()).await;
155    /// assert_eq!(result, "Hello, bob!");
156    /// # }
157    /// ```
158    pub fn then<F, Input, Fut>(self, f: F) -> op::Then<F, Input>
159    where
160        F: Fn(Input) -> Fut + Send + Sync,
161        Input: Send + Sync,
162        Fut: Future + Send + Sync,
163        Fut::Output: Send + Sync,
164        Self: Sized,
165    {
166        op::Then::new(f)
167    }
168
169    /// Add an arbitrary operation to the current pipeline.
170    ///
171    /// # Example
172    /// ```no_run
173    /// use rig_core::pipeline::{self, Op};
174    ///
175    /// # async fn run() {
176    /// struct MyOp;
177    ///
178    /// impl Op for MyOp {
179    ///     type Input = i32;
180    ///     type Output = i32;
181    ///
182    ///     async fn call(&self, input: Self::Input) -> Self::Output {
183    ///         input + 1
184    ///     }
185    /// }
186    ///
187    /// let pipeline = pipeline::new()
188    ///    .chain(MyOp);
189    ///
190    /// let result = pipeline.call(1).await;
191    /// assert_eq!(result, 2);
192    /// # }
193    /// ```
194    pub fn chain<T>(self, op: T) -> T
195    where
196        T: Op,
197        Self: Sized,
198    {
199        op
200    }
201
202    /// Chain a lookup operation to the current chain. The lookup operation expects the
203    /// current chain to output a query string. The lookup operation will use the query to
204    /// retrieve the top `n` documents from the index and return them with the query string.
205    ///
206    /// # Example
207    /// ```ignore
208    /// use rig_core::pipeline::{self, Op};
209    ///
210    /// let pipeline = pipeline::new()
211    ///     .lookup(index, 2)
212    ///     .pipeline(|(query, docs): (_, Vec<String>)| async move {
213    ///         format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
214    ///     });
215    ///
216    /// let result = pipeline.call("What is a flurbo?".to_string()).await;
217    /// ```
218    pub fn lookup<I, Input, Output>(self, index: I, n: usize) -> agent_ops::Lookup<I, Input, Output>
219    where
220        I: vector_store::VectorStoreIndex,
221        Output: Send + Sync + for<'a> serde::Deserialize<'a>,
222        Input: Into<String> + Send + Sync,
223        // E: From<vector_store::VectorStoreError> + Send + Sync,
224        Self: Sized,
225    {
226        agent_ops::Lookup::new(index, n)
227    }
228
229    /// Add a prompt operation to the current pipeline/op. The prompt operation expects the
230    /// current pipeline to output a string. The prompt operation will use the string to prompt
231    /// the given `agent`, which must implements the [Prompt](completion::Prompt) trait and return
232    /// the response.
233    ///
234    /// # Example
235    /// ```ignore
236    /// use rig_core::pipeline::{self, Op};
237    ///
238    /// let agent = &openai_client.agent("gpt-4").build();
239    ///
240    /// let pipeline = pipeline::new()
241    ///    .map(|name| format!("Find funny nicknames for the following name: {name}!"))
242    ///    .prompt(agent);
243    ///
244    /// let result = pipeline.call("Alice".to_string()).await;
245    /// ```
246    pub fn prompt<P, Input>(self, agent: P) -> agent_ops::Prompt<P, Input>
247    where
248        P: completion::Prompt,
249        Input: Into<String> + Send + Sync,
250        // E: From<completion::PromptError> + Send + Sync,
251        Self: Sized,
252    {
253        agent_ops::Prompt::new(agent)
254    }
255
256    /// Add an extract operation to the current pipeline/op. The extract operation expects the
257    /// current pipeline to output a string. The extract operation will use the given `extractor`
258    /// to extract information from the string in the form of the type `T` and return it.
259    ///
260    /// # Example
261    /// ```ignore
262    /// use rig_core::pipeline::{self, Op};
263    ///
264    /// #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
265    /// struct Sentiment {
266    ///     /// The sentiment score of the text (0.0 = negative, 1.0 = positive)
267    ///     score: f64,
268    /// }
269    ///
270    /// let extractor = &openai_client.extractor::<Sentiment>("gpt-4").build();
271    ///
272    /// let pipeline = pipeline::new()
273    ///     .map(|text| format!("Analyze the sentiment of the following text: {text}!"))
274    ///     .extract(extractor);
275    ///
276    /// let result: Sentiment = pipeline.call("I love ice cream!".to_string()).await?;
277    /// assert!(result.score > 0.5);
278    /// ```
279    pub fn extract<M, Input, Output>(
280        self,
281        extractor: Extractor<M, Output>,
282    ) -> agent_ops::Extract<M, Input, Output>
283    where
284        M: completion::CompletionModel,
285        Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
286        Input: Into<String> + Send + Sync,
287    {
288        agent_ops::Extract::new(extractor)
289    }
290}
291
292#[derive(Debug, thiserror::Error)]
293pub enum ChainError {
294    #[error("Failed to prompt agent: {0}")]
295    PromptError(#[from] Box<completion::PromptError>),
296
297    #[error("Failed to lookup documents: {0}")]
298    LookupError(#[from] vector_store::VectorStoreError),
299}
300
301/// Create a pipeline builder using Rig's default [`ChainError`] type.
302pub fn new() -> PipelineBuilder<ChainError> {
303    PipelineBuilder {
304        _error: std::marker::PhantomData,
305    }
306}
307
308/// Create a pipeline builder with a custom error type.
309pub fn with_error<E>() -> PipelineBuilder<E> {
310    PipelineBuilder {
311        _error: std::marker::PhantomData,
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use crate::test_utils::{Foo, MockPromptModel, MockVectorStoreIndex};
319    use parallel::parallel;
320
321    #[tokio::test]
322    async fn test_prompt_pipeline() {
323        let model = MockPromptModel;
324
325        let chain = super::new()
326            .map(|input| format!("User query: {input}"))
327            .prompt(model);
328
329        let result = chain
330            .call("What is a flurbo?")
331            .await
332            .expect("Failed to run chain");
333
334        assert_eq!(result, "Mock response: User query: What is a flurbo?");
335    }
336
337    #[tokio::test]
338    async fn test_prompt_pipeline_error() {
339        let model = MockPromptModel;
340
341        let chain = super::with_error::<()>()
342            .map(|input| format!("User query: {input}"))
343            .prompt(model);
344
345        let result = chain
346            .try_call("What is a flurbo?")
347            .await
348            .expect("Failed to run chain");
349
350        assert_eq!(result, "Mock response: User query: What is a flurbo?");
351    }
352
353    #[tokio::test]
354    async fn test_lookup_pipeline() {
355        let index = MockVectorStoreIndex;
356
357        let chain = super::new()
358            .lookup::<_, _, Foo>(index, 1)
359            .map_ok(|docs| format!("Top documents:\n{}", docs[0].2.foo));
360
361        let result = chain
362            .try_call("What is a flurbo?")
363            .await
364            .expect("Failed to run chain");
365
366        assert_eq!(result, "Top documents:\nbar");
367    }
368
369    #[tokio::test]
370    async fn test_rag_pipeline() {
371        let index = MockVectorStoreIndex;
372
373        let chain = super::new()
374            .chain(parallel!(
375                passthrough(),
376                agent_ops::lookup::<_, _, Foo>(index, 1),
377            ))
378            .map(|(query, maybe_docs)| match maybe_docs {
379                Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].2.foo),
380                Err(err) => format!("Error: {err}"),
381            })
382            .prompt(MockPromptModel);
383
384        let result = chain
385            .call("What is a flurbo?")
386            .await
387            .expect("Failed to run chain");
388
389        assert_eq!(
390            result,
391            "Mock response: User query: What is a flurbo?\n\nTop documents:\nbar"
392        );
393    }
394}