use std::fmt;
use std::sync::Arc;
use async_trait::async_trait;
use juncture_core::state::messages::Message;
use crate::prebuilt::messages_state::MessagesState;
#[derive(Debug)]
pub enum MiddlewareAction {
Continue,
ShortCircuit(Message),
}
#[async_trait]
pub trait AgentMiddleware: Send + Sync + fmt::Debug {
async fn before_model(&self, _state: &MessagesState) -> MiddlewareAction {
MiddlewareAction::Continue
}
async fn after_model(&self, _state: &MessagesState, response: &Message) -> Message {
response.clone()
}
async fn before_tool(
&self,
_tool_name: &str,
_arguments: &serde_json::Value,
) -> MiddlewareAction {
MiddlewareAction::Continue
}
async fn after_tool(&self, _tool_name: &str, result: &Message) -> Message {
result.clone()
}
async fn on_error(&self, _error: &str) -> Option<Message> {
None
}
}
#[derive(Debug)]
pub struct NopMiddleware;
#[async_trait]
impl AgentMiddleware for NopMiddleware {}
#[derive(Debug)]
pub struct LoopDetectionMiddleware {
max_repetitions: usize,
}
impl LoopDetectionMiddleware {
#[must_use]
pub const fn new(max_repetitions: usize) -> Self {
Self { max_repetitions }
}
}
#[async_trait]
impl AgentMiddleware for LoopDetectionMiddleware {
async fn before_model(&self, state: &MessagesState) -> MiddlewareAction {
if state.messages.len() < self.max_repetitions * 2 {
return MiddlewareAction::Continue;
}
let recent: Vec<&Message> = state
.messages
.iter()
.rev()
.take(self.max_repetitions * 2)
.collect();
if recent.len() < self.max_repetitions * 2 {
return MiddlewareAction::Continue;
}
let tool_calls: Vec<(&str, &serde_json::Value)> = recent
.iter()
.filter_map(|m| {
m.tool_calls
.first()
.map(|tc| (tc.name.as_str(), &tc.arguments))
})
.collect();
if tool_calls.len() >= self.max_repetitions {
let first = &tool_calls[0];
let all_same = tool_calls
.iter()
.all(|tc| tc.0 == first.0 && tc.1 == first.1);
if all_same {
return MiddlewareAction::ShortCircuit(Message::ai(format!(
"Loop detected: tool '{}' called {} times with identical arguments. Stopping.",
first.0, self.max_repetitions
)));
}
}
MiddlewareAction::Continue
}
}
#[derive(Debug)]
pub struct ToolErrorHandlingMiddleware;
impl ToolErrorHandlingMiddleware {
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl Default for ToolErrorHandlingMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl AgentMiddleware for ToolErrorHandlingMiddleware {
async fn after_tool(&self, tool_name: &str, result: &Message) -> Message {
if let crate::llm::Content::Text(text) = &result.content
&& (text.starts_with("Error:") || text.starts_with("error:"))
{
return Message::tool_result(
result.tool_call_id.clone().unwrap_or_default(),
format!(
"Tool '{tool_name}' failed: {text}\nPlease try a different approach or tool."
),
);
}
result.clone()
}
async fn on_error(&self, error: &str) -> Option<Message> {
Some(Message::ai(format!(
"An error occurred: {error}\nLet me try a different approach."
)))
}
}
#[derive(Clone)]
pub struct AgentMiddlewareChain {
middlewares: Vec<Arc<dyn AgentMiddleware>>,
}
impl AgentMiddlewareChain {
#[must_use]
pub fn new() -> Self {
Self {
middlewares: Vec::new(),
}
}
#[must_use]
pub fn with<M: AgentMiddleware + 'static>(mut self, middleware: M) -> Self {
self.middlewares.push(Arc::new(middleware));
self
}
#[must_use]
pub fn len(&self) -> usize {
self.middlewares.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.middlewares.is_empty()
}
pub async fn run_before_model(&self, state: &MessagesState) -> MiddlewareAction {
for mw in &self.middlewares {
match mw.before_model(state).await {
MiddlewareAction::Continue => {}
MiddlewareAction::ShortCircuit(msg) => return MiddlewareAction::ShortCircuit(msg),
}
}
MiddlewareAction::Continue
}
pub async fn run_after_model(&self, state: &MessagesState, response: &Message) -> Message {
let mut result = response.clone();
for mw in self.middlewares.iter().rev() {
result = mw.after_model(state, &result).await;
}
result
}
pub async fn run_before_tool(
&self,
tool_name: &str,
arguments: &serde_json::Value,
) -> MiddlewareAction {
for mw in &self.middlewares {
match mw.before_tool(tool_name, arguments).await {
MiddlewareAction::Continue => {}
MiddlewareAction::ShortCircuit(msg) => return MiddlewareAction::ShortCircuit(msg),
}
}
MiddlewareAction::Continue
}
pub async fn run_after_tool(&self, tool_name: &str, result: &Message) -> Message {
let mut result = result.clone();
for mw in self.middlewares.iter().rev() {
result = mw.after_tool(tool_name, &result).await;
}
result
}
pub async fn run_on_error(&self, error: &str) -> Option<Message> {
for mw in &self.middlewares {
if let Some(recovery) = mw.on_error(error).await {
return Some(recovery);
}
}
None
}
}
impl Default for AgentMiddlewareChain {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for AgentMiddlewareChain {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AgentMiddlewareChain")
.field("count", &self.middlewares.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nop_middleware_debug() {
let mw = NopMiddleware;
let debug = format!("{mw:?}");
assert_eq!(debug, "NopMiddleware");
}
#[test]
fn test_loop_detection_middleware_new() {
let mw = LoopDetectionMiddleware::new(3);
assert_eq!(mw.max_repetitions, 3);
}
#[test]
fn test_tool_error_handling_middleware_default() {
let mw = ToolErrorHandlingMiddleware;
let _ = format!("{mw:?}");
}
#[test]
fn test_middleware_chain_new() {
let chain = AgentMiddlewareChain::new();
assert!(chain.is_empty());
assert_eq!(chain.len(), 0);
}
#[test]
fn test_middleware_chain_with() {
let chain = AgentMiddlewareChain::new()
.with(NopMiddleware)
.with(LoopDetectionMiddleware::new(3));
assert_eq!(chain.len(), 2);
assert!(!chain.is_empty());
}
#[test]
fn test_middleware_chain_default() {
let chain = AgentMiddlewareChain::default();
assert!(chain.is_empty());
}
#[test]
fn test_middleware_chain_clone() {
let chain = AgentMiddlewareChain::new().with(NopMiddleware);
let cloned = chain.clone();
drop(chain);
assert_eq!(cloned.len(), 1);
}
#[test]
fn test_middleware_chain_debug() {
let chain = AgentMiddlewareChain::new().with(NopMiddleware);
let debug = format!("{chain:?}");
assert!(debug.contains("AgentMiddlewareChain"));
assert!(debug.contains("count: 1"));
}
#[tokio::test]
async fn test_middleware_action_debug() {
let cont = MiddlewareAction::Continue;
assert_eq!(format!("{cont:?}"), "Continue");
let sc = MiddlewareAction::ShortCircuit(Message::ai("test"));
let debug = format!("{sc:?}");
assert!(debug.contains("ShortCircuit"));
}
#[tokio::test]
async fn test_chain_run_before_model_continue() {
let chain = AgentMiddlewareChain::new().with(NopMiddleware);
let state = MessagesState::default();
let result = chain.run_before_model(&state).await;
assert!(matches!(result, MiddlewareAction::Continue));
}
#[tokio::test]
async fn test_chain_run_after_model_passthrough() {
let chain = AgentMiddlewareChain::new().with(NopMiddleware);
let state = MessagesState::default();
let response = Message::ai("hello");
let result = chain.run_after_model(&state, &response).await;
assert_eq!(result.content_text(), "hello");
}
#[tokio::test]
async fn test_chain_run_before_tool_continue() {
let chain = AgentMiddlewareChain::new().with(NopMiddleware);
let args = serde_json::json!({});
let result = chain.run_before_tool("test_tool", &args).await;
assert!(matches!(result, MiddlewareAction::Continue));
}
#[tokio::test]
async fn test_chain_run_after_tool_passthrough() {
let chain = AgentMiddlewareChain::new().with(NopMiddleware);
let result_msg = Message::tool_result("call_1", "result");
let result = chain.run_after_tool("test_tool", &result_msg).await;
assert_eq!(result.content_text(), "result");
}
#[tokio::test]
async fn test_chain_run_on_error_none() {
let chain = AgentMiddlewareChain::new().with(NopMiddleware);
let result = chain.run_on_error("test error").await;
assert!(result.is_none());
}
#[tokio::test]
async fn test_tool_error_handling_recovery() {
let chain = AgentMiddlewareChain::new().with(ToolErrorHandlingMiddleware::new());
let result = chain.run_on_error("something broke").await;
assert!(result.is_some());
let msg = result.unwrap();
assert!(msg.content_text().contains("something broke"));
}
#[tokio::test]
async fn test_tool_error_handling_normal_result() {
let chain = AgentMiddlewareChain::new().with(ToolErrorHandlingMiddleware::new());
let result_msg = Message::tool_result("call_1", "success");
let result = chain.run_after_tool("test_tool", &result_msg).await;
assert_eq!(result.content_text(), "success");
}
#[tokio::test]
async fn test_tool_error_handling_error_result() {
let chain = AgentMiddlewareChain::new().with(ToolErrorHandlingMiddleware::new());
let result_msg = Message::tool_result("call_1", "Error: something failed");
let result = chain.run_after_tool("test_tool", &result_msg).await;
assert!(result.content_text().contains("test_tool"));
assert!(result.content_text().contains("Error: something failed"));
}
#[tokio::test]
async fn test_loop_detection_no_loop() {
let chain = AgentMiddlewareChain::new().with(LoopDetectionMiddleware::new(3));
let state = MessagesState {
messages: vec![Message::human("hello")],
};
let result = chain.run_before_model(&state).await;
assert!(matches!(result, MiddlewareAction::Continue));
}
}