anchor_chain/parallel_node.rs
1//! Provides a structure for processing input through multiple nodes in parallel.
2//!
3//! The `ParallelNode` struct represents a node that processes input through
4//! multiple nodes in parallel. The output of each node is then combined using
5//! a provided function to produce the final output.
6//!
7//! Example:
8//! ```rust,no_run
9//! use async_trait::async_trait;
10//! use futures::{future::BoxFuture, Future};
11//! use std::collections::HashMap;
12//!
13//! use anchor_chain::{
14//! chain::ChainBuilder,
15//! models::{claude_3::Claude3Bedrock, openai::OpenAIModel},
16//! parallel_node::{ParallelNode, to_boxed_future},
17//! nodes::prompt::Prompt,
18//! };
19//!
20//! #[tokio::main]
21//! async fn main() {
22//! let gpt3 =
23//! Box::new(OpenAIModel::new_gpt3_5_turbo("You are a helpful assistant".to_string()).await);
24//! let claude3 = Box::new(Claude3Bedrock::new("You are a helpful assistant".to_string()).await);
25//!
26//! let concat_fn = to_boxed_future(|outputs: Vec<String>| {
27//! Ok(outputs
28//! .iter()
29//! .enumerate()
30//! .map(|(i, output)| format!("Output {}:\n```\n{}\n```\n", i + 1, output))
31//! .collect::<Vec<String>>()
32//! .concat())
33//! });
34//!
35//!
36//! let chain = ChainBuilder::new()
37//! .link(Prompt::new("{{ input }}"))
38//! .link(ParallelNode::new(vec![gpt3, claude3], concat_fn))
39//! .build();
40//!
41//! let output = chain
42//! .process(HashMap::from([("input".to_string(), "Write a hello world program in Rust".to_string())]))
43//! .await
44//! .expect("Error processing chain");
45//! println!("{}", output);
46//! }
47//! ```
48
49use async_trait::async_trait;
50use futures::future::try_join_all;
51use futures::{future::BoxFuture, FutureExt};
52use std::fmt;
53#[cfg(feature = "tracing")]
54use tracing::{instrument, Instrument};
55
56use crate::error::AnchorChainError;
57use crate::node::Node;
58
59/// A function that combines the output of multiple nodes.
60///
61/// The function takes a vector of outputs from multiple nodes and returns a
62/// `Result` containing the final output. The BoxFuture can be created using
63/// the `to_boxed_future` helper function.
64type CombinationFunction<I, O> =
65 Box<dyn Fn(Vec<I>) -> BoxFuture<'static, Result<O, AnchorChainError>> + Send + Sync>;
66
67/// A node that processes input through multiple nodes in parallel.
68///
69/// The `ParallelNode` struct represents a node that processes input through
70/// multiple nodes in parallel. The output of each node is then combined using
71/// a provided function to produce the final output.
72pub struct ParallelNode<I, O, C>
73where
74 I: Clone + Send + Sync,
75 O: Send + Sync,
76 C: Send + Sync,
77{
78 /// The nodes that will process the input in parallel.
79 pub nodes: Vec<Box<dyn Node<Input = I, Output = O> + Send + Sync>>,
80 /// The function to process the output of the nodes.
81 pub function: CombinationFunction<O, C>,
82}
83
84impl<I, O, C> ParallelNode<I, O, C>
85where
86 I: Clone + Send + Sync,
87 O: Send + Sync,
88 C: Send + Sync,
89{
90 /// Creates a new `ParallelNode` with the provided nodes and combination
91 /// function.
92 ///
93 /// The combination function can be defined using the helper function `to_boxed_future`.
94 ///
95 /// # Example
96 /// // Using PassThroughNode as an example node
97 /// ```rust
98 /// use anchor_chain::{
99 /// node::NoOpNode,
100 /// parallel_node::ParallelNode,
101 /// parallel_node::to_boxed_future
102 /// };
103 ///
104 /// #[tokio::main]
105 /// async fn main() {
106 /// let node1 = Box::new(NoOpNode::new());
107 /// let node2 = Box::new(NoOpNode::new());
108 /// let concat_fn = to_boxed_future(|outputs: Vec<String>| {
109 /// Ok(outputs
110 /// .iter()
111 /// .enumerate()
112 /// .map(|(i, output)| format!("Output {}:\n```\n{}\n```\n", i + 1, output))
113 /// .collect::<Vec<String>>()
114 /// .concat())
115 /// });
116 /// let parallel_node = ParallelNode::new(vec![node1, node2], concat_fn);
117 /// }
118 pub fn new(
119 nodes: Vec<Box<dyn Node<Input = I, Output = O> + Send + Sync>>,
120 function: CombinationFunction<O, C>,
121 ) -> Self {
122 ParallelNode { nodes, function }
123 }
124}
125
126#[async_trait]
127impl<I, O, C> Node for ParallelNode<I, O, C>
128where
129 I: Clone + Send + Sync + fmt::Debug,
130 O: Send + Sync + fmt::Debug,
131 C: Send + Sync + fmt::Debug,
132{
133 type Input = I;
134 type Output = C;
135
136 /// Processes the given input through nodes in parallel.
137 ///
138 /// The input is processed by each node in parallel, and the results are combined
139 /// using the provided function to produce the final output.
140 #[cfg_attr(feature = "tracing", instrument)]
141 async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
142 let futures = self.nodes.iter().map(|node| {
143 let input_clone = input.clone();
144 async move { node.process(input_clone).await }
145 });
146
147 let results = try_join_all(futures);
148
149 #[cfg(feature = "tracing")]
150 let results = results.instrument(tracing::info_span!("Joining parallel node futures"));
151
152 let results = results.await?;
153
154 let combined_results = (self.function)(results);
155
156 #[cfg(feature = "tracing")]
157 let combined_results =
158 combined_results.instrument(tracing::info_span!("Combining parallel node outputs"));
159
160 combined_results.await
161 }
162}
163
164impl<I, O, C> fmt::Debug for ParallelNode<I, O, C>
165where
166 I: fmt::Debug + Clone + Send + Sync,
167 O: fmt::Debug + Send + Sync,
168 C: fmt::Debug + Send + Sync,
169{
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 f.debug_struct("ParallelNode")
172 .field("nodes", &self.nodes)
173 // Unable to debug print closures
174 .field("function", &format_args!("<function/closure>"))
175 .finish()
176 }
177}
178
179/// Converts a function into a `BoxFuture` that can be used in a `ParallelNode`.
180///
181/// This function takes a function that processes input and returns a `Result` and
182/// converts it into a boxed future.
183pub fn to_boxed_future<F, I, O>(
184 f: F,
185) -> Box<dyn Fn(I) -> BoxFuture<'static, Result<O, AnchorChainError>> + Send + Sync>
186where
187 F: Fn(I) -> Result<O, AnchorChainError> + Send + Sync + Clone + 'static,
188 I: Send + 'static,
189{
190 Box::new(move |input| {
191 let f_clone = f.clone();
192 async move { f_clone(input) }.boxed()
193 })
194}