mockforge_core/protocol_abstraction/
middleware.rs

1//! Unified middleware implementations for common patterns across protocols
2
3use super::{Protocol, ProtocolMiddleware, ProtocolRequest, ProtocolResponse};
4use crate::{request_logger::log_request_global, Result};
5use std::time::Instant;
6
7/// Logging middleware that works across all protocols
8pub struct LoggingMiddleware {
9    /// Middleware name
10    name: String,
11    /// Whether to log request bodies
12    #[allow(dead_code)]
13    log_bodies: bool,
14}
15
16impl LoggingMiddleware {
17    /// Create a new logging middleware
18    pub fn new(log_bodies: bool) -> Self {
19        Self {
20            name: "LoggingMiddleware".to_string(),
21            log_bodies,
22        }
23    }
24}
25
26#[async_trait::async_trait]
27impl ProtocolMiddleware for LoggingMiddleware {
28    fn name(&self) -> &str {
29        &self.name
30    }
31
32    async fn process_request(&self, request: &mut ProtocolRequest) -> Result<()> {
33        // Add timestamp to request metadata
34        let timestamp = chrono::Utc::now().to_rfc3339();
35        request.metadata.insert("x-mockforge-request-time".to_string(), timestamp);
36
37        // Store start time for duration calculation
38        request.metadata.insert(
39            "x-mockforge-request-start".to_string(),
40            Instant::now().elapsed().as_millis().to_string(),
41        );
42
43        tracing::debug!(
44            protocol = %request.protocol,
45            operation = %request.operation,
46            path = %request.path,
47            "Processing request through logging middleware"
48        );
49
50        Ok(())
51    }
52
53    async fn process_response(
54        &self,
55        request: &ProtocolRequest,
56        response: &mut ProtocolResponse,
57    ) -> Result<()> {
58        let duration_ms = if let Some(start) = request.metadata.get("x-mockforge-request-start") {
59            let start: u128 = start.parse().unwrap_or(0);
60            Instant::now().elapsed().as_millis() - start
61        } else {
62            0
63        };
64
65        // Create appropriate log entry based on protocol
66        let log_entry = match request.protocol {
67            Protocol::Http => crate::create_http_log_entry(
68                &request.operation,
69                &request.path,
70                response.status.as_code().unwrap_or(0) as u16,
71                duration_ms as u64,
72                request.client_ip.clone(),
73                request.metadata.get("user-agent").cloned(),
74                request.metadata.clone(),
75                response.body.len() as u64,
76                if !response.status.is_success() {
77                    Some(format!("Error response: {:?}", response.status))
78                } else {
79                    None
80                },
81            ),
82            Protocol::Grpc => {
83                // Extract service and method from operation (e.g., "greeter.SayHello")
84                let parts: Vec<&str> = request.operation.split('.').collect();
85                let (service, method) = if parts.len() == 2 {
86                    (parts[0], parts[1])
87                } else {
88                    ("unknown", request.operation.as_str())
89                };
90                crate::create_grpc_log_entry(
91                    service,
92                    method,
93                    response.status.as_code().unwrap_or(0) as u16,
94                    duration_ms as u64,
95                    request.client_ip.clone(),
96                    request.body.as_ref().map(|b| b.len() as u64).unwrap_or(0),
97                    response.body.len() as u64,
98                    if !response.status.is_success() {
99                        Some(format!("Error response: {:?}", response.status))
100                    } else {
101                        None
102                    },
103                )
104            }
105            Protocol::GraphQL => crate::create_http_log_entry(
106                "GraphQL",
107                &request.path,
108                if response.status.is_success() {
109                    200
110                } else {
111                    400
112                },
113                duration_ms as u64,
114                request.client_ip.clone(),
115                request.metadata.get("user-agent").cloned(),
116                request.metadata.clone(),
117                response.body.len() as u64,
118                None,
119            ),
120            Protocol::WebSocket => crate::create_websocket_log_entry(
121                &request.operation,
122                &request.path,
123                response.status.as_code().unwrap_or(0) as u16,
124                request.client_ip.clone(),
125                response.body.len() as u64,
126                if !response.status.is_success() {
127                    Some(format!("Error response: {:?}", response.status))
128                } else {
129                    None
130                },
131            ),
132            Protocol::Smtp => crate::create_http_log_entry(
133                "SMTP",
134                &request.path,
135                response.status.as_code().unwrap_or(250) as u16,
136                duration_ms as u64,
137                request.client_ip.clone(),
138                None,
139                request.metadata.clone(),
140                response.body.len() as u64,
141                if !response.status.is_success() {
142                    Some(format!("SMTP Error: {:?}", response.status))
143                } else {
144                    None
145                },
146            ),
147            Protocol::Mqtt => crate::create_http_log_entry(
148                "MQTT",
149                &request.topic.clone().unwrap_or_else(|| request.path.clone()),
150                if response.status.is_success() {
151                    200
152                } else {
153                    500
154                },
155                duration_ms as u64,
156                request.client_ip.clone(),
157                None,
158                request.metadata.clone(),
159                response.body.len() as u64,
160                if !response.status.is_success() {
161                    Some(format!("MQTT Error: {:?}", response.status))
162                } else {
163                    None
164                },
165            ),
166            Protocol::Ftp => crate::create_http_log_entry(
167                "FTP",
168                &request.path,
169                response.status.as_code().unwrap_or(226) as u16,
170                duration_ms as u64,
171                request.client_ip.clone(),
172                None,
173                request.metadata.clone(),
174                response.body.len() as u64,
175                if !response.status.is_success() {
176                    Some(format!("FTP Error: {:?}", response.status))
177                } else {
178                    None
179                },
180            ),
181            Protocol::Kafka => crate::create_http_log_entry(
182                "Kafka",
183                &request.topic.clone().unwrap_or_else(|| request.path.clone()),
184                response.status.as_code().unwrap_or(0) as u16,
185                duration_ms as u64,
186                request.client_ip.clone(),
187                None,
188                request.metadata.clone(),
189                response.body.len() as u64,
190                if !response.status.is_success() {
191                    Some(format!("Kafka Error: {:?}", response.status))
192                } else {
193                    None
194                },
195            ),
196            Protocol::RabbitMq | Protocol::Amqp => crate::create_http_log_entry(
197                "AMQP",
198                &request.routing_key.clone().unwrap_or_else(|| request.path.clone()),
199                response.status.as_code().unwrap_or(200) as u16,
200                duration_ms as u64,
201                request.client_ip.clone(),
202                None,
203                request.metadata.clone(),
204                response.body.len() as u64,
205                if !response.status.is_success() {
206                    Some(format!("AMQP Error: {:?}", response.status))
207                } else {
208                    None
209                },
210            ),
211            Protocol::Tcp => crate::create_http_log_entry(
212                "TCP",
213                &request.path,
214                response.status.as_code().unwrap_or(0) as u16,
215                duration_ms as u64,
216                request.client_ip.clone(),
217                None,
218                request.metadata.clone(),
219                response.body.len() as u64,
220                if !response.status.is_success() {
221                    Some(format!("TCP Error: {:?}", response.status))
222                } else {
223                    None
224                },
225            ),
226        };
227
228        // Log to centralized logger
229        log_request_global(log_entry).await;
230
231        tracing::debug!(
232            protocol = %request.protocol,
233            operation = %request.operation,
234            path = %request.path,
235            duration_ms = duration_ms,
236            success = response.status.is_success(),
237            "Request processed"
238        );
239
240        Ok(())
241    }
242
243    fn supports_protocol(&self, _protocol: Protocol) -> bool {
244        // Logging middleware supports all protocols
245        true
246    }
247}
248
249/// Metrics middleware that collects metrics across all protocols
250pub struct MetricsMiddleware {
251    /// Middleware name
252    name: String,
253}
254
255impl MetricsMiddleware {
256    /// Create a new metrics middleware
257    pub fn new() -> Self {
258        Self {
259            name: "MetricsMiddleware".to_string(),
260        }
261    }
262}
263
264impl Default for MetricsMiddleware {
265    fn default() -> Self {
266        Self::new()
267    }
268}
269
270#[async_trait::async_trait]
271impl ProtocolMiddleware for MetricsMiddleware {
272    fn name(&self) -> &str {
273        &self.name
274    }
275
276    async fn process_request(&self, request: &mut ProtocolRequest) -> Result<()> {
277        // Store start time for metrics calculation
278        request.metadata.insert(
279            "x-mockforge-metrics-start".to_string(),
280            std::time::Instant::now().elapsed().as_millis().to_string(),
281        );
282
283        tracing::debug!(
284            protocol = %request.protocol,
285            operation = %request.operation,
286            "Metrics: request started"
287        );
288
289        Ok(())
290    }
291
292    async fn process_response(
293        &self,
294        request: &ProtocolRequest,
295        response: &mut ProtocolResponse,
296    ) -> Result<()> {
297        let duration_ms = if let Some(start) = request.metadata.get("x-mockforge-metrics-start") {
298            let start: u128 = start.parse().unwrap_or(0);
299            Instant::now().elapsed().as_millis() - start
300        } else {
301            0
302        };
303
304        let status_code = response.status.as_code().unwrap_or(0);
305
306        tracing::info!(
307            protocol = %request.protocol,
308            operation = %request.operation,
309            status_code = status_code,
310            duration_ms = duration_ms,
311            response_size = response.body.len(),
312            success = response.status.is_success(),
313            "Metrics: request completed"
314        );
315
316        Ok(())
317    }
318
319    fn supports_protocol(&self, _protocol: Protocol) -> bool {
320        // Metrics middleware supports all protocols
321        true
322    }
323}
324
325/// Latency injection middleware for simulating delays
326pub struct LatencyMiddleware {
327    /// Middleware name
328    name: String,
329    /// Latency injector
330    injector: crate::latency::LatencyInjector,
331}
332
333impl LatencyMiddleware {
334    /// Create a new latency middleware
335    pub fn new(injector: crate::latency::LatencyInjector) -> Self {
336        Self {
337            name: "LatencyMiddleware".to_string(),
338            injector,
339        }
340    }
341}
342
343#[async_trait::async_trait]
344impl ProtocolMiddleware for LatencyMiddleware {
345    fn name(&self) -> &str {
346        &self.name
347    }
348
349    async fn process_request(&self, request: &mut ProtocolRequest) -> Result<()> {
350        // Extract tags from request metadata
351        let tags: Vec<String> = request
352            .metadata
353            .get("x-mockforge-tags")
354            .map(|t| t.split(',').map(|s| s.trim().to_string()).collect())
355            .unwrap_or_default();
356
357        // Inject latency
358        self.injector.inject_latency(&tags).await?;
359
360        Ok(())
361    }
362
363    async fn process_response(
364        &self,
365        _request: &ProtocolRequest,
366        _response: &mut ProtocolResponse,
367    ) -> Result<()> {
368        // No post-processing needed for latency
369        Ok(())
370    }
371
372    fn supports_protocol(&self, _protocol: Protocol) -> bool {
373        // Latency middleware supports all protocols
374        true
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use std::collections::HashMap;
382
383    #[test]
384    fn test_logging_middleware_creation() {
385        let middleware = LoggingMiddleware::new(true);
386        assert_eq!(middleware.name(), "LoggingMiddleware");
387        assert!(middleware.supports_protocol(Protocol::Http));
388        assert!(middleware.supports_protocol(Protocol::GraphQL));
389        assert!(middleware.supports_protocol(Protocol::Grpc));
390    }
391
392    #[test]
393    fn test_metrics_middleware_creation() {
394        let middleware = MetricsMiddleware::new();
395        assert_eq!(middleware.name(), "MetricsMiddleware");
396        assert!(middleware.supports_protocol(Protocol::Http));
397        assert!(middleware.supports_protocol(Protocol::GraphQL));
398    }
399
400    #[test]
401    fn test_latency_middleware_creation() {
402        let injector = crate::latency::LatencyInjector::default();
403        let middleware = LatencyMiddleware::new(injector);
404        assert_eq!(middleware.name(), "LatencyMiddleware");
405        assert!(middleware.supports_protocol(Protocol::Http));
406    }
407
408    #[tokio::test]
409    async fn test_logging_middleware_process_request() {
410        let middleware = LoggingMiddleware::new(false);
411        let mut request = ProtocolRequest {
412            protocol: Protocol::Http,
413            pattern: crate::MessagePattern::RequestResponse,
414            operation: "GET".to_string(),
415            path: "/test".to_string(),
416            topic: None,
417            routing_key: None,
418            partition: None,
419            qos: None,
420            metadata: HashMap::new(),
421            body: None,
422            client_ip: None,
423        };
424
425        let result = middleware.process_request(&mut request).await;
426        assert!(result.is_ok());
427        assert!(request.metadata.contains_key("x-mockforge-request-time"));
428    }
429}