1use crate::reflection::metrics::{record_error, record_success};
7use crate::reflection::mock_proxy::proxy::MockReflectionProxy;
8use prost_reflect::{DynamicMessage, Kind, ReflectMessage};
9use std::time::Instant;
10use tonic::{
11 metadata::{Ascii, MetadataKey, MetadataValue},
12 Code, Request, Status,
13};
14use tracing::error;
15
16impl MockReflectionProxy {
17 pub async fn preprocess_request<T>(&self, request: &mut Request<T>) -> Result<(), Status>
19 where
20 T: prost_reflect::ReflectMessage,
21 {
22 let mut metadata_log = Vec::new();
24 for kv in request.metadata().iter() {
25 match kv {
26 tonic::metadata::KeyAndValueRef::Ascii(key, value) => {
27 metadata_log.push(format!("{}: {}", key, value.to_str().unwrap_or("<binary>")));
28 }
29 tonic::metadata::KeyAndValueRef::Binary(key, _) => {
30 metadata_log.push(format!("{}: <binary>", key));
31 }
32 }
33 }
34 tracing::debug!("Extracted request metadata: [{}]", metadata_log.join(", "));
35
36 let descriptor = request.get_ref().descriptor();
38 let mut buf = Vec::new();
39 request
40 .get_ref()
41 .encode(&mut buf)
42 .map_err(|_e| Status::internal("Failed to encode request".to_string()))?;
43 let dynamic_message = DynamicMessage::decode(descriptor.clone(), &buf[..])
44 .map_err(|_e| Status::internal("Failed to decode request".to_string()))?;
45 if let Err(e) = self.validate_request_message(&dynamic_message) {
46 return Err(Status::internal(format!("Request validation failed: {}", e)));
47 }
48 tracing::debug!("Request format validation passed");
49
50 request.metadata_mut().insert("x-mockforge-processed", "true".parse().unwrap());
53 request
54 .metadata_mut()
55 .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
56
57 tracing::debug!("Applied request transformations: added processed and timestamp headers");
58
59 Ok(())
60 }
61
62 pub async fn log_request<T>(&self, request: &Request<T>, service_name: &str, method_name: &str)
64 where
65 T: prost_reflect::ReflectMessage,
66 {
67 let start_time = std::time::Instant::now();
68
69 let mut metadata_log = Vec::new();
71 for kv in request.metadata().iter() {
72 match kv {
73 tonic::metadata::KeyAndValueRef::Ascii(key, value) => {
74 metadata_log.push(format!("{}: {}", key, value.to_str().unwrap_or("<binary>")));
75 }
76 tonic::metadata::KeyAndValueRef::Binary(key, _) => {
77 metadata_log.push(format!("{}: <binary>", key));
78 }
79 }
80 }
81 tracing::debug!(
82 "Request metadata for {}/{}: [{}]",
83 service_name,
84 method_name,
85 metadata_log.join(", ")
86 );
87
88 let request_size = request.get_ref().encoded_len();
90 tracing::debug!(
91 "Request size for {}/{}: {} bytes",
92 service_name,
93 method_name,
94 request_size
95 );
96
97 tracing::debug!(
99 "Request start time for {}/{}: {:?}",
100 service_name,
101 method_name,
102 start_time
103 );
104 }
105
106 pub async fn postprocess_response<T>(
108 &self,
109 response: &mut tonic::Response<T>,
110 service_name: &str,
111 method_name: &str,
112 ) -> Result<(), Status> {
113 let start = Instant::now();
114 response.metadata_mut().insert("x-mockforge-processed", "true".parse().unwrap());
116 response
117 .metadata_mut()
118 .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
119
120 if self.config.response_transform.enabled {
131 for (key, value) in &self.config.response_transform.custom_headers {
133 let key: MetadataKey<Ascii> = key.parse().unwrap();
134 let value: MetadataValue<Ascii> = value.parse().unwrap();
135 response.metadata_mut().insert(key, value);
136 }
137 }
138
139 let processing_time = start.elapsed().as_millis();
141 response
143 .metadata_mut()
144 .insert("x-mockforge-processing-time", processing_time.to_string().parse().unwrap());
145 tracing::debug!("Postprocessed response for {}/{}", service_name, method_name);
146
147 Ok(())
148 }
149
150 pub async fn postprocess_dynamic_response(
152 &self,
153 response: &mut tonic::Response<prost_reflect::DynamicMessage>,
154 service_name: &str,
155 method_name: &str,
156 ) -> Result<(), Status> {
157 self.postprocess_response(response, service_name, method_name).await?;
159
160 if self.config.response_transform.enabled {
162 if let Some(ref overrides) = self.config.response_transform.overrides {
163 match self
164 .transform_dynamic_message(
165 &response.get_ref().clone(),
166 service_name,
167 method_name,
168 overrides,
169 )
170 .await
171 {
172 Ok(transformed_message) => {
173 *response.get_mut() = transformed_message;
175 tracing::debug!(
176 "Applied body transformations to response for {}/{}",
177 service_name,
178 method_name
179 );
180 }
181 Err(e) => {
182 tracing::warn!(
183 "Failed to transform response body for {}/{}: {}",
184 service_name,
185 method_name,
186 e
187 );
188 }
189 }
190 }
191
192 if self.config.response_transform.validate_responses {
194 if let Err(validation_error) = self
195 .validate_dynamic_message(response.get_ref(), service_name, method_name)
196 .await
197 {
198 tracing::warn!(
199 "Response validation failed for {}/{}: {}",
200 service_name,
201 method_name,
202 validation_error
203 );
204 }
205 }
206 }
207
208 Ok(())
209 }
210
211 async fn transform_dynamic_message(
213 &self,
214 message: &prost_reflect::DynamicMessage,
215 service_name: &str,
216 method_name: &str,
217 overrides: &mockforge_core::overrides::Overrides,
218 ) -> Result<prost_reflect::DynamicMessage, Box<dyn std::error::Error + Send + Sync>> {
219 use crate::dynamic::http_bridge::converters::ProtobufJsonConverter;
220
221 let descriptor_pool = self.service_registry.descriptor_pool();
223
224 let converter = ProtobufJsonConverter::new(descriptor_pool.clone());
226
227 let json_value = converter.protobuf_to_json(&message.descriptor(), message)?;
229
230 let mut json_value = serde_json::Value::Object(json_value.as_object().unwrap().clone());
232 overrides.apply_with_context(
233 &format!("{}/{}", service_name, method_name),
234 &[service_name.to_string()],
235 &format!("{}/{}", service_name, method_name),
236 &mut json_value,
237 &mockforge_core::conditions::ConditionContext::new(),
238 );
239
240 let transformed_message = converter.json_to_protobuf(&message.descriptor(), &json_value)?;
242
243 Ok(transformed_message)
244 }
245
246 pub async fn postprocess_streaming_dynamic_response(
248 &self,
249 response: &mut tonic::Response<
250 tokio_stream::wrappers::ReceiverStream<
251 Result<prost_reflect::DynamicMessage, tonic::Status>,
252 >,
253 >,
254 service_name: &str,
255 method_name: &str,
256 ) -> Result<(), Status> {
257 self.postprocess_response(response, service_name, method_name).await?;
259
260 if self.config.response_transform.enabled {
265 if self.config.response_transform.overrides.is_some() {
266 tracing::debug!(
267 "Body transformation for streaming responses not yet implemented for {}/{}",
268 service_name,
269 method_name
270 );
271 }
272
273 if self.config.response_transform.validate_responses {
274 tracing::debug!(
275 "Response validation for streaming responses not yet implemented for {}/{}",
276 service_name,
277 method_name
278 );
279 }
280 }
281
282 Ok(())
283 }
284
285 async fn validate_dynamic_message(
287 &self,
288 message: &prost_reflect::DynamicMessage,
289 service_name: &str,
290 method_name: &str,
291 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
292 let _descriptor = message.descriptor();
294
295 self.validate_message_schema(message, service_name, method_name)?;
302
303 self.validate_business_rules(message, service_name, method_name)?;
305
306 self.validate_cross_field_rules(message, service_name, method_name)?;
308
309 self.validate_custom_rules(message, service_name, method_name)?;
311
312 tracing::debug!("Response validation passed for {}/{}", service_name, method_name);
313
314 Ok(())
315 }
316
317 fn validate_request_message(
319 &self,
320 message: &DynamicMessage,
321 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
322 self.validate_message_schema(message, "", "")?;
324 self.validate_business_rules(message, "", "")?;
326 self.validate_cross_field_rules(message, "", "")?;
328 self.validate_custom_rules(message, "", "")?;
330 tracing::debug!("Request validation passed");
331 Ok(())
332 }
333
334 fn validate_message_schema(
336 &self,
337 message: &DynamicMessage,
338 _service_name: &str,
339 _method_name: &str,
340 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
341 let descriptor = message.descriptor();
342
343 for field in descriptor.fields() {
345 let value = message.get_field(&field);
346 let value_ref = value.as_ref();
347
348 if !Self::value_matches_kind(value_ref, field.kind()) {
350 return Err(format!(
351 "{} field '{}' has incorrect type: expected {:?}, got {:?}",
352 "Message validation",
353 field.name(),
354 field.kind(),
355 value_ref
356 )
357 .into());
358 }
359
360 if let Kind::Message(expected_msg) = field.kind() {
362 if let prost_reflect::Value::Message(ref nested_msg) = *value_ref {
363 if nested_msg.descriptor() != expected_msg {
365 return Err(format!(
366 "{} field '{}' has incorrect message type",
367 "Message validation",
368 field.name()
369 )
370 .into());
371 }
372 }
373 }
374 }
375
376 Ok(())
377 }
378
379 fn validate_business_rules(
381 &self,
382 message: &DynamicMessage,
383 service_name: &str,
384 method_name: &str,
385 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
386 let descriptor = message.descriptor();
387
388 for field in descriptor.fields() {
389 let value = message.get_field(&field);
390 let field_value = value.as_ref();
391 let field_name = field.name().to_lowercase();
392
393 if field_name.contains("email") && field.kind() == Kind::String {
395 if let Some(email_str) = field_value.as_str() {
396 if !self.is_valid_email(email_str) {
397 return Err(format!(
398 "Invalid email format '{}' for field '{}' in {}/{}",
399 email_str,
400 field.name(),
401 service_name,
402 method_name
403 )
404 .into());
405 }
406 }
407 }
408
409 if field_name.contains("date") || field_name.contains("timestamp") {
411 match field.kind() {
412 Kind::String => {
413 if let Some(date_str) = field_value.as_str() {
414 if !self.is_valid_iso8601_date(date_str) {
415 return Err(format!(
416 "Invalid date format '{}' for field '{}' in {}/{}",
417 date_str,
418 field.name(),
419 service_name,
420 method_name
421 )
422 .into());
423 }
424 }
425 }
426 Kind::Int64 | Kind::Uint64 => {
427 if let Some(timestamp) = field_value.as_i64() {
429 if !(0..=4102444800).contains(×tamp) {
430 return Err(format!(
432 "Timestamp {} out of reasonable range for field '{}' in {}/{}",
433 timestamp,
434 field.name(),
435 service_name,
436 method_name
437 )
438 .into());
439 }
440 }
441 }
442 _ => {}
443 }
444 }
445
446 if field_name.contains("phone") && field.kind() == Kind::String {
448 if let Some(phone_str) = field_value.as_str() {
449 if !self.is_valid_phone_number(phone_str) {
450 return Err(format!(
451 "Invalid phone number format '{}' for field '{}' in {}/{}",
452 phone_str,
453 field.name(),
454 service_name,
455 method_name
456 )
457 .into());
458 }
459 }
460 }
461 }
462
463 Ok(())
464 }
465
466 fn validate_cross_field_rules(
468 &self,
469 message: &DynamicMessage,
470 service_name: &str,
471 method_name: &str,
472 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
473 let descriptor = message.descriptor();
474
475 let mut date_fields = Vec::new();
477 let mut timestamp_fields = Vec::new();
478
479 for field in descriptor.fields() {
480 let value = message.get_field(&field);
481 let field_value = value.as_ref();
482 let field_name = field.name().to_lowercase();
483
484 if field_name.contains("start")
485 && (field_name.contains("date") || field_name.contains("time"))
486 {
487 if let Some(value) = field_value.as_i64() {
488 date_fields.push(("start", value));
489 }
490 } else if field_name.contains("end")
491 && (field_name.contains("date") || field_name.contains("time"))
492 {
493 if let Some(value) = field_value.as_i64() {
494 date_fields.push(("end", value));
495 }
496 } else if field_name.contains("timestamp") {
497 if let Some(value) = field_value.as_i64() {
498 timestamp_fields.push((field.name().to_string(), value));
499 }
500 }
501 }
502
503 if date_fields.len() >= 2 {
505 let start_dates: Vec<_> = date_fields.iter().filter(|(t, _)| *t == "start").collect();
506 let end_dates: Vec<_> = date_fields.iter().filter(|(t, _)| *t == "end").collect();
507
508 for &(_, start_val) in &start_dates {
509 for &(_, end_val) in &end_dates {
510 if start_val >= end_val {
511 return Err(format!(
512 "Start date/time {} must be before end date/time {} in {}/{}",
513 start_val, end_val, service_name, method_name
514 )
515 .into());
516 }
517 }
518 }
519 }
520
521 if timestamp_fields.len() >= 2 {
523 let created_at = timestamp_fields
524 .iter()
525 .find(|(name, _)| name.to_lowercase().contains("created"));
526 let updated_at = timestamp_fields
527 .iter()
528 .find(|(name, _)| name.to_lowercase().contains("updated"));
529
530 if let (Some((_, created)), Some((_, updated))) = (created_at, updated_at) {
531 if created > updated {
532 return Err(format!(
533 "Created timestamp {} cannot be after updated timestamp {} in {}/{}",
534 created, updated, service_name, method_name
535 )
536 .into());
537 }
538 }
539 }
540
541 Ok(())
542 }
543
544 fn validate_custom_rules(
546 &self,
547 message: &DynamicMessage,
548 service_name: &str,
549 method_name: &str,
550 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
551 let descriptor = message.descriptor();
555
556 for field in descriptor.fields() {
557 let value = message.get_field(&field);
558 let field_value = value.as_ref();
559 let field_name = field.name().to_lowercase();
560
561 if field_name.ends_with("_id") || field_name == "id" {
563 match field.kind() {
564 Kind::Int32 | Kind::Int64 => {
565 if let Some(id_val) = field_value.as_i64() {
566 if id_val <= 0 {
567 return Err(format!(
568 "ID field '{}' must be positive, got {} in {}/{}",
569 field.name(),
570 id_val,
571 service_name,
572 method_name
573 )
574 .into());
575 }
576 }
577 }
578 Kind::Uint32 | Kind::Uint64 => {
579 if let Some(id_val) = field_value.as_u64() {
580 if id_val == 0 {
581 return Err(format!(
582 "ID field '{}' must be non-zero, got {} in {}/{}",
583 field.name(),
584 id_val,
585 service_name,
586 method_name
587 )
588 .into());
589 }
590 }
591 }
592 Kind::String => {
593 if let Some(id_str) = field_value.as_str() {
594 if id_str.trim().is_empty() {
595 return Err(format!(
596 "ID field '{}' cannot be empty in {}/{}",
597 field.name(),
598 service_name,
599 method_name
600 )
601 .into());
602 }
603 }
604 }
605 _ => {}
606 }
607 }
608
609 if field_name.contains("amount")
611 || field_name.contains("price")
612 || field_name.contains("cost")
613 {
614 if let Some(numeric_val) = field_value.as_f64() {
615 if numeric_val < 0.0 {
616 return Err(format!(
617 "Amount/price field '{}' cannot be negative, got {} in {}/{}",
618 field.name(),
619 numeric_val,
620 service_name,
621 method_name
622 )
623 .into());
624 }
625 }
626 }
627 }
628
629 Ok(())
630 }
631
632 fn is_valid_email(&self, email: &str) -> bool {
634 let parts: Vec<&str> = email.split('@').collect();
636 if parts.len() != 2 {
637 return false;
638 }
639
640 let local = parts[0];
641 let domain = parts[1];
642
643 if local.is_empty() || domain.is_empty() {
644 return false;
645 }
646
647 domain.contains('.') && !domain.starts_with('.') && !domain.ends_with('.')
649 }
650
651 fn is_valid_phone_number(&self, phone: &str) -> bool {
653 !phone.is_empty() && phone.len() >= 7 && phone.len() <= 15
655 }
656
657 fn is_valid_iso8601_date(&self, date_str: &str) -> bool {
659 chrono::DateTime::parse_from_rfc3339(date_str).is_ok()
662 || chrono::NaiveDate::parse_from_str(date_str, "%Y-%m-%d").is_ok()
663 || chrono::NaiveDateTime::parse_from_str(date_str, "%Y-%m-%d %H:%M:%S").is_ok()
664 }
665
666 pub async fn handle_error(
668 &self,
669 error: Status,
670 service_name: &str,
671 method_name: &str,
672 ) -> Status {
673 error!(
675 "Error in {}/{}: {} (code: {:?})",
676 service_name,
677 method_name,
678 error,
679 error.code()
680 );
681
682 match error.code() {
683 Code::InvalidArgument => Status::invalid_argument(format!(
684 "Invalid arguments provided to {}/{}",
685 service_name, method_name
686 )),
687 Code::NotFound => {
688 Status::not_found(format!("Resource not found in {}/{}", service_name, method_name))
689 }
690 Code::AlreadyExists => Status::already_exists(format!(
691 "Resource already exists in {}/{}",
692 service_name, method_name
693 )),
694 Code::PermissionDenied => Status::permission_denied(format!(
695 "Permission denied for {}/{}",
696 service_name, method_name
697 )),
698 Code::FailedPrecondition => Status::failed_precondition(format!(
699 "Precondition failed for {}/{}",
700 service_name, method_name
701 )),
702 Code::Aborted => {
703 Status::aborted(format!("Operation aborted for {}/{}", service_name, method_name))
704 }
705 Code::OutOfRange => Status::out_of_range(format!(
706 "Value out of range in {}/{}",
707 service_name, method_name
708 )),
709 Code::Unimplemented => Status::unimplemented(format!(
710 "Method {}/{} not implemented",
711 service_name, method_name
712 )),
713 Code::Internal => {
714 Status::internal(format!("Internal error in {}/{}", service_name, method_name))
715 }
716 Code::Unavailable => Status::unavailable(format!(
717 "Service {}/{} temporarily unavailable",
718 service_name, method_name
719 )),
720 Code::DataLoss => {
721 Status::data_loss(format!("Data loss occurred in {}/{}", service_name, method_name))
722 }
723 Code::Unauthenticated => Status::unauthenticated(format!(
724 "Authentication required for {}/{}",
725 service_name, method_name
726 )),
727 Code::DeadlineExceeded => Status::deadline_exceeded(format!(
728 "Request to {}/{} timed out",
729 service_name, method_name
730 )),
731 Code::ResourceExhausted => Status::resource_exhausted(format!(
732 "Rate limit exceeded for {}/{}",
733 service_name, method_name
734 )),
735 _ => {
736 let message = error.message();
737 if message.contains(service_name) && message.contains(method_name) {
738 error
739 } else {
740 Status::new(
741 error.code(),
742 format!("{}/{}: {}", service_name, method_name, message),
743 )
744 }
745 }
746 }
747 }
748
749 pub async fn collect_metrics(
751 &self,
752 service_name: &str,
753 method_name: &str,
754 duration: std::time::Duration,
755 success: bool,
756 ) {
757 let duration_ms = duration.as_millis() as u64;
758
759 if success {
760 record_success(service_name, method_name, duration_ms).await;
761 } else {
762 record_error(service_name, method_name).await;
763 }
764
765 tracing::debug!(
766 "Request {}/{} completed in {:?}, success: {}",
767 service_name,
768 method_name,
769 duration,
770 success
771 );
772 }
773}
774
775#[cfg(test)]
776mod tests {
777
778 #[test]
779 fn test_module_compiles() {}
780}