mockforge_core/protocol_abstraction/
middleware.rs

1//! Unified middleware implementations for common patterns across protocols
2
3use super::{Protocol, ProtocolMiddleware, ProtocolRequest, ProtocolResponse};
4use crate::{request_logger::log_request_global, Result};
5use std::time::Instant;
6
7/// Logging middleware that works across all protocols
8pub struct LoggingMiddleware {
9    /// Middleware name
10    name: String,
11    /// Whether to log request bodies
12    #[allow(dead_code)]
13    log_bodies: bool,
14}
15
16impl LoggingMiddleware {
17    /// Create a new logging middleware
18    pub fn new(log_bodies: bool) -> Self {
19        Self {
20            name: "LoggingMiddleware".to_string(),
21            log_bodies,
22        }
23    }
24}
25
26#[async_trait::async_trait]
27impl ProtocolMiddleware for LoggingMiddleware {
28    fn name(&self) -> &str {
29        &self.name
30    }
31
32    async fn process_request(&self, request: &mut ProtocolRequest) -> Result<()> {
33        // Add timestamp to request metadata
34        let timestamp = chrono::Utc::now().to_rfc3339();
35        request.metadata.insert("x-mockforge-request-time".to_string(), timestamp);
36
37        // Store start time for duration calculation
38        request.metadata.insert(
39            "x-mockforge-request-start".to_string(),
40            Instant::now().elapsed().as_millis().to_string(),
41        );
42
43        tracing::debug!(
44            protocol = %request.protocol,
45            operation = %request.operation,
46            path = %request.path,
47            "Processing request through logging middleware"
48        );
49
50        Ok(())
51    }
52
53    async fn process_response(
54        &self,
55        request: &ProtocolRequest,
56        response: &mut ProtocolResponse,
57    ) -> Result<()> {
58        let duration_ms = if let Some(start) = request.metadata.get("x-mockforge-request-start") {
59            let start: u128 = start.parse().unwrap_or(0);
60            Instant::now().elapsed().as_millis() - start
61        } else {
62            0
63        };
64
65        // Create appropriate log entry based on protocol
66        let log_entry = match request.protocol {
67            Protocol::Http => crate::create_http_log_entry(
68                &request.operation,
69                &request.path,
70                response.status.as_code().unwrap_or(0) as u16,
71                duration_ms as u64,
72                request.client_ip.clone(),
73                request.metadata.get("user-agent").cloned(),
74                request.metadata.clone(),
75                response.body.len() as u64,
76                if !response.status.is_success() {
77                    Some(format!("Error response: {:?}", response.status))
78                } else {
79                    None
80                },
81            ),
82            Protocol::Grpc => {
83                // Extract service and method from operation (e.g., "greeter.SayHello")
84                let parts: Vec<&str> = request.operation.split('.').collect();
85                let (service, method) = if parts.len() == 2 {
86                    (parts[0], parts[1])
87                } else {
88                    ("unknown", request.operation.as_str())
89                };
90                crate::create_grpc_log_entry(
91                    service,
92                    method,
93                    response.status.as_code().unwrap_or(0) as u16,
94                    duration_ms as u64,
95                    request.client_ip.clone(),
96                    request.body.as_ref().map(|b| b.len() as u64).unwrap_or(0),
97                    response.body.len() as u64,
98                    if !response.status.is_success() {
99                        Some(format!("Error response: {:?}", response.status))
100                    } else {
101                        None
102                    },
103                )
104            }
105            Protocol::GraphQL => crate::create_http_log_entry(
106                "GraphQL",
107                &request.path,
108                if response.status.is_success() {
109                    200
110                } else {
111                    400
112                },
113                duration_ms as u64,
114                request.client_ip.clone(),
115                request.metadata.get("user-agent").cloned(),
116                request.metadata.clone(),
117                response.body.len() as u64,
118                None,
119            ),
120            Protocol::WebSocket => crate::create_websocket_log_entry(
121                &request.operation,
122                &request.path,
123                response.status.as_code().unwrap_or(0) as u16,
124                request.client_ip.clone(),
125                response.body.len() as u64,
126                if !response.status.is_success() {
127                    Some(format!("Error response: {:?}", response.status))
128                } else {
129                    None
130                },
131            ),
132            Protocol::Smtp => crate::create_http_log_entry(
133                "SMTP",
134                &request.path,
135                response.status.as_code().unwrap_or(250) as u16,
136                duration_ms as u64,
137                request.client_ip.clone(),
138                None,
139                request.metadata.clone(),
140                response.body.len() as u64,
141                if !response.status.is_success() {
142                    Some(format!("SMTP Error: {:?}", response.status))
143                } else {
144                    None
145                },
146            ),
147            Protocol::Mqtt => crate::create_http_log_entry(
148                "MQTT",
149                &request.topic.clone().unwrap_or_else(|| request.path.clone()),
150                if response.status.is_success() {
151                    200
152                } else {
153                    500
154                },
155                duration_ms as u64,
156                request.client_ip.clone(),
157                None,
158                request.metadata.clone(),
159                response.body.len() as u64,
160                if !response.status.is_success() {
161                    Some(format!("MQTT Error: {:?}", response.status))
162                } else {
163                    None
164                },
165            ),
166            Protocol::Ftp => crate::create_http_log_entry(
167                "FTP",
168                &request.path,
169                response.status.as_code().unwrap_or(226) as u16,
170                duration_ms as u64,
171                request.client_ip.clone(),
172                None,
173                request.metadata.clone(),
174                response.body.len() as u64,
175                if !response.status.is_success() {
176                    Some(format!("FTP Error: {:?}", response.status))
177                } else {
178                    None
179                },
180            ),
181            Protocol::Kafka => crate::create_http_log_entry(
182                "Kafka",
183                &request.topic.clone().unwrap_or_else(|| request.path.clone()),
184                response.status.as_code().unwrap_or(0) as u16,
185                duration_ms as u64,
186                request.client_ip.clone(),
187                None,
188                request.metadata.clone(),
189                response.body.len() as u64,
190                if !response.status.is_success() {
191                    Some(format!("Kafka Error: {:?}", response.status))
192                } else {
193                    None
194                },
195            ),
196            Protocol::RabbitMq | Protocol::Amqp => crate::create_http_log_entry(
197                "AMQP",
198                &request.routing_key.clone().unwrap_or_else(|| request.path.clone()),
199                response.status.as_code().unwrap_or(200) as u16,
200                duration_ms as u64,
201                request.client_ip.clone(),
202                None,
203                request.metadata.clone(),
204                response.body.len() as u64,
205                if !response.status.is_success() {
206                    Some(format!("AMQP Error: {:?}", response.status))
207                } else {
208                    None
209                },
210            ),
211        };
212
213        // Log to centralized logger
214        log_request_global(log_entry).await;
215
216        tracing::debug!(
217            protocol = %request.protocol,
218            operation = %request.operation,
219            path = %request.path,
220            duration_ms = duration_ms,
221            success = response.status.is_success(),
222            "Request processed"
223        );
224
225        Ok(())
226    }
227
228    fn supports_protocol(&self, _protocol: Protocol) -> bool {
229        // Logging middleware supports all protocols
230        true
231    }
232}
233
234/// Metrics middleware that collects metrics across all protocols
235pub struct MetricsMiddleware {
236    /// Middleware name
237    name: String,
238}
239
240impl MetricsMiddleware {
241    /// Create a new metrics middleware
242    pub fn new() -> Self {
243        Self {
244            name: "MetricsMiddleware".to_string(),
245        }
246    }
247}
248
249impl Default for MetricsMiddleware {
250    fn default() -> Self {
251        Self::new()
252    }
253}
254
255#[async_trait::async_trait]
256impl ProtocolMiddleware for MetricsMiddleware {
257    fn name(&self) -> &str {
258        &self.name
259    }
260
261    async fn process_request(&self, request: &mut ProtocolRequest) -> Result<()> {
262        // Store start time for metrics calculation
263        request.metadata.insert(
264            "x-mockforge-metrics-start".to_string(),
265            std::time::Instant::now().elapsed().as_millis().to_string(),
266        );
267
268        tracing::debug!(
269            protocol = %request.protocol,
270            operation = %request.operation,
271            "Metrics: request started"
272        );
273
274        Ok(())
275    }
276
277    async fn process_response(
278        &self,
279        request: &ProtocolRequest,
280        response: &mut ProtocolResponse,
281    ) -> Result<()> {
282        let duration_ms = if let Some(start) = request.metadata.get("x-mockforge-metrics-start") {
283            let start: u128 = start.parse().unwrap_or(0);
284            Instant::now().elapsed().as_millis() - start
285        } else {
286            0
287        };
288
289        let status_code = response.status.as_code().unwrap_or(0);
290
291        tracing::info!(
292            protocol = %request.protocol,
293            operation = %request.operation,
294            status_code = status_code,
295            duration_ms = duration_ms,
296            response_size = response.body.len(),
297            success = response.status.is_success(),
298            "Metrics: request completed"
299        );
300
301        Ok(())
302    }
303
304    fn supports_protocol(&self, _protocol: Protocol) -> bool {
305        // Metrics middleware supports all protocols
306        true
307    }
308}
309
310/// Latency injection middleware for simulating delays
311pub struct LatencyMiddleware {
312    /// Middleware name
313    name: String,
314    /// Latency injector
315    injector: crate::latency::LatencyInjector,
316}
317
318impl LatencyMiddleware {
319    /// Create a new latency middleware
320    pub fn new(injector: crate::latency::LatencyInjector) -> Self {
321        Self {
322            name: "LatencyMiddleware".to_string(),
323            injector,
324        }
325    }
326}
327
328#[async_trait::async_trait]
329impl ProtocolMiddleware for LatencyMiddleware {
330    fn name(&self) -> &str {
331        &self.name
332    }
333
334    async fn process_request(&self, request: &mut ProtocolRequest) -> Result<()> {
335        // Extract tags from request metadata
336        let tags: Vec<String> = request
337            .metadata
338            .get("x-mockforge-tags")
339            .map(|t| t.split(',').map(|s| s.trim().to_string()).collect())
340            .unwrap_or_default();
341
342        // Inject latency
343        self.injector.inject_latency(&tags).await?;
344
345        Ok(())
346    }
347
348    async fn process_response(
349        &self,
350        _request: &ProtocolRequest,
351        _response: &mut ProtocolResponse,
352    ) -> Result<()> {
353        // No post-processing needed for latency
354        Ok(())
355    }
356
357    fn supports_protocol(&self, _protocol: Protocol) -> bool {
358        // Latency middleware supports all protocols
359        true
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use std::collections::HashMap;
367
368    #[test]
369    fn test_logging_middleware_creation() {
370        let middleware = LoggingMiddleware::new(true);
371        assert_eq!(middleware.name(), "LoggingMiddleware");
372        assert!(middleware.supports_protocol(Protocol::Http));
373        assert!(middleware.supports_protocol(Protocol::GraphQL));
374        assert!(middleware.supports_protocol(Protocol::Grpc));
375    }
376
377    #[test]
378    fn test_metrics_middleware_creation() {
379        let middleware = MetricsMiddleware::new();
380        assert_eq!(middleware.name(), "MetricsMiddleware");
381        assert!(middleware.supports_protocol(Protocol::Http));
382        assert!(middleware.supports_protocol(Protocol::GraphQL));
383    }
384
385    #[test]
386    fn test_latency_middleware_creation() {
387        let injector = crate::latency::LatencyInjector::default();
388        let middleware = LatencyMiddleware::new(injector);
389        assert_eq!(middleware.name(), "LatencyMiddleware");
390        assert!(middleware.supports_protocol(Protocol::Http));
391    }
392
393    #[tokio::test]
394    async fn test_logging_middleware_process_request() {
395        let middleware = LoggingMiddleware::new(false);
396        let mut request = ProtocolRequest {
397            protocol: Protocol::Http,
398            pattern: crate::MessagePattern::RequestResponse,
399            operation: "GET".to_string(),
400            path: "/test".to_string(),
401            topic: None,
402            routing_key: None,
403            partition: None,
404            qos: None,
405            metadata: HashMap::new(),
406            body: None,
407            client_ip: None,
408        };
409
410        let result = middleware.process_request(&mut request).await;
411        assert!(result.is_ok());
412        assert!(request.metadata.contains_key("x-mockforge-request-time"));
413    }
414}