use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use crate::{
agent::Agent, chain::Chain, language_models::llm::LLM, prompt::PromptArgs,
schemas::messages::Message,
};
use super::{
compiled::CompiledGraph,
error::GraphError,
persistence::{config::RunnableConfig, store::StoreBox},
state::State,
StateUpdate,
};
mod subgraph;
pub use subgraph::{SubgraphNode, SubgraphNodeWithTransform};
#[async_trait]
pub trait Node<S: State>: Send + Sync {
async fn invoke(&self, state: &S) -> Result<StateUpdate, GraphError>;
async fn invoke_with_context(
&self,
state: &S,
_config: Option<&RunnableConfig>,
_store: Option<StoreBox>,
) -> Result<StateUpdate, GraphError> {
self.invoke(state).await
}
fn get_llm(&self) -> Option<Arc<dyn LLM>> {
None
}
fn get_subgraph(&self) -> Option<Arc<CompiledGraph<S>>> {
None
}
}
pub struct FunctionNode<S: State> {
name: String,
func_state_only: Option<
Arc<
dyn Fn(
&S,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<StateUpdate, GraphError>> + Send>,
> + Send
+ Sync,
>,
>,
func_with_config: Option<
Arc<
dyn Fn(
&S,
&RunnableConfig,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<StateUpdate, GraphError>> + Send>,
> + Send
+ Sync,
>,
>,
func_with_config_store: Option<
Arc<
dyn Fn(
&S,
&RunnableConfig,
StoreBox,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<StateUpdate, GraphError>> + Send>,
> + Send
+ Sync,
>,
>,
}
impl<S: State> FunctionNode<S> {
pub fn new<F, Fut>(name: String, func: F) -> Self
where
F: Fn(&S) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<StateUpdate, GraphError>> + Send + 'static,
{
Self {
name,
func_state_only: Some(Arc::new(move |state| Box::pin(func(state)))),
func_with_config: None,
func_with_config_store: None,
}
}
pub fn with_config<F, Fut>(name: String, func: F) -> Self
where
F: Fn(&S, &RunnableConfig) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<StateUpdate, GraphError>> + Send + 'static,
{
Self {
name,
func_state_only: None,
func_with_config: Some(Arc::new(move |state, config| Box::pin(func(state, config)))),
func_with_config_store: None,
}
}
pub fn with_config_store<F, Fut>(name: String, func: F) -> Self
where
F: Fn(&S, &RunnableConfig, StoreBox) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<StateUpdate, GraphError>> + Send + 'static,
{
Self {
name,
func_state_only: None,
func_with_config: None,
func_with_config_store: Some(Arc::new(move |state, config, store| {
Box::pin(func(state, config, store))
})),
}
}
pub fn name(&self) -> &str {
&self.name
}
}
#[async_trait]
impl<S: State> Node<S> for FunctionNode<S> {
async fn invoke(&self, state: &S) -> Result<StateUpdate, GraphError> {
if let Some(ref func) = self.func_state_only {
func(state).await
} else {
Err(GraphError::ExecutionError(
"Node requires config or store, use invoke_with_context".to_string(),
))
}
}
async fn invoke_with_context(
&self,
state: &S,
config: Option<&RunnableConfig>,
store: Option<StoreBox>,
) -> Result<StateUpdate, GraphError> {
if let Some(ref func) = self.func_with_config_store {
if let (Some(config), Some(store)) = (config, store) {
return func(state, config, store.clone()).await;
} else {
return Err(GraphError::ExecutionError(
"Node requires both config and store".to_string(),
));
}
}
if let Some(ref func) = self.func_with_config {
if let Some(config) = config {
return func(state, config).await;
} else {
return Err(GraphError::ExecutionError(
"Node requires config".to_string(),
));
}
}
if let Some(ref func) = self.func_state_only {
func(state).await
} else {
Err(GraphError::ExecutionError(
"No valid function signature found".to_string(),
))
}
}
}
pub struct ChainNode {
chain: Arc<dyn Chain>,
input_key: String,
output_key: String,
}
impl ChainNode {
pub fn new(
chain: Arc<dyn Chain>,
input_key: Option<String>,
output_key: Option<String>,
) -> Self {
Self {
chain,
input_key: input_key.unwrap_or_else(|| "input".to_string()),
output_key: output_key.unwrap_or_else(|| "output".to_string()),
}
}
}
#[async_trait]
impl<S: State> Node<S> for ChainNode {
async fn invoke(&self, state: &S) -> Result<StateUpdate, GraphError> {
let state_json = serde_json::to_value(state).map_err(GraphError::SerializationError)?;
let mut prompt_args = PromptArgs::new();
if let Some(input_value) = state_json.get(&self.input_key) {
prompt_args.insert(self.input_key.clone(), input_value.clone());
} else if let Some(messages) = state_json.get("messages") {
if let Some(msg_array) = messages.as_array() {
if let Some(last_msg) = msg_array.last() {
if let Some(content) = last_msg.get("content") {
prompt_args.insert(self.input_key.clone(), content.clone());
}
}
}
}
let result = self.chain.call(prompt_args).await?;
let mut update = HashMap::new();
update.insert(
self.output_key.clone(),
serde_json::to_value(result.generation)?,
);
Ok(update)
}
}
pub struct LLMNode {
llm: Arc<dyn LLM>,
}
impl LLMNode {
pub fn new(llm: Arc<dyn LLM>) -> Self {
Self { llm }
}
}
#[async_trait]
impl<S: State> Node<S> for LLMNode {
async fn invoke(&self, state: &S) -> Result<StateUpdate, GraphError> {
let state_json = serde_json::to_value(state).map_err(GraphError::SerializationError)?;
let messages: Vec<Message> = if let Some(messages_value) = state_json.get("messages") {
serde_json::from_value(messages_value.clone())
.map_err(GraphError::SerializationError)?
} else {
vec![Message::new_human_message("")]
};
let result = self.llm.generate(&messages).await?;
let ai_message = Message::new_ai_message(&result.generation);
let mut update = HashMap::new();
update.insert(
"messages".to_string(),
serde_json::to_value(vec![ai_message])?,
);
Ok(update)
}
fn get_llm(&self) -> Option<Arc<dyn LLM>> {
Some(self.llm.clone())
}
}
pub struct AgentNode {
agent: Arc<dyn Agent>,
}
impl AgentNode {
pub fn new(agent: Arc<dyn Agent>) -> Self {
Self { agent }
}
}
#[async_trait]
impl<S: State> Node<S> for AgentNode {
async fn invoke(&self, state: &S) -> Result<StateUpdate, GraphError> {
let state_json = serde_json::to_value(state).map_err(GraphError::SerializationError)?;
let mut prompt_args = PromptArgs::new();
if let Some(input_value) = state_json.get("input") {
prompt_args.insert("input".to_string(), input_value.clone());
} else if let Some(messages) = state_json.get("messages") {
if let Some(msg_array) = messages.as_array() {
if let Some(last_msg) = msg_array.last() {
if let Some(content) = last_msg.get("content") {
prompt_args.insert("input".to_string(), content.clone());
}
}
}
}
let intermediate_steps = vec![];
let event = self.agent.plan(&intermediate_steps, prompt_args).await?;
let mut update = HashMap::new();
match event {
crate::schemas::agent::AgentEvent::Finish(finish) => {
update.insert(
"output".to_string(),
serde_json::Value::String(finish.output.clone()),
);
}
crate::schemas::agent::AgentEvent::Action(actions) => {
if let Some(action) = actions.first() {
update.insert("action".to_string(), serde_json::to_value(action)?);
}
}
}
Ok(update)
}
}
pub fn function_node<S: State, F, Fut>(name: impl Into<String>, func: F) -> FunctionNode<S>
where
F: Fn(&S) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<StateUpdate, GraphError>> + Send + 'static,
{
FunctionNode::new(name.into(), func)
}
pub fn function_node_with_config<S: State, F, Fut>(
name: impl Into<String>,
func: F,
) -> FunctionNode<S>
where
F: Fn(&S, &RunnableConfig) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<StateUpdate, GraphError>> + Send + 'static,
{
FunctionNode::with_config(name.into(), func)
}
pub fn function_node_with_store<S: State, F, Fut>(
name: impl Into<String>,
func: F,
) -> FunctionNode<S>
where
F: Fn(&S, &RunnableConfig, StoreBox) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<StateUpdate, GraphError>> + Send + 'static,
{
FunctionNode::with_config_store(name.into(), func)
}
#[cfg(test)]
mod tests_subgraph;
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::state::MessagesState;
#[tokio::test]
async fn test_function_node() {
let node = function_node("test_node", |_state: &MessagesState| async move {
let mut update = HashMap::new();
update.insert(
"messages".to_string(),
serde_json::to_value(vec![Message::new_ai_message("Hello from node")])?,
);
Ok(update)
});
let state = MessagesState::new();
let result = node.invoke(&state).await;
assert!(result.is_ok());
}
}