use serde::Deserialize;
use serde_json::Value as JsonValue;
use crate::errors::AppError;
use crate::schema::DatasetSchema;
#[derive(Clone, Deserialize)]
pub struct Predicate {
pub col: String,
pub op: String,
pub val: Option<JsonValue>,
}
#[derive(Clone, Deserialize)]
pub struct OrderBy {
pub col: String,
#[serde(default)]
pub dir: Option<String>,
}
#[derive(Clone, Deserialize)]
pub struct Aggregation {
#[serde(default)]
pub col: Option<String>,
pub op: String,
#[serde(default)]
pub alias: Option<String>,
}
#[derive(Clone, Deserialize)]
pub struct QueryRequest {
#[serde(default)]
pub columns: Vec<String>,
#[serde(default)]
pub predicates: Vec<Predicate>,
#[serde(default)]
pub group_by: Vec<String>,
#[serde(default)]
pub aggregations: Vec<Aggregation>,
#[serde(default)]
pub distinct: bool,
#[serde(default)]
pub order_by: Vec<OrderBy>,
#[serde(default)]
pub limit: Option<u64>,
#[serde(default = "default_page")]
pub page: u64,
#[serde(default = "default_page_size")]
pub page_size: u64,
}
#[derive(Clone)]
pub struct AggSpec {
pub col: Option<String>,
pub op: AggOp,
pub alias: String,
}
#[derive(Clone, Copy)]
pub enum AggOp {
Count,
Sum,
Avg,
Min,
Max,
}
impl AggOp {
pub fn as_sql(self) -> &'static str {
match self {
AggOp::Count => "COUNT",
AggOp::Sum => "SUM",
AggOp::Avg => "AVG",
AggOp::Min => "MIN",
AggOp::Max => "MAX",
}
}
pub fn name(self) -> &'static str {
match self {
AggOp::Count => "count",
AggOp::Sum => "sum",
AggOp::Avg => "avg",
AggOp::Min => "min",
AggOp::Max => "max",
}
}
}
impl AggSpec {
pub fn sql_expr(&self) -> Result<String, AppError> {
match (self.op, self.col.as_deref()) {
(AggOp::Count, None) => Ok("COUNT(*)".to_string()),
(op, Some(c)) => Ok(format!(
"{}({})",
op.as_sql(),
DatasetSchema::quote_ident(c)
)),
(op, None) => Err(AppError::Internal(format!(
"aggregation '{}' resolved without a column (planner invariant violated)",
op.name()
))),
}
}
}
#[derive(Clone)]
pub struct AggPlan {
pub group_cols: Vec<String>,
pub aggs: Vec<AggSpec>,
}
impl AggPlan {
pub fn output_names(&self) -> Vec<String> {
let mut v = self.group_cols.clone();
v.extend(self.aggs.iter().map(|a| a.alias.clone()));
v
}
}
impl QueryRequest {
pub fn agg_plan(&self, schema: &DatasetSchema) -> Result<Option<AggPlan>, AppError> {
if self.distinct && (!self.group_by.is_empty() || !self.aggregations.is_empty()) {
return Err(AppError::InvalidValue(
"distinct is mutually exclusive with group_by / aggregations".into(),
));
}
if self.group_by.is_empty() {
if !self.aggregations.is_empty() {
return Err(AppError::InvalidValue(
"aggregations require a non-empty group_by".into(),
));
}
return Ok(None);
}
let mut group_cols = Vec::with_capacity(self.group_by.len());
for name in &self.group_by {
group_cols.push(schema.find(name)?.name.clone());
}
let raw_aggs: Vec<Aggregation> = if self.aggregations.is_empty() {
vec![Aggregation {
col: None,
op: "count".into(),
alias: None,
}]
} else {
self.aggregations.clone()
};
let mut aggs = Vec::with_capacity(raw_aggs.len());
for a in &raw_aggs {
let op = match a.op.to_ascii_lowercase().as_str() {
"count" => AggOp::Count,
"sum" => AggOp::Sum,
"avg" => AggOp::Avg,
"min" => AggOp::Min,
"max" => AggOp::Max,
other => {
return Err(AppError::InvalidValue(format!(
"unknown aggregation op '{other}' (expected count|sum|avg|min|max)"
)));
}
};
let col = match (op, a.col.as_deref()) {
(AggOp::Count, None) => None,
(_, None) => {
return Err(AppError::InvalidValue(format!(
"aggregation '{}' requires a 'col'",
op.name()
)));
}
(_, Some(c)) => Some(schema.find(c)?.name.clone()),
};
let alias = a.alias.clone().unwrap_or_else(|| match col.as_deref() {
Some(c) => format!("{}_{}", op.name(), c.to_lowercase()),
None => "count".into(),
});
aggs.push(AggSpec { col, op, alias });
}
Ok(Some(AggPlan { group_cols, aggs }))
}
pub fn order_by_sql(
&self,
schema: &DatasetSchema,
plan: Option<&AggPlan>,
) -> Result<Option<String>, AppError> {
if self.order_by.is_empty() {
return Ok(None);
}
let parts: Vec<String> = self
.order_by
.iter()
.map(|o| {
let dir = match o
.dir
.as_deref()
.unwrap_or("asc")
.to_ascii_lowercase()
.as_str()
{
"asc" => "ASC",
"desc" => "DESC",
other => {
return Err(AppError::InvalidValue(format!(
"order_by direction must be 'asc' or 'desc' (got '{other}')"
)));
}
};
let ident = match plan {
Some(p) => {
let lc = o.col.to_lowercase();
let allowed = p.output_names();
allowed
.iter()
.find(|n| n.to_lowercase() == lc)
.map(|n| DatasetSchema::quote_ident(n))
.ok_or_else(|| {
AppError::UnknownColumn(format!(
"{} (must be a group_by column or aggregation alias)",
o.col
))
})?
}
None => DatasetSchema::quote_ident(&schema.find(&o.col)?.name),
};
Ok(format!("{ident} {dir}"))
})
.collect::<Result<_, _>>()?;
Ok(Some(parts.join(", ")))
}
pub fn effective_limit_offset(&self, page_size_cap: u64) -> (u64, u64) {
let page = self.page.max(1);
let page_size = self.page_size.clamp(1, page_size_cap);
let offset = (page - 1) * page_size;
let limit = match self.limit {
Some(cap) => {
if offset >= cap {
0
} else {
page_size.min(cap - offset)
}
}
None => page_size,
};
(limit, offset)
}
}
fn default_page() -> u64 {
1
}
fn default_page_size() -> u64 {
1000
}
#[derive(Clone, Deserialize, Default)]
pub struct CountRequest {
#[serde(default)]
pub predicates: Vec<Predicate>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{ColumnInfo, DatasetSchema, LogicalType};
fn schema() -> DatasetSchema {
DatasetSchema::new(
"t",
vec![
ColumnInfo {
name: "id".into(),
logical: LogicalType::Int,
sql_type: "BIGINT".into(),
nullable: false,
},
ColumnInfo {
name: "name".into(),
logical: LogicalType::Utf8,
sql_type: "VARCHAR".into(),
nullable: true,
},
ColumnInfo {
name: "score".into(),
logical: LogicalType::Float,
sql_type: "DOUBLE".into(),
nullable: true,
},
ColumnInfo {
name: "Mixed".into(),
logical: LogicalType::Utf8,
sql_type: "VARCHAR".into(),
nullable: true,
},
],
)
}
fn empty_req() -> QueryRequest {
QueryRequest {
columns: vec![],
predicates: vec![],
group_by: vec![],
aggregations: vec![],
distinct: false,
order_by: vec![],
limit: None,
page: 1,
page_size: 1000,
}
}
#[test]
fn agg_plan_none_when_no_group_by() {
let r = empty_req();
assert!(r.agg_plan(&schema()).unwrap().is_none());
}
#[test]
fn agg_plan_rejects_aggs_without_group_by() {
let mut r = empty_req();
r.aggregations = vec![Aggregation {
col: Some("score".into()),
op: "sum".into(),
alias: None,
}];
let err = r.agg_plan(&schema()).err().expect("expected error");
assert!(matches!(err, AppError::InvalidValue(_)), "got {err:?}");
}
#[test]
fn agg_plan_implicit_count_star() {
let mut r = empty_req();
r.group_by = vec!["name".into()];
let plan = r.agg_plan(&schema()).unwrap().unwrap();
assert_eq!(plan.group_cols, vec!["name"]);
assert_eq!(plan.aggs.len(), 1);
assert_eq!(plan.aggs[0].alias, "count");
assert!(plan.aggs[0].col.is_none());
assert!(matches!(plan.aggs[0].op, AggOp::Count));
}
#[test]
fn agg_plan_default_alias_format() {
let mut r = empty_req();
r.group_by = vec!["name".into()];
r.aggregations = vec![
Aggregation {
col: Some("score".into()),
op: "Sum".into(),
alias: None,
},
Aggregation {
col: Some("Mixed".into()),
op: "MAX".into(),
alias: Some("hi".into()),
},
];
let plan = r.agg_plan(&schema()).unwrap().unwrap();
assert_eq!(plan.aggs[0].alias, "sum_score");
assert_eq!(plan.aggs[1].alias, "hi");
assert_eq!(plan.aggs[1].col.as_deref(), Some("Mixed"));
}
#[test]
fn agg_plan_unknown_op() {
let mut r = empty_req();
r.group_by = vec!["name".into()];
r.aggregations = vec![Aggregation {
col: Some("score".into()),
op: "median".into(),
alias: None,
}];
let err = r.agg_plan(&schema()).err().expect("expected error");
assert!(matches!(err, AppError::InvalidValue(m) if m.contains("median")));
}
#[test]
fn agg_plan_non_count_requires_col() {
let mut r = empty_req();
r.group_by = vec!["name".into()];
r.aggregations = vec![Aggregation {
col: None,
op: "avg".into(),
alias: None,
}];
let err = r.agg_plan(&schema()).err().expect("expected error");
assert!(matches!(err, AppError::InvalidValue(m) if m.contains("avg")));
}
#[test]
fn agg_plan_unknown_group_col() {
let mut r = empty_req();
r.group_by = vec!["nope".into()];
let err = r.agg_plan(&schema()).err().expect("expected error");
assert!(matches!(err, AppError::UnknownColumn(_)));
}
#[test]
fn agg_plan_distinct_conflicts_with_group_by() {
let mut r = empty_req();
r.distinct = true;
r.group_by = vec!["name".into()];
let err = r.agg_plan(&schema()).err().expect("expected error");
assert!(matches!(err, AppError::InvalidValue(_)));
}
#[test]
fn order_by_none_when_empty() {
let r = empty_req();
assert!(r.order_by_sql(&schema(), None).unwrap().is_none());
}
#[test]
fn order_by_default_asc_and_quoting() {
let mut r = empty_req();
r.order_by = vec![OrderBy {
col: "ID".into(),
dir: None,
}];
let sql = r.order_by_sql(&schema(), None).unwrap().unwrap();
assert_eq!(sql, "\"id\" ASC");
}
#[test]
fn order_by_desc_case_insensitive() {
let mut r = empty_req();
r.order_by = vec![OrderBy {
col: "name".into(),
dir: Some("DESC".into()),
}];
let sql = r.order_by_sql(&schema(), None).unwrap().unwrap();
assert_eq!(sql, "\"name\" DESC");
}
#[test]
fn order_by_bad_direction() {
let mut r = empty_req();
r.order_by = vec![OrderBy {
col: "id".into(),
dir: Some("backwards".into()),
}];
let err = r.order_by_sql(&schema(), None).unwrap_err();
assert!(matches!(err, AppError::InvalidValue(m) if m.contains("backwards")));
}
#[test]
fn order_by_unknown_col_no_plan() {
let mut r = empty_req();
r.order_by = vec![OrderBy {
col: "missing".into(),
dir: None,
}];
let err = r.order_by_sql(&schema(), None).unwrap_err();
assert!(matches!(err, AppError::UnknownColumn(_)));
}
#[test]
fn order_by_with_plan_restricts_to_outputs() {
let mut r = empty_req();
r.group_by = vec!["name".into()];
r.aggregations = vec![Aggregation {
col: Some("score".into()),
op: "sum".into(),
alias: Some("total".into()),
}];
let plan = r.agg_plan(&schema()).unwrap().unwrap();
r.order_by = vec![
OrderBy {
col: "name".into(),
dir: Some("asc".into()),
},
OrderBy {
col: "TOTAL".into(),
dir: Some("desc".into()),
},
];
let sql = r.order_by_sql(&schema(), Some(&plan)).unwrap().unwrap();
assert_eq!(sql, "\"name\" ASC, \"total\" DESC");
r.order_by = vec![OrderBy {
col: "id".into(),
dir: None,
}];
let err = r.order_by_sql(&schema(), Some(&plan)).unwrap_err();
assert!(matches!(err, AppError::UnknownColumn(_)));
}
#[test]
fn limit_offset_first_page_default() {
let r = empty_req();
assert_eq!(r.effective_limit_offset(1000), (1000, 0));
}
#[test]
fn limit_offset_pagination() {
let mut r = empty_req();
r.page = 3;
r.page_size = 50;
assert_eq!(r.effective_limit_offset(1000), (50, 100));
}
#[test]
fn limit_offset_caps_page_size_to_max() {
let mut r = empty_req();
r.page_size = 10_000;
assert_eq!(r.effective_limit_offset(1000), (1000, 0));
}
#[test]
fn limit_offset_page_zero_treated_as_one() {
let mut r = empty_req();
r.page = 0;
r.page_size = 10;
assert_eq!(r.effective_limit_offset(1000), (10, 0));
}
#[test]
fn limit_offset_top_level_cap_truncates_last_page() {
let mut r = empty_req();
r.page = 2;
r.page_size = 50;
r.limit = Some(75); assert_eq!(r.effective_limit_offset(1000), (25, 50));
}
#[test]
fn limit_offset_top_level_cap_exhausted_returns_zero() {
let mut r = empty_req();
r.page = 3;
r.page_size = 50;
r.limit = Some(75); assert_eq!(r.effective_limit_offset(1000), (0, 100));
}
}