use proc_macro2::TokenStream;
use quote::quote;
use prax_schema::ast::{Field, FieldType, ScalarType, TypeModifier};
use super::{pascal_ident, snake_ident};
use crate::types::{field_type_to_rust, supports_comparison, supports_in_op, supports_string_ops};
pub fn generate_field_filters(field: &Field, _model_name: &str) -> TokenStream {
let field_name = snake_ident(field.name());
let field_type = field_type_to_rust(&field.field_type, &TypeModifier::Required);
let where_variant = pascal_ident(field.name());
let col_name = field
.attributes
.iter()
.find(|a| a.name() == "map")
.and_then(|a| a.first_arg())
.and_then(|v| v.as_string())
.map(|s| s.to_string())
.unwrap_or_else(|| field.name().to_string());
let is_optional = field.modifier.is_optional();
let mut ops = vec![
quote! {
pub fn equals(value: #field_type) -> super::WhereParam {
super::WhereParam::#where_variant(WhereOp::Equals(value))
}
},
quote! {
pub fn not(value: #field_type) -> super::WhereParam {
super::WhereParam::#where_variant(WhereOp::Not(value))
}
},
];
if is_optional {
ops.push(quote! {
pub fn is_null() -> super::WhereParam {
super::WhereParam::#where_variant(WhereOp::IsNull)
}
});
ops.push(quote! {
pub fn is_not_null() -> super::WhereParam {
super::WhereParam::#where_variant(WhereOp::IsNotNull)
}
});
}
if supports_in_op(&field.field_type) {
ops.push(quote! {
pub fn in_(values: Vec<#field_type>) -> super::WhereParam {
super::WhereParam::#where_variant(WhereOp::In(values))
}
});
ops.push(quote! {
pub fn not_in(values: Vec<#field_type>) -> super::WhereParam {
super::WhereParam::#where_variant(WhereOp::NotIn(values))
}
});
}
if let FieldType::Scalar(scalar) = &field.field_type {
if supports_comparison(scalar) {
ops.push(quote! {
pub fn gt(value: #field_type) -> super::WhereParam {
super::WhereParam::#where_variant(WhereOp::Gt(value))
}
});
ops.push(quote! {
pub fn gte(value: #field_type) -> super::WhereParam {
super::WhereParam::#where_variant(WhereOp::Gte(value))
}
});
ops.push(quote! {
pub fn lt(value: #field_type) -> super::WhereParam {
super::WhereParam::#where_variant(WhereOp::Lt(value))
}
});
ops.push(quote! {
pub fn lte(value: #field_type) -> super::WhereParam {
super::WhereParam::#where_variant(WhereOp::Lte(value))
}
});
}
if supports_string_ops(scalar) {
ops.push(quote! {
pub fn contains(value: impl Into<String>) -> super::WhereParam {
super::WhereParam::#where_variant(WhereOp::Contains(value.into()))
}
});
ops.push(quote! {
pub fn starts_with(value: impl Into<String>) -> super::WhereParam {
super::WhereParam::#where_variant(WhereOp::StartsWith(value.into()))
}
});
ops.push(quote! {
pub fn ends_with(value: impl Into<String>) -> super::WhereParam {
super::WhereParam::#where_variant(WhereOp::EndsWith(value.into()))
}
});
}
}
let where_op_variants = generate_where_op_variants(&field.field_type, is_optional);
let mut to_filter_arms = vec![
quote! { Self::Equals(v) => Filter::Equals(col, v.into()), },
quote! { Self::Not(v) => Filter::NotEquals(col, v.into()), },
];
if is_optional {
to_filter_arms.push(quote! { Self::IsNull => Filter::IsNull(col), });
to_filter_arms.push(quote! { Self::IsNotNull => Filter::IsNotNull(col), });
}
if supports_in_op(&field.field_type) {
to_filter_arms.push(quote! {
Self::In(vs) => Filter::In(col, vs.into_iter().map(Into::into).collect()),
});
to_filter_arms.push(quote! {
Self::NotIn(vs) => Filter::NotIn(col, vs.into_iter().map(Into::into).collect()),
});
}
if let FieldType::Scalar(scalar) = &field.field_type {
if supports_comparison(scalar) {
to_filter_arms.push(quote! { Self::Gt(v) => Filter::Gt(col, v.into()), });
to_filter_arms.push(quote! { Self::Gte(v) => Filter::Gte(col, v.into()), });
to_filter_arms.push(quote! { Self::Lt(v) => Filter::Lt(col, v.into()), });
to_filter_arms.push(quote! { Self::Lte(v) => Filter::Lte(col, v.into()), });
}
if supports_string_ops(scalar) {
to_filter_arms.push(quote! {
Self::Contains(v) => Filter::Contains(col, FilterValue::String(v)),
});
to_filter_arms.push(quote! {
Self::StartsWith(v) => Filter::StartsWith(col, FilterValue::String(v)),
});
to_filter_arms.push(quote! {
Self::EndsWith(v) => Filter::EndsWith(col, FilterValue::String(v)),
});
}
}
let _ = field_name; quote! {
#[derive(Debug, Clone)]
pub enum WhereOp {
#where_op_variants
}
impl WhereOp {
pub fn to_filter(self) -> prax_query::filter::Filter {
use prax_query::filter::{Filter, FilterValue};
use std::borrow::Cow;
let col: Cow<'static, str> = Cow::Borrowed(#col_name);
match self {
#(#to_filter_arms)*
}
}
}
#(#ops)*
}
}
fn generate_where_op_variants(field_type: &FieldType, is_optional: bool) -> TokenStream {
let base_type = match field_type {
FieldType::Scalar(s) => match s {
ScalarType::Int => quote! { i32 },
ScalarType::BigInt => quote! { i64 },
ScalarType::Float => quote! { f64 },
ScalarType::Decimal => quote! { rust_decimal::Decimal },
ScalarType::Boolean => quote! { bool },
ScalarType::String => quote! { String },
ScalarType::DateTime => quote! { chrono::DateTime<chrono::Utc> },
ScalarType::Date => quote! { chrono::NaiveDate },
ScalarType::Time => quote! { chrono::NaiveTime },
ScalarType::Json => quote! { serde_json::Value },
ScalarType::Bytes => quote! { Vec<u8> },
ScalarType::Uuid => quote! { uuid::Uuid },
ScalarType::Cuid | ScalarType::Cuid2 | ScalarType::NanoId | ScalarType::Ulid => {
quote! { String }
}
ScalarType::Vector(_) | ScalarType::HalfVector(_) => quote! { Vec<f32> },
ScalarType::SparseVector(_) => quote! { Vec<(u32, f32)> },
ScalarType::Bit(_) => quote! { Vec<u8> },
},
FieldType::Enum(name) | FieldType::Model(name) | FieldType::Composite(name) => {
let ident = quote::format_ident!("{}", name.to_string());
quote! { super::super::#ident }
}
FieldType::Unsupported(_) => quote! { String },
};
let mut variants = vec![quote! { Equals(#base_type) }, quote! { Not(#base_type) }];
if is_optional {
variants.push(quote! { IsNull });
variants.push(quote! { IsNotNull });
}
if supports_in_op(field_type) {
variants.push(quote! { In(Vec<#base_type>) });
variants.push(quote! { NotIn(Vec<#base_type>) });
}
if let FieldType::Scalar(scalar) = field_type {
if supports_comparison(scalar) {
variants.push(quote! { Gt(#base_type) });
variants.push(quote! { Gte(#base_type) });
variants.push(quote! { Lt(#base_type) });
variants.push(quote! { Lte(#base_type) });
}
if supports_string_ops(scalar) {
variants.push(quote! { Contains(String) });
variants.push(quote! { StartsWith(String) });
variants.push(quote! { EndsWith(String) });
}
}
quote! {
#(#variants,)*
}
}
#[cfg(test)]
mod tests {
use super::*;
use prax_schema::ast::{Ident, Span};
fn make_span() -> Span {
Span::new(0, 0)
}
fn make_ident(name: &str) -> Ident {
Ident::new(name, make_span())
}
#[test]
fn test_generate_string_field_filters() {
let field = Field::new(
make_ident("name"),
FieldType::Scalar(ScalarType::String),
TypeModifier::Required,
vec![],
make_span(),
);
let filters = generate_field_filters(&field, "User");
let code = filters.to_string();
assert!(code.contains("pub fn equals"));
assert!(code.contains("pub fn not"));
assert!(code.contains("pub fn contains"));
assert!(code.contains("pub fn starts_with"));
assert!(code.contains("pub fn ends_with"));
}
#[test]
fn test_generate_int_field_filters() {
let field = Field::new(
make_ident("age"),
FieldType::Scalar(ScalarType::Int),
TypeModifier::Required,
vec![],
make_span(),
);
let filters = generate_field_filters(&field, "User");
let code = filters.to_string();
assert!(code.contains("pub fn gt"));
assert!(code.contains("pub fn gte"));
assert!(code.contains("pub fn lt"));
assert!(code.contains("pub fn lte"));
}
#[test]
fn test_generate_optional_field_filters() {
let field = Field::new(
make_ident("bio"),
FieldType::Scalar(ScalarType::String),
TypeModifier::Optional,
vec![],
make_span(),
);
let filters = generate_field_filters(&field, "User");
let code = filters.to_string();
assert!(code.contains("pub fn is_null"));
assert!(code.contains("pub fn is_not_null"));
}
}