use async_trait::async_trait;
use serde::{Serialize, Serializer};
use serde_json::Value;
use crate::ctx::HookCtx;
use crate::errors::ToolError;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HookOutcome {
Continue,
Abort {
reason: String,
},
Block {
reason: String,
do_not_reply_again: bool,
},
Transform {
transformed_body: String,
reason: Option<String>,
do_not_reply_again: bool,
},
}
impl HookOutcome {
pub fn block(reason: impl Into<String>) -> Self {
HookOutcome::Block {
reason: reason.into(),
do_not_reply_again: false,
}
}
pub fn transform(body: impl Into<String>) -> Self {
HookOutcome::Transform {
transformed_body: body.into(),
reason: None,
do_not_reply_again: false,
}
}
}
impl Serialize for HookOutcome {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::SerializeMap;
let mut map = serializer.serialize_map(None)?;
match self {
HookOutcome::Continue => {
map.serialize_entry("abort", &false)?;
map.serialize_entry("decision", "allow")?;
}
HookOutcome::Abort { reason } => {
map.serialize_entry("abort", &true)?;
map.serialize_entry("decision", "block")?;
map.serialize_entry("reason", reason)?;
}
HookOutcome::Block {
reason,
do_not_reply_again,
} => {
map.serialize_entry("abort", &true)?;
map.serialize_entry("decision", "block")?;
map.serialize_entry("reason", reason)?;
if *do_not_reply_again {
map.serialize_entry("do_not_reply_again", &true)?;
}
}
HookOutcome::Transform {
transformed_body,
reason,
do_not_reply_again,
} => {
map.serialize_entry("abort", &false)?;
map.serialize_entry("decision", "transform")?;
map.serialize_entry("transformed_body", transformed_body)?;
if let Some(r) = reason {
map.serialize_entry("reason", r)?;
}
if *do_not_reply_again {
map.serialize_entry("do_not_reply_again", &true)?;
}
}
}
map.end()
}
}
#[async_trait]
pub trait HookHandler: Send + Sync {
async fn call(&self, args: Value, ctx: HookCtx) -> Result<HookOutcome, ToolError>;
}
#[async_trait]
impl<F, Fut> HookHandler for F
where
F: Fn(Value, HookCtx) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<HookOutcome, ToolError>> + Send,
{
async fn call(&self, args: Value, ctx: HookCtx) -> Result<HookOutcome, ToolError> {
(self)(args, ctx).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn continue_serialises_to_abort_false() {
let v = serde_json::to_value(&HookOutcome::Continue).unwrap();
assert_eq!(v["abort"], false);
assert!(v.get("reason").is_none());
}
#[test]
fn abort_serialises_with_reason() {
let v = serde_json::to_value(&HookOutcome::Abort {
reason: "spam".into(),
})
.unwrap();
assert_eq!(v["abort"], true);
assert_eq!(v["reason"], "spam");
}
#[tokio::test]
async fn blanket_impl_async_fn() {
async fn h(_args: Value, _ctx: HookCtx) -> Result<HookOutcome, ToolError> {
Ok(HookOutcome::Abort {
reason: "policy".into(),
})
}
let ctx = HookCtx {
agent_id: "a".into(),
binding: None,
inbound: None,
#[cfg(feature = "admin")]
admin: None,
};
let out = HookHandler::call(&h, Value::Null, ctx).await.unwrap();
assert!(matches!(out, HookOutcome::Abort { .. }));
}
#[test]
fn outcome_pattern_match_is_exhaustive_internal_use() {
let o = HookOutcome::Continue;
match o {
HookOutcome::Continue => {}
HookOutcome::Abort { .. } => {}
HookOutcome::Block { .. } => {}
HookOutcome::Transform { .. } => {} }
}
#[test]
fn block_serialises_with_decision_field() {
let out = HookOutcome::Block {
reason: "anti-loop".into(),
do_not_reply_again: false,
};
let v = serde_json::to_value(&out).unwrap();
assert_eq!(v["decision"], "block");
assert_eq!(v["reason"], "anti-loop");
assert_eq!(v["abort"], true);
assert!(v.get("do_not_reply_again").is_none());
}
#[test]
fn block_with_anti_loop_flag_serialises_field() {
let out = HookOutcome::Block {
reason: "loop".into(),
do_not_reply_again: true,
};
let v = serde_json::to_value(&out).unwrap();
assert_eq!(v["do_not_reply_again"], true);
}
#[test]
fn transform_serialises_with_body_and_no_abort() {
let out = HookOutcome::Transform {
transformed_body: "Hasta luego".into(),
reason: Some("opt-out keyword".into()),
do_not_reply_again: true,
};
let v = serde_json::to_value(&out).unwrap();
assert_eq!(v["decision"], "transform");
assert_eq!(v["transformed_body"], "Hasta luego");
assert_eq!(v["reason"], "opt-out keyword");
assert_eq!(v["do_not_reply_again"], true);
assert_eq!(v["abort"], false);
}
#[test]
fn transform_without_reason_omits_field() {
let out = HookOutcome::Transform {
transformed_body: "redacted".into(),
reason: None,
do_not_reply_again: false,
};
let v = serde_json::to_value(&out).unwrap();
assert!(v.get("reason").is_none());
assert!(v.get("do_not_reply_again").is_none());
}
#[test]
fn continue_serialises_decision_allow() {
let v = serde_json::to_value(&HookOutcome::Continue).unwrap();
assert_eq!(v["decision"], "allow");
assert_eq!(v["abort"], false);
}
#[test]
fn legacy_abort_serialises_decision_block() {
let out = HookOutcome::Abort {
reason: "spam".into(),
};
let v = serde_json::to_value(&out).unwrap();
assert_eq!(v["decision"], "block");
assert_eq!(v["abort"], true);
assert_eq!(v["reason"], "spam");
}
#[test]
fn block_helper_constructor_defaults_to_no_anti_loop() {
let out = HookOutcome::block("rate-limit");
assert_eq!(
out,
HookOutcome::Block {
reason: "rate-limit".into(),
do_not_reply_again: false,
}
);
}
#[test]
fn transform_helper_constructor_minimal() {
let out = HookOutcome::transform("[redacted]");
assert_eq!(
out,
HookOutcome::Transform {
transformed_body: "[redacted]".into(),
reason: None,
do_not_reply_again: false,
}
);
}
}