use super::{DatabaseDialect, FilterBackend, FilterError, FilterResult};
use async_trait::async_trait;
use reinhardt_query::SimpleExpr;
use reinhardt_query::prelude::{
Alias, Cond, Expr, ExprTrait, MySqlQueryBuilder, Order, PostgresQueryBuilder, Query,
QueryStatementBuilder,
};
use std::collections::HashMap;
use std::sync::Arc;
fn find_sql_keyword(sql: &str, keyword: &str) -> Option<usize> {
let sql_bytes = sql.as_bytes();
let kw_bytes = keyword.as_bytes();
let kw_len = kw_bytes.len();
if sql_bytes.len() < kw_len {
return None;
}
for i in 0..=(sql_bytes.len() - kw_len) {
let matched = sql_bytes[i..i + kw_len]
.iter()
.zip(kw_bytes.iter())
.all(|(s, k)| s.eq_ignore_ascii_case(k));
if !matched {
continue;
}
let after_ok = if i + kw_len >= sql_bytes.len() {
true
} else {
sql_bytes[i + kw_len].is_ascii_whitespace()
};
let before_ok = if i == 0 {
true
} else {
let prev = sql_bytes[i - 1];
prev.is_ascii_whitespace() || prev == b')'
};
if after_ok && before_ok {
return Some(i);
}
}
None
}
fn find_clause_end(sql: &str, start_pos: usize, end_keywords: &[&str]) -> usize {
end_keywords
.iter()
.filter_map(|kw| find_sql_keyword(&sql[start_pos..], kw).map(|pos| start_pos + pos))
.min()
.unwrap_or(sql.len())
}
#[derive(Default)]
pub struct CustomFilterBackend {
filters: Vec<Arc<dyn FilterBackend>>,
}
impl CustomFilterBackend {
pub fn new() -> Self {
Self {
filters: Vec::new(),
}
}
pub fn add_filter(&mut self, filter: Box<dyn FilterBackend>) {
self.filters.push(Arc::from(filter));
}
pub fn filter_count(&self) -> usize {
self.filters.len()
}
}
#[async_trait]
impl FilterBackend for CustomFilterBackend {
async fn filter_queryset(
&self,
query_params: &HashMap<String, String>,
mut sql: String,
) -> FilterResult<String> {
for filter in &self.filters {
sql = filter.filter_queryset(query_params, sql).await?;
}
Ok(sql)
}
}
pub struct SimpleSearchBackend {
param_name: String,
fields: Vec<String>,
dialect: DatabaseDialect,
}
impl SimpleSearchBackend {
pub fn new(param_name: impl Into<String>) -> Self {
Self {
param_name: param_name.into(),
fields: Vec::new(),
dialect: DatabaseDialect::default(),
}
}
pub fn with_field(mut self, field: impl Into<String>) -> Self {
self.fields.push(field.into());
self
}
pub fn with_dialect(mut self, dialect: DatabaseDialect) -> Self {
self.dialect = dialect;
self
}
fn escape_like_pattern(pattern: &str) -> String {
pattern
.replace('\\', "\\\\")
.replace('%', "\\%")
.replace('_', "\\_")
}
fn build_search_conditions(&self, search_query: &str) -> Vec<SimpleExpr> {
let escaped = Self::escape_like_pattern(search_query);
self.fields
.iter()
.map(|field| {
Expr::col(Alias::new(field)).contains(escaped.as_str())
})
.collect()
}
}
#[async_trait]
impl FilterBackend for SimpleSearchBackend {
async fn filter_queryset(
&self,
query_params: &HashMap<String, String>,
sql: String,
) -> FilterResult<String> {
if let Some(search_query) = query_params.get(&self.param_name) {
if self.fields.is_empty() {
return Err(FilterError::InvalidParameter(
"No search fields configured".to_string(),
));
}
let conditions = self.build_search_conditions(search_query);
let mut condition = Cond::any();
for cond in conditions {
condition = condition.add(cond);
}
let query = match self.dialect {
DatabaseDialect::MySQL => Query::select()
.expr(Expr::val(1))
.cond_where(condition)
.to_string(MySqlQueryBuilder),
DatabaseDialect::PostgreSQL => Query::select()
.expr(Expr::val(1))
.cond_where(condition)
.to_string(PostgresQueryBuilder),
};
let condition_str = if let Some(idx) = query.find("WHERE ") {
query[idx + 6..].to_string()
} else {
String::new()
};
let where_clause = format!("WHERE ({})", condition_str);
if let Some(where_pos) = find_sql_keyword(&sql, "WHERE") {
let after_keyword = where_pos + "WHERE".len();
let content_start = sql[after_keyword..]
.bytes()
.position(|b| !b.is_ascii_whitespace())
.map(|p| after_keyword + p)
.unwrap_or(after_keyword);
let clause_end_keywords = ["GROUP BY", "ORDER BY", "LIMIT", "OFFSET", "HAVING"];
let end_pos = find_clause_end(&sql, content_start, &clause_end_keywords);
let existing_where = sql[content_start..end_pos].trim();
let remainder = &sql[end_pos..];
let prefix = &sql[..where_pos];
Ok(format!(
"{}WHERE ({}) AND ({}) {}",
prefix,
existing_where,
condition_str,
remainder.trim()
))
} else {
Ok(format!("{} {}", sql, where_clause))
}
} else {
Ok(sql)
}
}
}
pub struct SimpleOrderingBackend {
param_name: String,
allowed_fields: Vec<String>,
}
impl SimpleOrderingBackend {
pub fn new(param_name: impl Into<String>) -> Self {
Self {
param_name: param_name.into(),
allowed_fields: Vec::new(),
}
}
pub fn allow_field(mut self, field: impl Into<String>) -> Self {
self.allowed_fields.push(field.into());
self
}
fn build_order_clause(&self, field: &str, order: Order) -> String {
let order_str = match order {
Order::Asc => "ASC",
Order::Desc => "DESC",
};
format!("{} {}", field, order_str)
}
}
#[async_trait]
impl FilterBackend for SimpleOrderingBackend {
async fn filter_queryset(
&self,
query_params: &HashMap<String, String>,
sql: String,
) -> FilterResult<String> {
if let Some(ordering) = query_params.get(&self.param_name) {
let (field, order) = if let Some(field_name) = ordering.strip_prefix('-') {
(field_name, Order::Desc)
} else {
(ordering.as_str(), Order::Asc)
};
if !self.allowed_fields.contains(&field.to_string()) {
return Err(FilterError::InvalidParameter(format!(
"Field '{}' is not allowed for ordering",
field
)));
}
let order_expr = self.build_order_clause(field, order);
let order_clause = format!("ORDER BY {}", order_expr);
if let Some(order_pos) = find_sql_keyword(&sql, "ORDER BY") {
let after_keyword = order_pos + "ORDER BY".len();
let content_start = sql[after_keyword..]
.bytes()
.position(|b| !b.is_ascii_whitespace())
.map(|p| after_keyword + p)
.unwrap_or(after_keyword);
let clause_end_keywords = ["LIMIT", "OFFSET"];
let end_pos = find_clause_end(&sql, content_start, &clause_end_keywords);
let existing_order = sql[content_start..end_pos].trim_end();
let remainder = &sql[end_pos..];
let prefix = &sql[..order_pos];
Ok(format!(
"{}ORDER BY {}, {} {}",
prefix,
existing_order,
order_expr,
remainder.trim()
)
.trim_end()
.to_string())
} else {
Ok(format!("{} {}", sql, order_clause))
}
} else {
Ok(sql)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
#[tokio::test]
async fn test_custom_filter_backend_empty() {
let backend = CustomFilterBackend::new();
let params = HashMap::new();
let sql = "SELECT * FROM users".to_string();
let count = backend.filter_count();
let result = backend.filter_queryset(¶ms, sql.clone()).await.unwrap();
assert_eq!(count, 0);
assert_eq!(result, sql);
}
#[rstest]
#[tokio::test]
async fn test_custom_filter_backend_with_filters() {
let mut backend = CustomFilterBackend::new();
backend.add_filter(Box::new(
SimpleSearchBackend::new("search").with_field("name"),
));
let mut params = HashMap::new();
params.insert("search".to_string(), "john".to_string());
let sql = "SELECT * FROM users".to_string();
let count = backend.filter_count();
let result = backend.filter_queryset(¶ms, sql).await.unwrap();
assert_eq!(count, 1);
assert!(result.contains("WHERE"));
assert!(result.contains("`name` LIKE '%john%'"));
}
#[rstest]
#[tokio::test]
async fn test_simple_search_backend() {
let backend = SimpleSearchBackend::new("search")
.with_field("title")
.with_field("content");
let mut params = HashMap::new();
params.insert("search".to_string(), "rust".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = backend.filter_queryset(¶ms, sql).await.unwrap();
assert!(result.contains("WHERE"));
assert!(result.contains("`title` LIKE '%rust%'"));
assert!(result.contains("`content` LIKE '%rust%'"));
assert!(result.contains("OR"));
}
#[rstest]
#[tokio::test]
async fn test_simple_search_backend_no_query() {
let backend = SimpleSearchBackend::new("search").with_field("title");
let params = HashMap::new();
let sql = "SELECT * FROM articles".to_string();
let result = backend.filter_queryset(¶ms, sql.clone()).await.unwrap();
assert_eq!(result, sql);
}
#[rstest]
#[tokio::test]
async fn test_simple_search_backend_no_fields() {
let backend = SimpleSearchBackend::new("search");
let mut params = HashMap::new();
params.insert("search".to_string(), "rust".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = backend.filter_queryset(¶ms, sql).await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_simple_search_backend_postgres() {
let backend = SimpleSearchBackend::new("search")
.with_field("title")
.with_dialect(DatabaseDialect::PostgreSQL);
let mut params = HashMap::new();
params.insert("search".to_string(), "rust".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = backend.filter_queryset(¶ms, sql).await.unwrap();
assert!(result.contains("WHERE"));
assert!(result.contains("\"title\" LIKE '%rust%'"));
}
#[rstest]
#[case("' OR '1'='1")]
#[case("'; DROP TABLE articles; --")]
#[case("' UNION SELECT * FROM users --")]
#[tokio::test]
async fn test_simple_search_backend_sql_injection_prevention(#[case] payload: &str) {
let backend = SimpleSearchBackend::new("search").with_field("title");
let mut params = HashMap::new();
params.insert("search".to_string(), payload.to_string());
let sql = "SELECT * FROM articles".to_string();
let result = backend.filter_queryset(¶ms, sql).await.unwrap();
assert!(
result.contains("LIKE"),
"Result should contain LIKE clause for payload: {}",
payload
);
let single_quote_count = result.matches('\'').count();
assert!(
single_quote_count % 2 == 0,
"SQL injection vulnerability: unbalanced single quotes in result for payload: {}. Result: {}",
payload,
result
);
}
#[rstest]
#[tokio::test]
async fn test_simple_ordering_backend_asc() {
let backend = SimpleOrderingBackend::new("ordering")
.allow_field("created_at")
.allow_field("title");
let mut params = HashMap::new();
params.insert("ordering".to_string(), "created_at".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = backend.filter_queryset(¶ms, sql).await.unwrap();
assert!(result.contains("ORDER BY created_at ASC"));
}
#[rstest]
#[tokio::test]
async fn test_simple_ordering_backend_desc() {
let backend = SimpleOrderingBackend::new("ordering")
.allow_field("created_at")
.allow_field("title");
let mut params = HashMap::new();
params.insert("ordering".to_string(), "-created_at".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = backend.filter_queryset(¶ms, sql).await.unwrap();
assert!(result.contains("ORDER BY created_at DESC"));
}
#[rstest]
#[tokio::test]
async fn test_simple_ordering_backend_invalid_field() {
let backend = SimpleOrderingBackend::new("ordering").allow_field("created_at");
let mut params = HashMap::new();
params.insert("ordering".to_string(), "invalid_field".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = backend.filter_queryset(¶ms, sql).await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_simple_ordering_backend_no_query() {
let backend = SimpleOrderingBackend::new("ordering").allow_field("created_at");
let params = HashMap::new();
let sql = "SELECT * FROM articles".to_string();
let result = backend.filter_queryset(¶ms, sql.clone()).await.unwrap();
assert_eq!(result, sql);
}
#[rstest]
#[tokio::test]
async fn test_chained_filters() {
let mut backend = CustomFilterBackend::new();
backend.add_filter(Box::new(
SimpleSearchBackend::new("search").with_field("title"),
));
backend.add_filter(Box::new(
SimpleOrderingBackend::new("ordering").allow_field("created_at"),
));
let mut params = HashMap::new();
params.insert("search".to_string(), "rust".to_string());
params.insert("ordering".to_string(), "-created_at".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = backend.filter_queryset(¶ms, sql).await.unwrap();
assert!(result.contains("WHERE"));
assert!(result.contains("`title` LIKE '%rust%'"));
assert!(result.contains("ORDER BY created_at DESC"));
}
#[rstest]
#[tokio::test]
async fn test_ordering_preserves_existing_order_by() {
let backend = SimpleOrderingBackend::new("ordering").allow_field("title");
let mut params = HashMap::new();
params.insert("ordering".to_string(), "title".to_string());
let sql = "SELECT * FROM articles ORDER BY created_at ASC".to_string();
let result = backend.filter_queryset(¶ms, sql).await.unwrap();
assert!(
result.contains("ORDER BY created_at ASC, title ASC"),
"Expected existing ORDER BY to be preserved with new criteria appended, got: {}",
result
);
}
#[rstest]
#[tokio::test]
async fn test_ordering_preserves_existing_order_by_with_limit() {
let backend = SimpleOrderingBackend::new("ordering").allow_field("title");
let mut params = HashMap::new();
params.insert("ordering".to_string(), "-title".to_string());
let sql = "SELECT * FROM articles ORDER BY created_at ASC LIMIT 10".to_string();
let result = backend.filter_queryset(¶ms, sql).await.unwrap();
assert!(
result.contains("ORDER BY created_at ASC, title DESC"),
"Expected existing ORDER BY to be preserved, got: {}",
result
);
assert!(
result.contains("LIMIT 10"),
"Expected LIMIT clause to be preserved, got: {}",
result
);
}
#[rstest]
#[tokio::test]
async fn test_search_preserves_existing_where_clause() {
let backend = SimpleSearchBackend::new("search").with_field("title");
let mut params = HashMap::new();
params.insert("search".to_string(), "rust".to_string());
let sql = "SELECT * FROM articles WHERE status = 'published'".to_string();
let result = backend.filter_queryset(¶ms, sql).await.unwrap();
assert!(
result.contains("status = 'published'"),
"Expected original WHERE condition to be preserved, got: {}",
result
);
assert!(
result.contains("`title` LIKE '%rust%'"),
"Expected search condition to be added, got: {}",
result
);
assert!(
result.contains("AND"),
"Expected AND joining original and search conditions, got: {}",
result
);
}
#[rstest]
#[tokio::test]
async fn test_search_preserves_complex_where_clause() {
let backend = SimpleSearchBackend::new("search").with_field("name");
let mut params = HashMap::new();
params.insert("search".to_string(), "john".to_string());
let sql = "SELECT * FROM users WHERE (age > 18 AND active = true) ORDER BY id".to_string();
let result = backend.filter_queryset(¶ms, sql).await.unwrap();
assert!(
result.contains("age > 18 AND active = true"),
"Expected complex WHERE condition to be preserved, got: {}",
result
);
assert!(
result.contains("`name` LIKE '%john%'"),
"Expected search condition to be added, got: {}",
result
);
assert!(
result.contains("ORDER BY"),
"Expected ORDER BY clause to be preserved, got: {}",
result
);
}
#[rstest]
#[tokio::test]
async fn test_search_adds_where_to_empty_query() {
let backend = SimpleSearchBackend::new("search").with_field("title");
let mut params = HashMap::new();
params.insert("search".to_string(), "rust".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = backend.filter_queryset(¶ms, sql).await.unwrap();
assert!(
result.contains("WHERE"),
"Expected WHERE clause to be added, got: {}",
result
);
assert!(
result.contains("`title` LIKE '%rust%'"),
"Expected search condition, got: {}",
result
);
}
#[rstest]
#[tokio::test]
async fn test_ordering_adds_order_by_to_empty_query() {
let backend = SimpleOrderingBackend::new("ordering").allow_field("created_at");
let mut params = HashMap::new();
params.insert("ordering".to_string(), "-created_at".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = backend.filter_queryset(¶ms, sql).await.unwrap();
assert_eq!(
result, "SELECT * FROM articles ORDER BY created_at DESC",
"Expected ORDER BY to be appended to query without existing ORDER BY"
);
}
}