use proc_macro2::TokenStream;
use quote::quote;
struct SortVariant {
ident: syn::Ident,
sql_fragment: String,
}
pub fn expand_sort_enum(_attr: TokenStream, item: TokenStream) -> Result<TokenStream, syn::Error> {
let input: syn::ItemEnum = syn::parse2(item)?;
for variant in &input.variants {
if !matches!(variant.fields, syn::Fields::Unit) {
return Err(syn::Error::new_spanned(
variant,
"sort enum only supports unit variants (no fields)",
));
}
}
if input.variants.is_empty() {
return Err(syn::Error::new_spanned(
&input,
"sort enum requires at least one variant",
));
}
let variants = parse_sort_variants(&input)?;
let enum_name = &input.ident;
let vis = &input.vis;
let enum_attrs: Vec<_> = input.attrs.iter().collect();
let variant_defs = input.variants.iter().map(|v| {
let attrs: Vec<_> = v
.attrs
.iter()
.filter(|a| !a.path().is_ident("sql"))
.collect();
let ident = &v.ident;
quote! { #(#attrs)* #ident }
});
let sql_arms: Vec<TokenStream> = variants
.iter()
.map(|v| {
let ident = &v.ident;
let sql_fragment = &v.sql_fragment;
quote! { #enum_name::#ident => #sql_fragment }
})
.collect();
let display_arms: Vec<TokenStream> = variants
.iter()
.map(|v| {
let ident = &v.ident;
let sql_fragment = &v.sql_fragment;
quote! { #enum_name::#ident => #sql_fragment, }
})
.collect();
Ok(quote! {
#(#enum_attrs)*
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#vis enum #enum_name {
#(#variant_defs,)*
}
impl #enum_name {
pub fn sql(&self) -> &'static str {
match self {
#(#sql_arms,)*
}
}
}
impl ::std::fmt::Display for #enum_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
let label = match self {
#(#display_arms)*
};
f.write_str(label)
}
}
})
}
fn parse_sort_variants(input: &syn::ItemEnum) -> Result<Vec<SortVariant>, syn::Error> {
let mut variants = Vec::with_capacity(input.variants.len());
for variant in &input.variants {
let sql_fragment = extract_sql_attr(variant)?;
validate_sort_fragment(&sql_fragment, variant)?;
variants.push(SortVariant {
ident: variant.ident.clone(),
sql_fragment,
});
}
for (i, a) in variants.iter().enumerate() {
for b in variants.iter().skip(i + 1) {
if a.sql_fragment == b.sql_fragment {
return Err(syn::Error::new_spanned(
&input.variants[i],
format!(
"duplicate SQL fragment \"{}\" on variants `{}` and `{}`",
a.sql_fragment, a.ident, b.ident
),
));
}
}
}
Ok(variants)
}
fn validate_sort_fragment(fragment: &str, variant: &syn::Variant) -> Result<(), syn::Error> {
if fragment.contains(';') {
return Err(syn::Error::new_spanned(
variant,
format!(
"sort SQL fragment contains a semicolon, which is not allowed: \"{}\"",
fragment
),
));
}
if fragment.contains("--") || fragment.contains("/*") {
return Err(syn::Error::new_spanned(
variant,
format!(
"sort SQL fragment contains a SQL comment, which is not allowed: \"{}\"",
fragment
),
));
}
const DANGEROUS_KEYWORDS: &[&str] = &[
"DROP", "DELETE", "INSERT", "UPDATE", "CREATE", "ALTER", "TRUNCATE", "GRANT", "REVOKE",
"EXECUTE", "COPY",
];
let upper = fragment.to_ascii_uppercase();
for keyword in DANGEROUS_KEYWORDS {
for (pos, _) in upper.match_indices(keyword) {
let before_ok = pos == 0
|| !upper.as_bytes()[pos - 1].is_ascii_alphanumeric()
&& upper.as_bytes()[pos - 1] != b'_';
let after_pos = pos + keyword.len();
let after_ok = after_pos >= upper.len()
|| !upper.as_bytes()[after_pos].is_ascii_alphanumeric()
&& upper.as_bytes()[after_pos] != b'_';
if before_ok && after_ok {
return Err(syn::Error::new_spanned(
variant,
format!(
"sort SQL fragment contains disallowed keyword `{keyword}`: \"{fragment}\". \
Sort fragments must contain only ORDER BY expressions (column references, \
ASC/DESC, NULLS FIRST/LAST)."
),
));
}
}
}
Ok(())
}
fn extract_sql_attr(variant: &syn::Variant) -> Result<String, syn::Error> {
for attr in &variant.attrs {
if attr.path().is_ident("sql") {
let label: syn::LitStr = attr.parse_args()?;
let value = label.value();
if value.is_empty() {
return Err(syn::Error::new_spanned(
attr,
"SQL fragment cannot be empty",
));
}
return Ok(value);
}
}
Err(syn::Error::new_spanned(
variant,
format!(
"variant `{}` is missing #[sql(\"...\")] attribute",
variant.ident
),
))
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_sort(tokens: TokenStream) -> TokenStream {
expand_sort_enum(TokenStream::new(), tokens).expect("should parse")
}
#[test]
fn basic_sort_enum_generates_code() {
let input = quote! {
enum TicketSort {
#[sql("t.updated_at DESC, t.id DESC")]
UpdatedAt,
#[sql("t.deadline ASC NULLS LAST, t.id ASC")]
Deadline,
#[sql("t.id DESC")]
Id,
}
};
let output = parse_sort(input);
let code = output.to_string();
assert!(code.contains("enum TicketSort"), "missing enum: {code}");
assert!(code.contains("fn sql"), "missing sql method: {code}");
assert!(
code.contains("t.updated_at DESC"),
"missing sql fragment: {code}"
);
assert!(code.contains("Display"), "missing Display: {code}");
assert!(code.contains("Debug"), "missing Debug derive: {code}");
assert!(code.contains("Clone"), "missing Clone derive: {code}");
assert!(code.contains("Copy"), "missing Copy derive: {code}");
}
#[test]
fn missing_sql_attr_errors() {
let input = quote! {
enum Sort {
#[sql("a DESC")]
A,
B,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("missing #[sql"), "unexpected error: {err}");
}
#[test]
fn empty_sql_fragment_errors() {
let input = quote! {
enum Sort {
#[sql("")]
A,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("cannot be empty"), "unexpected error: {err}");
}
#[test]
fn non_unit_variant_errors() {
let input = quote! {
enum Sort {
#[sql("a")]
A(i32),
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("unit variants"), "unexpected error: {err}");
}
#[test]
fn empty_enum_errors() {
let input = quote! {
enum Sort {}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_err());
}
#[test]
fn duplicate_sql_fragment_errors() {
let input = quote! {
enum Sort {
#[sql("a DESC")]
A,
#[sql("a DESC")]
B,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("duplicate SQL fragment"),
"unexpected error: {err}"
);
}
#[test]
fn visibility_preserved() {
let input = quote! {
pub enum Sort {
#[sql("a")]
A,
}
};
let output = parse_sort(input);
let code = output.to_string();
assert!(code.contains("pub enum Sort"), "visibility lost: {code}");
}
#[test]
fn no_encode_impl_generated() {
let input = quote! {
enum Sort {
#[sql("a")]
A,
}
};
let output = parse_sort(input);
let code = output.to_string();
assert!(
!code.contains("Encode"),
"sort enum should not have Encode: {code}"
);
}
#[test]
fn single_variant() {
let input = quote! {
enum Sort {
#[sql("id ASC")]
Id,
}
};
let output = parse_sort(input);
let code = output.to_string();
assert!(code.contains("id ASC"), "missing fragment: {code}");
}
#[test]
fn struct_not_accepted() {
let input = quote! {
struct NotAnEnum {
field: i32,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_err(), "structs should be rejected");
}
#[test]
fn sort_fragment_with_semicolon_rejected() {
let input = quote! {
enum Sort {
#[sql("id ASC; DROP TABLE users")]
Bad,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("semicolon"), "should mention semicolon: {err}");
}
#[test]
fn sort_fragment_with_comment_rejected() {
let input = quote! {
enum Sort {
#[sql("id ASC -- sneaky")]
Bad,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("comment"), "should mention comment: {err}");
}
#[test]
fn sort_fragment_with_block_comment_rejected() {
let input = quote! {
enum Sort {
#[sql("id ASC /* hidden */")]
Bad,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("comment"), "should mention comment: {err}");
}
#[test]
fn sort_fragment_with_drop_rejected() {
let input = quote! {
enum Sort {
#[sql("(SELECT DROP FROM evil)")]
Bad,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("DROP"), "should mention DROP: {err}");
}
#[test]
fn sort_fragment_with_delete_rejected() {
let input = quote! {
enum Sort {
#[sql("DELETE FROM users")]
Bad,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("DELETE"), "should mention DELETE: {err}");
}
#[test]
fn sort_fragment_keyword_in_column_name_accepted() {
let input = quote! {
enum Sort {
#[sql("updated_at DESC")]
UpdatedAt,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(
result.is_ok(),
"column name containing keyword substring should be OK"
);
}
#[test]
fn sort_fragment_valid_complex_accepted() {
let input = quote! {
enum Sort {
#[sql("t.deadline ASC NULLS LAST, t.id ASC")]
Deadline,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(
result.is_ok(),
"valid complex sort fragment should be accepted"
);
}
#[test]
fn sort_fragment_created_at_desc_nulls_last() {
let input = quote! {
enum Sort {
#[sql("created_at DESC NULLS LAST")]
CreatedAt,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(
result.is_ok(),
"created_at DESC NULLS LAST should be accepted"
);
}
#[test]
fn sort_fragment_updated_at_asc_not_confused_with_update() {
let input = quote! {
enum Sort {
#[sql("updated_at ASC")]
UpdatedAt,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(
result.is_ok(),
"updated_at should not be confused with UPDATE keyword"
);
}
#[test]
fn sort_fragment_insert_rejected() {
let input = quote! {
enum Sort {
#[sql("INSERT INTO t VALUES (1)")]
Bad,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("INSERT"), "should mention INSERT: {err}");
}
#[test]
fn sort_fragment_standalone_update_rejected() {
let input = quote! {
enum Sort {
#[sql("UPDATE t SET x = 1")]
Bad,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("UPDATE"), "should mention UPDATE: {err}");
}
#[test]
fn sort_fragment_truncate_rejected() {
let input = quote! {
enum Sort {
#[sql("TRUNCATE users")]
Bad,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_err());
}
#[test]
fn sort_fragment_execute_substring_in_column_accepted() {
let input = quote! {
enum Sort {
#[sql("executed_at DESC")]
ExecutedAt,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(
result.is_ok(),
"executed_at should be accepted (EXECUTE is substring, not word)"
);
}
#[test]
fn sort_fragment_created_substring_accepted() {
let input = quote! {
enum Sort {
#[sql("created_by ASC")]
CreatedBy,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(result.is_ok(), "created_by should be accepted");
}
#[test]
fn sort_fragment_deleted_at_accepted() {
let input = quote! {
enum Sort {
#[sql("deleted_at DESC NULLS LAST")]
DeletedAt,
}
};
let result = expand_sort_enum(TokenStream::new(), input);
assert!(
result.is_ok(),
"deleted_at should be accepted (DELETE is substring, not standalone word)"
);
}
}