1use crate::reflection::metrics::{record_error, record_success};
7use crate::reflection::mock_proxy::proxy::MockReflectionProxy;
8use prost_reflect::{DynamicMessage, Kind, ReflectMessage};
9use std::time::Instant;
10use tokio::sync::mpsc;
11use tokio_stream::{wrappers::ReceiverStream, StreamExt};
12use tonic::{
13 metadata::{Ascii, MetadataKey, MetadataValue},
14 Code, Request, Status,
15};
16use tracing::error;
17
18impl MockReflectionProxy {
19 pub async fn preprocess_request<T>(&self, request: &mut Request<T>) -> Result<(), Status>
21 where
22 T: ReflectMessage,
23 {
24 let mut metadata_log = Vec::new();
26 for kv in request.metadata().iter() {
27 match kv {
28 tonic::metadata::KeyAndValueRef::Ascii(key, value) => {
29 metadata_log.push(format!("{}: {}", key, value.to_str().unwrap_or("<binary>")));
30 }
31 tonic::metadata::KeyAndValueRef::Binary(key, _) => {
32 metadata_log.push(format!("{}: <binary>", key));
33 }
34 }
35 }
36 tracing::debug!("Extracted request metadata: [{}]", metadata_log.join(", "));
37
38 let descriptor = request.get_ref().descriptor();
40 let mut buf = Vec::new();
41 request
42 .get_ref()
43 .encode(&mut buf)
44 .map_err(|_e| Status::internal("Failed to encode request".to_string()))?;
45 let dynamic_message = DynamicMessage::decode(descriptor.clone(), &buf[..])
46 .map_err(|_e| Status::internal("Failed to decode request".to_string()))?;
47 if let Err(e) = self.validate_request_message(&dynamic_message) {
48 return Err(Status::internal(format!("Request validation failed: {}", e)));
49 }
50 tracing::debug!("Request format validation passed");
51
52 request.metadata_mut().insert("x-mockforge-processed", "true".parse().unwrap());
55 request
56 .metadata_mut()
57 .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
58
59 tracing::debug!("Applied request transformations: added processed and timestamp headers");
60
61 Ok(())
62 }
63
64 pub async fn log_request<T>(&self, request: &Request<T>, service_name: &str, method_name: &str)
66 where
67 T: ReflectMessage,
68 {
69 let start_time = Instant::now();
70
71 let mut metadata_log = Vec::new();
73 for kv in request.metadata().iter() {
74 match kv {
75 tonic::metadata::KeyAndValueRef::Ascii(key, value) => {
76 metadata_log.push(format!("{}: {}", key, value.to_str().unwrap_or("<binary>")));
77 }
78 tonic::metadata::KeyAndValueRef::Binary(key, _) => {
79 metadata_log.push(format!("{}: <binary>", key));
80 }
81 }
82 }
83 tracing::debug!(
84 "Request metadata for {}/{}: [{}]",
85 service_name,
86 method_name,
87 metadata_log.join(", ")
88 );
89
90 let request_size = request.get_ref().encoded_len();
92 tracing::debug!(
93 "Request size for {}/{}: {} bytes",
94 service_name,
95 method_name,
96 request_size
97 );
98
99 tracing::debug!(
101 "Request start time for {}/{}: {:?}",
102 service_name,
103 method_name,
104 start_time
105 );
106 }
107
108 pub async fn postprocess_response<T>(
110 &self,
111 response: &mut tonic::Response<T>,
112 service_name: &str,
113 method_name: &str,
114 ) -> Result<(), Status> {
115 let start = Instant::now();
116 response.metadata_mut().insert("x-mockforge-processed", "true".parse().unwrap());
118 response
119 .metadata_mut()
120 .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
121
122 if self.config.response_transform.enabled {
133 for (key, value) in &self.config.response_transform.custom_headers {
135 let parsed_key: Option<MetadataKey<Ascii>> = key.parse().ok();
136 let parsed_value: Option<MetadataValue<Ascii>> = value.parse().ok();
137
138 match (parsed_key, parsed_value) {
139 (Some(k), Some(v)) => {
140 response.metadata_mut().insert(k, v);
141 }
142 (None, _) => {
143 tracing::warn!(
144 "Skipping invalid custom header key '{}' in response transform config",
145 key
146 );
147 }
148 (_, None) => {
149 tracing::warn!(
150 "Skipping invalid custom header value for key '{}' in response transform config",
151 key
152 );
153 }
154 }
155 }
156 }
157
158 let processing_time = start.elapsed().as_millis();
160 response
162 .metadata_mut()
163 .insert("x-mockforge-processing-time", processing_time.to_string().parse().unwrap());
164 tracing::debug!("Postprocessed response for {}/{}", service_name, method_name);
165
166 Ok(())
167 }
168
169 pub async fn postprocess_dynamic_response(
171 &self,
172 response: &mut tonic::Response<DynamicMessage>,
173 service_name: &str,
174 method_name: &str,
175 ) -> Result<(), Status> {
176 self.postprocess_response(response, service_name, method_name).await?;
178
179 if self.config.response_transform.enabled {
181 if let Some(ref overrides) = self.config.response_transform.overrides {
182 match self
183 .transform_dynamic_message(
184 &response.get_ref().clone(),
185 service_name,
186 method_name,
187 overrides,
188 )
189 .await
190 {
191 Ok(transformed_message) => {
192 *response.get_mut() = transformed_message;
194 tracing::debug!(
195 "Applied body transformations to response for {}/{}",
196 service_name,
197 method_name
198 );
199 }
200 Err(e) => {
201 tracing::warn!(
202 "Failed to transform response body for {}/{}: {}",
203 service_name,
204 method_name,
205 e
206 );
207 }
208 }
209 }
210
211 if self.config.response_transform.validate_responses {
213 if let Err(validation_error) = self
214 .validate_dynamic_message(response.get_ref(), service_name, method_name)
215 .await
216 {
217 tracing::warn!(
218 "Response validation failed for {}/{}: {}",
219 service_name,
220 method_name,
221 validation_error
222 );
223 }
224 }
225 }
226
227 Ok(())
228 }
229
230 async fn transform_dynamic_message(
232 &self,
233 message: &DynamicMessage,
234 service_name: &str,
235 method_name: &str,
236 overrides: &mockforge_core::overrides::Overrides,
237 ) -> Result<DynamicMessage, Box<dyn std::error::Error + Send + Sync>> {
238 use crate::dynamic::http_bridge::converters::ProtobufJsonConverter;
239
240 let descriptor_pool = self.service_registry.descriptor_pool();
242
243 let converter = ProtobufJsonConverter::new(descriptor_pool.clone());
245
246 let json_value = converter.protobuf_to_json(&message.descriptor(), message)?;
248
249 let mut json_value = serde_json::Value::Object(json_value.as_object().unwrap().clone());
251 overrides.apply_with_context(
252 &format!("{}/{}", service_name, method_name),
253 &[service_name.to_string()],
254 &format!("{}/{}", service_name, method_name),
255 &mut json_value,
256 &mockforge_core::conditions::ConditionContext::new(),
257 );
258
259 let transformed_message = converter.json_to_protobuf(&message.descriptor(), &json_value)?;
261
262 Ok(transformed_message)
263 }
264
265 pub async fn postprocess_streaming_dynamic_response(
267 &self,
268 response: &mut tonic::Response<
269 tokio_stream::wrappers::ReceiverStream<Result<DynamicMessage, Status>>,
270 >,
271 service_name: &str,
272 method_name: &str,
273 ) -> Result<(), Status> {
274 self.postprocess_response(response, service_name, method_name).await?;
276
277 if self.config.response_transform.enabled
278 && (self.config.response_transform.overrides.is_some()
279 || self.config.response_transform.validate_responses)
280 {
281 let (placeholder_tx, placeholder_rx) = mpsc::channel(1);
282 drop(placeholder_tx);
283 let mut original_stream =
284 std::mem::replace(response.get_mut(), ReceiverStream::new(placeholder_rx));
285
286 let (tx, rx) = mpsc::channel(16);
287 let proxy = self.clone();
288 let service_name = service_name.to_string();
289 let method_name = method_name.to_string();
290 let overrides = self.config.response_transform.overrides.clone();
291 let validate_responses = self.config.response_transform.validate_responses;
292
293 tokio::spawn(async move {
294 while let Some(item) = original_stream.next().await {
295 match item {
296 Ok(mut message) => {
297 if let Some(ref override_config) = overrides {
298 match proxy
299 .transform_dynamic_message(
300 &message,
301 &service_name,
302 &method_name,
303 override_config,
304 )
305 .await
306 {
307 Ok(transformed) => {
308 message = transformed;
309 }
310 Err(e) => {
311 tracing::warn!(
312 "Failed to transform streaming message for {}/{}: {}",
313 service_name,
314 method_name,
315 e
316 );
317 }
318 }
319 }
320
321 if validate_responses {
322 if let Err(e) = proxy
323 .validate_dynamic_message(&message, &service_name, &method_name)
324 .await
325 {
326 tracing::warn!(
327 "Streaming response validation failed for {}/{}: {}",
328 service_name,
329 method_name,
330 e
331 );
332 }
333 }
334
335 if tx.send(Ok(message)).await.is_err() {
336 break;
337 }
338 }
339 Err(status) => {
340 if tx.send(Err(status)).await.is_err() {
341 break;
342 }
343 }
344 }
345 }
346 });
347
348 *response.get_mut() = ReceiverStream::new(rx);
349 }
350
351 Ok(())
352 }
353
354 async fn validate_dynamic_message(
356 &self,
357 message: &DynamicMessage,
358 service_name: &str,
359 method_name: &str,
360 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
361 let _descriptor = message.descriptor();
363
364 self.validate_message_schema(message, service_name, method_name)?;
371
372 self.validate_business_rules(message, service_name, method_name)?;
374
375 self.validate_cross_field_rules(message, service_name, method_name)?;
377
378 self.validate_custom_rules(message, service_name, method_name)?;
380
381 tracing::debug!("Response validation passed for {}/{}", service_name, method_name);
382
383 Ok(())
384 }
385
386 fn validate_request_message(
388 &self,
389 message: &DynamicMessage,
390 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
391 self.validate_message_schema(message, "", "")?;
393 self.validate_business_rules(message, "", "")?;
395 self.validate_cross_field_rules(message, "", "")?;
397 self.validate_custom_rules(message, "", "")?;
399 tracing::debug!("Request validation passed");
400 Ok(())
401 }
402
403 fn validate_message_schema(
405 &self,
406 message: &DynamicMessage,
407 _service_name: &str,
408 _method_name: &str,
409 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
410 let descriptor = message.descriptor();
411
412 for field in descriptor.fields() {
414 let value = message.get_field(&field);
415 let value_ref = value.as_ref();
416
417 if !Self::value_matches_kind(value_ref, field.kind()) {
419 return Err(format!(
420 "{} field '{}' has incorrect type: expected {:?}, got {:?}",
421 "Message validation",
422 field.name(),
423 field.kind(),
424 value_ref
425 )
426 .into());
427 }
428
429 if let Kind::Message(expected_msg) = field.kind() {
431 if let prost_reflect::Value::Message(ref nested_msg) = *value_ref {
432 if nested_msg.descriptor() != expected_msg {
434 return Err(format!(
435 "{} field '{}' has incorrect message type",
436 "Message validation",
437 field.name()
438 )
439 .into());
440 }
441 }
442 }
443 }
444
445 Ok(())
446 }
447
448 fn validate_business_rules(
450 &self,
451 message: &DynamicMessage,
452 service_name: &str,
453 method_name: &str,
454 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
455 let descriptor = message.descriptor();
456
457 for field in descriptor.fields() {
458 let value = message.get_field(&field);
459 let field_value = value.as_ref();
460 let field_name = field.name().to_lowercase();
461
462 if field_name.contains("email") && field.kind() == Kind::String {
464 if let Some(email_str) = field_value.as_str() {
465 if !self.is_valid_email(email_str) {
466 return Err(format!(
467 "Invalid email format '{}' for field '{}' in {}/{}",
468 email_str,
469 field.name(),
470 service_name,
471 method_name
472 )
473 .into());
474 }
475 }
476 }
477
478 if field_name.contains("date") || field_name.contains("timestamp") {
480 match field.kind() {
481 Kind::String => {
482 if let Some(date_str) = field_value.as_str() {
483 if !self.is_valid_iso8601_date(date_str) {
484 return Err(format!(
485 "Invalid date format '{}' for field '{}' in {}/{}",
486 date_str,
487 field.name(),
488 service_name,
489 method_name
490 )
491 .into());
492 }
493 }
494 }
495 Kind::Int64 | Kind::Uint64 => {
496 if let Some(timestamp) = field_value.as_i64() {
498 if !(0..=4102444800).contains(×tamp) {
499 return Err(format!(
501 "Timestamp {} out of reasonable range for field '{}' in {}/{}",
502 timestamp,
503 field.name(),
504 service_name,
505 method_name
506 )
507 .into());
508 }
509 }
510 }
511 _ => {}
512 }
513 }
514
515 if field_name.contains("phone") && field.kind() == Kind::String {
517 if let Some(phone_str) = field_value.as_str() {
518 if !self.is_valid_phone_number(phone_str) {
519 return Err(format!(
520 "Invalid phone number format '{}' for field '{}' in {}/{}",
521 phone_str,
522 field.name(),
523 service_name,
524 method_name
525 )
526 .into());
527 }
528 }
529 }
530 }
531
532 Ok(())
533 }
534
535 fn validate_cross_field_rules(
537 &self,
538 message: &DynamicMessage,
539 service_name: &str,
540 method_name: &str,
541 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
542 let descriptor = message.descriptor();
543
544 let mut date_fields = Vec::new();
546 let mut timestamp_fields = Vec::new();
547
548 for field in descriptor.fields() {
549 let value = message.get_field(&field);
550 let field_value = value.as_ref();
551 let field_name = field.name().to_lowercase();
552
553 if field_name.contains("start")
554 && (field_name.contains("date") || field_name.contains("time"))
555 {
556 if let Some(value) = field_value.as_i64() {
557 date_fields.push(("start", value));
558 }
559 } else if field_name.contains("end")
560 && (field_name.contains("date") || field_name.contains("time"))
561 {
562 if let Some(value) = field_value.as_i64() {
563 date_fields.push(("end", value));
564 }
565 } else if field_name.contains("timestamp") {
566 if let Some(value) = field_value.as_i64() {
567 timestamp_fields.push((field.name().to_string(), value));
568 }
569 }
570 }
571
572 if date_fields.len() >= 2 {
574 let start_dates: Vec<_> = date_fields.iter().filter(|(t, _)| *t == "start").collect();
575 let end_dates: Vec<_> = date_fields.iter().filter(|(t, _)| *t == "end").collect();
576
577 for &(_, start_val) in &start_dates {
578 for &(_, end_val) in &end_dates {
579 if start_val >= end_val {
580 return Err(format!(
581 "Start date/time {} must be before end date/time {} in {}/{}",
582 start_val, end_val, service_name, method_name
583 )
584 .into());
585 }
586 }
587 }
588 }
589
590 if timestamp_fields.len() >= 2 {
592 let created_at = timestamp_fields
593 .iter()
594 .find(|(name, _)| name.to_lowercase().contains("created"));
595 let updated_at = timestamp_fields
596 .iter()
597 .find(|(name, _)| name.to_lowercase().contains("updated"));
598
599 if let (Some((_, created)), Some((_, updated))) = (created_at, updated_at) {
600 if created > updated {
601 return Err(format!(
602 "Created timestamp {} cannot be after updated timestamp {} in {}/{}",
603 created, updated, service_name, method_name
604 )
605 .into());
606 }
607 }
608 }
609
610 Ok(())
611 }
612
613 fn validate_custom_rules(
615 &self,
616 message: &DynamicMessage,
617 service_name: &str,
618 method_name: &str,
619 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
620 let descriptor = message.descriptor();
624
625 for field in descriptor.fields() {
626 let value = message.get_field(&field);
627 let field_value = value.as_ref();
628 let field_name = field.name().to_lowercase();
629
630 if field_name.ends_with("_id") || field_name == "id" {
632 match field.kind() {
633 Kind::Int32 | Kind::Int64 => {
634 if let Some(id_val) = field_value.as_i64() {
635 if id_val <= 0 {
636 return Err(format!(
637 "ID field '{}' must be positive, got {} in {}/{}",
638 field.name(),
639 id_val,
640 service_name,
641 method_name
642 )
643 .into());
644 }
645 }
646 }
647 Kind::Uint32 | Kind::Uint64 => {
648 if let Some(id_val) = field_value.as_u64() {
649 if id_val == 0 {
650 return Err(format!(
651 "ID field '{}' must be non-zero, got {} in {}/{}",
652 field.name(),
653 id_val,
654 service_name,
655 method_name
656 )
657 .into());
658 }
659 }
660 }
661 Kind::String => {
662 if let Some(id_str) = field_value.as_str() {
663 if id_str.trim().is_empty() {
664 return Err(format!(
665 "ID field '{}' cannot be empty in {}/{}",
666 field.name(),
667 service_name,
668 method_name
669 )
670 .into());
671 }
672 }
673 }
674 _ => {}
675 }
676 }
677
678 if field_name.contains("amount")
680 || field_name.contains("price")
681 || field_name.contains("cost")
682 {
683 if let Some(numeric_val) = field_value.as_f64() {
684 if numeric_val < 0.0 {
685 return Err(format!(
686 "Amount/price field '{}' cannot be negative, got {} in {}/{}",
687 field.name(),
688 numeric_val,
689 service_name,
690 method_name
691 )
692 .into());
693 }
694 }
695 }
696 }
697
698 Ok(())
699 }
700
701 fn is_valid_email(&self, email: &str) -> bool {
703 let parts: Vec<&str> = email.split('@').collect();
705 if parts.len() != 2 {
706 return false;
707 }
708
709 let local = parts[0];
710 let domain = parts[1];
711
712 if local.is_empty() || domain.is_empty() {
713 return false;
714 }
715
716 domain.contains('.') && !domain.starts_with('.') && !domain.ends_with('.')
718 }
719
720 fn is_valid_phone_number(&self, phone: &str) -> bool {
722 !phone.is_empty() && phone.len() >= 7 && phone.len() <= 15
724 }
725
726 fn is_valid_iso8601_date(&self, date_str: &str) -> bool {
728 chrono::DateTime::parse_from_rfc3339(date_str).is_ok()
731 || chrono::NaiveDate::parse_from_str(date_str, "%Y-%m-%d").is_ok()
732 || chrono::NaiveDateTime::parse_from_str(date_str, "%Y-%m-%d %H:%M:%S").is_ok()
733 }
734
735 pub async fn handle_error(
737 &self,
738 error: Status,
739 service_name: &str,
740 method_name: &str,
741 ) -> Status {
742 error!(
744 "Error in {}/{}: {} (code: {:?})",
745 service_name,
746 method_name,
747 error,
748 error.code()
749 );
750
751 match error.code() {
752 Code::InvalidArgument => Status::invalid_argument(format!(
753 "Invalid arguments provided to {}/{}",
754 service_name, method_name
755 )),
756 Code::NotFound => {
757 Status::not_found(format!("Resource not found in {}/{}", service_name, method_name))
758 }
759 Code::AlreadyExists => Status::already_exists(format!(
760 "Resource already exists in {}/{}",
761 service_name, method_name
762 )),
763 Code::PermissionDenied => Status::permission_denied(format!(
764 "Permission denied for {}/{}",
765 service_name, method_name
766 )),
767 Code::FailedPrecondition => Status::failed_precondition(format!(
768 "Precondition failed for {}/{}",
769 service_name, method_name
770 )),
771 Code::Aborted => {
772 Status::aborted(format!("Operation aborted for {}/{}", service_name, method_name))
773 }
774 Code::OutOfRange => Status::out_of_range(format!(
775 "Value out of range in {}/{}",
776 service_name, method_name
777 )),
778 Code::Unimplemented => Status::unimplemented(format!(
779 "Method {}/{} not implemented",
780 service_name, method_name
781 )),
782 Code::Internal => {
783 Status::internal(format!("Internal error in {}/{}", service_name, method_name))
784 }
785 Code::Unavailable => Status::unavailable(format!(
786 "Service {}/{} temporarily unavailable",
787 service_name, method_name
788 )),
789 Code::DataLoss => {
790 Status::data_loss(format!("Data loss occurred in {}/{}", service_name, method_name))
791 }
792 Code::Unauthenticated => Status::unauthenticated(format!(
793 "Authentication required for {}/{}",
794 service_name, method_name
795 )),
796 Code::DeadlineExceeded => Status::deadline_exceeded(format!(
797 "Request to {}/{} timed out",
798 service_name, method_name
799 )),
800 Code::ResourceExhausted => Status::resource_exhausted(format!(
801 "Rate limit exceeded for {}/{}",
802 service_name, method_name
803 )),
804 _ => {
805 let message = error.message();
806 if message.contains(service_name) && message.contains(method_name) {
807 error
808 } else {
809 Status::new(
810 error.code(),
811 format!("{}/{}: {}", service_name, method_name, message),
812 )
813 }
814 }
815 }
816 }
817
818 pub async fn collect_metrics(
820 &self,
821 service_name: &str,
822 method_name: &str,
823 duration: std::time::Duration,
824 success: bool,
825 ) {
826 let duration_ms = duration.as_millis() as u64;
827
828 if success {
829 record_success(service_name, method_name, duration_ms).await;
830 } else {
831 record_error(service_name, method_name).await;
832 }
833
834 tracing::debug!(
835 "Request {}/{} completed in {:?}, success: {}",
836 service_name,
837 method_name,
838 duration,
839 success
840 );
841 }
842}
843
844#[cfg(test)]
845mod tests {
846 #[test]
847 fn test_module_compiles() {
848 }
850}