rig_core/pipeline/op.rs
1use crate::wasm_compat::*;
2#[allow(unused_imports)] // Needed since this is used in a macro rule
3use futures::join;
4use futures::stream;
5use std::future::Future;
6
7// ================================================================
8// Core Op trait
9// ================================================================
10pub trait Op: WasmCompatSend + WasmCompatSync {
11 type Input: WasmCompatSend + WasmCompatSync;
12 type Output: WasmCompatSend + WasmCompatSync;
13
14 fn call(&self, input: Self::Input) -> impl Future<Output = Self::Output> + WasmCompatSend;
15
16 /// Execute the current pipeline with the given inputs. `n` is the number of concurrent
17 /// inputs that will be processed concurrently.
18 fn batch_call<I>(
19 &self,
20 n: usize,
21 input: I,
22 ) -> impl Future<Output = Vec<Self::Output>> + WasmCompatSend
23 where
24 I: IntoIterator<Item = Self::Input> + WasmCompatSend,
25 I::IntoIter: WasmCompatSend,
26 Self: Sized,
27 {
28 use futures::stream::StreamExt;
29
30 async move {
31 stream::iter(input)
32 .map(|input| self.call(input))
33 .buffered(n)
34 .collect()
35 .await
36 }
37 }
38
39 /// Chain a function `f` to the current op.
40 ///
41 /// # Example
42 /// ```no_run
43 /// use rig_core::pipeline::{self, Op};
44 ///
45 /// # async fn run() {
46 /// let chain = pipeline::new()
47 /// .map(|(x, y)| x + y)
48 /// .map(|z| format!("Result: {z}!"));
49 ///
50 /// let result = chain.call((1, 2)).await;
51 /// assert_eq!(result, "Result: 3!");
52 /// # }
53 /// ```
54 fn map<F, Input>(self, f: F) -> Sequential<Self, Map<F, Self::Output>>
55 where
56 F: Fn(Self::Output) -> Input + WasmCompatSend + WasmCompatSync,
57 Input: WasmCompatSend + WasmCompatSync,
58 Self: Sized,
59 {
60 Sequential::new(self, Map::new(f))
61 }
62
63 /// Same as `map` but for asynchronous functions
64 ///
65 /// # Example
66 /// ```no_run
67 /// use rig_core::pipeline::{self, Op};
68 ///
69 /// # async fn run() {
70 /// let chain = pipeline::new()
71 /// .then(|email: String| async move {
72 /// email.split('@').next().unwrap().to_string()
73 /// })
74 /// .then(|username: String| async move {
75 /// format!("Hello, {}!", username)
76 /// });
77 ///
78 /// let result = chain.call("bob@gmail.com".to_string()).await;
79 /// assert_eq!(result, "Hello, bob!");
80 /// # }
81 /// ```
82 fn then<F, Fut>(self, f: F) -> Sequential<Self, Then<F, Fut::Output>>
83 where
84 F: Fn(Self::Output) -> Fut + Send + WasmCompatSync,
85 Fut: Future + WasmCompatSend + WasmCompatSync,
86 Fut::Output: WasmCompatSend + WasmCompatSync,
87 Self: Sized,
88 {
89 Sequential::new(self, Then::new(f))
90 }
91
92 /// Chain an arbitrary operation to the current op.
93 ///
94 /// # Example
95 /// ```no_run
96 /// use rig_core::pipeline::{self, Op};
97 ///
98 /// # async fn run() {
99 /// struct AddOne;
100 ///
101 /// impl Op for AddOne {
102 /// type Input = i32;
103 /// type Output = i32;
104 ///
105 /// async fn call(&self, input: Self::Input) -> Self::Output {
106 /// input + 1
107 /// }
108 /// }
109 ///
110 /// let chain = pipeline::new()
111 /// .chain(AddOne);
112 ///
113 /// let result = chain.call(1).await;
114 /// assert_eq!(result, 2);
115 /// # }
116 /// ```
117 fn chain<T>(self, op: T) -> Sequential<Self, T>
118 where
119 T: Op<Input = Self::Output>,
120 Self: Sized,
121 {
122 Sequential::new(self, op)
123 }
124
125 /// Chain a lookup operation to the current chain. The lookup operation expects the
126 /// current chain to output a query string. The lookup operation will use the query to
127 /// retrieve the top `n` documents from the index and return them with the query string.
128 ///
129 /// # Example
130 /// ```ignore
131 /// use rig_core::chain::{self, Chain};
132 ///
133 /// let chain = chain::new()
134 /// .lookup(index, 2)
135 /// .chain(|(query, docs): (_, Vec<String>)| async move {
136 /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
137 /// });
138 ///
139 /// let result = chain.call("What is a flurbo?".to_string()).await;
140 /// ```
141 fn lookup<I, Input>(
142 self,
143 index: I,
144 n: usize,
145 ) -> Sequential<Self, Lookup<I, Self::Output, Input>>
146 where
147 I: vector_store::VectorStoreIndex,
148 Input: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>,
149 Self::Output: Into<String>,
150 Self: Sized,
151 {
152 Sequential::new(self, Lookup::new(index, n))
153 }
154
155 /// Chain a prompt operation to the current chain. The prompt operation expects the
156 /// current chain to output a string. The prompt operation will use the string to prompt
157 /// the given agent (or any other type that implements the `Prompt` trait) and return
158 /// the response.
159 ///
160 /// # Example
161 /// ```ignore
162 /// use rig_core::chain::{self, Chain};
163 ///
164 /// let agent = &openai_client.agent("gpt-4").build();
165 ///
166 /// let chain = chain::new()
167 /// .map(|name| format!("Find funny nicknames for the following name: {name}!"))
168 /// .prompt(agent);
169 ///
170 /// let result = chain.call("Alice".to_string()).await;
171 /// ```
172 fn prompt<P>(self, prompt: P) -> Sequential<Self, Prompt<P, Self::Output>>
173 where
174 P: completion::Prompt,
175 Self::Output: Into<String>,
176 Self: Sized,
177 {
178 Sequential::new(self, Prompt::new(prompt))
179 }
180}
181
182impl<T: Op> Op for &T {
183 type Input = T::Input;
184 type Output = T::Output;
185
186 #[inline]
187 async fn call(&self, input: Self::Input) -> Self::Output {
188 (*self).call(input).await
189 }
190}
191
192// ================================================================
193// Op combinators
194// ================================================================
195pub struct Sequential<Op1, Op2> {
196 prev: Op1,
197 op: Op2,
198}
199
200impl<Op1, Op2> Sequential<Op1, Op2> {
201 pub(crate) fn new(prev: Op1, op: Op2) -> Self {
202 Self { prev, op }
203 }
204}
205
206impl<Op1, Op2> Op for Sequential<Op1, Op2>
207where
208 Op1: Op,
209 Op2: Op<Input = Op1::Output>,
210{
211 type Input = Op1::Input;
212 type Output = Op2::Output;
213
214 #[inline]
215 async fn call(&self, input: Self::Input) -> Self::Output {
216 let prev = self.prev.call(input).await;
217 self.op.call(prev).await
218 }
219}
220
221use super::agent_ops::{Lookup, Prompt};
222use crate::{completion, vector_store};
223
224// ================================================================
225// Core Op implementations
226// ================================================================
227pub struct Map<F, Input> {
228 f: F,
229 _t: std::marker::PhantomData<Input>,
230}
231
232impl<F, Input> Map<F, Input> {
233 pub(crate) fn new(f: F) -> Self {
234 Self {
235 f,
236 _t: std::marker::PhantomData,
237 }
238 }
239}
240
241impl<F, Input, Output> Op for Map<F, Input>
242where
243 F: Fn(Input) -> Output + WasmCompatSend + WasmCompatSync,
244 Input: WasmCompatSend + WasmCompatSync,
245 Output: WasmCompatSend + WasmCompatSync,
246{
247 type Input = Input;
248 type Output = Output;
249
250 #[inline]
251 async fn call(&self, input: Self::Input) -> Self::Output {
252 (self.f)(input)
253 }
254}
255
256pub fn map<F, Input, Output>(f: F) -> Map<F, Input>
257where
258 F: Fn(Input) -> Output + WasmCompatSend + WasmCompatSync,
259 Input: WasmCompatSend + WasmCompatSync,
260 Output: WasmCompatSend + WasmCompatSync,
261{
262 Map::new(f)
263}
264
265pub struct Passthrough<T> {
266 _t: std::marker::PhantomData<T>,
267}
268
269impl<T> Passthrough<T> {
270 pub(crate) fn new() -> Self {
271 Self {
272 _t: std::marker::PhantomData,
273 }
274 }
275}
276
277impl<T> Op for Passthrough<T>
278where
279 T: WasmCompatSend + WasmCompatSync,
280{
281 type Input = T;
282 type Output = T;
283
284 async fn call(&self, input: Self::Input) -> Self::Output {
285 input
286 }
287}
288
289pub fn passthrough<T>() -> Passthrough<T>
290where
291 T: WasmCompatSend + WasmCompatSync,
292{
293 Passthrough::new()
294}
295
296pub struct Then<F, Input> {
297 f: F,
298 _t: std::marker::PhantomData<Input>,
299}
300
301impl<F, Input> Then<F, Input> {
302 pub(crate) fn new(f: F) -> Self {
303 Self {
304 f,
305 _t: std::marker::PhantomData,
306 }
307 }
308}
309
310impl<F, Input, Fut> Op for Then<F, Input>
311where
312 F: Fn(Input) -> Fut + WasmCompatSend + WasmCompatSync,
313 Input: WasmCompatSend + WasmCompatSync,
314 Fut: Future + WasmCompatSend,
315 Fut::Output: WasmCompatSend + WasmCompatSync,
316{
317 type Input = Input;
318 type Output = Fut::Output;
319
320 #[inline]
321 async fn call(&self, input: Self::Input) -> Self::Output {
322 (self.f)(input).await
323 }
324}
325
326pub fn then<F, Input, Fut>(f: F) -> Then<F, Input>
327where
328 F: Fn(Input) -> Fut + WasmCompatSend + WasmCompatSync,
329 Input: WasmCompatSend + WasmCompatSync,
330 Fut: Future + WasmCompatSend,
331 Fut::Output: WasmCompatSend + WasmCompatSync,
332{
333 Then::new(f)
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[tokio::test]
341 async fn test_sequential_constructor() {
342 let op1 = map(|x: i32| x + 1);
343 let op2 = map(|x: i32| x * 2);
344 let op3 = map(|x: i32| x * 3);
345
346 let pipeline = Sequential::new(Sequential::new(op1, op2), op3);
347
348 let result = pipeline.call(1).await;
349 assert_eq!(result, 12);
350 }
351
352 #[tokio::test]
353 async fn test_sequential_chain() {
354 let pipeline = map(|x: i32| x + 1)
355 .map(|x| x * 2)
356 .then(|x| async move { x * 3 });
357
358 let result = pipeline.call(1).await;
359 assert_eq!(result, 12);
360 }
361
362 // #[tokio::test]
363 // async fn test_flatten() {
364 // let op = Parallel::new(
365 // Parallel::new(
366 // map(|x: i32| x + 1),
367 // map(|x: i32| x * 2),
368 // ),
369 // map(|x: i32| x * 3),
370 // );
371
372 // let pipeline = flatten::<_, (_, _, _)>(op);
373
374 // let result = pipeline.call(1).await;
375 // assert_eq!(result, (2, 2, 3));
376 // }
377
378 // #[tokio::test]
379 // async fn test_parallel_macro() {
380 // let op1 = map(|x: i32| x + 1);
381 // let op2 = map(|x: i32| x * 3);
382 // let op3 = map(|x: i32| format!("{} is the number!", x));
383 // let op4 = map(|x: i32| x - 1);
384
385 // let pipeline = parallel!(op1, op2, op3, op4);
386
387 // let result = pipeline.call(1).await;
388 // assert_eq!(result, (2, 3, "1 is the number!".to_string(), 0));
389 // }
390
391 // #[tokio::test]
392 // async fn test_parallel_join() {
393 // let op3 = map(|x: i32| format!("{} is the number!", x));
394
395 // let pipeline = Sequential::new(
396 // map(|x: i32| x + 1),
397 // then(|x| {
398 // // let op1 = map(|x: i32| x * 2);
399 // // let op2 = map(|x: i32| x * 3);
400 // let op3 = &op3;
401
402 // async move {
403 // join!(
404 // (&map(|x: i32| x * 2)).call(x),
405 // {
406 // let op = map(|x: i32| x * 3);
407 // op.call(x)
408 // },
409 // op3.call(x),
410 // )
411 // }}),
412 // );
413
414 // let result = pipeline.call(1).await;
415 // assert_eq!(result, (2, 3, "1 is the number!".to_string()));
416 // }
417
418 // #[test]
419 // fn test_flatten() {
420 // let x = (1, (2, (3, 4)));
421 // let result = flatten!(0, 1, 1, 1, 1);
422 // assert_eq!(result, (1, 2, 3, 4));
423 // }
424}