use sea_orm::{DatabaseBackend, sea_query::SimpleExpr};
const MAX_SEARCH_QUERY_LENGTH: usize = 10_000;
fn escape_like_wildcards(input: &str) -> String {
input
.replace('\\', "\\\\") .replace('%', "\\%") .replace('_', "\\_") }
#[must_use]
pub fn build_fulltext_condition<T: crate::traits::CRUDResource>(
query: &str,
backend: DatabaseBackend,
) -> Option<SimpleExpr> {
let fulltext_columns = T::fulltext_searchable_columns();
if fulltext_columns.is_empty() {
return None;
}
match backend {
DatabaseBackend::Postgres => build_postgres_fulltext_condition(query, &fulltext_columns),
DatabaseBackend::MySql => build_mysql_fulltext_condition(query, &fulltext_columns),
_ => build_fallback_fulltext_condition(query, &fulltext_columns),
}
}
fn build_postgres_fulltext_condition(
query: &str,
columns: &[(&'static str, impl sea_orm::ColumnTrait)],
) -> Option<SimpleExpr> {
if columns.is_empty() || query.is_empty() {
return None;
}
let mut concat_parts = Vec::new();
for (name, _column) in columns {
concat_parts.push(format!("COALESCE({name}::text, '')"));
}
let concat_sql = concat_parts.join(" || ' ' || ");
let sanitized_query = query[..query.len().min(MAX_SEARCH_QUERY_LENGTH)].trim();
let escaped_query = escape_like_wildcards(sanitized_query).replace('\'', "''");
let search_sql = format!("({concat_sql}) ILIKE '%{escaped_query}%' ESCAPE '\\'");
Some(SimpleExpr::Custom(search_sql))
}
fn build_mysql_fulltext_condition(
query: &str,
columns: &[(&'static str, impl sea_orm::ColumnTrait)],
) -> Option<SimpleExpr> {
if columns.is_empty() || query.is_empty() {
return None;
}
let mut concat_parts = Vec::new();
for (name, _column) in columns {
concat_parts.push(format!("COALESCE(CAST({name} AS CHAR), '')"));
}
let concat_sql = if concat_parts.len() == 1 {
concat_parts[0].clone()
} else {
format!("CONCAT({})", concat_parts.join(", ' ', "))
};
let sanitized_query = query[..query.len().min(MAX_SEARCH_QUERY_LENGTH)].trim();
let escaped_query = escape_like_wildcards(sanitized_query).replace('\'', "''");
let search_sql = format!("UPPER({concat_sql}) LIKE UPPER('%{escaped_query}%') ESCAPE '\\\\'");
Some(SimpleExpr::Custom(search_sql))
}
fn build_fallback_fulltext_condition(
query: &str,
columns: &[(&'static str, impl sea_orm::ColumnTrait)],
) -> Option<SimpleExpr> {
if columns.is_empty() || query.is_empty() {
return None;
}
let mut concat_parts = Vec::new();
for (name, _column) in columns {
concat_parts.push(format!("CAST({name} AS TEXT)"));
}
let concat_sql = concat_parts.join(" || ' ' || ");
let sanitized_query = query[..query.len().min(MAX_SEARCH_QUERY_LENGTH)].trim();
let escaped_query = escape_like_wildcards(sanitized_query).replace('\'', "''");
let like_sql = format!("UPPER({concat_sql}) LIKE UPPER('%{escaped_query}%') ESCAPE '\\'",);
Some(SimpleExpr::Custom(like_sql))
}
#[must_use]
pub fn build_like_condition(key: &str, trimmed_value: &str) -> SimpleExpr {
use sea_orm::sea_query::{Alias, Expr, ExprTrait, Func};
let column = Expr::col(Alias::new(key));
let escaped_value = escape_like_wildcards(trimmed_value);
let pattern = format!("%{}%", escaped_value.to_uppercase());
Func::upper(column).like(pattern)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_column_names_use_expr_col() {
let result = build_like_condition("user_name", "test");
let sql = format!("{result:?}");
assert!(
sql.contains("Column(") && sql.contains("user_name"),
"Column should be wrapped in Column() AST node, got: {sql}"
);
}
#[test]
fn test_column_names_wrapped_safely() {
let result = build_like_condition("test_column", "value");
let sql = format!("{result:?}");
assert!(sql.contains("Column("), "Should use Expr::col() wrapper");
}
#[test]
fn test_search_query_value_safe() {
let malicious_values = vec!["'; DROP TABLE users; --", "' OR '1'='1"];
for malicious_value in malicious_values {
let result = build_like_condition("title", malicious_value);
let sql = format!("{result:?}");
assert!(
sql.contains("Value(String"),
"Values should be wrapped safely: {sql}"
);
}
}
#[test]
fn test_search_query_length_limit() {
let very_long_query = "a".repeat(20_000);
let sanitized = &very_long_query[..very_long_query.len().min(MAX_SEARCH_QUERY_LENGTH)];
assert!(
sanitized.len() <= MAX_SEARCH_QUERY_LENGTH,
"Query should be truncated to max length"
);
}
#[test]
fn test_wildcard_escaping() {
assert_eq!(
escape_like_wildcards("test"),
"test",
"Normal text should pass through"
);
assert_eq!(
escape_like_wildcards("test%"),
"test\\%",
"% should be escaped"
);
assert_eq!(
escape_like_wildcards("test_value"),
"test\\_value",
"_ should be escaped"
);
assert_eq!(
escape_like_wildcards("100%"),
"100\\%",
"% in middle should be escaped"
);
assert_eq!(
escape_like_wildcards("%_"),
"\\%\\_",
"Both wildcards should be escaped"
);
assert_eq!(
escape_like_wildcards("\\"),
"\\\\",
"Backslash should be escaped"
);
assert_eq!(
escape_like_wildcards("\\%"),
"\\\\\\%",
"Backslash and % should both be escaped"
);
}
#[test]
fn test_like_condition_prevents_wildcard_injection() {
let result_percent = build_like_condition("title", "test%");
let sql_percent = format!("{result_percent:?}");
assert!(
sql_percent.contains("\\\\%"),
"% should be escaped in SQL: {sql_percent}"
);
let result_underscore = build_like_condition("title", "test_value");
let sql_underscore = format!("{result_underscore:?}");
assert!(
sql_underscore.contains("\\\\_"),
"_ should be escaped in SQL: {sql_underscore}"
);
let result_just_percent = build_like_condition("title", "%");
let sql_just_percent = format!("{result_just_percent:?}");
assert!(
sql_just_percent.contains("\\\\%"),
"Single % should be escaped: {sql_just_percent}"
);
}
#[test]
fn test_build_like_condition_empty_value() {
let result = build_like_condition("field", "");
let sql = format!("{result:?}");
assert!(sql.contains("field"), "Should include field name");
}
#[test]
fn test_build_like_condition_case_insensitive() {
let result = build_like_condition("title", "TeSt");
let sql = format!("{result:?}");
assert!(
sql.contains("Upper") || sql.contains("UPPER"),
"Should use UPPER for case insensitivity: {sql}"
);
}
#[test]
fn test_build_like_condition_special_chars() {
let result = build_like_condition("title", "test@email.com");
let sql = format!("{result:?}");
assert!(sql.contains("title"), "Should handle special characters");
}
#[test]
fn test_like_condition_empty_query_matches_all() {
let result = build_like_condition("field", "");
let sql = format!("{result:?}");
assert!(
sql.contains("%%") || sql.contains("%\""),
"Empty query should produce match-all pattern"
);
}
#[test]
fn test_like_condition_whitespace_query() {
let result = build_like_condition("field", " ");
let sql = format!("{result:?}");
assert!(sql.contains("field"), "Should include field name");
}
#[test]
fn test_like_condition_case_insensitive_pattern() {
let result = build_like_condition("field", "MiXeD CaSe");
let sql = format!("{result:?}");
assert!(
sql.contains("MIXED CASE"),
"Pattern should be uppercased for case-insensitive match: {}",
sql
);
}
#[test]
fn test_max_search_query_length_constant() {
assert_eq!(
MAX_SEARCH_QUERY_LENGTH, 10_000,
"Max query length should be 10,000"
);
}
#[test]
fn test_escape_like_wildcards_empty() {
assert_eq!(
escape_like_wildcards(""),
"",
"Empty string should pass through"
);
}
}