pforge_runtime/
middleware.rs1use crate::{Error, Result};
2use serde_json::Value;
3use std::sync::Arc;
4
5#[async_trait::async_trait]
7pub trait Middleware: Send + Sync {
8 async fn before(&self, request: Value) -> Result<Value> {
11 Ok(request)
12 }
13
14 async fn after(&self, request: Value, response: Value) -> Result<Value> {
17 let _ = request;
18 Ok(response)
19 }
20
21 async fn on_error(&self, request: Value, error: Error) -> Result<Value> {
23 let _ = request;
24 Err(error)
25 }
26}
27
28pub struct MiddlewareChain {
30 middlewares: Vec<Arc<dyn Middleware>>,
31}
32
33impl MiddlewareChain {
34 pub fn new() -> Self {
35 Self {
36 middlewares: Vec::new(),
37 }
38 }
39
40 pub fn add(&mut self, middleware: Arc<dyn Middleware>) {
42 self.middlewares.push(middleware);
43 }
44
45 pub async fn execute<F, Fut>(&self, mut request: Value, handler: F) -> Result<Value>
47 where
48 F: FnOnce(Value) -> Fut,
49 Fut: std::future::Future<Output = Result<Value>>,
50 {
51 for middleware in &self.middlewares {
53 request = middleware.before(request).await?;
54 }
55
56 let result = handler(request.clone()).await;
58
59 match result {
61 Ok(mut response) => {
62 for middleware in self.middlewares.iter().rev() {
63 response = middleware.after(request.clone(), response).await?;
64 }
65 Ok(response)
66 }
67 Err(error) => {
68 let mut current_error = error;
69 for middleware in self.middlewares.iter().rev() {
70 match middleware.on_error(request.clone(), current_error).await {
71 Ok(recovery_response) => return Ok(recovery_response),
72 Err(new_error) => current_error = new_error,
73 }
74 }
75 Err(current_error)
76 }
77 }
78 }
79}
80
81impl Default for MiddlewareChain {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87pub struct LoggingMiddleware {
89 tag: String,
90}
91
92impl LoggingMiddleware {
93 pub fn new(tag: impl Into<String>) -> Self {
94 Self { tag: tag.into() }
95 }
96}
97
98#[async_trait::async_trait]
99impl Middleware for LoggingMiddleware {
100 async fn before(&self, request: Value) -> Result<Value> {
101 eprintln!(
102 "[{}] Request: {}",
103 self.tag,
104 serde_json::to_string(&request).unwrap_or_default()
105 );
106 Ok(request)
107 }
108
109 async fn after(&self, _request: Value, response: Value) -> Result<Value> {
110 eprintln!(
111 "[{}] Response: {}",
112 self.tag,
113 serde_json::to_string(&response).unwrap_or_default()
114 );
115 Ok(response)
116 }
117
118 async fn on_error(&self, _request: Value, error: Error) -> Result<Value> {
119 eprintln!("[{}] Error: {}", self.tag, error);
120 Err(error)
121 }
122}
123
124pub struct ValidationMiddleware {
126 required_fields: Vec<String>,
127}
128
129impl ValidationMiddleware {
130 pub fn new(required_fields: Vec<String>) -> Self {
131 Self { required_fields }
132 }
133}
134
135#[async_trait::async_trait]
136impl Middleware for ValidationMiddleware {
137 async fn before(&self, request: Value) -> Result<Value> {
138 if let Value::Object(obj) = &request {
139 for field in &self.required_fields {
140 if !obj.contains_key(field) {
141 return Err(Error::Handler(format!("Missing required field: {}", field)));
142 }
143 }
144 }
145 Ok(request)
146 }
147}
148
149pub struct TransformMiddleware<BeforeFn, AfterFn>
151where
152 BeforeFn: Fn(Value) -> Result<Value> + Send + Sync,
153 AfterFn: Fn(Value) -> Result<Value> + Send + Sync,
154{
155 before_fn: BeforeFn,
156 after_fn: AfterFn,
157}
158
159impl<BeforeFn, AfterFn> TransformMiddleware<BeforeFn, AfterFn>
160where
161 BeforeFn: Fn(Value) -> Result<Value> + Send + Sync,
162 AfterFn: Fn(Value) -> Result<Value> + Send + Sync,
163{
164 pub fn new(before_fn: BeforeFn, after_fn: AfterFn) -> Self {
165 Self {
166 before_fn,
167 after_fn,
168 }
169 }
170}
171
172#[async_trait::async_trait]
173impl<BeforeFn, AfterFn> Middleware for TransformMiddleware<BeforeFn, AfterFn>
174where
175 BeforeFn: Fn(Value) -> Result<Value> + Send + Sync,
176 AfterFn: Fn(Value) -> Result<Value> + Send + Sync,
177{
178 async fn before(&self, request: Value) -> Result<Value> {
179 (self.before_fn)(request)
180 }
181
182 async fn after(&self, _request: Value, response: Value) -> Result<Value> {
183 (self.after_fn)(response)
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use serde_json::json;
191
192 struct TestMiddleware {
193 tag: String,
194 }
195
196 #[async_trait::async_trait]
197 impl Middleware for TestMiddleware {
198 async fn before(&self, mut request: Value) -> Result<Value> {
199 if let Value::Object(ref mut obj) = request {
200 obj.insert(format!("{}_before", self.tag), Value::Bool(true));
201 }
202 Ok(request)
203 }
204
205 async fn after(&self, _request: Value, mut response: Value) -> Result<Value> {
206 if let Value::Object(ref mut obj) = response {
207 obj.insert(format!("{}_after", self.tag), Value::Bool(true));
208 }
209 Ok(response)
210 }
211 }
212
213 #[tokio::test]
214 async fn test_middleware_chain_execution_order() {
215 let mut chain = MiddlewareChain::new();
216
217 chain.add(Arc::new(TestMiddleware {
218 tag: "first".to_string(),
219 }));
220 chain.add(Arc::new(TestMiddleware {
221 tag: "second".to_string(),
222 }));
223
224 let request = json!({});
225 let result = chain
226 .execute(request, |req| async move {
227 assert!(req["first_before"].as_bool().unwrap_or(false));
229 assert!(req["second_before"].as_bool().unwrap_or(false));
230 Ok(json!({}))
231 })
232 .await
233 .unwrap();
234
235 assert!(result["second_after"].as_bool().unwrap_or(false));
237 assert!(result["first_after"].as_bool().unwrap_or(false));
238 }
239
240 #[tokio::test]
241 async fn test_validation_middleware() {
242 let middleware = ValidationMiddleware::new(vec!["name".to_string(), "age".to_string()]);
243
244 let valid_request = json!({"name": "Alice", "age": 30});
246 let result = middleware.before(valid_request).await;
247 assert!(result.is_ok());
248
249 let invalid_request = json!({"name": "Alice"});
251 let result = middleware.before(invalid_request).await;
252 assert!(result.is_err());
253 assert!(result
254 .unwrap_err()
255 .to_string()
256 .contains("Missing required field"));
257 }
258
259 #[tokio::test]
260 async fn test_transform_middleware() {
261 let middleware = TransformMiddleware::new(
262 |mut req| {
263 if let Value::Object(ref mut obj) = req {
264 if let Some(Value::String(s)) = obj.get("name") {
265 obj.insert("name".to_string(), Value::String(s.to_uppercase()));
266 }
267 }
268 Ok(req)
269 },
270 |mut resp| {
271 if let Value::Object(ref mut obj) = resp {
272 obj.insert("transformed".to_string(), Value::Bool(true));
273 }
274 Ok(resp)
275 },
276 );
277
278 let request = json!({"name": "alice"});
279 let transformed = middleware.before(request).await.unwrap();
280 assert_eq!(transformed["name"], "ALICE");
281
282 let response = json!({});
283 let transformed = middleware.after(json!({}), response).await.unwrap();
284 assert_eq!(transformed["transformed"], true);
285 }
286
287 #[tokio::test]
288 async fn test_error_handling_middleware() {
289 struct RecoveryMiddleware;
290
291 #[async_trait::async_trait]
292 impl Middleware for RecoveryMiddleware {
293 async fn on_error(&self, _request: Value, error: Error) -> Result<Value> {
294 if error.to_string().contains("recoverable") {
296 Ok(json!({"recovered": true}))
297 } else {
298 Err(error)
299 }
300 }
301 }
302
303 let mut chain = MiddlewareChain::new();
304 chain.add(Arc::new(RecoveryMiddleware));
305
306 let result = chain
308 .execute(json!({}), |_| async {
309 Err(Error::Handler("recoverable error".to_string()))
310 })
311 .await;
312
313 assert!(result.is_ok());
314 assert_eq!(result.unwrap()["recovered"], true);
315
316 let result = chain
318 .execute(json!({}), |_| async {
319 Err(Error::Handler("fatal error".to_string()))
320 })
321 .await;
322
323 assert!(result.is_err());
324 }
325
326 #[tokio::test]
327 async fn test_multiple_middleware_composition() {
328 let mut chain = MiddlewareChain::new();
329
330 chain.add(Arc::new(ValidationMiddleware::new(vec![
331 "input".to_string()
332 ])));
333 chain.add(Arc::new(TransformMiddleware::new(
334 |mut req| {
335 if let Value::Object(ref mut obj) = req {
336 if let Some(Value::Number(n)) = obj.get("input") {
337 obj.insert(
338 "doubled".to_string(),
339 Value::Number(serde_json::Number::from(n.as_i64().unwrap() * 2)),
340 );
341 }
342 }
343 Ok(req)
344 },
345 Ok,
346 )));
347
348 let request = json!({"input": 5});
349 let result = chain
350 .execute(request, |req| async move {
351 assert_eq!(req["doubled"], 10);
352 Ok(json!({"result": req["doubled"].as_i64().unwrap() + 1}))
353 })
354 .await
355 .unwrap();
356
357 assert_eq!(result["result"], 11);
358 }
359}