1use serde_json::Value;
16
17pub trait Middleware {
19 fn name(&self) -> &str;
20 fn pre(&self, tool: &str, args: Value) -> Value { let _ = tool; args }
21 fn post(&self, tool: &str, result: Value) -> Value { let _ = tool; result }
22}
23
24pub struct LogMiddleware {
26 pub log: std::sync::Mutex<Vec<(String, Value, Value)>>,
27}
28
29impl LogMiddleware {
30 pub fn new() -> Self { Self { log: std::sync::Mutex::new(Vec::new()) } }
31
32 pub fn entries(&self) -> std::sync::MutexGuard<Vec<(String, Value, Value)>> {
33 self.log.lock().unwrap()
34 }
35}
36
37impl Default for LogMiddleware {
38 fn default() -> Self { Self::new() }
39}
40
41impl Middleware for LogMiddleware {
42 fn name(&self) -> &str { "log" }
43
44 fn post(&self, tool: &str, result: Value) -> Value {
45 let _ = tool;
47 result
48 }
49}
50
51pub struct InjectFieldMiddleware {
53 pub field: String,
54 pub value: Value,
55}
56
57impl InjectFieldMiddleware {
58 pub fn new(field: &str, value: Value) -> Self {
59 Self { field: field.to_string(), value }
60 }
61}
62
63impl Middleware for InjectFieldMiddleware {
64 fn name(&self) -> &str { "inject_field" }
65
66 fn pre(&self, _tool: &str, mut args: Value) -> Value {
67 if let Some(obj) = args.as_object_mut() {
68 obj.insert(self.field.clone(), self.value.clone());
69 }
70 args
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct CallRecord {
77 pub tool: String,
78 pub args: Value,
79 pub result: Value,
80}
81
82pub struct MiddlewarePipeline {
84 middleware: Vec<Box<dyn Middleware>>,
85 log: Vec<CallRecord>,
86}
87
88impl MiddlewarePipeline {
89 pub fn new() -> Self { Self { middleware: Vec::new(), log: Vec::new() } }
90
91 pub fn add<M: Middleware + 'static>(&mut self, m: M) {
92 self.middleware.push(Box::new(m));
93 }
94
95 pub fn run(&mut self, tool: &str, args: Value, result: Value) -> (Value, Value) {
97 let mut a = args;
98 for m in &self.middleware {
99 a = m.pre(tool, a);
100 }
101 let mut r = result;
102 for m in self.middleware.iter().rev() {
103 r = m.post(tool, r);
104 }
105 self.log.push(CallRecord { tool: tool.to_string(), args: a.clone(), result: r.clone() });
106 (a, r)
107 }
108
109 pub fn call_log(&self) -> &[CallRecord] { &self.log }
110 pub fn call_count(&self) -> usize { self.log.len() }
111 pub fn middleware_count(&self) -> usize { self.middleware.len() }
112}
113
114impl Default for MiddlewarePipeline {
115 fn default() -> Self { Self::new() }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use serde_json::json;
122
123 #[test]
124 fn no_middleware_passthrough() {
125 let mut p = MiddlewarePipeline::new();
126 let (a, r) = p.run("fn", json!({"q": 1}), json!({"ok": true}));
127 assert_eq!(a["q"], 1);
128 assert_eq!(r["ok"], true);
129 }
130
131 #[test]
132 fn inject_field_pre() {
133 let mut p = MiddlewarePipeline::new();
134 p.add(InjectFieldMiddleware::new("version", json!("1.0")));
135 let (a, _) = p.run("fn", json!({"q": "x"}), json!({}));
136 assert_eq!(a["version"], "1.0");
137 }
138
139 #[test]
140 fn inject_field_preserves_existing() {
141 let mut p = MiddlewarePipeline::new();
142 p.add(InjectFieldMiddleware::new("v", json!(99)));
143 let (a, _) = p.run("fn", json!({"q": "x"}), json!({}));
144 assert_eq!(a["q"], "x");
145 assert_eq!(a["v"], 99);
146 }
147
148 #[test]
149 fn call_log_recorded() {
150 let mut p = MiddlewarePipeline::new();
151 p.run("search", json!({}), json!({}));
152 p.run("fetch", json!({}), json!({}));
153 assert_eq!(p.call_count(), 2);
154 assert_eq!(p.call_log()[0].tool, "search");
155 }
156
157 #[test]
158 fn middleware_count() {
159 let mut p = MiddlewarePipeline::new();
160 p.add(InjectFieldMiddleware::new("x", json!(1)));
161 p.add(InjectFieldMiddleware::new("y", json!(2)));
162 assert_eq!(p.middleware_count(), 2);
163 }
164
165 #[test]
166 fn multiple_inject_fields() {
167 let mut p = MiddlewarePipeline::new();
168 p.add(InjectFieldMiddleware::new("a", json!(1)));
169 p.add(InjectFieldMiddleware::new("b", json!(2)));
170 let (args, _) = p.run("fn", json!({}), json!({}));
171 assert_eq!(args["a"], 1);
172 assert_eq!(args["b"], 2);
173 }
174
175 #[test]
176 fn log_middleware_name() {
177 let m = LogMiddleware::new();
178 assert_eq!(m.name(), "log");
179 }
180
181 #[test]
182 fn inject_field_name() {
183 let m = InjectFieldMiddleware::new("x", json!(1));
184 assert_eq!(m.name(), "inject_field");
185 }
186
187 #[test]
188 fn call_record_fields() {
189 let mut p = MiddlewarePipeline::new();
190 p.run("my_tool", json!({"arg": "val"}), json!({"result": 42}));
191 let rec = &p.call_log()[0];
192 assert_eq!(rec.tool, "my_tool");
193 assert_eq!(rec.result["result"], 42);
194 }
195
196 #[test]
197 fn non_object_args_unchanged() {
198 let mut p = MiddlewarePipeline::new();
199 p.add(InjectFieldMiddleware::new("x", json!(1)));
200 let (a, _) = p.run("fn", json!([1, 2, 3]), json!({}));
201 assert_eq!(a, json!([1, 2, 3]));
203 }
204
205 #[test]
206 fn empty_pipeline_no_overhead() {
207 let mut p = MiddlewarePipeline::new();
208 let (a, r) = p.run("t", json!("x"), json!("y"));
209 assert_eq!(a, "x");
210 assert_eq!(r, "y");
211 }
212}