use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use synaptic_core::SynapticError;
use crate::{AgentMiddleware, ToolCallRequest, ToolCaller};
#[async_trait]
pub trait ApprovalCallback: Send + Sync {
async fn approve(&self, tool_name: &str, arguments: &Value) -> Result<bool, SynapticError>;
}
pub struct HumanInTheLoopMiddleware {
callback: Arc<dyn ApprovalCallback>,
tools: HashSet<String>,
}
impl HumanInTheLoopMiddleware {
pub fn new(callback: Arc<dyn ApprovalCallback>) -> Self {
Self {
callback,
tools: HashSet::new(),
}
}
pub fn for_tools(callback: Arc<dyn ApprovalCallback>, tools: Vec<String>) -> Self {
Self {
callback,
tools: tools.into_iter().collect(),
}
}
}
#[async_trait]
impl AgentMiddleware for HumanInTheLoopMiddleware {
async fn wrap_tool_call(
&self,
request: ToolCallRequest,
next: &dyn ToolCaller,
) -> Result<Value, SynapticError> {
let needs_approval = self.tools.is_empty() || self.tools.contains(&request.call.name);
if needs_approval {
let approved = self
.callback
.approve(&request.call.name, &request.call.arguments)
.await?;
if !approved {
return Ok(Value::String(format!(
"Tool call '{}' was rejected by human review.",
request.call.name
)));
}
}
next.call(request).await
}
}