pub(crate) mod bool_rules;
pub(crate) mod bytes;
pub(crate) mod duration;
pub(crate) mod enum_rules;
pub(crate) mod field_mask;
pub(crate) mod map;
pub(crate) mod number;
pub(crate) mod repeated;
pub(crate) mod string;
pub(crate) mod timestamp;
use proc_macro2::{Ident, TokenStream};
use prost_reflect::{DescriptorPool, FieldDescriptor};
use quote::quote;
use prost_protovalidate_types::field_rules;
use crate::Error;
use crate::naming::NamingContext;
pub(crate) fn generate_type_rules(
type_rules: &field_rules::Type,
field: &FieldDescriptor,
field_ident: &Ident,
proto_name: &str,
pool: &DescriptorPool,
naming: &NamingContext,
) -> Result<Vec<TokenStream>, Error> {
if is_wkt_wrapper(field) && !field.is_list() && !field.is_map() {
let inner_access = quote!(_wkt.value);
let inner_checks = generate_scalar_type_checks(type_rules, &inner_access, proto_name, &[])?;
if inner_checks.is_empty() {
return Ok(Vec::new());
}
return Ok(vec![quote! {
if let ::core::option::Option::Some(ref _wkt) = self.#field_ident {
#(#inner_checks)*
}
}]);
}
if field_storage_is_option_scalar(field) {
let inner_access = quote!((*_val));
let defined_values = defined_enum_values(&field.kind());
let inner_checks =
generate_scalar_type_checks(type_rules, &inner_access, proto_name, &defined_values)?;
if inner_checks.is_empty() {
return Ok(Vec::new());
}
return Ok(vec![quote! {
if let ::core::option::Option::Some(ref _val) = self.#field_ident {
#(#inner_checks)*
}
}]);
}
let value_access = quote!(self.#field_ident);
match type_rules {
field_rules::Type::Bool(r) => Ok(bool_rules::generate(r, &value_access, proto_name)),
field_rules::Type::Float(r) => Ok(number::generate_float(r, &value_access, proto_name)),
field_rules::Type::Double(r) => Ok(number::generate_double(r, &value_access, proto_name)),
field_rules::Type::Int32(r) => Ok(number::generate_int32(r, &value_access, proto_name)),
field_rules::Type::Int64(r) => Ok(number::generate_int64(r, &value_access, proto_name)),
field_rules::Type::Uint32(r) => Ok(number::generate_uint32(r, &value_access, proto_name)),
field_rules::Type::Uint64(r) => Ok(number::generate_uint64(r, &value_access, proto_name)),
field_rules::Type::Sint32(r) => Ok(number::generate_sint32(r, &value_access, proto_name)),
field_rules::Type::Sint64(r) => Ok(number::generate_sint64(r, &value_access, proto_name)),
field_rules::Type::Fixed32(r) => Ok(number::generate_fixed32(r, &value_access, proto_name)),
field_rules::Type::Fixed64(r) => Ok(number::generate_fixed64(r, &value_access, proto_name)),
field_rules::Type::Sfixed32(r) => {
Ok(number::generate_sfixed32(r, &value_access, proto_name))
}
field_rules::Type::Sfixed64(r) => {
Ok(number::generate_sfixed64(r, &value_access, proto_name))
}
field_rules::Type::String(r) => Ok(string::generate(r, &value_access, proto_name)),
field_rules::Type::Bytes(r) => Ok(bytes::generate(r, &value_access, proto_name)),
field_rules::Type::Enum(r) => {
let defined_values: Vec<i32> = field
.kind()
.as_enum()
.map(|e| e.values().map(|v| v.number()).collect())
.unwrap_or_default();
Ok(enum_rules::generate(
r,
&value_access,
proto_name,
&defined_values,
))
}
field_rules::Type::Repeated(r) => {
repeated::generate(r, field, field_ident, proto_name, pool, naming)
}
field_rules::Type::Map(r) => map::generate(r, field, field_ident, proto_name, pool, naming),
field_rules::Type::Duration(r) => Ok(duration::generate(r, field_ident, proto_name)),
field_rules::Type::Timestamp(r) => Ok(timestamp::generate(r, field_ident, proto_name)),
field_rules::Type::FieldMask(r) => Ok(field_mask::generate(r, field_ident, proto_name)),
field_rules::Type::Any(r) => Ok(generate_any_rules(r, field_ident, proto_name)),
}
}
pub(crate) fn generate_scalar_type_checks(
type_rules: &field_rules::Type,
value_access: &TokenStream,
proto_name: &str,
defined_values: &[i32],
) -> Result<Vec<TokenStream>, Error> {
match type_rules {
field_rules::Type::Bool(r) => Ok(bool_rules::generate(r, value_access, proto_name)),
field_rules::Type::Float(r) => Ok(number::generate_float(r, value_access, proto_name)),
field_rules::Type::Double(r) => Ok(number::generate_double(r, value_access, proto_name)),
field_rules::Type::Int32(r) => Ok(number::generate_int32(r, value_access, proto_name)),
field_rules::Type::Int64(r) => Ok(number::generate_int64(r, value_access, proto_name)),
field_rules::Type::Uint32(r) => Ok(number::generate_uint32(r, value_access, proto_name)),
field_rules::Type::Uint64(r) => Ok(number::generate_uint64(r, value_access, proto_name)),
field_rules::Type::Sint32(r) => Ok(number::generate_sint32(r, value_access, proto_name)),
field_rules::Type::Sint64(r) => Ok(number::generate_sint64(r, value_access, proto_name)),
field_rules::Type::Fixed32(r) => Ok(number::generate_fixed32(r, value_access, proto_name)),
field_rules::Type::Fixed64(r) => Ok(number::generate_fixed64(r, value_access, proto_name)),
field_rules::Type::Sfixed32(r) => {
Ok(number::generate_sfixed32(r, value_access, proto_name))
}
field_rules::Type::Sfixed64(r) => {
Ok(number::generate_sfixed64(r, value_access, proto_name))
}
field_rules::Type::String(r) => Ok(string::generate(r, value_access, proto_name)),
field_rules::Type::Bytes(r) => Ok(bytes::generate(r, value_access, proto_name)),
field_rules::Type::Enum(r) => Ok(enum_rules::generate(
r,
value_access,
proto_name,
defined_values,
)),
_ => Err(Error::Codegen(format!(
"unsupported item/key/value rule type for field {proto_name}"
))),
}
}
pub(crate) fn defined_enum_values(field_kind: &prost_reflect::Kind) -> Vec<i32> {
field_kind
.as_enum()
.map(|e| e.values().map(|v| v.number()).collect())
.unwrap_or_default()
}
pub(crate) fn field_storage_is_option_scalar(field: &FieldDescriptor) -> bool {
field.supports_presence()
&& !field.is_required()
&& !field.is_list()
&& !field.is_map()
&& field.kind().as_message().is_none()
}
fn is_wkt_wrapper(field: &FieldDescriptor) -> bool {
field.kind().as_message().is_some_and(|msg| {
matches!(
msg.full_name(),
"google.protobuf.BoolValue"
| "google.protobuf.BytesValue"
| "google.protobuf.DoubleValue"
| "google.protobuf.FloatValue"
| "google.protobuf.Int32Value"
| "google.protobuf.Int64Value"
| "google.protobuf.StringValue"
| "google.protobuf.UInt32Value"
| "google.protobuf.UInt64Value"
)
})
}
fn generate_any_rules(
r: &prost_protovalidate_types::AnyRules,
field_ident: &Ident,
proto_name: &str,
) -> Vec<TokenStream> {
let mut checks = Vec::new();
if !r.r#in.is_empty() {
let vals = &r.r#in;
checks.push(quote! {
if let ::core::option::Option::Some(ref _any) = self.#field_ident {
if ![#(#vals),*].contains(&_any.type_url.as_str()) {
violations.push(::prost_protovalidate::Violation::new(
#proto_name, "any.in", "type_url must be in the allow list",
));
}
}
});
}
if !r.not_in.is_empty() {
let vals = &r.not_in;
checks.push(quote! {
if let ::core::option::Option::Some(ref _any) = self.#field_ident {
if [#(#vals),*].contains(&_any.type_url.as_str()) {
violations.push(::prost_protovalidate::Violation::new(
#proto_name, "any.not_in", "type_url must not be in the block list",
));
}
}
});
}
checks
}