use std::collections::HashMap;
use crate::graph::{
error::GraphError,
node::Node,
persistence::{config::RunnableConfig, store::StoreBox},
state::{State, StateUpdate},
};
pub async fn execute_nodes_parallel<S: State>(
nodes: &HashMap<String, std::sync::Arc<dyn Node<S>>>,
node_names: &[String],
state: &S,
config: Option<&RunnableConfig>,
store: Option<StoreBox>,
) -> Result<Vec<(String, StateUpdate)>, GraphError> {
let futures: Vec<_> = node_names
.iter()
.map(|node_name| {
let node_opt = nodes.get(node_name).cloned();
let node_name = node_name.clone();
let state = state.clone();
let store = store.clone();
async move {
let node = node_opt.ok_or_else(|| GraphError::NodeNotFound(node_name.clone()))?;
let update = node.invoke_with_context(&state, config, store).await?;
Ok::<(String, StateUpdate), GraphError>((node_name, update))
}
})
.collect();
let results = futures::future::join_all(futures).await;
let mut updates = Vec::new();
for result in results {
match result {
Ok(update) => updates.push(update),
Err(e) => {
return Err(e);
}
}
}
Ok(updates)
}
pub fn merge_state_updates<S: State>(
state: &S,
updates: &[(String, StateUpdate)],
) -> Result<S, GraphError> {
let mut current_state = state.clone();
for (node_name, update) in updates {
log::debug!("Merging update from node: {}", node_name);
current_state = merge_single_update(¤t_state, update)?;
}
Ok(current_state)
}
fn merge_single_update<S: State>(state: &S, update: &StateUpdate) -> Result<S, GraphError> {
let state_json = serde_json::to_value(state).map_err(GraphError::SerializationError)?;
if state_json.get("messages").is_some() {
return merge_messages_state_update(state, update);
}
let update_json = serde_json::to_value(update).map_err(GraphError::SerializationError)?;
let update_state: S = serde_json::from_value(update_json.clone()).map_err(|_| {
GraphError::ExecutionError("Cannot deserialize update as state".to_string())
})?;
Ok(state.merge(&update_state))
}
fn merge_messages_state_update<S: State>(state: &S, update: &StateUpdate) -> Result<S, GraphError> {
use crate::graph::state::{apply_update_to_messages_state, MessagesState};
let state_json = serde_json::to_value(state).map_err(GraphError::SerializationError)?;
let messages_state: MessagesState = if let Some(messages_value) = state_json.get("messages") {
if let Ok(messages) =
serde_json::from_value::<Vec<crate::schemas::messages::Message>>(messages_value.clone())
{
MessagesState::with_messages(messages)
} else {
MessagesState::new()
}
} else {
MessagesState::new()
};
let updated_state = apply_update_to_messages_state(&messages_state, update);
let updated_json =
serde_json::to_value(&updated_state).map_err(GraphError::SerializationError)?;
serde_json::from_value(updated_json).map_err(GraphError::SerializationError)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::node::Node;
use crate::graph::{function_node, state::MessagesState};
use std::sync::Arc;
#[tokio::test]
async fn test_execute_nodes_parallel() {
let mut nodes: HashMap<String, Arc<dyn Node<MessagesState>>> = HashMap::new();
nodes.insert(
"node1".to_string(),
Arc::new(function_node("node1", |_state| async move {
let mut update = HashMap::new();
update.insert(
"messages".to_string(),
serde_json::to_value(vec![crate::schemas::messages::Message::new_ai_message(
"Node1",
)])?,
);
Ok(update)
})),
);
nodes.insert(
"node2".to_string(),
Arc::new(function_node("node2", |_state| async move {
let mut update = HashMap::new();
update.insert(
"messages".to_string(),
serde_json::to_value(vec![crate::schemas::messages::Message::new_ai_message(
"Node2",
)])?,
);
Ok(update)
})),
);
let state = MessagesState::new();
let results = execute_nodes_parallel(
&nodes,
&["node1".to_string(), "node2".to_string()],
&state,
None,
None,
)
.await
.unwrap();
assert_eq!(results.len(), 2);
}
}