pforge_runtime/
middleware.rs1use crate::{Error, Result};
2use serde_json::Value;
3use std::sync::Arc;
4
5#[async_trait::async_trait]
7pub trait Middleware: Send + Sync {
8 async fn before(&self, request: Value) -> Result<Value> {
11 Ok(request)
12 }
13
14 async fn after(&self, request: Value, response: Value) -> Result<Value> {
17 let _ = request;
18 Ok(response)
19 }
20
21 async fn on_error(&self, request: Value, error: Error) -> Result<Value> {
23 let _ = request;
24 Err(error)
25 }
26}
27
28pub struct MiddlewareChain {
30 middlewares: Vec<Arc<dyn Middleware>>,
31}
32
33impl MiddlewareChain {
34 pub fn new() -> Self {
35 Self {
36 middlewares: Vec::new(),
37 }
38 }
39
40 pub fn add(&mut self, middleware: Arc<dyn Middleware>) {
42 self.middlewares.push(middleware);
43 }
44
45 pub async fn execute<F, Fut>(
47 &self,
48 mut request: Value,
49 handler: F,
50 ) -> Result<Value>
51 where
52 F: FnOnce(Value) -> Fut,
53 Fut: std::future::Future<Output = Result<Value>>,
54 {
55 for middleware in &self.middlewares {
57 request = middleware.before(request).await?;
58 }
59
60 let result = handler(request.clone()).await;
62
63 match result {
65 Ok(mut response) => {
66 for middleware in self.middlewares.iter().rev() {
67 response = middleware.after(request.clone(), response).await?;
68 }
69 Ok(response)
70 }
71 Err(error) => {
72 let mut current_error = error;
73 for middleware in self.middlewares.iter().rev() {
74 match middleware.on_error(request.clone(), current_error).await {
75 Ok(recovery_response) => return Ok(recovery_response),
76 Err(new_error) => current_error = new_error,
77 }
78 }
79 Err(current_error)
80 }
81 }
82 }
83}
84
85impl Default for MiddlewareChain {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91pub struct LoggingMiddleware {
93 tag: String,
94}
95
96impl LoggingMiddleware {
97 pub fn new(tag: impl Into<String>) -> Self {
98 Self { tag: tag.into() }
99 }
100}
101
102#[async_trait::async_trait]
103impl Middleware for LoggingMiddleware {
104 async fn before(&self, request: Value) -> Result<Value> {
105 eprintln!("[{}] Request: {}", self.tag, serde_json::to_string(&request).unwrap_or_default());
106 Ok(request)
107 }
108
109 async fn after(&self, _request: Value, response: Value) -> Result<Value> {
110 eprintln!("[{}] Response: {}", self.tag, serde_json::to_string(&response).unwrap_or_default());
111 Ok(response)
112 }
113
114 async fn on_error(&self, _request: Value, error: Error) -> Result<Value> {
115 eprintln!("[{}] Error: {}", self.tag, error);
116 Err(error)
117 }
118}
119
120pub struct ValidationMiddleware {
122 required_fields: Vec<String>,
123}
124
125impl ValidationMiddleware {
126 pub fn new(required_fields: Vec<String>) -> Self {
127 Self { required_fields }
128 }
129}
130
131#[async_trait::async_trait]
132impl Middleware for ValidationMiddleware {
133 async fn before(&self, request: Value) -> Result<Value> {
134 if let Value::Object(obj) = &request {
135 for field in &self.required_fields {
136 if !obj.contains_key(field) {
137 return Err(Error::Handler(format!(
138 "Missing required field: {}",
139 field
140 )));
141 }
142 }
143 }
144 Ok(request)
145 }
146}
147
148pub struct TransformMiddleware<BeforeFn, AfterFn>
150where
151 BeforeFn: Fn(Value) -> Result<Value> + Send + Sync,
152 AfterFn: Fn(Value) -> Result<Value> + Send + Sync,
153{
154 before_fn: BeforeFn,
155 after_fn: AfterFn,
156}
157
158impl<BeforeFn, AfterFn> TransformMiddleware<BeforeFn, AfterFn>
159where
160 BeforeFn: Fn(Value) -> Result<Value> + Send + Sync,
161 AfterFn: Fn(Value) -> Result<Value> + Send + Sync,
162{
163 pub fn new(before_fn: BeforeFn, after_fn: AfterFn) -> Self {
164 Self { before_fn, after_fn }
165 }
166}
167
168#[async_trait::async_trait]
169impl<BeforeFn, AfterFn> Middleware for TransformMiddleware<BeforeFn, AfterFn>
170where
171 BeforeFn: Fn(Value) -> Result<Value> + Send + Sync,
172 AfterFn: Fn(Value) -> Result<Value> + Send + Sync,
173{
174 async fn before(&self, request: Value) -> Result<Value> {
175 (self.before_fn)(request)
176 }
177
178 async fn after(&self, _request: Value, response: Value) -> Result<Value> {
179 (self.after_fn)(response)
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use serde_json::json;
187
188 struct TestMiddleware {
189 tag: String,
190 }
191
192 #[async_trait::async_trait]
193 impl Middleware for TestMiddleware {
194 async fn before(&self, mut request: Value) -> Result<Value> {
195 if let Value::Object(ref mut obj) = request {
196 obj.insert(
197 format!("{}_before", self.tag),
198 Value::Bool(true),
199 );
200 }
201 Ok(request)
202 }
203
204 async fn after(&self, _request: Value, mut response: Value) -> Result<Value> {
205 if let Value::Object(ref mut obj) = response {
206 obj.insert(
207 format!("{}_after", self.tag),
208 Value::Bool(true),
209 );
210 }
211 Ok(response)
212 }
213 }
214
215 #[tokio::test]
216 async fn test_middleware_chain_execution_order() {
217 let mut chain = MiddlewareChain::new();
218
219 chain.add(Arc::new(TestMiddleware { tag: "first".to_string() }));
220 chain.add(Arc::new(TestMiddleware { tag: "second".to_string() }));
221
222 let request = json!({});
223 let result = chain
224 .execute(request, |req| async move {
225 assert!(req["first_before"].as_bool().unwrap_or(false));
227 assert!(req["second_before"].as_bool().unwrap_or(false));
228 Ok(json!({}))
229 })
230 .await
231 .unwrap();
232
233 assert!(result["second_after"].as_bool().unwrap_or(false));
235 assert!(result["first_after"].as_bool().unwrap_or(false));
236 }
237
238 #[tokio::test]
239 async fn test_validation_middleware() {
240 let middleware = ValidationMiddleware::new(vec!["name".to_string(), "age".to_string()]);
241
242 let valid_request = json!({"name": "Alice", "age": 30});
244 let result = middleware.before(valid_request).await;
245 assert!(result.is_ok());
246
247 let invalid_request = json!({"name": "Alice"});
249 let result = middleware.before(invalid_request).await;
250 assert!(result.is_err());
251 assert!(result.unwrap_err().to_string().contains("Missing required field"));
252 }
253
254 #[tokio::test]
255 async fn test_transform_middleware() {
256 let middleware = TransformMiddleware::new(
257 |mut req| {
258 if let Value::Object(ref mut obj) = req {
259 if let Some(Value::String(s)) = obj.get("name") {
260 obj.insert("name".to_string(), Value::String(s.to_uppercase()));
261 }
262 }
263 Ok(req)
264 },
265 |mut resp| {
266 if let Value::Object(ref mut obj) = resp {
267 obj.insert("transformed".to_string(), Value::Bool(true));
268 }
269 Ok(resp)
270 },
271 );
272
273 let request = json!({"name": "alice"});
274 let transformed = middleware.before(request).await.unwrap();
275 assert_eq!(transformed["name"], "ALICE");
276
277 let response = json!({});
278 let transformed = middleware.after(json!({}), response).await.unwrap();
279 assert_eq!(transformed["transformed"], true);
280 }
281
282 #[tokio::test]
283 async fn test_error_handling_middleware() {
284 struct RecoveryMiddleware;
285
286 #[async_trait::async_trait]
287 impl Middleware for RecoveryMiddleware {
288 async fn on_error(&self, _request: Value, error: Error) -> Result<Value> {
289 if error.to_string().contains("recoverable") {
291 Ok(json!({"recovered": true}))
292 } else {
293 Err(error)
294 }
295 }
296 }
297
298 let mut chain = MiddlewareChain::new();
299 chain.add(Arc::new(RecoveryMiddleware));
300
301 let result = chain
303 .execute(json!({}), |_| async {
304 Err(Error::Handler("recoverable error".to_string()))
305 })
306 .await;
307
308 assert!(result.is_ok());
309 assert_eq!(result.unwrap()["recovered"], true);
310
311 let result = chain
313 .execute(json!({}), |_| async {
314 Err(Error::Handler("fatal error".to_string()))
315 })
316 .await;
317
318 assert!(result.is_err());
319 }
320
321 #[tokio::test]
322 async fn test_multiple_middleware_composition() {
323 let mut chain = MiddlewareChain::new();
324
325 chain.add(Arc::new(ValidationMiddleware::new(vec!["input".to_string()])));
326 chain.add(Arc::new(TransformMiddleware::new(
327 |mut req| {
328 if let Value::Object(ref mut obj) = req {
329 if let Some(Value::Number(n)) = obj.get("input") {
330 obj.insert("doubled".to_string(), Value::Number(serde_json::Number::from(n.as_i64().unwrap() * 2)));
331 }
332 }
333 Ok(req)
334 },
335 |resp| Ok(resp),
336 )));
337
338 let request = json!({"input": 5});
339 let result = chain
340 .execute(request, |req| async move {
341 assert_eq!(req["doubled"], 10);
342 Ok(json!({"result": req["doubled"].as_i64().unwrap() + 1}))
343 })
344 .await
345 .unwrap();
346
347 assert_eq!(result["result"], 11);
348 }
349}