mockforge_grpc/reflection/mock_proxy/
validation.rs1use crate::reflection::mock_proxy::proxy::MockReflectionProxy;
7use mockforge_core::openapi_routes::ValidationMode;
8use prost::bytes::Bytes as ProstBytes;
9use prost_reflect::ReflectMessage;
10use prost_reflect::{DynamicMessage, Kind, MessageDescriptor, Value};
11use tonic::{Request, Status};
12use tracing::debug;
13
14use prost_reflect::prost::Message;
15
16impl MockReflectionProxy {
17 pub async fn validate_request(
19 &self,
20 request: &Request<DynamicMessage>,
21 service_name: &str,
22 method_name: &str,
23 ) -> Result<(), Status> {
24 debug!("Validating request for {}/{}", service_name, method_name);
25
26 let method_descriptor = self.cache.get_method(service_name, method_name).await?;
28
29 let expected_descriptor = method_descriptor.input();
31
32 let actual_descriptor = request.get_ref().descriptor();
34
35 if actual_descriptor.full_name() != expected_descriptor.full_name() {
37 return Err(Status::invalid_argument(format!(
38 "Request type mismatch: expected {}, got {}",
39 expected_descriptor.full_name(),
40 actual_descriptor.full_name()
41 )));
42 }
43
44 let method_descriptor = self.cache.get_method(service_name, method_name).await?;
46 let expected_descriptor = method_descriptor.input();
47
48 let encoded = request.get_ref().encode_to_vec();
49 let dynamic_message =
50 DynamicMessage::decode(expected_descriptor.clone(), ProstBytes::from(encoded))
51 .map_err(|e| {
52 Status::invalid_argument(format!(
53 "Failed to decode request as DynamicMessage: {}",
54 e
55 ))
56 })?;
57
58 Self::validate_dynamic_message_fields(&dynamic_message, &expected_descriptor, "request")?;
60
61 debug!("Request validation passed for {}/{}", service_name, method_name);
62 Ok(())
63 }
64
65 pub async fn validate_response(
67 &self,
68 response: &DynamicMessage,
69 service_name: &str,
70 method_name: &str,
71 ) -> Result<(), Status> {
72 debug!("Validating response for {}/{}", service_name, method_name);
73
74 let method_descriptor = self.cache.get_method(service_name, method_name).await?;
76
77 let expected_descriptor = method_descriptor.output();
79
80 if response.descriptor().full_name() != expected_descriptor.full_name() {
82 return Err(Status::invalid_argument(format!(
83 "Response type mismatch: expected {}, got {}",
84 expected_descriptor.full_name(),
85 response.descriptor().full_name()
86 )));
87 }
88
89 Self::validate_dynamic_message_fields(response, &expected_descriptor, "response")?;
91
92 debug!("Response validation passed for {}/{}", service_name, method_name);
93 Ok(())
94 }
95
96 pub async fn route_request<T>(
98 &self,
99 request: Request<T>,
100 ) -> Result<(String, String, Request<T>), Status> {
101 let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
103
104 let contains_service = self.cache.contains_service(&service_name).await;
106 if !contains_service {
107 return Err(Status::not_found(format!("Service {} not found", service_name)));
108 }
109
110 if self.cache.get_method(&service_name, &method_name).await.is_err() {
111 return Err(Status::not_found(format!(
112 "Method {} not found in service {}",
113 method_name, service_name
114 )));
115 }
116
117 Ok((service_name.to_string(), method_name.to_string(), request))
118 }
119
120 pub async fn can_handle_service_method(&self, service_name: &str, method_name: &str) -> bool {
122 if !self.cache.contains_service(service_name).await {
124 return false;
125 }
126
127 if !self.cache.contains_method(service_name, method_name).await {
129 return false;
130 }
131
132 true
133 }
134
135 pub async fn validate_service_method_signature(
137 &self,
138 service_name: &str,
139 method_name: &str,
140 input_descriptor: MessageDescriptor,
141 output_descriptor: MessageDescriptor,
142 ) -> Result<(), Status> {
143 debug!("Validating signature for {}/{}", service_name, method_name);
144
145 let cached_descriptor = self.cache.get_method(service_name, method_name).await?;
147
148 if input_descriptor.full_name() != cached_descriptor.input().full_name() {
150 return Err(Status::invalid_argument(format!(
151 "Input type mismatch: expected {}, got {}",
152 cached_descriptor.input().full_name(),
153 input_descriptor.full_name()
154 )));
155 }
156
157 if output_descriptor.full_name() != cached_descriptor.output().full_name() {
158 return Err(Status::invalid_argument(format!(
159 "Output type mismatch: expected {}, got {}",
160 cached_descriptor.output().full_name(),
161 output_descriptor.full_name()
162 )));
163 }
164
165 Self::check_message_compatibility(&cached_descriptor.input(), &input_descriptor, "input")?;
167 Self::check_message_compatibility(
168 &cached_descriptor.output(),
169 &output_descriptor,
170 "output",
171 )?;
172
173 debug!("Signature validation passed for {}/{}", service_name, method_name);
174 Ok(())
175 }
176
177 fn check_message_compatibility(
179 expected: &MessageDescriptor,
180 provided: &MessageDescriptor,
181 message_type: &str,
182 ) -> Result<(), Status> {
183 for expected_field in expected.fields() {
184 let field_name = expected_field.name();
185 if let Some(provided_field) = provided.get_field_by_name(field_name) {
186 if expected_field.kind() != provided_field.kind() {
188 return Err(Status::invalid_argument(format!(
189 "{} field '{}' type mismatch: expected {:?}, got {:?}",
190 message_type,
191 field_name,
192 expected_field.kind(),
193 provided_field.kind()
194 )));
195 }
196
197 if let prost_reflect::Kind::Message(expected_msg) = expected_field.kind() {
199 if let prost_reflect::Kind::Message(provided_msg) = provided_field.kind() {
200 if expected_msg.full_name() != provided_msg.full_name() {
201 Self::check_message_compatibility(
203 &expected_msg,
204 &provided_msg,
205 &format!("{}.{}", message_type, field_name),
206 )?;
207 }
208 }
209 }
210 } else {
211 return Err(Status::invalid_argument(format!(
212 "Missing {} field '{}' in provided descriptor",
213 message_type, field_name
214 )));
215 }
216 }
217
218 Ok(())
219 }
220
221 fn validate_dynamic_message_fields(
223 message: &DynamicMessage,
224 descriptor: &MessageDescriptor,
225 context: &str,
226 ) -> Result<(), Status> {
227 for field in descriptor.fields() {
228 let field_name = field.name();
229
230 let value = message.get_field(&field);
231 let value_ref = value.as_ref();
232 if !Self::value_matches_kind(value_ref, field.kind()) {
234 return Err(Status::invalid_argument(format!(
235 "{} field '{}' has incorrect type: expected {:?}, got {:?}",
236 context,
237 field_name,
238 field.kind(),
239 value_ref
240 )));
241 }
242
243 if let Kind::Message(expected_msg) = field.kind() {
245 if let Value::Message(ref nested_msg) = *value_ref {
246 Self::validate_dynamic_message_fields(
247 nested_msg,
248 &expected_msg,
249 &format!("{}.{}", context, field_name),
250 )?;
251 }
252 }
253 }
254
255 Ok(())
256 }
257
258 pub fn value_matches_kind(value: &Value, kind: prost_reflect::Kind) -> bool {
260 match *value {
261 prost_reflect::Value::Bool(_) => kind == prost_reflect::Kind::Bool,
262 prost_reflect::Value::I32(_) => matches!(
263 kind,
264 prost_reflect::Kind::Int32
265 | prost_reflect::Kind::Sint32
266 | prost_reflect::Kind::Sfixed32
267 ),
268 prost_reflect::Value::I64(_) => matches!(
269 kind,
270 prost_reflect::Kind::Int64
271 | prost_reflect::Kind::Sint64
272 | prost_reflect::Kind::Sfixed64
273 ),
274 prost_reflect::Value::U32(_) => {
275 matches!(kind, prost_reflect::Kind::Uint32 | prost_reflect::Kind::Fixed32)
276 }
277 prost_reflect::Value::U64(_) => {
278 matches!(kind, prost_reflect::Kind::Uint64 | prost_reflect::Kind::Fixed64)
279 }
280 prost_reflect::Value::F32(_) => kind == prost_reflect::Kind::Float,
281 prost_reflect::Value::F64(_) => kind == prost_reflect::Kind::Double,
282 prost_reflect::Value::String(_) => kind == prost_reflect::Kind::String,
283 prost_reflect::Value::Bytes(_) => kind == prost_reflect::Kind::Bytes,
284 prost_reflect::Value::Message(_) => matches!(kind, prost_reflect::Kind::Message(_)),
285 prost_reflect::Value::List(_) => matches!(kind, prost_reflect::Kind::Message(_)), _ => false,
287 }
288 }
289
290 pub async fn validate_request_size<T>(
292 &self,
293 request: &Request<T>,
294 max_size: usize,
295 ) -> Result<(), Status>
296 where
297 T: Message,
298 {
299 let encoded_size = request.get_ref().encode_to_vec().len();
301
302 if encoded_size > max_size {
304 return Err(Status::resource_exhausted(format!(
305 "Request size {} bytes exceeds maximum allowed size of {} bytes",
306 encoded_size, max_size
307 )));
308 }
309
310 Ok(())
311 }
312
313 pub async fn validate_response_size(
315 &self,
316 response: &DynamicMessage,
317 max_size: usize,
318 ) -> Result<(), Status> {
319 let encoded_size = response.encode_to_vec().len();
321
322 if encoded_size > max_size {
324 return Err(Status::resource_exhausted(format!(
325 "Response size {} bytes exceeds maximum allowed size of {} bytes",
326 encoded_size, max_size
327 )));
328 }
329
330 Ok(())
331 }
332
333 pub fn should_skip_validation(&self, service_name: &str, method_name: &str) -> bool {
335 for prefix in &self.config.admin_skip_prefixes {
337 if service_name.starts_with(prefix) || method_name.starts_with(prefix) {
338 return true;
339 }
340 }
341
342 false
343 }
344
345 pub fn get_validation_mode_for_method(
347 &self,
348 service_name: &str,
349 method_name: &str,
350 ) -> ValidationMode {
351 if let Some(mode) = self.config.overrides.get(&format!("{}/{}", service_name, method_name))
353 {
354 return mode.clone();
355 }
356
357 if let Some(mode) = self.config.overrides.get(service_name) {
359 return mode.clone();
360 }
361
362 self.config.request_mode.clone()
364 }
365}
366
367#[cfg(test)]
368mod tests {
369
370 #[test]
371 fn test_module_compiles() {}
372}