Skip to main content

agent_tool_middleware/
lib.rs

1/*!
2agent-tool-middleware: middleware pipeline for LLM agent tool call processing.
3
4```rust
5use agent_tool_middleware::{MiddlewarePipeline, LogMiddleware};
6use serde_json::json;
7
8let mut pipe = MiddlewarePipeline::new();
9pipe.add(LogMiddleware::new());
10let (args, result) = pipe.run("search", json!({"q": "rust"}), json!({"hits": 5}));
11assert_eq!(args["q"], "rust");
12```
13*/
14
15use serde_json::Value;
16
17/// A middleware that can modify args before a call and result after.
18pub trait Middleware {
19    fn name(&self) -> &str;
20    fn pre(&self, tool: &str, args: Value) -> Value { let _ = tool; args }
21    fn post(&self, tool: &str, result: Value) -> Value { let _ = tool; result }
22}
23
24/// Logging middleware — records every call.
25pub struct LogMiddleware {
26    pub log: std::sync::Mutex<Vec<(String, Value, Value)>>,
27}
28
29impl LogMiddleware {
30    pub fn new() -> Self { Self { log: std::sync::Mutex::new(Vec::new()) } }
31
32    pub fn entries(&self) -> std::sync::MutexGuard<Vec<(String, Value, Value)>> {
33        self.log.lock().unwrap()
34    }
35}
36
37impl Default for LogMiddleware {
38    fn default() -> Self { Self::new() }
39}
40
41impl Middleware for LogMiddleware {
42    fn name(&self) -> &str { "log" }
43
44    fn post(&self, tool: &str, result: Value) -> Value {
45        // Log is recorded by the pipeline's post hook.
46        let _ = tool;
47        result
48    }
49}
50
51/// Middleware that adds a field to every args object.
52pub struct InjectFieldMiddleware {
53    pub field: String,
54    pub value: Value,
55}
56
57impl InjectFieldMiddleware {
58    pub fn new(field: &str, value: Value) -> Self {
59        Self { field: field.to_string(), value }
60    }
61}
62
63impl Middleware for InjectFieldMiddleware {
64    fn name(&self) -> &str { "inject_field" }
65
66    fn pre(&self, _tool: &str, mut args: Value) -> Value {
67        if let Some(obj) = args.as_object_mut() {
68            obj.insert(self.field.clone(), self.value.clone());
69        }
70        args
71    }
72}
73
74/// A call record in the pipeline log.
75#[derive(Debug, Clone)]
76pub struct CallRecord {
77    pub tool: String,
78    pub args: Value,
79    pub result: Value,
80}
81
82/// A pipeline of middleware applied to every tool call.
83pub struct MiddlewarePipeline {
84    middleware: Vec<Box<dyn Middleware>>,
85    log: Vec<CallRecord>,
86}
87
88impl MiddlewarePipeline {
89    pub fn new() -> Self { Self { middleware: Vec::new(), log: Vec::new() } }
90
91    pub fn add<M: Middleware + 'static>(&mut self, m: M) {
92        self.middleware.push(Box::new(m));
93    }
94
95    /// Run all pre-hooks, then post-hooks, return final (args, result).
96    pub fn run(&mut self, tool: &str, args: Value, result: Value) -> (Value, Value) {
97        let mut a = args;
98        for m in &self.middleware {
99            a = m.pre(tool, a);
100        }
101        let mut r = result;
102        for m in self.middleware.iter().rev() {
103            r = m.post(tool, r);
104        }
105        self.log.push(CallRecord { tool: tool.to_string(), args: a.clone(), result: r.clone() });
106        (a, r)
107    }
108
109    pub fn call_log(&self) -> &[CallRecord] { &self.log }
110    pub fn call_count(&self) -> usize { self.log.len() }
111    pub fn middleware_count(&self) -> usize { self.middleware.len() }
112}
113
114impl Default for MiddlewarePipeline {
115    fn default() -> Self { Self::new() }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use serde_json::json;
122
123    #[test]
124    fn no_middleware_passthrough() {
125        let mut p = MiddlewarePipeline::new();
126        let (a, r) = p.run("fn", json!({"q": 1}), json!({"ok": true}));
127        assert_eq!(a["q"], 1);
128        assert_eq!(r["ok"], true);
129    }
130
131    #[test]
132    fn inject_field_pre() {
133        let mut p = MiddlewarePipeline::new();
134        p.add(InjectFieldMiddleware::new("version", json!("1.0")));
135        let (a, _) = p.run("fn", json!({"q": "x"}), json!({}));
136        assert_eq!(a["version"], "1.0");
137    }
138
139    #[test]
140    fn inject_field_preserves_existing() {
141        let mut p = MiddlewarePipeline::new();
142        p.add(InjectFieldMiddleware::new("v", json!(99)));
143        let (a, _) = p.run("fn", json!({"q": "x"}), json!({}));
144        assert_eq!(a["q"], "x");
145        assert_eq!(a["v"], 99);
146    }
147
148    #[test]
149    fn call_log_recorded() {
150        let mut p = MiddlewarePipeline::new();
151        p.run("search", json!({}), json!({}));
152        p.run("fetch", json!({}), json!({}));
153        assert_eq!(p.call_count(), 2);
154        assert_eq!(p.call_log()[0].tool, "search");
155    }
156
157    #[test]
158    fn middleware_count() {
159        let mut p = MiddlewarePipeline::new();
160        p.add(InjectFieldMiddleware::new("x", json!(1)));
161        p.add(InjectFieldMiddleware::new("y", json!(2)));
162        assert_eq!(p.middleware_count(), 2);
163    }
164
165    #[test]
166    fn multiple_inject_fields() {
167        let mut p = MiddlewarePipeline::new();
168        p.add(InjectFieldMiddleware::new("a", json!(1)));
169        p.add(InjectFieldMiddleware::new("b", json!(2)));
170        let (args, _) = p.run("fn", json!({}), json!({}));
171        assert_eq!(args["a"], 1);
172        assert_eq!(args["b"], 2);
173    }
174
175    #[test]
176    fn log_middleware_name() {
177        let m = LogMiddleware::new();
178        assert_eq!(m.name(), "log");
179    }
180
181    #[test]
182    fn inject_field_name() {
183        let m = InjectFieldMiddleware::new("x", json!(1));
184        assert_eq!(m.name(), "inject_field");
185    }
186
187    #[test]
188    fn call_record_fields() {
189        let mut p = MiddlewarePipeline::new();
190        p.run("my_tool", json!({"arg": "val"}), json!({"result": 42}));
191        let rec = &p.call_log()[0];
192        assert_eq!(rec.tool, "my_tool");
193        assert_eq!(rec.result["result"], 42);
194    }
195
196    #[test]
197    fn non_object_args_unchanged() {
198        let mut p = MiddlewarePipeline::new();
199        p.add(InjectFieldMiddleware::new("x", json!(1)));
200        let (a, _) = p.run("fn", json!([1, 2, 3]), json!({}));
201        // Array isn't touched by inject
202        assert_eq!(a, json!([1, 2, 3]));
203    }
204
205    #[test]
206    fn empty_pipeline_no_overhead() {
207        let mut p = MiddlewarePipeline::new();
208        let (a, r) = p.run("t", json!("x"), json!("y"));
209        assert_eq!(a, "x");
210        assert_eq!(r, "y");
211    }
212}