bep/pipeline/
mod.rs

1//! This module defines a flexible pipeline API for defining a sequence of operations that
2//! may or may not use AI components (e.g.: semantic search, LLMs prompting, etc).
3//!
4//! The pipeline API was inspired by general orchestration pipelines such as Airflow, Dagster and Prefect,
5//! but implemented with idiomatic Rust patterns and providing some AI-specific ops out-of-the-box along
6//! general combinators.
7//!
8//! Pipelines are made up of one or more operations, or "ops", each of which must implement the [Op] trait.
9//! The [Op] trait requires the implementation of only one method: `call`, which takes an input
10//! and returns an output. The trait provides a wide range of combinators for chaining operations together.
11//!
12//! One can think of a pipeline as a DAG (Directed Acyclic Graph) where each node is an operation and
13//! the edges represent the data flow between operations. When invoking the pipeline on some input,
14//! the input is passed to the root node of the DAG (i.e.: the first op defined in the pipeline).
15//! The output of each op is then passed to the next op in the pipeline until the output reaches the
16//! leaf node (i.e.: the last op defined in the pipeline). The output of the leaf node is then returned
17//! as the result of the pipeline.
18//!
19//! ## Basic Example
20//! For example, the pipeline below takes a tuple of two integers, adds them together and then formats
21//! the result as a string using the [map](Op::map) combinator method, which applies a simple function
22//! op to the output of the previous op:
23//! ```rust
24//! use bep::pipeline::{self, Op};
25//!
26//! let pipeline = pipeline::new()
27//!     // op1: add two numbers
28//!     .map(|(x, y)| x + y)
29//!     // op2: format result
30//!     .map(|z| format!("Result: {z}!"));
31//!
32//! let result = pipeline.call((1, 2)).await;
33//! assert_eq!(result, "Result: 3!");
34//! ```
35//!
36//! This pipeline can be visualized as the following DAG:
37//! ```text
38//!          ┌─────────┐   ┌─────────┐         
39//! Input───►│   op1   ├──►│   op2   ├──►Output
40//!          └─────────┘   └─────────┘         
41//! ```
42//!
43//! ## Parallel Operations
44//! The pipeline API also provides a [parallel!](crate::parallel!) and macro for running operations in parallel.
45//! The macro takes a list of ops and turns them into a single op that will duplicate the input
46//! and run each op in concurently. The results of each op are then collected and returned as a tuple.
47//!
48//! For example, the pipeline below runs two operations concurently:
49//! ```rust
50//! use bep::{pipeline::{self, Op, map}, parallel};
51//!
52//! let pipeline = pipeline::new()
53//!     .chain(parallel!(
54//!         // op1: add 1 to input
55//!         map(|x| x + 1),
56//!         // op2: subtract 1 from input
57//!         map(|x| x - 1),
58//!     ))
59//!     // op3: format results
60//!     .map(|(a, b)| format!("Results: {a}, {b}"));
61//!
62//! let result = pipeline.call(1).await;
63//! assert_eq!(result, "Result: 2, 0");
64//! ```
65//!
66//! Notes:
67//! - The [chain](Op::chain) method is similar to the [map](Op::map) method but it allows
68//!   for chaining arbitrary operations, as long as they implement the [Op] trait.
69//! - [map] is a function that initializes a standalone [Map](self::op::Map) op without an existing pipeline/op.
70//!
71//! The pipeline above can be visualized as the following DAG:
72//! ```text                 
73//!           Input            
74//!             │              
75//!      ┌──────┴──────┐       
76//!      ▼             ▼       
77//! ┌─────────┐   ┌─────────┐  
78//! │   op1   │   │   op2   │  
79//! └────┬────┘   └────┬────┘  
80//!      └──────┬──────┘       
81//!             ▼              
82//!        ┌─────────┐         
83//!        │   op3   │         
84//!        └────┬────┘         
85//!             │              
86//!             ▼              
87//!          Output           
88//! ```
89
90pub mod agent_ops;
91pub mod op;
92pub mod try_op;
93#[macro_use]
94pub mod parallel;
95
96use std::future::Future;
97
98pub use op::{map, passthrough, then, Op};
99pub use try_op::TryOp;
100
101use crate::{completion, extractor::Extractor, vector_store};
102
103pub struct PipelineBuilder<E> {
104    _error: std::marker::PhantomData<E>,
105}
106
107impl<E> PipelineBuilder<E> {
108    /// Add a function to the current pipeline
109    ///
110    /// # Example
111    /// ```rust
112    /// use bep::pipeline::{self, Op};
113    ///
114    /// let pipeline = pipeline::new()
115    ///    .map(|(x, y)| x + y)
116    ///    .map(|z| format!("Result: {z}!"));
117    ///
118    /// let result = pipeline.call((1, 2)).await;
119    /// assert_eq!(result, "Result: 3!");
120    /// ```
121    pub fn map<F, Input, Output>(self, f: F) -> op::Map<F, Input>
122    where
123        F: Fn(Input) -> Output + Send + Sync,
124        Input: Send + Sync,
125        Output: Send + Sync,
126        Self: Sized,
127    {
128        op::Map::new(f)
129    }
130
131    /// Same as `map` but for asynchronous functions
132    ///
133    /// # Example
134    /// ```rust
135    /// use bep::pipeline::{self, Op};
136    ///
137    /// let pipeline = pipeline::new()
138    ///     .then(|email: String| async move {
139    ///         email.split('@').next().unwrap().to_string()
140    ///     })
141    ///     .then(|username: String| async move {
142    ///         format!("Hello, {}!", username)
143    ///     });
144    ///
145    /// let result = pipeline.call("bob@gmail.com".to_string()).await;
146    /// assert_eq!(result, "Hello, bob!");
147    /// ```
148    pub fn then<F, Input, Fut>(self, f: F) -> op::Then<F, Input>
149    where
150        F: Fn(Input) -> Fut + Send + Sync,
151        Input: Send + Sync,
152        Fut: Future + Send + Sync,
153        Fut::Output: Send + Sync,
154        Self: Sized,
155    {
156        op::Then::new(f)
157    }
158
159    /// Add an arbitrary operation to the current pipeline.
160    ///
161    /// # Example
162    /// ```rust
163    /// use bep::pipeline::{self, Op};
164    ///
165    /// struct MyOp;
166    ///
167    /// impl Op for MyOp {
168    ///     type Input = i32;
169    ///     type Output = i32;
170    ///
171    ///     async fn call(&self, input: Self::Input) -> Self::Output {
172    ///         input + 1
173    ///     }
174    /// }
175    ///
176    /// let pipeline = pipeline::new()
177    ///    .chain(MyOp);
178    ///
179    /// let result = pipeline.call(1).await;
180    /// assert_eq!(result, 2);
181    /// ```
182    pub fn chain<T>(self, op: T) -> T
183    where
184        T: Op,
185        Self: Sized,
186    {
187        op
188    }
189
190    /// Chain a lookup operation to the current chain. The lookup operation expects the
191    /// current chain to output a query string. The lookup operation will use the query to
192    /// retrieve the top `n` documents from the index and return them with the query string.
193    ///
194    /// # Example
195    /// ```rust
196    /// use bep::pipeline::{self, Op};
197    ///
198    /// let pipeline = pipeline::new()
199    ///     .lookup(index, 2)
200    ///     .pipeline(|(query, docs): (_, Vec<String>)| async move {
201    ///         format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
202    ///     });
203    ///
204    /// let result = pipeline.call("What is a flurbo?".to_string()).await;
205    /// ```
206    pub fn lookup<I, Input, Output>(self, index: I, n: usize) -> agent_ops::Lookup<I, Input, Output>
207    where
208        I: vector_store::VectorStoreIndex,
209        Output: Send + Sync + for<'a> serde::Deserialize<'a>,
210        Input: Into<String> + Send + Sync,
211        // E: From<vector_store::VectorStoreError> + Send + Sync,
212        Self: Sized,
213    {
214        agent_ops::Lookup::new(index, n)
215    }
216
217    /// Add a prompt operation to the current pipeline/op. The prompt operation expects the
218    /// current pipeline to output a string. The prompt operation will use the string to prompt
219    /// the given `agent`, which must implements the [Prompt](completion::Prompt) trait and return
220    /// the response.
221    ///
222    /// # Example
223    /// ```rust
224    /// use bep::pipeline::{self, Op};
225    ///
226    /// let agent = &openai_client.agent("gpt-4").build();
227    ///
228    /// let pipeline = pipeline::new()
229    ///    .map(|name| format!("Find funny nicknames for the following name: {name}!"))
230    ///    .prompt(agent);
231    ///
232    /// let result = pipeline.call("Alice".to_string()).await;
233    /// ```
234    pub fn prompt<P, Input>(self, agent: P) -> agent_ops::Prompt<P, Input>
235    where
236        P: completion::Prompt,
237        Input: Into<String> + Send + Sync,
238        // E: From<completion::PromptError> + Send + Sync,
239        Self: Sized,
240    {
241        agent_ops::Prompt::new(agent)
242    }
243
244    /// Add an extract operation to the current pipeline/op. The extract operation expects the
245    /// current pipeline to output a string. The extract operation will use the given `extractor`
246    /// to extract information from the string in the form of the type `T` and return it.
247    ///
248    /// # Example
249    /// ```rust
250    /// use bep::pipeline::{self, Op};
251    ///
252    /// #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
253    /// struct Sentiment {
254    ///     /// The sentiment score of the text (0.0 = negative, 1.0 = positive)
255    ///     score: f64,
256    /// }
257    ///
258    /// let extractor = &openai_client.extractor::<Sentiment>("gpt-4").build();
259    ///
260    /// let pipeline = pipeline::new()
261    ///     .map(|text| format!("Analyze the sentiment of the following text: {text}!"))
262    ///     .extract(extractor);
263    ///
264    /// let result: Sentiment = pipeline.call("I love ice cream!".to_string()).await?;
265    /// assert!(result.score > 0.5);
266    /// ```
267    pub fn extract<M, Input, Output>(
268        self,
269        extractor: Extractor<M, Output>,
270    ) -> agent_ops::Extract<M, Input, Output>
271    where
272        M: completion::CompletionModel,
273        Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
274        Input: Into<String> + Send + Sync,
275    {
276        agent_ops::Extract::new(extractor)
277    }
278}
279
280#[derive(Debug, thiserror::Error)]
281pub enum ChainError {
282    #[error("Failed to prompt agent: {0}")]
283    PromptError(#[from] completion::PromptError),
284
285    #[error("Failed to lookup documents: {0}")]
286    LookupError(#[from] vector_store::VectorStoreError),
287}
288
289pub fn new() -> PipelineBuilder<ChainError> {
290    PipelineBuilder {
291        _error: std::marker::PhantomData,
292    }
293}
294
295pub fn with_error<E>() -> PipelineBuilder<E> {
296    PipelineBuilder {
297        _error: std::marker::PhantomData,
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use agent_ops::tests::{Foo, MockIndex, MockModel};
305    use parallel::parallel;
306
307    #[tokio::test]
308    async fn test_prompt_pipeline() {
309        let model = MockModel;
310
311        let chain = super::new()
312            .map(|input| format!("User query: {}", input))
313            .prompt(model);
314
315        let result = chain
316            .call("What is a flurbo?")
317            .await
318            .expect("Failed to run chain");
319
320        assert_eq!(result, "Mock response: User query: What is a flurbo?");
321    }
322
323    #[tokio::test]
324    async fn test_prompt_pipeline_error() {
325        let model = MockModel;
326
327        let chain = super::with_error::<()>()
328            .map(|input| format!("User query: {}", input))
329            .prompt(model);
330
331        let result = chain
332            .try_call("What is a flurbo?")
333            .await
334            .expect("Failed to run chain");
335
336        assert_eq!(result, "Mock response: User query: What is a flurbo?");
337    }
338
339    #[tokio::test]
340    async fn test_lookup_pipeline() {
341        let index = MockIndex;
342
343        let chain = super::new()
344            .lookup::<_, _, Foo>(index, 1)
345            .map_ok(|docs| format!("Top documents:\n{}", docs[0].2.foo));
346
347        let result = chain
348            .try_call("What is a flurbo?")
349            .await
350            .expect("Failed to run chain");
351
352        assert_eq!(result, "Top documents:\nbar");
353    }
354
355    #[tokio::test]
356    async fn test_rag_pipeline() {
357        let index = MockIndex;
358
359        let chain = super::new()
360            .chain(parallel!(
361                passthrough(),
362                agent_ops::lookup::<_, _, Foo>(index, 1),
363            ))
364            .map(|(query, maybe_docs)| match maybe_docs {
365                Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].2.foo),
366                Err(err) => format!("Error: {}", err),
367            })
368            .prompt(MockModel);
369
370        let result = chain
371            .call("What is a flurbo?")
372            .await
373            .expect("Failed to run chain");
374
375        assert_eq!(
376            result,
377            "Mock response: User query: What is a flurbo?\n\nTop documents:\nbar"
378        );
379    }
380}