use crate::error::GraphResult;
use crate::state::{GraphRunContext, GraphState};
use async_trait::async_trait;
use std::future::Future;
use std::marker::PhantomData;
use std::sync::Arc;
pub enum NodeResult<State, Deps, End> {
Next(Box<dyn BaseNode<State, Deps, End>>),
NextNamed(String),
End(End),
}
impl<State, Deps, End> NodeResult<State, Deps, End> {
pub fn next<N: BaseNode<State, Deps, End> + 'static>(node: N) -> Self {
Self::Next(Box::new(node))
}
pub fn next_named(name: impl Into<String>) -> Self {
Self::NextNamed(name.into())
}
pub fn end(value: End) -> Self {
Self::End(value)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct End<T>(pub T);
impl<T> End<T> {
pub fn new(value: T) -> Self {
Self(value)
}
pub fn into_inner(self) -> T {
self.0
}
pub fn value(&self) -> &T {
&self.0
}
}
impl<T: Default> Default for End<T> {
fn default() -> Self {
Self(T::default())
}
}
#[async_trait]
pub trait BaseNode<State, Deps = (), End = ()>: Send + Sync {
fn type_name(&self) -> &'static str {
std::any::type_name::<Self>()
}
fn name(&self) -> &str {
self.type_name()
}
async fn run(
&self,
ctx: &mut GraphRunContext<State, Deps>,
) -> GraphResult<NodeResult<State, Deps, End>>;
}
#[async_trait]
pub trait Node<State: GraphState>: Send + Sync {
async fn execute(&self, state: State) -> GraphResult<State>;
fn name(&self) -> &str;
}
pub struct FunctionNode<State, F, Fut>
where
F: Fn(State) -> Fut + Send + Sync,
Fut: Future<Output = GraphResult<State>> + Send,
{
name: String,
func: F,
_phantom: PhantomData<State>,
}
impl<State, F, Fut> FunctionNode<State, F, Fut>
where
F: Fn(State) -> Fut + Send + Sync,
Fut: Future<Output = GraphResult<State>> + Send,
{
pub fn new(name: impl Into<String>, func: F) -> Self {
Self {
name: name.into(),
func,
_phantom: PhantomData,
}
}
}
#[async_trait]
impl<State, F, Fut> Node<State> for FunctionNode<State, F, Fut>
where
State: GraphState,
F: Fn(State) -> Fut + Send + Sync,
Fut: Future<Output = GraphResult<State>> + Send,
{
async fn execute(&self, state: State) -> GraphResult<State> {
(self.func)(state).await
}
fn name(&self) -> &str {
&self.name
}
}
#[allow(dead_code)]
pub struct AgentNode<State, Agent, UpdateFn>
where
UpdateFn: Fn(State, &Agent) -> State + Send + Sync,
{
name: String,
agent: Arc<Agent>,
update_state: UpdateFn,
_phantom: PhantomData<State>,
}
impl<State, Agent, UpdateFn> AgentNode<State, Agent, UpdateFn>
where
UpdateFn: Fn(State, &Agent) -> State + Send + Sync,
{
pub fn new(name: impl Into<String>, agent: Agent, update_state: UpdateFn) -> Self {
Self {
name: name.into(),
agent: Arc::new(agent),
update_state,
_phantom: PhantomData,
}
}
pub fn agent(&self) -> &Agent {
&self.agent
}
}
pub struct RouterNode<State, F>
where
F: Fn(&State) -> String + Send + Sync,
{
#[allow(dead_code)]
name: String,
router: F,
_phantom: PhantomData<State>,
}
impl<State, F> RouterNode<State, F>
where
F: Fn(&State) -> String + Send + Sync,
{
pub fn new(name: impl Into<String>, router: F) -> Self {
Self {
name: name.into(),
router,
_phantom: PhantomData,
}
}
pub fn route(&self, state: &State) -> String {
(self.router)(state)
}
}
#[allow(dead_code)]
pub struct ConditionalNode<State, Cond, Then, Else>
where
Cond: Fn(&State) -> bool + Send + Sync,
Then: BaseNode<State> + 'static,
Else: BaseNode<State> + 'static,
{
name: String,
condition: Cond,
then_node: Box<Then>,
else_node: Box<Else>,
_phantom: PhantomData<State>,
}
impl<State, Cond, Then, Else> ConditionalNode<State, Cond, Then, Else>
where
Cond: Fn(&State) -> bool + Send + Sync,
Then: BaseNode<State> + 'static,
Else: BaseNode<State> + 'static,
{
pub fn new(name: impl Into<String>, condition: Cond, then_node: Then, else_node: Else) -> Self {
Self {
name: name.into(),
condition,
then_node: Box::new(then_node),
else_node: Box::new(else_node),
_phantom: PhantomData,
}
}
}
pub struct NodeDef<State, Deps = (), End = ()> {
pub name: String,
pub node: Box<dyn BaseNode<State, Deps, End>>,
}
impl<State, Deps, End> NodeDef<State, Deps, End> {
pub fn new<N: BaseNode<State, Deps, End> + 'static>(name: impl Into<String>, node: N) -> Self {
Self {
name: name.into(),
node: Box::new(node),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, Default)]
struct TestState {
value: i32,
}
#[test]
fn test_end_marker() {
let end = End::new(42);
assert_eq!(end.value(), &42);
assert_eq!(end.into_inner(), 42);
}
#[test]
fn test_node_result_variants() {
let _next_named: NodeResult<TestState, (), i32> = NodeResult::next_named("next");
let _end: NodeResult<TestState, (), i32> = NodeResult::end(42);
}
#[test]
fn test_router_node() {
let router = RouterNode::new("router", |state: &TestState| {
if state.value > 0 {
"positive".to_string()
} else {
"negative".to_string()
}
});
assert_eq!(router.route(&TestState { value: 1 }), "positive");
assert_eq!(router.route(&TestState { value: -1 }), "negative");
}
}