anchor_chain/
parallel_node.rs

1//! Provides a structure for processing input through multiple nodes in parallel.
2//!
3//! The `ParallelNode` struct represents a node that processes input through
4//! multiple nodes in parallel. The output of each node is then combined using
5//! a provided function to produce the final output.
6//!
7//! Example:
8//! ```rust,no_run
9//! use async_trait::async_trait;
10//! use futures::{future::BoxFuture, Future};
11//! use std::collections::HashMap;
12//!
13//! use anchor_chain::{
14//!     chain::ChainBuilder,
15//!     models::{claude_3::Claude3Bedrock, openai::OpenAIModel},
16//!     parallel_node::{ParallelNode, to_boxed_future},
17//!     nodes::prompt::Prompt,
18//! };
19//!
20//! #[tokio::main]
21//! async fn main() {
22//!     let gpt3 =
23//!         Box::new(OpenAIModel::new_gpt3_5_turbo("You are a helpful assistant".to_string()).await);
24//!     let claude3 = Box::new(Claude3Bedrock::new("You are a helpful assistant".to_string()).await);
25//!
26//!     let concat_fn = to_boxed_future(|outputs: Vec<String>| {
27//!         Ok(outputs
28//!             .iter()
29//!             .enumerate()
30//!             .map(|(i, output)| format!("Output {}:\n```\n{}\n```\n", i + 1, output))
31//!             .collect::<Vec<String>>()
32//!             .concat())
33//!     });
34//!
35//!
36//!     let chain = ChainBuilder::new()
37//!         .link(Prompt::new("{{ input }}"))
38//!         .link(ParallelNode::new(vec![gpt3, claude3], concat_fn))
39//!         .build();
40//!
41//!     let output = chain
42//!         .process(HashMap::from([("input".to_string(), "Write a hello world program in Rust".to_string())]))
43//!         .await
44//!         .expect("Error processing chain");
45//!     println!("{}", output);
46//! }
47//! ```
48
49use async_trait::async_trait;
50use futures::future::try_join_all;
51use futures::{future::BoxFuture, FutureExt};
52use std::fmt;
53#[cfg(feature = "tracing")]
54use tracing::{instrument, Instrument};
55
56use crate::error::AnchorChainError;
57use crate::node::Node;
58
59/// A function that combines the output of multiple nodes.
60///
61/// The function takes a vector of outputs from multiple nodes and returns a
62/// `Result` containing the final output. The BoxFuture can be created using
63/// the `to_boxed_future` helper function.
64type CombinationFunction<I, O> =
65    Box<dyn Fn(Vec<I>) -> BoxFuture<'static, Result<O, AnchorChainError>> + Send + Sync>;
66
67/// A node that processes input through multiple nodes in parallel.
68///
69/// The `ParallelNode` struct represents a node that processes input through
70/// multiple nodes in parallel. The output of each node is then combined using
71/// a provided function to produce the final output.
72pub struct ParallelNode<I, O, C>
73where
74    I: Clone + Send + Sync,
75    O: Send + Sync,
76    C: Send + Sync,
77{
78    /// The nodes that will process the input in parallel.
79    pub nodes: Vec<Box<dyn Node<Input = I, Output = O> + Send + Sync>>,
80    /// The function to process the output of the nodes.
81    pub function: CombinationFunction<O, C>,
82}
83
84impl<I, O, C> ParallelNode<I, O, C>
85where
86    I: Clone + Send + Sync,
87    O: Send + Sync,
88    C: Send + Sync,
89{
90    /// Creates a new `ParallelNode` with the provided nodes and combination
91    /// function.
92    ///
93    /// The combination function can be defined using the helper function `to_boxed_future`.
94    ///
95    /// # Example
96    /// // Using PassThroughNode as an example node
97    /// ```rust
98    /// use anchor_chain::{
99    ///     node::NoOpNode,
100    ///     parallel_node::ParallelNode,
101    ///     parallel_node::to_boxed_future
102    /// };
103    ///
104    /// #[tokio::main]
105    /// async fn main() {
106    ///     let node1 = Box::new(NoOpNode::new());
107    ///     let node2 = Box::new(NoOpNode::new());
108    ///     let concat_fn = to_boxed_future(|outputs: Vec<String>| {
109    ///         Ok(outputs
110    ///            .iter()
111    ///            .enumerate()
112    ///            .map(|(i, output)| format!("Output {}:\n```\n{}\n```\n", i + 1, output))
113    ///            .collect::<Vec<String>>()
114    ///            .concat())
115    ///     });
116    ///     let parallel_node = ParallelNode::new(vec![node1, node2], concat_fn);
117    /// }
118    pub fn new(
119        nodes: Vec<Box<dyn Node<Input = I, Output = O> + Send + Sync>>,
120        function: CombinationFunction<O, C>,
121    ) -> Self {
122        ParallelNode { nodes, function }
123    }
124}
125
126#[async_trait]
127impl<I, O, C> Node for ParallelNode<I, O, C>
128where
129    I: Clone + Send + Sync + fmt::Debug,
130    O: Send + Sync + fmt::Debug,
131    C: Send + Sync + fmt::Debug,
132{
133    type Input = I;
134    type Output = C;
135
136    /// Processes the given input through nodes in parallel.
137    ///
138    /// The input is processed by each node in parallel, and the results are combined
139    /// using the provided function to produce the final output.
140    #[cfg_attr(feature = "tracing", instrument)]
141    async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
142        let futures = self.nodes.iter().map(|node| {
143            let input_clone = input.clone();
144            async move { node.process(input_clone).await }
145        });
146
147        let results = try_join_all(futures);
148
149        #[cfg(feature = "tracing")]
150        let results = results.instrument(tracing::info_span!("Joining parallel node futures"));
151
152        let results = results.await?;
153
154        let combined_results = (self.function)(results);
155
156        #[cfg(feature = "tracing")]
157        let combined_results =
158            combined_results.instrument(tracing::info_span!("Combining parallel node outputs"));
159
160        combined_results.await
161    }
162}
163
164impl<I, O, C> fmt::Debug for ParallelNode<I, O, C>
165where
166    I: fmt::Debug + Clone + Send + Sync,
167    O: fmt::Debug + Send + Sync,
168    C: fmt::Debug + Send + Sync,
169{
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        f.debug_struct("ParallelNode")
172            .field("nodes", &self.nodes)
173            // Unable to debug print closures
174            .field("function", &format_args!("<function/closure>"))
175            .finish()
176    }
177}
178
179/// Converts a function into a `BoxFuture` that can be used in a `ParallelNode`.
180///
181/// This function takes a function that processes input and returns a `Result` and
182/// converts it into a boxed future.
183pub fn to_boxed_future<F, I, O>(
184    f: F,
185) -> Box<dyn Fn(I) -> BoxFuture<'static, Result<O, AnchorChainError>> + Send + Sync>
186where
187    F: Fn(I) -> Result<O, AnchorChainError> + Send + Sync + Clone + 'static,
188    I: Send + 'static,
189{
190    Box::new(move |input| {
191        let f_clone = f.clone();
192        async move { f_clone(input) }.boxed()
193    })
194}