use async_trait::async_trait;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::error::NodeError;
use crate::graph::{NodeExecutor, NodeOutput};
use crate::state::SharedState;
pub type AsyncNodeFn =
Arc<dyn Fn(SharedState) -> Pin<Box<dyn Future<Output = Result<NodeOutput, NodeError>> + Send>> + Send + Sync>;
pub struct FunctionNode {
id: String,
func: AsyncNodeFn,
description: Option<String>,
}
impl FunctionNode {
pub fn new<F, Fut>(id: impl Into<String>, func: F) -> Self
where
F: Fn(SharedState) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<NodeOutput, NodeError>> + Send + 'static,
{
let id = id.into();
Self {
id,
func: Arc::new(move |state| Box::pin(func(state))),
description: None,
}
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
pub fn passthrough(id: impl Into<String>) -> Self {
Self::new(id, |_state| async move { Ok(NodeOutput::cont()) })
}
pub fn finish(id: impl Into<String>) -> Self {
Self::new(id, |_state| async move { Ok(NodeOutput::finish()) })
}
pub fn route_to(id: impl Into<String>, target: impl Into<String>) -> Self {
let target = target.into();
Self::new(id, move |_state| {
let t = target.clone();
async move { Ok(NodeOutput::continue_to(t)) }
})
}
}
#[async_trait]
impl NodeExecutor for FunctionNode {
fn id(&self) -> &str {
&self.id
}
async fn execute(&self, state: SharedState) -> Result<NodeOutput, NodeError> {
(self.func)(state).await
}
fn description(&self) -> Option<&str> {
self.description.as_deref()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::AgentState;
use std::sync::RwLock;
#[tokio::test]
async fn test_function_node() {
let node = FunctionNode::new("test", |state| async move {
let mut guard = state
.write()
.map_err(|e| NodeError::execution_failed(e.to_string()))?;
guard.set_context("executed", true);
Ok(NodeOutput::cont())
});
let state = Arc::new(RwLock::new(AgentState::new()));
let result = node.execute(state.clone()).await.unwrap();
assert!(!result.is_terminal());
let guard = state.read().unwrap();
assert_eq!(guard.get_context::<bool>("executed"), Some(true));
}
#[tokio::test]
async fn test_passthrough_node() {
let node = FunctionNode::passthrough("pass");
let state = Arc::new(RwLock::new(AgentState::new()));
let result = node.execute(state).await.unwrap();
assert!(!result.is_terminal());
assert!(result.target().is_none());
}
}