use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use super::Node;
use crate::graph::{
compiled::CompiledGraph,
error::GraphError,
persistence::{config::RunnableConfig, store::StoreBox},
state::State,
StateUpdate,
};
pub struct SubgraphNode<S: State + 'static> {
subgraph: Arc<CompiledGraph<S>>,
name: String,
}
impl<S: State + 'static> SubgraphNode<S> {
pub fn new(name: impl Into<String>, subgraph: CompiledGraph<S>) -> Self {
Self {
subgraph: Arc::new(subgraph),
name: name.into(),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn subgraph(&self) -> &CompiledGraph<S> {
&self.subgraph
}
}
#[async_trait]
impl<S: State + 'static> Node<S> for SubgraphNode<S> {
async fn invoke(&self, state: &S) -> Result<StateUpdate, GraphError> {
let final_state = self.subgraph.invoke(state.clone()).await?;
let state_json =
serde_json::to_value(final_state).map_err(GraphError::SerializationError)?;
let mut update = HashMap::new();
if let serde_json::Value::Object(map) = state_json {
for (key, value) in map {
update.insert(key, value);
}
}
Ok(update)
}
async fn invoke_with_context(
&self,
state: &S,
config: Option<&RunnableConfig>,
_store: Option<StoreBox>,
) -> Result<StateUpdate, GraphError> {
let final_state = if let Some(config) = config {
self.subgraph
.invoke_with_config(Some(state.clone()), config)
.await?
} else {
self.subgraph.invoke(state.clone()).await?
};
let state_json =
serde_json::to_value(final_state).map_err(GraphError::SerializationError)?;
let mut update = HashMap::new();
if let serde_json::Value::Object(map) = state_json {
for (key, value) in map {
update.insert(key, value);
}
}
Ok(update)
}
fn get_subgraph(&self) -> Option<Arc<CompiledGraph<S>>> {
Some(self.subgraph.clone())
}
}
pub struct SubgraphNodeWithTransform<ParentState: State + 'static, SubState: State + 'static> {
subgraph: Arc<CompiledGraph<SubState>>,
name: String,
transform_in: Arc<dyn Fn(&ParentState) -> Result<SubState, GraphError> + Send + Sync>,
transform_out: Arc<dyn Fn(&SubState) -> Result<StateUpdate, GraphError> + Send + Sync>,
}
impl<ParentState: State + 'static, SubState: State + 'static>
SubgraphNodeWithTransform<ParentState, SubState>
{
pub fn new(
name: impl Into<String>,
subgraph: CompiledGraph<SubState>,
transform_in: impl Fn(&ParentState) -> Result<SubState, GraphError> + Send + Sync + 'static,
transform_out: impl Fn(&SubState) -> Result<StateUpdate, GraphError> + Send + Sync + 'static,
) -> Self {
Self {
subgraph: Arc::new(subgraph),
name: name.into(),
transform_in: Arc::new(transform_in),
transform_out: Arc::new(transform_out),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn subgraph(&self) -> &CompiledGraph<SubState> {
&self.subgraph
}
}
#[async_trait]
impl<ParentState: State + 'static, SubState: State + 'static> Node<ParentState>
for SubgraphNodeWithTransform<ParentState, SubState>
{
async fn invoke(&self, state: &ParentState) -> Result<StateUpdate, GraphError> {
let sub_state = (self.transform_in)(state)?;
let final_sub_state = self.subgraph.invoke(sub_state).await?;
(self.transform_out)(&final_sub_state)
}
async fn invoke_with_context(
&self,
state: &ParentState,
config: Option<&RunnableConfig>,
_store: Option<StoreBox>,
) -> Result<StateUpdate, GraphError> {
let sub_state = (self.transform_in)(state)?;
let final_sub_state = if let Some(config) = config {
self.subgraph
.invoke_with_config(Some(sub_state), config)
.await?
} else {
self.subgraph.invoke(sub_state).await?
};
(self.transform_out)(&final_sub_state)
}
}