enact_core/kernel/
interrupt.rs1use super::ids::ExecutionId;
6use crate::runner::approval_policy::ApprovalPolicy;
7use crate::streaming::{EventEmitter, StreamEvent};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::sync::Arc;
11use tokio::sync::mpsc;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub enum InterruptReason {
16 PlanApproval { plan: Value },
18 UserInput { prompt: String },
20 Custom { reason: String, data: Value },
22}
23
24#[derive(Debug, Clone)]
26pub enum InterruptDecision {
27 Approve,
29 Reject { reason: Option<String> },
31 Input { value: Value },
33}
34
35pub struct InterruptableRunner {
37 execution_id: ExecutionId,
38 emitter: EventEmitter,
39 interrupt_tx: mpsc::Sender<InterruptDecision>,
40 interrupt_rx: Option<mpsc::Receiver<InterruptDecision>>,
41 approval_policy: Option<Arc<dyn ApprovalPolicy>>,
43}
44
45impl InterruptableRunner {
46 pub fn new() -> Self {
47 let (tx, rx) = mpsc::channel(1);
48 Self {
49 execution_id: ExecutionId::new(),
50 emitter: EventEmitter::new(),
51 interrupt_tx: tx,
52 interrupt_rx: Some(rx),
53 approval_policy: None,
54 }
55 }
56
57 pub fn with_approval_policy<P: ApprovalPolicy + 'static>(mut self, policy: P) -> Self {
59 self.approval_policy = Some(Arc::new(policy));
60 self
61 }
62
63 pub fn with_approval_policy_arc(mut self, policy: Arc<dyn ApprovalPolicy>) -> Self {
65 self.approval_policy = Some(policy);
66 self
67 }
68
69 pub fn check_plan_approval(&self, plan: &Value) -> Option<String> {
72 self.approval_policy
73 .as_ref()
74 .and_then(|p| p.approval_reason(plan))
75 }
76
77 pub fn requires_approval(&self, plan: &Value) -> bool {
79 self.approval_policy
80 .as_ref()
81 .map(|p| p.requires_approval(plan))
82 .unwrap_or(false)
83 }
84
85 pub fn policy_name(&self) -> Option<&str> {
87 self.approval_policy.as_ref().map(|p| p.name())
88 }
89
90 pub fn execution_id(&self) -> &ExecutionId {
92 &self.execution_id
93 }
94
95 pub fn emitter(&self) -> &EventEmitter {
96 &self.emitter
97 }
98
99 pub fn decision_sender(&self) -> mpsc::Sender<InterruptDecision> {
101 self.interrupt_tx.clone()
102 }
103
104 pub async fn approve(&self) {
106 let _ = self.interrupt_tx.send(InterruptDecision::Approve).await;
107 }
108
109 pub async fn reject(&self, reason: Option<String>) {
111 let _ = self
112 .interrupt_tx
113 .send(InterruptDecision::Reject { reason })
114 .await;
115 }
116
117 pub async fn request_approval_if_needed(&mut self, plan: &Value) -> anyhow::Result<bool> {
128 if !self.requires_approval(plan) {
130 return Ok(true);
131 }
132
133 if let Some(reason) = self.check_plan_approval(plan) {
135 tracing::info!(
136 execution_id = %self.execution_id,
137 policy = ?self.policy_name(),
138 reason = %reason,
139 "Plan requires approval"
140 );
141 }
142
143 self.request_approval(plan.clone()).await
145 }
146
147 pub async fn request_approval(&mut self, plan: Value) -> anyhow::Result<bool> {
149 self.emitter.emit(StreamEvent::text_delta(
151 self.execution_id.as_str(),
152 format!(
153 "Waiting for approval: {}",
154 serde_json::to_string_pretty(&plan)?
155 ),
156 ));
157
158 if let Some(ref mut rx) = self.interrupt_rx {
160 if let Some(decision) = rx.recv().await {
161 match decision {
162 InterruptDecision::Approve => {
163 self.emitter.emit(StreamEvent::text_delta(
164 self.execution_id.as_str(),
165 "Plan approved, continuing execution",
166 ));
167 Ok(true)
168 }
169 InterruptDecision::Reject { reason } => {
170 self.emitter.emit(StreamEvent::text_delta(
171 self.execution_id.as_str(),
172 format!("Plan rejected: {:?}", reason),
173 ));
174 Ok(false)
175 }
176 InterruptDecision::Input { .. } => Ok(true),
177 }
178 } else {
179 anyhow::bail!("Interrupt channel closed")
180 }
181 } else {
182 anyhow::bail!("Interrupt receiver not available")
183 }
184 }
185}
186
187impl Default for InterruptableRunner {
188 fn default() -> Self {
189 Self::new()
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::runner::approval_policy::{
197 AlwaysApprovePolicy, AlwaysRequireApprovalPolicy, ThresholdApprovalPolicy,
198 };
199 use serde_json::json;
200
201 #[test]
202 fn test_runner_without_policy() {
203 let runner = InterruptableRunner::new();
204 let plan = json!({"steps": ["s1", "s2", "s3"]});
205
206 assert!(!runner.requires_approval(&plan));
208 assert!(runner.check_plan_approval(&plan).is_none());
209 assert!(runner.policy_name().is_none());
210 }
211
212 #[test]
213 fn test_runner_with_always_approve_policy() {
214 let runner = InterruptableRunner::new().with_approval_policy(AlwaysApprovePolicy);
215 let plan = json!({"steps": ["s1", "s2", "s3"]});
216
217 assert!(!runner.requires_approval(&plan));
218 assert!(runner.check_plan_approval(&plan).is_none());
219 assert_eq!(runner.policy_name(), Some("always_approve"));
220 }
221
222 #[test]
223 fn test_runner_with_always_require_policy() {
224 let runner = InterruptableRunner::new().with_approval_policy(AlwaysRequireApprovalPolicy);
225 let plan = json!({"steps": ["s1"]});
226
227 assert!(runner.requires_approval(&plan));
228 assert!(runner.check_plan_approval(&plan).is_some());
229 assert_eq!(runner.policy_name(), Some("always_require"));
230 }
231
232 #[test]
233 fn test_runner_with_threshold_policy() {
234 let runner =
235 InterruptableRunner::new().with_approval_policy(ThresholdApprovalPolicy::new(2));
236
237 let small_plan = json!({"steps": ["s1", "s2"]});
238 assert!(!runner.requires_approval(&small_plan));
239
240 let large_plan = json!({"steps": ["s1", "s2", "s3"]});
241 assert!(runner.requires_approval(&large_plan));
242
243 let reason = runner.check_plan_approval(&large_plan).unwrap();
244 assert!(reason.contains("3 steps"));
245 }
246
247 #[test]
248 fn test_runner_with_policy_arc() {
249 let policy: Arc<dyn ApprovalPolicy> = Arc::new(AlwaysRequireApprovalPolicy);
250 let runner = InterruptableRunner::new().with_approval_policy_arc(policy);
251
252 assert!(runner.requires_approval(&json!({})));
253 }
254}