use async_trait::async_trait;
use entelix_core::context::ExecutionContext;
use entelix_core::error::{Error, Result};
use serde_json::Value;
use tokio::sync::{mpsc, oneshot};
pub use entelix_core::ApprovalDecision;
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct ApprovalRequest {
pub id: String,
pub name: String,
pub input: Value,
}
impl ApprovalRequest {
#[must_use]
pub fn new(id: impl Into<String>, name: impl Into<String>, input: Value) -> Self {
Self {
id: id.into(),
name: name.into(),
input,
}
}
}
#[async_trait]
pub trait Approver: Send + Sync + 'static {
async fn decide(
&self,
request: &ApprovalRequest,
ctx: &ExecutionContext,
) -> Result<ApprovalDecision>;
}
#[derive(Clone, Copy, Debug, Default)]
pub struct AlwaysApprove;
#[async_trait]
impl Approver for AlwaysApprove {
async fn decide(
&self,
_request: &ApprovalRequest,
_ctx: &ExecutionContext,
) -> Result<ApprovalDecision> {
Ok(ApprovalDecision::Approve)
}
}
#[derive(Clone, Copy, Debug)]
pub struct ChannelApproverConfig {
pub timeout: std::time::Duration,
}
impl Default for ChannelApproverConfig {
fn default() -> Self {
Self {
timeout: std::time::Duration::from_mins(5),
}
}
}
#[derive(Debug)]
pub struct PendingApproval {
pub request: ApprovalRequest,
pub reply: oneshot::Sender<ApprovalDecision>,
}
pub struct ChannelApprover {
tx: mpsc::Sender<PendingApproval>,
config: ChannelApproverConfig,
}
impl ChannelApprover {
#[must_use]
pub fn new(capacity: usize) -> (Self, mpsc::Receiver<PendingApproval>) {
Self::with_config(capacity, ChannelApproverConfig::default())
}
#[must_use]
pub fn with_config(
capacity: usize,
config: ChannelApproverConfig,
) -> (Self, mpsc::Receiver<PendingApproval>) {
let (tx, rx) = mpsc::channel(capacity);
(Self { tx, config }, rx)
}
fn deadline_for(&self, ctx: &ExecutionContext) -> tokio::time::Instant {
let cfg_deadline = tokio::time::Instant::now() + self.config.timeout;
ctx.deadline()
.map_or(cfg_deadline, |ctx_deadline| ctx_deadline.min(cfg_deadline))
}
}
impl Clone for ChannelApprover {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
config: self.config,
}
}
}
impl std::fmt::Debug for ChannelApprover {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChannelApprover")
.field("timeout", &self.config.timeout)
.field("closed", &self.tx.is_closed())
.finish()
}
}
#[async_trait]
impl Approver for ChannelApprover {
async fn decide(
&self,
request: &ApprovalRequest,
ctx: &ExecutionContext,
) -> Result<ApprovalDecision> {
if ctx.is_cancelled() {
return Err(Error::Cancelled);
}
let (reply_tx, reply_rx) = oneshot::channel();
let pending = PendingApproval {
request: request.clone(),
reply: reply_tx,
};
self.tx.send(pending).await.map_err(|_| {
Error::config("ChannelApprover: receiver dropped before approval was requested")
})?;
let deadline = self.deadline_for(ctx);
let cancellation = ctx.cancellation().clone();
tokio::select! {
biased;
() = cancellation.cancelled() => Err(Error::Cancelled),
decision = reply_rx => decision.map_err(|_| {
Error::config("ChannelApprover: reply channel dropped without decision")
}),
() = tokio::time::sleep_until(deadline) => Ok(ApprovalDecision::Reject {
reason: format!(
"supervised approval timed out (no decision within {:?})",
self.config.timeout
),
}),
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use std::time::Duration;
use super::*;
fn req() -> ApprovalRequest {
ApprovalRequest::new("call-1", "echo", serde_json::json!({"x": 1}))
}
#[tokio::test]
async fn always_approve_returns_approve() {
let approver = AlwaysApprove;
let decision = approver
.decide(&req(), &ExecutionContext::new())
.await
.unwrap();
assert!(matches!(decision, ApprovalDecision::Approve));
}
#[tokio::test]
async fn channel_approver_round_trips_approve() {
let (approver, mut rx) = ChannelApprover::new(4);
let approver_clone = approver.clone();
let decide = tokio::spawn(async move {
approver_clone
.decide(&req(), &ExecutionContext::new())
.await
});
let pending = rx.recv().await.unwrap();
assert_eq!(pending.request.id, "call-1");
pending.reply.send(ApprovalDecision::Approve).unwrap();
let decision = decide.await.unwrap().unwrap();
assert!(matches!(decision, ApprovalDecision::Approve));
}
#[tokio::test]
async fn channel_approver_round_trips_reject_with_reason() {
let (approver, mut rx) = ChannelApprover::new(4);
let approver_clone = approver.clone();
let decide = tokio::spawn(async move {
approver_clone
.decide(&req(), &ExecutionContext::new())
.await
});
let pending = rx.recv().await.unwrap();
pending
.reply
.send(ApprovalDecision::Reject {
reason: "operator denied".into(),
})
.unwrap();
let decision = decide.await.unwrap().unwrap();
match decision {
ApprovalDecision::Reject { reason } => assert_eq!(reason, "operator denied"),
other => panic!("expected Reject, got {other:?}"),
}
}
#[tokio::test]
async fn channel_approver_times_out_when_operator_silent() {
let (approver, _rx_keeper) = ChannelApprover::with_config(
4,
ChannelApproverConfig {
timeout: Duration::from_millis(50),
},
);
let decision = approver
.decide(&req(), &ExecutionContext::new())
.await
.unwrap();
match decision {
ApprovalDecision::Reject { reason } => {
assert!(reason.contains("timed out"), "{reason}");
}
other => panic!("expected Reject(timeout), got {other:?}"),
}
}
#[tokio::test]
async fn channel_approver_propagates_cancellation() {
let (approver, _rx_keeper) = ChannelApprover::new(4);
let ctx = ExecutionContext::new();
let cancellation = ctx.cancellation().clone();
let approver_clone = approver.clone();
let decide = tokio::spawn(async move { approver_clone.decide(&req(), &ctx).await });
tokio::time::sleep(Duration::from_millis(10)).await;
cancellation.cancel();
let result = decide.await.unwrap();
assert!(matches!(result, Err(Error::Cancelled)));
}
#[tokio::test]
async fn channel_approver_errors_when_receiver_dropped_before_request() {
let (approver, rx) = ChannelApprover::new(4);
drop(rx);
let err = approver
.decide(&req(), &ExecutionContext::new())
.await
.unwrap_err();
assert!(format!("{err}").contains("receiver dropped"));
}
}