use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use cognis_core::{CognisError, Result};
use cognis_llm::tools::{Tool, ToolInput, ToolOutput};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Decision {
Approve,
Reject {
reason: String,
},
Edit {
args: serde_json::Value,
},
}
impl Decision {
pub fn reject(reason: impl Into<String>) -> Self {
Self::Reject {
reason: reason.into(),
}
}
}
#[async_trait]
pub trait Approver: Send + Sync {
async fn approve(&self, tool_name: &str, args: &serde_json::Value) -> Result<Decision>;
}
pub struct AutoApprove;
#[async_trait]
impl Approver for AutoApprove {
async fn approve(&self, _: &str, _: &serde_json::Value) -> Result<Decision> {
Ok(Decision::Approve)
}
}
pub struct RejectAll {
reason: String,
}
impl RejectAll {
pub fn new(reason: impl Into<String>) -> Self {
Self {
reason: reason.into(),
}
}
}
impl Default for RejectAll {
fn default() -> Self {
Self::new("approval required but no approver configured")
}
}
#[async_trait]
impl Approver for RejectAll {
async fn approve(&self, _: &str, _: &serde_json::Value) -> Result<Decision> {
Ok(Decision::reject(self.reason.clone()))
}
}
pub struct AllowList {
allowed: Vec<String>,
}
impl AllowList {
pub fn new<I, S>(allowed: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self {
allowed: allowed.into_iter().map(Into::into).collect(),
}
}
}
#[async_trait]
impl Approver for AllowList {
async fn approve(&self, tool_name: &str, _: &serde_json::Value) -> Result<Decision> {
if self.allowed.iter().any(|n| n == tool_name) {
Ok(Decision::Approve)
} else {
Ok(Decision::reject(format!(
"tool `{tool_name}` is not on the allow list"
)))
}
}
}
pub struct ApprovalGatedTool {
inner: Arc<dyn Tool>,
approver: Arc<dyn Approver>,
}
impl ApprovalGatedTool {
pub fn new(inner: Arc<dyn Tool>, approver: Arc<dyn Approver>) -> Self {
Self { inner, approver }
}
}
#[async_trait]
impl Tool for ApprovalGatedTool {
fn name(&self) -> &str {
self.inner.name()
}
fn description(&self) -> &str {
self.inner.description()
}
fn args_schema(&self) -> Option<serde_json::Value> {
self.inner.args_schema()
}
fn return_direct(&self) -> bool {
self.inner.return_direct()
}
async fn _run(&self, input: ToolInput) -> Result<ToolOutput> {
let args_json = input.clone().into_json();
let decision = self.approver.approve(self.inner.name(), &args_json).await?;
match decision {
Decision::Approve => self.inner._run(input).await,
Decision::Reject { reason } => Err(CognisError::Tool {
name: self.inner.name().to_string(),
reason: format!("rejected by approver: {reason}"),
}),
Decision::Edit { args } => {
let edited = if let serde_json::Value::Object(m) = args {
let map: std::collections::HashMap<String, serde_json::Value> =
m.into_iter().collect();
ToolInput::Structured(map)
} else {
ToolInput::Text(args.to_string())
};
self.inner._run(edited).await
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct Echo;
#[async_trait]
impl Tool for Echo {
fn name(&self) -> &str {
"echo"
}
fn description(&self) -> &str {
"echoes"
}
fn args_schema(&self) -> Option<serde_json::Value> {
None
}
async fn _run(&self, input: ToolInput) -> Result<ToolOutput> {
Ok(ToolOutput::Content(input.into_json()))
}
}
#[tokio::test]
async fn auto_approve_passes_through() {
let t = ApprovalGatedTool::new(Arc::new(Echo), Arc::new(AutoApprove));
let out = t._run(ToolInput::Text("hi".into())).await.unwrap();
assert_eq!(out.as_string(), "\"hi\"");
}
#[tokio::test]
async fn reject_all_blocks() {
let t = ApprovalGatedTool::new(Arc::new(Echo), Arc::new(RejectAll::default()));
let err = t._run(ToolInput::Text("hi".into())).await.unwrap_err();
assert!(format!("{err}").contains("rejected"));
}
#[tokio::test]
async fn allow_list_filters_by_tool_name() {
let allow = ApprovalGatedTool::new(Arc::new(Echo), Arc::new(AllowList::new(["echo"])));
assert!(allow._run(ToolInput::Text("a".into())).await.is_ok());
let block = ApprovalGatedTool::new(Arc::new(Echo), Arc::new(AllowList::new(["other"])));
assert!(block._run(ToolInput::Text("a".into())).await.is_err());
}
struct Editor;
#[async_trait]
impl Approver for Editor {
async fn approve(&self, _: &str, _: &serde_json::Value) -> Result<Decision> {
Ok(Decision::Edit {
args: serde_json::json!({"replaced": true}),
})
}
}
#[tokio::test]
async fn edit_substitutes_args() {
let t = ApprovalGatedTool::new(Arc::new(Echo), Arc::new(Editor));
let out = t._run(ToolInput::Text("ignored".into())).await.unwrap();
let v: serde_json::Value = match out {
ToolOutput::Content(v) => v,
_ => panic!(),
};
assert_eq!(v["replaced"], true);
}
}