1use std::any::Any;
2use std::collections::HashMap;
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use nuro_core::{Agent, AgentContext, AgentInput, Result};
8
9use crate::GraphStateTrait;
10
11#[derive(Default)]
14pub struct NodeContext {
15 data: HashMap<String, Box<dyn Any + Send + Sync>>,
16}
17
18impl NodeContext {
19 pub fn new() -> Self {
20 Self { data: HashMap::new() }
21 }
22
23 pub fn insert<T>(&mut self, key: impl Into<String>, value: T)
25 where
26 T: Send + Sync + 'static,
27 {
28 self.data.insert(key.into(), Box::new(value));
29 }
30
31 pub fn get<T>(&self, key: &str) -> Option<&T>
33 where
34 T: 'static,
35 {
36 self.data.get(key).and_then(|b| b.downcast_ref::<T>())
37 }
38}
39
40#[async_trait]
42pub trait GraphNode<S>: Send + Sync
43where
44 S: GraphStateTrait,
45{
46 async fn run(&self, state: &S, ctx: &mut NodeContext) -> Result<S::Update>;
47}
48
49pub struct FnNode<S, F>
53where
54 S: GraphStateTrait,
55 F: Fn(&S, &mut NodeContext) -> S::Update + Send + Sync + 'static,
56{
57 f: F,
58 _marker: PhantomData<S>,
59}
60
61impl<S, F> FnNode<S, F>
62where
63 S: GraphStateTrait,
64 F: Fn(&S, &mut NodeContext) -> S::Update + Send + Sync + 'static,
65{
66 pub fn new(f: F) -> Self {
67 Self { f, _marker: PhantomData }
68 }
69}
70
71#[async_trait]
72impl<S, F> GraphNode<S> for FnNode<S, F>
73where
74 S: GraphStateTrait,
75 F: Fn(&S, &mut NodeContext) -> S::Update + Send + Sync + 'static,
76{
77 async fn run(&self, state: &S, ctx: &mut NodeContext) -> Result<S::Update> {
78 Ok((self.f)(state, ctx))
79 }
80}
81
82pub struct AgentNode<A, S>
88where
89 A: Agent + 'static,
90 S: GraphStateTrait,
91 S::Update: Default,
92{
93 agent: Arc<A>,
94 _marker: PhantomData<S>,
95}
96
97impl<A, S> AgentNode<A, S>
98where
99 A: Agent + 'static,
100 S: GraphStateTrait,
101 S::Update: Default,
102{
103 pub fn new(agent: A) -> Self {
104 Self {
105 agent: Arc::new(agent),
106 _marker: PhantomData,
107 }
108 }
109}
110
111#[async_trait]
112impl<A, S> GraphNode<S> for AgentNode<A, S>
113where
114 A: Agent + 'static,
115 S: GraphStateTrait,
116 S::Update: Default,
117{
118 async fn run(&self, _state: &S, _ctx: &mut NodeContext) -> Result<S::Update> {
119 let mut ctx = AgentContext::new();
122 let _ = self
123 .agent
124 .invoke(AgentInput::Text("(graph node input)".to_string()), &mut ctx)
125 .await;
126
127 Ok(S::Update::default())
128 }
129}