1use crate::reflection::metrics::{record_error, record_success};
7use crate::reflection::mock_proxy::proxy::MockReflectionProxy;
8use prost_reflect::{DynamicMessage, Kind, ReflectMessage};
9use std::time::Instant;
10use tonic::{
11 metadata::{Ascii, MetadataKey, MetadataValue},
12 Code, Request, Status,
13};
14use tracing::error;
15
16impl MockReflectionProxy {
17 pub async fn preprocess_request<T>(&self, request: &mut Request<T>) -> Result<(), Status>
19 where
20 T: prost_reflect::ReflectMessage,
21 {
22 let mut metadata_log = Vec::new();
24 for kv in request.metadata().iter() {
25 match kv {
26 tonic::metadata::KeyAndValueRef::Ascii(key, value) => {
27 metadata_log.push(format!("{}: {}", key, value.to_str().unwrap_or("<binary>")));
28 }
29 tonic::metadata::KeyAndValueRef::Binary(key, _) => {
30 metadata_log.push(format!("{}: <binary>", key));
31 }
32 }
33 }
34 tracing::debug!("Extracted request metadata: [{}]", metadata_log.join(", "));
35
36 let descriptor = request.get_ref().descriptor();
38 let mut buf = Vec::new();
39 request
40 .get_ref()
41 .encode(&mut buf)
42 .map_err(|_e| Status::internal("Failed to encode request".to_string()))?;
43 let dynamic_message = DynamicMessage::decode(descriptor.clone(), &buf[..])
44 .map_err(|_e| Status::internal("Failed to decode request".to_string()))?;
45 if let Err(e) = self.validate_request_message(&dynamic_message) {
46 return Err(Status::internal(format!("Request validation failed: {}", e)));
47 }
48 tracing::debug!("Request format validation passed");
49
50 request.metadata_mut().insert("x-mockforge-processed", "true".parse().unwrap());
53 request
54 .metadata_mut()
55 .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
56
57 tracing::debug!("Applied request transformations: added processed and timestamp headers");
58
59 Ok(())
60 }
61
62 pub async fn log_request<T>(&self, request: &Request<T>, service_name: &str, method_name: &str)
64 where
65 T: prost_reflect::ReflectMessage,
66 {
67 let start_time = std::time::Instant::now();
68
69 let mut metadata_log = Vec::new();
71 for kv in request.metadata().iter() {
72 match kv {
73 tonic::metadata::KeyAndValueRef::Ascii(key, value) => {
74 metadata_log.push(format!("{}: {}", key, value.to_str().unwrap_or("<binary>")));
75 }
76 tonic::metadata::KeyAndValueRef::Binary(key, _) => {
77 metadata_log.push(format!("{}: <binary>", key));
78 }
79 }
80 }
81 tracing::debug!(
82 "Request metadata for {}/{}: [{}]",
83 service_name,
84 method_name,
85 metadata_log.join(", ")
86 );
87
88 let request_size = request.get_ref().encoded_len();
90 tracing::debug!(
91 "Request size for {}/{}: {} bytes",
92 service_name,
93 method_name,
94 request_size
95 );
96
97 tracing::debug!(
99 "Request start time for {}/{}: {:?}",
100 service_name,
101 method_name,
102 start_time
103 );
104 }
105
106 pub async fn postprocess_response<T>(
108 &self,
109 response: &mut tonic::Response<T>,
110 service_name: &str,
111 method_name: &str,
112 ) -> Result<(), Status> {
113 let start = Instant::now();
114 response.metadata_mut().insert("x-mockforge-processed", "true".parse().unwrap());
116 response
117 .metadata_mut()
118 .insert("x-mockforge-timestamp", chrono::Utc::now().to_rfc3339().parse().unwrap());
119
120 if self.config.response_transform.enabled {
131 for (key, value) in &self.config.response_transform.custom_headers {
133 let parsed_key: Option<MetadataKey<Ascii>> = key.parse().ok();
134 let parsed_value: Option<MetadataValue<Ascii>> = value.parse().ok();
135
136 match (parsed_key, parsed_value) {
137 (Some(k), Some(v)) => {
138 response.metadata_mut().insert(k, v);
139 }
140 (None, _) => {
141 tracing::warn!(
142 "Skipping invalid custom header key '{}' in response transform config",
143 key
144 );
145 }
146 (_, None) => {
147 tracing::warn!(
148 "Skipping invalid custom header value for key '{}' in response transform config",
149 key
150 );
151 }
152 }
153 }
154 }
155
156 let processing_time = start.elapsed().as_millis();
158 response
160 .metadata_mut()
161 .insert("x-mockforge-processing-time", processing_time.to_string().parse().unwrap());
162 tracing::debug!("Postprocessed response for {}/{}", service_name, method_name);
163
164 Ok(())
165 }
166
167 pub async fn postprocess_dynamic_response(
169 &self,
170 response: &mut tonic::Response<prost_reflect::DynamicMessage>,
171 service_name: &str,
172 method_name: &str,
173 ) -> Result<(), Status> {
174 self.postprocess_response(response, service_name, method_name).await?;
176
177 if self.config.response_transform.enabled {
179 if let Some(ref overrides) = self.config.response_transform.overrides {
180 match self
181 .transform_dynamic_message(
182 &response.get_ref().clone(),
183 service_name,
184 method_name,
185 overrides,
186 )
187 .await
188 {
189 Ok(transformed_message) => {
190 *response.get_mut() = transformed_message;
192 tracing::debug!(
193 "Applied body transformations to response for {}/{}",
194 service_name,
195 method_name
196 );
197 }
198 Err(e) => {
199 tracing::warn!(
200 "Failed to transform response body for {}/{}: {}",
201 service_name,
202 method_name,
203 e
204 );
205 }
206 }
207 }
208
209 if self.config.response_transform.validate_responses {
211 if let Err(validation_error) = self
212 .validate_dynamic_message(response.get_ref(), service_name, method_name)
213 .await
214 {
215 tracing::warn!(
216 "Response validation failed for {}/{}: {}",
217 service_name,
218 method_name,
219 validation_error
220 );
221 }
222 }
223 }
224
225 Ok(())
226 }
227
228 async fn transform_dynamic_message(
230 &self,
231 message: &prost_reflect::DynamicMessage,
232 service_name: &str,
233 method_name: &str,
234 overrides: &mockforge_core::overrides::Overrides,
235 ) -> Result<prost_reflect::DynamicMessage, Box<dyn std::error::Error + Send + Sync>> {
236 use crate::dynamic::http_bridge::converters::ProtobufJsonConverter;
237
238 let descriptor_pool = self.service_registry.descriptor_pool();
240
241 let converter = ProtobufJsonConverter::new(descriptor_pool.clone());
243
244 let json_value = converter.protobuf_to_json(&message.descriptor(), message)?;
246
247 let mut json_value = serde_json::Value::Object(json_value.as_object().unwrap().clone());
249 overrides.apply_with_context(
250 &format!("{}/{}", service_name, method_name),
251 &[service_name.to_string()],
252 &format!("{}/{}", service_name, method_name),
253 &mut json_value,
254 &mockforge_core::conditions::ConditionContext::new(),
255 );
256
257 let transformed_message = converter.json_to_protobuf(&message.descriptor(), &json_value)?;
259
260 Ok(transformed_message)
261 }
262
263 pub async fn postprocess_streaming_dynamic_response(
265 &self,
266 response: &mut tonic::Response<
267 tokio_stream::wrappers::ReceiverStream<
268 Result<prost_reflect::DynamicMessage, tonic::Status>,
269 >,
270 >,
271 service_name: &str,
272 method_name: &str,
273 ) -> Result<(), Status> {
274 self.postprocess_response(response, service_name, method_name).await?;
276
277 if self.config.response_transform.enabled {
282 if self.config.response_transform.overrides.is_some() {
283 tracing::debug!(
284 "Body transformation for streaming responses not yet implemented for {}/{}",
285 service_name,
286 method_name
287 );
288 }
289
290 if self.config.response_transform.validate_responses {
291 tracing::debug!(
292 "Response validation for streaming responses not yet implemented for {}/{}",
293 service_name,
294 method_name
295 );
296 }
297 }
298
299 Ok(())
300 }
301
302 async fn validate_dynamic_message(
304 &self,
305 message: &prost_reflect::DynamicMessage,
306 service_name: &str,
307 method_name: &str,
308 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
309 let _descriptor = message.descriptor();
311
312 self.validate_message_schema(message, service_name, method_name)?;
319
320 self.validate_business_rules(message, service_name, method_name)?;
322
323 self.validate_cross_field_rules(message, service_name, method_name)?;
325
326 self.validate_custom_rules(message, service_name, method_name)?;
328
329 tracing::debug!("Response validation passed for {}/{}", service_name, method_name);
330
331 Ok(())
332 }
333
334 fn validate_request_message(
336 &self,
337 message: &DynamicMessage,
338 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
339 self.validate_message_schema(message, "", "")?;
341 self.validate_business_rules(message, "", "")?;
343 self.validate_cross_field_rules(message, "", "")?;
345 self.validate_custom_rules(message, "", "")?;
347 tracing::debug!("Request validation passed");
348 Ok(())
349 }
350
351 fn validate_message_schema(
353 &self,
354 message: &DynamicMessage,
355 _service_name: &str,
356 _method_name: &str,
357 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
358 let descriptor = message.descriptor();
359
360 for field in descriptor.fields() {
362 let value = message.get_field(&field);
363 let value_ref = value.as_ref();
364
365 if !Self::value_matches_kind(value_ref, field.kind()) {
367 return Err(format!(
368 "{} field '{}' has incorrect type: expected {:?}, got {:?}",
369 "Message validation",
370 field.name(),
371 field.kind(),
372 value_ref
373 )
374 .into());
375 }
376
377 if let Kind::Message(expected_msg) = field.kind() {
379 if let prost_reflect::Value::Message(ref nested_msg) = *value_ref {
380 if nested_msg.descriptor() != expected_msg {
382 return Err(format!(
383 "{} field '{}' has incorrect message type",
384 "Message validation",
385 field.name()
386 )
387 .into());
388 }
389 }
390 }
391 }
392
393 Ok(())
394 }
395
396 fn validate_business_rules(
398 &self,
399 message: &DynamicMessage,
400 service_name: &str,
401 method_name: &str,
402 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
403 let descriptor = message.descriptor();
404
405 for field in descriptor.fields() {
406 let value = message.get_field(&field);
407 let field_value = value.as_ref();
408 let field_name = field.name().to_lowercase();
409
410 if field_name.contains("email") && field.kind() == Kind::String {
412 if let Some(email_str) = field_value.as_str() {
413 if !self.is_valid_email(email_str) {
414 return Err(format!(
415 "Invalid email format '{}' for field '{}' in {}/{}",
416 email_str,
417 field.name(),
418 service_name,
419 method_name
420 )
421 .into());
422 }
423 }
424 }
425
426 if field_name.contains("date") || field_name.contains("timestamp") {
428 match field.kind() {
429 Kind::String => {
430 if let Some(date_str) = field_value.as_str() {
431 if !self.is_valid_iso8601_date(date_str) {
432 return Err(format!(
433 "Invalid date format '{}' for field '{}' in {}/{}",
434 date_str,
435 field.name(),
436 service_name,
437 method_name
438 )
439 .into());
440 }
441 }
442 }
443 Kind::Int64 | Kind::Uint64 => {
444 if let Some(timestamp) = field_value.as_i64() {
446 if !(0..=4102444800).contains(×tamp) {
447 return Err(format!(
449 "Timestamp {} out of reasonable range for field '{}' in {}/{}",
450 timestamp,
451 field.name(),
452 service_name,
453 method_name
454 )
455 .into());
456 }
457 }
458 }
459 _ => {}
460 }
461 }
462
463 if field_name.contains("phone") && field.kind() == Kind::String {
465 if let Some(phone_str) = field_value.as_str() {
466 if !self.is_valid_phone_number(phone_str) {
467 return Err(format!(
468 "Invalid phone number format '{}' for field '{}' in {}/{}",
469 phone_str,
470 field.name(),
471 service_name,
472 method_name
473 )
474 .into());
475 }
476 }
477 }
478 }
479
480 Ok(())
481 }
482
483 fn validate_cross_field_rules(
485 &self,
486 message: &DynamicMessage,
487 service_name: &str,
488 method_name: &str,
489 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
490 let descriptor = message.descriptor();
491
492 let mut date_fields = Vec::new();
494 let mut timestamp_fields = Vec::new();
495
496 for field in descriptor.fields() {
497 let value = message.get_field(&field);
498 let field_value = value.as_ref();
499 let field_name = field.name().to_lowercase();
500
501 if field_name.contains("start")
502 && (field_name.contains("date") || field_name.contains("time"))
503 {
504 if let Some(value) = field_value.as_i64() {
505 date_fields.push(("start", value));
506 }
507 } else if field_name.contains("end")
508 && (field_name.contains("date") || field_name.contains("time"))
509 {
510 if let Some(value) = field_value.as_i64() {
511 date_fields.push(("end", value));
512 }
513 } else if field_name.contains("timestamp") {
514 if let Some(value) = field_value.as_i64() {
515 timestamp_fields.push((field.name().to_string(), value));
516 }
517 }
518 }
519
520 if date_fields.len() >= 2 {
522 let start_dates: Vec<_> = date_fields.iter().filter(|(t, _)| *t == "start").collect();
523 let end_dates: Vec<_> = date_fields.iter().filter(|(t, _)| *t == "end").collect();
524
525 for &(_, start_val) in &start_dates {
526 for &(_, end_val) in &end_dates {
527 if start_val >= end_val {
528 return Err(format!(
529 "Start date/time {} must be before end date/time {} in {}/{}",
530 start_val, end_val, service_name, method_name
531 )
532 .into());
533 }
534 }
535 }
536 }
537
538 if timestamp_fields.len() >= 2 {
540 let created_at = timestamp_fields
541 .iter()
542 .find(|(name, _)| name.to_lowercase().contains("created"));
543 let updated_at = timestamp_fields
544 .iter()
545 .find(|(name, _)| name.to_lowercase().contains("updated"));
546
547 if let (Some((_, created)), Some((_, updated))) = (created_at, updated_at) {
548 if created > updated {
549 return Err(format!(
550 "Created timestamp {} cannot be after updated timestamp {} in {}/{}",
551 created, updated, service_name, method_name
552 )
553 .into());
554 }
555 }
556 }
557
558 Ok(())
559 }
560
561 fn validate_custom_rules(
563 &self,
564 message: &DynamicMessage,
565 service_name: &str,
566 method_name: &str,
567 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
568 let descriptor = message.descriptor();
572
573 for field in descriptor.fields() {
574 let value = message.get_field(&field);
575 let field_value = value.as_ref();
576 let field_name = field.name().to_lowercase();
577
578 if field_name.ends_with("_id") || field_name == "id" {
580 match field.kind() {
581 Kind::Int32 | Kind::Int64 => {
582 if let Some(id_val) = field_value.as_i64() {
583 if id_val <= 0 {
584 return Err(format!(
585 "ID field '{}' must be positive, got {} in {}/{}",
586 field.name(),
587 id_val,
588 service_name,
589 method_name
590 )
591 .into());
592 }
593 }
594 }
595 Kind::Uint32 | Kind::Uint64 => {
596 if let Some(id_val) = field_value.as_u64() {
597 if id_val == 0 {
598 return Err(format!(
599 "ID field '{}' must be non-zero, got {} in {}/{}",
600 field.name(),
601 id_val,
602 service_name,
603 method_name
604 )
605 .into());
606 }
607 }
608 }
609 Kind::String => {
610 if let Some(id_str) = field_value.as_str() {
611 if id_str.trim().is_empty() {
612 return Err(format!(
613 "ID field '{}' cannot be empty in {}/{}",
614 field.name(),
615 service_name,
616 method_name
617 )
618 .into());
619 }
620 }
621 }
622 _ => {}
623 }
624 }
625
626 if field_name.contains("amount")
628 || field_name.contains("price")
629 || field_name.contains("cost")
630 {
631 if let Some(numeric_val) = field_value.as_f64() {
632 if numeric_val < 0.0 {
633 return Err(format!(
634 "Amount/price field '{}' cannot be negative, got {} in {}/{}",
635 field.name(),
636 numeric_val,
637 service_name,
638 method_name
639 )
640 .into());
641 }
642 }
643 }
644 }
645
646 Ok(())
647 }
648
649 fn is_valid_email(&self, email: &str) -> bool {
651 let parts: Vec<&str> = email.split('@').collect();
653 if parts.len() != 2 {
654 return false;
655 }
656
657 let local = parts[0];
658 let domain = parts[1];
659
660 if local.is_empty() || domain.is_empty() {
661 return false;
662 }
663
664 domain.contains('.') && !domain.starts_with('.') && !domain.ends_with('.')
666 }
667
668 fn is_valid_phone_number(&self, phone: &str) -> bool {
670 !phone.is_empty() && phone.len() >= 7 && phone.len() <= 15
672 }
673
674 fn is_valid_iso8601_date(&self, date_str: &str) -> bool {
676 chrono::DateTime::parse_from_rfc3339(date_str).is_ok()
679 || chrono::NaiveDate::parse_from_str(date_str, "%Y-%m-%d").is_ok()
680 || chrono::NaiveDateTime::parse_from_str(date_str, "%Y-%m-%d %H:%M:%S").is_ok()
681 }
682
683 pub async fn handle_error(
685 &self,
686 error: Status,
687 service_name: &str,
688 method_name: &str,
689 ) -> Status {
690 error!(
692 "Error in {}/{}: {} (code: {:?})",
693 service_name,
694 method_name,
695 error,
696 error.code()
697 );
698
699 match error.code() {
700 Code::InvalidArgument => Status::invalid_argument(format!(
701 "Invalid arguments provided to {}/{}",
702 service_name, method_name
703 )),
704 Code::NotFound => {
705 Status::not_found(format!("Resource not found in {}/{}", service_name, method_name))
706 }
707 Code::AlreadyExists => Status::already_exists(format!(
708 "Resource already exists in {}/{}",
709 service_name, method_name
710 )),
711 Code::PermissionDenied => Status::permission_denied(format!(
712 "Permission denied for {}/{}",
713 service_name, method_name
714 )),
715 Code::FailedPrecondition => Status::failed_precondition(format!(
716 "Precondition failed for {}/{}",
717 service_name, method_name
718 )),
719 Code::Aborted => {
720 Status::aborted(format!("Operation aborted for {}/{}", service_name, method_name))
721 }
722 Code::OutOfRange => Status::out_of_range(format!(
723 "Value out of range in {}/{}",
724 service_name, method_name
725 )),
726 Code::Unimplemented => Status::unimplemented(format!(
727 "Method {}/{} not implemented",
728 service_name, method_name
729 )),
730 Code::Internal => {
731 Status::internal(format!("Internal error in {}/{}", service_name, method_name))
732 }
733 Code::Unavailable => Status::unavailable(format!(
734 "Service {}/{} temporarily unavailable",
735 service_name, method_name
736 )),
737 Code::DataLoss => {
738 Status::data_loss(format!("Data loss occurred in {}/{}", service_name, method_name))
739 }
740 Code::Unauthenticated => Status::unauthenticated(format!(
741 "Authentication required for {}/{}",
742 service_name, method_name
743 )),
744 Code::DeadlineExceeded => Status::deadline_exceeded(format!(
745 "Request to {}/{} timed out",
746 service_name, method_name
747 )),
748 Code::ResourceExhausted => Status::resource_exhausted(format!(
749 "Rate limit exceeded for {}/{}",
750 service_name, method_name
751 )),
752 _ => {
753 let message = error.message();
754 if message.contains(service_name) && message.contains(method_name) {
755 error
756 } else {
757 Status::new(
758 error.code(),
759 format!("{}/{}: {}", service_name, method_name, message),
760 )
761 }
762 }
763 }
764 }
765
766 pub async fn collect_metrics(
768 &self,
769 service_name: &str,
770 method_name: &str,
771 duration: std::time::Duration,
772 success: bool,
773 ) {
774 let duration_ms = duration.as_millis() as u64;
775
776 if success {
777 record_success(service_name, method_name, duration_ms).await;
778 } else {
779 record_error(service_name, method_name).await;
780 }
781
782 tracing::debug!(
783 "Request {}/{} completed in {:?}, success: {}",
784 service_name,
785 method_name,
786 duration,
787 success
788 );
789 }
790}
791
792#[cfg(test)]
793mod tests {
794
795 #[test]
796 fn test_module_compiles() {}
797}