Skip to main content

lellm_graph/node/
subgraph_spec.rs

1//! SubgraphSpec — Builder 阶段的强类型 Subgraph 描述。
2//!
3//! # 设计理念
4//!
5//! ```text
6//! Builder 阶段:
7//!   SubgraphSpec<Outer, Inner, M, Lens>  (强类型)
8//!
9//! 编译阶段:
10//!   CompiledSubgraph<Outer>  (类型擦除 Inner/Lens/M)
11//!
12//! Engine 执行:
13//!   match node.kind {
14//!       NodeKind::Subgraph(spec) => self.execute_subgraph(spec).await,
15//!   }
16//! ```
17//!
18//! # 与 CompiledSubgraph 的区别
19//!
20//! - SubgraphSpec:Builder 阶段,强类型,包含 Graph + Lens
21//! - CompiledSubgraph:编译后,类型擦除,可存入 NodeKind
22//! - SubgraphSpec 实现 `StateProjector` trait,可转换为 CompiledSubgraph
23//!
24//! # 状态投影
25//!
26//! 通过 `StateLens` 从外层 State 投影出内层 State:
27//!
28//! ```text
29//! WorkflowState
30//!     ↓ StateLens
31//! &mut AgentState
32//!     ↓
33//! Agent Graph 操作
34//!     ↓ 借用结束
35//! WorkflowState 继续
36//! ```
37
38use std::future::Future;
39use std::marker::PhantomData;
40use std::pin::Pin;
41use std::sync::Arc;
42
43use super::compiled_subgraph::{CompiledSubgraph, StateProjector};
44use crate::error::GraphError;
45use crate::graph::Graph;
46use crate::state::state_lens::StateLens;
47use crate::state::workflow_state::{MergeStrategy, WorkflowState};
48use crate::stream_emitter::StreamSink;
49use tokio_util::sync::CancellationToken;
50
51/// Subgraph Builder 描述 — 强类型,包含 Graph + Lens。
52///
53/// # 泛型参数
54///
55/// - `Outer` — 外层 State 类型(如 WorkflowState)
56/// - `Inner` — 内层 State 类型(如 AgentState)
57/// - `M` — MergeStrategy 实现(用于 Graph)
58/// - `L` — StateLens 实现,用于状态投影
59///
60/// # 使用方式
61///
62/// ```ignore
63/// let spec = SubgraphSpec::new(agent_graph, AgentLens);
64/// let compiled: CompiledSubgraph<WorkflowState> = spec.compile();
65/// ```
66pub struct SubgraphSpec<
67    Outer: WorkflowState,
68    Inner: WorkflowState,
69    M: MergeStrategy<Inner>,
70    L: StateLens<Outer, Inner>,
71> {
72    /// 内层 Graph — Arc 共享,与 AgentBuilder::build() 返回类型一致(D10)。
73    pub graph: Arc<Graph<Inner, M>>,
74
75    /// 状态投影器
76    pub lens: L,
77
78    /// 最大执行步数
79    pub max_steps: usize,
80
81    /// PhantomData
82    _phantom: PhantomData<Outer>,
83}
84
85impl<
86    Outer: WorkflowState,
87    Inner: WorkflowState,
88    M: MergeStrategy<Inner>,
89    L: StateLens<Outer, Inner>,
90> SubgraphSpec<Outer, Inner, M, L>
91where
92    Outer: 'static,
93    Inner: 'static,
94    M: 'static,
95    L: 'static,
96{
97    /// 创建新的 SubgraphSpec。
98    ///
99    /// # 参数
100    ///
101    /// - `graph` — 内层 Graph(Arc 共享,与 AgentBuilder::build() 返回类型一致)
102    /// - `lens` — 状态投影器
103    ///
104    /// # 示例
105    ///
106    /// ```ignore
107    /// let agent_graph = AgentBuilder::new(model).tools([...]).build();
108    /// let spec = SubgraphSpec::new(agent_graph, AgentLens);
109    /// // agent_graph 仍然是 Arc<Graph<...>>,可直接传入,无需 clone
110    /// ```
111    pub fn new(graph: Arc<Graph<Inner, M>>, lens: L) -> Self {
112        Self {
113            graph,
114            lens,
115            max_steps: 1000, // 默认最大步数
116            _phantom: PhantomData,
117        }
118    }
119
120    /// 设置最大执行步数。
121    pub fn max_steps(mut self, max: usize) -> Self {
122        self.max_steps = max;
123        self
124    }
125
126    /// 通过 Lens 投影状态。
127    ///
128    /// 从外层 State 投影出内层 State 的可变引用。
129    pub fn project<'a>(&self, outer: &'a mut Outer) -> &'a mut Inner {
130        self.lens.get(outer)
131    }
132
133    /// 编译为 CompiledSubgraph — 类型擦除 Inner/Lens/M。
134    pub fn compile(self) -> CompiledSubgraph<Outer> {
135        let max_steps = self.max_steps;
136        CompiledSubgraph::new(Arc::new(self), max_steps)
137    }
138}
139
140// ─── StateProjector 实现 ──────────────────────────────────────
141
142impl<
143    Outer: WorkflowState,
144    Inner: WorkflowState,
145    M: MergeStrategy<Inner>,
146    L: StateLens<Outer, Inner>,
147> StateProjector<Outer> for SubgraphSpec<Outer, Inner, M, L>
148where
149    Inner: 'static,
150    M: 'static,
151    L: 'static,
152{
153    /// 执行 Subgraph — 投影状态 + 递归执行内层 Graph。
154    ///
155    /// # 执行流程
156    ///
157    /// 1. 通过 Lens 投影出内层 State(`&mut Inner`)
158    /// 2. 创建内层 ExecutionEngine(借用 `&mut Inner`)
159    /// 3. 调用 `graph.run_inline()`
160    /// 4. inner_engine drop → 借用释放 → outer 可继续使用
161    fn execute<'a>(
162        &'a self,
163        outer: &'a mut Outer,
164        stream: Option<Arc<dyn StreamSink>>,
165        cancel: CancellationToken,
166    ) -> Pin<Box<dyn Future<Output = Result<(), GraphError>> + Send + 'a>> {
167        Box::pin(async move {
168            // 1. 通过 Lens 投影出内层 State
169            let inner_ref = self.lens.get(outer);
170
171            // 2. 创建内层 ExecutionEngine(借用 inner_ref)
172            // Subgraph 内部不需要自动 checkpoint/barrier,传 None
173            let mut inner_engine =
174                crate::ExecutionEngine::new(inner_ref, stream, cancel, None, None);
175
176            // 3. 执行内层 Graph(Subgraph 内部不需要 step 回调)
177            let mut cb = crate::graph::NoopStepCallback;
178            self.graph
179                .run_inline(&mut inner_engine, self.max_steps, &mut cb)
180                .await?;
181
182            // 4. inner_engine drop → 借用释放 → outer 可继续使用
183            Ok(())
184        })
185    }
186
187    fn graph_name(&self) -> &str {
188        self.graph.name()
189    }
190
191    fn node_count(&self) -> usize {
192        self.graph.node_names().len()
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::state::State;
200
201    #[derive(Debug, PartialEq)]
202    struct OuterState {
203        inner: InnerState,
204    }
205
206    #[derive(Debug, PartialEq)]
207    struct InnerState {
208        value: i32,
209    }
210
211    struct TestLens;
212
213    impl StateLens<OuterState, InnerState> for TestLens {
214        fn get<'a>(&self, outer: &'a mut OuterState) -> &'a mut InnerState {
215            &mut outer.inner
216        }
217    }
218
219    #[test]
220    fn test_subgraph_spec_projection() {
221        let mut outer = OuterState {
222            inner: InnerState { value: 42 },
223        };
224
225        // 测试 Lens 投影
226        let lens = TestLens;
227        let inner = lens.get(&mut outer);
228
229        // 修改 inner
230        inner.value = 100;
231
232        // 验证 outer.inner 被修改
233        assert_eq!(outer.inner.value, 100);
234    }
235}