Skip to main content

lellm_graph/
test_executor.rs

1//! 测试用执行器 — 替代已删除的 SimpleExecutor。
2//!
3//! 提供两种执行模式:
4//! - `execute()` — 阻塞执行,返回 `GraphResult`
5//! - `execute_stream()` — 流式执行,返回 `GraphExecution { stream, handle }`
6
7use std::sync::Arc;
8use std::time::Instant;
9
10use tokio_util::sync::CancellationToken;
11
12use crate::error::GraphError;
13use crate::event::{GraphExecution, GraphHandle};
14use crate::exec::execution_engine::{ExecutionEngine, ExecutorState, NextAction};
15use crate::graph::Graph;
16use crate::ids::TraceId;
17use crate::node::{BarrierNode, ConditionNode, FlowNode, LeafNode, NodeKind};
18use crate::state::{ExecutionEntry, GraphResult, State};
19
20// ─── SimpleExecutor 兼容层 ────────────────────────────────────────
21
22/// 兼容 SimpleExecutor 的 API,供测试使用。
23///
24/// 仅支持 `Graph<State, StateMerge>`(默认泛型参数)。
25pub struct SimpleExecutor {
26    max_steps: usize,
27}
28
29impl Default for SimpleExecutor {
30    fn default() -> Self {
31        Self { max_steps: 100 }
32    }
33}
34
35impl SimpleExecutor {
36    pub fn new(max_steps: usize) -> Self {
37        Self { max_steps }
38    }
39
40    pub async fn execute(
41        &self,
42        graph: Arc<Graph>,
43        mut state: State,
44    ) -> Result<GraphResult, GraphError> {
45        let trace_id = TraceId::new();
46        let start_time = Instant::now();
47        let mut execution_log: Vec<ExecutionEntry> = Vec::new();
48
49        let cancel = CancellationToken::new();
50        // TestExecutor 不需要自动 checkpoint
51        let mut engine = ExecutionEngine::new(&mut state, None, cancel, None, None);
52
53        // 执行循环 — 与 run_inline 一致,但记录 ExecutionEntry
54        let mut current = graph.start_node().to_string();
55        let mut step: usize = 0;
56
57        loop {
58            step += 1;
59            if step > self.max_steps {
60                return Err(GraphError::Terminal(
61                    crate::error::TerminalError::StepsExceeded {
62                        limit: self.max_steps,
63                    },
64                ));
65            }
66
67            let node = match graph.nodes.get(&current) {
68                Some(n) => n,
69                None => {
70                    return Err(GraphError::Terminal(
71                        crate::error::TerminalError::NodeNotFound(current.clone()),
72                    ));
73                }
74            };
75
76            let node_name = current.clone();
77            let node_start = Instant::now();
78
79            // 根据 NodeKind 分发执行
80            match node {
81                NodeKind::Task(n) => {
82                    let mut ctx = engine.build_node_context();
83                    n.execute(&mut ctx).await?;
84                }
85                NodeKind::Condition(n) => {
86                    let mut ctx = engine.build_leaf_context();
87                    <ConditionNode as LeafNode>::execute(n, &mut ctx).await?;
88                }
89                NodeKind::Barrier(n) => {
90                    let mut ctx = engine.build_leaf_context();
91                    <BarrierNode as LeafNode>::execute(n, &mut ctx).await?;
92                }
93                NodeKind::External(n) => {
94                    let mut ctx = engine.build_node_context();
95                    n.execute(&mut ctx).await?;
96                }
97                NodeKind::ExternalLeaf(n) => {
98                    let mut ctx = engine.build_leaf_context();
99                    n.execute(&mut ctx).await?;
100                }
101                NodeKind::Parallel(p) => {
102                    // ExecutorOperation 直接接收 &mut ExecutionEngine
103                    p.execute(&mut engine).await?;
104                }
105                NodeKind::Subgraph(_subgraph) => {
106                    // TODO: 实现 Subgraph 执行
107                    // 由 ExecutionEngine 负责 Frame 管理、状态投影、Checkpoint 和恢复
108                    tracing::warn!("Subgraph execution not yet implemented");
109                }
110            }
111
112            let node_duration = node_start.elapsed();
113
114            execution_log.push(ExecutionEntry {
115                step,
116                node_name,
117                start_time: node_start,
118                end_time: start_time.checked_add(node_duration).unwrap_or(start_time),
119                success: true,
120                error: None,
121            });
122
123            // commit mutations (Unit of Work) — 对 Parallel 是空操作
124            // (replace_state 已经直接替换了状态,mutation buffer 为空)
125            engine.commit();
126
127            // 提取控制信号
128            let (next_action, _signal) = engine.take_control();
129
130            // 处理路由
131            match next_action {
132                NextAction::End => break,
133                NextAction::Goto(target) => {
134                    current = target;
135                }
136                NextAction::Next => {
137                    if current == graph.end_node() {
138                        break;
139                    }
140                    current = graph.resolve_next_inline(&current, engine.state())?;
141                }
142            }
143        }
144
145        let duration = start_time.elapsed();
146        let final_state = state;
147
148        Ok(GraphResult {
149            trace_id,
150            state: final_state,
151            execution_log,
152            duration,
153            trace: None,
154        })
155    }
156
157    pub fn execute_stream(&self, graph: Arc<Graph>, state: State) -> GraphExecution<State> {
158        self.execute_stream_with_restore(graph, state, None)
159    }
160
161    pub fn execute_stream_with_restore(
162        &self,
163        graph: Arc<Graph>,
164        state: State,
165        restore_from: Option<crate::checkpoint::Checkpoint<State>>,
166    ) -> GraphExecution<State> {
167        let (event_tx, event_rx) = tokio::sync::mpsc::channel(256);
168        let (decision_tx, decision_rx) = tokio::sync::mpsc::channel(256);
169        let (cancel_tx, cancel_rx) = tokio::sync::mpsc::channel(1);
170
171        let trace_id = TraceId::new();
172        let cancel = CancellationToken::new();
173
174        let handle = GraphHandle::new(decision_tx, cancel_tx);
175
176        tokio::spawn(crate::exec::execution_loop::run_execution_loop(
177            graph,
178            state,
179            self.max_steps,
180            trace_id,
181            event_tx,
182            decision_rx,
183            cancel_rx,
184            cancel,
185            None, // checkpoint
186            None, // trace_sink
187            restore_from,
188        ));
189
190        GraphExecution {
191            stream: event_rx,
192            handle,
193        }
194    }
195}