pforge_runtime/
middleware.rs

1use crate::{Error, Result};
2use serde_json::Value;
3use std::sync::Arc;
4
5/// Middleware trait for request/response processing
6#[async_trait::async_trait]
7pub trait Middleware: Send + Sync {
8    /// Process request before handler execution
9    /// Returns modified request or error
10    async fn before(&self, request: Value) -> Result<Value> {
11        Ok(request)
12    }
13
14    /// Process response after handler execution
15    /// Returns modified response or error
16    async fn after(&self, request: Value, response: Value) -> Result<Value> {
17        let _ = request;
18        Ok(response)
19    }
20
21    /// Handle errors from handler or downstream middleware
22    async fn on_error(&self, request: Value, error: Error) -> Result<Value> {
23        let _ = request;
24        Err(error)
25    }
26}
27
28/// Middleware chain manages ordered middleware execution
29pub struct MiddlewareChain {
30    middlewares: Vec<Arc<dyn Middleware>>,
31}
32
33impl MiddlewareChain {
34    pub fn new() -> Self {
35        Self {
36            middlewares: Vec::new(),
37        }
38    }
39
40    /// Add middleware to the chain
41    pub fn add(&mut self, middleware: Arc<dyn Middleware>) {
42        self.middlewares.push(middleware);
43    }
44
45    /// Execute middleware chain around a handler
46    pub async fn execute<F, Fut>(
47        &self,
48        mut request: Value,
49        handler: F,
50    ) -> Result<Value>
51    where
52        F: FnOnce(Value) -> Fut,
53        Fut: std::future::Future<Output = Result<Value>>,
54    {
55        // Execute "before" phase in order
56        for middleware in &self.middlewares {
57            request = middleware.before(request).await?;
58        }
59
60        // Execute handler
61        let result = handler(request.clone()).await;
62
63        // Execute "after" phase in reverse order or "on_error" if failed
64        match result {
65            Ok(mut response) => {
66                for middleware in self.middlewares.iter().rev() {
67                    response = middleware.after(request.clone(), response).await?;
68                }
69                Ok(response)
70            }
71            Err(error) => {
72                let mut current_error = error;
73                for middleware in self.middlewares.iter().rev() {
74                    match middleware.on_error(request.clone(), current_error).await {
75                        Ok(recovery_response) => return Ok(recovery_response),
76                        Err(new_error) => current_error = new_error,
77                    }
78                }
79                Err(current_error)
80            }
81        }
82    }
83}
84
85impl Default for MiddlewareChain {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91/// Logging middleware - logs requests and responses
92pub struct LoggingMiddleware {
93    tag: String,
94}
95
96impl LoggingMiddleware {
97    pub fn new(tag: impl Into<String>) -> Self {
98        Self { tag: tag.into() }
99    }
100}
101
102#[async_trait::async_trait]
103impl Middleware for LoggingMiddleware {
104    async fn before(&self, request: Value) -> Result<Value> {
105        eprintln!("[{}] Request: {}", self.tag, serde_json::to_string(&request).unwrap_or_default());
106        Ok(request)
107    }
108
109    async fn after(&self, _request: Value, response: Value) -> Result<Value> {
110        eprintln!("[{}] Response: {}", self.tag, serde_json::to_string(&response).unwrap_or_default());
111        Ok(response)
112    }
113
114    async fn on_error(&self, _request: Value, error: Error) -> Result<Value> {
115        eprintln!("[{}] Error: {}", self.tag, error);
116        Err(error)
117    }
118}
119
120/// Validation middleware - validates request structure
121pub struct ValidationMiddleware {
122    required_fields: Vec<String>,
123}
124
125impl ValidationMiddleware {
126    pub fn new(required_fields: Vec<String>) -> Self {
127        Self { required_fields }
128    }
129}
130
131#[async_trait::async_trait]
132impl Middleware for ValidationMiddleware {
133    async fn before(&self, request: Value) -> Result<Value> {
134        if let Value::Object(obj) = &request {
135            for field in &self.required_fields {
136                if !obj.contains_key(field) {
137                    return Err(Error::Handler(format!(
138                        "Missing required field: {}",
139                        field
140                    )));
141                }
142            }
143        }
144        Ok(request)
145    }
146}
147
148/// Transform middleware - applies transformations to request/response
149pub struct TransformMiddleware<BeforeFn, AfterFn>
150where
151    BeforeFn: Fn(Value) -> Result<Value> + Send + Sync,
152    AfterFn: Fn(Value) -> Result<Value> + Send + Sync,
153{
154    before_fn: BeforeFn,
155    after_fn: AfterFn,
156}
157
158impl<BeforeFn, AfterFn> TransformMiddleware<BeforeFn, AfterFn>
159where
160    BeforeFn: Fn(Value) -> Result<Value> + Send + Sync,
161    AfterFn: Fn(Value) -> Result<Value> + Send + Sync,
162{
163    pub fn new(before_fn: BeforeFn, after_fn: AfterFn) -> Self {
164        Self { before_fn, after_fn }
165    }
166}
167
168#[async_trait::async_trait]
169impl<BeforeFn, AfterFn> Middleware for TransformMiddleware<BeforeFn, AfterFn>
170where
171    BeforeFn: Fn(Value) -> Result<Value> + Send + Sync,
172    AfterFn: Fn(Value) -> Result<Value> + Send + Sync,
173{
174    async fn before(&self, request: Value) -> Result<Value> {
175        (self.before_fn)(request)
176    }
177
178    async fn after(&self, _request: Value, response: Value) -> Result<Value> {
179        (self.after_fn)(response)
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use serde_json::json;
187
188    struct TestMiddleware {
189        tag: String,
190    }
191
192    #[async_trait::async_trait]
193    impl Middleware for TestMiddleware {
194        async fn before(&self, mut request: Value) -> Result<Value> {
195            if let Value::Object(ref mut obj) = request {
196                obj.insert(
197                    format!("{}_before", self.tag),
198                    Value::Bool(true),
199                );
200            }
201            Ok(request)
202        }
203
204        async fn after(&self, _request: Value, mut response: Value) -> Result<Value> {
205            if let Value::Object(ref mut obj) = response {
206                obj.insert(
207                    format!("{}_after", self.tag),
208                    Value::Bool(true),
209                );
210            }
211            Ok(response)
212        }
213    }
214
215    #[tokio::test]
216    async fn test_middleware_chain_execution_order() {
217        let mut chain = MiddlewareChain::new();
218
219        chain.add(Arc::new(TestMiddleware { tag: "first".to_string() }));
220        chain.add(Arc::new(TestMiddleware { tag: "second".to_string() }));
221
222        let request = json!({});
223        let result = chain
224            .execute(request, |req| async move {
225                // Handler should see both "before" modifications
226                assert!(req["first_before"].as_bool().unwrap_or(false));
227                assert!(req["second_before"].as_bool().unwrap_or(false));
228                Ok(json!({}))
229            })
230            .await
231            .unwrap();
232
233        // Response should have "after" modifications in reverse order
234        assert!(result["second_after"].as_bool().unwrap_or(false));
235        assert!(result["first_after"].as_bool().unwrap_or(false));
236    }
237
238    #[tokio::test]
239    async fn test_validation_middleware() {
240        let middleware = ValidationMiddleware::new(vec!["name".to_string(), "age".to_string()]);
241
242        // Valid request
243        let valid_request = json!({"name": "Alice", "age": 30});
244        let result = middleware.before(valid_request).await;
245        assert!(result.is_ok());
246
247        // Invalid request - missing field
248        let invalid_request = json!({"name": "Alice"});
249        let result = middleware.before(invalid_request).await;
250        assert!(result.is_err());
251        assert!(result.unwrap_err().to_string().contains("Missing required field"));
252    }
253
254    #[tokio::test]
255    async fn test_transform_middleware() {
256        let middleware = TransformMiddleware::new(
257            |mut req| {
258                if let Value::Object(ref mut obj) = req {
259                    if let Some(Value::String(s)) = obj.get("name") {
260                        obj.insert("name".to_string(), Value::String(s.to_uppercase()));
261                    }
262                }
263                Ok(req)
264            },
265            |mut resp| {
266                if let Value::Object(ref mut obj) = resp {
267                    obj.insert("transformed".to_string(), Value::Bool(true));
268                }
269                Ok(resp)
270            },
271        );
272
273        let request = json!({"name": "alice"});
274        let transformed = middleware.before(request).await.unwrap();
275        assert_eq!(transformed["name"], "ALICE");
276
277        let response = json!({});
278        let transformed = middleware.after(json!({}), response).await.unwrap();
279        assert_eq!(transformed["transformed"], true);
280    }
281
282    #[tokio::test]
283    async fn test_error_handling_middleware() {
284        struct RecoveryMiddleware;
285
286        #[async_trait::async_trait]
287        impl Middleware for RecoveryMiddleware {
288            async fn on_error(&self, _request: Value, error: Error) -> Result<Value> {
289                // Attempt to recover from specific errors
290                if error.to_string().contains("recoverable") {
291                    Ok(json!({"recovered": true}))
292                } else {
293                    Err(error)
294                }
295            }
296        }
297
298        let mut chain = MiddlewareChain::new();
299        chain.add(Arc::new(RecoveryMiddleware));
300
301        // Recoverable error
302        let result = chain
303            .execute(json!({}), |_| async {
304                Err(Error::Handler("recoverable error".to_string()))
305            })
306            .await;
307
308        assert!(result.is_ok());
309        assert_eq!(result.unwrap()["recovered"], true);
310
311        // Non-recoverable error
312        let result = chain
313            .execute(json!({}), |_| async {
314                Err(Error::Handler("fatal error".to_string()))
315            })
316            .await;
317
318        assert!(result.is_err());
319    }
320
321    #[tokio::test]
322    async fn test_multiple_middleware_composition() {
323        let mut chain = MiddlewareChain::new();
324
325        chain.add(Arc::new(ValidationMiddleware::new(vec!["input".to_string()])));
326        chain.add(Arc::new(TransformMiddleware::new(
327            |mut req| {
328                if let Value::Object(ref mut obj) = req {
329                    if let Some(Value::Number(n)) = obj.get("input") {
330                        obj.insert("doubled".to_string(), Value::Number(serde_json::Number::from(n.as_i64().unwrap() * 2)));
331                    }
332                }
333                Ok(req)
334            },
335            |resp| Ok(resp),
336        )));
337
338        let request = json!({"input": 5});
339        let result = chain
340            .execute(request, |req| async move {
341                assert_eq!(req["doubled"], 10);
342                Ok(json!({"result": req["doubled"].as_i64().unwrap() + 1}))
343            })
344            .await
345            .unwrap();
346
347        assert_eq!(result["result"], 11);
348    }
349}