use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use wesichain_core::{state::StateSchema, WesichainError};
use crate::{GraphContext, GraphNode, GraphState, StateUpdate};
#[derive(Debug, Clone)]
pub enum ApprovalDecision {
Approved { comment: Option<String> },
Denied { reason: String },
Modified { new_input: String },
}
#[derive(Debug, Clone, Copy, Default)]
pub enum ApprovalDefault {
#[default]
Deny,
Approve,
}
pub struct ApprovalRequest {
pub run_id: String,
pub checkpoint_id: String,
pub prompt: String,
pub respond: oneshot::Sender<ApprovalDecision>,
}
pub struct ApprovalChannel {
rx: mpsc::Receiver<ApprovalRequest>,
}
impl ApprovalChannel {
pub async fn recv(&mut self) -> Option<ApprovalRequest> {
self.rx.recv().await
}
}
pub struct ApprovalGate {
pub prompt: String,
pub timeout: Option<Duration>,
pub default: ApprovalDefault,
tx: mpsc::Sender<ApprovalRequest>,
}
impl ApprovalGate {
pub fn new(prompt: impl Into<String>) -> (Self, ApprovalChannel) {
let (tx, rx) = mpsc::channel(16);
let gate = Self {
prompt: prompt.into(),
timeout: None,
default: ApprovalDefault::Deny,
tx,
};
(gate, ApprovalChannel { rx })
}
pub fn with_timeout(mut self, duration: Duration, default: ApprovalDefault) -> Self {
self.timeout = Some(duration);
self.default = default;
self
}
pub async fn request(
&self,
run_id: impl Into<String>,
checkpoint_id: impl Into<String>,
action_description: impl Into<String>,
) -> Result<ApprovalDecision, WesichainError> {
let prompt = self.prompt.replace("{action}", &action_description.into());
let (resp_tx, resp_rx) = oneshot::channel();
let req = ApprovalRequest {
run_id: run_id.into(),
checkpoint_id: checkpoint_id.into(),
prompt,
respond: resp_tx,
};
self.tx.send(req).await.map_err(|_| {
WesichainError::Custom("ApprovalChannel receiver dropped".to_string())
})?;
match self.timeout {
Some(dur) => match tokio::time::timeout(dur, resp_rx).await {
Ok(Ok(decision)) => Ok(decision),
Ok(Err(_)) => Err(WesichainError::Custom(
"Approval responder dropped without sending a decision".to_string(),
)),
Err(_elapsed) => Ok(match self.default {
ApprovalDefault::Approve => {
ApprovalDecision::Approved { comment: Some("auto-approved (timeout)".to_string()) }
}
ApprovalDefault::Deny => {
ApprovalDecision::Denied { reason: "timed out — auto-denied".to_string() }
}
}),
},
None => resp_rx.await.map_err(|_| {
WesichainError::Custom(
"Approval responder dropped without sending a decision".to_string(),
)
}),
}
}
}
#[async_trait::async_trait]
impl<S> GraphNode<S> for ApprovalGate
where
S: StateSchema<Update = S> + Send,
{
async fn invoke_with_context(
&self,
input: GraphState<S>,
ctx: &GraphContext,
) -> Result<StateUpdate<S>, WesichainError> {
let checkpoint_id = ctx.node_id.clone();
self.request(&ctx.node_id, &checkpoint_id, "agent action").await?;
Ok(StateUpdate::new(input.data))
}
}
#[derive(Debug, Clone)]
pub struct ApprovalState {
pub run_id: String,
pub checkpoint_id: String,
pub decision: ApprovalDecision,
}