#[cfg(feature = "sea-orm")]
use sea_orm::{sea_query::{SimpleExpr, ExprTrait}, Condition};
use crate::ast::{Comparator, Expression, Filter, Restriction, Value};
use crate::error::{FilterError, Result};
pub fn column_from_str<C>(field: &str) -> Result<SimpleExpr>
where
C: std::str::FromStr + sea_orm::IntoSimpleExpr,
<C as std::str::FromStr>::Err: std::fmt::Display,
{
if let Ok(column) = field.parse::<C>() {
return Ok(column.into_simple_expr());
}
let pascal_case = field
.split('_')
.map(|word| {
let mut chars = word.chars();
match chars.next() {
None => String::new(),
Some(first) => first.to_uppercase().chain(chars).collect(),
}
})
.collect::<String>();
pascal_case
.parse::<C>()
.map(|c| c.into_simple_expr())
.map_err(|e| FilterError::InvalidField(format!("{}: {}", field, e)))
}
pub trait ToSeaOrmCondition {
fn to_condition<C>(&self) -> Result<Condition>
where
C: std::str::FromStr + sea_orm::IntoSimpleExpr,
<C as std::str::FromStr>::Err: std::fmt::Display;
}
impl ToSeaOrmCondition for Filter {
fn to_condition<C>(&self) -> Result<Condition>
where
C: std::str::FromStr + sea_orm::IntoSimpleExpr,
<C as std::str::FromStr>::Err: std::fmt::Display,
{
expression_to_condition::<C>(&self.expression)
}
}
fn expression_to_condition<C>(expr: &Expression) -> Result<Condition>
where
C: std::str::FromStr + sea_orm::IntoSimpleExpr,
<C as std::str::FromStr>::Err: std::fmt::Display,
{
match expr {
Expression::And(left, right) => {
let left_cond = expression_to_condition::<C>(left)?;
let right_cond = expression_to_condition::<C>(right)?;
Ok(Condition::all().add(left_cond).add(right_cond))
}
Expression::Or(left, right) => {
let left_cond = expression_to_condition::<C>(left)?;
let right_cond = expression_to_condition::<C>(right)?;
Ok(Condition::any().add(left_cond).add(right_cond))
}
Expression::Not(inner) => {
let inner_cond = expression_to_condition::<C>(inner)?;
Ok(inner_cond.not())
}
Expression::Restriction(restriction) => {
restriction_to_condition::<C>(restriction)
}
Expression::Sequence(_) => {
Err(FilterError::UnsupportedOperation(
"Sequences are not yet supported in SeaORM conversion".to_string(),
))
}
}
}
fn restriction_to_condition<C>(restriction: &Restriction) -> Result<Condition>
where
C: std::str::FromStr + sea_orm::IntoSimpleExpr,
<C as std::str::FromStr>::Err: std::fmt::Display,
{
let column = column_from_str::<C>(&restriction.field)?;
let condition = match (&restriction.comparator, &restriction.value) {
(Comparator::Equal, Value::String(s)) => {
Condition::all().add(column.eq(s.as_str()))
}
(Comparator::Equal, Value::Number(n)) => {
if n.fract() == 0.0 {
Condition::all().add(column.eq(*n as i64))
} else {
Condition::all().add(column.eq(*n))
}
}
(Comparator::Equal, Value::Boolean(b)) => {
Condition::all().add(column.eq(*b))
}
(Comparator::Equal, Value::Null) => {
Condition::all().add(column.is_null())
}
(Comparator::NotEqual, Value::String(s)) => {
Condition::all().add(column.ne(s.as_str()))
}
(Comparator::NotEqual, Value::Number(n)) => {
if n.fract() == 0.0 {
Condition::all().add(column.ne(*n as i64))
} else {
Condition::all().add(column.ne(*n))
}
}
(Comparator::NotEqual, Value::Boolean(b)) => {
Condition::all().add(column.ne(*b))
}
(Comparator::NotEqual, Value::Null) => {
Condition::all().add(column.is_not_null())
}
(Comparator::GreaterThan, Value::Number(n)) => {
if n.fract() == 0.0 {
Condition::all().add(column.gt(*n as i64))
} else {
Condition::all().add(column.gt(*n))
}
}
(Comparator::GreaterThan, Value::String(s)) => {
Condition::all().add(column.gt(s.as_str()))
}
(Comparator::GreaterThanOrEqual, Value::Number(n)) => {
if n.fract() == 0.0 {
Condition::all().add(column.gte(*n as i64))
} else {
Condition::all().add(column.gte(*n))
}
}
(Comparator::GreaterThanOrEqual, Value::String(s)) => {
Condition::all().add(column.gte(s.as_str()))
}
(Comparator::LessThan, Value::Number(n)) => {
if n.fract() == 0.0 {
Condition::all().add(column.lt(*n as i64))
} else {
Condition::all().add(column.lt(*n))
}
}
(Comparator::LessThan, Value::String(s)) => {
Condition::all().add(column.lt(s.as_str()))
}
(Comparator::LessThanOrEqual, Value::Number(n)) => {
if n.fract() == 0.0 {
Condition::all().add(column.lte(*n as i64))
} else {
Condition::all().add(column.lte(*n))
}
}
(Comparator::LessThanOrEqual, Value::String(s)) => {
Condition::all().add(column.lte(s.as_str()))
}
(Comparator::Has, Value::String(s)) => {
Condition::all().add(column.like(format!("%{}%", s)))
}
_ => {
return Err(FilterError::UnsupportedOperation(format!(
"Unsupported combination: {} with {:?}",
restriction.comparator, restriction.value
)))
}
};
Ok(condition)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::parse_filter;
use sea_orm::sea_query::{Iden, SimpleExpr};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Column {
Id,
Name,
Email,
Age,
Active,
CreatedAt,
UserName, }
impl std::str::FromStr for Column {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"Id" => Ok(Column::Id),
"Name" => Ok(Column::Name),
"Email" => Ok(Column::Email),
"Age" => Ok(Column::Age),
"Active" => Ok(Column::Active),
"CreatedAt" => Ok(Column::CreatedAt),
"UserName" => Ok(Column::UserName),
_ => Err(format!("Unknown column: {}", s)),
}
}
}
impl Iden for Column {
fn unquoted(&self, s: &mut dyn std::fmt::Write) {
write!(
s,
"{}",
match self {
Column::Id => "id",
Column::Name => "name",
Column::Email => "email",
Column::Age => "age",
Column::Active => "active",
Column::CreatedAt => "created_at",
Column::UserName => "user_name",
}
)
.unwrap();
}
}
impl sea_orm::IntoSimpleExpr for Column {
fn into_simple_expr(self) -> SimpleExpr {
SimpleExpr::Column(sea_orm::sea_query::ColumnRef::Column(
sea_orm::sea_query::DynIden::new(self),
))
}
}
#[test]
fn test_simple_string_filter() {
let filter = parse_filter("name = \"John\"").unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok());
}
#[test]
fn test_number_filter() {
let filter = parse_filter("age > 18").unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok());
}
#[test]
fn test_float_number_filter() {
let filter = parse_filter("age > 18.5").unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok());
}
#[test]
fn test_boolean_filter() {
let filter = parse_filter("active = true").unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok());
}
#[test]
fn test_null_filter() {
let filter = parse_filter("email = NULL").unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok());
}
#[test]
fn test_not_null_filter() {
let filter = parse_filter("email != NULL").unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok());
}
#[test]
fn test_and_expression() {
let filter = parse_filter("age > 18 AND active = true").unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok());
}
#[test]
fn test_or_expression() {
let filter = parse_filter("name = \"John\" OR name = \"Jane\"").unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok());
}
#[test]
fn test_not_expression() {
let filter = parse_filter("NOT active = false").unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok());
}
#[test]
fn test_has_operator() {
let filter = parse_filter("email : \"@example.com\"").unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok());
}
#[test]
fn test_complex_filter() {
let filter = parse_filter(
"(name = \"Alice\" OR name = \"Bob\") AND age >= 21 AND active = true",
)
.unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok());
}
#[test]
fn test_all_comparators() {
let test_cases = vec![
("age = 25", true),
("age != 30", true),
("age > 20", true),
("age >= 21", true),
("age < 50", true),
("age <= 49", true),
("name : \"John\"", true),
];
for (filter_str, should_succeed) in test_cases {
let filter = parse_filter(filter_str).unwrap();
let result = filter.to_condition::<Column>();
assert_eq!(result.is_ok(), should_succeed, "Failed on: {}", filter_str);
}
}
#[test]
fn test_snake_case_to_pascal_case_conversion() {
let filter = parse_filter("user_name = \"test\"").unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok());
}
#[test]
fn test_created_at_snake_case() {
let filter = parse_filter("created_at > \"2024-01-01\"").unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok());
}
#[test]
fn test_invalid_field() {
let filter = parse_filter("invalid_field = \"value\"").unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_err());
}
#[test]
fn test_unsupported_sequence() {
let result = parse_filter("user.profile.name = \"John\"");
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("expected comparator"));
}
}
#[test]
fn test_not_equal_variations() {
let test_cases = vec![
"name != \"John\"",
"age != 25",
"active != true",
"email != NULL",
];
for filter_str in test_cases {
let filter = parse_filter(filter_str).unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok(), "Failed on: {}", filter_str);
}
}
#[test]
fn test_string_comparisons() {
let test_cases = vec![
"name > \"Alice\"",
"name >= \"Bob\"",
"name < \"Zebra\"",
"name <= \"Zoe\"",
];
for filter_str in test_cases {
let filter = parse_filter(filter_str).unwrap();
let condition = filter.to_condition::<Column>();
assert!(condition.is_ok(), "Failed on: {}", filter_str);
}
}
}