mockforge_grpc/reflection/mock_proxy/
middleware.rs

1//! Request processing middleware
2//!
3//! This module provides middleware for processing gRPC requests,
4//! including request transformation, logging, and metrics collection.
5
6use crate::reflection::metrics::{record_error, record_success};
7use crate::reflection::mock_proxy::proxy::MockReflectionProxy;
8use prost_reflect::{DynamicMessage, Kind, ReflectMessage};
9use std::time::Instant;
10use tonic::{
11    metadata::{Ascii, MetadataKey, MetadataValue},
12    Code, Request, Status,
13};
14use tracing::error;
15
16impl MockReflectionProxy {
17    /// Apply request preprocessing middleware
18    pub async fn preprocess_request<T>(&self, request: &mut Request<T>) -> Result<(), Status>
19    where
20        T: prost_reflect::ReflectMessage,
21    {
22        // Extract metadata
23        let mut metadata_log = Vec::new();
24        for kv in request.metadata().iter() {
25            match kv {
26                tonic::metadata::KeyAndValueRef::Ascii(key, value) => {
27                    metadata_log.push(format!("{}: {}", key, value.to_str().unwrap_or("<binary>")));
28                }
29                tonic::metadata::KeyAndValueRef::Binary(key, _) => {
30                    metadata_log.push(format!("{}: <binary>", key));
31                }
32            }
33        }
34        tracing::debug!("Extracted request metadata: [{}]", metadata_log.join(", "));
35
36        // Validate request format
37        let descriptor = request.get_ref().descriptor();
38        let mut buf = Vec::new();
39        request
40            .get_ref()
41            .encode(&mut buf)
42            .map_err(|_e| Status::internal("Failed to encode request".to_string()))?;
43        let dynamic_message = DynamicMessage::decode(descriptor.clone(), &buf[..])
44            .map_err(|_e| Status::internal("Failed to decode request".to_string()))?;
45        if let Err(e) = self.validate_request_message(&dynamic_message) {
46            return Err(Status::internal(format!("Request validation failed: {}", e)));
47        }
48        tracing::debug!("Request format validation passed");
49
50        // Apply request transformations
51        // Add mock-specific request headers
52        request.metadata_mut().insert("x-mockforge-processed", "true".parse().unwrap());
53        request
54            .metadata_mut()
55            .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
56
57        tracing::debug!("Applied request transformations: added processed and timestamp headers");
58
59        Ok(())
60    }
61
62    /// Apply request logging middleware
63    pub async fn log_request<T>(&self, request: &Request<T>, service_name: &str, method_name: &str)
64    where
65        T: prost_reflect::ReflectMessage,
66    {
67        let start_time = std::time::Instant::now();
68
69        // Log request metadata
70        let mut metadata_log = Vec::new();
71        for kv in request.metadata().iter() {
72            match kv {
73                tonic::metadata::KeyAndValueRef::Ascii(key, value) => {
74                    metadata_log.push(format!("{}: {}", key, value.to_str().unwrap_or("<binary>")));
75                }
76                tonic::metadata::KeyAndValueRef::Binary(key, _) => {
77                    metadata_log.push(format!("{}: <binary>", key));
78                }
79            }
80        }
81        tracing::debug!(
82            "Request metadata for {}/{}: [{}]",
83            service_name,
84            method_name,
85            metadata_log.join(", ")
86        );
87
88        // Log request size
89        let request_size = request.get_ref().encoded_len();
90        tracing::debug!(
91            "Request size for {}/{}: {} bytes",
92            service_name,
93            method_name,
94            request_size
95        );
96
97        // Log request timing (start time)
98        tracing::debug!(
99            "Request start time for {}/{}: {:?}",
100            service_name,
101            method_name,
102            start_time
103        );
104    }
105
106    /// Apply response postprocessing middleware
107    pub async fn postprocess_response<T>(
108        &self,
109        response: &mut tonic::Response<T>,
110        service_name: &str,
111        method_name: &str,
112    ) -> Result<(), Status> {
113        let start = Instant::now();
114        // Add mock-specific response headers
115        response.metadata_mut().insert("x-mockforge-processed", "true".parse().unwrap());
116        response
117            .metadata_mut()
118            .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
119
120        // // Add processing timestamp for performance monitoring
121        // let processing_time = std::time::SystemTime::now()
122        //     .duration_since(std::time::UNIX_EPOCH)
123        //     .unwrap()
124        //     .as_millis();
125        // response
126        //     .metadata_mut()
127        //     .insert("x-mockforge-processing-time", MetadataValue::<Ascii>::from(processing_time.to_string()));
128
129        // Apply response transformations based on configuration
130        if self.config.response_transform.enabled {
131            // Add custom headers from configuration
132            for (key, value) in &self.config.response_transform.custom_headers {
133                let key: MetadataKey<Ascii> = key.parse().unwrap();
134                let value: MetadataValue<Ascii> = value.parse().unwrap();
135                response.metadata_mut().insert(key, value);
136            }
137        }
138
139        // Log response processing
140        let processing_time = start.elapsed().as_millis();
141        // Add processing timestamp for performance monitoring
142        response
143            .metadata_mut()
144            .insert("x-mockforge-processing-time", processing_time.to_string().parse().unwrap());
145        tracing::debug!("Postprocessed response for {}/{}", service_name, method_name);
146
147        Ok(())
148    }
149
150    /// Apply response postprocessing with body transformations for DynamicMessage responses
151    pub async fn postprocess_dynamic_response(
152        &self,
153        response: &mut tonic::Response<prost_reflect::DynamicMessage>,
154        service_name: &str,
155        method_name: &str,
156    ) -> Result<(), Status> {
157        // First apply basic postprocessing
158        self.postprocess_response(response, service_name, method_name).await?;
159
160        // Apply body transformations if enabled
161        if self.config.response_transform.enabled {
162            if let Some(ref overrides) = self.config.response_transform.overrides {
163                match self
164                    .transform_dynamic_message(
165                        &response.get_ref().clone(),
166                        service_name,
167                        method_name,
168                        overrides,
169                    )
170                    .await
171                {
172                    Ok(transformed_message) => {
173                        // Replace the response body
174                        *response.get_mut() = transformed_message;
175                        tracing::debug!(
176                            "Applied body transformations to response for {}/{}",
177                            service_name,
178                            method_name
179                        );
180                    }
181                    Err(e) => {
182                        tracing::warn!(
183                            "Failed to transform response body for {}/{}: {}",
184                            service_name,
185                            method_name,
186                            e
187                        );
188                    }
189                }
190            }
191
192            // Response validation
193            if self.config.response_transform.validate_responses {
194                if let Err(validation_error) = self
195                    .validate_dynamic_message(response.get_ref(), service_name, method_name)
196                    .await
197                {
198                    tracing::warn!(
199                        "Response validation failed for {}/{}: {}",
200                        service_name,
201                        method_name,
202                        validation_error
203                    );
204                }
205            }
206        }
207
208        Ok(())
209    }
210
211    /// Transform a DynamicMessage using JSON overrides
212    async fn transform_dynamic_message(
213        &self,
214        message: &prost_reflect::DynamicMessage,
215        service_name: &str,
216        method_name: &str,
217        overrides: &mockforge_core::overrides::Overrides,
218    ) -> Result<prost_reflect::DynamicMessage, Box<dyn std::error::Error + Send + Sync>> {
219        use crate::dynamic::http_bridge::converters::ProtobufJsonConverter;
220
221        // Get descriptor pool from service registry
222        let descriptor_pool = self.service_registry.descriptor_pool();
223
224        // Create a converter for JSON transformations
225        let converter = ProtobufJsonConverter::new(descriptor_pool.clone());
226
227        // Convert protobuf message to JSON
228        let json_value = converter.protobuf_to_json(&message.descriptor(), message)?;
229
230        // Apply overrides to the JSON
231        let mut json_value = serde_json::Value::Object(json_value.as_object().unwrap().clone());
232        overrides.apply_with_context(
233            &format!("{}/{}", service_name, method_name),
234            &[service_name.to_string()],
235            &format!("{}/{}", service_name, method_name),
236            &mut json_value,
237            &mockforge_core::conditions::ConditionContext::new(),
238        );
239
240        // Convert back to protobuf message
241        let transformed_message = converter.json_to_protobuf(&message.descriptor(), &json_value)?;
242
243        Ok(transformed_message)
244    }
245
246    /// Apply response postprocessing for streaming DynamicMessage responses
247    pub async fn postprocess_streaming_dynamic_response(
248        &self,
249        response: &mut tonic::Response<
250            tokio_stream::wrappers::ReceiverStream<
251                Result<prost_reflect::DynamicMessage, tonic::Status>,
252            >,
253        >,
254        service_name: &str,
255        method_name: &str,
256    ) -> Result<(), Status> {
257        // Apply basic postprocessing (headers only for streaming responses)
258        self.postprocess_response(response, service_name, method_name).await?;
259
260        // Note: Body transformation for streaming responses is complex and not yet implemented
261        // It would require creating a new stream that transforms each message individually,
262        // which involves significant async complexity and descriptor pool management.
263
264        if self.config.response_transform.enabled {
265            if self.config.response_transform.overrides.is_some() {
266                tracing::debug!(
267                    "Body transformation for streaming responses not yet implemented for {}/{}",
268                    service_name,
269                    method_name
270                );
271            }
272
273            if self.config.response_transform.validate_responses {
274                tracing::debug!(
275                    "Response validation for streaming responses not yet implemented for {}/{}",
276                    service_name,
277                    method_name
278                );
279            }
280        }
281
282        Ok(())
283    }
284
285    /// Validate a DynamicMessage response
286    async fn validate_dynamic_message(
287        &self,
288        message: &prost_reflect::DynamicMessage,
289        service_name: &str,
290        method_name: &str,
291    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
292        // Basic validation: check that required fields are present
293        let _descriptor = message.descriptor();
294
295        // Note: In proto3, all fields are effectively optional
296        // Required field validation removed as is_required() method is no longer available
297
298        // Schema validation against expected message structure
299        // For protobuf, the message structure is validated by the descriptor,
300        // but we can check field constraints
301        self.validate_message_schema(message, service_name, method_name)?;
302
303        // Business rule validation (e.g., email format, date ranges)
304        self.validate_business_rules(message, service_name, method_name)?;
305
306        // Cross-field validation
307        self.validate_cross_field_rules(message, service_name, method_name)?;
308
309        // Custom validation rules from configuration
310        self.validate_custom_rules(message, service_name, method_name)?;
311
312        tracing::debug!("Response validation passed for {}/{}", service_name, method_name);
313
314        Ok(())
315    }
316
317    /// Validate a request DynamicMessage
318    fn validate_request_message(
319        &self,
320        message: &DynamicMessage,
321    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
322        // Schema validation
323        self.validate_message_schema(message, "", "")?;
324        // Business rule validation
325        self.validate_business_rules(message, "", "")?;
326        // Cross-field validation
327        self.validate_cross_field_rules(message, "", "")?;
328        // Custom validation
329        self.validate_custom_rules(message, "", "")?;
330        tracing::debug!("Request validation passed");
331        Ok(())
332    }
333
334    /// Validate message schema constraints
335    fn validate_message_schema(
336        &self,
337        message: &DynamicMessage,
338        _service_name: &str,
339        _method_name: &str,
340    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
341        let descriptor = message.descriptor();
342
343        // Check field types and constraints
344        for field in descriptor.fields() {
345            let value = message.get_field(&field);
346            let value_ref = value.as_ref();
347
348            // Check if the value kind matches the field kind
349            if !Self::value_matches_kind(value_ref, field.kind()) {
350                return Err(format!(
351                    "{} field '{}' has incorrect type: expected {:?}, got {:?}",
352                    "Message validation",
353                    field.name(),
354                    field.kind(),
355                    value_ref
356                )
357                .into());
358            }
359
360            // For nested messages, recursively validate
361            if let Kind::Message(expected_msg) = field.kind() {
362                if let prost_reflect::Value::Message(ref nested_msg) = *value_ref {
363                    // Basic nested message validation - could be expanded
364                    if nested_msg.descriptor() != expected_msg {
365                        return Err(format!(
366                            "{} field '{}' has incorrect message type",
367                            "Message validation",
368                            field.name()
369                        )
370                        .into());
371                    }
372                }
373            }
374        }
375
376        Ok(())
377    }
378
379    /// Validate business rules (email format, date ranges, etc.)
380    fn validate_business_rules(
381        &self,
382        message: &DynamicMessage,
383        service_name: &str,
384        method_name: &str,
385    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
386        let descriptor = message.descriptor();
387
388        for field in descriptor.fields() {
389            let value = message.get_field(&field);
390            let field_value = value.as_ref();
391            let field_name = field.name().to_lowercase();
392
393            // Email validation
394            if field_name.contains("email") && field.kind() == Kind::String {
395                if let Some(email_str) = field_value.as_str() {
396                    if !self.is_valid_email(email_str) {
397                        return Err(format!(
398                            "Invalid email format '{}' for field '{}' in {}/{}",
399                            email_str,
400                            field.name(),
401                            service_name,
402                            method_name
403                        )
404                        .into());
405                    }
406                }
407            }
408
409            // Date/timestamp validation
410            if field_name.contains("date") || field_name.contains("timestamp") {
411                match field.kind() {
412                    Kind::String => {
413                        if let Some(date_str) = field_value.as_str() {
414                            if !self.is_valid_iso8601_date(date_str) {
415                                return Err(format!(
416                                    "Invalid date format '{}' for field '{}' in {}/{}",
417                                    date_str,
418                                    field.name(),
419                                    service_name,
420                                    method_name
421                                )
422                                .into());
423                            }
424                        }
425                    }
426                    Kind::Int64 | Kind::Uint64 => {
427                        // For timestamp fields, check reasonable range (1970-2100)
428                        if let Some(timestamp) = field_value.as_i64() {
429                            if !(0..=4102444800).contains(&timestamp) {
430                                // 2100-01-01
431                                return Err(format!(
432                                    "Timestamp {} out of reasonable range for field '{}' in {}/{}",
433                                    timestamp,
434                                    field.name(),
435                                    service_name,
436                                    method_name
437                                )
438                                .into());
439                            }
440                        }
441                    }
442                    _ => {}
443                }
444            }
445
446            // Phone number validation (basic)
447            if field_name.contains("phone") && field.kind() == Kind::String {
448                if let Some(phone_str) = field_value.as_str() {
449                    if !self.is_valid_phone_number(phone_str) {
450                        return Err(format!(
451                            "Invalid phone number format '{}' for field '{}' in {}/{}",
452                            phone_str,
453                            field.name(),
454                            service_name,
455                            method_name
456                        )
457                        .into());
458                    }
459                }
460            }
461        }
462
463        Ok(())
464    }
465
466    /// Validate cross-field rules
467    fn validate_cross_field_rules(
468        &self,
469        message: &DynamicMessage,
470        service_name: &str,
471        method_name: &str,
472    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
473        let descriptor = message.descriptor();
474
475        // Collect date/time fields for cross-validation
476        let mut date_fields = Vec::new();
477        let mut timestamp_fields = Vec::new();
478
479        for field in descriptor.fields() {
480            let value = message.get_field(&field);
481            let field_value = value.as_ref();
482            let field_name = field.name().to_lowercase();
483
484            if field_name.contains("start")
485                && (field_name.contains("date") || field_name.contains("time"))
486            {
487                if let Some(value) = field_value.as_i64() {
488                    date_fields.push(("start", value));
489                }
490            } else if field_name.contains("end")
491                && (field_name.contains("date") || field_name.contains("time"))
492            {
493                if let Some(value) = field_value.as_i64() {
494                    date_fields.push(("end", value));
495                }
496            } else if field_name.contains("timestamp") {
497                if let Some(value) = field_value.as_i64() {
498                    timestamp_fields.push((field.name().to_string(), value));
499                }
500            }
501        }
502
503        // Validate start_date < end_date
504        if date_fields.len() >= 2 {
505            let start_dates: Vec<_> = date_fields.iter().filter(|(t, _)| *t == "start").collect();
506            let end_dates: Vec<_> = date_fields.iter().filter(|(t, _)| *t == "end").collect();
507
508            for &(_, start_val) in &start_dates {
509                for &(_, end_val) in &end_dates {
510                    if start_val >= end_val {
511                        return Err(format!(
512                            "Start date/time {} must be before end date/time {} in {}/{}",
513                            start_val, end_val, service_name, method_name
514                        )
515                        .into());
516                    }
517                }
518            }
519        }
520
521        // Validate timestamp ranges (e.g., created_at <= updated_at)
522        if timestamp_fields.len() >= 2 {
523            let created_at = timestamp_fields
524                .iter()
525                .find(|(name, _)| name.to_lowercase().contains("created"));
526            let updated_at = timestamp_fields
527                .iter()
528                .find(|(name, _)| name.to_lowercase().contains("updated"));
529
530            if let (Some((_, created)), Some((_, updated))) = (created_at, updated_at) {
531                if created > updated {
532                    return Err(format!(
533                        "Created timestamp {} cannot be after updated timestamp {} in {}/{}",
534                        created, updated, service_name, method_name
535                    )
536                    .into());
537                }
538            }
539        }
540
541        Ok(())
542    }
543
544    /// Validate custom rules from configuration
545    fn validate_custom_rules(
546        &self,
547        message: &DynamicMessage,
548        service_name: &str,
549        method_name: &str,
550    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
551        // For now, implement basic custom validation based on field names and values
552        // In a full implementation, this would read from a configuration file
553
554        let descriptor = message.descriptor();
555
556        for field in descriptor.fields() {
557            let value = message.get_field(&field);
558            let field_value = value.as_ref();
559            let field_name = field.name().to_lowercase();
560
561            // Custom rule: ID fields should be positive
562            if field_name.ends_with("_id") || field_name == "id" {
563                match field.kind() {
564                    Kind::Int32 | Kind::Int64 => {
565                        if let Some(id_val) = field_value.as_i64() {
566                            if id_val <= 0 {
567                                return Err(format!(
568                                    "ID field '{}' must be positive, got {} in {}/{}",
569                                    field.name(),
570                                    id_val,
571                                    service_name,
572                                    method_name
573                                )
574                                .into());
575                            }
576                        }
577                    }
578                    Kind::Uint32 | Kind::Uint64 => {
579                        if let Some(id_val) = field_value.as_u64() {
580                            if id_val == 0 {
581                                return Err(format!(
582                                    "ID field '{}' must be non-zero, got {} in {}/{}",
583                                    field.name(),
584                                    id_val,
585                                    service_name,
586                                    method_name
587                                )
588                                .into());
589                            }
590                        }
591                    }
592                    Kind::String => {
593                        if let Some(id_str) = field_value.as_str() {
594                            if id_str.trim().is_empty() {
595                                return Err(format!(
596                                    "ID field '{}' cannot be empty in {}/{}",
597                                    field.name(),
598                                    service_name,
599                                    method_name
600                                )
601                                .into());
602                            }
603                        }
604                    }
605                    _ => {}
606                }
607            }
608
609            // Custom rule: Amount/price fields should be non-negative
610            if field_name.contains("amount")
611                || field_name.contains("price")
612                || field_name.contains("cost")
613            {
614                if let Some(numeric_val) = field_value.as_f64() {
615                    if numeric_val < 0.0 {
616                        return Err(format!(
617                            "Amount/price field '{}' cannot be negative, got {} in {}/{}",
618                            field.name(),
619                            numeric_val,
620                            service_name,
621                            method_name
622                        )
623                        .into());
624                    }
625                }
626            }
627        }
628
629        Ok(())
630    }
631
632    /// Validate email format (basic)
633    fn is_valid_email(&self, email: &str) -> bool {
634        // Basic email validation: contains @ and . with reasonable structure
635        let parts: Vec<&str> = email.split('@').collect();
636        if parts.len() != 2 {
637            return false;
638        }
639
640        let local = parts[0];
641        let domain = parts[1];
642
643        if local.is_empty() || domain.is_empty() {
644            return false;
645        }
646
647        // Domain should contain a dot
648        domain.contains('.') && !domain.starts_with('.') && !domain.ends_with('.')
649    }
650
651    /// Validate phone number format (basic)
652    fn is_valid_phone_number(&self, phone: &str) -> bool {
653        // Basic phone validation: not empty and reasonable length
654        !phone.is_empty() && phone.len() >= 7 && phone.len() <= 15
655    }
656
657    /// Validate ISO 8601 date format (basic)
658    fn is_valid_iso8601_date(&self, date_str: &str) -> bool {
659        // Basic ISO 8601 validation: YYYY-MM-DDTHH:MM:SSZ or similar
660        // For simplicity, check if it parses as a date
661        chrono::DateTime::parse_from_rfc3339(date_str).is_ok()
662            || chrono::NaiveDate::parse_from_str(date_str, "%Y-%m-%d").is_ok()
663            || chrono::NaiveDateTime::parse_from_str(date_str, "%Y-%m-%d %H:%M:%S").is_ok()
664    }
665
666    /// Apply error handling middleware
667    pub async fn handle_error(
668        &self,
669        error: Status,
670        service_name: &str,
671        method_name: &str,
672    ) -> Status {
673        // Log error details with context
674        error!(
675            "Error in {}/{}: {} (code: {:?})",
676            service_name,
677            method_name,
678            error,
679            error.code()
680        );
681
682        match error.code() {
683            Code::InvalidArgument => Status::invalid_argument(format!(
684                "Invalid arguments provided to {}/{}",
685                service_name, method_name
686            )),
687            Code::NotFound => {
688                Status::not_found(format!("Resource not found in {}/{}", service_name, method_name))
689            }
690            Code::AlreadyExists => Status::already_exists(format!(
691                "Resource already exists in {}/{}",
692                service_name, method_name
693            )),
694            Code::PermissionDenied => Status::permission_denied(format!(
695                "Permission denied for {}/{}",
696                service_name, method_name
697            )),
698            Code::FailedPrecondition => Status::failed_precondition(format!(
699                "Precondition failed for {}/{}",
700                service_name, method_name
701            )),
702            Code::Aborted => {
703                Status::aborted(format!("Operation aborted for {}/{}", service_name, method_name))
704            }
705            Code::OutOfRange => Status::out_of_range(format!(
706                "Value out of range in {}/{}",
707                service_name, method_name
708            )),
709            Code::Unimplemented => Status::unimplemented(format!(
710                "Method {}/{} not implemented",
711                service_name, method_name
712            )),
713            Code::Internal => {
714                Status::internal(format!("Internal error in {}/{}", service_name, method_name))
715            }
716            Code::Unavailable => Status::unavailable(format!(
717                "Service {}/{} temporarily unavailable",
718                service_name, method_name
719            )),
720            Code::DataLoss => {
721                Status::data_loss(format!("Data loss occurred in {}/{}", service_name, method_name))
722            }
723            Code::Unauthenticated => Status::unauthenticated(format!(
724                "Authentication required for {}/{}",
725                service_name, method_name
726            )),
727            Code::DeadlineExceeded => Status::deadline_exceeded(format!(
728                "Request to {}/{} timed out",
729                service_name, method_name
730            )),
731            Code::ResourceExhausted => Status::resource_exhausted(format!(
732                "Rate limit exceeded for {}/{}",
733                service_name, method_name
734            )),
735            _ => {
736                let message = error.message();
737                if message.contains(service_name) && message.contains(method_name) {
738                    error
739                } else {
740                    Status::new(
741                        error.code(),
742                        format!("{}/{}: {}", service_name, method_name, message),
743                    )
744                }
745            }
746        }
747    }
748
749    /// Apply metrics collection middleware
750    pub async fn collect_metrics(
751        &self,
752        service_name: &str,
753        method_name: &str,
754        duration: std::time::Duration,
755        success: bool,
756    ) {
757        let duration_ms = duration.as_millis() as u64;
758
759        if success {
760            record_success(service_name, method_name, duration_ms).await;
761        } else {
762            record_error(service_name, method_name).await;
763        }
764
765        tracing::debug!(
766            "Request {}/{} completed in {:?}, success: {}",
767            service_name,
768            method_name,
769            duration,
770            success
771        );
772    }
773}
774
775#[cfg(test)]
776mod tests {
777
778    #[test]
779    fn test_module_compiles() {}
780}