mod callback_adapter;
mod circuit_breaker;
mod context_editing;
mod human_in_the_loop;
mod model_call_limit;
mod model_fallback;
mod security;
mod ssrf_guard;
mod summarization;
mod todo_list;
mod tool_call_limit;
mod tool_retry;
pub use callback_adapter::CallbackMiddleware;
pub use circuit_breaker::{CircuitBreakerConfig, CircuitBreakerMiddleware, CircuitState};
pub use context_editing::{ContextEditingMiddleware, ContextStrategy};
pub use human_in_the_loop::{ApprovalCallback, HumanInTheLoopMiddleware};
pub use model_call_limit::ModelCallLimitMiddleware;
pub use model_fallback::ModelFallbackMiddleware;
pub use security::{
ConfirmationPolicy, RiskLevel, RuleBasedAnalyzer, SecurityAnalyzer,
SecurityConfirmationCallback, SecurityMiddleware, ThresholdConfirmationPolicy,
};
pub use ssrf_guard::{SsrfGuardConfig, SsrfGuardMiddleware};
pub use summarization::SummarizationMiddleware;
pub use todo_list::TodoListMiddleware;
pub use tool_call_limit::ToolCallLimitMiddleware;
pub use tool_retry::ToolRetryMiddleware;
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use synaptic_core::{
ChatModel, ChatRequest, ChatResponse, Message, SynapticError, TokenUsage, ToolCall, ToolChoice,
ToolDefinition,
};
#[derive(Debug, Clone)]
pub struct ModelRequest {
pub messages: Vec<Message>,
pub tools: Vec<ToolDefinition>,
pub tool_choice: Option<ToolChoice>,
pub system_prompt: Option<String>,
}
impl ModelRequest {
pub fn to_chat_request(&self) -> ChatRequest {
let mut messages = Vec::new();
if let Some(ref prompt) = self.system_prompt {
messages.push(Message::system(prompt));
}
messages.extend(self.messages.clone());
let mut req = ChatRequest::new(messages).with_tools(self.tools.clone());
if let Some(ref choice) = self.tool_choice {
req = req.with_tool_choice(choice.clone());
}
req
}
}
#[derive(Debug, Clone)]
pub struct ModelResponse {
pub message: Message,
pub usage: Option<TokenUsage>,
}
impl From<ChatResponse> for ModelResponse {
fn from(resp: ChatResponse) -> Self {
Self {
message: resp.message,
usage: resp.usage,
}
}
}
#[derive(Debug, Clone)]
pub struct ToolCallRequest {
pub call: ToolCall,
}
#[derive(Debug, Clone)]
pub struct FileOp {
pub path: String,
pub kind: FileOpKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FileOpKind {
Read,
Write,
Delete,
}
#[derive(Debug, Clone)]
pub struct FileOpResult {
pub success: bool,
pub error: Option<String>,
}
#[derive(Debug, Clone)]
pub enum FileOpDecision {
Allow,
Deny(String),
}
#[derive(Debug, Clone)]
pub struct CommandOp {
pub command: String,
pub args: Vec<String>,
pub working_dir: Option<String>,
}
#[derive(Debug, Clone)]
pub struct CommandResult {
pub exit_code: i32,
pub stdout: String,
pub stderr: String,
}
#[derive(Debug, Clone)]
pub enum CommandDecision {
Allow,
Deny(String),
}
#[async_trait]
pub trait ModelCaller: Send + Sync {
async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError>;
}
#[async_trait]
pub trait ToolCaller: Send + Sync {
async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError>;
}
#[async_trait]
pub trait AgentMiddleware: Send + Sync {
async fn before_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
Ok(())
}
async fn after_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
Ok(())
}
async fn before_model(&self, _request: &mut ModelRequest) -> Result<(), SynapticError> {
Ok(())
}
async fn after_model(
&self,
_request: &ModelRequest,
_response: &mut ModelResponse,
) -> Result<(), SynapticError> {
Ok(())
}
async fn wrap_model_call(
&self,
request: ModelRequest,
next: &dyn ModelCaller,
) -> Result<ModelResponse, SynapticError> {
next.call(request).await
}
async fn wrap_tool_call(
&self,
request: ToolCallRequest,
next: &dyn ToolCaller,
) -> Result<Value, SynapticError> {
next.call(request).await
}
async fn before_file_op(&self, _op: &FileOp) -> Result<FileOpDecision, SynapticError> {
Ok(FileOpDecision::Allow)
}
async fn after_file_op(
&self,
_op: &FileOp,
_result: &FileOpResult,
) -> Result<(), SynapticError> {
Ok(())
}
async fn before_command(&self, _cmd: &CommandOp) -> Result<CommandDecision, SynapticError> {
Ok(CommandDecision::Allow)
}
async fn after_command(
&self,
_cmd: &CommandOp,
_result: &CommandResult,
) -> Result<(), SynapticError> {
Ok(())
}
}
pub struct MiddlewareChain {
middlewares: Vec<Arc<dyn AgentMiddleware>>,
}
impl MiddlewareChain {
pub fn new(middlewares: Vec<Arc<dyn AgentMiddleware>>) -> Self {
Self { middlewares }
}
pub fn is_empty(&self) -> bool {
self.middlewares.is_empty()
}
pub async fn run_before_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
for mw in &self.middlewares {
mw.before_agent(messages).await?;
}
Ok(())
}
pub async fn run_after_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
for mw in self.middlewares.iter().rev() {
mw.after_agent(messages).await?;
}
Ok(())
}
pub async fn run_before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
for mw in &self.middlewares {
mw.before_model(request).await?;
}
Ok(())
}
pub async fn run_after_model(
&self,
request: &ModelRequest,
response: &mut ModelResponse,
) -> Result<(), SynapticError> {
for mw in self.middlewares.iter().rev() {
mw.after_model(request, response).await?;
}
Ok(())
}
pub async fn call_model(
&self,
mut request: ModelRequest,
base: &dyn ModelCaller,
) -> Result<ModelResponse, SynapticError> {
self.run_before_model(&mut request).await?;
let mut response = if self.middlewares.is_empty() {
base.call(request.clone()).await?
} else {
let chain = WrapModelChain {
middlewares: &self.middlewares,
index: 0,
base,
};
chain.call(request.clone()).await?
};
self.run_after_model(&request, &mut response).await?;
Ok(response)
}
pub async fn call_tool(
&self,
request: ToolCallRequest,
base: &dyn ToolCaller,
) -> Result<Value, SynapticError> {
if self.middlewares.is_empty() {
base.call(request).await
} else {
let chain = WrapToolChain {
middlewares: &self.middlewares,
index: 0,
base,
};
chain.call(request).await
}
}
pub async fn run_before_file_op(&self, op: &FileOp) -> Result<FileOpDecision, SynapticError> {
for mw in &self.middlewares {
match mw.before_file_op(op).await? {
FileOpDecision::Allow => continue,
deny => return Ok(deny),
}
}
Ok(FileOpDecision::Allow)
}
pub async fn run_after_file_op(
&self,
op: &FileOp,
result: &FileOpResult,
) -> Result<(), SynapticError> {
for mw in self.middlewares.iter().rev() {
mw.after_file_op(op, result).await?;
}
Ok(())
}
pub async fn run_before_command(
&self,
cmd: &CommandOp,
) -> Result<CommandDecision, SynapticError> {
for mw in &self.middlewares {
match mw.before_command(cmd).await? {
CommandDecision::Allow => continue,
deny => return Ok(deny),
}
}
Ok(CommandDecision::Allow)
}
pub async fn run_after_command(
&self,
cmd: &CommandOp,
result: &CommandResult,
) -> Result<(), SynapticError> {
for mw in self.middlewares.iter().rev() {
mw.after_command(cmd, result).await?;
}
Ok(())
}
}
struct WrapModelChain<'a> {
middlewares: &'a [Arc<dyn AgentMiddleware>],
index: usize,
base: &'a dyn ModelCaller,
}
#[async_trait]
impl ModelCaller for WrapModelChain<'_> {
async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError> {
if self.index >= self.middlewares.len() {
self.base.call(request).await
} else {
let next = WrapModelChain {
middlewares: self.middlewares,
index: self.index + 1,
base: self.base,
};
self.middlewares[self.index]
.wrap_model_call(request, &next)
.await
}
}
}
struct WrapToolChain<'a> {
middlewares: &'a [Arc<dyn AgentMiddleware>],
index: usize,
base: &'a dyn ToolCaller,
}
#[async_trait]
impl ToolCaller for WrapToolChain<'_> {
async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError> {
if self.index >= self.middlewares.len() {
self.base.call(request).await
} else {
let next = WrapToolChain {
middlewares: self.middlewares,
index: self.index + 1,
base: self.base,
};
self.middlewares[self.index]
.wrap_tool_call(request, &next)
.await
}
}
}
pub struct BaseChatModelCaller {
model: Arc<dyn ChatModel>,
}
impl BaseChatModelCaller {
pub fn new(model: Arc<dyn ChatModel>) -> Self {
Self { model }
}
}
#[async_trait]
impl ModelCaller for BaseChatModelCaller {
async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError> {
let chat_request = request.to_chat_request();
let response = self.model.chat(chat_request).await?;
Ok(response.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
struct CountingMiddleware {
before_count: AtomicUsize,
after_count: AtomicUsize,
}
impl CountingMiddleware {
fn new() -> Self {
Self {
before_count: AtomicUsize::new(0),
after_count: AtomicUsize::new(0),
}
}
}
#[async_trait]
impl AgentMiddleware for CountingMiddleware {
async fn before_model(&self, _request: &mut ModelRequest) -> Result<(), SynapticError> {
self.before_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn after_model(
&self,
_request: &ModelRequest,
_response: &mut ModelResponse,
) -> Result<(), SynapticError> {
self.after_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
#[test]
fn middleware_chain_creation() {
let mw: Arc<dyn AgentMiddleware> = Arc::new(CountingMiddleware::new());
let chain = MiddlewareChain::new(vec![mw]);
assert!(!chain.is_empty());
}
#[test]
fn empty_middleware_chain() {
let chain = MiddlewareChain::new(vec![]);
assert!(chain.is_empty());
}
#[test]
fn model_request_to_chat_request() {
let req = ModelRequest {
messages: vec![Message::human("hello")],
tools: vec![],
tool_choice: None,
system_prompt: Some("You are helpful.".to_string()),
};
let chat_req = req.to_chat_request();
assert_eq!(chat_req.messages.len(), 2);
assert!(chat_req.messages[0].is_system());
assert!(chat_req.messages[1].is_human());
}
#[test]
fn model_request_without_system_prompt() {
let req = ModelRequest {
messages: vec![Message::human("hello")],
tools: vec![],
tool_choice: None,
system_prompt: None,
};
let chat_req = req.to_chat_request();
assert_eq!(chat_req.messages.len(), 1);
}
#[tokio::test]
async fn file_hook_default_allows() {
let mw: Arc<dyn AgentMiddleware> = Arc::new(CountingMiddleware::new());
let chain = MiddlewareChain::new(vec![mw]);
let op = FileOp {
path: "/tmp/test".to_string(),
kind: FileOpKind::Write,
};
let decision = chain.run_before_file_op(&op).await.unwrap();
assert!(matches!(decision, FileOpDecision::Allow));
}
#[tokio::test]
async fn command_hook_default_allows() {
let mw: Arc<dyn AgentMiddleware> = Arc::new(CountingMiddleware::new());
let chain = MiddlewareChain::new(vec![mw]);
let cmd = CommandOp {
command: "ls".to_string(),
args: vec![],
working_dir: None,
};
let decision = chain.run_before_command(&cmd).await.unwrap();
assert!(matches!(decision, CommandDecision::Allow));
}
}