#![cfg(test)]
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use prax_postgres::{PgEngine, PgPool, PgPoolBuilder};
use prax_query::filter::FilterValue;
use prax_query::operations::having;
use prax_query::operations::{
AggregateOperation, AggregateResult, GroupByOperation, GroupByResult,
};
use prax_query::traits::{Model, QueryEngine};
use prax_query::types::OrderByField;
static TABLE_COUNTER: AtomicU32 = AtomicU32::new(0);
fn unique_table(prefix: &str) -> String {
let n = TABLE_COUNTER.fetch_add(1, Ordering::SeqCst);
let pid = std::process::id();
format!("agg_{prefix}_{pid}_{n}")
}
fn skip_unless_e2e() -> Option<String> {
if std::env::var("PRAX_E2E").ok().as_deref() != Some("1") {
return None;
}
std::env::var("POSTGRES_URL").ok()
}
async fn pool() -> PgPool {
let url = skip_unless_e2e().expect("PRAX_E2E=1 and POSTGRES_URL required");
PgPoolBuilder::new()
.url(url)
.max_connections(4)
.connection_timeout(Duration::from_secs(10))
.build()
.await
.expect("connect to postgres")
}
async fn drop_table(pool: &PgPool, table: &str) {
let conn = pool.get().await.expect("acquire conn for cleanup");
let _ = conn
.batch_execute(&format!("DROP TABLE IF EXISTS {table}"))
.await;
}
struct CountModel;
impl Model for CountModel {
const MODEL_NAME: &'static str = "CountModel";
const TABLE_NAME: &'static str = "count_model_placeholder";
const PRIMARY_KEY: &'static [&'static str] = &["id"];
const COLUMNS: &'static [&'static str] = &["id", "email"];
}
struct ScoreModel;
impl Model for ScoreModel {
const MODEL_NAME: &'static str = "ScoreModel";
const TABLE_NAME: &'static str = "score_model_placeholder";
const PRIMARY_KEY: &'static [&'static str] = &["id"];
const COLUMNS: &'static [&'static str] = &["id", "score"];
}
struct TeamModel;
impl Model for TeamModel {
const MODEL_NAME: &'static str = "TeamModel";
const TABLE_NAME: &'static str = "team_model_placeholder";
const PRIMARY_KEY: &'static [&'static str] = &["id"];
const COLUMNS: &'static [&'static str] = &["id", "team_id", "score"];
}
struct RegionModel;
impl Model for RegionModel {
const MODEL_NAME: &'static str = "RegionModel";
const TABLE_NAME: &'static str = "region_model_placeholder";
const PRIMARY_KEY: &'static [&'static str] = &["id"];
const COLUMNS: &'static [&'static str] = &["id", "region"];
}
struct ViewsModel;
impl Model for ViewsModel {
const MODEL_NAME: &'static str = "ViewsModel";
const TABLE_NAME: &'static str = "views_model_placeholder";
const PRIMARY_KEY: &'static [&'static str] = &["id"];
const COLUMNS: &'static [&'static str] = &["id", "team_id", "views"];
}
#[tokio::test]
#[ignore = "requires running PostgreSQL via docker-compose (PRAX_E2E=1 + POSTGRES_URL)"]
async fn count_select_round_trip() {
if skip_unless_e2e().is_none() {
eprintln!("skipping: PRAX_E2E not set");
return;
}
let pool = pool().await;
let table = unique_table("count");
drop_table(&pool, &table).await;
{
let conn = pool.get().await.expect("conn");
conn.batch_execute(&format!(
"CREATE TABLE {table} (id SERIAL PRIMARY KEY, email TEXT)"
))
.await
.expect("create table");
conn.batch_execute(&format!(
"INSERT INTO {table} (email) VALUES \
('a@example.com'), ('b@example.com'), ('c@example.com'), (NULL), (NULL)"
))
.await
.expect("insert rows");
}
let engine = PgEngine::new(pool.clone());
let dialect = engine.dialect();
let op: AggregateOperation<CountModel, PgEngine> = AggregateOperation::new().count();
let (sql, params) = op.build_sql(dialect);
let sql = sql.replace(CountModel::TABLE_NAME, &table);
let mut rows = engine
.aggregate_query(&sql, params)
.await
.expect("aggregate_query");
let result = AggregateResult::from_row(rows.pop().unwrap_or_default());
assert_eq!(
result.count,
Some(5),
"COUNT(*) should be 5 (includes NULLs)"
);
drop_table(&pool, &table).await;
}
#[tokio::test]
#[ignore = "requires running PostgreSQL via docker-compose (PRAX_E2E=1 + POSTGRES_URL)"]
async fn aggregate_sum_avg_count_round_trip() {
if skip_unless_e2e().is_none() {
eprintln!("skipping: PRAX_E2E not set");
return;
}
let pool = pool().await;
let table = unique_table("score");
drop_table(&pool, &table).await;
{
let conn = pool.get().await.expect("conn");
conn.batch_execute(&format!(
"CREATE TABLE {table} (id SERIAL PRIMARY KEY, score INT NOT NULL)"
))
.await
.expect("create table");
conn.batch_execute(&format!(
"INSERT INTO {table} (score) VALUES (10), (20), (30)"
))
.await
.expect("insert rows");
}
let engine = PgEngine::new(pool.clone());
let dialect = engine.dialect();
let op: AggregateOperation<ScoreModel, PgEngine> =
AggregateOperation::new().count().sum("score").avg("score");
let (sql, params) = op.build_sql(dialect);
let sql = sql.replace(ScoreModel::TABLE_NAME, &table);
let mut rows = engine
.aggregate_query(&sql, params)
.await
.expect("aggregate_query");
let result = AggregateResult::from_row(rows.pop().unwrap_or_default());
assert_eq!(result.count, Some(3), "COUNT(*) should be 3");
let sum = result
.sum_as_f64("score")
.expect("sum(score) should be present");
assert!(
(sum - 60.0).abs() < 0.001,
"SUM(score) should be 60, got {sum}"
);
let avg = result
.avg_as_f64("score")
.expect("avg(score) should be present");
assert!(
(avg - 20.0).abs() < 0.001,
"AVG(score) should be 20.0, got {avg}"
);
drop_table(&pool, &table).await;
}
#[tokio::test]
#[ignore = "requires running PostgreSQL via docker-compose (PRAX_E2E=1 + POSTGRES_URL)"]
async fn group_by_with_having_round_trip() {
if skip_unless_e2e().is_none() {
eprintln!("skipping: PRAX_E2E not set");
return;
}
let pool = pool().await;
let table = unique_table("team");
drop_table(&pool, &table).await;
{
let conn = pool.get().await.expect("conn");
conn.batch_execute(&format!(
"CREATE TABLE {table} (id SERIAL PRIMARY KEY, team_id INT NOT NULL, score INT NOT NULL)"
))
.await
.expect("create table");
conn.batch_execute(&format!(
"INSERT INTO {table} (team_id, score) VALUES \
(1, 10), (1, 20), \
(2, 30), (2, 40), (2, 50), (2, 60)"
))
.await
.expect("insert rows");
}
let engine = PgEngine::new(pool.clone());
let dialect = engine.dialect();
let op: GroupByOperation<TeamModel, PgEngine> =
GroupByOperation::new(vec!["team_id".to_string()])
.count()
.having(having::count_gt(3.0));
let (sql, params) = op.build_sql(dialect);
let sql = sql.replace(TeamModel::TABLE_NAME, &table);
let raw_rows = engine
.aggregate_query(&sql, params)
.await
.expect("aggregate_query for group_by");
let group_columns = ["team_id"];
let results: Vec<GroupByResult> = raw_rows
.into_iter()
.map(|row| {
let mut group_values: HashMap<String, serde_json::Value> = HashMap::new();
let mut agg_map: HashMap<String, FilterValue> = HashMap::new();
for (k, v) in row {
if group_columns.contains(&k.as_str()) {
let json_val = match &v {
FilterValue::Int(n) => serde_json::Value::from(*n),
FilterValue::Float(f) => serde_json::json!(*f),
FilterValue::String(s) => serde_json::Value::String(s.clone()),
FilterValue::Bool(b) => serde_json::Value::Bool(*b),
_ => serde_json::Value::Null,
};
group_values.insert(k, json_val);
} else {
agg_map.insert(k, v);
}
}
GroupByResult {
group_values,
aggregates: AggregateResult::from_row(agg_map),
}
})
.collect();
assert_eq!(
results.len(),
1,
"HAVING COUNT(*) > 3 should return exactly one group"
);
let team_id = results[0]
.group_values
.get("team_id")
.and_then(serde_json::Value::as_i64)
.expect("team_id should be present as integer");
assert_eq!(team_id, 2, "the surviving group should be team 2");
let count = results[0]
.aggregates
.count
.expect("COUNT(*) should be present in aggregates");
assert_eq!(count, 4, "team 2 has 4 rows");
drop_table(&pool, &table).await;
}
#[tokio::test]
#[ignore = "requires running PostgreSQL via docker-compose (PRAX_E2E=1 + POSTGRES_URL)"]
async fn distinct_count_round_trip() {
if skip_unless_e2e().is_none() {
eprintln!("skipping: PRAX_E2E not set");
return;
}
let pool = pool().await;
let table = unique_table("region");
drop_table(&pool, &table).await;
{
let conn = pool.get().await.expect("conn");
conn.batch_execute(&format!(
"CREATE TABLE {table} (id SERIAL PRIMARY KEY, region TEXT NOT NULL)"
))
.await
.expect("create table");
conn.batch_execute(&format!(
"INSERT INTO {table} (region) VALUES ('a'), ('a'), ('b'), ('b'), ('c')"
))
.await
.expect("insert rows");
}
let engine = PgEngine::new(pool.clone());
let dialect = engine.dialect();
let op: AggregateOperation<RegionModel, PgEngine> = AggregateOperation::new()
.count_column("region")
.count_distinct("region");
let (sql, params) = op.build_sql(dialect);
let sql = sql.replace(RegionModel::TABLE_NAME, &table);
let mut rows = engine
.aggregate_query(&sql, params)
.await
.expect("aggregate_query");
let result = AggregateResult::from_row(rows.pop().unwrap_or_default());
assert_eq!(
result.count_of("region"),
Some(5),
"COUNT(region) should be 5 (5 non-NULL rows)"
);
assert_eq!(
result.count_distinct_of("region"),
Some(3),
"COUNT(DISTINCT region) should be 3 (a, b, c)"
);
drop_table(&pool, &table).await;
}
#[tokio::test]
#[ignore = "requires running PostgreSQL via docker-compose (PRAX_E2E=1 + POSTGRES_URL)"]
async fn group_by_order_by_round_trip() {
if skip_unless_e2e().is_none() {
eprintln!("skipping: PRAX_E2E not set");
return;
}
let pool = pool().await;
let table = unique_table("views");
drop_table(&pool, &table).await;
{
let conn = pool.get().await.expect("conn");
conn.batch_execute(&format!(
"CREATE TABLE {table} (id SERIAL PRIMARY KEY, team_id INT NOT NULL, views INT NOT NULL)"
))
.await
.expect("create table");
conn.batch_execute(&format!(
"INSERT INTO {table} (team_id, views) VALUES \
(1, 40), (1, 60), \
(2, 100), (2, 200)"
))
.await
.expect("insert rows");
}
let engine = PgEngine::new(pool.clone());
let dialect = engine.dialect();
let op: GroupByOperation<ViewsModel, PgEngine> =
GroupByOperation::new(vec!["team_id".to_string()])
.sum("views")
.order_by(OrderByField::desc("_sum_views"));
let (sql, params) = op.build_sql(dialect);
let sql = sql.replace(ViewsModel::TABLE_NAME, &table);
let raw_rows = engine
.aggregate_query(&sql, params)
.await
.expect("aggregate_query for group_by order_by");
let group_columns = ["team_id"];
let results: Vec<GroupByResult> = raw_rows
.into_iter()
.map(|row| {
let mut group_values: HashMap<String, serde_json::Value> = HashMap::new();
let mut agg_map: HashMap<String, FilterValue> = HashMap::new();
for (k, v) in row {
if group_columns.contains(&k.as_str()) {
let json_val = match &v {
FilterValue::Int(n) => serde_json::Value::from(*n),
FilterValue::Float(f) => serde_json::json!(*f),
FilterValue::String(s) => serde_json::Value::String(s.clone()),
FilterValue::Bool(b) => serde_json::Value::Bool(*b),
_ => serde_json::Value::Null,
};
group_values.insert(k, json_val);
} else {
agg_map.insert(k, v);
}
}
GroupByResult {
group_values,
aggregates: AggregateResult::from_row(agg_map),
}
})
.collect();
assert_eq!(results.len(), 2, "two groups expected (team 1 and team 2)");
let first_team_id = results[0]
.group_values
.get("team_id")
.and_then(serde_json::Value::as_i64)
.expect("team_id present in first group");
assert_eq!(first_team_id, 2, "team 2 should be first (highest sum)");
let first_sum = results[0]
.aggregates
.sum_as_f64("views")
.expect("SUM(views) present in first group");
assert!(
(first_sum - 300.0).abs() < 0.001,
"team 2 sum should be 300, got {first_sum}"
);
drop_table(&pool, &table).await;
}