rig/pipeline/mod.rs
1//! This module defines a flexible pipeline API for defining a sequence of operations that
2//! may or may not use AI components (e.g.: semantic search, LLMs prompting, etc).
3//!
4//! The pipeline API was inspired by general orchestration pipelines such as Airflow, Dagster and Prefect,
5//! but implemented with idiomatic Rust patterns and providing some AI-specific ops out-of-the-box along
6//! general combinators.
7//!
8//! Pipelines are made up of one or more operations, or "ops", each of which must implement the [Op] trait.
9//! The [Op] trait requires the implementation of only one method: `call`, which takes an input
10//! and returns an output. The trait provides a wide range of combinators for chaining operations together.
11//!
12//! One can think of a pipeline as a DAG (Directed Acyclic Graph) where each node is an operation and
13//! the edges represent the data flow between operations. When invoking the pipeline on some input,
14//! the input is passed to the root node of the DAG (i.e.: the first op defined in the pipeline).
15//! The output of each op is then passed to the next op in the pipeline until the output reaches the
16//! leaf node (i.e.: the last op defined in the pipeline). The output of the leaf node is then returned
17//! as the result of the pipeline.
18//!
19//! ## Basic Example
20//! For example, the pipeline below takes a tuple of two integers, adds them together and then formats
21//! the result as a string using the [map](Op::map) combinator method, which applies a simple function
22//! op to the output of the previous op:
23//! ```rust
24//! use rig::pipeline::{self, Op};
25//!
26//! let pipeline = pipeline::new()
27//! // op1: add two numbers
28//! .map(|(x, y)| x + y)
29//! // op2: format result
30//! .map(|z| format!("Result: {z}!"));
31//!
32//! let result = pipeline.call((1, 2)).await;
33//! assert_eq!(result, "Result: 3!");
34//! ```
35//!
36//! This pipeline can be visualized as the following DAG:
37//! ```text
38//! ┌─────────┐ ┌─────────┐
39//! Input───►│ op1 ├──►│ op2 ├──►Output
40//! └─────────┘ └─────────┘
41//! ```
42//!
43//! ## Parallel Operations
44//! The pipeline API also provides a [parallel!](crate::parallel!) and macro for running operations in parallel.
45//! The macro takes a list of ops and turns them into a single op that will duplicate the input
46//! and run each op in concurrently. The results of each op are then collected and returned as a tuple.
47//!
48//! For example, the pipeline below runs two operations concurrently:
49//! ```rust
50//! use rig::{pipeline::{self, Op, map}, parallel};
51//!
52//! let pipeline = pipeline::new()
53//! .chain(parallel!(
54//! // op1: add 1 to input
55//! map(|x| x + 1),
56//! // op2: subtract 1 from input
57//! map(|x| x - 1),
58//! ))
59//! // op3: format results
60//! .map(|(a, b)| format!("Results: {a}, {b}"));
61//!
62//! let result = pipeline.call(1).await;
63//! assert_eq!(result, "Result: 2, 0");
64//! ```
65//!
66//! Notes:
67//! - The [chain](Op::chain) method is similar to the [map](Op::map) method but it allows
68//! for chaining arbitrary operations, as long as they implement the [Op] trait.
69//! - [map] is a function that initializes a standalone [Map](self::op::Map) op without an existing pipeline/op.
70//!
71//! The pipeline above can be visualized as the following DAG:
72//! ```text
73//! Input
74//! │
75//! ┌──────┴──────┐
76//! ▼ ▼
77//! ┌─────────┐ ┌─────────┐
78//! │ op1 │ │ op2 │
79//! └────┬────┘ └────┬────┘
80//! └──────┬──────┘
81//! ▼
82//! ┌─────────┐
83//! │ op3 │
84//! └────┬────┘
85//! │
86//! ▼
87//! Output
88//! ```
89
90pub mod agent_ops;
91pub mod op;
92pub mod try_op;
93#[macro_use]
94pub mod parallel;
95#[macro_use]
96pub mod conditional;
97
98use std::future::Future;
99
100pub use op::{map, passthrough, then, Op};
101pub use try_op::TryOp;
102
103use crate::{completion, extractor::Extractor, vector_store};
104
105pub struct PipelineBuilder<E> {
106 _error: std::marker::PhantomData<E>,
107}
108
109impl<E> PipelineBuilder<E> {
110 /// Add a function to the current pipeline
111 ///
112 /// # Example
113 /// ```rust
114 /// use rig::pipeline::{self, Op};
115 ///
116 /// let pipeline = pipeline::new()
117 /// .map(|(x, y)| x + y)
118 /// .map(|z| format!("Result: {z}!"));
119 ///
120 /// let result = pipeline.call((1, 2)).await;
121 /// assert_eq!(result, "Result: 3!");
122 /// ```
123 pub fn map<F, Input, Output>(self, f: F) -> op::Map<F, Input>
124 where
125 F: Fn(Input) -> Output + Send + Sync,
126 Input: Send + Sync,
127 Output: Send + Sync,
128 Self: Sized,
129 {
130 op::Map::new(f)
131 }
132
133 /// Same as `map` but for asynchronous functions
134 ///
135 /// # Example
136 /// ```rust
137 /// use rig::pipeline::{self, Op};
138 ///
139 /// let pipeline = pipeline::new()
140 /// .then(|email: String| async move {
141 /// email.split('@').next().unwrap().to_string()
142 /// })
143 /// .then(|username: String| async move {
144 /// format!("Hello, {}!", username)
145 /// });
146 ///
147 /// let result = pipeline.call("bob@gmail.com".to_string()).await;
148 /// assert_eq!(result, "Hello, bob!");
149 /// ```
150 pub fn then<F, Input, Fut>(self, f: F) -> op::Then<F, Input>
151 where
152 F: Fn(Input) -> Fut + Send + Sync,
153 Input: Send + Sync,
154 Fut: Future + Send + Sync,
155 Fut::Output: Send + Sync,
156 Self: Sized,
157 {
158 op::Then::new(f)
159 }
160
161 /// Add an arbitrary operation to the current pipeline.
162 ///
163 /// # Example
164 /// ```rust
165 /// use rig::pipeline::{self, Op};
166 ///
167 /// struct MyOp;
168 ///
169 /// impl Op for MyOp {
170 /// type Input = i32;
171 /// type Output = i32;
172 ///
173 /// async fn call(&self, input: Self::Input) -> Self::Output {
174 /// input + 1
175 /// }
176 /// }
177 ///
178 /// let pipeline = pipeline::new()
179 /// .chain(MyOp);
180 ///
181 /// let result = pipeline.call(1).await;
182 /// assert_eq!(result, 2);
183 /// ```
184 pub fn chain<T>(self, op: T) -> T
185 where
186 T: Op,
187 Self: Sized,
188 {
189 op
190 }
191
192 /// Chain a lookup operation to the current chain. The lookup operation expects the
193 /// current chain to output a query string. The lookup operation will use the query to
194 /// retrieve the top `n` documents from the index and return them with the query string.
195 ///
196 /// # Example
197 /// ```rust
198 /// use rig::pipeline::{self, Op};
199 ///
200 /// let pipeline = pipeline::new()
201 /// .lookup(index, 2)
202 /// .pipeline(|(query, docs): (_, Vec<String>)| async move {
203 /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
204 /// });
205 ///
206 /// let result = pipeline.call("What is a flurbo?".to_string()).await;
207 /// ```
208 pub fn lookup<I, Input, Output>(self, index: I, n: usize) -> agent_ops::Lookup<I, Input, Output>
209 where
210 I: vector_store::VectorStoreIndex,
211 Output: Send + Sync + for<'a> serde::Deserialize<'a>,
212 Input: Into<String> + Send + Sync,
213 // E: From<vector_store::VectorStoreError> + Send + Sync,
214 Self: Sized,
215 {
216 agent_ops::Lookup::new(index, n)
217 }
218
219 /// Add a prompt operation to the current pipeline/op. The prompt operation expects the
220 /// current pipeline to output a string. The prompt operation will use the string to prompt
221 /// the given `agent`, which must implements the [Prompt](completion::Prompt) trait and return
222 /// the response.
223 ///
224 /// # Example
225 /// ```rust
226 /// use rig::pipeline::{self, Op};
227 ///
228 /// let agent = &openai_client.agent("gpt-4").build();
229 ///
230 /// let pipeline = pipeline::new()
231 /// .map(|name| format!("Find funny nicknames for the following name: {name}!"))
232 /// .prompt(agent);
233 ///
234 /// let result = pipeline.call("Alice".to_string()).await;
235 /// ```
236 pub fn prompt<P, Input>(self, agent: P) -> agent_ops::Prompt<P, Input>
237 where
238 P: completion::Prompt,
239 Input: Into<String> + Send + Sync,
240 // E: From<completion::PromptError> + Send + Sync,
241 Self: Sized,
242 {
243 agent_ops::Prompt::new(agent)
244 }
245
246 /// Add an extract operation to the current pipeline/op. The extract operation expects the
247 /// current pipeline to output a string. The extract operation will use the given `extractor`
248 /// to extract information from the string in the form of the type `T` and return it.
249 ///
250 /// # Example
251 /// ```rust
252 /// use rig::pipeline::{self, Op};
253 ///
254 /// #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
255 /// struct Sentiment {
256 /// /// The sentiment score of the text (0.0 = negative, 1.0 = positive)
257 /// score: f64,
258 /// }
259 ///
260 /// let extractor = &openai_client.extractor::<Sentiment>("gpt-4").build();
261 ///
262 /// let pipeline = pipeline::new()
263 /// .map(|text| format!("Analyze the sentiment of the following text: {text}!"))
264 /// .extract(extractor);
265 ///
266 /// let result: Sentiment = pipeline.call("I love ice cream!".to_string()).await?;
267 /// assert!(result.score > 0.5);
268 /// ```
269 pub fn extract<M, Input, Output>(
270 self,
271 extractor: Extractor<M, Output>,
272 ) -> agent_ops::Extract<M, Input, Output>
273 where
274 M: completion::CompletionModel,
275 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
276 Input: Into<String> + Send + Sync,
277 {
278 agent_ops::Extract::new(extractor)
279 }
280}
281
282#[derive(Debug, thiserror::Error)]
283pub enum ChainError {
284 #[error("Failed to prompt agent: {0}")]
285 PromptError(#[from] completion::PromptError),
286
287 #[error("Failed to lookup documents: {0}")]
288 LookupError(#[from] vector_store::VectorStoreError),
289}
290
291pub fn new() -> PipelineBuilder<ChainError> {
292 PipelineBuilder {
293 _error: std::marker::PhantomData,
294 }
295}
296
297pub fn with_error<E>() -> PipelineBuilder<E> {
298 PipelineBuilder {
299 _error: std::marker::PhantomData,
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use agent_ops::tests::{Foo, MockIndex, MockModel};
307 use parallel::parallel;
308
309 #[tokio::test]
310 async fn test_prompt_pipeline() {
311 let model = MockModel;
312
313 let chain = super::new()
314 .map(|input| format!("User query: {}", input))
315 .prompt(model);
316
317 let result = chain
318 .call("What is a flurbo?")
319 .await
320 .expect("Failed to run chain");
321
322 assert_eq!(result, "Mock response: User query: What is a flurbo?");
323 }
324
325 #[tokio::test]
326 async fn test_prompt_pipeline_error() {
327 let model = MockModel;
328
329 let chain = super::with_error::<()>()
330 .map(|input| format!("User query: {}", input))
331 .prompt(model);
332
333 let result = chain
334 .try_call("What is a flurbo?")
335 .await
336 .expect("Failed to run chain");
337
338 assert_eq!(result, "Mock response: User query: What is a flurbo?");
339 }
340
341 #[tokio::test]
342 async fn test_lookup_pipeline() {
343 let index = MockIndex;
344
345 let chain = super::new()
346 .lookup::<_, _, Foo>(index, 1)
347 .map_ok(|docs| format!("Top documents:\n{}", docs[0].2.foo));
348
349 let result = chain
350 .try_call("What is a flurbo?")
351 .await
352 .expect("Failed to run chain");
353
354 assert_eq!(result, "Top documents:\nbar");
355 }
356
357 #[tokio::test]
358 async fn test_rag_pipeline() {
359 let index = MockIndex;
360
361 let chain = super::new()
362 .chain(parallel!(
363 passthrough(),
364 agent_ops::lookup::<_, _, Foo>(index, 1),
365 ))
366 .map(|(query, maybe_docs)| match maybe_docs {
367 Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].2.foo),
368 Err(err) => format!("Error: {}", err),
369 })
370 .prompt(MockModel);
371
372 let result = chain
373 .call("What is a flurbo?")
374 .await
375 .expect("Failed to run chain");
376
377 assert_eq!(
378 result,
379 "Mock response: User query: What is a flurbo?\n\nTop documents:\nbar"
380 );
381 }
382}