bep/pipeline/mod.rs
1//! This module defines a flexible pipeline API for defining a sequence of operations that
2//! may or may not use AI components (e.g.: semantic search, LLMs prompting, etc).
3//!
4//! The pipeline API was inspired by general orchestration pipelines such as Airflow, Dagster and Prefect,
5//! but implemented with idiomatic Rust patterns and providing some AI-specific ops out-of-the-box along
6//! general combinators.
7//!
8//! Pipelines are made up of one or more operations, or "ops", each of which must implement the [Op] trait.
9//! The [Op] trait requires the implementation of only one method: `call`, which takes an input
10//! and returns an output. The trait provides a wide range of combinators for chaining operations together.
11//!
12//! One can think of a pipeline as a DAG (Directed Acyclic Graph) where each node is an operation and
13//! the edges represent the data flow between operations. When invoking the pipeline on some input,
14//! the input is passed to the root node of the DAG (i.e.: the first op defined in the pipeline).
15//! The output of each op is then passed to the next op in the pipeline until the output reaches the
16//! leaf node (i.e.: the last op defined in the pipeline). The output of the leaf node is then returned
17//! as the result of the pipeline.
18//!
19//! ## Basic Example
20//! For example, the pipeline below takes a tuple of two integers, adds them together and then formats
21//! the result as a string using the [map](Op::map) combinator method, which applies a simple function
22//! op to the output of the previous op:
23//! ```rust
24//! use bep::pipeline::{self, Op};
25//!
26//! let pipeline = pipeline::new()
27//! // op1: add two numbers
28//! .map(|(x, y)| x + y)
29//! // op2: format result
30//! .map(|z| format!("Result: {z}!"));
31//!
32//! let result = pipeline.call((1, 2)).await;
33//! assert_eq!(result, "Result: 3!");
34//! ```
35//!
36//! This pipeline can be visualized as the following DAG:
37//! ```text
38//! ┌─────────┐ ┌─────────┐
39//! Input───►│ op1 ├──►│ op2 ├──►Output
40//! └─────────┘ └─────────┘
41//! ```
42//!
43//! ## Parallel Operations
44//! The pipeline API also provides a [parallel!](crate::parallel!) and macro for running operations in parallel.
45//! The macro takes a list of ops and turns them into a single op that will duplicate the input
46//! and run each op in concurently. The results of each op are then collected and returned as a tuple.
47//!
48//! For example, the pipeline below runs two operations concurently:
49//! ```rust
50//! use bep::{pipeline::{self, Op, map}, parallel};
51//!
52//! let pipeline = pipeline::new()
53//! .chain(parallel!(
54//! // op1: add 1 to input
55//! map(|x| x + 1),
56//! // op2: subtract 1 from input
57//! map(|x| x - 1),
58//! ))
59//! // op3: format results
60//! .map(|(a, b)| format!("Results: {a}, {b}"));
61//!
62//! let result = pipeline.call(1).await;
63//! assert_eq!(result, "Result: 2, 0");
64//! ```
65//!
66//! Notes:
67//! - The [chain](Op::chain) method is similar to the [map](Op::map) method but it allows
68//! for chaining arbitrary operations, as long as they implement the [Op] trait.
69//! - [map] is a function that initializes a standalone [Map](self::op::Map) op without an existing pipeline/op.
70//!
71//! The pipeline above can be visualized as the following DAG:
72//! ```text
73//! Input
74//! │
75//! ┌──────┴──────┐
76//! ▼ ▼
77//! ┌─────────┐ ┌─────────┐
78//! │ op1 │ │ op2 │
79//! └────┬────┘ └────┬────┘
80//! └──────┬──────┘
81//! ▼
82//! ┌─────────┐
83//! │ op3 │
84//! └────┬────┘
85//! │
86//! ▼
87//! Output
88//! ```
89
90pub mod agent_ops;
91pub mod op;
92pub mod try_op;
93#[macro_use]
94pub mod parallel;
95
96use std::future::Future;
97
98pub use op::{map, passthrough, then, Op};
99pub use try_op::TryOp;
100
101use crate::{completion, extractor::Extractor, vector_store};
102
103pub struct PipelineBuilder<E> {
104 _error: std::marker::PhantomData<E>,
105}
106
107impl<E> PipelineBuilder<E> {
108 /// Add a function to the current pipeline
109 ///
110 /// # Example
111 /// ```rust
112 /// use bep::pipeline::{self, Op};
113 ///
114 /// let pipeline = pipeline::new()
115 /// .map(|(x, y)| x + y)
116 /// .map(|z| format!("Result: {z}!"));
117 ///
118 /// let result = pipeline.call((1, 2)).await;
119 /// assert_eq!(result, "Result: 3!");
120 /// ```
121 pub fn map<F, Input, Output>(self, f: F) -> op::Map<F, Input>
122 where
123 F: Fn(Input) -> Output + Send + Sync,
124 Input: Send + Sync,
125 Output: Send + Sync,
126 Self: Sized,
127 {
128 op::Map::new(f)
129 }
130
131 /// Same as `map` but for asynchronous functions
132 ///
133 /// # Example
134 /// ```rust
135 /// use bep::pipeline::{self, Op};
136 ///
137 /// let pipeline = pipeline::new()
138 /// .then(|email: String| async move {
139 /// email.split('@').next().unwrap().to_string()
140 /// })
141 /// .then(|username: String| async move {
142 /// format!("Hello, {}!", username)
143 /// });
144 ///
145 /// let result = pipeline.call("bob@gmail.com".to_string()).await;
146 /// assert_eq!(result, "Hello, bob!");
147 /// ```
148 pub fn then<F, Input, Fut>(self, f: F) -> op::Then<F, Input>
149 where
150 F: Fn(Input) -> Fut + Send + Sync,
151 Input: Send + Sync,
152 Fut: Future + Send + Sync,
153 Fut::Output: Send + Sync,
154 Self: Sized,
155 {
156 op::Then::new(f)
157 }
158
159 /// Add an arbitrary operation to the current pipeline.
160 ///
161 /// # Example
162 /// ```rust
163 /// use bep::pipeline::{self, Op};
164 ///
165 /// struct MyOp;
166 ///
167 /// impl Op for MyOp {
168 /// type Input = i32;
169 /// type Output = i32;
170 ///
171 /// async fn call(&self, input: Self::Input) -> Self::Output {
172 /// input + 1
173 /// }
174 /// }
175 ///
176 /// let pipeline = pipeline::new()
177 /// .chain(MyOp);
178 ///
179 /// let result = pipeline.call(1).await;
180 /// assert_eq!(result, 2);
181 /// ```
182 pub fn chain<T>(self, op: T) -> T
183 where
184 T: Op,
185 Self: Sized,
186 {
187 op
188 }
189
190 /// Chain a lookup operation to the current chain. The lookup operation expects the
191 /// current chain to output a query string. The lookup operation will use the query to
192 /// retrieve the top `n` documents from the index and return them with the query string.
193 ///
194 /// # Example
195 /// ```rust
196 /// use bep::pipeline::{self, Op};
197 ///
198 /// let pipeline = pipeline::new()
199 /// .lookup(index, 2)
200 /// .pipeline(|(query, docs): (_, Vec<String>)| async move {
201 /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
202 /// });
203 ///
204 /// let result = pipeline.call("What is a flurbo?".to_string()).await;
205 /// ```
206 pub fn lookup<I, Input, Output>(self, index: I, n: usize) -> agent_ops::Lookup<I, Input, Output>
207 where
208 I: vector_store::VectorStoreIndex,
209 Output: Send + Sync + for<'a> serde::Deserialize<'a>,
210 Input: Into<String> + Send + Sync,
211 // E: From<vector_store::VectorStoreError> + Send + Sync,
212 Self: Sized,
213 {
214 agent_ops::Lookup::new(index, n)
215 }
216
217 /// Add a prompt operation to the current pipeline/op. The prompt operation expects the
218 /// current pipeline to output a string. The prompt operation will use the string to prompt
219 /// the given `agent`, which must implements the [Prompt](completion::Prompt) trait and return
220 /// the response.
221 ///
222 /// # Example
223 /// ```rust
224 /// use bep::pipeline::{self, Op};
225 ///
226 /// let agent = &openai_client.agent("gpt-4").build();
227 ///
228 /// let pipeline = pipeline::new()
229 /// .map(|name| format!("Find funny nicknames for the following name: {name}!"))
230 /// .prompt(agent);
231 ///
232 /// let result = pipeline.call("Alice".to_string()).await;
233 /// ```
234 pub fn prompt<P, Input>(self, agent: P) -> agent_ops::Prompt<P, Input>
235 where
236 P: completion::Prompt,
237 Input: Into<String> + Send + Sync,
238 // E: From<completion::PromptError> + Send + Sync,
239 Self: Sized,
240 {
241 agent_ops::Prompt::new(agent)
242 }
243
244 /// Add an extract operation to the current pipeline/op. The extract operation expects the
245 /// current pipeline to output a string. The extract operation will use the given `extractor`
246 /// to extract information from the string in the form of the type `T` and return it.
247 ///
248 /// # Example
249 /// ```rust
250 /// use bep::pipeline::{self, Op};
251 ///
252 /// #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
253 /// struct Sentiment {
254 /// /// The sentiment score of the text (0.0 = negative, 1.0 = positive)
255 /// score: f64,
256 /// }
257 ///
258 /// let extractor = &openai_client.extractor::<Sentiment>("gpt-4").build();
259 ///
260 /// let pipeline = pipeline::new()
261 /// .map(|text| format!("Analyze the sentiment of the following text: {text}!"))
262 /// .extract(extractor);
263 ///
264 /// let result: Sentiment = pipeline.call("I love ice cream!".to_string()).await?;
265 /// assert!(result.score > 0.5);
266 /// ```
267 pub fn extract<M, Input, Output>(
268 self,
269 extractor: Extractor<M, Output>,
270 ) -> agent_ops::Extract<M, Input, Output>
271 where
272 M: completion::CompletionModel,
273 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
274 Input: Into<String> + Send + Sync,
275 {
276 agent_ops::Extract::new(extractor)
277 }
278}
279
280#[derive(Debug, thiserror::Error)]
281pub enum ChainError {
282 #[error("Failed to prompt agent: {0}")]
283 PromptError(#[from] completion::PromptError),
284
285 #[error("Failed to lookup documents: {0}")]
286 LookupError(#[from] vector_store::VectorStoreError),
287}
288
289pub fn new() -> PipelineBuilder<ChainError> {
290 PipelineBuilder {
291 _error: std::marker::PhantomData,
292 }
293}
294
295pub fn with_error<E>() -> PipelineBuilder<E> {
296 PipelineBuilder {
297 _error: std::marker::PhantomData,
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use agent_ops::tests::{Foo, MockIndex, MockModel};
305 use parallel::parallel;
306
307 #[tokio::test]
308 async fn test_prompt_pipeline() {
309 let model = MockModel;
310
311 let chain = super::new()
312 .map(|input| format!("User query: {}", input))
313 .prompt(model);
314
315 let result = chain
316 .call("What is a flurbo?")
317 .await
318 .expect("Failed to run chain");
319
320 assert_eq!(result, "Mock response: User query: What is a flurbo?");
321 }
322
323 #[tokio::test]
324 async fn test_prompt_pipeline_error() {
325 let model = MockModel;
326
327 let chain = super::with_error::<()>()
328 .map(|input| format!("User query: {}", input))
329 .prompt(model);
330
331 let result = chain
332 .try_call("What is a flurbo?")
333 .await
334 .expect("Failed to run chain");
335
336 assert_eq!(result, "Mock response: User query: What is a flurbo?");
337 }
338
339 #[tokio::test]
340 async fn test_lookup_pipeline() {
341 let index = MockIndex;
342
343 let chain = super::new()
344 .lookup::<_, _, Foo>(index, 1)
345 .map_ok(|docs| format!("Top documents:\n{}", docs[0].2.foo));
346
347 let result = chain
348 .try_call("What is a flurbo?")
349 .await
350 .expect("Failed to run chain");
351
352 assert_eq!(result, "Top documents:\nbar");
353 }
354
355 #[tokio::test]
356 async fn test_rag_pipeline() {
357 let index = MockIndex;
358
359 let chain = super::new()
360 .chain(parallel!(
361 passthrough(),
362 agent_ops::lookup::<_, _, Foo>(index, 1),
363 ))
364 .map(|(query, maybe_docs)| match maybe_docs {
365 Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].2.foo),
366 Err(err) => format!("Error: {}", err),
367 })
368 .prompt(MockModel);
369
370 let result = chain
371 .call("What is a flurbo?")
372 .await
373 .expect("Failed to run chain");
374
375 assert_eq!(
376 result,
377 "Mock response: User query: What is a flurbo?\n\nTop documents:\nbar"
378 );
379 }
380}