rig_core/pipeline/mod.rs
1//! This module defines a flexible pipeline API for defining a sequence of operations that
2//! may or may not use AI components (e.g.: semantic search, LLMs prompting, etc).
3//!
4//! The pipeline API was inspired by general orchestration pipelines such as Airflow, Dagster and Prefect,
5//! but implemented with idiomatic Rust patterns and providing some AI-specific ops out-of-the-box along
6//! general combinators.
7//!
8//! Pipelines are made up of one or more operations, or "ops", each of which must implement the [Op] trait.
9//! The [Op] trait requires the implementation of only one method: `call`, which takes an input
10//! and returns an output. The trait provides a wide range of combinators for chaining operations together.
11//!
12//! One can think of a pipeline as a DAG (Directed Acyclic Graph) where each node is an operation and
13//! the edges represent the data flow between operations. When invoking the pipeline on some input,
14//! the input is passed to the root node of the DAG (i.e.: the first op defined in the pipeline).
15//! The output of each op is then passed to the next op in the pipeline until the output reaches the
16//! leaf node (i.e.: the last op defined in the pipeline). The output of the leaf node is then returned
17//! as the result of the pipeline.
18//!
19//! ## Basic Example
20//! For example, the pipeline below takes a tuple of two integers, adds them together and then formats
21//! the result as a string using the [map](Op::map) combinator method, which applies a simple function
22//! op to the output of the previous op:
23//! ```no_run
24//! use rig_core::pipeline::{self, Op};
25//!
26//! # async fn run() {
27//! let pipeline = pipeline::new()
28//! // op1: add two numbers
29//! .map(|(x, y)| x + y)
30//! // op2: format result
31//! .map(|z| format!("Result: {z}!"));
32//!
33//! let result = pipeline.call((1, 2)).await;
34//! assert_eq!(result, "Result: 3!");
35//! # }
36//! ```
37//!
38//! This pipeline can be visualized as the following DAG:
39//! ```text
40//! ┌─────────┐ ┌─────────┐
41//! Input───►│ op1 ├──►│ op2 ├──►Output
42//! └─────────┘ └─────────┘
43//! ```
44//!
45//! ## Parallel Operations
46//! The pipeline API also provides the [`parallel!`](crate::parallel!) macro for running operations in parallel.
47//! The macro takes a list of ops and turns them into a single op that will duplicate the input
48//! and run each op in concurrently. The results of each op are then collected and returned as a tuple.
49//!
50//! For example, the pipeline below runs two operations concurrently:
51//! ```no_run
52//! use rig_core::{pipeline::{self, Op, map}, parallel};
53//!
54//! # async fn run() {
55//! let pipeline = pipeline::new()
56//! .chain(parallel!(
57//! // op1: add 1 to input
58//! map(|x| x + 1),
59//! // op2: subtract 1 from input
60//! map(|x| x - 1),
61//! ))
62//! // op3: format results
63//! .map(|(a, b)| format!("Results: {a}, {b}"));
64//!
65//! let result = pipeline.call(1).await;
66//! assert_eq!(result, "Results: 2, 0");
67//! # }
68//! ```
69//!
70//! Notes:
71//! - The [chain](Op::chain) method is similar to the [map](Op::map) method but it allows
72//! for chaining arbitrary operations, as long as they implement the [Op] trait.
73//! - [map] is a function that initializes a standalone [Map](self::op::Map) op without an existing pipeline/op.
74//!
75//! The pipeline above can be visualized as the following DAG:
76//! ```text
77//! Input
78//! │
79//! ┌──────┴──────┐
80//! ▼ ▼
81//! ┌─────────┐ ┌─────────┐
82//! │ op1 │ │ op2 │
83//! └────┬────┘ └────┬────┘
84//! └──────┬──────┘
85//! ▼
86//! ┌─────────┐
87//! │ op3 │
88//! └────┬────┘
89//! │
90//! ▼
91//! Output
92//! ```
93
94pub mod agent_ops;
95pub mod op;
96pub mod try_op;
97#[macro_use]
98pub mod parallel;
99#[macro_use]
100pub mod conditional;
101
102use std::future::Future;
103
104pub use op::{Op, map, passthrough, then};
105pub use try_op::TryOp;
106
107use crate::{completion, extractor::Extractor, vector_store};
108
109pub struct PipelineBuilder<E> {
110 _error: std::marker::PhantomData<E>,
111}
112
113impl<E> PipelineBuilder<E> {
114 /// Add a function to the current pipeline
115 ///
116 /// # Example
117 /// ```no_run
118 /// use rig_core::pipeline::{self, Op};
119 ///
120 /// # async fn run() {
121 /// let pipeline = pipeline::new()
122 /// .map(|(x, y)| x + y)
123 /// .map(|z| format!("Result: {z}!"));
124 ///
125 /// let result = pipeline.call((1, 2)).await;
126 /// assert_eq!(result, "Result: 3!");
127 /// # }
128 /// ```
129 pub fn map<F, Input, Output>(self, f: F) -> op::Map<F, Input>
130 where
131 F: Fn(Input) -> Output + Send + Sync,
132 Input: Send + Sync,
133 Output: Send + Sync,
134 Self: Sized,
135 {
136 op::Map::new(f)
137 }
138
139 /// Same as `map` but for asynchronous functions
140 ///
141 /// # Example
142 /// ```no_run
143 /// use rig_core::pipeline::{self, Op};
144 ///
145 /// # async fn run() {
146 /// let pipeline = pipeline::new()
147 /// .then(|email: String| async move {
148 /// email.split('@').next().unwrap_or_default().to_string()
149 /// })
150 /// .then(|username: String| async move {
151 /// format!("Hello, {}!", username)
152 /// });
153 ///
154 /// let result = pipeline.call("bob@gmail.com".to_string()).await;
155 /// assert_eq!(result, "Hello, bob!");
156 /// # }
157 /// ```
158 pub fn then<F, Input, Fut>(self, f: F) -> op::Then<F, Input>
159 where
160 F: Fn(Input) -> Fut + Send + Sync,
161 Input: Send + Sync,
162 Fut: Future + Send + Sync,
163 Fut::Output: Send + Sync,
164 Self: Sized,
165 {
166 op::Then::new(f)
167 }
168
169 /// Add an arbitrary operation to the current pipeline.
170 ///
171 /// # Example
172 /// ```no_run
173 /// use rig_core::pipeline::{self, Op};
174 ///
175 /// # async fn run() {
176 /// struct MyOp;
177 ///
178 /// impl Op for MyOp {
179 /// type Input = i32;
180 /// type Output = i32;
181 ///
182 /// async fn call(&self, input: Self::Input) -> Self::Output {
183 /// input + 1
184 /// }
185 /// }
186 ///
187 /// let pipeline = pipeline::new()
188 /// .chain(MyOp);
189 ///
190 /// let result = pipeline.call(1).await;
191 /// assert_eq!(result, 2);
192 /// # }
193 /// ```
194 pub fn chain<T>(self, op: T) -> T
195 where
196 T: Op,
197 Self: Sized,
198 {
199 op
200 }
201
202 /// Chain a lookup operation to the current chain. The lookup operation expects the
203 /// current chain to output a query string. The lookup operation will use the query to
204 /// retrieve the top `n` documents from the index and return them with the query string.
205 ///
206 /// # Example
207 /// ```ignore
208 /// use rig_core::pipeline::{self, Op};
209 ///
210 /// let pipeline = pipeline::new()
211 /// .lookup(index, 2)
212 /// .pipeline(|(query, docs): (_, Vec<String>)| async move {
213 /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
214 /// });
215 ///
216 /// let result = pipeline.call("What is a flurbo?".to_string()).await;
217 /// ```
218 pub fn lookup<I, Input, Output>(self, index: I, n: usize) -> agent_ops::Lookup<I, Input, Output>
219 where
220 I: vector_store::VectorStoreIndex,
221 Output: Send + Sync + for<'a> serde::Deserialize<'a>,
222 Input: Into<String> + Send + Sync,
223 // E: From<vector_store::VectorStoreError> + Send + Sync,
224 Self: Sized,
225 {
226 agent_ops::Lookup::new(index, n)
227 }
228
229 /// Add a prompt operation to the current pipeline/op. The prompt operation expects the
230 /// current pipeline to output a string. The prompt operation will use the string to prompt
231 /// the given `agent`, which must implements the [Prompt](completion::Prompt) trait and return
232 /// the response.
233 ///
234 /// # Example
235 /// ```ignore
236 /// use rig_core::pipeline::{self, Op};
237 ///
238 /// let agent = &openai_client.agent("gpt-4").build();
239 ///
240 /// let pipeline = pipeline::new()
241 /// .map(|name| format!("Find funny nicknames for the following name: {name}!"))
242 /// .prompt(agent);
243 ///
244 /// let result = pipeline.call("Alice".to_string()).await;
245 /// ```
246 pub fn prompt<P, Input>(self, agent: P) -> agent_ops::Prompt<P, Input>
247 where
248 P: completion::Prompt,
249 Input: Into<String> + Send + Sync,
250 // E: From<completion::PromptError> + Send + Sync,
251 Self: Sized,
252 {
253 agent_ops::Prompt::new(agent)
254 }
255
256 /// Add an extract operation to the current pipeline/op. The extract operation expects the
257 /// current pipeline to output a string. The extract operation will use the given `extractor`
258 /// to extract information from the string in the form of the type `T` and return it.
259 ///
260 /// # Example
261 /// ```ignore
262 /// use rig_core::pipeline::{self, Op};
263 ///
264 /// #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
265 /// struct Sentiment {
266 /// /// The sentiment score of the text (0.0 = negative, 1.0 = positive)
267 /// score: f64,
268 /// }
269 ///
270 /// let extractor = &openai_client.extractor::<Sentiment>("gpt-4").build();
271 ///
272 /// let pipeline = pipeline::new()
273 /// .map(|text| format!("Analyze the sentiment of the following text: {text}!"))
274 /// .extract(extractor);
275 ///
276 /// let result: Sentiment = pipeline.call("I love ice cream!".to_string()).await?;
277 /// assert!(result.score > 0.5);
278 /// ```
279 pub fn extract<M, Input, Output>(
280 self,
281 extractor: Extractor<M, Output>,
282 ) -> agent_ops::Extract<M, Input, Output>
283 where
284 M: completion::CompletionModel,
285 Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
286 Input: Into<String> + Send + Sync,
287 {
288 agent_ops::Extract::new(extractor)
289 }
290}
291
292#[derive(Debug, thiserror::Error)]
293pub enum ChainError {
294 #[error("Failed to prompt agent: {0}")]
295 PromptError(#[from] Box<completion::PromptError>),
296
297 #[error("Failed to lookup documents: {0}")]
298 LookupError(#[from] vector_store::VectorStoreError),
299}
300
301/// Create a pipeline builder using Rig's default [`ChainError`] type.
302pub fn new() -> PipelineBuilder<ChainError> {
303 PipelineBuilder {
304 _error: std::marker::PhantomData,
305 }
306}
307
308/// Create a pipeline builder with a custom error type.
309pub fn with_error<E>() -> PipelineBuilder<E> {
310 PipelineBuilder {
311 _error: std::marker::PhantomData,
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use crate::test_utils::{Foo, MockPromptModel, MockVectorStoreIndex};
319 use parallel::parallel;
320
321 #[tokio::test]
322 async fn test_prompt_pipeline() {
323 let model = MockPromptModel;
324
325 let chain = super::new()
326 .map(|input| format!("User query: {input}"))
327 .prompt(model);
328
329 let result = chain
330 .call("What is a flurbo?")
331 .await
332 .expect("Failed to run chain");
333
334 assert_eq!(result, "Mock response: User query: What is a flurbo?");
335 }
336
337 #[tokio::test]
338 async fn test_prompt_pipeline_error() {
339 let model = MockPromptModel;
340
341 let chain = super::with_error::<()>()
342 .map(|input| format!("User query: {input}"))
343 .prompt(model);
344
345 let result = chain
346 .try_call("What is a flurbo?")
347 .await
348 .expect("Failed to run chain");
349
350 assert_eq!(result, "Mock response: User query: What is a flurbo?");
351 }
352
353 #[tokio::test]
354 async fn test_lookup_pipeline() {
355 let index = MockVectorStoreIndex;
356
357 let chain = super::new()
358 .lookup::<_, _, Foo>(index, 1)
359 .map_ok(|docs| format!("Top documents:\n{}", docs[0].2.foo));
360
361 let result = chain
362 .try_call("What is a flurbo?")
363 .await
364 .expect("Failed to run chain");
365
366 assert_eq!(result, "Top documents:\nbar");
367 }
368
369 #[tokio::test]
370 async fn test_rag_pipeline() {
371 let index = MockVectorStoreIndex;
372
373 let chain = super::new()
374 .chain(parallel!(
375 passthrough(),
376 agent_ops::lookup::<_, _, Foo>(index, 1),
377 ))
378 .map(|(query, maybe_docs)| match maybe_docs {
379 Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].2.foo),
380 Err(err) => format!("Error: {err}"),
381 })
382 .prompt(MockPromptModel);
383
384 let result = chain
385 .call("What is a flurbo?")
386 .await
387 .expect("Failed to run chain");
388
389 assert_eq!(
390 result,
391 "Mock response: User query: What is a flurbo?\n\nTop documents:\nbar"
392 );
393 }
394}