#![allow(clippy::unwrap_used)]
use super::*;
use crate::compiler::{
aggregate_types::HavingOperator,
aggregation::{AggregateSelection, AggregationRequest, GroupBySelection},
fact_table::{DimensionColumn, FactTableMetadata, FilterColumn, MeasureColumn, SqlType},
};
fn create_aggregation_test_metadata() -> crate::compiler::fact_table::FactTableMetadata {
use crate::compiler::fact_table::{DimensionColumn, FilterColumn, MeasureColumn, SqlType};
crate::compiler::fact_table::FactTableMetadata {
table_name: "tf_sales".to_string(),
measures: vec![MeasureColumn {
name: "revenue".to_string(),
sql_type: SqlType::Decimal,
nullable: false,
}],
dimensions: DimensionColumn {
name: "data".to_string(),
paths: vec![],
},
denormalized_filters: vec![FilterColumn {
name: "customer_id".to_string(),
sql_type: SqlType::BigInt,
indexed: true,
}],
calendar_dimensions: vec![],
}
}
fn create_test_plan() -> AggregationPlan {
let metadata = FactTableMetadata {
table_name: "tf_sales".to_string(),
measures: vec![MeasureColumn {
name: "revenue".to_string(),
sql_type: SqlType::Decimal,
nullable: false,
}],
dimensions: DimensionColumn {
name: "dimensions".to_string(),
paths: vec![],
},
denormalized_filters: vec![FilterColumn {
name: "occurred_at".to_string(),
sql_type: SqlType::Timestamp,
indexed: true,
}],
calendar_dimensions: vec![],
};
let request = AggregationRequest {
table_name: "tf_sales".to_string(),
where_clause: None,
group_by: vec![
GroupBySelection::Dimension {
path: "category".to_string(),
alias: "category".to_string(),
},
GroupBySelection::TemporalBucket {
column: "occurred_at".to_string(),
bucket: TemporalBucket::Day,
alias: "day".to_string(),
},
],
aggregates: vec![
AggregateSelection::Count {
alias: "count".to_string(),
},
AggregateSelection::MeasureAggregate {
measure: "revenue".to_string(),
function: AggregateFunction::Sum,
alias: "revenue_sum".to_string(),
},
],
having: vec![],
order_by: vec![],
limit: Some(10),
offset: None,
};
crate::compiler::aggregation::AggregationPlanner::plan(request, metadata).unwrap()
}
#[test]
fn test_postgres_sql_generation() {
let plan = create_test_plan();
let generator = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = generator.generate_parameterized(&plan).unwrap();
assert!(sql.sql.contains("dimensions->>'category'"));
assert!(sql.sql.contains("DATE_TRUNC('day', occurred_at)"));
assert!(sql.sql.contains("COUNT(*)"));
assert!(sql.sql.contains("SUM(revenue)"));
assert!(sql.sql.contains("GROUP BY"));
assert!(sql.sql.contains("LIMIT 10"));
}
#[test]
fn test_mysql_sql_generation() {
let plan = create_test_plan();
let generator = AggregationSqlGenerator::new(DatabaseType::MySQL);
let sql = generator.generate_parameterized(&plan).unwrap();
assert!(sql.sql.contains("JSON_UNQUOTE(JSON_EXTRACT(dimensions, '$.category'))"));
assert!(sql.sql.contains("DATE_FORMAT(occurred_at"));
assert!(sql.sql.contains("COUNT(*)"));
assert!(sql.sql.contains("SUM(revenue)"));
}
#[test]
fn test_sqlite_sql_generation() {
let plan = create_test_plan();
let generator = AggregationSqlGenerator::new(DatabaseType::SQLite);
let sql = generator.generate_parameterized(&plan).unwrap();
assert!(sql.sql.contains("json_extract(dimensions, '$.category')"));
assert!(sql.sql.contains("strftime"));
assert!(sql.sql.contains("COUNT(*)"));
assert!(sql.sql.contains("SUM(revenue)"));
}
#[test]
fn test_sqlserver_sql_generation() {
let plan = create_test_plan();
let generator = AggregationSqlGenerator::new(DatabaseType::SQLServer);
let sql = generator.generate_parameterized(&plan).unwrap();
assert!(sql.sql.contains("JSON_VALUE(dimensions, '$.category')"));
assert!(sql.sql.contains("CAST(occurred_at AS DATE)"));
assert!(sql.sql.contains("COUNT(*)"));
assert!(sql.sql.contains("SUM(revenue)"));
}
#[test]
fn test_having_clause() {
let mut plan = create_test_plan();
plan.having_conditions = vec![ValidatedHavingCondition {
aggregate: AggregateExpression::MeasureAggregate {
column: "revenue".to_string(),
function: AggregateFunction::Sum,
alias: "revenue_sum".to_string(),
},
operator: HavingOperator::Gt,
value: serde_json::json!(1000),
}];
let generator = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = generator.generate_parameterized(&plan).unwrap();
assert!(sql.sql.contains("HAVING SUM(revenue) > $1"));
assert_eq!(sql.params, vec![serde_json::json!(1000)]);
}
#[test]
fn test_order_by_clause() {
use crate::compiler::aggregation::OrderByClause;
let mut plan = create_test_plan();
plan.request.order_by = vec![OrderByClause::new(
"revenue_sum".to_string(),
OrderDirection::Desc,
)];
let generator = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = generator.generate_parameterized(&plan).unwrap();
assert!(sql.sql.contains("ORDER BY \"revenue_sum\" DESC"));
}
mod native_where {
use fraiseql_db::where_clause::{WhereClause, WhereOperator};
use super::*;
fn plan_with_native_where(
column: &str,
pg_cast: &str,
value: serde_json::Value,
) -> AggregationPlan {
let mut plan = create_test_plan();
plan.request.where_clause = Some(WhereClause::NativeField {
column: column.to_string(),
pg_cast: pg_cast.to_string(),
operator: WhereOperator::Eq,
value,
});
plan
}
#[test]
fn postgres_native_uuid_where() {
let plan = plan_with_native_where("order_id", "uuid", serde_json::json!("abc-123"));
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = gen.generate_parameterized(&plan).unwrap().sql;
assert!(sql.contains(r#""order_id" = $1::uuid"#), "got: {sql}");
}
#[test]
fn postgres_native_int_where() {
let plan = plan_with_native_where("customer_id", "int8", serde_json::json!(42));
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = gen.generate_parameterized(&plan).unwrap().sql;
assert!(sql.contains(r#""customer_id" = $1::int8"#), "got: {sql}");
}
#[test]
fn postgres_native_no_cast_where() {
let plan = plan_with_native_where("status", "", serde_json::json!("active"));
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = gen.generate_parameterized(&plan).unwrap().sql;
assert!(sql.contains(r#""status" = $1"#), "got: {sql}");
assert!(!sql.contains("::"), "unexpected cast: {sql}");
}
#[test]
fn mysql_native_where() {
let plan = plan_with_native_where("customer_id", "int8", serde_json::json!(42));
let gen = AggregationSqlGenerator::new(DatabaseType::MySQL);
let sql = gen.generate_parameterized(&plan).unwrap().sql;
assert!(sql.contains("`customer_id` = ?"), "got: {sql}");
}
#[test]
fn sqlite_native_where() {
let plan = plan_with_native_where("customer_id", "int8", serde_json::json!(42));
let gen = AggregationSqlGenerator::new(DatabaseType::SQLite);
let sql = gen.generate_parameterized(&plan).unwrap().sql;
assert!(sql.contains(r#""customer_id" = ?"#), "got: {sql}");
}
#[test]
fn sqlserver_native_where() {
let plan = plan_with_native_where("customer_id", "int8", serde_json::json!(42));
let gen = AggregationSqlGenerator::new(DatabaseType::SQLServer);
let sql = gen.generate_parameterized(&plan).unwrap().sql;
assert!(sql.contains("[customer_id] = @P1"), "got: {sql}");
}
#[test]
fn and_wrapping_native_field() {
let mut plan = create_test_plan();
plan.request.where_clause = Some(WhereClause::And(vec![
WhereClause::NativeField {
column: "customer_id".to_string(),
pg_cast: "int8".to_string(),
operator: WhereOperator::Eq,
value: serde_json::json!(1),
},
WhereClause::NativeField {
column: "status".to_string(),
pg_cast: String::new(),
operator: WhereOperator::Eq,
value: serde_json::json!("active"),
},
]));
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = gen.generate_parameterized(&plan).unwrap().sql;
assert!(sql.contains(r#""customer_id" = $1::int8"#), "got: {sql}");
assert!(sql.contains(r#""status" = $2"#), "got: {sql}");
}
}
mod native_groupby {
use super::*;
use crate::compiler::aggregation::GroupByExpression;
fn plan_with_native_groupby() -> AggregationPlan {
let mut plan = create_test_plan();
plan.group_by_expressions = vec![GroupByExpression::NativeColumn {
column: "customer_id".to_string(),
pg_cast: "int8".to_string(),
alias: "customer_id".to_string(),
}];
plan
}
#[test]
fn postgres_native_groupby_select_clause() {
let plan = plan_with_native_groupby();
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = gen.generate_parameterized(&plan).unwrap().sql;
assert!(sql.contains(r#""customer_id" AS customer_id"#), "got: {sql}");
assert!(!sql.contains("data->>'customer_id'"), "unexpected JSONB ref: {sql}");
}
#[test]
fn postgres_native_groupby_group_by_clause() {
let plan = plan_with_native_groupby();
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = gen.generate_parameterized(&plan).unwrap().sql;
assert!(sql.contains(r#"GROUP BY "customer_id""#), "got: {sql}");
assert!(!sql.contains("data->>'customer_id'"), "unexpected JSONB ref: {sql}");
}
#[test]
fn mixed_native_and_jsonb_groupby() {
let mut plan = create_test_plan();
plan.group_by_expressions = vec![
GroupByExpression::NativeColumn {
column: "customer_id".to_string(),
pg_cast: "int8".to_string(),
alias: "customer_id".to_string(),
},
GroupByExpression::JsonbPath {
jsonb_column: "data".to_string(),
path: "status".to_string(),
alias: "status".to_string(),
},
];
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = gen.generate_parameterized(&plan).unwrap().sql;
assert!(sql.contains(r#""customer_id" AS customer_id"#), "got: {sql}");
assert!(sql.contains("data->>'status'"), "got: {sql}");
assert!(sql.contains("AS status"), "got: {sql}");
assert!(sql.contains(r#""customer_id""#), "got: {sql}");
}
#[test]
fn mysql_native_groupby() {
let plan = plan_with_native_groupby();
let gen = AggregationSqlGenerator::new(DatabaseType::MySQL);
let sql = gen.generate_parameterized(&plan).unwrap().sql;
assert!(sql.contains("`customer_id` AS customer_id"), "got: {sql}");
assert!(sql.contains("GROUP BY `customer_id`"), "got: {sql}");
}
}
mod native_columns_integration {
use std::collections::HashMap;
use fraiseql_db::where_clause::WhereClause;
use super::*;
use crate::{
compiler::aggregation::{AggregationPlanner, GroupByExpression},
runtime::aggregate_parser::AggregateQueryParser,
};
fn native_cols() -> HashMap<String, String> {
std::iter::once(("customer_id".to_string(), "int8".to_string())).collect()
}
#[test]
fn parser_emits_native_where_when_column_in_map() {
let query_json = serde_json::json!({
"table": "tf_sales",
"where": { "customer_id_eq": 42 },
"groupBy": { "status": true },
"aggregates": [{ "count": {} }]
});
let metadata = create_aggregation_test_metadata();
let native = native_cols();
let request = AggregateQueryParser::parse(&query_json, &metadata, &native).unwrap();
let where_clause = request.where_clause.unwrap();
let found_native = match &where_clause {
WhereClause::And(clauses) => clauses.iter().any(
|c| matches!(c, WhereClause::NativeField { column, .. } if column == "customer_id"),
),
WhereClause::NativeField { column, .. } => column == "customer_id",
_ => false,
};
assert!(found_native, "expected NativeField for customer_id, got: {where_clause:?}");
}
#[test]
fn parser_emits_native_groupby_when_column_in_map() {
let query_json = serde_json::json!({
"table": "tf_sales",
"groupBy": { "customer_id": true, "status": true },
"aggregates": [{ "count": {} }]
});
let metadata = create_aggregation_test_metadata();
let native = native_cols();
let request = AggregateQueryParser::parse(&query_json, &metadata, &native).unwrap();
let plan = AggregationPlanner::plan(request, metadata).unwrap();
let has_native = plan.group_by_expressions.iter().any(|e| {
matches!(e, GroupByExpression::NativeColumn { column, .. } if column == "customer_id")
});
assert!(
has_native,
"expected NativeColumn for customer_id; got: {:?}",
plan.group_by_expressions
);
let has_jsonb = plan
.group_by_expressions
.iter()
.any(|e| matches!(e, GroupByExpression::JsonbPath { path, .. } if path == "status"));
assert!(has_jsonb, "expected JsonbPath for status; got: {:?}", plan.group_by_expressions);
}
#[test]
fn full_sql_uses_native_column_references() {
let query_json = serde_json::json!({
"table": "tf_sales",
"where": { "customer_id_eq": 42 },
"groupBy": { "customer_id": true },
"aggregates": [{ "count": {} }]
});
let metadata = create_aggregation_test_metadata();
let native = native_cols();
let request = AggregateQueryParser::parse(&query_json, &metadata, &native).unwrap();
let plan = AggregationPlanner::plan(request, metadata).unwrap();
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let result = gen.generate_parameterized(&plan).unwrap();
assert!(result.sql.contains(r#""customer_id""#), "got: {}", result.sql);
assert!(result.sql.contains("$1::int8"), "got: {}", result.sql);
assert!(!result.sql.contains("data->>'customer_id'"), "unexpected JSONB: {}", result.sql);
}
#[test]
fn empty_native_map_falls_back_to_jsonb() {
let query_json = serde_json::json!({
"table": "tf_sales",
"groupBy": { "customer_id": true },
"aggregates": [{ "count": {} }]
});
let metadata = create_aggregation_test_metadata();
let empty: HashMap<String, String> = HashMap::new();
let request = AggregateQueryParser::parse(&query_json, &metadata, &empty).unwrap();
let plan = AggregationPlanner::plan(request, metadata).unwrap();
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let result = gen.generate_parameterized(&plan).unwrap();
assert!(
result.sql.contains("data->>'customer_id'"),
"expected JSONB fallback: {}",
result.sql
);
}
}
mod native_orderby {
use super::*;
use crate::compiler::aggregation::{GroupByExpression, OrderByClause, OrderDirection};
#[test]
fn order_by_native_column_uses_alias_not_jsonb() {
let mut plan = create_test_plan();
plan.group_by_expressions = vec![GroupByExpression::NativeColumn {
column: "customer_id".to_string(),
pg_cast: "int8".to_string(),
alias: "customer_id".to_string(),
}];
plan.request.order_by = vec![OrderByClause::new(
"customer_id".to_string(),
OrderDirection::Asc,
)];
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = gen.generate_parameterized(&plan).unwrap().sql;
assert!(sql.contains(r#"ORDER BY "customer_id" ASC"#), "got: {sql}");
assert!(!sql.contains("data->>'customer_id'"), "unexpected JSONB in ORDER BY: {sql}");
}
#[test]
fn order_by_jsonb_dimension_unchanged() {
let mut plan = create_test_plan();
plan.request.order_by = vec![OrderByClause::new(
"category".to_string(),
OrderDirection::Desc,
)];
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = gen.generate_parameterized(&plan).unwrap().sql;
assert!(sql.contains(r#"ORDER BY "category" DESC"#), "got: {sql}");
}
#[test]
fn native_aliases_helper_returns_correct_set() {
let mut plan = create_test_plan();
plan.group_by_expressions = vec![
GroupByExpression::NativeColumn {
column: "customer_id".to_string(),
pg_cast: "int8".to_string(),
alias: "customer_id".to_string(),
},
GroupByExpression::JsonbPath {
jsonb_column: "data".to_string(),
path: "status".to_string(),
alias: "status".to_string(),
},
];
let aliases = plan.native_aliases();
assert!(aliases.contains("customer_id"), "expected customer_id in {aliases:?}");
assert!(!aliases.contains("status"), "status should not be native: {aliases:?}");
}
}
#[test]
fn test_array_agg_postgres() {
let generator = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = generator.generate_array_agg_sql("product_id", None);
assert_eq!(sql, "ARRAY_AGG(product_id)");
let order_by = vec![OrderByClause::new(
"revenue".to_string(),
OrderDirection::Desc,
)];
let sql = generator.generate_array_agg_sql("product_id", Some(&order_by));
assert_eq!(sql, "ARRAY_AGG(product_id ORDER BY \"revenue\" DESC)");
}
#[test]
fn test_array_agg_mysql() {
let generator = AggregationSqlGenerator::new(DatabaseType::MySQL);
let sql = generator.generate_array_agg_sql("product_id", None);
assert_eq!(sql, "JSON_ARRAYAGG(product_id)");
}
#[test]
fn test_array_agg_sqlite() {
let generator = AggregationSqlGenerator::new(DatabaseType::SQLite);
let sql = generator.generate_array_agg_sql("product_id", None);
assert!(sql.contains("GROUP_CONCAT"));
assert!(sql.contains("'[' ||"));
assert!(sql.contains("|| ']'"));
}
#[test]
fn test_string_agg_postgres() {
let generator = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = generator.generate_string_agg_sql("product_name", ", ", None);
assert_eq!(sql, "STRING_AGG(product_name, ', ')");
let order_by = vec![OrderByClause::new(
"revenue".to_string(),
OrderDirection::Desc,
)];
let sql = generator.generate_string_agg_sql("product_name", ", ", Some(&order_by));
assert_eq!(sql, "STRING_AGG(product_name, ', ' ORDER BY \"revenue\" DESC)");
}
#[test]
fn test_string_agg_mysql() {
let generator = AggregationSqlGenerator::new(DatabaseType::MySQL);
let order_by = vec![OrderByClause::new(
"revenue".to_string(),
OrderDirection::Desc,
)];
let sql = generator.generate_string_agg_sql("product_name", ", ", Some(&order_by));
assert_eq!(sql, "GROUP_CONCAT(product_name ORDER BY `revenue` DESC SEPARATOR ', ')");
}
#[test]
fn test_string_agg_sqlserver() {
let generator = AggregationSqlGenerator::new(DatabaseType::SQLServer);
let order_by = vec![OrderByClause::new(
"revenue".to_string(),
OrderDirection::Desc,
)];
let sql = generator.generate_string_agg_sql("product_name", ", ", Some(&order_by));
assert!(sql.contains("STRING_AGG(CAST(product_name AS NVARCHAR(MAX)), ', ')"));
assert!(sql.contains("WITHIN GROUP (ORDER BY [revenue] DESC)"));
}
#[test]
fn test_json_agg_postgres() {
let generator = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = generator.generate_json_agg_sql("data", None);
assert_eq!(sql, "JSON_AGG(data)");
let order_by = vec![OrderByClause::new(
"created_at".to_string(),
OrderDirection::Asc,
)];
let sql = generator.generate_json_agg_sql("data", Some(&order_by));
assert_eq!(sql, "JSON_AGG(data ORDER BY \"created_at\" ASC)");
}
#[test]
fn test_jsonb_agg_postgres() {
let generator = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = generator.generate_jsonb_agg_sql("data", None);
assert_eq!(sql, "JSONB_AGG(data)");
}
#[test]
fn test_bool_and_postgres() {
use crate::compiler::aggregate_types::BoolAggregateFunction;
let generator = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = generator.generate_bool_agg_sql("is_active", BoolAggregateFunction::And);
assert_eq!(sql, "BOOL_AND(is_active)");
let sql = generator.generate_bool_agg_sql("has_discount", BoolAggregateFunction::Or);
assert_eq!(sql, "BOOL_OR(has_discount)");
}
#[test]
fn test_bool_and_mysql() {
use crate::compiler::aggregate_types::BoolAggregateFunction;
let generator = AggregationSqlGenerator::new(DatabaseType::MySQL);
let sql = generator.generate_bool_agg_sql("is_active", BoolAggregateFunction::And);
assert_eq!(sql, "MIN(is_active) = 1");
let sql = generator.generate_bool_agg_sql("has_discount", BoolAggregateFunction::Or);
assert_eq!(sql, "MAX(has_discount) = 1");
}
#[test]
fn test_bool_and_sqlserver() {
use crate::compiler::aggregate_types::BoolAggregateFunction;
let generator = AggregationSqlGenerator::new(DatabaseType::SQLServer);
let sql = generator.generate_bool_agg_sql("is_active", BoolAggregateFunction::And);
assert_eq!(sql, "MIN(CAST(is_active AS BIT)) = 1");
let sql = generator.generate_bool_agg_sql("has_discount", BoolAggregateFunction::Or);
assert_eq!(sql, "MAX(CAST(has_discount AS BIT)) = 1");
}
#[test]
fn test_advanced_aggregate_full_query() {
let mut plan = create_test_plan();
plan.aggregate_expressions.push(AggregateExpression::AdvancedAggregate {
column: "product_id".to_string(),
function: AggregateFunction::ArrayAgg,
alias: "products".to_string(),
delimiter: None,
order_by: Some(vec![OrderByClause::new(
"revenue".to_string(),
OrderDirection::Desc,
)]),
});
plan.aggregate_expressions.push(AggregateExpression::AdvancedAggregate {
column: "product_name".to_string(),
function: AggregateFunction::StringAgg,
alias: "product_names".to_string(),
delimiter: Some(", ".to_string()),
order_by: None,
});
let generator = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = generator.generate_parameterized(&plan).unwrap();
assert!(sql.sql.contains("ARRAY_AGG(product_id ORDER BY \"revenue\" DESC)"));
assert!(sql.sql.contains("STRING_AGG(product_name, ', ')"));
}
#[test]
fn test_having_string_value_is_bound_not_escaped() {
use crate::compiler::aggregate_types::AggregateFunction;
let mut plan = create_test_plan();
plan.having_conditions = vec![ValidatedHavingCondition {
aggregate: AggregateExpression::MeasureAggregate {
column: "label".to_string(),
function: AggregateFunction::Max,
alias: "label_max".to_string(),
},
operator: HavingOperator::Eq,
value: serde_json::json!("O'Reilly"),
}];
let generator = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = generator.generate_parameterized(&plan).unwrap();
assert!(sql.sql.contains("HAVING MAX(label) = $1"));
assert!(!sql.sql.contains("O'Reilly"), "raw string must not appear in SQL: {}", sql.sql);
assert_eq!(sql.params, vec![serde_json::json!("O'Reilly")]);
}
#[test]
fn test_escape_sql_string_mysql_doubles_backslash() {
let gen = AggregationSqlGenerator::new(DatabaseType::MySQL);
assert_eq!(gen.escape_sql_string("test\\"), "test\\\\");
assert_eq!(gen.escape_sql_string("te'st"), "te''st");
assert_eq!(gen.escape_sql_string("te\\'st"), "te\\\\''st");
}
#[test]
fn test_escape_sql_string_postgres_only_doubles_quote() {
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
assert_eq!(gen.escape_sql_string("test\\"), "test\\");
assert_eq!(gen.escape_sql_string("te'st"), "te''st");
}
#[test]
fn test_escape_sql_string_strips_null_bytes() {
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
assert_eq!(gen.escape_sql_string("before\x00after"), "beforeafter");
assert_eq!(gen.escape_sql_string("\x00"), "");
assert_eq!(gen.escape_sql_string("no-null"), "no-null");
let mysql = AggregationSqlGenerator::new(DatabaseType::MySQL);
assert_eq!(mysql.escape_sql_string("te\x00st\\"), "test\\\\");
}
#[test]
fn test_jsonb_postgres_single_quote_escaped() {
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = gen.jsonb_extract_sql("dimensions", "user'name");
assert!(sql.contains("user''name"), "Expected doubled quote, got: {sql}");
assert!(!sql.contains("user'name'"), "Unescaped quote still present");
}
#[test]
fn test_jsonb_postgres_pg_sleep_injection_neutralised() {
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = gen.jsonb_extract_sql("dimensions", "a' || pg_sleep(10) --");
assert!(sql.contains("a'' || pg_sleep(10) --"), "Escaping not applied: {sql}");
}
#[test]
fn test_jsonb_postgres_clean_path_unchanged() {
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = gen.jsonb_extract_sql("dimensions", "category");
assert!(sql.contains("dimensions->>'category'"), "Clean path altered: {sql}");
}
#[test]
fn test_jsonb_mysql_single_quote_escaped() {
let gen = AggregationSqlGenerator::new(DatabaseType::MySQL);
let sql = gen.jsonb_extract_sql("dimensions", "user'name");
assert!(sql.contains("user''name"), "Expected doubled-quote escape in MySQL: {sql}");
}
#[test]
fn test_jsonb_mysql_path_prefix_not_doubled() {
let gen = AggregationSqlGenerator::new(DatabaseType::MySQL);
let sql = gen.jsonb_extract_sql("dimensions", "category");
assert!(sql.contains("$.category"), "Path prefix missing: {sql}");
assert!(!sql.contains("$.$."), "Double prefix detected: {sql}");
}
#[test]
fn test_jsonb_sqlite_single_quote_escaped() {
let gen = AggregationSqlGenerator::new(DatabaseType::SQLite);
let sql = gen.jsonb_extract_sql("dimensions", "it's");
assert!(sql.contains("it''s"), "Expected doubled-quote escape in SQLite: {sql}");
}
#[test]
fn test_jsonb_sqlserver_single_quote_escaped() {
let gen = AggregationSqlGenerator::new(DatabaseType::SQLServer);
let sql = gen.jsonb_extract_sql("dimensions", "user'name");
assert!(sql.contains("user''name"), "Expected doubled quote in SQL Server: {sql}");
}
#[test]
fn test_stringagg_delimiter_single_quote_escaped() {
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = gen.generate_string_agg_sql("product_name", "O'Reilly", None);
assert!(sql.contains("'O''Reilly'"), "single quote must be doubled: {sql}");
assert!(!sql.contains("'O'Reilly'"), "unescaped quote must not appear");
}
#[test]
fn test_stringagg_delimiter_injection_payload_neutralised() {
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let payload = "'; DROP TABLE users; --";
let sql = gen.generate_string_agg_sql("product_name", payload, None);
assert!(sql.contains("''"), "single quotes must be doubled: {sql}");
assert!(sql.starts_with("STRING_AGG("), "must remain a STRING_AGG call: {sql}");
}
#[test]
fn test_stringagg_delimiter_mysql_backslash_and_quote_escaped() {
let gen = AggregationSqlGenerator::new(DatabaseType::MySQL);
let sql = gen.generate_string_agg_sql("col", r"a\b", None);
assert!(sql.contains(r"a\\b"), "backslash must be doubled for MySQL: {sql}");
}
#[test]
fn test_stringagg_delimiter_mysql_single_quote_escaped() {
let gen = AggregationSqlGenerator::new(DatabaseType::MySQL);
let sql = gen.generate_string_agg_sql("col", "O'Reilly", None);
assert!(sql.contains("O''Reilly"), "single quote must be doubled for MySQL: {sql}");
}
#[test]
fn test_stringagg_delimiter_sqlite_single_quote_escaped() {
let gen = AggregationSqlGenerator::new(DatabaseType::SQLite);
let sql = gen.generate_string_agg_sql("col", "it's", None);
assert!(sql.contains("it''s"), "single quote must be doubled for SQLite: {sql}");
}
#[test]
fn test_stringagg_delimiter_sqlserver_single_quote_escaped() {
let gen = AggregationSqlGenerator::new(DatabaseType::SQLServer);
let sql = gen.generate_string_agg_sql("col", "O'Reilly", None);
assert!(sql.contains("O''Reilly"), "single quote must be doubled for SQL Server: {sql}");
}
#[test]
fn test_stringagg_delimiter_clean_value_unchanged() {
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let sql = gen.generate_string_agg_sql("product_name", ", ", None);
assert_eq!(sql, "STRING_AGG(product_name, ', ')");
}
fn make_string_where_plan(_db: DatabaseType) -> AggregationPlan {
let metadata = FactTableMetadata {
table_name: "tf_sales".to_string(),
measures: vec![],
dimensions: DimensionColumn {
name: "data".to_string(),
paths: vec![],
},
denormalized_filters: vec![FilterColumn {
name: "status".to_string(),
sql_type: SqlType::Timestamp,
indexed: true,
}],
calendar_dimensions: vec![],
};
let request = AggregationRequest {
table_name: "tf_sales".to_string(),
where_clause: Some(WhereClause::Field {
path: vec!["status".to_string()],
operator: WhereOperator::Eq,
value: serde_json::json!("test_value"),
}),
group_by: vec![GroupBySelection::Dimension {
path: "category".to_string(),
alias: "category".to_string(),
}],
aggregates: vec![AggregateSelection::Count {
alias: "count".to_string(),
}],
having: vec![],
order_by: vec![],
limit: None,
offset: None,
};
crate::compiler::aggregation::AggregationPlanner::plan(request, metadata).unwrap()
}
#[test]
fn test_generate_parameterized_where_string_becomes_placeholder() {
let plan = make_string_where_plan(DatabaseType::PostgreSQL);
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let result = gen.generate_parameterized(&plan).unwrap();
assert!(result.sql.contains("$1"), "PostgreSQL placeholder must be $1: {}", result.sql);
assert!(
!result.sql.contains("'test_value'"),
"String value must not appear as literal: {}",
result.sql
);
assert_eq!(result.params.len(), 1);
assert_eq!(result.params[0], serde_json::json!("test_value"));
}
#[test]
fn test_generate_parameterized_having_string_becomes_placeholder() {
let injection = "test\\' injection";
let mut plan = create_test_plan();
plan.having_conditions = vec![ValidatedHavingCondition {
aggregate: AggregateExpression::MeasureAggregate {
column: "revenue".to_string(),
function: AggregateFunction::Sum,
alias: "revenue_sum".to_string(),
},
operator: HavingOperator::Eq,
value: serde_json::json!(injection),
}];
let gen = AggregationSqlGenerator::new(DatabaseType::MySQL);
let result = gen.generate_parameterized(&plan).unwrap();
assert!(
result.sql.contains("HAVING SUM(revenue) = ?"),
"SQL must use ? placeholder: {}",
result.sql
);
assert_eq!(result.params.len(), 1);
assert_eq!(result.params[0], serde_json::json!(injection));
assert!(
!result.sql.contains("injection"),
"Injection string must not appear in SQL: {}",
result.sql
);
}
#[test]
fn test_parameterized_postgres_placeholder_numbering() {
let injection = "risky";
let metadata = FactTableMetadata {
table_name: "tf_sales".to_string(),
measures: vec![MeasureColumn {
name: "revenue".to_string(),
sql_type: SqlType::Decimal,
nullable: false,
}],
dimensions: DimensionColumn {
name: "dimensions".to_string(),
paths: vec![],
},
denormalized_filters: vec![
FilterColumn {
name: "occurred_at".to_string(),
sql_type: SqlType::Timestamp,
indexed: true,
},
FilterColumn {
name: "channel".to_string(),
sql_type: SqlType::Timestamp,
indexed: true,
},
],
calendar_dimensions: vec![],
};
let request = AggregationRequest {
table_name: "tf_sales".to_string(),
where_clause: Some(WhereClause::Field {
path: vec!["channel".to_string()],
operator: WhereOperator::Eq,
value: serde_json::json!(injection),
}),
group_by: vec![GroupBySelection::TemporalBucket {
column: "occurred_at".to_string(),
bucket: TemporalBucket::Day,
alias: "day".to_string(),
}],
aggregates: vec![AggregateSelection::MeasureAggregate {
measure: "revenue".to_string(),
function: AggregateFunction::Sum,
alias: "total".to_string(),
}],
having: vec![],
order_by: vec![],
limit: None,
offset: None,
};
let mut plan =
crate::compiler::aggregation::AggregationPlanner::plan(request, metadata).unwrap();
plan.having_conditions = vec![ValidatedHavingCondition {
aggregate: AggregateExpression::MeasureAggregate {
column: "revenue".to_string(),
function: AggregateFunction::Sum,
alias: "total".to_string(),
},
operator: HavingOperator::Gt,
value: serde_json::json!("threshold"),
}];
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let result = gen.generate_parameterized(&plan).unwrap();
assert!(result.sql.contains("WHERE channel = $1"), "SQL: {}", result.sql);
assert!(result.sql.contains("HAVING SUM(revenue) > $2"), "SQL: {}", result.sql);
assert_eq!(result.params.len(), 2);
assert_eq!(result.params[0], serde_json::json!(injection));
assert_eq!(result.params[1], serde_json::json!("threshold"));
}
#[test]
fn test_parameterized_mysql_uses_question_mark() {
let plan = make_string_where_plan(DatabaseType::MySQL);
let gen = AggregationSqlGenerator::new(DatabaseType::MySQL);
let result = gen.generate_parameterized(&plan).unwrap();
assert!(result.sql.contains("WHERE status = ?"), "SQL: {}", result.sql);
assert_eq!(result.params.len(), 1);
assert_eq!(result.params[0], serde_json::json!("test_value"));
}
#[test]
fn test_parameterized_sqlserver_uses_at_p_placeholder() {
let plan = make_string_where_plan(DatabaseType::SQLServer);
let gen = AggregationSqlGenerator::new(DatabaseType::SQLServer);
let result = gen.generate_parameterized(&plan).unwrap();
assert!(result.sql.contains("WHERE status = @P1"), "SQL: {}", result.sql);
assert_eq!(result.params.len(), 1);
assert_eq!(result.params[0], serde_json::json!("test_value"));
}
#[test]
fn test_parameterized_in_array_expands_to_multiple_placeholders() {
let metadata = FactTableMetadata {
table_name: "tf_sales".to_string(),
measures: vec![],
dimensions: DimensionColumn {
name: "data".to_string(),
paths: vec![],
},
denormalized_filters: vec![FilterColumn {
name: "status".to_string(),
sql_type: SqlType::Timestamp,
indexed: true,
}],
calendar_dimensions: vec![],
};
let request = AggregationRequest {
table_name: "tf_sales".to_string(),
where_clause: Some(WhereClause::Field {
path: vec!["status".to_string()],
operator: WhereOperator::In,
value: serde_json::json!(["a", "b", "c"]),
}),
group_by: vec![],
aggregates: vec![AggregateSelection::Count {
alias: "count".to_string(),
}],
having: vec![],
order_by: vec![],
limit: None,
offset: None,
};
let plan = crate::compiler::aggregation::AggregationPlanner::plan(request, metadata).unwrap();
let gen = AggregationSqlGenerator::new(DatabaseType::PostgreSQL);
let result = gen.generate_parameterized(&plan).unwrap();
assert!(
result.sql.contains("status IN ($1, $2, $3)"),
"IN clause must expand to 3 placeholders: {}",
result.sql
);
assert_eq!(result.params.len(), 3);
assert_eq!(result.params[0], serde_json::json!("a"));
assert_eq!(result.params[1], serde_json::json!("b"));
assert_eq!(result.params[2], serde_json::json!("c"));
}