Skip to main content

lellm_graph/
barrier_node.rs

1//! Human-in-the-loop 审批节点。
2//!
3//! BarrierNode 在执行时暂停 Graph,通过 `GraphHandle::decide()` 等待外部决策。
4//! 消费者收到 `GraphEvent::BarrierPaused` 后,通过 `GraphHandle` 发送 [`BarrierDecision`]。
5
6use async_trait::async_trait;
7
8use crate::error::{GraphError, TerminalError};
9use crate::event::{BarrierDecision, BarrierId, GraphEvent, SpanId};
10use crate::node::{GraphNode, NextStep, StreamNodeResult};
11use crate::state::State;
12
13/// Barrier 超时后的默认行为。
14#[derive(Debug, Clone, Default)]
15pub enum BarrierDefaultAction {
16    /// 超时视为拒绝
17    #[default]
18    Reject,
19    /// 超时视为通过
20    Approve,
21    /// 超时跳过(继续下一步)
22    Skip,
23}
24
25/// Human-in-the-loop 审批节点。
26///
27/// 执行流程:
28/// 1. 返回 `StreamNodeResult::BarrierPaused`,executor 发射 `BarrierPaused` 事件
29/// 2. 消费者通过 `GraphHandle::decide(barrier_id, decision)` 提交决策
30/// 3. executor 的 `wait_barrier_decision()` 接收决策,调用 `apply_decision()` 应用
31///
32/// **阻塞模式不支持。** 调用 `execute()` 直接报错,引导使用 `execute_stream()`。
33pub struct BarrierNode {
34    pub name: String,
35    /// 超时时间(None = 无限等待)
36    pub timeout: Option<std::time::Duration>,
37    /// 超时默认行为
38    pub default_action: BarrierDefaultAction,
39    /// 拒绝原因写入 State 的 key 后缀(默认 "{name}.reject_reason")
40    pub reject_key: String,
41    /// 审批通过后写入 State 的标记 key(默认 "{name}.approved")
42    pub approve_key: String,
43}
44
45impl BarrierNode {
46    pub fn new(name: impl Into<String>) -> Self {
47        let name = name.into();
48        Self {
49            name: name.clone(),
50            timeout: None,
51            default_action: BarrierDefaultAction::default(),
52            reject_key: format!("{name}.reject_reason"),
53            approve_key: format!("{name}.approved"),
54        }
55    }
56
57    /// 设置超时时间。超时后按 `default_action` 处理。
58    pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
59        self.timeout = Some(timeout);
60        self
61    }
62
63    /// 设置超时默认行为(默认 Reject)。
64    pub fn default_action(mut self, action: BarrierDefaultAction) -> Self {
65        self.default_action = action;
66        self
67    }
68
69    /// 设置拒绝原因写入 State 的 key(默认 "{name}.reject_reason")。
70    pub fn reject_key(mut self, key: impl Into<String>) -> Self {
71        self.reject_key = key.into();
72        self
73    }
74
75    /// 设置审批标记写入 State 的 key(默认 "{name}.approved")。
76    pub fn approve_key(mut self, key: impl Into<String>) -> Self {
77        self.approve_key = key.into();
78        self
79    }
80
81    /// 处理决策结果 — 写入 State 并返回 NextStep。
82    ///
83    /// 由 executor 在收到外部决策后调用。
84    pub fn apply_decision(
85        &self,
86        decision: BarrierDecision,
87        state: &mut State,
88    ) -> Result<NextStep, GraphError> {
89        match decision {
90            BarrierDecision::Approve => {
91                tracing::info!(barrier = %self.name, "approved");
92                state.insert(self.approve_key.clone(), serde_json::json!(true));
93                state.remove(&self.reject_key);
94                Ok(NextStep::GoToNext)
95            }
96            BarrierDecision::Reject { reason } => {
97                tracing::warn!(barrier = %self.name, reason = %reason, "rejected");
98                state.insert(self.reject_key.clone(), serde_json::json!(reason));
99                state.remove(&self.approve_key);
100                Ok(NextStep::GoToNext)
101            }
102            BarrierDecision::Modify { key, value } => {
103                tracing::info!(barrier = %self.name, key = %key, "state modified");
104                state.insert(key, value);
105                Ok(NextStep::GoToNext)
106            }
107            BarrierDecision::Reroute { target } => {
108                tracing::info!(barrier = %self.name, target = %target, "rerouted");
109                Ok(NextStep::Goto(target))
110            }
111        }
112    }
113}
114
115#[async_trait]
116impl GraphNode for BarrierNode {
117    /// 阻塞模式不支持 BarrierNode — 直接报错。
118    async fn execute(&self, _state: &mut State) -> Result<NextStep, GraphError> {
119        Err(GraphError::Terminal(TerminalError::InvalidGraph(format!(
120            "BarrierNode '{}' requires stream mode. Use GraphExecutor::execute_stream() for human-in-the-loop.",
121            self.name
122        ))))
123    }
124
125    /// 流式执行 — 返回 BarrierPaused,由 executor 发射事件并等待决策。
126    async fn execute_stream(
127        &self,
128        _state: &mut State,
129        _sink: &tokio::sync::mpsc::Sender<GraphEvent>,
130        span_id: SpanId,
131    ) -> Result<StreamNodeResult, GraphError> {
132        let node_name = self.name.clone();
133
134        // barrier_id 由 executor 的 DecisionRegistry 生成
135        // 这里传一个 placeholder,executor 会用 DecisionRegistry::next_id() 覆盖
136        let barrier_id = BarrierId::new(&node_name, 0);
137
138        // 返回 BarrierPaused,由 executor 发射 BarrierWaiting 事件
139        Ok(StreamNodeResult::BarrierPaused {
140            barrier_id,
141            node_name,
142            span_id,
143            timeout: self.timeout,
144            default_action: self.default_action.clone(),
145        })
146    }
147}