rig/pipeline/op.rs
1use crate::wasm_compat::*;
2#[allow(unused_imports)] // Needed since this is used in a macro rule
3use futures::join;
4use futures::stream;
5use std::future::Future;
6
7// ================================================================
8// Core Op trait
9// ================================================================
10pub trait Op: WasmCompatSend + WasmCompatSync {
11 type Input: WasmCompatSend + WasmCompatSync;
12 type Output: WasmCompatSend + WasmCompatSync;
13
14 fn call(&self, input: Self::Input) -> impl Future<Output = Self::Output> + WasmCompatSend;
15
16 /// Execute the current pipeline with the given inputs. `n` is the number of concurrent
17 /// inputs that will be processed concurrently.
18 fn batch_call<I>(
19 &self,
20 n: usize,
21 input: I,
22 ) -> impl Future<Output = Vec<Self::Output>> + WasmCompatSend
23 where
24 I: IntoIterator<Item = Self::Input> + WasmCompatSend,
25 I::IntoIter: WasmCompatSend,
26 Self: Sized,
27 {
28 use futures::stream::StreamExt;
29
30 async move {
31 stream::iter(input)
32 .map(|input| self.call(input))
33 .buffered(n)
34 .collect()
35 .await
36 }
37 }
38
39 /// Chain a function `f` to the current op.
40 ///
41 /// # Example
42 /// ```rust
43 /// use rig::pipeline::{self, Op};
44 ///
45 /// let chain = pipeline::new()
46 /// .map(|(x, y)| x + y)
47 /// .map(|z| format!("Result: {z}!"));
48 ///
49 /// let result = chain.call((1, 2)).await;
50 /// assert_eq!(result, "Result: 3!");
51 /// ```
52 fn map<F, Input>(self, f: F) -> Sequential<Self, Map<F, Self::Output>>
53 where
54 F: Fn(Self::Output) -> Input + WasmCompatSend + WasmCompatSync,
55 Input: WasmCompatSend + WasmCompatSync,
56 Self: Sized,
57 {
58 Sequential::new(self, Map::new(f))
59 }
60
61 /// Same as `map` but for asynchronous functions
62 ///
63 /// # Example
64 /// ```rust
65 /// use rig::pipeline::{self, Op};
66 ///
67 /// let chain = pipeline::new()
68 /// .then(|email: String| async move {
69 /// email.split('@').next().unwrap().to_string()
70 /// })
71 /// .then(|username: String| async move {
72 /// format!("Hello, {}!", username)
73 /// });
74 ///
75 /// let result = chain.call("bob@gmail.com".to_string()).await;
76 /// assert_eq!(result, "Hello, bob!");
77 /// ```
78 fn then<F, Fut>(self, f: F) -> Sequential<Self, Then<F, Fut::Output>>
79 where
80 F: Fn(Self::Output) -> Fut + Send + WasmCompatSync,
81 Fut: Future + WasmCompatSend + WasmCompatSync,
82 Fut::Output: WasmCompatSend + WasmCompatSync,
83 Self: Sized,
84 {
85 Sequential::new(self, Then::new(f))
86 }
87
88 /// Chain an arbitrary operation to the current op.
89 ///
90 /// # Example
91 /// ```rust
92 /// use rig::pipeline::{self, Op};
93 ///
94 /// struct AddOne;
95 ///
96 /// impl Op for AddOne {
97 /// type Input = i32;
98 /// type Output = i32;
99 ///
100 /// async fn call(&self, input: Self::Input) -> Self::Output {
101 /// input + 1
102 /// }
103 /// }
104 ///
105 /// let chain = pipeline::new()
106 /// .chain(AddOne);
107 ///
108 /// let result = chain.call(1).await;
109 /// assert_eq!(result, 2);
110 /// ```
111 fn chain<T>(self, op: T) -> Sequential<Self, T>
112 where
113 T: Op<Input = Self::Output>,
114 Self: Sized,
115 {
116 Sequential::new(self, op)
117 }
118
119 /// Chain a lookup operation to the current chain. The lookup operation expects the
120 /// current chain to output a query string. The lookup operation will use the query to
121 /// retrieve the top `n` documents from the index and return them with the query string.
122 ///
123 /// # Example
124 /// ```rust
125 /// use rig::chain::{self, Chain};
126 ///
127 /// let chain = chain::new()
128 /// .lookup(index, 2)
129 /// .chain(|(query, docs): (_, Vec<String>)| async move {
130 /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
131 /// });
132 ///
133 /// let result = chain.call("What is a flurbo?".to_string()).await;
134 /// ```
135 fn lookup<I, Input>(
136 self,
137 index: I,
138 n: usize,
139 ) -> Sequential<Self, Lookup<I, Self::Output, Input>>
140 where
141 I: vector_store::VectorStoreIndex,
142 Input: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>,
143 Self::Output: Into<String>,
144 Self: Sized,
145 {
146 Sequential::new(self, Lookup::new(index, n))
147 }
148
149 /// Chain a prompt operation to the current chain. The prompt operation expects the
150 /// current chain to output a string. The prompt operation will use the string to prompt
151 /// the given agent (or any other type that implements the `Prompt` trait) and return
152 /// the response.
153 ///
154 /// # Example
155 /// ```rust
156 /// use rig::chain::{self, Chain};
157 ///
158 /// let agent = &openai_client.agent("gpt-4").build();
159 ///
160 /// let chain = chain::new()
161 /// .map(|name| format!("Find funny nicknames for the following name: {name}!"))
162 /// .prompt(agent);
163 ///
164 /// let result = chain.call("Alice".to_string()).await;
165 /// ```
166 fn prompt<P>(self, prompt: P) -> Sequential<Self, Prompt<P, Self::Output>>
167 where
168 P: completion::Prompt,
169 Self::Output: Into<String>,
170 Self: Sized,
171 {
172 Sequential::new(self, Prompt::new(prompt))
173 }
174}
175
176impl<T: Op> Op for &T {
177 type Input = T::Input;
178 type Output = T::Output;
179
180 #[inline]
181 async fn call(&self, input: Self::Input) -> Self::Output {
182 (*self).call(input).await
183 }
184}
185
186// ================================================================
187// Op combinators
188// ================================================================
189pub struct Sequential<Op1, Op2> {
190 prev: Op1,
191 op: Op2,
192}
193
194impl<Op1, Op2> Sequential<Op1, Op2> {
195 pub(crate) fn new(prev: Op1, op: Op2) -> Self {
196 Self { prev, op }
197 }
198}
199
200impl<Op1, Op2> Op for Sequential<Op1, Op2>
201where
202 Op1: Op,
203 Op2: Op<Input = Op1::Output>,
204{
205 type Input = Op1::Input;
206 type Output = Op2::Output;
207
208 #[inline]
209 async fn call(&self, input: Self::Input) -> Self::Output {
210 let prev = self.prev.call(input).await;
211 self.op.call(prev).await
212 }
213}
214
215use super::agent_ops::{Lookup, Prompt};
216use crate::{completion, vector_store};
217
218// ================================================================
219// Core Op implementations
220// ================================================================
221pub struct Map<F, Input> {
222 f: F,
223 _t: std::marker::PhantomData<Input>,
224}
225
226impl<F, Input> Map<F, Input> {
227 pub(crate) fn new(f: F) -> Self {
228 Self {
229 f,
230 _t: std::marker::PhantomData,
231 }
232 }
233}
234
235impl<F, Input, Output> Op for Map<F, Input>
236where
237 F: Fn(Input) -> Output + WasmCompatSend + WasmCompatSync,
238 Input: WasmCompatSend + WasmCompatSync,
239 Output: WasmCompatSend + WasmCompatSync,
240{
241 type Input = Input;
242 type Output = Output;
243
244 #[inline]
245 async fn call(&self, input: Self::Input) -> Self::Output {
246 (self.f)(input)
247 }
248}
249
250pub fn map<F, Input, Output>(f: F) -> Map<F, Input>
251where
252 F: Fn(Input) -> Output + WasmCompatSend + WasmCompatSync,
253 Input: WasmCompatSend + WasmCompatSync,
254 Output: WasmCompatSend + WasmCompatSync,
255{
256 Map::new(f)
257}
258
259pub struct Passthrough<T> {
260 _t: std::marker::PhantomData<T>,
261}
262
263impl<T> Passthrough<T> {
264 pub(crate) fn new() -> Self {
265 Self {
266 _t: std::marker::PhantomData,
267 }
268 }
269}
270
271impl<T> Op for Passthrough<T>
272where
273 T: WasmCompatSend + WasmCompatSync,
274{
275 type Input = T;
276 type Output = T;
277
278 async fn call(&self, input: Self::Input) -> Self::Output {
279 input
280 }
281}
282
283pub fn passthrough<T>() -> Passthrough<T>
284where
285 T: WasmCompatSend + WasmCompatSync,
286{
287 Passthrough::new()
288}
289
290pub struct Then<F, Input> {
291 f: F,
292 _t: std::marker::PhantomData<Input>,
293}
294
295impl<F, Input> Then<F, Input> {
296 pub(crate) fn new(f: F) -> Self {
297 Self {
298 f,
299 _t: std::marker::PhantomData,
300 }
301 }
302}
303
304impl<F, Input, Fut> Op for Then<F, Input>
305where
306 F: Fn(Input) -> Fut + WasmCompatSend + WasmCompatSync,
307 Input: WasmCompatSend + WasmCompatSync,
308 Fut: Future + WasmCompatSend,
309 Fut::Output: WasmCompatSend + WasmCompatSync,
310{
311 type Input = Input;
312 type Output = Fut::Output;
313
314 #[inline]
315 async fn call(&self, input: Self::Input) -> Self::Output {
316 (self.f)(input).await
317 }
318}
319
320pub fn then<F, Input, Fut>(f: F) -> Then<F, Input>
321where
322 F: Fn(Input) -> Fut + WasmCompatSend + WasmCompatSync,
323 Input: WasmCompatSend + WasmCompatSync,
324 Fut: Future + WasmCompatSend,
325 Fut::Output: WasmCompatSend + WasmCompatSync,
326{
327 Then::new(f)
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[tokio::test]
335 async fn test_sequential_constructor() {
336 let op1 = map(|x: i32| x + 1);
337 let op2 = map(|x: i32| x * 2);
338 let op3 = map(|x: i32| x * 3);
339
340 let pipeline = Sequential::new(Sequential::new(op1, op2), op3);
341
342 let result = pipeline.call(1).await;
343 assert_eq!(result, 12);
344 }
345
346 #[tokio::test]
347 async fn test_sequential_chain() {
348 let pipeline = map(|x: i32| x + 1)
349 .map(|x| x * 2)
350 .then(|x| async move { x * 3 });
351
352 let result = pipeline.call(1).await;
353 assert_eq!(result, 12);
354 }
355
356 // #[tokio::test]
357 // async fn test_flatten() {
358 // let op = Parallel::new(
359 // Parallel::new(
360 // map(|x: i32| x + 1),
361 // map(|x: i32| x * 2),
362 // ),
363 // map(|x: i32| x * 3),
364 // );
365
366 // let pipeline = flatten::<_, (_, _, _)>(op);
367
368 // let result = pipeline.call(1).await;
369 // assert_eq!(result, (2, 2, 3));
370 // }
371
372 // #[tokio::test]
373 // async fn test_parallel_macro() {
374 // let op1 = map(|x: i32| x + 1);
375 // let op2 = map(|x: i32| x * 3);
376 // let op3 = map(|x: i32| format!("{} is the number!", x));
377 // let op4 = map(|x: i32| x - 1);
378
379 // let pipeline = parallel!(op1, op2, op3, op4);
380
381 // let result = pipeline.call(1).await;
382 // assert_eq!(result, (2, 3, "1 is the number!".to_string(), 0));
383 // }
384
385 // #[tokio::test]
386 // async fn test_parallel_join() {
387 // let op3 = map(|x: i32| format!("{} is the number!", x));
388
389 // let pipeline = Sequential::new(
390 // map(|x: i32| x + 1),
391 // then(|x| {
392 // // let op1 = map(|x: i32| x * 2);
393 // // let op2 = map(|x: i32| x * 3);
394 // let op3 = &op3;
395
396 // async move {
397 // join!(
398 // (&map(|x: i32| x * 2)).call(x),
399 // {
400 // let op = map(|x: i32| x * 3);
401 // op.call(x)
402 // },
403 // op3.call(x),
404 // )
405 // }}),
406 // );
407
408 // let result = pipeline.call(1).await;
409 // assert_eq!(result, (2, 3, "1 is the number!".to_string()));
410 // }
411
412 // #[test]
413 // fn test_flatten() {
414 // let x = (1, (2, (3, 4)));
415 // let result = flatten!(0, 1, 1, 1, 1);
416 // assert_eq!(result, (1, 2, 3, 4));
417 // }
418}