use async_trait::async_trait;
use super::state::{StateSchema, StateUpdate};
use super::errors::GraphError;
use std::future::Future;
use std::pin::Pin;
use std::marker::PhantomData;
#[async_trait]
pub trait GraphNode<S: StateSchema>: Send + Sync {
async fn execute(
&self,
state: &S,
config: Option<NodeConfig>,
) -> Result<StateUpdate<S>, GraphError>;
fn name(&self) -> &str;
}
#[derive(Debug, Clone, Default)]
pub struct NodeConfig {
pub recursion_limit: usize,
pub metadata: std::collections::HashMap<String, serde_json::Value>,
pub debug: bool,
}
pub type NodeResult<S> = Result<StateUpdate<S>, GraphError>;
pub type AsyncNodeFn<S> = Box<dyn Fn(&S) -> Pin<Box<dyn Future<Output = NodeResult<S>> + Send>> + Send + Sync>;
pub trait AsyncFn<S: StateSchema>: Send + Sync {
fn call(&self, state: &S) -> Pin<Box<dyn Future<Output = NodeResult<S>> + Send>>;
}
impl<S: StateSchema, F, Fut> AsyncFn<S> for F
where
F: Fn(&S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = NodeResult<S>> + Send + 'static,
{
fn call(&self, state: &S) -> Pin<Box<dyn Future<Output = NodeResult<S>> + Send>> {
Box::pin((self)(state))
}
}
pub struct AsyncNode<S: StateSchema, F: AsyncFn<S>> {
name: String,
func: F,
_marker: PhantomData<S>,
}
impl<S: StateSchema, F: AsyncFn<S>> AsyncNode<S, F> {
pub fn new(name: impl Into<String>, func: F) -> Self {
Self {
name: name.into(),
func,
_marker: PhantomData,
}
}
}
#[async_trait]
impl<S: StateSchema, F: AsyncFn<S>> GraphNode<S> for AsyncNode<S, F> {
async fn execute(&self, state: &S, _config: Option<NodeConfig>) -> NodeResult<S> {
self.func.call(state).await
}
fn name(&self) -> &str {
&self.name
}
}
pub struct FunctionNode<S: StateSchema, F> {
name: String,
func: F,
_marker: PhantomData<S>,
}
impl<S: StateSchema, F> FunctionNode<S, F>
where
F: Fn(&S) -> Pin<Box<dyn Future<Output = Result<StateUpdate<S>, GraphError>> + Send>> + Send + Sync,
{
pub fn new(name: impl Into<String>, func: F) -> Self {
Self {
name: name.into(),
func,
_marker: PhantomData,
}
}
}
#[async_trait]
impl<S: StateSchema, F> GraphNode<S> for FunctionNode<S, F>
where
F: Fn(&S) -> Pin<Box<dyn Future<Output = Result<StateUpdate<S>, GraphError>> + Send>> + Send + Sync,
{
async fn execute(
&self,
state: &S,
_config: Option<NodeConfig>,
) -> Result<StateUpdate<S>, GraphError> {
(self.func)(state).await
}
fn name(&self) -> &str {
&self.name
}
}
pub struct SyncNode<S: StateSchema, F> {
name: String,
func: F,
_marker: PhantomData<S>,
}
impl<S: StateSchema, F> SyncNode<S, F>
where
F: Fn(&S) -> Result<StateUpdate<S>, GraphError> + Send + Sync,
{
pub fn new(name: impl Into<String>, func: F) -> Self {
Self {
name: name.into(),
func,
_marker: PhantomData,
}
}
}
#[async_trait]
impl<S: StateSchema, F> GraphNode<S> for SyncNode<S, F>
where
F: Fn(&S) -> Result<StateUpdate<S>, GraphError> + Send + Sync,
{
async fn execute(
&self,
state: &S,
_config: Option<NodeConfig>,
) -> Result<StateUpdate<S>, GraphError> {
(self.func)(state)
}
fn name(&self) -> &str {
&self.name
}
}
pub struct SentinelNode {
name: String,
}
impl SentinelNode {
pub fn start() -> Self {
Self { name: super::START.to_string() }
}
pub fn end() -> Self {
Self { name: super::END.to_string() }
}
pub fn custom(name: impl Into<String>) -> Self {
Self { name: name.into() }
}
}
#[async_trait]
impl<S: StateSchema> GraphNode<S> for SentinelNode {
async fn execute(
&self,
state: &S,
_config: Option<NodeConfig>,
) -> Result<StateUpdate<S>, GraphError> {
Ok(StateUpdate::unchanged())
}
fn name(&self) -> &str {
&self.name
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::state::AgentState;
#[tokio::test]
async fn test_sync_node() {
let node = SyncNode::new("test", |state: &AgentState| {
Ok(StateUpdate::full(AgentState::new(state.input.clone())))
});
let state = AgentState::new("Hello".to_string());
let result = node.execute(&state, None).await;
assert!(result.is_ok());
}
#[test]
fn test_sentinel_nodes() {
let start: SentinelNode = SentinelNode::start();
assert_eq!(GraphNode::<AgentState>::name(&start), super::super::START);
let end: SentinelNode = SentinelNode::end();
assert_eq!(GraphNode::<AgentState>::name(&end), super::super::END);
}
}