Skip to main content

lellm_graph/exec/
execution_loop.rs

1//! Graph 流式执行循环 — Sink 组装层。
2//!
3//! 职责:组装 Sink(Barrier/Checkpoint),调用 `graph.run_inline()`,
4//! 发射 `GraphEvent` 边界事件(GraphStart / GraphComplete / GraphError)。
5//!
6//! 执行逻辑统一由 `Graph::run_inline()` 负责,本模块不再包含执行循环。
7
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use serde::Serialize;
12use tokio_util::sync::CancellationToken;
13
14use crate::checkpoint::{Checkpoint, CheckpointSink, FrameInfo, TraceId};
15use crate::event::{BarrierDecisionMessage, GraphEvent};
16use crate::exec::execution_engine::ExecutionEngine;
17use crate::graph::{Graph, StepCallback};
18use crate::node::barrier_sink::ChannelBarrierSink;
19use crate::state::workflow_state::WorkflowState;
20use crate::state::{ExecutionEntry, GraphResult};
21
22// ─── CheckpointConfig ──────────────────────────────────────────
23
24/// Checkpoint 保存配置 — 传入 `run_execution_loop` 即可启用自动保存。
25pub struct CheckpointConfig<S: WorkflowState> {
26    /// 触发策略
27    pub trigger: crate::checkpoint::checkpoint_policy::TriggerPolicy,
28    /// 保留策略
29    pub retention: crate::checkpoint::checkpoint_policy::RetentionPolicy,
30    /// 保存回调
31    save_fn: Arc<crate::checkpoint::checkpoint_policy::CheckpointSaveFn<S>>,
32    /// 图结构指纹
33    graph_hash: u64,
34    /// 存储后端引用(用于 prune)
35    store: Option<Arc<dyn crate::checkpoint::store::BlobCheckpointStore>>,
36}
37
38impl<S: WorkflowState> CheckpointConfig<S> {
39    pub fn new(
40        save_fn: impl Fn(
41            Checkpoint<S>,
42            TraceId,
43        ) -> std::pin::Pin<
44            Box<
45                dyn std::future::Future<
46                        Output = Result<(), crate::checkpoint::CheckpointStoreError>,
47                    > + Send,
48            >,
49        > + Send
50        + Sync
51        + 'static,
52        graph_hash: u64,
53    ) -> Self {
54        Self {
55            save_fn: Arc::new(Box::new(save_fn)),
56            trigger: crate::checkpoint::checkpoint_policy::TriggerPolicy::default(),
57            retention: crate::checkpoint::checkpoint_policy::RetentionPolicy::default(),
58            graph_hash,
59            store: None,
60        }
61    }
62
63    pub fn with_trigger(
64        mut self,
65        trigger: crate::checkpoint::checkpoint_policy::TriggerPolicy,
66    ) -> Self {
67        self.trigger = trigger;
68        self
69    }
70
71    pub fn with_retention(
72        mut self,
73        retention: crate::checkpoint::checkpoint_policy::RetentionPolicy,
74    ) -> Self {
75        self.retention = retention;
76        self
77    }
78
79    pub fn with_store(
80        mut self,
81        store: Arc<dyn crate::checkpoint::store::BlobCheckpointStore>,
82    ) -> Self {
83        self.store = Some(store);
84        self
85    }
86
87    #[allow(deprecated)]
88    pub fn with_policy(mut self, policy: crate::checkpoint::CheckpointPolicy) -> Self {
89        self.trigger = policy.into();
90        self
91    }
92
93    pub async fn apply_retention(
94        &self,
95        trace_id: &TraceId,
96    ) -> Result<(), crate::checkpoint::CheckpointStoreError> {
97        if let Some(keep) = self.retention.prune_keep() {
98            if let Some(ref store) = self.store {
99                let pruned = store.prune(trace_id, keep).await?;
100                if pruned > 0 {
101                    tracing::debug!(pruned, keep, "checkpoint pruned");
102                }
103            }
104        }
105        Ok(())
106    }
107}
108
109// ─── CheckpointSaveSink ─────────────────────────────────────────
110
111/// Checkpoint 保存 Sink — 包装 CheckpointConfig 为 CheckpointSink。
112pub struct CheckpointSaveSink<S: WorkflowState> {
113    save_fn: Arc<crate::checkpoint::checkpoint_policy::CheckpointSaveFn<S>>,
114    graph_hash: u64,
115    trace_id: TraceId,
116    retention: crate::checkpoint::checkpoint_policy::RetentionPolicy,
117    store: Option<Arc<dyn crate::checkpoint::store::BlobCheckpointStore>>,
118}
119
120impl<S: WorkflowState> CheckpointSaveSink<S> {
121    pub fn new(config: CheckpointConfig<S>, trace_id: TraceId) -> Self {
122        Self {
123            save_fn: config.save_fn,
124            graph_hash: config.graph_hash,
125            trace_id,
126            retention: config.retention,
127            store: config.store,
128        }
129    }
130}
131
132impl<S: WorkflowState + 'static> CheckpointSink<S> for CheckpointSaveSink<S> {
133    fn on_checkpoint(&mut self, state: &S, frame: &FrameInfo) {
134        let save_fn = self.save_fn.clone();
135        let graph_hash = self.graph_hash;
136        let trace_id = self.trace_id;
137        let retention = self.retention.clone();
138        let store = self.store.clone();
139        let cp = Checkpoint::new(frame.node_id.clone(), state, graph_hash);
140
141        tokio::spawn(async move {
142            match save_fn(cp, trace_id).await {
143                Ok(()) => {
144                    if let Some(keep) = retention.prune_keep() {
145                        if let Some(ref s) = store {
146                            if let Err(e) = s.prune(&trace_id, keep).await {
147                                tracing::warn!(error = %e, "checkpoint retention failed");
148                            }
149                        }
150                    }
151                }
152                Err(e) => {
153                    tracing::warn!(error = %e, "checkpoint save failed");
154                }
155            }
156        });
157    }
158}
159
160// ─── EventStepCallback ──────────────────────────────────────────
161
162/// StepCallback 实现 — 用于 run_execution_loop 追踪执行日志。
163struct EventStepCallback {
164    start_time: Instant,
165    execution_log: Vec<ExecutionEntry>,
166}
167
168impl EventStepCallback {
169    fn new(start_time: Instant) -> Self {
170        Self {
171            start_time,
172            execution_log: Vec::new(),
173        }
174    }
175
176    fn into_log(self) -> Vec<ExecutionEntry> {
177        self.execution_log
178    }
179}
180
181impl StepCallback<'_> for EventStepCallback {
182    fn on_step(&mut self, node_name: &str, step: usize, duration: Duration) {
183        let node_end = self
184            .start_time
185            .checked_add(duration)
186            .unwrap_or(self.start_time);
187        self.execution_log.push(ExecutionEntry {
188            step,
189            node_name: node_name.to_string(),
190            start_time: self.start_time,
191            end_time: node_end,
192            success: true,
193            error: None,
194        });
195    }
196}
197
198// ─── run_execution_loop ─────────────────────────────────────────
199
200/// 运行 Graph 的流式执行循环。
201///
202/// 在 `tokio::spawn` 中调用,通过 channel 发射 `GraphEvent`。
203///
204/// # Sink 组装
205///
206/// ```text
207/// run_execution_loop
208///   ├── ChannelBarrierSink  — Barrier 等待 + 决策注入
209///   ├── CheckpointSaveSink  — Checkpoint 保存(可选)
210///   └── graph.run_inline()  — 唯一执行路径
211/// ```
212pub(crate) async fn run_execution_loop<S, M>(
213    graph: Arc<Graph<S, M>>,
214    state: S,
215    max_steps: usize,
216    trace_id: TraceId,
217    event_tx: tokio::sync::mpsc::Sender<GraphEvent<S>>,
218    decision_rx: tokio::sync::mpsc::Receiver<BarrierDecisionMessage>,
219    cancel_rx: tokio::sync::mpsc::Receiver<()>,
220    cancel: CancellationToken,
221    checkpoint: Option<CheckpointConfig<S>>,
222    _trace_sink: Option<crate::checkpoint::trace::MemoryTraceSink<S::Mutation>>,
223    restore_from: Option<Checkpoint<S>>,
224) where
225    S: WorkflowState + Clone + Send + Sync + Serialize + 'static,
226    S::Mutation: Clone + Send + Sync,
227    M: crate::state::workflow_state::MergeStrategy<S>,
228{
229    let start_time = Instant::now();
230
231    // 恢复路径:从 Checkpoint 恢复 State
232    let restore_state = restore_from.as_ref().map(|cp| S::restore(cp.state.clone()));
233    let mut engine_state = restore_state.unwrap_or(state);
234
235    // 组装 Barrier Sink
236    let mut barrier_sink = ChannelBarrierSink::new(decision_rx, cancel_rx, cancel.clone());
237
238    // 组装 Checkpoint Sink
239    let mut cp_sink: Option<CheckpointSaveSink<S>> =
240        checkpoint.map(|cfg| CheckpointSaveSink::new(cfg, trace_id));
241
242    // 发射 GraphStart
243    let _ = event_tx.send(GraphEvent::GraphStart { trace_id }).await;
244
245    // step_cb 在 Engine 外部创建,以便在 Engine drop 后获取 execution_log
246    let mut step_cb = EventStepCallback::new(start_time);
247
248    // 在块作用域中创建 Engine,限制借用生命周期
249    let result = {
250        let mut engine = ExecutionEngine::new(
251            &mut engine_state,
252            None,
253            cancel.clone(),
254            cp_sink.as_mut().map(|s| s as &mut dyn CheckpointSink<S>),
255            Some(&mut barrier_sink),
256        );
257        graph.run_inline(&mut engine, max_steps, &mut step_cb).await
258    };
259
260    // engine 已 drop,可以安全访问 engine_state
261    let final_state = engine_state;
262    let execution_log = step_cb.into_log();
263
264    match result {
265        Ok(()) => {
266            let duration = start_time.elapsed();
267            let graph_result = GraphResult {
268                trace_id,
269                state: final_state,
270                execution_log,
271                duration,
272                trace: None,
273            };
274            let _ = event_tx.try_send(GraphEvent::GraphComplete {
275                result: graph_result,
276            });
277        }
278        Err(error) => {
279            let _ = event_tx
280                .send(GraphEvent::GraphError {
281                    error,
282                    state: final_state,
283                })
284                .await;
285        }
286    }
287}
288
289// ─── send_complete (deprecated) ─────────────────────────────────
290
291/// 发送 GraphComplete 事件。
292///
293/// @deprecated — 由 run_execution_loop 内部处理。
294#[allow(dead_code)]
295pub(crate) fn send_complete<S: WorkflowState>(
296    event_tx: &tokio::sync::mpsc::Sender<GraphEvent<S>>,
297    trace_id: TraceId,
298    final_state: &S,
299    execution_log: Vec<ExecutionEntry>,
300    start_time: Instant,
301    trace_sink: Option<crate::checkpoint::trace::MemoryTraceSink<S::Mutation>>,
302) {
303    let duration = start_time.elapsed();
304    let trace = trace_sink.map(|sink| sink.into_trace());
305    let result = GraphResult {
306        trace_id,
307        state: final_state.clone(),
308        execution_log,
309        duration,
310        trace,
311    };
312    let _ = event_tx.try_send(GraphEvent::GraphComplete { result });
313}