use crate::types::events::{ApprovalRequestEvent, PatchApprovalRequestEvent};
use serde::{Deserialize, Serialize};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ApprovalDecision {
Approved,
ApprovedForSession,
#[default]
Denied,
Abort,
}
#[derive(Debug, Clone)]
pub struct ApprovalOutcome {
pub decision: ApprovalDecision,
#[cfg(feature = "unstable")]
pub updated_command: Option<String>,
}
impl ApprovalOutcome {
pub fn new(decision: ApprovalDecision) -> Self {
Self {
decision,
#[cfg(feature = "unstable")]
updated_command: None,
}
}
#[cfg(feature = "unstable")]
pub fn with_updated_command(mut self, command: impl Into<String>) -> Self {
self.updated_command = Some(command.into());
self
}
}
impl From<ApprovalDecision> for ApprovalOutcome {
fn from(decision: ApprovalDecision) -> Self {
Self {
decision,
#[cfg(feature = "unstable")]
updated_command: None,
}
}
}
#[derive(Debug, Clone)]
pub struct ApprovalContext {
pub request: ApprovalRequestEvent,
pub thread_id: Option<String>,
}
pub type ApprovalCallback = Arc<
dyn Fn(ApprovalContext) -> Pin<Box<dyn Future<Output = ApprovalOutcome> + Send>> + Send + Sync,
>;
#[derive(Debug, Clone)]
pub struct PatchApprovalContext {
pub request: PatchApprovalRequestEvent,
pub thread_id: Option<String>,
}
pub type PatchApprovalCallback = Arc<
dyn Fn(PatchApprovalContext) -> Pin<Box<dyn Future<Output = ApprovalOutcome> + Send>>
+ Send
+ Sync,
>;
#[derive(Debug, Serialize)]
pub(crate) struct ApprovalResponse {
pub op: String,
pub id: String,
pub decision: ApprovalDecision,
#[serde(skip_serializing_if = "Option::is_none")]
pub turn_id: Option<String>,
}
impl ApprovalResponse {
pub fn new(request_id: String, decision: ApprovalDecision) -> Self {
Self {
op: "ExecApproval".to_string(),
id: request_id,
decision,
turn_id: None,
}
}
}
#[derive(Debug, Serialize)]
pub(crate) struct PatchApprovalResponse {
pub op: String,
pub id: String,
pub decision: ApprovalDecision,
}
impl PatchApprovalResponse {
pub fn new(request_id: String, decision: ApprovalDecision) -> Self {
Self {
op: "PatchApproval".to_string(),
id: request_id,
decision,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn approval_response_serializes() {
let response = ApprovalResponse::new("req-1".into(), ApprovalDecision::Approved);
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("ExecApproval"));
assert!(json.contains("approved"));
assert!(!json.contains("turn_id")); }
#[test]
fn approval_decision_round_trip() {
let decision = ApprovalDecision::ApprovedForSession;
let json = serde_json::to_string(&decision).unwrap();
let parsed: ApprovalDecision = serde_json::from_str(&json).unwrap();
assert!(matches!(parsed, ApprovalDecision::ApprovedForSession));
}
#[test]
fn patch_approval_response_serializes() {
let response = PatchApprovalResponse::new("req-2".into(), ApprovalDecision::Denied);
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("PatchApproval"));
assert!(json.contains("denied"));
}
#[test]
fn default_decision_is_denied() {
assert!(matches!(
ApprovalDecision::default(),
ApprovalDecision::Denied
));
}
#[test]
fn approval_outcome_from_decision() {
let outcome: ApprovalOutcome = ApprovalDecision::Approved.into();
assert!(matches!(outcome.decision, ApprovalDecision::Approved));
}
#[cfg(feature = "unstable")]
#[test]
fn approval_outcome_with_updated_command() {
let outcome = ApprovalOutcome::new(ApprovalDecision::Approved)
.with_updated_command("safe-command --flag");
assert!(matches!(outcome.decision, ApprovalDecision::Approved));
assert_eq!(
outcome.updated_command,
Some("safe-command --flag".to_string())
);
}
#[cfg(feature = "unstable")]
#[test]
fn approval_outcome_wire_compatible() {
let outcome =
ApprovalOutcome::new(ApprovalDecision::Approved).with_updated_command("overridden");
let response = ApprovalResponse::new("req-1".into(), outcome.decision);
let json = serde_json::to_string(&response).unwrap();
assert!(!json.contains("overridden"));
assert!(json.contains("approved"));
}
}