pforge_runtime/
middleware.rs

1use crate::{Error, Result};
2use serde_json::Value;
3use std::sync::Arc;
4
5/// Middleware trait for request/response processing
6#[async_trait::async_trait]
7pub trait Middleware: Send + Sync {
8    /// Process request before handler execution
9    /// Returns modified request or error
10    async fn before(&self, request: Value) -> Result<Value> {
11        Ok(request)
12    }
13
14    /// Process response after handler execution
15    /// Returns modified response or error
16    async fn after(&self, request: Value, response: Value) -> Result<Value> {
17        let _ = request;
18        Ok(response)
19    }
20
21    /// Handle errors from handler or downstream middleware
22    async fn on_error(&self, request: Value, error: Error) -> Result<Value> {
23        let _ = request;
24        Err(error)
25    }
26}
27
28/// Middleware chain manages ordered middleware execution
29pub struct MiddlewareChain {
30    middlewares: Vec<Arc<dyn Middleware>>,
31}
32
33impl MiddlewareChain {
34    pub fn new() -> Self {
35        Self {
36            middlewares: Vec::new(),
37        }
38    }
39
40    /// Add middleware to the chain
41    pub fn add(&mut self, middleware: Arc<dyn Middleware>) {
42        self.middlewares.push(middleware);
43    }
44
45    /// Execute middleware chain around a handler
46    pub async fn execute<F, Fut>(&self, mut request: Value, handler: F) -> Result<Value>
47    where
48        F: FnOnce(Value) -> Fut,
49        Fut: std::future::Future<Output = Result<Value>>,
50    {
51        // Execute "before" phase in order
52        for middleware in &self.middlewares {
53            request = middleware.before(request).await?;
54        }
55
56        // Execute handler
57        let result = handler(request.clone()).await;
58
59        // Execute "after" phase in reverse order or "on_error" if failed
60        match result {
61            Ok(mut response) => {
62                for middleware in self.middlewares.iter().rev() {
63                    response = middleware.after(request.clone(), response).await?;
64                }
65                Ok(response)
66            }
67            Err(error) => {
68                let mut current_error = error;
69                for middleware in self.middlewares.iter().rev() {
70                    match middleware.on_error(request.clone(), current_error).await {
71                        Ok(recovery_response) => return Ok(recovery_response),
72                        Err(new_error) => current_error = new_error,
73                    }
74                }
75                Err(current_error)
76            }
77        }
78    }
79}
80
81impl Default for MiddlewareChain {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87/// Logging middleware - logs requests and responses
88pub struct LoggingMiddleware {
89    tag: String,
90}
91
92impl LoggingMiddleware {
93    pub fn new(tag: impl Into<String>) -> Self {
94        Self { tag: tag.into() }
95    }
96}
97
98#[async_trait::async_trait]
99impl Middleware for LoggingMiddleware {
100    async fn before(&self, request: Value) -> Result<Value> {
101        eprintln!(
102            "[{}] Request: {}",
103            self.tag,
104            serde_json::to_string(&request).unwrap_or_default()
105        );
106        Ok(request)
107    }
108
109    async fn after(&self, _request: Value, response: Value) -> Result<Value> {
110        eprintln!(
111            "[{}] Response: {}",
112            self.tag,
113            serde_json::to_string(&response).unwrap_or_default()
114        );
115        Ok(response)
116    }
117
118    async fn on_error(&self, _request: Value, error: Error) -> Result<Value> {
119        eprintln!("[{}] Error: {}", self.tag, error);
120        Err(error)
121    }
122}
123
124/// Validation middleware - validates request structure
125pub struct ValidationMiddleware {
126    required_fields: Vec<String>,
127}
128
129impl ValidationMiddleware {
130    pub fn new(required_fields: Vec<String>) -> Self {
131        Self { required_fields }
132    }
133}
134
135#[async_trait::async_trait]
136impl Middleware for ValidationMiddleware {
137    async fn before(&self, request: Value) -> Result<Value> {
138        if let Value::Object(obj) = &request {
139            for field in &self.required_fields {
140                if !obj.contains_key(field) {
141                    return Err(Error::Handler(format!("Missing required field: {}", field)));
142                }
143            }
144        }
145        Ok(request)
146    }
147}
148
149/// Transform middleware - applies transformations to request/response
150pub struct TransformMiddleware<BeforeFn, AfterFn>
151where
152    BeforeFn: Fn(Value) -> Result<Value> + Send + Sync,
153    AfterFn: Fn(Value) -> Result<Value> + Send + Sync,
154{
155    before_fn: BeforeFn,
156    after_fn: AfterFn,
157}
158
159impl<BeforeFn, AfterFn> TransformMiddleware<BeforeFn, AfterFn>
160where
161    BeforeFn: Fn(Value) -> Result<Value> + Send + Sync,
162    AfterFn: Fn(Value) -> Result<Value> + Send + Sync,
163{
164    pub fn new(before_fn: BeforeFn, after_fn: AfterFn) -> Self {
165        Self {
166            before_fn,
167            after_fn,
168        }
169    }
170}
171
172#[async_trait::async_trait]
173impl<BeforeFn, AfterFn> Middleware for TransformMiddleware<BeforeFn, AfterFn>
174where
175    BeforeFn: Fn(Value) -> Result<Value> + Send + Sync,
176    AfterFn: Fn(Value) -> Result<Value> + Send + Sync,
177{
178    async fn before(&self, request: Value) -> Result<Value> {
179        (self.before_fn)(request)
180    }
181
182    async fn after(&self, _request: Value, response: Value) -> Result<Value> {
183        (self.after_fn)(response)
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use serde_json::json;
191
192    struct TestMiddleware {
193        tag: String,
194    }
195
196    #[async_trait::async_trait]
197    impl Middleware for TestMiddleware {
198        async fn before(&self, mut request: Value) -> Result<Value> {
199            if let Value::Object(ref mut obj) = request {
200                obj.insert(format!("{}_before", self.tag), Value::Bool(true));
201            }
202            Ok(request)
203        }
204
205        async fn after(&self, _request: Value, mut response: Value) -> Result<Value> {
206            if let Value::Object(ref mut obj) = response {
207                obj.insert(format!("{}_after", self.tag), Value::Bool(true));
208            }
209            Ok(response)
210        }
211    }
212
213    #[tokio::test]
214    async fn test_middleware_chain_execution_order() {
215        let mut chain = MiddlewareChain::new();
216
217        chain.add(Arc::new(TestMiddleware {
218            tag: "first".to_string(),
219        }));
220        chain.add(Arc::new(TestMiddleware {
221            tag: "second".to_string(),
222        }));
223
224        let request = json!({});
225        let result = chain
226            .execute(request, |req| async move {
227                // Handler should see both "before" modifications
228                assert!(req["first_before"].as_bool().unwrap_or(false));
229                assert!(req["second_before"].as_bool().unwrap_or(false));
230                Ok(json!({}))
231            })
232            .await
233            .unwrap();
234
235        // Response should have "after" modifications in reverse order
236        assert!(result["second_after"].as_bool().unwrap_or(false));
237        assert!(result["first_after"].as_bool().unwrap_or(false));
238    }
239
240    #[tokio::test]
241    async fn test_validation_middleware() {
242        let middleware = ValidationMiddleware::new(vec!["name".to_string(), "age".to_string()]);
243
244        // Valid request
245        let valid_request = json!({"name": "Alice", "age": 30});
246        let result = middleware.before(valid_request).await;
247        assert!(result.is_ok());
248
249        // Invalid request - missing field
250        let invalid_request = json!({"name": "Alice"});
251        let result = middleware.before(invalid_request).await;
252        assert!(result.is_err());
253        assert!(result
254            .unwrap_err()
255            .to_string()
256            .contains("Missing required field"));
257    }
258
259    #[tokio::test]
260    async fn test_transform_middleware() {
261        let middleware = TransformMiddleware::new(
262            |mut req| {
263                if let Value::Object(ref mut obj) = req {
264                    if let Some(Value::String(s)) = obj.get("name") {
265                        obj.insert("name".to_string(), Value::String(s.to_uppercase()));
266                    }
267                }
268                Ok(req)
269            },
270            |mut resp| {
271                if let Value::Object(ref mut obj) = resp {
272                    obj.insert("transformed".to_string(), Value::Bool(true));
273                }
274                Ok(resp)
275            },
276        );
277
278        let request = json!({"name": "alice"});
279        let transformed = middleware.before(request).await.unwrap();
280        assert_eq!(transformed["name"], "ALICE");
281
282        let response = json!({});
283        let transformed = middleware.after(json!({}), response).await.unwrap();
284        assert_eq!(transformed["transformed"], true);
285    }
286
287    #[tokio::test]
288    async fn test_error_handling_middleware() {
289        struct RecoveryMiddleware;
290
291        #[async_trait::async_trait]
292        impl Middleware for RecoveryMiddleware {
293            async fn on_error(&self, _request: Value, error: Error) -> Result<Value> {
294                // Attempt to recover from specific errors
295                if error.to_string().contains("recoverable") {
296                    Ok(json!({"recovered": true}))
297                } else {
298                    Err(error)
299                }
300            }
301        }
302
303        let mut chain = MiddlewareChain::new();
304        chain.add(Arc::new(RecoveryMiddleware));
305
306        // Recoverable error
307        let result = chain
308            .execute(json!({}), |_| async {
309                Err(Error::Handler("recoverable error".to_string()))
310            })
311            .await;
312
313        assert!(result.is_ok());
314        assert_eq!(result.unwrap()["recovered"], true);
315
316        // Non-recoverable error
317        let result = chain
318            .execute(json!({}), |_| async {
319                Err(Error::Handler("fatal error".to_string()))
320            })
321            .await;
322
323        assert!(result.is_err());
324    }
325
326    #[tokio::test]
327    async fn test_multiple_middleware_composition() {
328        let mut chain = MiddlewareChain::new();
329
330        chain.add(Arc::new(ValidationMiddleware::new(vec![
331            "input".to_string()
332        ])));
333        chain.add(Arc::new(TransformMiddleware::new(
334            |mut req| {
335                if let Value::Object(ref mut obj) = req {
336                    if let Some(Value::Number(n)) = obj.get("input") {
337                        obj.insert(
338                            "doubled".to_string(),
339                            Value::Number(serde_json::Number::from(n.as_i64().unwrap() * 2)),
340                        );
341                    }
342                }
343                Ok(req)
344            },
345            Ok,
346        )));
347
348        let request = json!({"input": 5});
349        let result = chain
350            .execute(request, |req| async move {
351                assert_eq!(req["doubled"], 10);
352                Ok(json!({"result": req["doubled"].as_i64().unwrap() + 1}))
353            })
354            .await
355            .unwrap();
356
357        assert_eq!(result["result"], 11);
358    }
359}