Skip to main content

nuro_graph/
node.rs

1use std::any::Any;
2use std::collections::HashMap;
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use nuro_core::{Agent, AgentContext, AgentInput, Result};
8
9use crate::GraphStateTrait;
10
11/// 节点执行上下文:提供一个简单的、基于字符串 key 的类型安全存取接口,
12/// 方便在节点之间共享少量辅助数据(如 LLM Provider、计数器等)。
13#[derive(Default)]
14pub struct NodeContext {
15    data: HashMap<String, Box<dyn Any + Send + Sync>>,    
16}
17
18impl NodeContext {
19    pub fn new() -> Self {
20        Self { data: HashMap::new() }
21    }
22
23    /// 按 key 存入一个任意类型的值。
24    pub fn insert<T>(&mut self, key: impl Into<String>, value: T)
25    where
26        T: Send + Sync + 'static,
27    {
28        self.data.insert(key.into(), Box::new(value));
29    }
30
31    /// 按 key 以引用形式取出指定类型的值。
32    pub fn get<T>(&self, key: &str) -> Option<&T>
33    where
34        T: 'static,
35    {
36        self.data.get(key).and_then(|b| b.downcast_ref::<T>())
37    }
38}
39
40/// 图节点抽象:给定当前状态与上下文,返回一个状态增量。
41#[async_trait]
42pub trait GraphNode<S>: Send + Sync
43where
44    S: GraphStateTrait,
45{
46    async fn run(&self, state: &S, ctx: &mut NodeContext) -> Result<S::Update>;
47}
48
49/// 使用闭包实现的节点适配器。
50///
51/// 闭包为同步函数:方便在 demo 与简单业务中快速定义节点逻辑。
52pub struct FnNode<S, F>
53where
54    S: GraphStateTrait,
55    F: Fn(&S, &mut NodeContext) -> S::Update + Send + Sync + 'static,
56{
57    f: F,
58    _marker: PhantomData<S>,
59}
60
61impl<S, F> FnNode<S, F>
62where
63    S: GraphStateTrait,
64    F: Fn(&S, &mut NodeContext) -> S::Update + Send + Sync + 'static,
65{
66    pub fn new(f: F) -> Self {
67        Self { f, _marker: PhantomData }
68    }
69}
70
71#[async_trait]
72impl<S, F> GraphNode<S> for FnNode<S, F>
73where
74    S: GraphStateTrait,
75    F: Fn(&S, &mut NodeContext) -> S::Update + Send + Sync + 'static,
76{
77    async fn run(&self, state: &S, ctx: &mut NodeContext) -> Result<S::Update> {
78        Ok((self.f)(state, ctx))
79    }
80}
81
82/// 使用 `Agent` 适配为图节点的占位实现。
83///
84/// 当前版本中,`AgentNode` 只负责调用底层 Agent,并丢弃结果,返回
85/// `Default::default()` 作为状态增量,以保证编译通过。
86/// 未来版本会扩展为可配置的输入/输出映射逻辑。
87pub struct AgentNode<A, S>
88where
89    A: Agent + 'static,
90    S: GraphStateTrait,
91    S::Update: Default,
92{
93    agent: Arc<A>,
94    _marker: PhantomData<S>,
95}
96
97impl<A, S> AgentNode<A, S>
98where
99    A: Agent + 'static,
100    S: GraphStateTrait,
101    S::Update: Default,
102{
103    pub fn new(agent: A) -> Self {
104        Self {
105            agent: Arc::new(agent),
106            _marker: PhantomData,
107        }
108    }
109}
110
111#[async_trait]
112impl<A, S> GraphNode<S> for AgentNode<A, S>
113where
114    A: Agent + 'static,
115    S: GraphStateTrait,
116    S::Update: Default,
117{
118    async fn run(&self, _state: &S, _ctx: &mut NodeContext) -> Result<S::Update> {
119        // 占位实现:目前不从 Agent 结果构造状态增量,只是走一遍调用,
120        // 以验证协议与类型形状。未来可以在这里接入真正的映射逻辑。
121        let mut ctx = AgentContext::new();
122        let _ = self
123            .agent
124            .invoke(AgentInput::Text("(graph node input)".to_string()), &mut ctx)
125            .await;
126
127        Ok(S::Update::default())
128    }
129}