1use crate::reflection::metrics::{record_error, record_success};
7use crate::reflection::mock_proxy::proxy::MockReflectionProxy;
8use prost_reflect::{DynamicMessage, Kind, ReflectMessage};
9use std::time::Instant;
10use tokio::sync::mpsc;
11use tokio_stream::{wrappers::ReceiverStream, StreamExt};
12use tonic::{
13 metadata::{Ascii, MetadataKey, MetadataValue},
14 Code, Request, Status,
15};
16use tracing::error;
17
18impl MockReflectionProxy {
19 pub async fn preprocess_request<T>(&self, request: &mut Request<T>) -> Result<(), Status>
21 where
22 T: ReflectMessage,
23 {
24 let mut metadata_log = Vec::new();
26 for kv in request.metadata().iter() {
27 match kv {
28 tonic::metadata::KeyAndValueRef::Ascii(key, value) => {
29 metadata_log.push(format!("{}: {}", key, value.to_str().unwrap_or("<binary>")));
30 }
31 tonic::metadata::KeyAndValueRef::Binary(key, _) => {
32 metadata_log.push(format!("{}: <binary>", key));
33 }
34 }
35 }
36 tracing::debug!("Extracted request metadata: [{}]", metadata_log.join(", "));
37
38 let descriptor = request.get_ref().descriptor();
40 let mut buf = Vec::new();
41 request
42 .get_ref()
43 .encode(&mut buf)
44 .map_err(|_e| Status::internal("Failed to encode request".to_string()))?;
45 let dynamic_message = DynamicMessage::decode(descriptor.clone(), &buf[..])
46 .map_err(|_e| Status::internal("Failed to decode request".to_string()))?;
47 if let Err(e) = self.validate_request_message(&dynamic_message) {
48 return Err(Status::internal(format!("Request validation failed: {}", e)));
49 }
50 tracing::debug!("Request format validation passed");
51
52 request.metadata_mut().insert("x-mockforge-processed", "true".parse().unwrap());
55 request
56 .metadata_mut()
57 .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
58
59 tracing::debug!("Applied request transformations: added processed and timestamp headers");
60
61 Ok(())
62 }
63
64 pub async fn log_request<T>(&self, request: &Request<T>, service_name: &str, method_name: &str)
66 where
67 T: ReflectMessage,
68 {
69 let start_time = Instant::now();
70
71 let mut metadata_log = Vec::new();
73 for kv in request.metadata().iter() {
74 match kv {
75 tonic::metadata::KeyAndValueRef::Ascii(key, value) => {
76 metadata_log.push(format!("{}: {}", key, value.to_str().unwrap_or("<binary>")));
77 }
78 tonic::metadata::KeyAndValueRef::Binary(key, _) => {
79 metadata_log.push(format!("{}: <binary>", key));
80 }
81 }
82 }
83 tracing::debug!(
84 "Request metadata for {}/{}: [{}]",
85 service_name,
86 method_name,
87 metadata_log.join(", ")
88 );
89
90 let request_size = request.get_ref().encoded_len();
92 tracing::debug!(
93 "Request size for {}/{}: {} bytes",
94 service_name,
95 method_name,
96 request_size
97 );
98
99 tracing::debug!(
101 "Request start time for {}/{}: {:?}",
102 service_name,
103 method_name,
104 start_time
105 );
106 }
107
108 pub async fn postprocess_response<T>(
110 &self,
111 response: &mut tonic::Response<T>,
112 service_name: &str,
113 method_name: &str,
114 ) -> Result<(), Status> {
115 let start = Instant::now();
116 response.metadata_mut().insert("x-mockforge-processed", "true".parse().unwrap());
118 response
119 .metadata_mut()
120 .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
121
122 if self.config.response_transform.enabled {
133 for (key, value) in &self.config.response_transform.custom_headers {
135 let parsed_key: Option<MetadataKey<Ascii>> = key.parse().ok();
136 let parsed_value: Option<MetadataValue<Ascii>> = value.parse().ok();
137
138 match (parsed_key, parsed_value) {
139 (Some(k), Some(v)) => {
140 response.metadata_mut().insert(k, v);
141 }
142 (None, _) => {
143 tracing::warn!(
144 "Skipping invalid custom header key '{}' in response transform config",
145 key
146 );
147 }
148 (_, None) => {
149 tracing::warn!(
150 "Skipping invalid custom header value for key '{}' in response transform config",
151 key
152 );
153 }
154 }
155 }
156 }
157
158 let processing_time = start.elapsed().as_millis();
160 response
162 .metadata_mut()
163 .insert("x-mockforge-processing-time", processing_time.to_string().parse().unwrap());
164 tracing::debug!("Postprocessed response for {}/{}", service_name, method_name);
165
166 Ok(())
167 }
168
169 pub async fn postprocess_dynamic_response(
171 &self,
172 response: &mut tonic::Response<DynamicMessage>,
173 service_name: &str,
174 method_name: &str,
175 ) -> Result<(), Status> {
176 self.postprocess_response(response, service_name, method_name).await?;
178
179 if self.config.response_transform.enabled {
181 if let Some(ref overrides) = self.config.response_transform.overrides {
182 match self
183 .transform_dynamic_message(
184 &response.get_ref().clone(),
185 service_name,
186 method_name,
187 overrides,
188 )
189 .await
190 {
191 Ok(transformed_message) => {
192 *response.get_mut() = transformed_message;
194 tracing::debug!(
195 "Applied body transformations to response for {}/{}",
196 service_name,
197 method_name
198 );
199 }
200 Err(e) => {
201 tracing::warn!(
202 "Failed to transform response body for {}/{}: {}",
203 service_name,
204 method_name,
205 e
206 );
207 }
208 }
209 }
210
211 if self.config.response_transform.validate_responses {
213 if let Err(validation_error) = self
214 .validate_dynamic_message(response.get_ref(), service_name, method_name)
215 .await
216 {
217 tracing::warn!(
218 "Response validation failed for {}/{}: {}",
219 service_name,
220 method_name,
221 validation_error
222 );
223 }
224 }
225 }
226
227 Ok(())
228 }
229
230 async fn transform_dynamic_message(
232 &self,
233 message: &DynamicMessage,
234 service_name: &str,
235 method_name: &str,
236 overrides: &mockforge_core::overrides::Overrides,
237 ) -> Result<DynamicMessage, Box<dyn std::error::Error + Send + Sync>> {
238 use crate::dynamic::http_bridge::converters::ProtobufJsonConverter;
239
240 let descriptor_pool = self.service_registry.descriptor_pool();
242
243 let converter = ProtobufJsonConverter::new(descriptor_pool.clone());
245
246 let json_value = converter.protobuf_to_json(&message.descriptor(), message)?;
248
249 let mut json_value = serde_json::Value::Object(json_value.as_object().unwrap().clone());
251 overrides.apply_with_context(
252 &format!("{}/{}", service_name, method_name),
253 &[service_name.to_string()],
254 &format!("{}/{}", service_name, method_name),
255 &mut json_value,
256 &mockforge_core::conditions::ConditionContext::new(),
257 );
258
259 let transformed_message = converter.json_to_protobuf(&message.descriptor(), &json_value)?;
261
262 Ok(transformed_message)
263 }
264
265 pub async fn postprocess_streaming_dynamic_response(
267 &self,
268 response: &mut tonic::Response<ReceiverStream<Result<DynamicMessage, Status>>>,
269 service_name: &str,
270 method_name: &str,
271 ) -> Result<(), Status> {
272 self.postprocess_response(response, service_name, method_name).await?;
274
275 if self.config.response_transform.enabled
276 && (self.config.response_transform.overrides.is_some()
277 || self.config.response_transform.validate_responses)
278 {
279 let (placeholder_tx, placeholder_rx) = mpsc::channel(1);
280 drop(placeholder_tx);
281 let mut original_stream =
282 std::mem::replace(response.get_mut(), ReceiverStream::new(placeholder_rx));
283
284 let (tx, rx) = mpsc::channel(16);
285 let proxy = self.clone();
286 let service_name = service_name.to_string();
287 let method_name = method_name.to_string();
288 let overrides = self.config.response_transform.overrides.clone();
289 let validate_responses = self.config.response_transform.validate_responses;
290
291 tokio::spawn(async move {
292 while let Some(item) = original_stream.next().await {
293 match item {
294 Ok(mut message) => {
295 if let Some(ref override_config) = overrides {
296 match proxy
297 .transform_dynamic_message(
298 &message,
299 &service_name,
300 &method_name,
301 override_config,
302 )
303 .await
304 {
305 Ok(transformed) => {
306 message = transformed;
307 }
308 Err(e) => {
309 tracing::warn!(
310 "Failed to transform streaming message for {}/{}: {}",
311 service_name,
312 method_name,
313 e
314 );
315 }
316 }
317 }
318
319 if validate_responses {
320 if let Err(e) = proxy
321 .validate_dynamic_message(&message, &service_name, &method_name)
322 .await
323 {
324 tracing::warn!(
325 "Streaming response validation failed for {}/{}: {}",
326 service_name,
327 method_name,
328 e
329 );
330 }
331 }
332
333 if tx.send(Ok(message)).await.is_err() {
334 break;
335 }
336 }
337 Err(status) => {
338 if tx.send(Err(status)).await.is_err() {
339 break;
340 }
341 }
342 }
343 }
344 });
345
346 *response.get_mut() = ReceiverStream::new(rx);
347 }
348
349 Ok(())
350 }
351
352 async fn validate_dynamic_message(
354 &self,
355 message: &DynamicMessage,
356 service_name: &str,
357 method_name: &str,
358 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
359 let _descriptor = message.descriptor();
361
362 self.validate_message_schema(message, service_name, method_name)?;
369
370 self.validate_business_rules(message, service_name, method_name)?;
372
373 self.validate_cross_field_rules(message, service_name, method_name)?;
375
376 self.validate_custom_rules(message, service_name, method_name)?;
378
379 tracing::debug!("Response validation passed for {}/{}", service_name, method_name);
380
381 Ok(())
382 }
383
384 fn validate_request_message(
386 &self,
387 message: &DynamicMessage,
388 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
389 self.validate_message_schema(message, "", "")?;
391 self.validate_business_rules(message, "", "")?;
393 self.validate_cross_field_rules(message, "", "")?;
395 self.validate_custom_rules(message, "", "")?;
397 tracing::debug!("Request validation passed");
398 Ok(())
399 }
400
401 fn validate_message_schema(
403 &self,
404 message: &DynamicMessage,
405 _service_name: &str,
406 _method_name: &str,
407 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
408 let descriptor = message.descriptor();
409
410 for field in descriptor.fields() {
412 let value = message.get_field(&field);
413 let value_ref = value.as_ref();
414
415 if !Self::value_matches_kind(value_ref, field.kind()) {
417 return Err(format!(
418 "{} field '{}' has incorrect type: expected {:?}, got {:?}",
419 "Message validation",
420 field.name(),
421 field.kind(),
422 value_ref
423 )
424 .into());
425 }
426
427 if let Kind::Message(expected_msg) = field.kind() {
429 if let prost_reflect::Value::Message(ref nested_msg) = *value_ref {
430 if nested_msg.descriptor() != expected_msg {
432 return Err(format!(
433 "{} field '{}' has incorrect message type",
434 "Message validation",
435 field.name()
436 )
437 .into());
438 }
439 }
440 }
441 }
442
443 Ok(())
444 }
445
446 fn validate_business_rules(
448 &self,
449 message: &DynamicMessage,
450 service_name: &str,
451 method_name: &str,
452 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
453 let descriptor = message.descriptor();
454
455 for field in descriptor.fields() {
456 let value = message.get_field(&field);
457 let field_value = value.as_ref();
458 let field_name = field.name().to_lowercase();
459
460 if field_name.contains("email") && field.kind() == Kind::String {
462 if let Some(email_str) = field_value.as_str() {
463 if !self.is_valid_email(email_str) {
464 return Err(format!(
465 "Invalid email format '{}' for field '{}' in {}/{}",
466 email_str,
467 field.name(),
468 service_name,
469 method_name
470 )
471 .into());
472 }
473 }
474 }
475
476 if field_name.contains("date") || field_name.contains("timestamp") {
478 match field.kind() {
479 Kind::String => {
480 if let Some(date_str) = field_value.as_str() {
481 if !self.is_valid_iso8601_date(date_str) {
482 return Err(format!(
483 "Invalid date format '{}' for field '{}' in {}/{}",
484 date_str,
485 field.name(),
486 service_name,
487 method_name
488 )
489 .into());
490 }
491 }
492 }
493 Kind::Int64 | Kind::Uint64 => {
494 if let Some(timestamp) = field_value.as_i64() {
496 if !(0..=4102444800).contains(×tamp) {
497 return Err(format!(
499 "Timestamp {} out of reasonable range for field '{}' in {}/{}",
500 timestamp,
501 field.name(),
502 service_name,
503 method_name
504 )
505 .into());
506 }
507 }
508 }
509 _ => {}
510 }
511 }
512
513 if field_name.contains("phone") && field.kind() == Kind::String {
515 if let Some(phone_str) = field_value.as_str() {
516 if !self.is_valid_phone_number(phone_str) {
517 return Err(format!(
518 "Invalid phone number format '{}' for field '{}' in {}/{}",
519 phone_str,
520 field.name(),
521 service_name,
522 method_name
523 )
524 .into());
525 }
526 }
527 }
528 }
529
530 Ok(())
531 }
532
533 fn validate_cross_field_rules(
535 &self,
536 message: &DynamicMessage,
537 service_name: &str,
538 method_name: &str,
539 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
540 let descriptor = message.descriptor();
541
542 let mut date_fields = Vec::new();
544 let mut timestamp_fields = Vec::new();
545
546 for field in descriptor.fields() {
547 let value = message.get_field(&field);
548 let field_value = value.as_ref();
549 let field_name = field.name().to_lowercase();
550
551 if field_name.contains("start")
552 && (field_name.contains("date") || field_name.contains("time"))
553 {
554 if let Some(value) = field_value.as_i64() {
555 date_fields.push(("start", value));
556 }
557 } else if field_name.contains("end")
558 && (field_name.contains("date") || field_name.contains("time"))
559 {
560 if let Some(value) = field_value.as_i64() {
561 date_fields.push(("end", value));
562 }
563 } else if field_name.contains("timestamp") {
564 if let Some(value) = field_value.as_i64() {
565 timestamp_fields.push((field.name().to_string(), value));
566 }
567 }
568 }
569
570 if date_fields.len() >= 2 {
572 let start_dates: Vec<_> = date_fields.iter().filter(|(t, _)| *t == "start").collect();
573 let end_dates: Vec<_> = date_fields.iter().filter(|(t, _)| *t == "end").collect();
574
575 for &(_, start_val) in &start_dates {
576 for &(_, end_val) in &end_dates {
577 if start_val >= end_val {
578 return Err(format!(
579 "Start date/time {} must be before end date/time {} in {}/{}",
580 start_val, end_val, service_name, method_name
581 )
582 .into());
583 }
584 }
585 }
586 }
587
588 if timestamp_fields.len() >= 2 {
590 let created_at = timestamp_fields
591 .iter()
592 .find(|(name, _)| name.to_lowercase().contains("created"));
593 let updated_at = timestamp_fields
594 .iter()
595 .find(|(name, _)| name.to_lowercase().contains("updated"));
596
597 if let (Some((_, created)), Some((_, updated))) = (created_at, updated_at) {
598 if created > updated {
599 return Err(format!(
600 "Created timestamp {} cannot be after updated timestamp {} in {}/{}",
601 created, updated, service_name, method_name
602 )
603 .into());
604 }
605 }
606 }
607
608 Ok(())
609 }
610
611 fn validate_custom_rules(
613 &self,
614 message: &DynamicMessage,
615 service_name: &str,
616 method_name: &str,
617 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
618 let descriptor = message.descriptor();
622
623 for field in descriptor.fields() {
624 let value = message.get_field(&field);
625 let field_value = value.as_ref();
626 let field_name = field.name().to_lowercase();
627
628 if field_name.ends_with("_id") || field_name == "id" {
630 match field.kind() {
631 Kind::Int32 | Kind::Int64 => {
632 if let Some(id_val) = field_value.as_i64() {
633 if id_val <= 0 {
634 return Err(format!(
635 "ID field '{}' must be positive, got {} in {}/{}",
636 field.name(),
637 id_val,
638 service_name,
639 method_name
640 )
641 .into());
642 }
643 }
644 }
645 Kind::Uint32 | Kind::Uint64 => {
646 if let Some(id_val) = field_value.as_u64() {
647 if id_val == 0 {
648 return Err(format!(
649 "ID field '{}' must be non-zero, got {} in {}/{}",
650 field.name(),
651 id_val,
652 service_name,
653 method_name
654 )
655 .into());
656 }
657 }
658 }
659 Kind::String => {
660 if let Some(id_str) = field_value.as_str() {
661 if id_str.trim().is_empty() {
662 return Err(format!(
663 "ID field '{}' cannot be empty in {}/{}",
664 field.name(),
665 service_name,
666 method_name
667 )
668 .into());
669 }
670 }
671 }
672 _ => {}
673 }
674 }
675
676 if field_name.contains("amount")
678 || field_name.contains("price")
679 || field_name.contains("cost")
680 {
681 if let Some(numeric_val) = field_value.as_f64() {
682 if numeric_val < 0.0 {
683 return Err(format!(
684 "Amount/price field '{}' cannot be negative, got {} in {}/{}",
685 field.name(),
686 numeric_val,
687 service_name,
688 method_name
689 )
690 .into());
691 }
692 }
693 }
694 }
695
696 Ok(())
697 }
698
699 fn is_valid_email(&self, email: &str) -> bool {
701 let parts: Vec<&str> = email.split('@').collect();
703 if parts.len() != 2 {
704 return false;
705 }
706
707 let local = parts[0];
708 let domain = parts[1];
709
710 if local.is_empty() || domain.is_empty() {
711 return false;
712 }
713
714 domain.contains('.') && !domain.starts_with('.') && !domain.ends_with('.')
716 }
717
718 fn is_valid_phone_number(&self, phone: &str) -> bool {
720 !phone.is_empty() && phone.len() >= 7 && phone.len() <= 15
722 }
723
724 fn is_valid_iso8601_date(&self, date_str: &str) -> bool {
726 chrono::DateTime::parse_from_rfc3339(date_str).is_ok()
729 || chrono::NaiveDate::parse_from_str(date_str, "%Y-%m-%d").is_ok()
730 || chrono::NaiveDateTime::parse_from_str(date_str, "%Y-%m-%d %H:%M:%S").is_ok()
731 }
732
733 pub async fn handle_error(
735 &self,
736 error: Status,
737 service_name: &str,
738 method_name: &str,
739 ) -> Status {
740 error!(
742 "Error in {}/{}: {} (code: {:?})",
743 service_name,
744 method_name,
745 error,
746 error.code()
747 );
748
749 match error.code() {
750 Code::InvalidArgument => Status::invalid_argument(format!(
751 "Invalid arguments provided to {}/{}",
752 service_name, method_name
753 )),
754 Code::NotFound => {
755 Status::not_found(format!("Resource not found in {}/{}", service_name, method_name))
756 }
757 Code::AlreadyExists => Status::already_exists(format!(
758 "Resource already exists in {}/{}",
759 service_name, method_name
760 )),
761 Code::PermissionDenied => Status::permission_denied(format!(
762 "Permission denied for {}/{}",
763 service_name, method_name
764 )),
765 Code::FailedPrecondition => Status::failed_precondition(format!(
766 "Precondition failed for {}/{}",
767 service_name, method_name
768 )),
769 Code::Aborted => {
770 Status::aborted(format!("Operation aborted for {}/{}", service_name, method_name))
771 }
772 Code::OutOfRange => Status::out_of_range(format!(
773 "Value out of range in {}/{}",
774 service_name, method_name
775 )),
776 Code::Unimplemented => Status::unimplemented(format!(
777 "Method {}/{} not implemented",
778 service_name, method_name
779 )),
780 Code::Internal => {
781 Status::internal(format!("Internal error in {}/{}", service_name, method_name))
782 }
783 Code::Unavailable => Status::unavailable(format!(
784 "Service {}/{} temporarily unavailable",
785 service_name, method_name
786 )),
787 Code::DataLoss => {
788 Status::data_loss(format!("Data loss occurred in {}/{}", service_name, method_name))
789 }
790 Code::Unauthenticated => Status::unauthenticated(format!(
791 "Authentication required for {}/{}",
792 service_name, method_name
793 )),
794 Code::DeadlineExceeded => Status::deadline_exceeded(format!(
795 "Request to {}/{} timed out",
796 service_name, method_name
797 )),
798 Code::ResourceExhausted => Status::resource_exhausted(format!(
799 "Rate limit exceeded for {}/{}",
800 service_name, method_name
801 )),
802 _ => {
803 let message = error.message();
804 if message.contains(service_name) && message.contains(method_name) {
805 error
806 } else {
807 Status::new(
808 error.code(),
809 format!("{}/{}: {}", service_name, method_name, message),
810 )
811 }
812 }
813 }
814 }
815
816 pub async fn collect_metrics(
818 &self,
819 service_name: &str,
820 method_name: &str,
821 duration: std::time::Duration,
822 success: bool,
823 ) {
824 let duration_ms = duration.as_millis() as u64;
825
826 if success {
827 record_success(service_name, method_name, duration_ms).await;
828 } else {
829 record_error(service_name, method_name).await;
830 }
831
832 tracing::debug!(
833 "Request {}/{} completed in {:?}, success: {}",
834 service_name,
835 method_name,
836 duration,
837 success
838 );
839 }
840}
841
842#[cfg(test)]
843mod tests {
844 #[test]
845 fn test_module_compiles() {
846 }
848}