Skip to main content

lellm_graph/
execution_loop.rs

1//! Graph 流式执行循环 — SimpleExecutor::execute_stream() 的核心逻辑。
2//!
3//! 包含:
4//! - 执行循环(节点调度、路由、Mutation 消费)
5//! - GraphComplete / GraphError 事件发射
6//! - Checkpoint Save Path(根据 CheckpointPolicy 触发)
7//!
8//! Barrier 等待与决策应用见 [`barrier_wait`] 模块。
9//! Checkpoint Restore 路径留给 v0.5。
10//!
11//! v0.4+: 泛型化 `run_execution_loop<S, M>`,支持任意 `WorkflowState`。
12
13use std::sync::Arc;
14use std::time::Instant;
15
16use serde::Serialize;
17use tokio_util::sync::CancellationToken;
18
19use crate::barrier_wait::{
20    BarrierOutcome, apply_barrier_decision_generic, wait_for_barrier_decision,
21};
22use crate::checkpoint::{Checkpoint, NodeId, TraceId};
23use crate::checkpoint_policy::{RetentionPolicy, TriggerPolicy};
24use crate::error::GraphError;
25use crate::event::{BarrierDecision, BarrierDecisionMessage, GraphEvent};
26use crate::execution_engine::{ExecutionEngine, ExecutionSignal, ExecutorState, NextAction};
27use crate::graph::Graph;
28use crate::ids::SpanId;
29use crate::node::{BarrierNode, ConditionNode, ExecutorOperation, FlowNode, LeafNode, NodeKind};
30use crate::state::{ExecutionEntry, GraphResult};
31use crate::trace::{MemoryTraceSink, TraceSink, TraceStep};
32use crate::workflow_state::{MergeStrategy, WorkflowState};
33
34// ─── CheckpointConfig ─────────────────────────────────────────
35
36/// Checkpoint 保存回调 — 传入 `run_execution_loop` 即可启用自动保存。
37///
38/// v0.4 只做 save path,restore 留给 v0.5。
39type CheckpointSaveFn<S> = Box<
40    dyn Fn(
41            Checkpoint<S>,
42            TraceId,
43        ) -> std::pin::Pin<
44            Box<
45                dyn std::future::Future<
46                        Output = Result<(), crate::checkpoint::CheckpointStoreError>,
47                    > + Send,
48            >,
49        > + Send
50        + Sync,
51>;
52
53/// Checkpoint 保存配置 — 三层策略。
54///
55/// ```text
56/// CheckpointConfig
57///   ├── trigger:    何时保存(EveryNode / BarrierOnly / Manual / OnMutation)
58///   ├── retention:  保留多少个(KeepAll / KeepLatest(N) / TimeBased)
59///   ├── save_fn:    保存回调
60///   └── store:      存储后端引用(用于 prune)
61/// ```
62pub struct CheckpointConfig<S: WorkflowState> {
63    /// 触发策略
64    pub trigger: TriggerPolicy,
65    /// 保留策略
66    pub retention: RetentionPolicy,
67    /// 保存回调
68    save_fn: CheckpointSaveFn<S>,
69    /// 图结构指纹
70    graph_hash: u64,
71    /// 存储后端引用(用于 prune)
72    store: Option<std::sync::Arc<dyn crate::store::BlobCheckpointStore>>,
73}
74
75impl<S: WorkflowState> CheckpointConfig<S> {
76    /// 创建 CheckpointConfig。
77    ///
78    /// `save_fn` 接收 `(Checkpoint<S>, TraceId)` 并异步保存。
79    /// 通常由调用方组合 `TypedCheckpointStore` + `SerdeCheckpointCodec` 构造。
80    pub fn new(
81        save_fn: impl Fn(
82            Checkpoint<S>,
83            TraceId,
84        ) -> std::pin::Pin<
85            Box<
86                dyn std::future::Future<
87                        Output = Result<(), crate::checkpoint::CheckpointStoreError>,
88                    > + Send,
89            >,
90        > + Send
91        + Sync
92        + 'static,
93        graph_hash: u64,
94    ) -> Self {
95        Self {
96            save_fn: Box::new(save_fn),
97            trigger: TriggerPolicy::default(),
98            retention: RetentionPolicy::default(),
99            graph_hash,
100            store: None,
101        }
102    }
103
104    /// 设置触发策略。
105    pub fn with_trigger(mut self, trigger: TriggerPolicy) -> Self {
106        self.trigger = trigger;
107        self
108    }
109
110    /// 设置保留策略。
111    pub fn with_retention(mut self, retention: RetentionPolicy) -> Self {
112        self.retention = retention;
113        self
114    }
115
116    /// 设置存储后端引用(用于 prune)。
117    pub fn with_store(
118        mut self,
119        store: std::sync::Arc<dyn crate::store::BlobCheckpointStore>,
120    ) -> Self {
121        self.store = Some(store);
122        self
123    }
124
125    /// 向后兼容 — 设置旧的 CheckpointPolicy(自动转换为 TriggerPolicy)。
126    #[allow(deprecated)]
127    pub fn with_policy(mut self, policy: crate::checkpoint::CheckpointPolicy) -> Self {
128        self.trigger = policy.into();
129        self
130    }
131
132    /// 根据策略判断是否应该保存。
133    pub fn should_save(&self, has_mutations: bool, is_barrier: bool) -> bool {
134        match self.trigger {
135            TriggerPolicy::EveryNode => true,
136            TriggerPolicy::BarrierOnly => is_barrier,
137            TriggerPolicy::Manual => false,
138            TriggerPolicy::OnMutation => has_mutations,
139        }
140    }
141
142    /// 执行保留策略(prune)。
143    pub async fn apply_retention(
144        &self,
145        trace_id: &TraceId,
146    ) -> Result<(), crate::checkpoint::CheckpointStoreError> {
147        if let Some(keep) = self.retention.prune_keep() {
148            if let Some(ref store) = self.store {
149                let pruned = store.prune(trace_id, keep).await?;
150                if pruned > 0 {
151                    tracing::debug!(pruned, keep, "checkpoint pruned");
152                }
153            }
154        }
155        Ok(())
156    }
157}
158
159/// 运行 Graph 的流式执行循环。
160///
161/// 在 `tokio::spawn` 中调用,通过 channel 发射 `GraphEvent`。
162///
163/// # 泛型
164///
165/// - `S` — 类型化状态
166/// - `M` — 合并策略
167///
168/// # 恢复路径
169///
170/// `restore_from` 包含 Checkpoint 时:
171/// 1. 使用 Checkpoint 中的 state 和 current_node
172/// 2. 如果 current_node 是 Barrier → 立即 Re-Wait(等待新决策)
173/// 3. decision 属于 Control Plane,不回放旧决策
174pub(crate) async fn run_execution_loop<S, M>(
175    graph: Arc<Graph<S, M>>,
176    state: S,
177    max_steps: usize,
178    trace_id: TraceId,
179    event_tx: tokio::sync::mpsc::Sender<GraphEvent<S>>,
180    mut decision_rx: tokio::sync::mpsc::Receiver<BarrierDecisionMessage>,
181    mut cancel_rx: tokio::sync::mpsc::Receiver<()>,
182    cancel: CancellationToken,
183    checkpoint: Option<CheckpointConfig<S>>,
184    trace_sink: Option<MemoryTraceSink<S::Mutation>>,
185    restore_from: Option<Checkpoint<S>>,
186) where
187    S: WorkflowState + Clone + Send + Sync + Serialize + 'static,
188    S::Mutation: Clone + Send + Sync,
189    M: MergeStrategy<S>,
190{
191    let start_time = Instant::now();
192    let mut execution_log: Vec<ExecutionEntry> = Vec::new();
193
194    // 恢复路径:使用 Checkpoint 中的 state 和 current_node
195    let restore_state = restore_from.as_ref().map(|cp| cp.state.clone());
196    let engine_state = restore_state.unwrap_or(state);
197    let mut engine = ExecutionEngine::new(engine_state, None, cancel.clone());
198    let mut current = if let Some(ref cp) = restore_from {
199        cp.current_node.0.clone()
200    } else {
201        graph.start_node().to_string()
202    };
203    let mut step: usize = 0;
204    // 缓存通配决策 — 一次发送,多次匹配
205    let mut wildcard_cache = std::collections::HashMap::new();
206    // TraceSink — 审计日志
207    let mut trace_sink = trace_sink;
208
209    let _ = event_tx.send(GraphEvent::GraphStart { trace_id }).await;
210
211    // Barrier Re-Wait 恢复路径
212    // 如果从 Checkpoint 恢复且当前节点是 Barrier,立即重新等待决策。
213    // Decision 属于 Control Plane,不回放旧决策。
214    let mut skip_first_execution = false;
215    if restore_from.is_some() {
216        if let Some(node) = graph.nodes.get(&current) {
217            if matches!(node, NodeKind::Barrier(_)) {
218                let span_id = SpanId::new();
219                let barrier_id = crate::event::BarrierId::new(&current, 0);
220
221                let _ = event_tx
222                    .send(GraphEvent::BarrierWaiting {
223                        barrier_id: barrier_id.clone(),
224                        node_name: current.clone(),
225                        span_id,
226                    })
227                    .await;
228
229                // 提取 Barrier 的 timeout(如果有)
230                let barrier_timeout = if let NodeKind::Barrier(bn) = node {
231                    bn.timeout
232                } else {
233                    None
234                };
235
236                let outcome = wait_for_barrier_decision(
237                    &mut decision_rx,
238                    &mut cancel_rx,
239                    &cancel,
240                    &barrier_id,
241                    barrier_timeout,
242                    &mut wildcard_cache,
243                )
244                .await;
245
246                match outcome {
247                    BarrierOutcome::Decision(d) => {
248                        let _ = event_tx
249                            .send(GraphEvent::BarrierResolved {
250                                barrier_id,
251                                decision: d.clone(),
252                            })
253                            .await;
254
255                        let reroute_target =
256                            apply_barrier_decision_generic(engine.state_mut(), &current, &d);
257
258                        if let Some(target) = reroute_target {
259                            current = target;
260                        }
261                        // 决策已应用,跳过 Barrier 节点的执行(只负责 pause)
262                        skip_first_execution = true;
263                    }
264                    BarrierOutcome::TimedOut => {
265                        let _ = event_tx
266                            .send(GraphEvent::BarrierResolved {
267                                barrier_id,
268                                decision: BarrierDecision::Reject {
269                                    reason: "timeout on restore".into(),
270                                },
271                            })
272                            .await;
273                        apply_barrier_decision_generic(
274                            engine.state_mut(),
275                            &current,
276                            &BarrierDecision::Reject {
277                                reason: "timeout on restore".into(),
278                            },
279                        );
280                        skip_first_execution = true;
281                    }
282                    BarrierOutcome::Cancelled => {
283                        let _ = event_tx
284                            .send(GraphEvent::GraphError {
285                                error: GraphError::Terminal(
286                                    crate::error::TerminalError::BarrierCancelled {
287                                        node: current.clone(),
288                                    },
289                                ),
290                                state: engine.state().clone(),
291                            })
292                            .await;
293                        return;
294                    }
295                }
296            }
297        }
298    }
299
300    loop {
301        // 如果跳过了第一次执行(Barrier Re-Wait),直接进入路由解析
302        if step == 0 && skip_first_execution {
303            // Re-Wait 完成后,直接解析下一步(不执行 Barrier 节点本身)
304            match graph.resolve_next_inline(&current, engine.state()) {
305                Ok(target) => current = target,
306                Err(_) => {
307                    // 没有下一节点,检查是否到达终点
308                    if current == graph.end_node() {
309                        send_complete(
310                            &event_tx,
311                            trace_id,
312                            engine,
313                            execution_log,
314                            start_time,
315                            trace_sink.take(),
316                        );
317                        break;
318                    }
319                    // 否则继续执行当前节点之后的流程
320                    // (实际上不会走到这里,因为 resolve_next 应该成功)
321                }
322            }
323            step += 1;
324            continue;
325        }
326        if cancel.is_cancelled() {
327            let _ = event_tx
328                .send(GraphEvent::GraphError {
329                    error: GraphError::Terminal(crate::error::TerminalError::BarrierCancelled {
330                        node: "execution cancelled".into(),
331                    }),
332                    state: engine.state().clone(),
333                })
334                .await;
335            break;
336        }
337
338        step += 1;
339        if step > max_steps {
340            let _ = event_tx
341                .send(GraphEvent::GraphError {
342                    error: GraphError::Terminal(crate::error::TerminalError::StepsExceeded {
343                        limit: max_steps,
344                    }),
345                    state: engine.state().clone(),
346                })
347                .await;
348            break;
349        }
350
351        let node = match graph.nodes.get(&current) {
352            Some(n) => n,
353            None => {
354                let _ = event_tx
355                    .send(GraphEvent::GraphError {
356                        error: GraphError::Terminal(crate::error::TerminalError::NodeNotFound(
357                            current.clone(),
358                        )),
359                        state: engine.state().clone(),
360                    })
361                    .await;
362                break;
363            }
364        };
365
366        let node_name = current.clone();
367        let span_id = SpanId::new();
368        let node_start = Instant::now();
369
370        let _ = event_tx
371            .send(GraphEvent::NodeStart {
372                node_name: node_name.clone(),
373                trace_id,
374                span_id,
375                step,
376            })
377            .await;
378
379        // 执行节点 — 根据 NodeKind 分发
380        let is_barrier = matches!(node, NodeKind::Barrier(_));
381        let node_ok = match node {
382            NodeKind::Task(n) => {
383                let mut ctx = engine.build_node_context();
384                n.execute(&mut ctx).await.is_ok()
385            }
386            NodeKind::Condition(n) => {
387                let mut ctx = engine.build_leaf_context();
388                <ConditionNode<S> as LeafNode<S>>::execute(n, &mut ctx)
389                    .await
390                    .is_ok()
391            }
392            NodeKind::Barrier(n) => {
393                let mut ctx = engine.build_leaf_context();
394                <BarrierNode<S> as LeafNode<S>>::execute(n, &mut ctx)
395                    .await
396                    .is_ok()
397            }
398            NodeKind::External(n) => {
399                let mut ctx = engine.build_node_context();
400                n.execute(&mut ctx).await.is_ok()
401            }
402            NodeKind::ExternalLeaf(n) => {
403                let mut ctx = engine.build_leaf_context();
404                n.execute(&mut ctx).await.is_ok()
405            }
406            NodeKind::Parallel(p) => p.execute(&mut engine).await.is_ok(),
407        };
408
409        if !node_ok {
410            let _ = event_tx
411                .send(GraphEvent::NodeEnd {
412                    node_name: node_name.clone(),
413                    trace_id,
414                    span_id,
415                    success: false,
416                    duration: node_start.elapsed(),
417                })
418                .await;
419
420            let _ = event_tx
421                .send(GraphEvent::GraphError {
422                    error: GraphError::Terminal(crate::error::TerminalError::NodeExecutionFailed {
423                        node: node_name,
424                        source: "node execution failed".into(),
425                    }),
426                    state: engine.state().clone(),
427                })
428                .await;
429            break;
430        }
431
432        // 消费 FlowEvent 缓冲 → 转发为 GraphEvent::Node
433        let flow_events = engine.take_flow_events();
434        for fe in flow_events {
435            let _ = event_tx
436                .send(GraphEvent::Node {
437                    span_id,
438                    node_name: node_name.clone(),
439                    event: fe,
440                })
441                .await;
442        }
443
444        // commit mutations (Unit of Work) — 三段式流水线
445        // P1: take batch → P3: trace/mutation-log → apply to state
446        let commit_batch = engine.take_commit_batch();
447        let has_mutations = !commit_batch.is_empty();
448
449        // TraceSink 记录(如果启用)
450        if let Some(ref mut sink) = trace_sink {
451            if !commit_batch.is_empty() {
452                sink.record_step(TraceStep {
453                    step,
454                    node_id: NodeId(node_name.clone()),
455                    mutations: commit_batch.clone(),
456                });
457            }
458        }
459
460        engine.apply_batch_to_state(commit_batch);
461
462        let node_duration = node_start.elapsed();
463        execution_log.push(ExecutionEntry {
464            step,
465            node_name: node_name.clone(),
466            start_time,
467            end_time: start_time.checked_add(node_duration).unwrap_or(start_time),
468            success: true,
469            error: None,
470        });
471
472        let _ = event_tx
473            .send(GraphEvent::NodeEnd {
474                node_name: node_name.clone(),
475                trace_id,
476                span_id,
477                success: true,
478                duration: node_duration,
479            })
480            .await;
481
482        // 提取控制信号
483        let (next_action, signal) = engine.take_control();
484
485        // 处理 Barrier 信号
486        let mut next_action = next_action;
487
488        if let Some(ExecutionSignal::Pause {
489            barrier_id,
490            timeout,
491        }) = signal
492        {
493            let _ = event_tx
494                .send(GraphEvent::BarrierWaiting {
495                    barrier_id: barrier_id.clone(),
496                    node_name: node_name.clone(),
497                    span_id,
498                })
499                .await;
500
501            let outcome = wait_for_barrier_decision(
502                &mut decision_rx,
503                &mut cancel_rx,
504                &cancel,
505                &barrier_id,
506                timeout,
507                &mut wildcard_cache,
508            )
509            .await;
510
511            match outcome {
512                BarrierOutcome::Decision(d) => {
513                    let _ = event_tx
514                        .send(GraphEvent::BarrierResolved {
515                            barrier_id,
516                            decision: d.clone(),
517                        })
518                        .await;
519
520                    let reroute_target =
521                        apply_barrier_decision_generic(engine.state_mut(), &node_name, &d);
522
523                    if let Some(target) = reroute_target {
524                        current = target;
525                        continue;
526                    }
527                    next_action = NextAction::Next;
528                }
529                BarrierOutcome::TimedOut => {
530                    // 超时 → 应用默认 Reject 决策
531                    let _ = event_tx
532                        .send(GraphEvent::BarrierResolved {
533                            barrier_id,
534                            decision: BarrierDecision::Reject {
535                                reason: "timeout".into(),
536                            },
537                        })
538                        .await;
539                    apply_barrier_decision_generic(
540                        engine.state_mut(),
541                        &node_name,
542                        &BarrierDecision::Reject {
543                            reason: "timeout".into(),
544                        },
545                    );
546                    next_action = NextAction::Next;
547                }
548                BarrierOutcome::Cancelled => {
549                    let _ = event_tx
550                        .send(GraphEvent::GraphError {
551                            error: GraphError::Terminal(
552                                crate::error::TerminalError::BarrierCancelled {
553                                    node: node_name.clone(),
554                                },
555                            ),
556                            state: engine.state().clone(),
557                        })
558                        .await;
559                    break;
560                }
561            }
562        }
563
564        // 处理路由
565        match next_action {
566            NextAction::End => {
567                send_complete(
568                    &event_tx,
569                    trace_id,
570                    engine,
571                    execution_log,
572                    start_time,
573                    trace_sink.take(),
574                );
575                break;
576            }
577            NextAction::Goto(target) => {
578                current = target;
579            }
580            NextAction::Next => {
581                if current == graph.end_node() {
582                    send_complete(
583                        &event_tx,
584                        trace_id,
585                        engine,
586                        execution_log,
587                        start_time,
588                        trace_sink.take(),
589                    );
590                    break;
591                }
592                match graph.resolve_next_inline(&current, engine.state()) {
593                    Ok(target) => current = target,
594                    Err(e) => {
595                        let _ = event_tx
596                            .send(GraphEvent::GraphError {
597                                error: e,
598                                state: engine.state().clone(),
599                            })
600                            .await;
601                        break;
602                    }
603                }
604            }
605        }
606
607        // Checkpoint Save Path — 路由已解析,current 是下一个节点
608        if let Some(ref cp_config) = checkpoint {
609            if cp_config.should_save(has_mutations, is_barrier) {
610                let cp = Checkpoint::new(&current, engine.state().clone(), cp_config.graph_hash);
611                match (cp_config.save_fn)(cp, trace_id).await {
612                    Ok(()) => {
613                        // 保存成功后,应用保留策略
614                        if let Err(e) = cp_config.apply_retention(&trace_id).await {
615                            tracing::warn!(error = %e, "checkpoint retention failed");
616                        }
617                    }
618                    Err(e) => {
619                        tracing::warn!(error = %e, "checkpoint save failed");
620                    }
621                }
622            }
623        }
624    }
625}
626
627/// 发送 GraphComplete 事件。
628pub(crate) fn send_complete<S: WorkflowState>(
629    event_tx: &tokio::sync::mpsc::Sender<GraphEvent<S>>,
630    trace_id: TraceId,
631    engine: ExecutionEngine<S>,
632    execution_log: Vec<ExecutionEntry>,
633    start_time: Instant,
634    trace_sink: Option<MemoryTraceSink<S::Mutation>>,
635) {
636    let duration = start_time.elapsed();
637    let final_state = engine.into_state();
638    let trace = trace_sink.map(|sink| sink.into_trace());
639    let result = GraphResult {
640        trace_id,
641        state: final_state,
642        execution_log,
643        duration,
644        trace,
645    };
646    let _ = event_tx.try_send(GraphEvent::GraphComplete { result });
647}