Skip to main content

lellm_graph/
execution_loop.rs

1//! Graph 流式执行循环 — SimpleExecutor::execute_stream() 的核心逻辑。
2//!
3//! 包含:
4//! - 执行循环(节点调度、路由、Mutation 消费)
5//! - GraphComplete / GraphError 事件发射
6//! - Checkpoint Save Path(根据 CheckpointPolicy 触发)
7//!
8//! Barrier 等待与决策应用见 [`barrier_wait`] 模块。
9//! Checkpoint Restore 路径留给 v0.5。
10//!
11//! v0.4+: 泛型化 `run_execution_loop<S, M>`,支持任意 `WorkflowState`。
12
13use std::sync::Arc;
14use std::time::Instant;
15
16use serde::Serialize;
17use tokio_util::sync::CancellationToken;
18
19use crate::barrier_wait::{
20    BarrierOutcome, apply_barrier_decision_generic, wait_for_barrier_decision,
21};
22use crate::checkpoint::{Checkpoint, NodeId, TraceId};
23use crate::checkpoint_policy::{RetentionPolicy, TriggerPolicy};
24use crate::error::GraphError;
25use crate::event::{BarrierDecision, BarrierDecisionMessage, GraphEvent};
26use crate::execution_engine::{ExecutionEngine, ExecutionSignal, ExecutorState, NextAction};
27use crate::graph::Graph;
28use crate::ids::SpanId;
29use crate::node::{BarrierNode, ConditionNode, FlowNode, LeafNode, NodeKind};
30use crate::state::{ExecutionEntry, GraphResult};
31use crate::trace::{MemoryTraceSink, TraceSink, TraceStep};
32use crate::workflow_state::{MergeStrategy, WorkflowState};
33
34// ─── CheckpointConfig ─────────────────────────────────────────
35
36/// Checkpoint 保存回调 — 传入 `run_execution_loop` 即可启用自动保存。
37///
38/// v0.4 只做 save path,restore 留给 v0.5。
39type CheckpointSaveFn<S> = Box<
40    dyn Fn(
41            Checkpoint<S>,
42            TraceId,
43        ) -> std::pin::Pin<
44            Box<
45                dyn std::future::Future<
46                        Output = Result<(), crate::checkpoint::CheckpointStoreError>,
47                    > + Send,
48            >,
49        > + Send
50        + Sync,
51>;
52
53/// Checkpoint 保存配置 — 三层策略。
54///
55/// ```text
56/// CheckpointConfig
57///   ├── trigger:    何时保存(EveryNode / BarrierOnly / Manual / OnMutation)
58///   ├── retention:  保留多少个(KeepAll / KeepLatest(N) / TimeBased)
59///   ├── save_fn:    保存回调
60///   └── store:      存储后端引用(用于 prune)
61/// ```
62pub struct CheckpointConfig<S: WorkflowState> {
63    /// 触发策略
64    pub trigger: TriggerPolicy,
65    /// 保留策略
66    pub retention: RetentionPolicy,
67    /// 保存回调
68    save_fn: CheckpointSaveFn<S>,
69    /// 图结构指纹
70    graph_hash: u64,
71    /// 存储后端引用(用于 prune)
72    store: Option<std::sync::Arc<dyn crate::store::BlobCheckpointStore>>,
73}
74
75impl<S: WorkflowState> CheckpointConfig<S> {
76    /// 创建 CheckpointConfig。
77    ///
78    /// `save_fn` 接收 `(Checkpoint<S>, TraceId)` 并异步保存。
79    /// 通常由调用方组合 `TypedCheckpointStore` + `SerdeCheckpointCodec` 构造。
80    pub fn new(
81        save_fn: impl Fn(
82            Checkpoint<S>,
83            TraceId,
84        ) -> std::pin::Pin<
85            Box<
86                dyn std::future::Future<
87                        Output = Result<(), crate::checkpoint::CheckpointStoreError>,
88                    > + Send,
89            >,
90        > + Send
91        + Sync
92        + 'static,
93        graph_hash: u64,
94    ) -> Self {
95        Self {
96            save_fn: Box::new(save_fn),
97            trigger: TriggerPolicy::default(),
98            retention: RetentionPolicy::default(),
99            graph_hash,
100            store: None,
101        }
102    }
103
104    /// 设置触发策略。
105    pub fn with_trigger(mut self, trigger: TriggerPolicy) -> Self {
106        self.trigger = trigger;
107        self
108    }
109
110    /// 设置保留策略。
111    pub fn with_retention(mut self, retention: RetentionPolicy) -> Self {
112        self.retention = retention;
113        self
114    }
115
116    /// 设置存储后端引用(用于 prune)。
117    pub fn with_store(
118        mut self,
119        store: std::sync::Arc<dyn crate::store::BlobCheckpointStore>,
120    ) -> Self {
121        self.store = Some(store);
122        self
123    }
124
125    /// 向后兼容 — 设置旧的 CheckpointPolicy(自动转换为 TriggerPolicy)。
126    #[allow(deprecated)]
127    pub fn with_policy(mut self, policy: crate::checkpoint::CheckpointPolicy) -> Self {
128        self.trigger = policy.into();
129        self
130    }
131
132    /// 根据策略判断是否应该保存。
133    pub fn should_save(&self, has_mutations: bool, is_barrier: bool) -> bool {
134        match self.trigger {
135            TriggerPolicy::EveryNode => true,
136            TriggerPolicy::BarrierOnly => is_barrier,
137            TriggerPolicy::Manual => false,
138            TriggerPolicy::OnMutation => has_mutations,
139        }
140    }
141
142    /// 执行保留策略(prune)。
143    pub async fn apply_retention(
144        &self,
145        trace_id: &TraceId,
146    ) -> Result<(), crate::checkpoint::CheckpointStoreError> {
147        if let Some(keep) = self.retention.prune_keep() {
148            if let Some(ref store) = self.store {
149                let pruned = store.prune(trace_id, keep).await?;
150                if pruned > 0 {
151                    tracing::debug!(pruned, keep, "checkpoint pruned");
152                }
153            }
154        }
155        Ok(())
156    }
157}
158
159/// 运行 Graph 的流式执行循环。
160///
161/// 在 `tokio::spawn` 中调用,通过 channel 发射 `GraphEvent`。
162///
163/// # 泛型
164///
165/// - `S` — 类型化状态
166/// - `M` — 合并策略
167///
168/// # 恢复路径
169///
170/// `restore_from` 包含 Checkpoint 时:
171/// 1. 使用 Checkpoint 中的 state 和 current_node
172/// 2. 如果 current_node 是 Barrier → 立即 Re-Wait(等待新决策)
173/// 3. decision 属于 Control Plane,不回放旧决策
174pub(crate) async fn run_execution_loop<S, M>(
175    graph: Arc<Graph<S, M>>,
176    state: S,
177    max_steps: usize,
178    trace_id: TraceId,
179    event_tx: tokio::sync::mpsc::Sender<GraphEvent<S>>,
180    mut decision_rx: tokio::sync::mpsc::Receiver<BarrierDecisionMessage>,
181    mut cancel_rx: tokio::sync::mpsc::Receiver<()>,
182    cancel: CancellationToken,
183    checkpoint: Option<CheckpointConfig<S>>,
184    trace_sink: Option<MemoryTraceSink<S::Mutation>>,
185    restore_from: Option<Checkpoint<S>>,
186) where
187    S: WorkflowState + Clone + Send + Sync + Serialize + 'static,
188    S::Mutation: Clone + Send + Sync,
189    M: MergeStrategy<S>,
190{
191    let start_time = Instant::now();
192    let mut execution_log: Vec<ExecutionEntry> = Vec::new();
193
194    // 恢复路径:使用 Checkpoint 中的 state 和 current_node
195    // P0-1: checkpoint.state 是 S::Checkpoint,需要通过 restore() 转换为 S
196    let restore_state = restore_from.as_ref().map(|cp| S::restore(cp.state.clone()));
197    let mut engine_state = restore_state.unwrap_or(state);
198    let mut engine = ExecutionEngine::new(&mut engine_state, None, cancel.clone());
199    let mut current = if let Some(ref cp) = restore_from {
200        cp.current_node.0.clone()
201    } else {
202        graph.start_node().to_string()
203    };
204    let mut step: usize = 0;
205    // 缓存通配决策 — 一次发送,多次匹配
206    let mut wildcard_cache = std::collections::HashMap::new();
207    // TraceSink — 审计日志
208    let mut trace_sink = trace_sink;
209
210    let _ = event_tx.send(GraphEvent::GraphStart { trace_id }).await;
211
212    // Barrier Re-Wait 恢复路径
213    // 如果从 Checkpoint 恢复且当前节点是 Barrier,立即重新等待决策。
214    // Decision 属于 Control Plane,不回放旧决策。
215    let mut skip_first_execution = false;
216    if restore_from.is_some() {
217        if let Some(node) = graph.nodes.get(&current) {
218            if matches!(node, NodeKind::Barrier(_)) {
219                let span_id = SpanId::new();
220                let barrier_id = crate::event::BarrierId::new(&current, 0);
221
222                let _ = event_tx
223                    .send(GraphEvent::BarrierWaiting {
224                        barrier_id: barrier_id.clone(),
225                        node_name: current.clone(),
226                        span_id,
227                    })
228                    .await;
229
230                // 提取 Barrier 的 timeout(如果有)
231                let barrier_timeout = if let NodeKind::Barrier(bn) = node {
232                    bn.timeout
233                } else {
234                    None
235                };
236
237                let outcome = wait_for_barrier_decision(
238                    &mut decision_rx,
239                    &mut cancel_rx,
240                    &cancel,
241                    &barrier_id,
242                    barrier_timeout,
243                    &mut wildcard_cache,
244                )
245                .await;
246
247                match outcome {
248                    BarrierOutcome::Decision(d) => {
249                        let _ = event_tx
250                            .send(GraphEvent::BarrierResolved {
251                                barrier_id,
252                                decision: d.clone(),
253                            })
254                            .await;
255
256                        let reroute_target =
257                            apply_barrier_decision_generic(engine.state_mut(), &current, &d);
258
259                        if let Some(target) = reroute_target {
260                            current = target;
261                        }
262                        // 决策已应用,跳过 Barrier 节点的执行(只负责 pause)
263                        skip_first_execution = true;
264                    }
265                    BarrierOutcome::TimedOut => {
266                        let _ = event_tx
267                            .send(GraphEvent::BarrierResolved {
268                                barrier_id,
269                                decision: BarrierDecision::Reject {
270                                    reason: "timeout on restore".into(),
271                                },
272                            })
273                            .await;
274                        apply_barrier_decision_generic(
275                            engine.state_mut(),
276                            &current,
277                            &BarrierDecision::Reject {
278                                reason: "timeout on restore".into(),
279                            },
280                        );
281                        skip_first_execution = true;
282                    }
283                    BarrierOutcome::Cancelled => {
284                        let _ = event_tx
285                            .send(GraphEvent::GraphError {
286                                error: GraphError::Terminal(
287                                    crate::error::TerminalError::BarrierCancelled {
288                                        node: current.clone(),
289                                    },
290                                ),
291                                state: engine.state().clone(),
292                            })
293                            .await;
294                        return;
295                    }
296                }
297            }
298        }
299    }
300
301    loop {
302        // 如果跳过了第一次执行(Barrier Re-Wait),直接进入路由解析
303        if step == 0 && skip_first_execution {
304            // Re-Wait 完成后,直接解析下一步(不执行 Barrier 节点本身)
305            match graph.resolve_next_inline(&current, engine.state()) {
306                Ok(target) => current = target,
307                Err(_) => {
308                    // 没有下一节点,检查是否到达终点
309                    if current == graph.end_node() {
310                        send_complete(
311                            &event_tx,
312                            trace_id,
313                            engine.state(),
314                            execution_log,
315                            start_time,
316                            trace_sink.take(),
317                        );
318                        break;
319                    }
320                    // 否则继续执行当前节点之后的流程
321                    // (实际上不会走到这里,因为 resolve_next 应该成功)
322                }
323            }
324            step += 1;
325            continue;
326        }
327        if cancel.is_cancelled() {
328            let _ = event_tx
329                .send(GraphEvent::GraphError {
330                    error: GraphError::Terminal(crate::error::TerminalError::BarrierCancelled {
331                        node: "execution cancelled".into(),
332                    }),
333                    state: engine.state().clone(),
334                })
335                .await;
336            break;
337        }
338
339        step += 1;
340        if step > max_steps {
341            let _ = event_tx
342                .send(GraphEvent::GraphError {
343                    error: GraphError::Terminal(crate::error::TerminalError::StepsExceeded {
344                        limit: max_steps,
345                    }),
346                    state: engine.state().clone(),
347                })
348                .await;
349            break;
350        }
351
352        let node = match graph.nodes.get(&current) {
353            Some(n) => n,
354            None => {
355                let _ = event_tx
356                    .send(GraphEvent::GraphError {
357                        error: GraphError::Terminal(crate::error::TerminalError::NodeNotFound(
358                            current.clone(),
359                        )),
360                        state: engine.state().clone(),
361                    })
362                    .await;
363                break;
364            }
365        };
366
367        let node_name = current.clone();
368        let span_id = SpanId::new();
369        let node_start = Instant::now();
370
371        let _ = event_tx
372            .send(GraphEvent::NodeStart {
373                node_name: node_name.clone(),
374                trace_id,
375                span_id,
376                step,
377            })
378            .await;
379
380        // 执行节点 — 根据 NodeKind 分发
381        let is_barrier = matches!(node, NodeKind::Barrier(_));
382        let node_ok = match node {
383            NodeKind::Task(n) => {
384                let mut ctx = engine.build_node_context();
385                n.execute(&mut ctx).await.is_ok()
386            }
387            NodeKind::Condition(n) => {
388                let mut ctx = engine.build_leaf_context();
389                <ConditionNode<S> as LeafNode<S>>::execute(n, &mut ctx)
390                    .await
391                    .is_ok()
392            }
393            NodeKind::Barrier(n) => {
394                let mut ctx = engine.build_leaf_context();
395                <BarrierNode<S> as LeafNode<S>>::execute(n, &mut ctx)
396                    .await
397                    .is_ok()
398            }
399            NodeKind::External(n) => {
400                let mut ctx = engine.build_node_context();
401                n.execute(&mut ctx).await.is_ok()
402            }
403            NodeKind::ExternalLeaf(n) => {
404                let mut ctx = engine.build_leaf_context();
405                n.execute(&mut ctx).await.is_ok()
406            }
407            NodeKind::Parallel(p) => p.execute(&mut engine).await.is_ok(),
408            NodeKind::Subgraph(subgraph) => {
409                // Subgraph 执行 — 通过 StateProjector 投影状态 + 递归执行内层 Graph
410                let stream = engine.stream_sink();
411                let cancel = engine.cancel_token().clone();
412                subgraph
413                    .execute(engine.state_mut(), stream, cancel)
414                    .await
415                    .is_ok()
416            }
417        };
418
419        if !node_ok {
420            let _ = event_tx
421                .send(GraphEvent::NodeEnd {
422                    node_name: node_name.clone(),
423                    trace_id,
424                    span_id,
425                    success: false,
426                    duration: node_start.elapsed(),
427                })
428                .await;
429
430            let _ = event_tx
431                .send(GraphEvent::GraphError {
432                    error: GraphError::Terminal(crate::error::TerminalError::NodeExecutionFailed {
433                        node: node_name,
434                        source: "node execution failed".into(),
435                    }),
436                    state: engine.state().clone(),
437                })
438                .await;
439            break;
440        }
441
442        // 消费 FlowEvent 缓冲 → 转发为 GraphEvent::Node
443        let flow_events = engine.take_flow_events();
444        for fe in flow_events {
445            let _ = event_tx
446                .send(GraphEvent::Node {
447                    span_id,
448                    node_name: node_name.clone(),
449                    event: fe,
450                })
451                .await;
452        }
453
454        // commit mutations (Unit of Work) — 三段式流水线
455        // P1: take batch → P3: trace/mutation-log → apply to state
456        let commit_batch = engine.take_commit_batch();
457        let has_mutations = !commit_batch.is_empty();
458
459        // TraceSink 记录(如果启用)
460        if let Some(ref mut sink) = trace_sink {
461            if !commit_batch.is_empty() {
462                sink.record_step(TraceStep {
463                    step,
464                    node_id: NodeId(node_name.clone()),
465                    mutations: commit_batch.clone(),
466                });
467            }
468        }
469
470        engine.apply_batch_to_state(commit_batch);
471
472        let node_duration = node_start.elapsed();
473        execution_log.push(ExecutionEntry {
474            step,
475            node_name: node_name.clone(),
476            start_time,
477            end_time: start_time.checked_add(node_duration).unwrap_or(start_time),
478            success: true,
479            error: None,
480        });
481
482        let _ = event_tx
483            .send(GraphEvent::NodeEnd {
484                node_name: node_name.clone(),
485                trace_id,
486                span_id,
487                success: true,
488                duration: node_duration,
489            })
490            .await;
491
492        // 提取控制信号
493        let (next_action, signal) = engine.take_control();
494
495        // 处理 Barrier 信号
496        let mut next_action = next_action;
497
498        if let Some(ExecutionSignal::Pause {
499            barrier_id,
500            timeout,
501        }) = signal
502        {
503            let _ = event_tx
504                .send(GraphEvent::BarrierWaiting {
505                    barrier_id: barrier_id.clone(),
506                    node_name: node_name.clone(),
507                    span_id,
508                })
509                .await;
510
511            let outcome = wait_for_barrier_decision(
512                &mut decision_rx,
513                &mut cancel_rx,
514                &cancel,
515                &barrier_id,
516                timeout,
517                &mut wildcard_cache,
518            )
519            .await;
520
521            match outcome {
522                BarrierOutcome::Decision(d) => {
523                    let _ = event_tx
524                        .send(GraphEvent::BarrierResolved {
525                            barrier_id,
526                            decision: d.clone(),
527                        })
528                        .await;
529
530                    let reroute_target =
531                        apply_barrier_decision_generic(engine.state_mut(), &node_name, &d);
532
533                    if let Some(target) = reroute_target {
534                        current = target;
535                        continue;
536                    }
537                    next_action = NextAction::Next;
538                }
539                BarrierOutcome::TimedOut => {
540                    // 超时 → 应用默认 Reject 决策
541                    let _ = event_tx
542                        .send(GraphEvent::BarrierResolved {
543                            barrier_id,
544                            decision: BarrierDecision::Reject {
545                                reason: "timeout".into(),
546                            },
547                        })
548                        .await;
549                    apply_barrier_decision_generic(
550                        engine.state_mut(),
551                        &node_name,
552                        &BarrierDecision::Reject {
553                            reason: "timeout".into(),
554                        },
555                    );
556                    next_action = NextAction::Next;
557                }
558                BarrierOutcome::Cancelled => {
559                    let _ = event_tx
560                        .send(GraphEvent::GraphError {
561                            error: GraphError::Terminal(
562                                crate::error::TerminalError::BarrierCancelled {
563                                    node: node_name.clone(),
564                                },
565                            ),
566                            state: engine.state().clone(),
567                        })
568                        .await;
569                    break;
570                }
571            }
572        }
573
574        // 处理路由
575        match next_action {
576            NextAction::End => {
577                send_complete(
578                    &event_tx,
579                    trace_id,
580                    engine.state(),
581                    execution_log,
582                    start_time,
583                    trace_sink.take(),
584                );
585                break;
586            }
587            NextAction::Goto(target) => {
588                current = target;
589            }
590            NextAction::Next => {
591                if current == graph.end_node() {
592                    send_complete(
593                        &event_tx,
594                        trace_id,
595                        engine.state(),
596                        execution_log,
597                        start_time,
598                        trace_sink.take(),
599                    );
600                    break;
601                }
602                match graph.resolve_next_inline(&current, engine.state()) {
603                    Ok(target) => current = target,
604                    Err(e) => {
605                        let _ = event_tx
606                            .send(GraphEvent::GraphError {
607                                error: e,
608                                state: engine.state().clone(),
609                            })
610                            .await;
611                        break;
612                    }
613                }
614            }
615        }
616
617        // Checkpoint Save Path — 路由已解析,current 是下一个节点
618        if let Some(ref cp_config) = checkpoint {
619            if cp_config.should_save(has_mutations, is_barrier) {
620                // P0-1: Checkpoint::new 现在接受 &S,内部调用 snapshot() 进行投影
621                let cp = Checkpoint::new(&current, engine.state(), cp_config.graph_hash);
622                match (cp_config.save_fn)(cp, trace_id).await {
623                    Ok(()) => {
624                        // 保存成功后,应用保留策略
625                        if let Err(e) = cp_config.apply_retention(&trace_id).await {
626                            tracing::warn!(error = %e, "checkpoint retention failed");
627                        }
628                    }
629                    Err(e) => {
630                        tracing::warn!(error = %e, "checkpoint save failed");
631                    }
632                }
633            }
634        }
635    }
636}
637
638/// 发送 GraphComplete 事件。
639pub(crate) fn send_complete<S: WorkflowState>(
640    event_tx: &tokio::sync::mpsc::Sender<GraphEvent<S>>,
641    trace_id: TraceId,
642    final_state: &S,
643    execution_log: Vec<ExecutionEntry>,
644    start_time: Instant,
645    trace_sink: Option<MemoryTraceSink<S::Mutation>>,
646) {
647    let duration = start_time.elapsed();
648    let trace = trace_sink.map(|sink| sink.into_trace());
649    let result = GraphResult {
650        trace_id,
651        state: final_state.clone(),
652        execution_log,
653        duration,
654        trace,
655    };
656    let _ = event_tx.try_send(GraphEvent::GraphComplete { result });
657}