rig/pipeline/op.rs
1use std::future::Future;
2
3#[allow(unused_imports)] // Needed since this is used in a macro rule
4use futures::join;
5use futures::stream;
6
7// ================================================================
8// Core Op trait
9// ================================================================
10pub trait Op: Send + Sync {
11 type Input: Send + Sync;
12 type Output: Send + Sync;
13
14 fn call(&self, input: Self::Input) -> impl Future<Output = Self::Output> + Send;
15
16 /// Execute the current pipeline with the given inputs. `n` is the number of concurrent
17 /// inputs that will be processed concurrently.
18 fn batch_call<I>(&self, n: usize, input: I) -> impl Future<Output = Vec<Self::Output>> + Send
19 where
20 I: IntoIterator<Item = Self::Input> + Send,
21 I::IntoIter: Send,
22 Self: Sized,
23 {
24 use futures::stream::StreamExt;
25
26 async move {
27 stream::iter(input)
28 .map(|input| self.call(input))
29 .buffered(n)
30 .collect()
31 .await
32 }
33 }
34
35 /// Chain a function `f` to the current op.
36 ///
37 /// # Example
38 /// ```rust
39 /// use rig::pipeline::{self, Op};
40 ///
41 /// let chain = pipeline::new()
42 /// .map(|(x, y)| x + y)
43 /// .map(|z| format!("Result: {z}!"));
44 ///
45 /// let result = chain.call((1, 2)).await;
46 /// assert_eq!(result, "Result: 3!");
47 /// ```
48 fn map<F, Input>(self, f: F) -> Sequential<Self, Map<F, Self::Output>>
49 where
50 F: Fn(Self::Output) -> Input + Send + Sync,
51 Input: Send + Sync,
52 Self: Sized,
53 {
54 Sequential::new(self, Map::new(f))
55 }
56
57 /// Same as `map` but for asynchronous functions
58 ///
59 /// # Example
60 /// ```rust
61 /// use rig::pipeline::{self, Op};
62 ///
63 /// let chain = pipeline::new()
64 /// .then(|email: String| async move {
65 /// email.split('@').next().unwrap().to_string()
66 /// })
67 /// .then(|username: String| async move {
68 /// format!("Hello, {}!", username)
69 /// });
70 ///
71 /// let result = chain.call("bob@gmail.com".to_string()).await;
72 /// assert_eq!(result, "Hello, bob!");
73 /// ```
74 fn then<F, Fut>(self, f: F) -> Sequential<Self, Then<F, Fut::Output>>
75 where
76 F: Fn(Self::Output) -> Fut + Send + Sync,
77 Fut: Future + Send + Sync,
78 Fut::Output: Send + Sync,
79 Self: Sized,
80 {
81 Sequential::new(self, Then::new(f))
82 }
83
84 /// Chain an arbitrary operation to the current op.
85 ///
86 /// # Example
87 /// ```rust
88 /// use rig::pipeline::{self, Op};
89 ///
90 /// struct AddOne;
91 ///
92 /// impl Op for AddOne {
93 /// type Input = i32;
94 /// type Output = i32;
95 ///
96 /// async fn call(&self, input: Self::Input) -> Self::Output {
97 /// input + 1
98 /// }
99 /// }
100 ///
101 /// let chain = pipeline::new()
102 /// .chain(AddOne);
103 ///
104 /// let result = chain.call(1).await;
105 /// assert_eq!(result, 2);
106 /// ```
107 fn chain<T>(self, op: T) -> Sequential<Self, T>
108 where
109 T: Op<Input = Self::Output>,
110 Self: Sized,
111 {
112 Sequential::new(self, op)
113 }
114
115 /// Chain a lookup operation to the current chain. The lookup operation expects the
116 /// current chain to output a query string. The lookup operation will use the query to
117 /// retrieve the top `n` documents from the index and return them with the query string.
118 ///
119 /// # Example
120 /// ```rust
121 /// use rig::chain::{self, Chain};
122 ///
123 /// let chain = chain::new()
124 /// .lookup(index, 2)
125 /// .chain(|(query, docs): (_, Vec<String>)| async move {
126 /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
127 /// });
128 ///
129 /// let result = chain.call("What is a flurbo?".to_string()).await;
130 /// ```
131 fn lookup<I, Input>(
132 self,
133 index: I,
134 n: usize,
135 ) -> Sequential<Self, Lookup<I, Self::Output, Input>>
136 where
137 I: vector_store::VectorStoreIndex,
138 Input: Send + Sync + for<'a> serde::Deserialize<'a>,
139 Self::Output: Into<String>,
140 Self: Sized,
141 {
142 Sequential::new(self, Lookup::new(index, n))
143 }
144
145 /// Chain a prompt operation to the current chain. The prompt operation expects the
146 /// current chain to output a string. The prompt operation will use the string to prompt
147 /// the given agent (or any other type that implements the `Prompt` trait) and return
148 /// the response.
149 ///
150 /// # Example
151 /// ```rust
152 /// use rig::chain::{self, Chain};
153 ///
154 /// let agent = &openai_client.agent("gpt-4").build();
155 ///
156 /// let chain = chain::new()
157 /// .map(|name| format!("Find funny nicknames for the following name: {name}!"))
158 /// .prompt(agent);
159 ///
160 /// let result = chain.call("Alice".to_string()).await;
161 /// ```
162 fn prompt<P>(self, prompt: P) -> Sequential<Self, Prompt<P, Self::Output>>
163 where
164 P: completion::Prompt,
165 Self::Output: Into<String>,
166 Self: Sized,
167 {
168 Sequential::new(self, Prompt::new(prompt))
169 }
170}
171
172impl<T: Op> Op for &T {
173 type Input = T::Input;
174 type Output = T::Output;
175
176 #[inline]
177 async fn call(&self, input: Self::Input) -> Self::Output {
178 (*self).call(input).await
179 }
180}
181
182// ================================================================
183// Op combinators
184// ================================================================
185pub struct Sequential<Op1, Op2> {
186 prev: Op1,
187 op: Op2,
188}
189
190impl<Op1, Op2> Sequential<Op1, Op2> {
191 pub(crate) fn new(prev: Op1, op: Op2) -> Self {
192 Self { prev, op }
193 }
194}
195
196impl<Op1, Op2> Op for Sequential<Op1, Op2>
197where
198 Op1: Op,
199 Op2: Op<Input = Op1::Output>,
200{
201 type Input = Op1::Input;
202 type Output = Op2::Output;
203
204 #[inline]
205 async fn call(&self, input: Self::Input) -> Self::Output {
206 let prev = self.prev.call(input).await;
207 self.op.call(prev).await
208 }
209}
210
211use crate::{completion, vector_store};
212
213use super::agent_ops::{Lookup, Prompt};
214
215// ================================================================
216// Core Op implementations
217// ================================================================
218pub struct Map<F, Input> {
219 f: F,
220 _t: std::marker::PhantomData<Input>,
221}
222
223impl<F, Input> Map<F, Input> {
224 pub(crate) fn new(f: F) -> Self {
225 Self {
226 f,
227 _t: std::marker::PhantomData,
228 }
229 }
230}
231
232impl<F, Input, Output> Op for Map<F, Input>
233where
234 F: Fn(Input) -> Output + Send + Sync,
235 Input: Send + Sync,
236 Output: Send + Sync,
237{
238 type Input = Input;
239 type Output = Output;
240
241 #[inline]
242 async fn call(&self, input: Self::Input) -> Self::Output {
243 (self.f)(input)
244 }
245}
246
247pub fn map<F, Input, Output>(f: F) -> Map<F, Input>
248where
249 F: Fn(Input) -> Output + Send + Sync,
250 Input: Send + Sync,
251 Output: Send + Sync,
252{
253 Map::new(f)
254}
255
256pub struct Passthrough<T> {
257 _t: std::marker::PhantomData<T>,
258}
259
260impl<T> Passthrough<T> {
261 pub(crate) fn new() -> Self {
262 Self {
263 _t: std::marker::PhantomData,
264 }
265 }
266}
267
268impl<T> Op for Passthrough<T>
269where
270 T: Send + Sync,
271{
272 type Input = T;
273 type Output = T;
274
275 async fn call(&self, input: Self::Input) -> Self::Output {
276 input
277 }
278}
279
280pub fn passthrough<T>() -> Passthrough<T>
281where
282 T: Send + Sync,
283{
284 Passthrough::new()
285}
286
287pub struct Then<F, Input> {
288 f: F,
289 _t: std::marker::PhantomData<Input>,
290}
291
292impl<F, Input> Then<F, Input> {
293 pub(crate) fn new(f: F) -> Self {
294 Self {
295 f,
296 _t: std::marker::PhantomData,
297 }
298 }
299}
300
301impl<F, Input, Fut> Op for Then<F, Input>
302where
303 F: Fn(Input) -> Fut + Send + Sync,
304 Input: Send + Sync,
305 Fut: Future + Send,
306 Fut::Output: Send + Sync,
307{
308 type Input = Input;
309 type Output = Fut::Output;
310
311 #[inline]
312 async fn call(&self, input: Self::Input) -> Self::Output {
313 (self.f)(input).await
314 }
315}
316
317pub fn then<F, Input, Fut>(f: F) -> Then<F, Input>
318where
319 F: Fn(Input) -> Fut + Send + Sync,
320 Input: Send + Sync,
321 Fut: Future + Send,
322 Fut::Output: Send + Sync,
323{
324 Then::new(f)
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[tokio::test]
332 async fn test_sequential_constructor() {
333 let op1 = map(|x: i32| x + 1);
334 let op2 = map(|x: i32| x * 2);
335 let op3 = map(|x: i32| x * 3);
336
337 let pipeline = Sequential::new(Sequential::new(op1, op2), op3);
338
339 let result = pipeline.call(1).await;
340 assert_eq!(result, 12);
341 }
342
343 #[tokio::test]
344 async fn test_sequential_chain() {
345 let pipeline = map(|x: i32| x + 1)
346 .map(|x| x * 2)
347 .then(|x| async move { x * 3 });
348
349 let result = pipeline.call(1).await;
350 assert_eq!(result, 12);
351 }
352
353 // #[tokio::test]
354 // async fn test_flatten() {
355 // let op = Parallel::new(
356 // Parallel::new(
357 // map(|x: i32| x + 1),
358 // map(|x: i32| x * 2),
359 // ),
360 // map(|x: i32| x * 3),
361 // );
362
363 // let pipeline = flatten::<_, (_, _, _)>(op);
364
365 // let result = pipeline.call(1).await;
366 // assert_eq!(result, (2, 2, 3));
367 // }
368
369 // #[tokio::test]
370 // async fn test_parallel_macro() {
371 // let op1 = map(|x: i32| x + 1);
372 // let op2 = map(|x: i32| x * 3);
373 // let op3 = map(|x: i32| format!("{} is the number!", x));
374 // let op4 = map(|x: i32| x - 1);
375
376 // let pipeline = parallel!(op1, op2, op3, op4);
377
378 // let result = pipeline.call(1).await;
379 // assert_eq!(result, (2, 3, "1 is the number!".to_string(), 0));
380 // }
381
382 // #[tokio::test]
383 // async fn test_parallel_join() {
384 // let op3 = map(|x: i32| format!("{} is the number!", x));
385
386 // let pipeline = Sequential::new(
387 // map(|x: i32| x + 1),
388 // then(|x| {
389 // // let op1 = map(|x: i32| x * 2);
390 // // let op2 = map(|x: i32| x * 3);
391 // let op3 = &op3;
392
393 // async move {
394 // join!(
395 // (&map(|x: i32| x * 2)).call(x),
396 // {
397 // let op = map(|x: i32| x * 3);
398 // op.call(x)
399 // },
400 // op3.call(x),
401 // )
402 // }}),
403 // );
404
405 // let result = pipeline.call(1).await;
406 // assert_eq!(result, (2, 3, "1 is the number!".to_string()));
407 // }
408
409 // #[test]
410 // fn test_flatten() {
411 // let x = (1, (2, (3, 4)));
412 // let result = flatten!(0, 1, 1, 1, 1);
413 // assert_eq!(result, (1, 2, 3, 4));
414 // }
415}