mockforge_grpc/reflection/mock_proxy/
validation.rs

1//! Request validation and routing
2//!
3//! This module provides validation functionality for gRPC requests,
4//! including service/method validation and request routing.
5
6use crate::reflection::mock_proxy::proxy::MockReflectionProxy;
7use mockforge_core::openapi_routes::ValidationMode;
8use prost::bytes::Bytes as ProstBytes;
9use prost_reflect::ReflectMessage;
10use prost_reflect::{DynamicMessage, Kind, MessageDescriptor, Value};
11use tonic::{Request, Status};
12use tracing::debug;
13
14use prost_reflect::prost::Message;
15
16impl MockReflectionProxy {
17    /// Validate a request against the service method schema
18    pub async fn validate_request(
19        &self,
20        request: &Request<DynamicMessage>,
21        service_name: &str,
22        method_name: &str,
23    ) -> Result<(), Status> {
24        debug!("Validating request for {}/{}", service_name, method_name);
25
26        // Get method descriptor for validation
27        let method_descriptor = self.cache.get_method(service_name, method_name).await?;
28
29        // Get expected input descriptor
30        let expected_descriptor = method_descriptor.input();
31
32        // Get actual descriptor from the request message
33        let actual_descriptor = request.get_ref().descriptor();
34
35        // Check if the request descriptor matches the expected input type
36        if actual_descriptor.full_name() != expected_descriptor.full_name() {
37            return Err(Status::invalid_argument(format!(
38                "Request type mismatch: expected {}, got {}",
39                expected_descriptor.full_name(),
40                actual_descriptor.full_name()
41            )));
42        }
43
44        // Convert the typed message to DynamicMessage for field validation
45        let method_descriptor = self.cache.get_method(service_name, method_name).await?;
46        let expected_descriptor = method_descriptor.input();
47
48        let encoded = request.get_ref().encode_to_vec();
49        let dynamic_message =
50            DynamicMessage::decode(expected_descriptor.clone(), ProstBytes::from(encoded))
51                .map_err(|e| {
52                    Status::invalid_argument(format!(
53                        "Failed to decode request as DynamicMessage: {}",
54                        e
55                    ))
56                })?;
57
58        // Validate field types and presence
59        Self::validate_dynamic_message_fields(&dynamic_message, &expected_descriptor, "request")?;
60
61        debug!("Request validation passed for {}/{}", service_name, method_name);
62        Ok(())
63    }
64
65    /// Validate response data against the method's response schema
66    pub async fn validate_response(
67        &self,
68        response: &DynamicMessage,
69        service_name: &str,
70        method_name: &str,
71    ) -> Result<(), Status> {
72        debug!("Validating response for {}/{}", service_name, method_name);
73
74        // Get method descriptor for validation
75        let method_descriptor = self.cache.get_method(service_name, method_name).await?;
76
77        // Validate response against protobuf schema
78        let expected_descriptor = method_descriptor.output();
79
80        // Check if the response descriptor matches
81        if response.descriptor().full_name() != expected_descriptor.full_name() {
82            return Err(Status::invalid_argument(format!(
83                "Response type mismatch: expected {}, got {}",
84                expected_descriptor.full_name(),
85                response.descriptor().full_name()
86            )));
87        }
88
89        // Validate field types and presence
90        Self::validate_dynamic_message_fields(response, &expected_descriptor, "response")?;
91
92        debug!("Response validation passed for {}/{}", service_name, method_name);
93        Ok(())
94    }
95
96    /// Route a request to the appropriate handler
97    pub async fn route_request<T>(
98        &self,
99        request: Request<T>,
100    ) -> Result<(String, String, Request<T>), Status> {
101        // Extract service and method from request metadata
102        let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
103
104        // Validate that the service and method exist
105        let contains_service = self.cache.contains_service(&service_name).await;
106        if !contains_service {
107            return Err(Status::not_found(format!("Service {} not found", service_name)));
108        }
109
110        if self.cache.get_method(&service_name, &method_name).await.is_err() {
111            return Err(Status::not_found(format!(
112                "Method {} not found in service {}",
113                method_name, service_name
114            )));
115        }
116
117        Ok((service_name.to_string(), method_name.to_string(), request))
118    }
119
120    /// Check if a service method should be processed by this proxy
121    pub async fn can_handle_service_method(&self, service_name: &str, method_name: &str) -> bool {
122        // Check if service exists in cache
123        if !self.cache.contains_service(service_name).await {
124            return false;
125        }
126
127        // Check if method exists in service
128        if !self.cache.contains_method(service_name, method_name).await {
129            return false;
130        }
131
132        true
133    }
134
135    /// Validate service method signature compatibility
136    pub async fn validate_service_method_signature(
137        &self,
138        service_name: &str,
139        method_name: &str,
140        input_descriptor: MessageDescriptor,
141        output_descriptor: MessageDescriptor,
142    ) -> Result<(), Status> {
143        debug!("Validating signature for {}/{}", service_name, method_name);
144
145        // Check if method exists in cache
146        let cached_descriptor = self.cache.get_method(service_name, method_name).await?;
147
148        // Compare input/output types
149        if input_descriptor.full_name() != cached_descriptor.input().full_name() {
150            return Err(Status::invalid_argument(format!(
151                "Input type mismatch: expected {}, got {}",
152                cached_descriptor.input().full_name(),
153                input_descriptor.full_name()
154            )));
155        }
156
157        if output_descriptor.full_name() != cached_descriptor.output().full_name() {
158            return Err(Status::invalid_argument(format!(
159                "Output type mismatch: expected {}, got {}",
160                cached_descriptor.output().full_name(),
161                output_descriptor.full_name()
162            )));
163        }
164
165        // Validate field compatibility and check for breaking changes
166        Self::check_message_compatibility(&cached_descriptor.input(), &input_descriptor, "input")?;
167        Self::check_message_compatibility(
168            &cached_descriptor.output(),
169            &output_descriptor,
170            "output",
171        )?;
172
173        debug!("Signature validation passed for {}/{}", service_name, method_name);
174        Ok(())
175    }
176
177    /// Check if two message descriptors are compatible (no breaking changes)
178    fn check_message_compatibility(
179        expected: &MessageDescriptor,
180        provided: &MessageDescriptor,
181        message_type: &str,
182    ) -> Result<(), Status> {
183        for expected_field in expected.fields() {
184            let field_name = expected_field.name();
185            if let Some(provided_field) = provided.get_field_by_name(field_name) {
186                // Check if kinds match
187                if expected_field.kind() != provided_field.kind() {
188                    return Err(Status::invalid_argument(format!(
189                        "{} field '{}' type mismatch: expected {:?}, got {:?}",
190                        message_type,
191                        field_name,
192                        expected_field.kind(),
193                        provided_field.kind()
194                    )));
195                }
196
197                // For message types, check nested compatibility if full names differ
198                if let prost_reflect::Kind::Message(expected_msg) = expected_field.kind() {
199                    if let prost_reflect::Kind::Message(provided_msg) = provided_field.kind() {
200                        if expected_msg.full_name() != provided_msg.full_name() {
201                            // Recursively check nested messages
202                            Self::check_message_compatibility(
203                                &expected_msg,
204                                &provided_msg,
205                                &format!("{}.{}", message_type, field_name),
206                            )?;
207                        }
208                    }
209                }
210            } else {
211                return Err(Status::invalid_argument(format!(
212                    "Missing {} field '{}' in provided descriptor",
213                    message_type, field_name
214                )));
215            }
216        }
217
218        Ok(())
219    }
220
221    /// Validate fields of a DynamicMessage against its descriptor
222    fn validate_dynamic_message_fields(
223        message: &DynamicMessage,
224        descriptor: &MessageDescriptor,
225        context: &str,
226    ) -> Result<(), Status> {
227        for field in descriptor.fields() {
228            let field_name = field.name();
229
230            let value = message.get_field(&field);
231            let value_ref = value.as_ref();
232            // Check if the value kind matches the field kind
233            if !Self::value_matches_kind(value_ref, field.kind()) {
234                return Err(Status::invalid_argument(format!(
235                    "{} field '{}' has incorrect type: expected {:?}, got {:?}",
236                    context,
237                    field_name,
238                    field.kind(),
239                    value_ref
240                )));
241            }
242
243            // For nested messages, recursively validate
244            if let Kind::Message(expected_msg) = field.kind() {
245                if let Value::Message(ref nested_msg) = *value_ref {
246                    Self::validate_dynamic_message_fields(
247                        nested_msg,
248                        &expected_msg,
249                        &format!("{}.{}", context, field_name),
250                    )?;
251                }
252            }
253        }
254
255        Ok(())
256    }
257
258    /// Check if a Value matches a Kind
259    pub fn value_matches_kind(value: &Value, kind: prost_reflect::Kind) -> bool {
260        match *value {
261            prost_reflect::Value::Bool(_) => kind == prost_reflect::Kind::Bool,
262            prost_reflect::Value::I32(_) => matches!(
263                kind,
264                prost_reflect::Kind::Int32
265                    | prost_reflect::Kind::Sint32
266                    | prost_reflect::Kind::Sfixed32
267            ),
268            prost_reflect::Value::I64(_) => matches!(
269                kind,
270                prost_reflect::Kind::Int64
271                    | prost_reflect::Kind::Sint64
272                    | prost_reflect::Kind::Sfixed64
273            ),
274            prost_reflect::Value::U32(_) => {
275                matches!(kind, prost_reflect::Kind::Uint32 | prost_reflect::Kind::Fixed32)
276            }
277            prost_reflect::Value::U64(_) => {
278                matches!(kind, prost_reflect::Kind::Uint64 | prost_reflect::Kind::Fixed64)
279            }
280            prost_reflect::Value::F32(_) => kind == prost_reflect::Kind::Float,
281            prost_reflect::Value::F64(_) => kind == prost_reflect::Kind::Double,
282            prost_reflect::Value::String(_) => kind == prost_reflect::Kind::String,
283            prost_reflect::Value::Bytes(_) => kind == prost_reflect::Kind::Bytes,
284            prost_reflect::Value::Message(_) => matches!(kind, prost_reflect::Kind::Message(_)),
285            prost_reflect::Value::List(_) => matches!(kind, prost_reflect::Kind::Message(_)), // Lists are for repeated messages
286            _ => false,
287        }
288    }
289
290    /// Validate request size limits
291    pub async fn validate_request_size<T>(
292        &self,
293        request: &Request<T>,
294        max_size: usize,
295    ) -> Result<(), Status>
296    where
297        T: Message,
298    {
299        // Encode the request to get its serialized size
300        let encoded_size = request.get_ref().encode_to_vec().len();
301
302        // Check if the request size exceeds the configured limit
303        if encoded_size > max_size {
304            return Err(Status::resource_exhausted(format!(
305                "Request size {} bytes exceeds maximum allowed size of {} bytes",
306                encoded_size, max_size
307            )));
308        }
309
310        Ok(())
311    }
312
313    /// Validate response size limits
314    pub async fn validate_response_size(
315        &self,
316        response: &DynamicMessage,
317        max_size: usize,
318    ) -> Result<(), Status> {
319        // Encode the response to get its serialized size
320        let encoded_size = response.encode_to_vec().len();
321
322        // Check if the response size exceeds the configured limit
323        if encoded_size > max_size {
324            return Err(Status::resource_exhausted(format!(
325                "Response size {} bytes exceeds maximum allowed size of {} bytes",
326                encoded_size, max_size
327            )));
328        }
329
330        Ok(())
331    }
332
333    /// Check if request should be skipped for validation (admin endpoints, etc.)
334    pub fn should_skip_validation(&self, service_name: &str, method_name: &str) -> bool {
335        // Check admin skip prefixes from config
336        for prefix in &self.config.admin_skip_prefixes {
337            if service_name.starts_with(prefix) || method_name.starts_with(prefix) {
338                return true;
339            }
340        }
341
342        false
343    }
344
345    /// Apply validation mode for a service method
346    pub fn get_validation_mode_for_method(
347        &self,
348        service_name: &str,
349        method_name: &str,
350    ) -> ValidationMode {
351        // Check for method-specific overrides
352        if let Some(mode) = self.config.overrides.get(&format!("{}/{}", service_name, method_name))
353        {
354            return mode.clone();
355        }
356
357        // Check for service-specific overrides
358        if let Some(mode) = self.config.overrides.get(service_name) {
359            return mode.clone();
360        }
361
362        // Return default mode
363        self.config.request_mode.clone()
364    }
365}
366
367#[cfg(test)]
368mod tests {
369
370    #[test]
371    fn test_module_compiles() {}
372}