use proc_macro::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Field, Fields, GenericArgument, PathArguments, Type};
use crate::utils::{determine_visibility, get_holy_string_value};
enum SanitizeRule {
Trim,
Lowercase,
Uppercase,
Truncate(usize),
Alphanumeric,
EscapeHtml,
NulStrip,
ControlStrip,
Slug,
Clamp(proc_macro2::TokenStream, proc_macro2::TokenStream),
}
enum FieldTypeKind {
String,
OptionString,
Numeric,
Other,
}
fn is_string_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
return segment.ident == "String";
}
}
false
}
fn classify_type(ty: &Type) -> FieldTypeKind {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
let ident = segment.ident.to_string();
if ident == "Option" {
if let PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(GenericArgument::Type(inner)) = args.args.first() {
if is_string_type(inner) {
return FieldTypeKind::OptionString;
}
}
}
return FieldTypeKind::Other;
}
return match ident.as_str() {
"String" => FieldTypeKind::String,
"i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64"
| "isize" | "usize" => FieldTypeKind::Numeric,
_ => FieldTypeKind::Other,
};
}
}
FieldTypeKind::Other
}
fn split_rules(input: &str) -> Vec<String> {
let mut result = Vec::new();
let mut current = String::new();
let mut depth = 0u32;
for ch in input.chars() {
match ch {
'(' => {
depth += 1;
current.push(ch);
}
')' => {
depth = depth.saturating_sub(1);
current.push(ch);
}
',' if depth == 0 => {
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
result.push(trimmed);
}
current.clear();
}
_ => {
current.push(ch);
}
}
}
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
result.push(trimmed);
}
result
}
fn parse_sanitize_rules(
raw: &str,
span: proc_macro2::Span,
) -> Result<Vec<SanitizeRule>, syn::Error> {
let tokens = split_rules(raw);
let mut rules = Vec::new();
for token in &tokens {
let rule = if token == "trim" {
SanitizeRule::Trim
} else if token == "lowercase" {
SanitizeRule::Lowercase
} else if token == "uppercase" {
SanitizeRule::Uppercase
} else if token == "alphanumeric" {
SanitizeRule::Alphanumeric
} else if token == "escape_html" {
SanitizeRule::EscapeHtml
} else if token == "nul_strip" {
SanitizeRule::NulStrip
} else if token == "control_strip" {
SanitizeRule::ControlStrip
} else if token == "slug" {
SanitizeRule::Slug
} else if let Some(inner) = token
.strip_prefix("truncate(")
.and_then(|s| s.strip_suffix(')'))
{
let n: usize = inner.trim().parse().map_err(|_| {
syn::Error::new(span, format!("invalid truncate length: '{}'", inner.trim()))
})?;
SanitizeRule::Truncate(n)
} else if let Some(inner) = token
.strip_prefix("clamp(")
.and_then(|s| s.strip_suffix(')'))
{
let parts: Vec<&str> = inner.splitn(2, ',').collect();
if parts.len() != 2 {
return Err(syn::Error::new(
span,
format!(
"clamp requires two arguments: clamp(min,max), got '{}'",
token
),
));
}
let min_raw = parts[0].trim();
let max_raw = parts[1].trim();
let min_ts: proc_macro2::TokenStream = min_raw.parse().map_err(|_| {
syn::Error::new(span, format!("invalid clamp min argument: '{}'", min_raw))
})?;
let max_ts: proc_macro2::TokenStream = max_raw.parse().map_err(|_| {
syn::Error::new(span, format!("invalid clamp max argument: '{}'", max_raw))
})?;
SanitizeRule::Clamp(min_ts, max_ts)
} else {
return Err(syn::Error::new(
span,
format!("unknown sanitize rule: '{}'", token),
));
};
rules.push(rule);
}
Ok(rules)
}
fn validate_rule_for_type(
rule: &SanitizeRule,
type_kind: &FieldTypeKind,
field_name: &syn::Ident,
span: proc_macro2::Span,
) -> Result<(), syn::Error> {
match rule {
SanitizeRule::Trim
| SanitizeRule::Lowercase
| SanitizeRule::Uppercase
| SanitizeRule::Truncate(_)
| SanitizeRule::Alphanumeric
| SanitizeRule::EscapeHtml
| SanitizeRule::NulStrip
| SanitizeRule::ControlStrip
| SanitizeRule::Slug => {
if !matches!(
type_kind,
FieldTypeKind::String | FieldTypeKind::OptionString
) {
let rule_name = match rule {
SanitizeRule::Trim => "trim",
SanitizeRule::Lowercase => "lowercase",
SanitizeRule::Uppercase => "uppercase",
SanitizeRule::Truncate(_) => "truncate",
SanitizeRule::Alphanumeric => "alphanumeric",
SanitizeRule::EscapeHtml => "escape_html",
SanitizeRule::NulStrip => "nul_strip",
SanitizeRule::ControlStrip => "control_strip",
SanitizeRule::Slug => "slug",
_ => unreachable!(),
};
return Err(syn::Error::new(
span,
format!(
"sanitize rule '{}' is only valid for String fields, but field '{}' has a numeric type",
rule_name, field_name
),
));
}
}
SanitizeRule::Clamp(_, _) => {
if !matches!(type_kind, FieldTypeKind::Numeric) {
return Err(syn::Error::new(
span,
format!(
"sanitize rule 'clamp' is only valid for numeric fields, but field '{}' has type String",
field_name
),
));
}
}
}
Ok(())
}
fn rule_to_tokens(
access: &proc_macro2::TokenStream,
rule: &SanitizeRule,
) -> proc_macro2::TokenStream {
match rule {
SanitizeRule::Trim => quote! {
#access = #access.trim().to_string();
},
SanitizeRule::Lowercase => quote! {
#access = #access.to_lowercase();
},
SanitizeRule::Uppercase => quote! {
#access = #access.to_uppercase();
},
SanitizeRule::Truncate(n) => quote! {
if #access.len() > #n {
let mut __end = #n;
while __end > 0 && !#access.is_char_boundary(__end) {
__end -= 1;
}
#access.truncate(__end);
}
},
SanitizeRule::Alphanumeric => quote! {
#access = #access.chars().filter(|c| c.is_alphanumeric()).collect();
},
SanitizeRule::EscapeHtml => quote! {
#access = #access
.replace('&', "&")
.replace('<', "<")
.replace('>', ">")
.replace('"', """)
.replace('\'', "'");
},
SanitizeRule::NulStrip => quote! {
if #access.contains('\0') {
#access = #access.replace('\0', "");
}
},
SanitizeRule::ControlStrip => quote! {
#access = #access
.chars()
.filter(|c| {
let cp = *c as u32;
!c.is_control()
&& !(0x202A..=0x202E).contains(&cp)
&& !(0x2066..=0x2069).contains(&cp)
&& !matches!(*c, '\u{200B}'..='\u{200D}' | '\u{FEFF}')
})
.collect();
},
SanitizeRule::Slug => quote! {
#access = {
let lower = #access.to_lowercase();
let mut out = String::with_capacity(lower.len());
let mut last_dash = false;
for ch in lower.chars() {
if ch.is_ascii_alphanumeric() {
out.push(ch);
last_dash = false;
} else if !last_dash {
out.push('-');
last_dash = true;
}
}
out.trim_matches('-').to_string()
};
},
SanitizeRule::Clamp(min, max) => quote! {
#access = #access.clamp(#min, #max);
},
}
}
fn process_field(
field: &Field,
) -> Result<Option<(syn::Ident, proc_macro2::TokenStream)>, syn::Error> {
let Some((raw_rules, span)) = get_holy_string_value(&field.attrs, "sanitize") else {
return Ok(None);
};
let field_name = field.ident.as_ref().unwrap().clone();
let type_kind = classify_type(&field.ty);
let rules = parse_sanitize_rules(&raw_rules, span)?;
for rule in &rules {
validate_rule_for_type(rule, &type_kind, &field_name, span)?;
}
let body = match type_kind {
FieldTypeKind::OptionString => {
let access = quote! { (*__s) };
let rule_tokens = rules.iter().map(|r| rule_to_tokens(&access, r));
quote! {
if let Some(__s) = self.#field_name.as_mut() {
#(#rule_tokens)*
}
}
}
_ => {
let access = quote! { self.#field_name };
let rule_tokens = rules.iter().map(|r| rule_to_tokens(&access, r));
quote! { #(#rule_tokens)* }
}
};
Ok(Some((field_name, body)))
}
pub fn impl_sanitize_macro(ast: &DeriveInput) -> Result<TokenStream, syn::Error> {
let struct_name = &ast.ident;
let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
let fields = match &ast.data {
Data::Struct(data) => match &data.fields {
Fields::Named(named) => &named.named,
_ => {
return Err(syn::Error::new_spanned(
ast,
"Sanitize macro only supports structs with named fields",
));
}
},
_ => {
return Err(syn::Error::new_spanned(
ast,
"Sanitize macro only supports structs",
));
}
};
let mut per_field_methods = Vec::new();
let mut all_field_calls = Vec::new();
for field in fields.iter() {
let Some((field_name, body)) = process_field(field)? else {
continue;
};
let sanitize_method_name =
syn::Ident::new(&format!("sanitize_{}", field_name), field_name.span());
let method_vis = determine_visibility(&field.vis, &field.attrs)?;
per_field_methods.push(quote! {
#method_vis fn #sanitize_method_name(&mut self) {
#body
}
});
all_field_calls.push(quote! {
self.#sanitize_method_name();
});
}
if per_field_methods.is_empty() {
return Ok(TokenStream::from(quote! {}));
}
let expanded = quote! {
impl #impl_generics #struct_name #ty_generics #where_clause {
pub fn sanitize(&mut self) {
#(#all_field_calls)*
}
#(#per_field_methods)*
}
};
Ok(TokenStream::from(expanded))
}