use super::ordering_field::OrderingField;
use super::{FilterBackend, FilterResult};
use async_trait::async_trait;
use reinhardt_db::orm::{
Cond, Expr, Lookup, Model, MySqlQueryBuilder, Query, QueryFieldCompiler, QueryStatementBuilder,
};
use std::collections::HashMap;
use std::marker::PhantomData;
pub struct QueryFilter<M: Model> {
lookups: Vec<Lookup<M>>,
or_groups: Vec<Vec<Lookup<M>>>, ordering: Vec<OrderingField<M>>,
_phantom: PhantomData<M>,
}
impl<M: Model> QueryFilter<M> {
pub fn new() -> Self {
Self {
lookups: Vec::new(),
or_groups: Vec::new(),
ordering: Vec::new(),
_phantom: PhantomData,
}
}
pub fn with_lookup(mut self, lookup: Lookup<M>) -> Self {
self.lookups.push(lookup);
self
}
pub fn add_all(mut self, lookups: Vec<Lookup<M>>) -> Self {
self.lookups.extend(lookups);
self
}
pub fn order_by(mut self, field: OrderingField<M>) -> Self {
self.ordering.push(field);
self
}
pub fn order_by_all(mut self, fields: Vec<OrderingField<M>>) -> Self {
self.ordering.extend(fields);
self
}
pub fn add_or_group(mut self, lookups: Vec<Lookup<M>>) -> Self {
if !lookups.is_empty() {
self.or_groups.push(lookups);
}
self
}
pub fn add_multi_term(mut self, term_lookups: Vec<Vec<Lookup<M>>>) -> Self {
for lookups in term_lookups {
if !lookups.is_empty() {
self.or_groups.push(lookups);
}
}
self
}
pub fn lookups(&self) -> &[Lookup<M>] {
&self.lookups
}
pub fn or_groups(&self) -> &[Vec<Lookup<M>>] {
&self.or_groups
}
pub fn ordering(&self) -> &[OrderingField<M>] {
&self.ordering
}
fn compile_where_clause(&self) -> Option<String> {
if self.lookups.is_empty() && self.or_groups.is_empty() {
return None;
}
let mut main_cond = Cond::all();
for lookup in &self.lookups {
main_cond = main_cond.add(QueryFieldCompiler::compile_to_expr(lookup));
}
for or_group in &self.or_groups {
if or_group.is_empty() {
continue;
}
if or_group.len() == 1 {
main_cond = main_cond.add(QueryFieldCompiler::compile_to_expr(&or_group[0]));
} else {
let mut or_cond = Cond::any();
for lookup in or_group {
or_cond = or_cond.add(QueryFieldCompiler::compile_to_expr(lookup));
}
main_cond = main_cond.add(or_cond);
}
}
let query = Query::select()
.expr(Expr::val(1))
.cond_where(main_cond)
.to_string(MySqlQueryBuilder);
query.find("WHERE ").map(|idx| query[idx + 6..].to_string())
}
fn compile_order_clause(&self) -> Option<String> {
if self.ordering.is_empty() {
return None;
}
let order_parts: Vec<String> = self.ordering.iter().map(|field| field.to_sql()).collect();
Some(order_parts.join(", "))
}
fn append_order_clause(&self, sql: &str, new_order: &str) -> String {
if let Some(order_by_pos) = sql.find("ORDER BY") {
let before_order = &sql[..order_by_pos + 8]; let after_order = &sql[order_by_pos + 8..];
let end_markers = ["LIMIT", "OFFSET", ";"];
let mut end_pos = after_order.len();
for marker in &end_markers {
if let Some(pos) = after_order.find(marker) {
end_pos = end_pos.min(pos);
}
}
let existing_order = after_order[..end_pos].trim();
let remaining = &after_order[end_pos..];
if remaining.is_empty() {
format!("{} {}, {}", before_order, existing_order, new_order)
} else {
format!(
"{} {}, {} {}",
before_order,
existing_order,
new_order,
remaining.trim()
)
}
} else {
format!("{} ORDER BY {}", sql, new_order)
}
}
}
impl<M: Model> Default for QueryFilter<M> {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<M: Model> FilterBackend for QueryFilter<M> {
async fn filter_queryset(
&self,
_params: &HashMap<String, String>,
mut sql: String,
) -> FilterResult<String> {
if let Some(where_clause) = self.compile_where_clause() {
sql = if sql.contains("WHERE") {
sql.replace("WHERE", &format!("WHERE {} AND", where_clause))
} else {
format!("{} WHERE {}", sql, where_clause)
};
}
if let Some(order_clause) = self.compile_order_clause() {
if sql.contains("ORDER BY") {
sql = self.append_order_clause(&sql, &order_clause);
} else {
sql = format!("{} ORDER BY {}", sql, order_clause);
}
}
Ok(sql)
}
}