#![allow(dead_code)]
#![allow(unused_imports)]
use std::borrow::Cow;
use std::sync::{Arc, Mutex};
use prax_orm::{Model, client};
use prax_query::capabilities::{SupportsNestedWrites, SupportsScalarSubqueryInSelect};
use prax_query::dialect::SqlDialect;
use prax_query::error::{QueryError, QueryResult};
use prax_query::filter::{Filter, FilterValue};
use prax_query::row::{FromRow, RowError, RowRef};
use prax_query::traits::{BoxFuture, Model as ModelTrait, QueryEngine};
use prax_query::types::{OrderBy, OrderByField, SortOrder};
use prax_query::{
AggregateField, AggregateOperation, GroupByOperation, HavingCondition, HavingOp, having,
};
type StatementLog = Arc<Mutex<Vec<(String, Vec<FilterValue>)>>>;
#[derive(Clone)]
struct RecordingEngine {
recorded: StatementLog,
}
impl RecordingEngine {
fn new() -> Self {
Self {
recorded: Arc::new(Mutex::new(Vec::new())),
}
}
fn statements(&self) -> Vec<(String, Vec<FilterValue>)> {
self.recorded.lock().unwrap().clone()
}
}
impl QueryEngine for RecordingEngine {
fn dialect(&self) -> &dyn SqlDialect {
&prax_query::dialect::Postgres
}
fn query_many<T: ModelTrait + FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<Vec<T>>> {
let recorded = self.recorded.clone();
let sql = sql.to_string();
Box::pin(async move {
recorded.lock().unwrap().push((sql, params));
Ok(Vec::new())
})
}
fn query_one<T: ModelTrait + FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<T>> {
let recorded = self.recorded.clone();
let sql = sql.to_string();
Box::pin(async move {
recorded.lock().unwrap().push((sql, params));
T::from_row(&CannedRow).map_err(|e| QueryError::internal(e.to_string()))
})
}
fn query_optional<T: ModelTrait + FromRow + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<Option<T>>> {
Box::pin(async { Ok(None) })
}
fn execute_insert<T: ModelTrait + FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<T>> {
let recorded = self.recorded.clone();
let sql = sql.to_string();
Box::pin(async move {
recorded.lock().unwrap().push((sql, params));
T::from_row(&CannedRow).map_err(|e| QueryError::internal(e.to_string()))
})
}
fn execute_update<T: ModelTrait + FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<Vec<T>>> {
let recorded = self.recorded.clone();
let sql = sql.to_string();
Box::pin(async move {
recorded.lock().unwrap().push((sql, params));
Ok(Vec::new())
})
}
fn execute_delete(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<u64>> {
Box::pin(async { Ok(0) })
}
fn execute_raw(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
let recorded = self.recorded.clone();
let sql = sql.to_string();
Box::pin(async move {
recorded.lock().unwrap().push((sql, params));
Ok(1)
})
}
fn count(&self, _sql: &str, _params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
Box::pin(async { Ok(0) })
}
}
impl SupportsNestedWrites for RecordingEngine {}
impl SupportsScalarSubqueryInSelect for RecordingEngine {}
struct CannedRow;
impl RowRef for CannedRow {
fn get_i32(&self, _column: &str) -> Result<i32, RowError> {
Ok(1)
}
fn get_i32_opt(&self, _column: &str) -> Result<Option<i32>, RowError> {
Ok(Some(1))
}
fn get_i64(&self, _column: &str) -> Result<i64, RowError> {
Ok(0)
}
fn get_i64_opt(&self, _column: &str) -> Result<Option<i64>, RowError> {
Ok(None)
}
fn get_f64(&self, _column: &str) -> Result<f64, RowError> {
Ok(0.0)
}
fn get_f64_opt(&self, _column: &str) -> Result<Option<f64>, RowError> {
Ok(None)
}
fn get_bool(&self, _column: &str) -> Result<bool, RowError> {
Ok(false)
}
fn get_bool_opt(&self, _column: &str) -> Result<Option<bool>, RowError> {
Ok(None)
}
fn get_str(&self, _column: &str) -> Result<&str, RowError> {
Ok("canned")
}
fn get_str_opt(&self, _column: &str) -> Result<Option<&str>, RowError> {
Ok(Some("canned"))
}
fn get_bytes(&self, _column: &str) -> Result<&[u8], RowError> {
Ok(b"")
}
fn get_bytes_opt(&self, _column: &str) -> Result<Option<&[u8]>, RowError> {
Ok(None)
}
}
#[derive(Model, Debug, Clone, Default)]
#[prax(table = "users")]
pub struct User {
#[prax(id, auto)]
pub id: i32,
#[prax(unique)]
pub email: String,
pub team_id: i32,
pub region: String,
pub active: bool,
pub views: i32,
pub score: i32,
}
client!(User);
#[test]
fn count_select_emits_per_column_counts() {
let op: AggregateOperation<User, RecordingEngine> = AggregateOperation::new()
.count() .count_column("email");
let (sql, params) = op.build_sql(&prax_query::dialect::Postgres);
assert!(sql.contains("COUNT(*)"), "missing COUNT(*); got: {sql}");
assert!(
sql.contains("COUNT(email)"),
"missing COUNT(email); got: {sql}"
);
assert!(params.is_empty(), "no params expected; got: {params:?}");
}
#[test]
fn aggregate_emits_all_five_functions() {
let op: AggregateOperation<User, RecordingEngine> = AggregateOperation::new()
.sum("views")
.avg("score")
.min("views")
.max("views")
.count();
let (sql, params) = op.build_sql(&prax_query::dialect::Postgres);
assert!(sql.contains("SUM(views)"), "missing SUM(views); got: {sql}");
assert!(sql.contains("AVG(score)"), "missing AVG(score); got: {sql}");
assert!(sql.contains("MIN(views)"), "missing MIN(views); got: {sql}");
assert!(sql.contains("MAX(views)"), "missing MAX(views); got: {sql}");
assert!(sql.contains("COUNT(*)"), "missing COUNT(*); got: {sql}");
assert!(params.is_empty(), "no params expected; got: {params:?}");
}
#[test]
fn aggregate_where_filters_underlying_select() {
let op: AggregateOperation<User, RecordingEngine> = AggregateOperation::new()
.count()
.r#where(Filter::Equals("active".into(), FilterValue::Bool(true)));
let (sql, params) = op.build_sql(&prax_query::dialect::Postgres);
assert!(sql.contains("WHERE"), "missing WHERE clause; got: {sql}");
assert!(
sql.contains("active"),
"missing `active` column in WHERE; got: {sql}"
);
assert_eq!(params.len(), 1, "expected exactly 1 param; got: {params:?}");
assert_eq!(
params[0],
FilterValue::Bool(true),
"param should be Bool(true); got: {:?}",
params[0]
);
}
#[test]
fn group_by_emits_group_by_clause() {
let op: GroupByOperation<User, RecordingEngine> =
GroupByOperation::new(vec!["team_id".into(), "region".into()]).count();
let (sql, params) = op.build_sql(&prax_query::dialect::Postgres);
assert!(
sql.contains("GROUP BY team_id, region"),
"missing GROUP BY clause; got: {sql}"
);
assert!(sql.contains("COUNT(*)"), "missing COUNT(*); got: {sql}");
assert!(params.is_empty(), "no params expected; got: {params:?}");
}
#[test]
fn group_by_having_emits_having_clause() {
let op: GroupByOperation<User, RecordingEngine> = GroupByOperation::new(vec!["team_id".into()])
.count()
.having(having::count_gt(5.0));
let (sql, _params) = op.build_sql(&prax_query::dialect::Postgres);
assert!(sql.contains("HAVING"), "missing HAVING clause; got: {sql}");
assert!(
sql.contains("COUNT(*) > 5"),
"missing `COUNT(*) > 5` in HAVING; got: {sql}"
);
}
#[test]
fn aggregate_omits_unspecified_blocks() {
let op: AggregateOperation<User, RecordingEngine> = AggregateOperation::new().sum("views");
let (sql, params) = op.build_sql(&prax_query::dialect::Postgres);
assert!(sql.contains("SUM(views)"), "missing SUM(views); got: {sql}");
assert!(!sql.contains("AVG"), "unexpected AVG in SQL; got: {sql}");
assert!(!sql.contains("MIN"), "unexpected MIN in SQL; got: {sql}");
assert!(!sql.contains("MAX"), "unexpected MAX in SQL; got: {sql}");
assert!(
!sql.contains("COUNT"),
"unexpected COUNT in SQL; got: {sql}"
);
assert!(params.is_empty(), "no params expected; got: {params:?}");
}
#[test]
fn distinct_count_emits_count_distinct_sql() {
let op: AggregateOperation<User, RecordingEngine> =
AggregateOperation::new().count_distinct("region");
let (sql, params) = op.build_sql(&prax_query::dialect::Postgres);
assert!(
sql.contains("COUNT(DISTINCT region)"),
"missing COUNT(DISTINCT region); got: {sql}"
);
assert!(
sql.contains("_count_distinct_region"),
"missing alias _count_distinct_region; got: {sql}"
);
assert!(params.is_empty(), "no params expected; got: {params:?}");
}
#[test]
fn group_by_order_by_emits_order_by_clause() {
let op: GroupByOperation<User, RecordingEngine> = GroupByOperation::new(vec!["team_id".into()])
.sum("views")
.order_by(OrderByField::desc("_sum_views"));
let (sql, params) = op.build_sql(&prax_query::dialect::Postgres);
assert!(
sql.contains("ORDER BY"),
"missing ORDER BY clause; got: {sql}"
);
assert!(
sql.contains("_sum_views"),
"missing _sum_views in ORDER BY; got: {sql}"
);
assert!(sql.contains("DESC"), "missing DESC in ORDER BY; got: {sql}");
assert!(params.is_empty(), "no params expected; got: {params:?}");
}
#[test]
fn aggregate_result_hydrates_per_column_and_distinct_counts() {
use std::collections::HashMap;
let mut row: HashMap<String, prax_query::filter::FilterValue> = HashMap::new();
row.insert(
"_count".to_string(),
prax_query::filter::FilterValue::Int(10),
);
row.insert(
"_count_email".to_string(),
prax_query::filter::FilterValue::Int(8),
);
row.insert(
"_count_distinct_email".to_string(),
prax_query::filter::FilterValue::Int(5),
);
let result = prax_query::operations::AggregateResult::from_row(row);
assert_eq!(
result.count,
Some(10),
"overall COUNT(*) should be 10; got: {:?}",
result.count
);
assert_eq!(
result.count_of("email"),
Some(8),
"COUNT(email) should be 8; got: {:?}",
result.count_of("email")
);
assert_eq!(
result.count_distinct_of("email"),
Some(5),
"COUNT(DISTINCT email) should be 5; got: {:?}",
result.count_distinct_of("email")
);
assert_eq!(
result.count_of("distinct_email"),
None,
"distinct_email must not appear in count_columns"
);
}