Skip to main content

mockforge_core/protocol_abstraction/
middleware.rs

1//! Unified middleware implementations for common patterns across protocols
2
3use super::{Protocol, ProtocolMiddleware, ProtocolRequest, ProtocolResponse};
4use crate::{request_logger::log_request_global, Result};
5use std::time::Instant;
6
7/// Logging middleware that works across all protocols
8pub struct LoggingMiddleware {
9    /// Middleware name
10    name: String,
11    /// Whether to log request/response bodies in debug traces
12    log_bodies: bool,
13}
14
15impl LoggingMiddleware {
16    /// Create a new logging middleware
17    pub fn new(log_bodies: bool) -> Self {
18        Self {
19            name: "LoggingMiddleware".to_string(),
20            log_bodies,
21        }
22    }
23}
24
25#[async_trait::async_trait]
26impl ProtocolMiddleware for LoggingMiddleware {
27    fn name(&self) -> &str {
28        &self.name
29    }
30
31    async fn process_request(&self, request: &mut ProtocolRequest) -> Result<()> {
32        // Add timestamp to request metadata
33        let timestamp = chrono::Utc::now().to_rfc3339();
34        request.metadata.insert("x-mockforge-request-time".to_string(), timestamp);
35
36        // Store start time for duration calculation
37        request.metadata.insert(
38            "x-mockforge-request-start".to_string(),
39            Instant::now().elapsed().as_millis().to_string(),
40        );
41
42        if self.log_bodies {
43            tracing::debug!(
44                protocol = %request.protocol,
45                operation = %request.operation,
46                path = %request.path,
47                body_size = request.body.as_ref().map(|b| b.len()).unwrap_or(0),
48                body = ?request.body.as_deref().and_then(|b| std::str::from_utf8(b).ok()),
49                "Processing request through logging middleware (with body)"
50            );
51        } else {
52            tracing::debug!(
53                protocol = %request.protocol,
54                operation = %request.operation,
55                path = %request.path,
56                "Processing request through logging middleware"
57            );
58        }
59
60        Ok(())
61    }
62
63    async fn process_response(
64        &self,
65        request: &ProtocolRequest,
66        response: &mut ProtocolResponse,
67    ) -> Result<()> {
68        let duration_ms = if let Some(start) = request.metadata.get("x-mockforge-request-start") {
69            let start: u128 = start.parse().unwrap_or(0);
70            Instant::now().elapsed().as_millis() - start
71        } else {
72            0
73        };
74
75        // Create appropriate log entry based on protocol
76        let log_entry = match request.protocol {
77            Protocol::Http => crate::create_http_log_entry(
78                &request.operation,
79                &request.path,
80                response.status.as_code().unwrap_or(0) as u16,
81                duration_ms as u64,
82                request.client_ip.clone(),
83                request.metadata.get("user-agent").cloned(),
84                request.metadata.clone(),
85                response.body.len() as u64,
86                if !response.status.is_success() {
87                    Some(format!("Error response: {:?}", response.status))
88                } else {
89                    None
90                },
91            ),
92            Protocol::Grpc => {
93                // Extract service and method from operation (e.g., "greeter.SayHello")
94                let parts: Vec<&str> = request.operation.split('.').collect();
95                let (service, method) = if parts.len() == 2 {
96                    (parts[0], parts[1])
97                } else {
98                    ("unknown", request.operation.as_str())
99                };
100                crate::create_grpc_log_entry(
101                    service,
102                    method,
103                    response.status.as_code().unwrap_or(0) as u16,
104                    duration_ms as u64,
105                    request.client_ip.clone(),
106                    request.body.as_ref().map(|b| b.len() as u64).unwrap_or(0),
107                    response.body.len() as u64,
108                    if !response.status.is_success() {
109                        Some(format!("Error response: {:?}", response.status))
110                    } else {
111                        None
112                    },
113                )
114            }
115            Protocol::GraphQL => crate::create_http_log_entry(
116                "GraphQL",
117                &request.path,
118                if response.status.is_success() {
119                    200
120                } else {
121                    400
122                },
123                duration_ms as u64,
124                request.client_ip.clone(),
125                request.metadata.get("user-agent").cloned(),
126                request.metadata.clone(),
127                response.body.len() as u64,
128                None,
129            ),
130            Protocol::WebSocket => crate::create_websocket_log_entry(
131                &request.operation,
132                &request.path,
133                response.status.as_code().unwrap_or(0) as u16,
134                request.client_ip.clone(),
135                response.body.len() as u64,
136                if !response.status.is_success() {
137                    Some(format!("Error response: {:?}", response.status))
138                } else {
139                    None
140                },
141            ),
142            Protocol::Smtp => crate::create_http_log_entry(
143                "SMTP",
144                &request.path,
145                response.status.as_code().unwrap_or(250) as u16,
146                duration_ms as u64,
147                request.client_ip.clone(),
148                None,
149                request.metadata.clone(),
150                response.body.len() as u64,
151                if !response.status.is_success() {
152                    Some(format!("SMTP Error: {:?}", response.status))
153                } else {
154                    None
155                },
156            ),
157            Protocol::Mqtt => crate::create_http_log_entry(
158                "MQTT",
159                &request.topic.clone().unwrap_or_else(|| request.path.clone()),
160                if response.status.is_success() {
161                    200
162                } else {
163                    500
164                },
165                duration_ms as u64,
166                request.client_ip.clone(),
167                None,
168                request.metadata.clone(),
169                response.body.len() as u64,
170                if !response.status.is_success() {
171                    Some(format!("MQTT Error: {:?}", response.status))
172                } else {
173                    None
174                },
175            ),
176            Protocol::Ftp => crate::create_http_log_entry(
177                "FTP",
178                &request.path,
179                response.status.as_code().unwrap_or(226) as u16,
180                duration_ms as u64,
181                request.client_ip.clone(),
182                None,
183                request.metadata.clone(),
184                response.body.len() as u64,
185                if !response.status.is_success() {
186                    Some(format!("FTP Error: {:?}", response.status))
187                } else {
188                    None
189                },
190            ),
191            Protocol::Kafka => crate::create_http_log_entry(
192                "Kafka",
193                &request.topic.clone().unwrap_or_else(|| request.path.clone()),
194                response.status.as_code().unwrap_or(0) as u16,
195                duration_ms as u64,
196                request.client_ip.clone(),
197                None,
198                request.metadata.clone(),
199                response.body.len() as u64,
200                if !response.status.is_success() {
201                    Some(format!("Kafka Error: {:?}", response.status))
202                } else {
203                    None
204                },
205            ),
206            Protocol::RabbitMq | Protocol::Amqp => crate::create_http_log_entry(
207                "AMQP",
208                &request.routing_key.clone().unwrap_or_else(|| request.path.clone()),
209                response.status.as_code().unwrap_or(200) as u16,
210                duration_ms as u64,
211                request.client_ip.clone(),
212                None,
213                request.metadata.clone(),
214                response.body.len() as u64,
215                if !response.status.is_success() {
216                    Some(format!("AMQP Error: {:?}", response.status))
217                } else {
218                    None
219                },
220            ),
221            Protocol::Tcp => crate::create_http_log_entry(
222                "TCP",
223                &request.path,
224                response.status.as_code().unwrap_or(0) as u16,
225                duration_ms as u64,
226                request.client_ip.clone(),
227                None,
228                request.metadata.clone(),
229                response.body.len() as u64,
230                if !response.status.is_success() {
231                    Some(format!("TCP Error: {:?}", response.status))
232                } else {
233                    None
234                },
235            ),
236        };
237
238        // Log to centralized logger
239        log_request_global(log_entry).await;
240
241        if self.log_bodies {
242            tracing::debug!(
243                protocol = %request.protocol,
244                operation = %request.operation,
245                path = %request.path,
246                duration_ms = duration_ms,
247                success = response.status.is_success(),
248                response_body_size = response.body.len(),
249                response_body = ?std::str::from_utf8(&response.body).ok(),
250                "Request processed (with body)"
251            );
252        } else {
253            tracing::debug!(
254                protocol = %request.protocol,
255                operation = %request.operation,
256                path = %request.path,
257                duration_ms = duration_ms,
258                success = response.status.is_success(),
259                "Request processed"
260            );
261        }
262
263        Ok(())
264    }
265
266    fn supports_protocol(&self, _protocol: Protocol) -> bool {
267        // Logging middleware supports all protocols
268        true
269    }
270}
271
272/// Metrics middleware that collects metrics across all protocols
273pub struct MetricsMiddleware {
274    /// Middleware name
275    name: String,
276}
277
278impl MetricsMiddleware {
279    /// Create a new metrics middleware
280    pub fn new() -> Self {
281        Self {
282            name: "MetricsMiddleware".to_string(),
283        }
284    }
285}
286
287impl Default for MetricsMiddleware {
288    fn default() -> Self {
289        Self::new()
290    }
291}
292
293#[async_trait::async_trait]
294impl ProtocolMiddleware for MetricsMiddleware {
295    fn name(&self) -> &str {
296        &self.name
297    }
298
299    async fn process_request(&self, request: &mut ProtocolRequest) -> Result<()> {
300        // Store start time for metrics calculation
301        request.metadata.insert(
302            "x-mockforge-metrics-start".to_string(),
303            Instant::now().elapsed().as_millis().to_string(),
304        );
305
306        tracing::debug!(
307            protocol = %request.protocol,
308            operation = %request.operation,
309            "Metrics: request started"
310        );
311
312        Ok(())
313    }
314
315    async fn process_response(
316        &self,
317        request: &ProtocolRequest,
318        response: &mut ProtocolResponse,
319    ) -> Result<()> {
320        let duration_ms = if let Some(start) = request.metadata.get("x-mockforge-metrics-start") {
321            let start: u128 = start.parse().unwrap_or(0);
322            Instant::now().elapsed().as_millis() - start
323        } else {
324            0
325        };
326
327        let status_code = response.status.as_code().unwrap_or(0);
328
329        tracing::info!(
330            protocol = %request.protocol,
331            operation = %request.operation,
332            status_code = status_code,
333            duration_ms = duration_ms,
334            response_size = response.body.len(),
335            success = response.status.is_success(),
336            "Metrics: request completed"
337        );
338
339        Ok(())
340    }
341
342    fn supports_protocol(&self, _protocol: Protocol) -> bool {
343        // Metrics middleware supports all protocols
344        true
345    }
346}
347
348/// Latency injection middleware for simulating delays
349pub struct LatencyMiddleware {
350    /// Middleware name
351    name: String,
352    /// Latency injector
353    injector: crate::latency::LatencyInjector,
354}
355
356impl LatencyMiddleware {
357    /// Create a new latency middleware
358    pub fn new(injector: crate::latency::LatencyInjector) -> Self {
359        Self {
360            name: "LatencyMiddleware".to_string(),
361            injector,
362        }
363    }
364}
365
366#[async_trait::async_trait]
367impl ProtocolMiddleware for LatencyMiddleware {
368    fn name(&self) -> &str {
369        &self.name
370    }
371
372    async fn process_request(&self, request: &mut ProtocolRequest) -> Result<()> {
373        // Extract tags from request metadata
374        let tags: Vec<String> = request
375            .metadata
376            .get("x-mockforge-tags")
377            .map(|t| t.split(',').map(|s| s.trim().to_string()).collect())
378            .unwrap_or_default();
379
380        // Inject latency
381        self.injector.inject_latency(&tags).await?;
382
383        Ok(())
384    }
385
386    async fn process_response(
387        &self,
388        _request: &ProtocolRequest,
389        _response: &mut ProtocolResponse,
390    ) -> Result<()> {
391        // No post-processing needed for latency
392        Ok(())
393    }
394
395    fn supports_protocol(&self, _protocol: Protocol) -> bool {
396        // Latency middleware supports all protocols
397        true
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use std::collections::HashMap;
405
406    #[test]
407    fn test_logging_middleware_creation() {
408        let middleware = LoggingMiddleware::new(true);
409        assert_eq!(middleware.name(), "LoggingMiddleware");
410        assert!(middleware.supports_protocol(Protocol::Http));
411        assert!(middleware.supports_protocol(Protocol::GraphQL));
412        assert!(middleware.supports_protocol(Protocol::Grpc));
413    }
414
415    #[test]
416    fn test_metrics_middleware_creation() {
417        let middleware = MetricsMiddleware::new();
418        assert_eq!(middleware.name(), "MetricsMiddleware");
419        assert!(middleware.supports_protocol(Protocol::Http));
420        assert!(middleware.supports_protocol(Protocol::GraphQL));
421    }
422
423    #[test]
424    fn test_latency_middleware_creation() {
425        let injector = crate::latency::LatencyInjector::default();
426        let middleware = LatencyMiddleware::new(injector);
427        assert_eq!(middleware.name(), "LatencyMiddleware");
428        assert!(middleware.supports_protocol(Protocol::Http));
429    }
430
431    #[tokio::test]
432    async fn test_logging_middleware_process_request() {
433        let middleware = LoggingMiddleware::new(false);
434        let mut request = ProtocolRequest {
435            protocol: Protocol::Http,
436            pattern: crate::MessagePattern::RequestResponse,
437            operation: "GET".to_string(),
438            path: "/test".to_string(),
439            topic: None,
440            routing_key: None,
441            partition: None,
442            qos: None,
443            metadata: HashMap::new(),
444            body: None,
445            client_ip: None,
446        };
447
448        let result = middleware.process_request(&mut request).await;
449        assert!(result.is_ok());
450        assert!(request.metadata.contains_key("x-mockforge-request-time"));
451    }
452}