use crate::reflection::mock_proxy::proxy::MockReflectionProxy;
use mockforge_openapi::openapi_routes::ValidationMode;
use prost::bytes::Bytes as ProstBytes;
use prost_reflect::ReflectMessage;
use prost_reflect::{DynamicMessage, Kind, MessageDescriptor, Value};
use tonic::{Request, Status};
use tracing::debug;
use prost_reflect::prost::Message;
impl MockReflectionProxy {
pub async fn validate_request(
&self,
request: &Request<DynamicMessage>,
service_name: &str,
method_name: &str,
) -> Result<(), Status> {
debug!("Validating request for {}/{}", service_name, method_name);
let method_descriptor = self.cache.get_method(service_name, method_name).await?;
let expected_descriptor = method_descriptor.input();
let actual_descriptor = request.get_ref().descriptor();
if actual_descriptor.full_name() != expected_descriptor.full_name() {
return Err(Status::invalid_argument(format!(
"Request type mismatch: expected {}, got {}",
expected_descriptor.full_name(),
actual_descriptor.full_name()
)));
}
let method_descriptor = self.cache.get_method(service_name, method_name).await?;
let expected_descriptor = method_descriptor.input();
let encoded = request.get_ref().encode_to_vec();
let dynamic_message =
DynamicMessage::decode(expected_descriptor.clone(), ProstBytes::from(encoded))
.map_err(|e| {
Status::invalid_argument(format!(
"Failed to decode request as DynamicMessage: {}",
e
))
})?;
Self::validate_dynamic_message_fields(&dynamic_message, &expected_descriptor, "request")?;
debug!("Request validation passed for {}/{}", service_name, method_name);
Ok(())
}
pub async fn validate_response(
&self,
response: &DynamicMessage,
service_name: &str,
method_name: &str,
) -> Result<(), Status> {
debug!("Validating response for {}/{}", service_name, method_name);
let method_descriptor = self.cache.get_method(service_name, method_name).await?;
let expected_descriptor = method_descriptor.output();
if response.descriptor().full_name() != expected_descriptor.full_name() {
return Err(Status::invalid_argument(format!(
"Response type mismatch: expected {}, got {}",
expected_descriptor.full_name(),
response.descriptor().full_name()
)));
}
Self::validate_dynamic_message_fields(response, &expected_descriptor, "response")?;
debug!("Response validation passed for {}/{}", service_name, method_name);
Ok(())
}
pub async fn route_request<T>(
&self,
request: Request<T>,
) -> Result<(String, String, Request<T>), Status> {
let (service_name, method_name) = self.extract_service_method_from_request(&request)?;
let contains_service = self.cache.contains_service(&service_name).await;
if !contains_service {
return Err(Status::not_found(format!("Service {} not found", service_name)));
}
if self.cache.get_method(&service_name, &method_name).await.is_err() {
return Err(Status::not_found(format!(
"Method {} not found in service {}",
method_name, service_name
)));
}
Ok((service_name.to_string(), method_name.to_string(), request))
}
pub async fn can_handle_service_method(&self, service_name: &str, method_name: &str) -> bool {
if !self.cache.contains_service(service_name).await {
return false;
}
if !self.cache.contains_method(service_name, method_name).await {
return false;
}
true
}
pub async fn validate_service_method_signature(
&self,
service_name: &str,
method_name: &str,
input_descriptor: MessageDescriptor,
output_descriptor: MessageDescriptor,
) -> Result<(), Status> {
debug!("Validating signature for {}/{}", service_name, method_name);
let cached_descriptor = self.cache.get_method(service_name, method_name).await?;
if input_descriptor.full_name() != cached_descriptor.input().full_name() {
return Err(Status::invalid_argument(format!(
"Input type mismatch: expected {}, got {}",
cached_descriptor.input().full_name(),
input_descriptor.full_name()
)));
}
if output_descriptor.full_name() != cached_descriptor.output().full_name() {
return Err(Status::invalid_argument(format!(
"Output type mismatch: expected {}, got {}",
cached_descriptor.output().full_name(),
output_descriptor.full_name()
)));
}
Self::check_message_compatibility(&cached_descriptor.input(), &input_descriptor, "input")?;
Self::check_message_compatibility(
&cached_descriptor.output(),
&output_descriptor,
"output",
)?;
debug!("Signature validation passed for {}/{}", service_name, method_name);
Ok(())
}
fn check_message_compatibility(
expected: &MessageDescriptor,
provided: &MessageDescriptor,
message_type: &str,
) -> Result<(), Status> {
for expected_field in expected.fields() {
let field_name = expected_field.name();
if let Some(provided_field) = provided.get_field_by_name(field_name) {
if expected_field.kind() != provided_field.kind() {
return Err(Status::invalid_argument(format!(
"{} field '{}' type mismatch: expected {:?}, got {:?}",
message_type,
field_name,
expected_field.kind(),
provided_field.kind()
)));
}
if let Kind::Message(expected_msg) = expected_field.kind() {
if let Kind::Message(provided_msg) = provided_field.kind() {
if expected_msg.full_name() != provided_msg.full_name() {
Self::check_message_compatibility(
&expected_msg,
&provided_msg,
&format!("{}.{}", message_type, field_name),
)?;
}
}
}
} else {
return Err(Status::invalid_argument(format!(
"Missing {} field '{}' in provided descriptor",
message_type, field_name
)));
}
}
Ok(())
}
fn validate_dynamic_message_fields(
message: &DynamicMessage,
descriptor: &MessageDescriptor,
context: &str,
) -> Result<(), Status> {
for field in descriptor.fields() {
let field_name = field.name();
let value = message.get_field(&field);
let value_ref = value.as_ref();
if !Self::value_matches_kind(value_ref, field.kind()) {
return Err(Status::invalid_argument(format!(
"{} field '{}' has incorrect type: expected {:?}, got {:?}",
context,
field_name,
field.kind(),
value_ref
)));
}
if let Kind::Message(expected_msg) = field.kind() {
if let Value::Message(ref nested_msg) = *value_ref {
Self::validate_dynamic_message_fields(
nested_msg,
&expected_msg,
&format!("{}.{}", context, field_name),
)?;
}
}
}
Ok(())
}
pub fn value_matches_kind(value: &Value, kind: Kind) -> bool {
match *value {
Value::Bool(_) => kind == Kind::Bool,
Value::I32(_) => matches!(kind, Kind::Int32 | Kind::Sint32 | Kind::Sfixed32),
Value::I64(_) => matches!(kind, Kind::Int64 | Kind::Sint64 | Kind::Sfixed64),
Value::U32(_) => {
matches!(kind, Kind::Uint32 | Kind::Fixed32)
}
Value::U64(_) => {
matches!(kind, Kind::Uint64 | Kind::Fixed64)
}
Value::F32(_) => kind == Kind::Float,
Value::F64(_) => kind == Kind::Double,
Value::String(_) => kind == Kind::String,
Value::Bytes(_) => kind == Kind::Bytes,
Value::Message(_) => matches!(kind, Kind::Message(_)),
Value::List(_) => matches!(kind, Kind::Message(_)), _ => false,
}
}
pub async fn validate_request_size<T>(
&self,
request: &Request<T>,
max_size: usize,
) -> Result<(), Status>
where
T: Message,
{
let encoded_size = request.get_ref().encode_to_vec().len();
if encoded_size > max_size {
return Err(Status::resource_exhausted(format!(
"Request size {} bytes exceeds maximum allowed size of {} bytes",
encoded_size, max_size
)));
}
Ok(())
}
pub async fn validate_response_size(
&self,
response: &DynamicMessage,
max_size: usize,
) -> Result<(), Status> {
let encoded_size = response.encode_to_vec().len();
if encoded_size > max_size {
return Err(Status::resource_exhausted(format!(
"Response size {} bytes exceeds maximum allowed size of {} bytes",
encoded_size, max_size
)));
}
Ok(())
}
pub fn should_skip_validation(&self, service_name: &str, method_name: &str) -> bool {
for prefix in &self.config.admin_skip_prefixes {
if service_name.starts_with(prefix) || method_name.starts_with(prefix) {
return true;
}
}
false
}
pub fn get_validation_mode_for_method(
&self,
service_name: &str,
method_name: &str,
) -> ValidationMode {
if let Some(mode) = self.config.overrides.get(&format!("{}/{}", service_name, method_name))
{
return mode.clone();
}
if let Some(mode) = self.config.overrides.get(service_name) {
return mode.clone();
}
self.config.request_mode.clone()
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_module_compiles() {
}
}