use super::ids::ExecutionId;
use crate::runner::approval_policy::ApprovalPolicy;
use crate::streaming::{EventEmitter, StreamEvent};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::mpsc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum InterruptReason {
PlanApproval { plan: Value },
UserInput { prompt: String },
Custom { reason: String, data: Value },
}
#[derive(Debug, Clone)]
pub enum InterruptDecision {
Approve,
Reject { reason: Option<String> },
Input { value: Value },
}
pub struct InterruptableRunner {
execution_id: ExecutionId,
emitter: EventEmitter,
interrupt_tx: mpsc::Sender<InterruptDecision>,
interrupt_rx: Option<mpsc::Receiver<InterruptDecision>>,
approval_policy: Option<Arc<dyn ApprovalPolicy>>,
}
impl InterruptableRunner {
pub fn new() -> Self {
let (tx, rx) = mpsc::channel(1);
Self {
execution_id: ExecutionId::new(),
emitter: EventEmitter::new(),
interrupt_tx: tx,
interrupt_rx: Some(rx),
approval_policy: None,
}
}
pub fn with_approval_policy<P: ApprovalPolicy + 'static>(mut self, policy: P) -> Self {
self.approval_policy = Some(Arc::new(policy));
self
}
pub fn with_approval_policy_arc(mut self, policy: Arc<dyn ApprovalPolicy>) -> Self {
self.approval_policy = Some(policy);
self
}
pub fn check_plan_approval(&self, plan: &Value) -> Option<String> {
self.approval_policy
.as_ref()
.and_then(|p| p.approval_reason(plan))
}
pub fn requires_approval(&self, plan: &Value) -> bool {
self.approval_policy
.as_ref()
.map(|p| p.requires_approval(plan))
.unwrap_or(false)
}
pub fn policy_name(&self) -> Option<&str> {
self.approval_policy.as_ref().map(|p| p.name())
}
pub fn execution_id(&self) -> &ExecutionId {
&self.execution_id
}
pub fn emitter(&self) -> &EventEmitter {
&self.emitter
}
pub fn decision_sender(&self) -> mpsc::Sender<InterruptDecision> {
self.interrupt_tx.clone()
}
pub async fn approve(&self) {
let _ = self.interrupt_tx.send(InterruptDecision::Approve).await;
}
pub async fn reject(&self, reason: Option<String>) {
let _ = self
.interrupt_tx
.send(InterruptDecision::Reject { reason })
.await;
}
pub async fn request_approval_if_needed(&mut self, plan: &Value) -> anyhow::Result<bool> {
if !self.requires_approval(plan) {
return Ok(true);
}
if let Some(reason) = self.check_plan_approval(plan) {
tracing::info!(
execution_id = %self.execution_id,
policy = ?self.policy_name(),
reason = %reason,
"Plan requires approval"
);
}
self.request_approval(plan.clone()).await
}
pub async fn request_approval(&mut self, plan: Value) -> anyhow::Result<bool> {
self.emitter.emit(StreamEvent::text_delta(
self.execution_id.as_str(),
format!(
"Waiting for approval: {}",
serde_json::to_string_pretty(&plan)?
),
));
if let Some(ref mut rx) = self.interrupt_rx {
if let Some(decision) = rx.recv().await {
match decision {
InterruptDecision::Approve => {
self.emitter.emit(StreamEvent::text_delta(
self.execution_id.as_str(),
"Plan approved, continuing execution",
));
Ok(true)
}
InterruptDecision::Reject { reason } => {
self.emitter.emit(StreamEvent::text_delta(
self.execution_id.as_str(),
format!("Plan rejected: {:?}", reason),
));
Ok(false)
}
InterruptDecision::Input { .. } => Ok(true),
}
} else {
anyhow::bail!("Interrupt channel closed")
}
} else {
anyhow::bail!("Interrupt receiver not available")
}
}
}
impl Default for InterruptableRunner {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runner::approval_policy::{
AlwaysApprovePolicy, AlwaysRequireApprovalPolicy, ThresholdApprovalPolicy,
};
use serde_json::json;
#[test]
fn test_runner_without_policy() {
let runner = InterruptableRunner::new();
let plan = json!({"steps": ["s1", "s2", "s3"]});
assert!(!runner.requires_approval(&plan));
assert!(runner.check_plan_approval(&plan).is_none());
assert!(runner.policy_name().is_none());
}
#[test]
fn test_runner_with_always_approve_policy() {
let runner = InterruptableRunner::new().with_approval_policy(AlwaysApprovePolicy);
let plan = json!({"steps": ["s1", "s2", "s3"]});
assert!(!runner.requires_approval(&plan));
assert!(runner.check_plan_approval(&plan).is_none());
assert_eq!(runner.policy_name(), Some("always_approve"));
}
#[test]
fn test_runner_with_always_require_policy() {
let runner = InterruptableRunner::new().with_approval_policy(AlwaysRequireApprovalPolicy);
let plan = json!({"steps": ["s1"]});
assert!(runner.requires_approval(&plan));
assert!(runner.check_plan_approval(&plan).is_some());
assert_eq!(runner.policy_name(), Some("always_require"));
}
#[test]
fn test_runner_with_threshold_policy() {
let runner =
InterruptableRunner::new().with_approval_policy(ThresholdApprovalPolicy::new(2));
let small_plan = json!({"steps": ["s1", "s2"]});
assert!(!runner.requires_approval(&small_plan));
let large_plan = json!({"steps": ["s1", "s2", "s3"]});
assert!(runner.requires_approval(&large_plan));
let reason = runner.check_plan_approval(&large_plan).unwrap();
assert!(reason.contains("3 steps"));
}
#[test]
fn test_runner_with_policy_arc() {
let policy: Arc<dyn ApprovalPolicy> = Arc::new(AlwaysRequireApprovalPolicy);
let runner = InterruptableRunner::new().with_approval_policy_arc(policy);
assert!(runner.requires_approval(&json!({})));
}
}