use super::db_sql;
use super::{
CTE, FrameBound, FrameType, Order, QueryBuilder, UnionClause, UnionType, WindowFunction,
WindowFunctionType,
};
use crate::columns::ColumnLike;
use crate::config::DatabaseType;
#[cfg(feature = "fulltext")]
use crate::fulltext::{FullTextSearchBuilder, SearchMode};
use crate::internal::Value;
use crate::model::Model as ModelTrait;
use std::time::Duration;
#[tideorm::model(table = "query_test_users")]
struct QueryTestUser {
#[tideorm(primary_key, auto_increment)]
id: i64,
name: String,
}
#[test]
fn test_quote_char() {
assert_eq!(db_sql::quote_char(DatabaseType::Postgres), '"');
assert_eq!(db_sql::quote_char(DatabaseType::MySQL), '`');
assert_eq!(db_sql::quote_char(DatabaseType::MariaDB), '`');
assert_eq!(db_sql::quote_char(DatabaseType::SQLite), '"');
}
#[test]
fn test_quote_ident() {
assert_eq!(
db_sql::quote_ident(DatabaseType::Postgres, "column"),
"\"column\""
);
assert_eq!(
db_sql::quote_ident(DatabaseType::MySQL, "column"),
"`column`"
);
assert_eq!(
db_sql::quote_ident(DatabaseType::MariaDB, "column"),
"`column`"
);
assert_eq!(
db_sql::quote_ident(DatabaseType::SQLite, "column"),
"\"column\""
);
assert_eq!(
db_sql::quote_ident(DatabaseType::Postgres, "col\"umn"),
"\"col\"\"umn\""
);
assert_eq!(
db_sql::quote_ident(DatabaseType::MySQL, "col`umn"),
"`col``umn`"
);
}
#[test]
fn test_json_contains_postgres() {
let sql = db_sql::json_contains(DatabaseType::Postgres, "metadata", r#"{"key": "value"}"#);
assert!(sql.contains("@>"));
assert!(sql.contains("\"metadata\""));
}
#[test]
fn test_json_contains_mysql() {
let sql = db_sql::json_contains(DatabaseType::MySQL, "metadata", r#"{"key": "value"}"#);
assert!(sql.contains("JSON_CONTAINS"));
assert!(sql.contains("`metadata`"));
let sql = db_sql::json_contains(DatabaseType::MariaDB, "metadata", r#"{"key": "value"}"#);
assert!(sql.contains("JSON_CONTAINS"));
assert!(sql.contains("`metadata`"));
}
#[test]
fn test_json_contains_sqlite() {
let sql = db_sql::json_contains(DatabaseType::SQLite, "metadata", "test_value");
assert!(sql.contains("json_each"));
assert!(sql.contains("\"metadata\""));
}
#[test]
fn test_json_key_exists_postgres() {
let sql = db_sql::json_key_exists(DatabaseType::Postgres, "data", "email");
assert_eq!(sql, "\"data\" ? 'email'");
}
#[test]
fn test_json_key_exists_mysql() {
let sql = db_sql::json_key_exists(DatabaseType::MySQL, "data", "email");
assert!(sql.contains("JSON_CONTAINS_PATH"));
assert!(sql.contains("$.\"email\""));
let sql = db_sql::json_key_exists(DatabaseType::MariaDB, "data", "email");
assert!(sql.contains("JSON_CONTAINS_PATH"));
assert!(sql.contains("$.\"email\""));
}
#[test]
fn test_json_key_exists_sqlite() {
let sql = db_sql::json_key_exists(DatabaseType::SQLite, "data", "email");
assert!(sql.contains("json_extract"));
assert!(sql.contains("$.\"email\""));
assert!(sql.contains("IS NOT NULL"));
}
#[test]
fn test_json_path_exists_postgres() {
let sql = db_sql::json_path_exists(DatabaseType::Postgres, "data", "$.user.name");
assert!(sql.contains("@?"));
}
#[test]
fn test_json_path_exists_mysql() {
let sql = db_sql::json_path_exists(DatabaseType::MySQL, "data", "$.user.name");
assert!(sql.contains("JSON_CONTAINS_PATH"));
assert!(sql.contains("$.\"user\".\"name\""));
let sql = db_sql::json_path_exists(DatabaseType::MariaDB, "data", "$.user.name");
assert!(sql.contains("JSON_CONTAINS_PATH"));
assert!(sql.contains("$.\"user\".\"name\""));
}
#[test]
fn test_json_path_exists_sqlite() {
let sql = db_sql::json_path_exists(DatabaseType::SQLite, "data", "$.user.name");
assert!(sql.contains("json_extract"));
assert!(sql.contains("$.\"user\".\"name\""));
}
#[test]
fn test_array_contains_postgres() {
let values = vec!["'admin'".to_string(), "'user'".to_string()];
let sql = db_sql::array_contains(DatabaseType::Postgres, "roles", &values);
assert!(sql.contains("@>"));
assert!(sql.contains("ARRAY["));
}
#[test]
fn test_array_contains_mysql() {
let values = vec!["'admin'".to_string(), "'user'".to_string()];
let sql = db_sql::array_contains(DatabaseType::MySQL, "roles", &values);
assert!(sql.contains("JSON_CONTAINS"));
let sql = db_sql::array_contains(DatabaseType::MariaDB, "roles", &values);
assert!(sql.contains("JSON_CONTAINS"));
}
#[test]
fn test_array_contains_sqlite() {
let values = vec!["'admin'".to_string(), "'user'".to_string()];
let sql = db_sql::array_contains(DatabaseType::SQLite, "roles", &values);
assert!(sql.contains("json_each"));
}
#[test]
fn test_array_overlaps_postgres() {
let values = vec!["'a'".to_string(), "'b'".to_string()];
let sql = db_sql::array_overlaps(DatabaseType::Postgres, "tags", &values);
assert!(sql.contains("&&"));
assert!(sql.contains("ARRAY["));
}
#[test]
fn test_array_overlaps_mysql() {
let values = vec!["'a'".to_string(), "'b'".to_string()];
let sql = db_sql::array_overlaps(DatabaseType::MySQL, "tags", &values);
assert!(sql.contains(" OR "));
let sql = db_sql::array_overlaps(DatabaseType::MariaDB, "tags", &values);
assert!(sql.contains(" OR "));
}
#[test]
fn test_array_overlaps_sqlite() {
let values = vec!["'a'".to_string(), "'b'".to_string()];
let sql = db_sql::array_overlaps(DatabaseType::SQLite, "tags", &values);
assert!(sql.contains(" OR "));
}
#[test]
fn test_format_column_simple() {
assert_eq!(
db_sql::format_column(DatabaseType::Postgres, "name"),
"\"name\""
);
assert_eq!(db_sql::format_column(DatabaseType::MySQL, "name"), "`name`");
assert_eq!(
db_sql::format_column(DatabaseType::MariaDB, "name"),
"`name`"
);
}
#[test]
fn test_format_column_dotted() {
assert_eq!(
db_sql::format_column(DatabaseType::Postgres, "users.name"),
"\"users\".\"name\""
);
assert_eq!(
db_sql::format_column(DatabaseType::MySQL, "users.name"),
"`users`.`name`"
);
assert_eq!(
db_sql::format_column(DatabaseType::MariaDB, "users.name"),
"`users`.`name`"
);
}
#[test]
fn test_format_column_expression() {
assert_eq!(
db_sql::format_column(DatabaseType::Postgres, "COUNT(*)"),
"COUNT(*)"
);
}
#[test]
fn test_format_identifier_reference_quotes_reserved_words() {
assert_eq!(
db_sql::format_identifier_reference(DatabaseType::Postgres, "order"),
Some("\"order\"".to_string())
);
assert_eq!(
db_sql::format_identifier_reference(DatabaseType::MySQL, "users.group"),
Some("`users`.`group`".to_string())
);
}
#[test]
fn test_cast_to_float() {
assert_eq!(
db_sql::cast_to_float(DatabaseType::Postgres, "value"),
"CAST(value AS FLOAT8)"
);
assert_eq!(
db_sql::cast_to_float(DatabaseType::MySQL, "value"),
"CAST(value AS DOUBLE)"
);
assert_eq!(
db_sql::cast_to_float(DatabaseType::MariaDB, "value"),
"CAST(value AS DOUBLE)"
);
assert_eq!(
db_sql::cast_to_float(DatabaseType::SQLite, "value"),
"CAST(value AS REAL)"
);
}
#[test]
fn test_sql_injection_prevention() {
let sql = db_sql::json_contains(DatabaseType::Postgres, "data", "O'Brien");
assert!(sql.contains("O''Brien"));
let sql = db_sql::json_key_exists(DatabaseType::MySQL, "data", "key'; DROP TABLE--");
assert_eq!(
sql,
"JSON_CONTAINS_PATH(`data`, 'one', '$.\"key''; DROP TABLE--\"')"
);
let sql = db_sql::json_key_exists(DatabaseType::MariaDB, "data", "key'; DROP TABLE--");
assert_eq!(
sql,
"JSON_CONTAINS_PATH(`data`, 'one', '$.\"key''; DROP TABLE--\"')"
);
}
#[test]
fn test_json_path_injection_is_rejected_for_mysql_and_sqlite() {
let path = "$.user') OR 1=1 --";
assert_eq!(
db_sql::json_path_exists(DatabaseType::MySQL, "data", path),
"0 = 1"
);
assert_eq!(
db_sql::json_path_not_exists(DatabaseType::MySQL, "data", path),
"0 = 1"
);
assert_eq!(
db_sql::json_path_exists(DatabaseType::SQLite, "data", path),
"0 = 1"
);
assert_eq!(
db_sql::json_path_not_exists(DatabaseType::SQLite, "data", path),
"0 = 1"
);
}
#[test]
fn test_json_path_special_keys_are_quoted_safely() {
let sql = db_sql::json_path_exists(DatabaseType::MySQL, "data", "$['weird.key'][0].name");
assert_eq!(
sql,
"JSON_CONTAINS_PATH(`data`, 'one', '$.\"weird.key\"[0].\"name\"')"
);
}
#[test]
fn test_mysql_array_literals_are_json_encoded() {
let values = vec!["'ad\"min'".to_string(), "'slash\\user'".to_string()];
let contains_sql = db_sql::array_contains(DatabaseType::MySQL, "roles", &values);
assert_eq!(
contains_sql,
"JSON_CONTAINS(`roles`, '[\"ad\\\"min\",\"slash\\\\user\"]')"
);
let overlaps_sql = db_sql::array_overlaps(DatabaseType::MySQL, "roles", &values);
assert_eq!(
overlaps_sql,
"(JSON_CONTAINS(`roles`, '\"ad\\\"min\"') OR JSON_CONTAINS(`roles`, '\"slash\\\\user\"'))"
);
}
#[test]
fn test_join_identifier_validation_accepts_safe_values() {
assert!(db_sql::validate_identifier("JOIN table", "users").is_ok());
assert!(db_sql::validate_identifier("JOIN alias", "author_1").is_ok());
assert!(db_sql::validate_join_column("posts.user_id").is_ok());
}
#[test]
fn test_join_identifier_validation_rejects_injection() {
let table_err =
db_sql::validate_identifier("JOIN table", "users; DROP TABLE users; --").unwrap_err();
assert!(table_err.contains("unsafe JOIN table"));
let alias_err = db_sql::validate_identifier("JOIN alias", "author --").unwrap_err();
assert!(alias_err.contains("unsafe JOIN alias"));
let column_err = db_sql::validate_join_column("posts.user_id OR 1=1").unwrap_err();
assert!(column_err.contains("unsafe JOIN column reference"));
}
#[test]
fn test_raw_sql_fragment_validation_rejects_injection_tokens() {
let err =
db_sql::validate_raw_sql_fragment("WHERE raw SQL", "1 = 1; DROP TABLE users").unwrap_err();
assert!(err.contains("unsafe WHERE raw SQL"));
let comment_err =
db_sql::validate_raw_sql_fragment("WHERE raw SQL", "1 = 1 -- comment").unwrap_err();
assert!(comment_err.contains("unsafe WHERE raw SQL"));
}
#[test]
fn test_subquery_validation_rejects_non_select_sql() {
let err = db_sql::validate_subquery_sql("DELETE FROM users").unwrap_err();
assert!(err.contains("unsafe subquery"));
}
#[tokio::test]
async fn test_where_raw_rejects_unsafe_sql_before_db_lookup() {
let err = QueryTestUser::query()
.where_raw("1 = 1; DROP TABLE users")
.count()
.await
.unwrap_err();
assert!(err.to_string().contains("unsafe WHERE raw SQL"));
}
#[tokio::test]
async fn test_or_where_raw_rejects_unsafe_sql_before_db_lookup() {
let err = QueryTestUser::query()
.begin_or()
.or_where_raw("1 = 1; DROP TABLE users")
.end_or()
.count()
.await
.unwrap_err();
assert!(err.to_string().contains("unsafe WHERE raw SQL"));
}
#[tokio::test]
async fn test_having_rejects_unsafe_sql_before_db_lookup() {
let err = QueryTestUser::query()
.group_by("id")
.having("COUNT(*) > 0; DROP TABLE users")
.count()
.await
.unwrap_err();
assert!(err.to_string().contains("unsafe HAVING raw SQL"));
}
#[tokio::test]
async fn test_select_raw_rejects_unsafe_sql_before_db_lookup() {
let err = QueryTestUser::query()
.select_raw("id; DROP TABLE users")
.count()
.await
.unwrap_err();
assert!(err.to_string().contains("unsafe SELECT raw SQL"));
}
#[tokio::test]
async fn test_where_in_subquery_rejects_invalid_nested_query_before_db_lookup() {
let err = QueryTestUser::query()
.where_in_subquery(
"id",
QueryTestUser::query()
.select(vec!["id"])
.where_raw("1 = 1; DROP TABLE users"),
)
.count()
.await
.unwrap_err();
assert!(
err.to_string()
.contains("invalid subquery for where_in_subquery()")
);
}
#[tokio::test]
async fn test_select_subquery_rejects_invalid_nested_query_before_db_lookup() {
let err = QueryTestUser::query()
.select_subquery(
QueryTestUser::query()
.select(vec!["id"])
.where_raw("1 = 1; DROP TABLE users"),
"nested_id",
)
.count()
.await
.unwrap_err();
assert!(
err.to_string()
.contains("invalid subquery for select_subquery()")
);
}
#[tokio::test]
async fn test_select_subquery_rejects_unsafe_alias_before_db_lookup() {
let err = QueryTestUser::query()
.select_subquery(QueryTestUser::query().select(vec!["id"]), "bad\"alias")
.count()
.await
.unwrap_err();
assert!(err.to_string().contains("unsafe SELECT alias"));
}
#[tokio::test]
async fn test_union_raw_rejects_invalid_subquery_before_db_lookup() {
let err = QueryTestUser::query()
.union_raw("DELETE FROM users")
.count()
.await
.unwrap_err();
assert!(err.to_string().contains("invalid subquery for union_raw()"));
}
#[tokio::test]
async fn test_union_all_raw_rejects_invalid_subquery_before_db_lookup() {
let err = QueryTestUser::query()
.union_all_raw("DELETE FROM users")
.count()
.await
.unwrap_err();
assert!(
err.to_string()
.contains("invalid subquery for union_all_raw()")
);
}
#[tokio::test]
async fn test_with_cte_rejects_non_select_sql_before_db_lookup() {
let err = QueryTestUser::query()
.with_cte(CTE::new(
"active_users",
"DELETE FROM users RETURNING id".to_string(),
))
.count()
.await
.unwrap_err();
assert!(err.to_string().contains("invalid CTE for with_cte()"));
}
#[tokio::test]
async fn test_with_cte_columns_rejects_non_select_sql_before_db_lookup() {
let err = QueryTestUser::query()
.with_cte_columns("active_users", vec!["id"], "DELETE FROM users RETURNING id")
.count()
.await
.unwrap_err();
assert!(
err.to_string()
.contains("invalid subquery for with_cte_columns()")
);
}
#[tokio::test]
async fn test_with_recursive_cte_rejects_non_select_sql_before_db_lookup() {
let err = QueryTestUser::query()
.with_recursive_cte(
"user_tree",
vec!["id"],
"SELECT id FROM users",
"DELETE FROM users RETURNING id",
)
.count()
.await
.unwrap_err();
assert!(
err.to_string()
.contains("invalid subquery for with_recursive_cte() recursive query")
);
}
#[tokio::test]
async fn test_lag_rejects_unsafe_default_expression_before_db_lookup() {
let err = QueryTestUser::query()
.lag(
"previous_id",
"id",
1,
Some("0; DROP TABLE users"),
"name",
"id",
Order::Asc,
)
.count()
.await
.unwrap_err();
assert!(
err.to_string()
.contains("unsafe LAG/LEAD default expression")
);
}
#[tokio::test]
async fn test_custom_window_expression_rejects_unsafe_sql_before_db_lookup() {
let err = QueryTestUser::query()
.window(WindowFunction::new(
WindowFunctionType::Custom("SUM(id); DROP TABLE users".to_string()),
"unsafe_sum",
))
.count()
.await
.unwrap_err();
assert!(
err.to_string()
.contains("unsafe window function expression")
);
}
#[test]
fn test_union_type_sql() {
assert_eq!(UnionType::Union.as_sql(), "UNION");
assert_eq!(UnionType::UnionAll.as_sql(), "UNION ALL");
}
#[test]
fn test_union_clause_creation() {
let clause = UnionClause {
union_type: UnionType::Union,
query_sql: "SELECT * FROM users WHERE active = true".to_string(),
};
assert_eq!(clause.union_type, UnionType::Union);
assert!(clause.query_sql.contains("active = true"));
}
#[test]
fn test_frame_bound_sql() {
assert_eq!(
FrameBound::UnboundedPreceding.as_sql(),
"UNBOUNDED PRECEDING"
);
assert_eq!(
FrameBound::UnboundedFollowing.as_sql(),
"UNBOUNDED FOLLOWING"
);
assert_eq!(FrameBound::CurrentRow.as_sql(), "CURRENT ROW");
assert_eq!(FrameBound::Preceding(5).as_sql(), "5 PRECEDING");
assert_eq!(FrameBound::Following(3).as_sql(), "3 FOLLOWING");
}
#[test]
fn test_frame_type_sql() {
assert_eq!(FrameType::Rows.as_sql(), "ROWS");
assert_eq!(FrameType::Range.as_sql(), "RANGE");
assert_eq!(FrameType::Groups.as_sql(), "GROUPS");
}
#[test]
fn test_window_function_type_row_number() {
assert_eq!(WindowFunctionType::RowNumber.as_sql(), "ROW_NUMBER()");
}
#[test]
fn test_window_function_type_rank() {
assert_eq!(WindowFunctionType::Rank.as_sql(), "RANK()");
}
#[test]
fn test_window_function_type_dense_rank() {
assert_eq!(WindowFunctionType::DenseRank.as_sql(), "DENSE_RANK()");
}
#[test]
fn test_window_function_type_ntile() {
assert_eq!(WindowFunctionType::Ntile(4).as_sql(), "NTILE(4)");
}
#[test]
fn test_window_function_type_lag() {
let sql = WindowFunctionType::Lag("price".to_string(), Some(1), Some("0".to_string())).as_sql();
assert!(sql.contains("LAG"));
assert!(sql.contains("\"price\""));
assert!(sql.contains("1"));
}
#[test]
fn test_window_function_type_lead() {
let sql = WindowFunctionType::Lead("date".to_string(), Some(1), None).as_sql();
assert!(sql.contains("LEAD"));
assert!(sql.contains("\"date\""));
}
#[test]
fn test_window_function_type_first_value() {
assert_eq!(
WindowFunctionType::FirstValue("amount".to_string()).as_sql(),
"FIRST_VALUE(\"amount\")"
);
}
#[test]
fn test_window_function_type_last_value() {
assert_eq!(
WindowFunctionType::LastValue("total".to_string()).as_sql(),
"LAST_VALUE(\"total\")"
);
}
#[test]
fn test_window_function_type_sum() {
assert_eq!(
WindowFunctionType::Sum("amount".to_string()).as_sql(),
"SUM(\"amount\")"
);
}
#[test]
fn test_window_function_type_count() {
assert_eq!(WindowFunctionType::Count(None).as_sql(), "COUNT(*)");
assert_eq!(
WindowFunctionType::Count(Some("id".to_string())).as_sql(),
"COUNT(\"id\")"
);
}
#[test]
fn test_window_function_basic() {
let sql = WindowFunction::new(WindowFunctionType::RowNumber, "row_num").to_sql();
assert!(sql.contains("ROW_NUMBER()"));
assert!(sql.contains("OVER"));
assert!(sql.contains("AS \"row_num\""));
}
#[test]
fn test_window_function_with_partition() {
let sql = WindowFunction::new(WindowFunctionType::RowNumber, "row_num")
.partition_by("category")
.to_sql();
assert!(sql.contains("PARTITION BY \"category\""));
}
#[test]
fn test_window_function_with_order() {
let sql = WindowFunction::new(WindowFunctionType::Rank, "rank")
.order_by("score", Order::Desc)
.to_sql();
assert!(sql.contains("ORDER BY \"score\" DESC"));
}
#[test]
fn test_window_function_with_frame() {
let sql = WindowFunction::new(
WindowFunctionType::Sum("amount".to_string()),
"running_total",
)
.order_by("date", Order::Asc)
.frame(
FrameType::Rows,
FrameBound::UnboundedPreceding,
FrameBound::CurrentRow,
)
.to_sql();
assert!(sql.contains("ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"));
}
#[test]
fn test_window_function_full() {
let sql = WindowFunction::new(WindowFunctionType::Sum("sales".to_string()), "total_sales")
.partition_by("region")
.order_by("month", Order::Asc)
.frame(
FrameType::Range,
FrameBound::UnboundedPreceding,
FrameBound::CurrentRow,
)
.to_sql();
assert!(sql.contains("SUM(\"sales\")"));
assert!(sql.contains("PARTITION BY \"region\""));
assert!(sql.contains("ORDER BY \"month\" ASC"));
assert!(sql.contains("RANGE BETWEEN"));
assert!(sql.contains("AS \"total_sales\""));
}
#[test]
fn test_cte_basic() {
let cte = CTE::new(
"active_users",
"SELECT * FROM users WHERE active = true".to_string(),
);
let sql = cte.to_sql();
assert!(sql.contains("\"active_users\""));
assert!(sql.contains("AS ("));
assert!(sql.contains("active = true"));
}
#[test]
fn test_cte_with_columns() {
let cte = CTE::with_columns(
"user_stats",
vec!["user_id", "total", "count"],
"SELECT user_id, SUM(amount), COUNT(*) FROM orders GROUP BY user_id".to_string(),
);
let sql = cte.to_sql();
assert!(sql.contains("\"user_stats\""));
assert!(sql.contains("(\"user_id\", \"total\", \"count\")"));
assert!(sql.contains("GROUP BY"));
}
#[test]
fn test_cte_recursive() {
let cte = CTE::new("tree", "SELECT 1 UNION ALL SELECT 2".to_string()).recursive();
assert!(cte.recursive);
}
#[test]
fn test_cte_name_quoting() {
let cte = CTE::new("my_cte", "SELECT 1".to_string());
let sql = cte.to_sql();
assert!(sql.starts_with("\"my_cte\""));
}
#[test]
fn test_build_where_sql_includes_or_groups() {
let query = QueryBuilder::<QueryTestUser>::new()
.where_eq("status", "active")
.or_where(|q| q.where_eq("role", "admin").where_eq("role", "moderator"));
let sql = query.build_where_sql_for_db(DatabaseType::Postgres);
assert_eq!(
sql,
"\"status\" = 'active' AND (\"role\" = 'admin' OR \"role\" = 'moderator')"
);
}
#[test]
fn test_build_where_sql_includes_typed_columns_in_or_groups() {
let query = QueryBuilder::<QueryTestUser>::new().or_where(|q| {
q.where_eq(QueryTestUser::columns.name, "alice")
.where_eq(QueryTestUser::columns.id, 7)
});
let sql = query.build_where_sql_for_db(DatabaseType::Postgres);
assert_eq!(sql, "(\"name\" = 'alice' OR \"id\" = 7)");
}
#[test]
fn test_begin_or_where_eq_accepts_typed_columns() {
let query = QueryBuilder::<QueryTestUser>::new()
.begin_or_where_eq(QueryTestUser::columns.name, "alice")
.and_where_eq(QueryTestUser::columns.id, 7)
.end_or();
let sql = query.build_where_sql_for_db(DatabaseType::Postgres);
assert_eq!(sql, "((\"name\" = 'alice' AND \"id\" = 7))");
}
#[test]
fn test_build_where_sql_escapes_inner_quotes_in_column_names() {
let query = QueryBuilder::<QueryTestUser>::new().where_eq("profile.na\"me", "active");
let sql = query.build_where_sql_for_db(DatabaseType::Postgres);
assert_eq!(sql, "\"profile\".\"na\"\"me\" = 'active'");
}
#[test]
fn test_query_validation_rejects_unknown_model_column_in_where() {
let err = QueryBuilder::<QueryTestUser>::new()
.where_eq("naem", "alice")
.ensure_query_is_valid()
.expect_err("unknown where column should invalidate query");
assert!(err.to_string().contains("unknown WHERE column 'naem'"));
assert!(err.to_string().contains("known columns: id, name"));
}
#[test]
fn test_query_validation_rejects_unknown_self_qualified_column() {
let err = QueryBuilder::<QueryTestUser>::new()
.where_eq("query_test_users.naem", "alice")
.ensure_query_is_valid()
.expect_err("unknown self-qualified column should invalidate query");
assert!(
err.to_string()
.contains("unknown WHERE column 'query_test_users.naem'")
);
}
#[test]
fn test_query_validation_allows_joined_table_column_references() {
QueryBuilder::<QueryTestUser>::new()
.inner_join("profiles", "query_test_users.id", "profiles.user_id")
.where_eq("profiles.active", true)
.order_by("profiles.created_at", Order::Desc)
.ensure_query_is_valid()
.expect("joined-table column references should remain allowed");
}
#[test]
fn test_query_validation_rejects_unknown_model_column_in_order_by() {
let err = QueryBuilder::<QueryTestUser>::new()
.order_by("naem", Order::Asc)
.ensure_query_is_valid()
.expect_err("unknown order-by column should invalidate query");
assert!(err.to_string().contains("unknown ORDER BY column 'naem'"));
}
#[test]
fn test_build_select_sql_with_params_parameterizes_read_filters() {
let query = QueryBuilder::<QueryTestUser>::new()
.select_raw("COUNT(*) as total")
.where_eq("status", "active")
.or_where(|q| q.where_eq("role", "admin").where_eq("role", "moderator"))
.where_in("department", vec!["engineering", "design"])
.order_by("name", Order::Asc)
.limit(10);
let (sql, params) = query.build_select_sql_with_params_for_db(DatabaseType::Postgres);
assert_eq!(
sql,
"SELECT COUNT(*) as total FROM \"query_test_users\" WHERE \"status\" = $1 AND \"department\" IN ($2, $3) AND (\"role\" = $4 OR \"role\" = $5) ORDER BY \"name\" ASC LIMIT 10"
);
assert_eq!(params.len(), 5);
}
#[test]
fn test_build_select_sql_with_params_uses_mysql_identifier_quoting() {
let query = QueryBuilder::<QueryTestUser>::new()
.select(vec!["id", "name"])
.where_eq("status", "active")
.order_by("name", Order::Asc)
.limit(5);
let (sql, params) = query.build_select_sql_with_params_for_db(DatabaseType::MySQL);
assert_eq!(
sql,
"SELECT `query_test_users`.`id`, `query_test_users`.`name` FROM `query_test_users` WHERE `status` = ? ORDER BY `name` ASC LIMIT 5"
);
assert_eq!(params.len(), 1);
}
#[test]
fn test_build_select_sql_with_params_inlines_postgres_array_predicates() {
let query = QueryBuilder::<QueryTestUser>::new()
.where_array_contains("tags", vec!["ops'", "core"])
.where_array_contained_by("tags", vec!["ops'", "core"])
.where_array_overlaps("tags", vec!["ops'", "core"]);
let (sql, params) = query.build_select_sql_with_params_for_db(DatabaseType::Postgres);
assert!(
sql.contains("\"tags\" @> ARRAY['ops''','core']"),
"sql: {sql}"
);
assert!(
sql.contains("\"tags\" <@ ARRAY['ops''','core']"),
"sql: {sql}"
);
assert!(
sql.contains("\"tags\" && ARRAY['ops''','core']"),
"sql: {sql}"
);
assert!(!sql.contains("$1"), "sql: {sql}");
assert!(params.is_empty(), "params: {:?}", params);
}
#[test]
fn test_build_select_sql_with_params_inlines_postgres_integer_array_predicates() {
let query = QueryBuilder::<QueryTestUser>::new()
.where_array_contains("scores", vec![5, 7])
.where_array_overlaps("scores", vec![3, 5]);
let (sql, params) = query.build_select_sql_with_params_for_db(DatabaseType::Postgres);
assert!(sql.contains("\"scores\" @> ARRAY[5,7]"), "sql: {sql}");
assert!(sql.contains("\"scores\" && ARRAY[3,5]"), "sql: {sql}");
assert!(!sql.contains("$1"), "sql: {sql}");
assert!(params.is_empty(), "params: {:?}", params);
}
#[test]
fn test_build_select_sql_with_params_parameterizes_postgres_json_predicates() {
let query = QueryBuilder::<QueryTestUser>::new()
.where_json_contains("data", serde_json::json!({"role": "admin'"}))
.where_json_key_exists("data", "unsafe'key")
.where_json_path_exists("data", "$.user.name");
let (sql, params) = query.build_select_sql_with_params_for_db(DatabaseType::Postgres);
assert!(sql.contains("\"data\" @> $1"));
assert!(sql.contains("\"data\" ? $2"));
assert!(sql.contains("\"data\" @? ($3::jsonpath)"));
assert!(!sql.contains("admin'"));
assert!(!sql.contains("unsafe'key"));
assert_eq!(params.len(), 3);
}
#[test]
fn test_build_select_sql_with_params_parameterizes_mysql_json_predicates() {
let query = QueryBuilder::<QueryTestUser>::new()
.where_json_contains("data", serde_json::json!({"role": "admin'"}))
.where_json_key_exists("data", "unsafe'key")
.where_json_path_exists("data", "$.user.name");
let (sql, params) = query.build_select_sql_with_params_for_db(DatabaseType::MySQL);
assert!(sql.contains("JSON_CONTAINS(`data`, CAST(? AS JSON))"));
assert!(sql.contains("JSON_CONTAINS_PATH(`data`, 'one', ?)"));
assert!(!sql.contains("admin'"));
assert!(!sql.contains("unsafe'key"));
assert_eq!(params.len(), 3);
assert!(
matches!(params.first(), Some(Value::String(Some(json))) if json == "{\"role\":\"admin'\"}")
);
}
#[test]
fn test_build_select_sql_with_params_parameterizes_sqlite_json_predicates() {
let query = QueryBuilder::<QueryTestUser>::new()
.where_json_contains("data", serde_json::json!("admin'"))
.where_json_path_exists("data", "$.user.name");
let (sql, params) = query.build_select_sql_with_params_for_db(DatabaseType::SQLite);
assert!(sql.contains("EXISTS (SELECT 1 FROM json_each(\"data\") WHERE value = ?)"));
assert!(sql.contains("json_extract(\"data\", ?) IS NOT NULL"));
assert!(!sql.contains("admin'"));
assert_eq!(params.len(), 2);
assert!(matches!(params.first(), Some(Value::String(Some(value))) if value == "admin'"));
}
#[test]
fn test_build_select_sql_with_params_parameterizes_mysql_array_predicates() {
let query = QueryBuilder::<QueryTestUser>::new()
.where_array_contains("tags", vec!["ops'", "core"])
.where_array_overlaps("tags", vec!["ops'", "core"]);
let (sql, params) = query.build_select_sql_with_params_for_db(DatabaseType::MySQL);
assert!(sql.contains("JSON_CONTAINS(`tags`, CAST(? AS JSON))"));
assert!(sql.contains(
"(JSON_CONTAINS(`tags`, CAST(? AS JSON)) OR JSON_CONTAINS(`tags`, CAST(? AS JSON)))"
));
assert!(!sql.contains("ops'"));
assert_eq!(params.len(), 3);
assert!(
matches!(params.first(), Some(Value::String(Some(json))) if json == "[\"ops'\",\"core\"]")
);
assert!(matches!(params.get(1), Some(Value::String(Some(json))) if json == "\"ops'\""));
assert!(matches!(params.get(2), Some(Value::String(Some(json))) if json == "\"core\""));
}
#[test]
fn test_build_select_sql_with_params_parameterizes_sqlite_array_predicates() {
let query = QueryBuilder::<QueryTestUser>::new()
.where_array_contained_by("tags", vec!["ops'", "core"])
.where_array_overlaps("tags", vec!["ops'", "core"]);
let (sql, params) = query.build_select_sql_with_params_for_db(DatabaseType::SQLite);
assert!(
sql.contains("NOT EXISTS (SELECT 1 FROM json_each(\"tags\") WHERE value NOT IN (?, ?))")
);
assert!(sql.contains("(EXISTS (SELECT 1 FROM json_each(\"tags\") WHERE value = ?) OR EXISTS (SELECT 1 FROM json_each(\"tags\") WHERE value = ?))"));
assert!(!sql.contains("ops'"));
assert_eq!(params.len(), 4);
}
#[test]
fn test_build_select_sql_with_params_quotes_reserved_identifiers() {
let query = QueryBuilder::<QueryTestUser>::new()
.select(vec!["order as group"])
.where_eq("group", "active")
.group_by("group")
.order_by("order", Order::Desc)
.limit(5);
let (postgres_sql, postgres_params) = query
.clone()
.build_select_sql_with_params_for_db(DatabaseType::Postgres);
assert_eq!(
postgres_sql,
"SELECT \"query_test_users\".\"order\" AS \"group\" FROM \"query_test_users\" WHERE \"group\" = $1 GROUP BY \"group\" ORDER BY \"order\" DESC LIMIT 5"
);
assert_eq!(postgres_params.len(), 1);
let (mysql_sql, mysql_params) = query.build_select_sql_with_params_for_db(DatabaseType::MySQL);
assert_eq!(
mysql_sql,
"SELECT `query_test_users`.`order` AS `group` FROM `query_test_users` WHERE `group` = ? GROUP BY `group` ORDER BY `order` DESC LIMIT 5"
);
assert_eq!(mysql_params.len(), 1);
}
#[test]
fn test_build_select_sql_with_params_uses_escape_clause_for_typed_literal_like_helpers() {
let name = crate::columns::Column::<String>::new("name");
let query = QueryBuilder::<QueryTestUser>::new().where_col(name.contains(r"100%_\done"));
let (postgres_sql, postgres_params) = query
.clone()
.build_select_sql_with_params_for_db(DatabaseType::Postgres);
assert!(
postgres_sql.contains(" LIKE "),
"postgres sql: {postgres_sql}"
);
assert!(
postgres_sql.contains("ESCAPE '\\'"),
"postgres sql: {postgres_sql}"
);
assert!(postgres_sql.contains("$1"), "postgres sql: {postgres_sql}");
assert_eq!(
postgres_params.len(),
1,
"postgres params: {:?}",
postgres_params
);
assert!(matches!(
postgres_params.first(),
Some(Value::String(Some(value))) if value == r"%100\%\_\\done%"
));
let (mysql_sql, mysql_params) = query.build_select_sql_with_params_for_db(DatabaseType::MySQL);
assert!(mysql_sql.contains(" LIKE "), "mysql sql: {mysql_sql}");
assert!(
mysql_sql.contains("ESCAPE '\\\\'"),
"mysql sql: {mysql_sql}"
);
assert!(mysql_sql.contains("?"), "mysql sql: {mysql_sql}");
assert_eq!(mysql_params.len(), 1, "mysql params: {:?}", mysql_params);
assert!(matches!(
mysql_params.first(),
Some(Value::String(Some(value))) if value == r"%100\%\_\\done%"
));
}
#[test]
fn test_build_select_sql_with_params_uses_escape_clause_for_query_contains_helpers() {
let query = QueryBuilder::<QueryTestUser>::new()
.where_contains("name", r"100%_\done")
.or_where_starts_with("name", r"lead%_")
.begin_or_where_ends_with("name", r"tail%_")
.end_or();
let (sql, params) = query.build_select_sql_with_params_for_db(DatabaseType::Postgres);
assert!(sql.contains("$1 ESCAPE '\\'"), "sql: {sql}");
assert!(sql.contains("$2 ESCAPE '\\'"), "sql: {sql}");
assert!(sql.contains("$3 ESCAPE '\\'"), "sql: {sql}");
assert_eq!(params.len(), 3, "params: {:?}", params);
assert!(
matches!(params.first(), Some(Value::String(Some(value))) if value == r"%100\%\_\\done%")
);
assert!(matches!(params.get(1), Some(Value::String(Some(value))) if value == r"lead\%\_%"));
assert!(matches!(params.get(2), Some(Value::String(Some(value))) if value == r"%tail\%\_"));
}
#[test]
fn test_consolidate_preserves_full_query_fragment_state() {
let original = QueryBuilder::<QueryTestUser>::new()
.where_eq("name", "alice")
.or_where_eq("name", "bob")
.select(vec!["id", "name"])
.select_raw("COUNT(*) AS total_count")
.order_desc("id")
.limit(5)
.offset(10)
.union_all(QueryBuilder::<QueryTestUser>::new().where_eq("name", "carol"))
.window(
WindowFunction::new(WindowFunctionType::RowNumber, "row_num")
.order_by("id", Order::Asc),
)
.with_cte(CTE::new(
"active_users",
"SELECT id FROM query_test_users WHERE name IS NOT NULL".to_string(),
))
.cache_with_key("fragment-key", Duration::from_secs(30));
let fragment = original.consolidate();
assert_eq!(fragment.condition_count(), 2);
assert_eq!(fragment.or_groups.len(), 1);
assert_eq!(
fragment.select_columns.as_deref(),
Some(&["id".to_string(), "name".to_string()][..])
);
assert_eq!(
fragment.raw_select_expressions,
vec!["COUNT(*) AS total_count"]
);
assert_eq!(fragment.limit_value, Some(5));
assert_eq!(fragment.offset_value, Some(10));
assert_eq!(fragment.unions.len(), 1);
assert_eq!(fragment.window_functions.len(), 1);
assert_eq!(fragment.ctes.len(), 1);
assert_eq!(fragment.cache_key.as_deref(), Some("fragment-key"));
let cache_options = fragment
.cache_options
.as_ref()
.expect("cache options should be preserved");
assert_eq!(cache_options.ttl, Duration::from_secs(30));
let rebuilt = QueryBuilder::<QueryTestUser>::from_fragment(&fragment);
assert_eq!(rebuilt.build_sql_preview(), original.build_sql_preview());
}
#[test]
fn test_window_function_sql_uses_postgres_identifier_quoting() {
let (sql, params) = QueryBuilder::<QueryTestUser>::new()
.first_value("first_name", "na\"me", "id", "na\"me", Order::Asc)
.build_select_sql_with_params_for_db(DatabaseType::Postgres);
assert!(params.is_empty());
assert!(sql.contains("FIRST_VALUE(\"na\"\"me\") OVER (PARTITION BY \"id\" ORDER BY \"na\"\"me\" ASC) AS \"first_name\""));
}
#[test]
fn test_window_function_sql_uses_mysql_identifier_quoting() {
let (sql, params) = QueryBuilder::<QueryTestUser>::new()
.first_value("first_name", "na`me", "id", "na`me", Order::Asc)
.build_select_sql_with_params_for_db(DatabaseType::MySQL);
assert!(params.is_empty());
assert!(sql.contains(
"FIRST_VALUE(`na``me`) OVER (PARTITION BY `id` ORDER BY `na``me` ASC) AS `first_name`"
));
}
#[cfg(feature = "fulltext")]
#[test]
fn test_fulltext_build_postgres_sql_parameterizes_query_and_escapes_identifiers() {
let builder = FullTextSearchBuilder::<QueryTestUser>::new(&["na\"me", "bio"], "o'hai")
.language("en'g\"lish");
let (sql, params) = builder.build_sql(DatabaseType::Postgres).unwrap();
assert!(sql.contains("SELECT * FROM \"query_test_users\""));
assert!(sql.contains("COALESCE(\"na\"\"me\", '')"));
assert!(sql.contains("plainto_tsquery(CAST($1 AS regconfig), $2)"));
assert!(
matches!(params.first(), Some(Value::String(Some(language))) if language == "en'g\"lish")
);
assert!(matches!(params.get(1), Some(Value::String(Some(query))) if query == "o'hai"));
}
#[cfg(feature = "fulltext")]
#[test]
fn test_fulltext_build_postgres_ranked_sql_binds_prefix_query_and_min_rank() {
let builder = FullTextSearchBuilder::<QueryTestUser>::new(&["name"], "quick fox")
.mode(SearchMode::Prefix)
.with_ranking()
.min_rank(0.75);
let (sql, params) = builder.build_ranked_sql(DatabaseType::Postgres).unwrap();
assert!(sql.contains("to_tsquery(CAST($1 AS regconfig), $2)"));
assert!(sql.contains(" >= $4"));
assert!(matches!(params.first(), Some(Value::String(Some(language))) if language == "english"));
assert!(
matches!(params.get(1), Some(Value::String(Some(query))) if query == "quick:* & fox:*")
);
assert!(
matches!(params.get(3), Some(Value::Double(Some(rank))) if (*rank - 0.75).abs() < f64::EPSILON)
);
}
#[cfg(feature = "fulltext")]
#[test]
fn test_fulltext_build_mysql_ranked_sql_uses_bound_values_for_all_dynamic_inputs() {
let builder = FullTextSearchBuilder::<QueryTestUser>::new(&["na`me", "bio"], "+urgent term")
.mode(SearchMode::Boolean)
.with_ranking()
.min_rank(0.5);
let (sql, params) = builder.build_ranked_sql(DatabaseType::MySQL).unwrap();
assert!(sql.contains("MATCH(`na``me`, `bio`) AGAINST(? IN BOOLEAN MODE)"));
assert!(sql.contains("AND MATCH(`na``me`, `bio`) AGAINST(? IN BOOLEAN MODE) >= ?"));
assert_eq!(params.len(), 4);
assert!(matches!(params.first(), Some(Value::String(Some(query))) if query == "+urgent term"));
assert!(matches!(params.get(1), Some(Value::String(Some(query))) if query == "+urgent term"));
assert!(
matches!(params.get(2), Some(Value::Double(Some(rank))) if (*rank - 0.5).abs() < f64::EPSILON)
);
assert!(matches!(params.get(3), Some(Value::String(Some(query))) if query == "+urgent term"));
}
#[cfg(feature = "fulltext")]
#[test]
fn test_fulltext_build_sqlite_sql_binds_escaped_fts_query() {
let builder =
FullTextSearchBuilder::<QueryTestUser>::new(&["name", "bio"], "say \"hello\" to it's")
.limit(5)
.offset(2);
let (sql, params) = builder.build_sql(DatabaseType::SQLite).unwrap();
assert!(sql.contains("SELECT t.* FROM \"query_test_users\" t"));
assert!(sql.contains("INNER JOIN \"query_test_users_fts\" fts"));
assert!(sql.contains("WHERE \"query_test_users_fts\" MATCH ?"));
assert!(sql.contains("LIMIT ? OFFSET ?"));
assert!(
matches!(params.first(), Some(Value::String(Some(query))) if query == "say \"\"hello\"\" to it''s")
);
assert!(matches!(params.get(1), Some(Value::BigInt(Some(limit))) if *limit == 5));
assert!(matches!(params.get(2), Some(Value::BigInt(Some(offset))) if *offset == 2));
}