Skip to main content

enact_core/kernel/
interrupt.rs

1//! Interruptable execution for plan approval flows
2//!
3//! Kernel-owned interrupt handling logic.
4
5use super::ids::ExecutionId;
6use crate::runner::approval_policy::ApprovalPolicy;
7use crate::streaming::{EventEmitter, StreamEvent};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::sync::Arc;
11use tokio::sync::mpsc;
12
13/// Interrupt reason
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub enum InterruptReason {
16    /// Waiting for plan approval
17    PlanApproval { plan: Value },
18    /// Waiting for user input
19    UserInput { prompt: String },
20    /// Custom interrupt
21    Custom { reason: String, data: Value },
22}
23
24/// User decision for an interrupt
25#[derive(Debug, Clone)]
26pub enum InterruptDecision {
27    /// Approve and continue
28    Approve,
29    /// Reject and abort/replan
30    Reject { reason: Option<String> },
31    /// Provide input and continue
32    Input { value: Value },
33}
34
35/// Interruptable execution handler
36pub struct InterruptableRunner {
37    execution_id: ExecutionId,
38    emitter: EventEmitter,
39    interrupt_tx: mpsc::Sender<InterruptDecision>,
40    interrupt_rx: Option<mpsc::Receiver<InterruptDecision>>,
41    /// Optional approval policy for plan approval
42    approval_policy: Option<Arc<dyn ApprovalPolicy>>,
43}
44
45impl InterruptableRunner {
46    pub fn new() -> Self {
47        let (tx, rx) = mpsc::channel(1);
48        Self {
49            execution_id: ExecutionId::new(),
50            emitter: EventEmitter::new(),
51            interrupt_tx: tx,
52            interrupt_rx: Some(rx),
53            approval_policy: None,
54        }
55    }
56
57    /// Set the approval policy for this runner
58    pub fn with_approval_policy<P: ApprovalPolicy + 'static>(mut self, policy: P) -> Self {
59        self.approval_policy = Some(Arc::new(policy));
60        self
61    }
62
63    /// Set the approval policy from an Arc
64    pub fn with_approval_policy_arc(mut self, policy: Arc<dyn ApprovalPolicy>) -> Self {
65        self.approval_policy = Some(policy);
66        self
67    }
68
69    /// Check if a plan requires approval based on the configured policy
70    /// Returns the approval reason if approval is required, None otherwise
71    pub fn check_plan_approval(&self, plan: &Value) -> Option<String> {
72        self.approval_policy
73            .as_ref()
74            .and_then(|p| p.approval_reason(plan))
75    }
76
77    /// Check if a plan requires approval (boolean)
78    pub fn requires_approval(&self, plan: &Value) -> bool {
79        self.approval_policy
80            .as_ref()
81            .map(|p| p.requires_approval(plan))
82            .unwrap_or(false)
83    }
84
85    /// Get the approval policy name (for logging)
86    pub fn policy_name(&self) -> Option<&str> {
87        self.approval_policy.as_ref().map(|p| p.name())
88    }
89
90    /// Get the execution ID
91    pub fn execution_id(&self) -> &ExecutionId {
92        &self.execution_id
93    }
94
95    pub fn emitter(&self) -> &EventEmitter {
96        &self.emitter
97    }
98
99    /// Get the decision sender (for external use to send decisions)
100    pub fn decision_sender(&self) -> mpsc::Sender<InterruptDecision> {
101        self.interrupt_tx.clone()
102    }
103
104    /// Send approval decision
105    pub async fn approve(&self) {
106        let _ = self.interrupt_tx.send(InterruptDecision::Approve).await;
107    }
108
109    /// Send rejection decision
110    pub async fn reject(&self, reason: Option<String>) {
111        let _ = self
112            .interrupt_tx
113            .send(InterruptDecision::Reject { reason })
114            .await;
115    }
116
117    /// Request approval if the policy requires it
118    ///
119    /// This combines policy check and approval request into a single method.
120    /// Returns Ok(true) if:
121    /// - No approval policy is set
122    /// - Policy doesn't require approval for this plan
123    /// - Approval is requested and granted
124    ///
125    /// Returns Ok(false) if approval is rejected.
126    /// Returns Err if there's a communication error.
127    pub async fn request_approval_if_needed(&mut self, plan: &Value) -> anyhow::Result<bool> {
128        // No policy or policy doesn't require approval
129        if !self.requires_approval(plan) {
130            return Ok(true);
131        }
132
133        // Log the approval reason
134        if let Some(reason) = self.check_plan_approval(plan) {
135            tracing::info!(
136                execution_id = %self.execution_id,
137                policy = ?self.policy_name(),
138                reason = %reason,
139                "Plan requires approval"
140            );
141        }
142
143        // Request approval
144        self.request_approval(plan.clone()).await
145    }
146
147    /// Request approval and wait for decision
148    pub async fn request_approval(&mut self, plan: Value) -> anyhow::Result<bool> {
149        // Emit waiting event
150        self.emitter.emit(StreamEvent::text_delta(
151            self.execution_id.as_str(),
152            format!(
153                "Waiting for approval: {}",
154                serde_json::to_string_pretty(&plan)?
155            ),
156        ));
157
158        // Wait for decision
159        if let Some(ref mut rx) = self.interrupt_rx {
160            if let Some(decision) = rx.recv().await {
161                match decision {
162                    InterruptDecision::Approve => {
163                        self.emitter.emit(StreamEvent::text_delta(
164                            self.execution_id.as_str(),
165                            "Plan approved, continuing execution",
166                        ));
167                        Ok(true)
168                    }
169                    InterruptDecision::Reject { reason } => {
170                        self.emitter.emit(StreamEvent::text_delta(
171                            self.execution_id.as_str(),
172                            format!("Plan rejected: {:?}", reason),
173                        ));
174                        Ok(false)
175                    }
176                    InterruptDecision::Input { .. } => Ok(true),
177                }
178            } else {
179                anyhow::bail!("Interrupt channel closed")
180            }
181        } else {
182            anyhow::bail!("Interrupt receiver not available")
183        }
184    }
185}
186
187impl Default for InterruptableRunner {
188    fn default() -> Self {
189        Self::new()
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::runner::approval_policy::{
197        AlwaysApprovePolicy, AlwaysRequireApprovalPolicy, ThresholdApprovalPolicy,
198    };
199    use serde_json::json;
200
201    #[test]
202    fn test_runner_without_policy() {
203        let runner = InterruptableRunner::new();
204        let plan = json!({"steps": ["s1", "s2", "s3"]});
205
206        // Without policy, no approval required
207        assert!(!runner.requires_approval(&plan));
208        assert!(runner.check_plan_approval(&plan).is_none());
209        assert!(runner.policy_name().is_none());
210    }
211
212    #[test]
213    fn test_runner_with_always_approve_policy() {
214        let runner = InterruptableRunner::new().with_approval_policy(AlwaysApprovePolicy);
215        let plan = json!({"steps": ["s1", "s2", "s3"]});
216
217        assert!(!runner.requires_approval(&plan));
218        assert!(runner.check_plan_approval(&plan).is_none());
219        assert_eq!(runner.policy_name(), Some("always_approve"));
220    }
221
222    #[test]
223    fn test_runner_with_always_require_policy() {
224        let runner = InterruptableRunner::new().with_approval_policy(AlwaysRequireApprovalPolicy);
225        let plan = json!({"steps": ["s1"]});
226
227        assert!(runner.requires_approval(&plan));
228        assert!(runner.check_plan_approval(&plan).is_some());
229        assert_eq!(runner.policy_name(), Some("always_require"));
230    }
231
232    #[test]
233    fn test_runner_with_threshold_policy() {
234        let runner =
235            InterruptableRunner::new().with_approval_policy(ThresholdApprovalPolicy::new(2));
236
237        let small_plan = json!({"steps": ["s1", "s2"]});
238        assert!(!runner.requires_approval(&small_plan));
239
240        let large_plan = json!({"steps": ["s1", "s2", "s3"]});
241        assert!(runner.requires_approval(&large_plan));
242
243        let reason = runner.check_plan_approval(&large_plan).unwrap();
244        assert!(reason.contains("3 steps"));
245    }
246
247    #[test]
248    fn test_runner_with_policy_arc() {
249        let policy: Arc<dyn ApprovalPolicy> = Arc::new(AlwaysRequireApprovalPolicy);
250        let runner = InterruptableRunner::new().with_approval_policy_arc(policy);
251
252        assert!(runner.requires_approval(&json!({})));
253    }
254}