use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Field, Fields};
#[proc_macro_derive(DnfEvaluable, attributes(dnf))]
pub fn derive_dnf_evaluable(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let fields = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => {
return syn::Error::new_spanned(
&input,
"DnfEvaluable can only be derived for structs with named fields",
)
.to_compile_error()
.into();
}
},
_ => {
return syn::Error::new_spanned(&input, "DnfEvaluable can only be derived for structs")
.to_compile_error()
.into();
}
};
let match_arms = fields.iter().filter_map(generate_field_match_arm);
let nested_match_arms = fields.iter().filter_map(generate_nested_field_match_arm);
let field_infos = fields.iter().filter_map(generate_field_info);
let field_value_arms = fields.iter().filter_map(generate_field_value_arm);
let validate_path_arms = fields.iter().filter_map(generate_validate_field_path_arm);
let expanded = quote! {
impl dnf::DnfEvaluable for #name {
fn evaluate_field(
&self,
field_name: &str,
operator: &dnf::Op,
value: &dnf::Value
) -> bool {
match field_name {
#(#match_arms)*
_ => {
if let Some(dot_pos) = field_name.find('.') {
let (outer, inner) = field_name.split_at(dot_pos);
let inner = &inner[1..]; match outer {
#(#nested_match_arms)*
_ => false,
}
} else {
false }
}
}
}
fn field_value(&self, field_name: &str) -> Option<dnf::Value> {
match field_name {
#(#field_value_arms)*
_ => None,
}
}
fn fields() -> impl Iterator<Item = dnf::FieldInfo> {
[
#(#field_infos),*
].into_iter()
}
fn validate_field_path(path: &str) -> Option<dnf::FieldKind> {
if let Some(dot) = path.find('.') {
let (head, tail) = path.split_at(dot);
let tail = &tail[1..];
match head {
#(#validate_path_arms)*
_ => {
let _ = tail;
<Self as dnf::DnfEvaluable>::fields()
.find(|f| f.name() == head)
.map(|f| f.kind())
}
}
} else {
<Self as dnf::DnfEvaluable>::fields()
.find(|f| f.name() == path)
.map(|f| f.kind())
}
}
}
};
TokenStream::from(expanded)
}
fn generate_field_match_arm(field: &Field) -> Option<proc_macro2::TokenStream> {
let field_name = field.ident.as_ref()?;
let field_type = &field.ty;
if has_skip_attribute(field) {
return None;
}
let type_str = quote!(#field_type).to_string().replace(" ", "");
let has_iter = get_iter_attribute(field).is_some();
if has_nested_attribute(field) || (!has_iter && is_nested_type(&type_str)) {
return None;
}
let query_name = get_rename_attribute(field).unwrap_or_else(|| field_name.to_string());
let value_conversion = generate_value_conversion(field, field_name, field_type);
Some(quote! {
#query_name => #value_conversion,
})
}
fn generate_field_value_arm(field: &Field) -> Option<proc_macro2::TokenStream> {
let field_name = field.ident.as_ref()?;
let field_type = &field.ty;
if has_skip_attribute(field) {
return None;
}
let type_str = quote!(#field_type).to_string().replace(" ", "");
let has_iter = get_iter_attribute(field).is_some();
if has_nested_attribute(field) || (!has_iter && is_nested_type(&type_str)) {
return None;
}
if !is_value_convertible(&type_str) {
return None;
}
let query_name = get_rename_attribute(field).unwrap_or_else(|| field_name.to_string());
let value_conversion = if type_str.starts_with("Option<") {
quote! {
match &self.#field_name {
Some(v) => Some(dnf::Value::from(v)),
None => Some(dnf::Value::None),
}
}
} else {
quote! {
Some(dnf::Value::from(&self.#field_name))
}
};
Some(quote! {
#query_name => #value_conversion,
})
}
fn is_value_convertible(type_str: &str) -> bool {
let primitives = [
"i8", "i16", "i32", "i64", "isize", "u8", "u16", "u32", "u64", "usize", "f32", "f64",
"bool", "String",
];
if primitives.contains(&type_str) {
return true;
}
if type_str.starts_with("&") && type_str.contains("str") {
return true;
}
if type_str.starts_with("Cow<") && type_str.contains("str") {
return true;
}
if type_str == "Box<str>" {
return true;
}
if type_str.starts_with("Vec<") {
if let Some(inner) = type_str
.strip_prefix("Vec<")
.and_then(|s| s.strip_suffix(">"))
{
return primitives.contains(&inner);
}
}
if type_str.starts_with("HashSet<") {
if let Some(inner) = type_str
.strip_prefix("HashSet<")
.and_then(|s| s.strip_suffix(">"))
{
return primitives.contains(&inner) && inner != "f32" && inner != "f64";
}
}
if type_str.starts_with("Option<") {
if let Some(inner) = type_str
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix(">"))
{
return is_value_convertible(inner);
}
}
false
}
fn is_nested_type(type_str: &str) -> bool {
if type_str.starts_with("Vec<") {
if let Some(inner) = type_str
.strip_prefix("Vec<")
.and_then(|s| s.strip_suffix(">"))
{
return !is_primitive_or_builtin(inner);
}
}
if type_str.starts_with("Option<Vec<") {
if let Some(inner) = type_str
.strip_prefix("Option<Vec<")
.and_then(|s| s.strip_suffix(">>"))
{
return !is_primitive_or_builtin(inner);
}
}
if is_map_type(type_str) {
if let Some((_, value_type)) = extract_map_types(type_str) {
return !is_primitive_or_builtin(&value_type);
}
}
if type_str.starts_with("Option<HashMap<") || type_str.starts_with("Option<BTreeMap<") {
if let Some(inner) = type_str
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix(">"))
{
if let Some((_, value_type)) = extract_map_types(inner) {
return !is_primitive_or_builtin(&value_type);
}
}
}
false
}
fn is_map_type(type_str: &str) -> bool {
type_str.starts_with("HashMap<") || type_str.starts_with("BTreeMap<")
}
fn extract_map_types(type_str: &str) -> Option<(String, String)> {
let inner = type_str
.strip_prefix("HashMap<")
.or_else(|| type_str.strip_prefix("BTreeMap<"))?;
let inner = inner.strip_suffix(">")?;
let mut depth = 0;
let mut comma_pos = None;
for (i, c) in inner.char_indices() {
match c {
'<' => depth += 1,
'>' => depth -= 1,
',' if depth == 0 => {
comma_pos = Some(i);
break;
}
_ => {}
}
}
let pos = comma_pos?;
let key = inner[..pos].trim().to_string();
let value = inner[pos + 1..].trim().to_string();
Some((key, value))
}
fn is_string_key(key_type: &str) -> bool {
let t = key_type.trim();
matches!(t, "String" | "str" | "&str")
|| (t.starts_with("&'") && (t.ends_with("str") || t.ends_with(" str")))
}
fn has_skip_attribute(field: &Field) -> bool {
for attr in &field.attrs {
if attr.path().is_ident("dnf") {
let mut has_skip = false;
let _ = attr.parse_nested_meta(|meta| {
if meta.path.is_ident("skip") {
has_skip = true;
}
Ok(())
});
if has_skip {
return true;
}
}
}
false
}
fn has_nested_attribute(field: &Field) -> bool {
for attr in &field.attrs {
if attr.path().is_ident("dnf") {
let mut has_nested = false;
let _ = attr.parse_nested_meta(|meta| {
if meta.path.is_ident("nested") {
has_nested = true;
}
Ok(())
});
if has_nested {
return true;
}
}
}
false
}
fn generate_nested_field_match_arm(field: &Field) -> Option<proc_macro2::TokenStream> {
let field_name = field.ident.as_ref()?;
let field_type = &field.ty;
if has_skip_attribute(field) {
return None;
}
let type_str = quote!(#field_type).to_string().replace(" ", "");
let has_iter = get_iter_attribute(field).is_some();
if has_iter {
return None;
}
if !has_nested_attribute(field) && !is_nested_type(&type_str) {
return None;
}
let query_name = get_rename_attribute(field).unwrap_or_else(|| field_name.to_string());
let delegation_code = if type_str.starts_with("Vec<") {
quote! {
self.#field_name.iter().any(|item| item.evaluate_field(inner, operator, value))
}
} else if type_str.starts_with("Option<Vec<") {
quote! {
match &self.#field_name {
Some(vec) => vec.iter().any(|item| item.evaluate_field(inner, operator, value)),
None => false,
}
}
} else if type_str.starts_with("HashMap<") || type_str.starts_with("BTreeMap<") {
quote! {
if let Some(rest) = inner.strip_prefix("@values.") {
self.#field_name.values().any(|item| item.evaluate_field(rest, operator, value))
} else if inner == "@keys" {
operator.any(self.#field_name.keys(), value)
} else if inner.starts_with("[\"") {
if let Some(end_bracket) = inner.find("\"]") {
let key = &inner[2..end_bracket];
let rest = inner.get(end_bracket + 2..).unwrap_or("").trim_start_matches('.');
if rest.is_empty() {
false
} else {
match self.#field_name.get(key) {
Some(item) => item.evaluate_field(rest, operator, value),
None => false,
}
}
} else {
false
}
} else {
false
}
}
} else if type_str.starts_with("Option<HashMap<") || type_str.starts_with("Option<BTreeMap<") {
quote! {
match &self.#field_name {
Some(map) => {
if let Some(rest) = inner.strip_prefix("@values.") {
map.values().any(|item| item.evaluate_field(rest, operator, value))
} else if inner == "@keys" {
operator.any(map.keys(), value)
} else if inner.starts_with("[\"") {
if let Some(end_bracket) = inner.find("\"]") {
let key = &inner[2..end_bracket];
let rest = inner.get(end_bracket + 2..).unwrap_or("").trim_start_matches('.');
if rest.is_empty() {
false
} else {
match map.get(key) {
Some(item) => item.evaluate_field(rest, operator, value),
None => false,
}
}
} else {
false
}
} else {
false
}
},
None => false,
}
}
} else if type_str.starts_with("Option<") {
quote! {
match &self.#field_name {
Some(inner_val) => inner_val.evaluate_field(inner, operator, value),
None => false,
}
}
} else {
quote! {
self.#field_name.evaluate_field(inner, operator, value)
}
};
Some(quote! {
#query_name => #delegation_code,
})
}
fn generate_validate_field_path_arm(field: &Field) -> Option<proc_macro2::TokenStream> {
let field_name = field.ident.as_ref()?;
let field_type = &field.ty;
if has_skip_attribute(field) {
return None;
}
let type_str = quote!(#field_type).to_string().replace(" ", "");
if !has_nested_attribute(field) {
return None;
}
let is_collection = type_str.starts_with("Vec<")
|| type_str.starts_with("Option<Vec<")
|| type_str.starts_with("HashMap<")
|| type_str.starts_with("BTreeMap<")
|| type_str.starts_with("Option<HashMap<")
|| type_str.starts_with("Option<BTreeMap<")
|| type_str.starts_with("HashSet<")
|| type_str.starts_with("BTreeSet<");
if is_collection {
return None;
}
let query_name = get_rename_attribute(field).unwrap_or_else(|| field_name.to_string());
let inner_type_str = type_str
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix(">"))
.unwrap_or(&type_str)
.to_string();
let inner_type: syn::Type = syn::parse_str(&inner_type_str).ok()?;
Some(quote! {
#query_name => <#inner_type as dnf::DnfEvaluable>::validate_field_path(tail),
})
}
fn get_rename_attribute(field: &Field) -> Option<String> {
for attr in &field.attrs {
if attr.path().is_ident("dnf") {
let mut rename_value = None;
let _ = attr.parse_nested_meta(|meta| {
if meta.path.is_ident("rename") {
if let Ok(value) = meta.value() {
if let Ok(lit_str) = value.parse::<syn::LitStr>() {
rename_value = Some(lit_str.value());
}
}
}
Ok(())
});
if let Some(name) = rename_value {
return Some(name);
}
}
}
None
}
fn get_iter_attribute(field: &Field) -> Option<Option<String>> {
for attr in &field.attrs {
if attr.path().is_ident("dnf") {
let mut has_iter = false;
let mut iter_method = None;
let _ = attr.parse_nested_meta(|meta| {
if meta.path.is_ident("iter") {
has_iter = true;
if let Ok(value) = meta.value() {
if let Ok(lit_str) = value.parse::<syn::LitStr>() {
iter_method = Some(lit_str.value());
}
}
}
Ok(())
});
if has_iter {
return Some(iter_method);
}
}
}
None
}
fn generate_value_conversion(
field: &Field,
field_name: &syn::Ident,
_field_type: &syn::Type,
) -> proc_macro2::TokenStream {
if let Some(iter_method) = get_iter_attribute(field) {
let method = iter_method.unwrap_or_else(|| "iter".to_string());
let method_ident = syn::Ident::new(&method, field_name.span());
return quote! {
operator.any(self.#field_name.#method_ident(), value)
};
}
quote! {
dnf::DnfField::evaluate(&self.#field_name, operator, value)
}
}
fn is_primitive_or_builtin(type_str: &str) -> bool {
let primitives = [
"i8", "i16", "i32", "i64", "isize", "u8", "u16", "u32", "u64", "usize", "f32", "f64",
"bool", "String",
];
if primitives.contains(&type_str) {
return true;
}
if type_str.starts_with("&") && type_str.contains("str") {
return true;
}
if type_str.starts_with("Cow<") && type_str.contains("str") {
return true;
}
if type_str == "Box<str>" {
return true;
}
if type_str.starts_with("Vec<") {
if let Some(inner) = type_str.strip_prefix("Vec<") {
if let Some(inner) = inner.strip_suffix(">") {
return is_primitive_or_builtin(inner);
}
}
}
if type_str.starts_with("HashSet<") {
if let Some(inner) = type_str.strip_prefix("HashSet<") {
if let Some(inner) = inner.strip_suffix(">") {
if inner == "f32" || inner == "f64" {
return false;
}
return is_primitive_or_builtin(inner);
}
}
}
if is_map_type(type_str) {
if let Some((key_type, value_type)) = extract_map_types(type_str) {
return is_string_key(&key_type) && is_primitive_or_builtin(&value_type);
}
}
false
}
fn generate_field_info(field: &Field) -> Option<proc_macro2::TokenStream> {
let field_name = field.ident.as_ref()?;
let field_type = &field.ty;
if has_skip_attribute(field) {
return None;
}
let query_name = get_rename_attribute(field).unwrap_or_else(|| field_name.to_string());
let type_str = quote!(#field_type).to_string();
let type_str_normalized = type_str.replace(" ", "");
let field_kind = if get_iter_attribute(field).is_some() {
quote! { dnf::FieldKind::Iter }
} else if is_map_type(&type_str_normalized) {
quote! { dnf::FieldKind::Map }
} else if type_str_normalized.starts_with("Vec<")
|| type_str_normalized.starts_with("HashSet<")
|| type_str_normalized.starts_with("BTreeSet<")
{
quote! { dnf::FieldKind::Iter }
} else {
quote! { dnf::FieldKind::Scalar }
};
Some(quote! {
dnf::FieldInfo::with_kind(#query_name, #type_str, #field_kind)
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_all_field_types_use_dnf_field() {
let types = [
"String",
"u32",
"i64",
"f64",
"bool",
"Vec<String>",
"HashSet<i32>",
"Score",
"CustomEnum",
"MyStruct",
];
for type_str in types {
let input_str = format!("struct User {{ field: {} }}", type_str);
let input: proc_macro2::TokenStream = input_str.parse().unwrap();
let parsed: DeriveInput = syn::parse2(input).unwrap();
let fields = match &parsed.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => continue,
},
_ => continue,
};
if let Some(field) = fields.first() {
let conversion =
generate_value_conversion(field, field.ident.as_ref().unwrap(), &field.ty);
let conversion_str = conversion.to_string();
assert!(
conversion_str.contains("DnfField :: evaluate"),
"Type {} should use DnfField::evaluate(), got: {}",
type_str,
conversion_str
);
}
}
}
#[test]
fn test_iter_attribute_generates_any() {
let input_str = "struct User { #[dnf(iter)] field: LinkedList<String> }";
let input: proc_macro2::TokenStream = input_str.parse().unwrap();
let parsed: DeriveInput = syn::parse2(input).unwrap();
let fields = match &parsed.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => panic!("Expected named fields"),
},
_ => panic!("Expected struct"),
};
let field = fields.first().unwrap();
let conversion = generate_value_conversion(field, field.ident.as_ref().unwrap(), &field.ty);
let conversion_str = conversion.to_string();
assert!(
conversion_str.contains("any") && conversion_str.contains(". iter ()"),
"Expected any with .iter(), got: {}",
conversion_str
);
}
#[test]
fn test_iter_attribute_with_custom_method() {
let input_str = "struct User { #[dnf(iter = \"items\")] field: CustomList<i32> }";
let input: proc_macro2::TokenStream = input_str.parse().unwrap();
let parsed: DeriveInput = syn::parse2(input).unwrap();
let fields = match &parsed.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => panic!("Expected named fields"),
},
_ => panic!("Expected struct"),
};
let field = fields.first().unwrap();
let conversion = generate_value_conversion(field, field.ident.as_ref().unwrap(), &field.ty);
let conversion_str = conversion.to_string();
assert!(
conversion_str.contains("any") && conversion_str.contains(". items ()"),
"Expected any with .items(), got: {}",
conversion_str
);
}
}