Skip to main content

lellm_graph/
barrier_node.rs

1//! Human-in-the-loop 审批节点。
2//!
3//! BarrierNode 在执行时暂停 Graph,通过 `GraphHandle::decide()` 等待外部决策。
4//! 消费者收到 `GraphEvent::BarrierPaused` 后,通过 `GraphHandle` 发送 [`BarrierDecision`]。
5
6use async_trait::async_trait;
7
8use crate::delta::StateDelta;
9use crate::error::{GraphError, TerminalError};
10use crate::event::{BarrierDecision, BarrierId, GraphEvent};
11use crate::ids::SpanId;
12use crate::node::{FlowNode, NextStep, NodeMetadata, NodeOutput, StreamNodeResult};
13use crate::state::State;
14
15/// Barrier 超时后的默认行为。
16#[derive(Debug, Clone, Default)]
17pub enum BarrierDefaultAction {
18    /// 超时视为拒绝
19    #[default]
20    Reject,
21    /// 超时视为通过
22    Approve,
23    /// 超时跳过(继续下一步)
24    Skip,
25}
26
27/// Human-in-the-loop 审批节点。
28///
29/// 执行流程:
30/// 1. 返回 `StreamNodeResult::BarrierPaused`,executor 发射 `BarrierPaused` 事件
31/// 2. 消费者通过 `GraphHandle::decide(barrier_id, decision)` 提交决策
32/// 3. executor 的 `wait_barrier_decision()` 接收决策,调用 `apply_decision()` 应用
33///
34/// **阻塞模式不支持。** 调用 `execute()` 直接报错,引导使用 `execute_stream()`。
35#[derive(Debug, Clone)]
36pub struct BarrierNode {
37    pub name: String,
38    /// 超时时间(None = 无限等待)
39    pub timeout: Option<std::time::Duration>,
40    /// 超时默认行为
41    pub default_action: BarrierDefaultAction,
42    /// 拒绝原因写入 State 的 key 后缀(默认 "{name}.reject_reason")
43    pub reject_key: String,
44    /// 审批通过后写入 State 的标记 key(默认 "{name}.approved")
45    pub approve_key: String,
46}
47
48impl BarrierNode {
49    pub fn new(name: impl Into<String>) -> Self {
50        let name = name.into();
51        Self {
52            name: name.clone(),
53            timeout: None,
54            default_action: BarrierDefaultAction::default(),
55            reject_key: format!("{name}.reject_reason"),
56            approve_key: format!("{name}.approved"),
57        }
58    }
59
60    /// 设置超时时间。超时后按 `default_action` 处理。
61    pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
62        self.timeout = Some(timeout);
63        self
64    }
65
66    /// 设置超时默认行为(默认 Reject)。
67    pub fn default_action(mut self, action: BarrierDefaultAction) -> Self {
68        self.default_action = action;
69        self
70    }
71
72    /// 设置拒绝原因写入 State 的 key(默认 "{name}.reject_reason")。
73    pub fn reject_key(mut self, key: impl Into<String>) -> Self {
74        self.reject_key = key.into();
75        self
76    }
77
78    /// 设置审批标记写入 State 的 key(默认 "{name}.approved")。
79    pub fn approve_key(mut self, key: impl Into<String>) -> Self {
80        self.approve_key = key.into();
81        self
82    }
83
84    /// 处理决策结果 — 返回 NextStep + StateDelta,不直接修改 State。
85    ///
86    /// 由 executor 在收到外部决策后调用。Executor 负责 apply deltas。
87    pub fn apply_decision(&self, decision: BarrierDecision) -> (NextStep, Vec<StateDelta>) {
88        match decision {
89            BarrierDecision::Approve => {
90                tracing::info!(barrier = %self.name, "approved");
91                let deltas = vec![
92                    StateDelta::put(&self.approve_key, serde_json::json!(true)),
93                    StateDelta::delete(&self.reject_key),
94                ];
95                (NextStep::GoToNext, deltas)
96            }
97            BarrierDecision::Reject { reason } => {
98                tracing::warn!(barrier = %self.name, reason = %reason, "rejected");
99                let deltas = vec![
100                    StateDelta::put(&self.reject_key, serde_json::json!(reason)),
101                    StateDelta::delete(&self.approve_key),
102                ];
103                (NextStep::GoToNext, deltas)
104            }
105            BarrierDecision::Modify { key, value } => {
106                tracing::info!(barrier = %self.name, key = %key, "state modified");
107                let deltas = vec![StateDelta::put(key, value)];
108                (NextStep::GoToNext, deltas)
109            }
110            BarrierDecision::Reroute { target } => {
111                tracing::info!(barrier = %self.name, target = %target, "rerouted");
112                (NextStep::Goto(target), vec![])
113            }
114        }
115    }
116}
117
118#[async_trait]
119impl FlowNode for BarrierNode {
120    /// 阻塞模式不支持 BarrierNode — 直接报错。
121    async fn execute(&self, _state: &State) -> Result<NodeOutput, GraphError> {
122        Err(GraphError::Terminal(TerminalError::InvalidGraph(format!(
123            "BarrierNode '{}' requires stream mode. Use GraphExecutor::execute_stream() for human-in-the-loop.",
124            self.name
125        ))))
126    }
127
128    /// 流式执行 — 返回 Pause,由 executor 发射事件并等待决策。
129    async fn execute_stream(
130        &self,
131        _state: &State,
132        _sink: &tokio::sync::mpsc::Sender<GraphEvent>,
133        span_id: SpanId,
134    ) -> Result<StreamNodeResult, GraphError> {
135        let node_name = self.name.clone();
136
137        // barrier_id 由 executor 的 DecisionRegistry 生成
138        // 这里传一个 placeholder,executor 会用 DecisionRegistry::next_id() 覆盖
139        let barrier_id = BarrierId::new(&node_name, 0);
140
141        // 返回 Pause,由 executor 发射 BarrierWaiting 事件
142        Ok(StreamNodeResult::Pause {
143            deltas: vec![],
144            barrier_id,
145            node_name,
146            span_id,
147            timeout: self.timeout,
148            default_action: self.default_action.clone(),
149        })
150    }
151
152    fn metadata_hint(&self) -> NodeMetadata {
153        // BarrierNode 是 Human-in-the-loop,权重高
154        NodeMetadata {
155            token_cost: 0.0,
156            has_side_effects: true, // 审批后可能触发外部操作
157        }
158    }
159}