use std::collections::{HashMap, HashSet};
use std::sync::OnceLock;
use regex::Regex;
use super::types::StatResult;
use crate::naming;
use crate::plot::aesthetic::AestheticContext;
use crate::plot::types::{ArrayElement, ParameterValue, Parameters, Schema};
use crate::reader::SqlDialect;
use crate::{GgsqlError, Mappings, Result};
pub const AGG_NAMES: &[&str] = &[
"count", "sum", "prod", "min", "max", "range", "mid", "mean", "geomean", "harmean", "rms", "median", "sdev", "var", "iqr", "p05", "p10", "p25", "p50", "p75", "p90", "p95", "first", "last", "diff",
];
pub const OFFSET_STATS: &[&str] = &[
"mean", "median", "geomean", "harmean", "rms", "sum", "prod", "min", "max", "mid", "p05",
"p10", "p25", "p50", "p75", "p90", "p95",
];
pub const EXPANSION_STATS: &[&str] = &["sdev", "se", "var", "iqr", "range"];
#[derive(Debug, Clone, PartialEq)]
pub struct AggSpec {
pub offset: &'static str,
pub band: Option<Band>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Band {
pub mod_value: f64,
pub expansion: &'static str,
}
fn resolve_static(name: &str, vocab: &'static [&'static str]) -> Option<&'static str> {
vocab.iter().copied().find(|v| *v == name)
}
fn entry_re() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
Regex::new(r"^(?:([^:]+):)?([a-z]+\d*)(?:([+-])(\d+(?:\.\d+)?)?([a-z]+))?$").unwrap()
})
}
struct ParsedEntry {
aesthetic: Option<String>,
spec: AggSpec,
}
fn parse_entry(entry: &str) -> std::result::Result<ParsedEntry, String> {
let caps = entry_re()
.captures(entry)
.ok_or_else(|| format!("could not parse aggregate entry '{}'", entry))?;
let aesthetic = caps.get(1).map(|m| m.as_str().to_string());
let offset_str = caps.get(2).unwrap().as_str();
let band_present = caps.get(3).is_some();
let band = if band_present {
let expansion_str = caps.get(5).unwrap().as_str();
let expansion = resolve_static(expansion_str, EXPANSION_STATS).ok_or_else(|| {
format!(
"'{}': '{}' is not a valid expansion stat. Allowed expansions: {}",
entry,
expansion_str,
crate::or_list_quoted(EXPANSION_STATS, '\''),
)
})?;
let magnitude: f64 = caps
.get(4)
.map_or(1.0, |m| m.as_str().parse().unwrap_or(1.0));
let mod_value = if caps.get(3).unwrap().as_str() == "-" {
-magnitude
} else {
magnitude
};
Some(Band {
mod_value,
expansion,
})
} else {
None
};
let offset = if band.is_some() {
resolve_static(offset_str, OFFSET_STATS).ok_or_else(|| {
if AGG_NAMES.contains(&offset_str) {
format!(
"'{}': '{}' is not a valid offset stat. Allowed offsets: {}",
entry,
offset_str,
crate::or_list_quoted(OFFSET_STATS, '\''),
)
} else {
format!(
"'{}': '{}' is not a known stat. Allowed offsets: {}",
entry,
offset_str,
crate::or_list_quoted(OFFSET_STATS, '\''),
)
}
})?
} else {
resolve_static(offset_str, AGG_NAMES).ok_or_else(|| {
format!(
"unknown aggregate function '{}'. Allowed: {} (or use a band like `mean+sdev`)",
offset_str,
crate::or_list_quoted(AGG_NAMES, '\''),
)
})?
};
Ok(ParsedEntry {
aesthetic,
spec: AggSpec { offset, band },
})
}
#[derive(Debug, Clone, PartialEq)]
pub struct AggregateSpec {
pub default_lower: Option<AggSpec>,
pub default_upper: Option<AggSpec>,
pub targets: Vec<(String, Vec<AggSpec>)>,
}
impl AggregateSpec {
fn new() -> Self {
Self {
default_lower: None,
default_upper: None,
targets: Vec::new(),
}
}
pub fn explosion_factor(&self) -> usize {
self.targets
.iter()
.map(|(_, fns)| fns.len())
.max()
.unwrap_or(1)
.max(1)
}
pub fn explosion_labels(&self) -> Option<Vec<String>> {
let n = self.explosion_factor();
if n <= 1 {
return None;
}
let exploded: Vec<&Vec<AggSpec>> = self
.targets
.iter()
.filter(|(_, fns)| fns.len() == n)
.map(|(_, fns)| fns)
.collect();
let labels = (0..n)
.map(|row| {
let mut parts: Vec<String> = Vec::new();
for fns in &exploded {
let label = agg_label(&fns[row]);
if !parts.contains(&label) {
parts.push(label);
}
}
parts.join("/")
})
.collect();
Some(labels)
}
}
fn agg_label(spec: &AggSpec) -> String {
match &spec.band {
None => spec.offset.to_string(),
Some(b) => {
let sign = if b.mod_value < 0.0 { '-' } else { '+' };
let magnitude = b.mod_value.abs();
if magnitude == 1.0 {
format!("{}{}{}", spec.offset, sign, b.expansion)
} else {
format!("{}{}{}{}", spec.offset, sign, magnitude, b.expansion)
}
}
}
}
pub fn parse_aggregate_param(
value: &ParameterValue,
) -> std::result::Result<Option<AggregateSpec>, String> {
let entries: Vec<&str> = match value {
ParameterValue::Null => return Ok(None),
ParameterValue::String(s) => vec![s.as_str()],
ParameterValue::Array(arr) => {
let mut out = Vec::with_capacity(arr.len());
for el in arr {
match el {
ArrayElement::String(s) => out.push(s.as_str()),
ArrayElement::Null => continue,
_ => {
return Err("'aggregate' array entries must be strings or null".to_string());
}
}
}
if out.is_empty() {
return Ok(None);
}
out
}
_ => return Err("'aggregate' must be a string, array of strings, or null".to_string()),
};
let mut spec = AggregateSpec::new();
for entry in entries {
let parsed = parse_entry(entry)?;
match parsed.aesthetic {
Some(aes) => {
if let Some((_, fns)) = spec.targets.iter_mut().find(|(a, _)| *a == aes) {
fns.push(parsed.spec);
} else {
spec.targets.push((aes, vec![parsed.spec]));
}
}
None => {
if spec.default_lower.is_none() {
spec.default_lower = Some(parsed.spec);
} else if spec.default_upper.is_none() {
spec.default_upper = Some(parsed.spec);
} else {
return Err(format!(
"'aggregate' accepts at most two unprefixed defaults; got a third: '{}'",
entry
));
}
}
}
}
if spec.default_lower.is_none() && spec.default_upper.is_none() && spec.targets.is_empty() {
return Ok(None);
}
let n = spec.explosion_factor();
if n > 1 {
for (aes, fns) in &spec.targets {
if fns.len() != 1 && fns.len() != n {
return Err(format!(
"aggregate target '{}' has {} functions; targets in an exploded layer must \
have either 1 or {} functions (the longest target's count)",
aes,
fns.len(),
n
));
}
}
}
Ok(Some(spec))
}
fn percentile_fraction(func: &str) -> Option<f64> {
match func {
"median" | "p50" => Some(0.50),
"p05" => Some(0.05),
"p10" => Some(0.10),
"p25" => Some(0.25),
"p75" => Some(0.75),
"p90" => Some(0.90),
"p95" => Some(0.95),
_ => None,
}
}
fn simple_stat_sql_inline(name: &str, qcol: &str, dialect: &dyn SqlDialect) -> Option<String> {
if let Some(frac) = percentile_fraction(name) {
let unquoted = unquote(qcol);
return dialect.sql_quantile_inline(&unquoted, frac);
}
if name == "iqr" {
let unquoted = unquote(qcol);
let p75 = dialect.sql_quantile_inline(&unquoted, 0.75)?;
let p25 = dialect.sql_quantile_inline(&unquoted, 0.25)?;
return Some(format!("({} - {})", p75, p25));
}
dialect.sql_aggregate(name, qcol)
}
fn dialect_supports(name: &str, dialect: &dyn SqlDialect) -> bool {
if percentile_fraction(name).is_some() || name == "iqr" {
return true;
}
dialect.sql_aggregate(name, "x").is_some()
}
fn unsupported_functions(
aggregated: &[(String, String, Vec<AggSpec>)],
dialect: &dyn SqlDialect,
) -> Vec<String> {
let mut missing: Vec<String> = Vec::new();
for (_, _, specs) in aggregated {
for spec in specs {
for name in [Some(spec.offset), spec.band.as_ref().map(|b| b.expansion)]
.into_iter()
.flatten()
{
if !dialect_supports(name, dialect) && !missing.iter().any(|m| m == name) {
missing.push(name.to_string());
}
}
}
}
missing
}
fn agg_sql_inline(spec: &AggSpec, qcol: &str, dialect: &dyn SqlDialect) -> Option<String> {
let offset_sql = simple_stat_sql_inline(spec.offset, qcol, dialect)?;
match &spec.band {
None => Some(offset_sql),
Some(band) => {
let exp_sql = simple_stat_sql_inline(band.expansion, qcol, dialect)?;
Some(format_band(&offset_sql, band.mod_value, &exp_sql))
}
}
}
fn format_band(offset: &str, mod_value: f64, exp: &str) -> String {
let sign = if mod_value < 0.0 { '-' } else { '+' };
let magnitude = mod_value.abs();
if magnitude == 1.0 {
format!("({} {} {})", offset, sign, exp)
} else {
format!("({} {} {} * {})", offset, sign, magnitude, exp)
}
}
fn simple_stat_sql_fallback(
name: &str,
raw_col: &str,
dialect: &dyn SqlDialect,
src_alias: &str,
group_cols: &[String],
) -> String {
if let Some(frac) = percentile_fraction(name) {
return dialect.sql_percentile(raw_col, frac, src_alias, group_cols);
}
if name == "iqr" {
let p75 = dialect.sql_percentile(raw_col, 0.75, src_alias, group_cols);
let p25 = dialect.sql_percentile(raw_col, 0.25, src_alias, group_cols);
return format!("({} - {})", p75, p25);
}
let qcol = naming::quote_ident(raw_col);
simple_stat_sql_inline(name, &qcol, dialect).unwrap_or_else(|| "NULL".to_string())
}
fn agg_sql_fallback(
spec: &AggSpec,
raw_col: &str,
dialect: &dyn SqlDialect,
src_alias: &str,
group_cols: &[String],
) -> String {
let offset_sql = simple_stat_sql_fallback(spec.offset, raw_col, dialect, src_alias, group_cols);
match &spec.band {
None => offset_sql,
Some(band) => {
let exp_sql =
simple_stat_sql_fallback(band.expansion, raw_col, dialect, src_alias, group_cols);
format_band(&offset_sql, band.mod_value, &exp_sql)
}
}
}
fn needs_quantile_fallback(spec: &AggSpec, probe_col: &str, dialect: &dyn SqlDialect) -> bool {
if simple_needs_fallback(spec.offset, probe_col, dialect) {
return true;
}
if let Some(band) = &spec.band {
if simple_needs_fallback(band.expansion, probe_col, dialect) {
return true;
}
}
false
}
fn simple_needs_fallback(name: &str, probe_col: &str, dialect: &dyn SqlDialect) -> bool {
if let Some(frac) = percentile_fraction(name) {
return dialect.sql_quantile_inline(probe_col, frac).is_none();
}
if name == "iqr" {
return dialect.sql_quantile_inline(probe_col, 0.5).is_none();
}
false
}
fn unquote(qcol: &str) -> String {
naming::unquote_ident(qcol)
}
fn resolve_target_aesthetic(
user_aes: &str,
aesthetics: &Mappings,
aesthetic_ctx: &AestheticContext,
) -> Vec<String> {
use crate::plot::layer::geom::types::AESTHETIC_ALIASES;
let mut out = Vec::new();
if let Some(internal) = aesthetic_ctx.map_user_to_internal(user_aes) {
if aesthetics.aesthetics.contains_key(internal) {
out.push(internal.to_string());
return out;
}
}
for (alias, targets) in AESTHETIC_ALIASES {
if *alias == user_aes {
for t in *targets {
let internal = aesthetic_ctx
.map_user_to_internal(t)
.map(|s| s.to_string())
.unwrap_or_else(|| (*t).to_string());
if aesthetics.aesthetics.contains_key(&internal) && !out.contains(&internal) {
out.push(internal);
}
}
return out;
}
}
if aesthetics.aesthetics.contains_key(user_aes) {
out.push(user_aes.to_string());
}
out
}
fn is_upper_half(internal_aes: &str) -> bool {
internal_aes.ends_with("max") || internal_aes.ends_with("end")
}
pub(crate) fn resolve_aggregate_targets(
spec: &AggregateSpec,
aesthetics: &Mappings,
aesthetic_ctx: &AestheticContext,
) -> std::result::Result<HashMap<String, Vec<AggSpec>>, String> {
let mut targets_internal: HashMap<String, Vec<AggSpec>> = HashMap::new();
for (user_aes, fns) in &spec.targets {
let resolved = resolve_target_aesthetic(user_aes, aesthetics, aesthetic_ctx);
if resolved.is_empty() {
return Err(format!(
"aggregate target '{}' is not mapped on this layer",
user_aes
));
}
for internal in resolved {
if targets_internal.contains_key(&internal) {
return Err(format!(
"aggregate target '{}' resolves to aesthetic '{}' which is already targeted",
user_aes, internal
));
}
targets_internal.insert(internal, fns.clone());
}
}
Ok(targets_internal)
}
pub fn targeted_aesthetics(
parameters: &Parameters,
aesthetics: &Mappings,
aesthetic_ctx: &AestheticContext,
) -> HashSet<String> {
let raw = match parameters.get("aggregate") {
Some(v) if !matches!(v, ParameterValue::Null) => v,
_ => return HashSet::new(),
};
let spec = match parse_aggregate_param(raw).ok().flatten() {
Some(s) => s,
None => return HashSet::new(),
};
let mut targeted: HashSet<String> = HashSet::new();
for (user_aes, _fns) in &spec.targets {
for internal in resolve_target_aesthetic(user_aes, aesthetics, aesthetic_ctx) {
targeted.insert(internal);
}
}
targeted
}
pub fn aggregated_aesthetics(
parameters: &Parameters,
aesthetics: &Mappings,
schema: &Schema,
aesthetic_ctx: &AestheticContext,
domain_aesthetics: &[&'static str],
) -> Option<(HashSet<String>, HashSet<String>)> {
let raw = parameters.get("aggregate")?;
if matches!(raw, ParameterValue::Null) {
return None;
}
let spec = parse_aggregate_param(raw).ok()??;
let mut targeted: HashSet<String> = HashSet::new();
for (user_aes, _fns) in &spec.targets {
for internal in resolve_target_aesthetic(user_aes, aesthetics, aesthetic_ctx) {
targeted.insert(internal);
}
}
let mut aggregated: HashSet<String> = targeted.clone();
let mut entries: Vec<(&String, &crate::AestheticValue)> =
aesthetics.aesthetics.iter().collect();
entries.sort_by(|a, b| a.0.cmp(b.0));
for (aes, value) in entries {
let col = match value.column_name() {
Some(c) => c,
None => continue,
};
if domain_aesthetics.contains(&aes.as_str()) {
continue;
}
let is_discrete = schema
.iter()
.find(|c| c.name == col)
.map(|c| c.is_discrete)
.unwrap_or(false);
if is_discrete {
continue;
}
if targeted.contains(aes) {
continue;
}
let default_applies = if is_upper_half(aes) {
spec.default_upper.is_some() || spec.default_lower.is_some()
} else {
spec.default_lower.is_some()
};
if default_applies {
aggregated.insert(aes.clone());
}
}
Some((targeted, aggregated))
}
#[allow(clippy::too_many_arguments)]
pub fn apply(
query: &str,
schema: &Schema,
aesthetics: &Mappings,
group_by: &[String],
parameters: &Parameters,
dialect: &dyn SqlDialect,
aesthetic_ctx: &AestheticContext,
domain_aesthetics: &[&'static str],
) -> Result<StatResult> {
let raw = match parameters.get("aggregate") {
None | Some(ParameterValue::Null) => return Ok(StatResult::Identity),
Some(v) => v,
};
let spec = parse_aggregate_param(raw).map_err(GgsqlError::ValidationError)?;
let spec = match spec {
Some(s) => s,
None => return Ok(StatResult::Identity),
};
let n = spec.explosion_factor();
let labels = spec.explosion_labels();
let targets_internal = resolve_aggregate_targets(&spec, aesthetics, aesthetic_ctx)
.map_err(GgsqlError::ValidationError)?;
let mut aggregated: Vec<(String, String, Vec<AggSpec>)> = Vec::new();
let mut kept_cols: Vec<String> = Vec::new();
let mut dropped: Vec<String> = Vec::new();
let mut entries: Vec<(&String, &crate::AestheticValue)> =
aesthetics.aesthetics.iter().collect();
entries.sort_by(|a, b| a.0.cmp(b.0));
for (aes, value) in entries {
let col = match value.column_name() {
Some(c) => c.to_string(),
None => continue, };
if domain_aesthetics.contains(&aes.as_str()) {
if !kept_cols.contains(&col) {
kept_cols.push(col);
}
continue;
}
let info = schema.iter().find(|c| c.name == col);
let is_discrete = info.map(|c| c.is_discrete).unwrap_or(false);
if is_discrete {
if !kept_cols.contains(&col) {
kept_cols.push(col);
}
continue;
}
let fns: Option<Vec<AggSpec>> = if let Some(list) = targets_internal.get(aes) {
if list.len() == n {
Some(list.clone())
} else {
debug_assert_eq!(list.len(), 1);
Some(vec![list[0].clone(); n])
}
} else {
let default = if is_upper_half(aes) {
spec.default_upper
.clone()
.or_else(|| spec.default_lower.clone())
} else {
spec.default_lower.clone()
};
default.map(|d| vec![d; n])
};
match fns {
Some(list) => aggregated.push((aes.clone(), col, list)),
None => dropped.push(aes.clone()),
}
}
for d in &dropped {
let user_aes = aesthetic_ctx.map_internal_to_user(d);
eprintln!(
"Warning: aggregate dropped numeric mapping for aesthetic '{}' \
(no applicable default and no targeted function). \
Suggestion: add an unprefixed default like `aggregate => 'mean'` \
to apply one function to every numeric mapping, or target this \
aesthetic with `'{0}:<func>'`.",
user_aes,
);
}
if aggregated.is_empty() {
return Ok(StatResult::Identity);
}
let mut group_cols: Vec<String> = Vec::new();
for g in group_by {
if !group_cols.contains(g) {
group_cols.push(g.clone());
}
}
for c in &kept_cols {
if !group_cols.contains(c) {
group_cols.push(c.clone());
}
}
let missing = unsupported_functions(&aggregated, dialect);
if !missing.is_empty() {
return Err(GgsqlError::ValidationError(format!(
"aggregate function(s) {} are not supported by this database backend",
crate::or_list_quoted(&missing, '\''),
)));
}
let transformed_query = match &labels {
Some(ls) => build_aggregate_query(query, &aggregated, &group_cols, ls, dialect),
None => build_group_by_query(query, &aggregated, &group_cols, dialect),
};
let mut stat_columns: Vec<String> = aggregated.iter().map(|(a, _, _)| a.clone()).collect();
let consumed_aesthetics: Vec<String> = stat_columns.clone();
if labels.is_some() {
stat_columns.push("aggregate".to_string());
}
Ok(StatResult::Transformed {
query: transformed_query,
stat_columns,
dummy_columns: vec![],
consumed_aesthetics,
})
}
fn source_cte_chain(
query: &str,
aggregated: &[(String, String, Vec<AggSpec>)],
group_cols: &[String],
dialect: &dyn SqlDialect,
) -> (String, &'static str) {
let raw_src = "\"__ggsql_stat_src__\"";
if !needs_row_position(aggregated, dialect) {
return (format!("WITH {raw_src} AS ({query})"), raw_src);
}
let rn_src = "\"__ggsql_stat_src_rn__\"";
let group_select: Vec<String> = group_cols.iter().map(|c| naming::quote_ident(c)).collect();
let partition = if group_select.is_empty() {
String::new()
} else {
format!("PARTITION BY {} ", group_select.join(", "))
};
let cte = format!(
"WITH {raw_src} AS ({query}), {rn_src} AS (\
SELECT *, \
ROW_NUMBER() OVER ({partition}ORDER BY (SELECT 1)) AS \"__ggsql_rn__\", \
COUNT(*) OVER ({partition_no_order}) AS \"__ggsql_max_rn__\" \
FROM {raw_src}\
)",
partition_no_order = partition.trim_end(),
);
(cte, rn_src)
}
fn needs_row_position(
aggregated: &[(String, String, Vec<AggSpec>)],
dialect: &dyn SqlDialect,
) -> bool {
for (_, _, specs) in aggregated {
for spec in specs {
for name in [Some(spec.offset), spec.band.as_ref().map(|b| b.expansion)]
.into_iter()
.flatten()
{
if let Some(sql) = dialect.sql_aggregate(name, "x") {
if sql.contains("__ggsql_rn__") {
return true;
}
}
}
}
}
false
}
fn build_group_by_query(
query: &str,
aggregated: &[(String, String, Vec<AggSpec>)],
group_cols: &[String],
dialect: &dyn SqlDialect,
) -> String {
let outer_alias = "\"__ggsql_qt__\"";
let (with_clause, src_alias) = source_cte_chain(query, aggregated, group_cols, dialect);
let group_select: Vec<String> = group_cols.iter().map(|c| naming::quote_ident(c)).collect();
let group_by_clause = if group_cols.is_empty() {
String::new()
} else {
format!(" GROUP BY {}", group_select.join(", "))
};
let mut select_parts: Vec<String> = group_select.clone();
for (aes, raw_col, fns) in aggregated {
let agg = &fns[0];
let stat_col = naming::stat_column(aes);
let qcol = naming::quote_ident(raw_col);
let expr = if needs_quantile_fallback(agg, raw_col, dialect) {
agg_sql_fallback(agg, raw_col, dialect, src_alias, group_cols)
} else {
agg_sql_inline(agg, &qcol, dialect)
.expect("agg_sql_inline must succeed when needs_quantile_fallback is false")
};
select_parts.push(format!("{} AS {}", expr, naming::quote_ident(&stat_col)));
}
format!(
"{with_clause} SELECT {sel} FROM {src} AS {outer}{gb}",
sel = select_parts.join(", "),
src = src_alias,
outer = outer_alias,
gb = group_by_clause,
)
}
fn build_aggregate_query(
query: &str,
aggregated: &[(String, String, Vec<AggSpec>)],
group_cols: &[String],
labels: &[String],
dialect: &dyn SqlDialect,
) -> String {
let outer_alias = "\"__ggsql_qt__\"";
let (with_clause, src_alias) = source_cte_chain(query, aggregated, group_cols, dialect);
let group_select: Vec<String> = group_cols.iter().map(|c| naming::quote_ident(c)).collect();
let group_by_clause = if group_cols.is_empty() {
String::new()
} else {
format!(" GROUP BY {}", group_select.join(", "))
};
let stat_aggregate_col = naming::stat_column("aggregate");
let branches: Vec<String> = labels
.iter()
.enumerate()
.map(|(row_idx, label)| {
let mut select_parts: Vec<String> = group_select.clone();
for (aes, raw_col, fns) in aggregated {
let agg = &fns[row_idx];
let stat_col = naming::stat_column(aes);
let qcol = naming::quote_ident(raw_col);
let expr = if needs_quantile_fallback(agg, raw_col, dialect) {
agg_sql_fallback(agg, raw_col, dialect, src_alias, group_cols)
} else {
agg_sql_inline(agg, &qcol, dialect)
.expect("agg_sql_inline must succeed when needs_quantile_fallback is false")
};
select_parts.push(format!("{} AS {}", expr, naming::quote_ident(&stat_col)));
}
select_parts.push(format!(
"{} AS {}",
naming::quote_literal(label),
naming::quote_ident(&stat_aggregate_col)
));
format!(
"SELECT {} FROM {} AS {}{}",
select_parts.join(", "),
src_alias,
outer_alias,
group_by_clause,
)
})
.collect();
format!("{with_clause} {body}", body = branches.join(" UNION ALL "),)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::plot::aesthetic::AestheticContext;
use crate::plot::types::{AestheticValue, ColumnInfo};
use crate::plot::Parameters;
use arrow::datatypes::DataType;
struct InlineQuantileDialect;
impl SqlDialect for InlineQuantileDialect {
fn sql_quantile_inline(&self, column: &str, fraction: f64) -> Option<String> {
Some(format!(
"QUANTILE_CONT({}, {})",
naming::quote_ident(column),
fraction
))
}
fn sql_aggregate(&self, name: &str, qcol: &str) -> Option<String> {
match name {
"first" => Some(format!("FIRST({})", qcol)),
"last" => Some(format!("LAST({})", qcol)),
"diff" => Some(format!("(LAST({c}) - FIRST({c}))", c = qcol)),
_ => crate::reader::default_sql_aggregate(name, qcol),
}
}
}
struct NoInlineQuantileDialect;
impl SqlDialect for NoInlineQuantileDialect {}
fn col(name: &str) -> AestheticValue {
AestheticValue::Column {
name: name.to_string(),
original_name: None,
is_dummy: false,
}
}
fn schema_for(cols: &[(&str, bool)]) -> Schema {
cols.iter()
.map(|(name, is_discrete)| ColumnInfo {
name: name.to_string(),
dtype: if *is_discrete {
DataType::Utf8
} else {
DataType::Float64
},
is_discrete: *is_discrete,
min: None,
max: None,
})
.collect()
}
fn cartesian_ctx() -> AestheticContext {
AestheticContext::from_static(&["x", "y"], &[])
}
fn run(
params: ParameterValue,
aes: &Mappings,
schema: &Schema,
group_by: &[String],
dialect: &dyn SqlDialect,
) -> Result<StatResult> {
run_with_domain(params, aes, schema, group_by, dialect, &[])
}
fn run_with_domain(
params: ParameterValue,
aes: &Mappings,
schema: &Schema,
group_by: &[String],
dialect: &dyn SqlDialect,
domain: &[&'static str],
) -> Result<StatResult> {
let mut p = Parameters::new();
p.insert("aggregate".to_string(), params);
let ctx = cartesian_ctx();
apply(
"SELECT * FROM t",
schema,
aes,
group_by,
&p,
dialect,
&ctx,
domain,
)
}
fn arr(items: &[&str]) -> ParameterValue {
ParameterValue::Array(
items
.iter()
.map(|s| ArrayElement::String(s.to_string()))
.collect(),
)
}
#[test]
fn parses_unset_and_null() {
assert_eq!(parse_aggregate_param(&ParameterValue::Null).unwrap(), None);
assert_eq!(parse_aggregate_param(&arr(&[])).unwrap(), None);
}
#[test]
fn parses_single_default() {
let s = parse_aggregate_param(&ParameterValue::String("mean".to_string()))
.unwrap()
.unwrap();
assert_eq!(s.default_lower.as_ref().map(|a| a.offset), Some("mean"));
assert!(s.default_upper.is_none());
assert!(s.targets.is_empty());
}
#[test]
fn parses_two_defaults_in_order() {
let s = parse_aggregate_param(&arr(&["min", "max"]))
.unwrap()
.unwrap();
assert_eq!(s.default_lower.as_ref().map(|a| a.offset), Some("min"));
assert_eq!(s.default_upper.as_ref().map(|a| a.offset), Some("max"));
}
#[test]
fn three_unprefixed_defaults_is_error() {
let err = parse_aggregate_param(&arr(&["mean", "min", "max"])).unwrap_err();
assert!(err.contains("at most two"), "got: {}", err);
}
fn target_funcs<'a>(spec: &'a AggregateSpec, aes: &str) -> Option<&'a [AggSpec]> {
spec.targets
.iter()
.find(|(a, _)| a == aes)
.map(|(_, fns)| fns.as_slice())
}
#[test]
fn parses_targeted_entries() {
let s = parse_aggregate_param(&arr(&["mean", "y:max", "color:median"]))
.unwrap()
.unwrap();
assert_eq!(s.default_lower.as_ref().map(|a| a.offset), Some("mean"));
assert_eq!(target_funcs(&s, "y").map(|fs| fs[0].offset), Some("max"));
assert_eq!(
target_funcs(&s, "color").map(|fs| fs[0].offset),
Some("median")
);
}
#[test]
fn duplicate_target_explodes_into_a_list() {
let s = parse_aggregate_param(&arr(&["y:min", "y:max"]))
.unwrap()
.unwrap();
let fns = target_funcs(&s, "y").unwrap();
assert_eq!(fns.len(), 2);
assert_eq!(fns[0].offset, "min");
assert_eq!(fns[1].offset, "max");
assert_eq!(s.explosion_factor(), 2);
assert_eq!(
s.explosion_labels(),
Some(vec!["min".to_string(), "max".to_string()])
);
}
#[test]
fn multi_aesthetic_explosion_joins_unique_function_names() {
let s = parse_aggregate_param(&arr(&["y:min", "y:max", "color:sum", "color:prod"]))
.unwrap()
.unwrap();
assert_eq!(
s.explosion_labels(),
Some(vec!["min/sum".to_string(), "max/prod".to_string()])
);
}
#[test]
fn multi_aesthetic_explosion_dedups_repeats() {
let s = parse_aggregate_param(&arr(&["y:mean", "y:max", "color:mean", "color:prod"]))
.unwrap()
.unwrap();
assert_eq!(
s.explosion_labels(),
Some(vec!["mean".to_string(), "max/prod".to_string()])
);
}
#[test]
fn recycled_target_excluded_from_label() {
let s = parse_aggregate_param(&arr(&["y:min", "y:max", "color:median"]))
.unwrap()
.unwrap();
assert_eq!(
s.explosion_labels(),
Some(vec!["min".to_string(), "max".to_string()])
);
}
#[test]
fn single_row_returns_no_labels() {
let s = parse_aggregate_param(&ParameterValue::String("mean".to_string()))
.unwrap()
.unwrap();
assert_eq!(s.explosion_labels(), None);
let s = parse_aggregate_param(&arr(&["mean", "color:median"]))
.unwrap()
.unwrap();
assert_eq!(s.explosion_labels(), None);
}
#[test]
fn recycling_violation_is_error() {
let err = parse_aggregate_param(&arr(&[
"y:min",
"y:max",
"color:p10",
"color:p50",
"color:p90",
]))
.unwrap_err();
assert!(err.contains("longest target"), "got: {}", err);
}
#[test]
fn length_one_target_recycles_in_explosion() {
let s = parse_aggregate_param(&arr(&["y:min", "y:max", "color:median"]))
.unwrap()
.unwrap();
assert_eq!(s.explosion_factor(), 2);
assert_eq!(target_funcs(&s, "color").map(|f| f.len()), Some(1));
}
#[test]
fn empty_prefix_is_error() {
let err = parse_aggregate_param(&ParameterValue::String(":mean".to_string())).unwrap_err();
assert!(err.contains("could not parse"), "got: {}", err);
}
#[test]
fn unknown_function_is_error() {
let err = parse_aggregate_param(&ParameterValue::String("nope".to_string())).unwrap_err();
assert!(err.contains("unknown aggregate"), "got: {}", err);
}
#[test]
fn band_functions_parse() {
let s = parse_aggregate_param(&arr(&["mean-sdev", "mean+sdev"]))
.unwrap()
.unwrap();
assert_eq!(s.default_lower.as_ref().unwrap().offset, "mean");
assert_eq!(
s.default_lower
.as_ref()
.unwrap()
.band
.as_ref()
.unwrap()
.expansion,
"sdev"
);
assert_eq!(
s.default_lower
.as_ref()
.unwrap()
.band
.as_ref()
.unwrap()
.mod_value,
-1.0,
);
assert_eq!(s.default_upper.as_ref().unwrap().offset, "mean");
assert_eq!(
s.default_upper
.as_ref()
.unwrap()
.band
.as_ref()
.unwrap()
.mod_value,
1.0,
);
}
#[test]
fn returns_identity_when_param_unset() {
let aes = Mappings::new();
let schema: Schema = vec![];
let p = Parameters::new();
let ctx = cartesian_ctx();
let result = apply(
"SELECT * FROM t",
&schema,
&aes,
&[],
&p,
&InlineQuantileDialect,
&ctx,
&[],
)
.unwrap();
assert_eq!(result, StatResult::Identity);
}
#[test]
fn returns_identity_when_param_null() {
let aes = Mappings::new();
let schema: Schema = vec![];
let result = run(
ParameterValue::Null,
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
assert_eq!(result, StatResult::Identity);
}
#[test]
fn single_default_applies_to_every_numeric_mapping() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]);
let result = run(
ParameterValue::String("mean".to_string()),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed {
query,
stat_columns,
consumed_aesthetics,
..
} => {
assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")"), "{}", query);
assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")"), "{}", query);
assert!(!query.contains("CROSS JOIN"));
assert!(!query.contains("UNION ALL"));
assert_eq!(stat_columns.len(), 2);
assert!(stat_columns.contains(&"pos1".to_string()));
assert!(stat_columns.contains(&"pos2".to_string()));
assert_eq!(consumed_aesthetics.len(), 2);
}
_ => panic!("expected Transformed"),
}
}
#[cfg(feature = "sqlite")]
#[test]
fn sqlite_dialect_emits_portable_stddev_and_first() {
use crate::reader::sqlite::SqliteDialect;
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]);
let result = run(
ParameterValue::String("sdev".to_string()),
&aes,
&schema,
&[],
&SqliteDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(
!query.contains("STDDEV_POP"),
"SQLite dialect must not emit STDDEV_POP, got: {query}"
);
assert!(query.contains("SQRT") && query.contains("AVG"), "{query}");
}
_ => panic!("expected Transformed"),
}
let result = run(
ParameterValue::String("first".to_string()),
&aes,
&schema,
&[],
&SqliteDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(
query.contains("ROW_NUMBER()"),
"expected ROW_NUMBER prep, got: {query}"
);
assert!(
query.contains("\"__ggsql_rn__\" = 1"),
"expected first via rn=1, got: {query}"
);
assert!(
!query.contains("FIRST(\""),
"must not call FIRST as an aggregate, got: {query}"
);
}
_ => panic!("expected Transformed"),
}
}
#[cfg(feature = "sqlite")]
#[test]
fn sqlite_first_last_diff_return_correct_values() {
use crate::naming;
use crate::reader::SqliteReader;
let reader = SqliteReader::new().unwrap();
let body = "WITH t(g, ord, v) AS (\
SELECT 'A', 1, 10 UNION ALL SELECT 'A', 2, 30 \
UNION ALL SELECT 'A', 3, 20 \
UNION ALL SELECT 'B', 1, 100 UNION ALL SELECT 'B', 2, 50) \
SELECT g, v FROM t ORDER BY g, ord";
let run_agg = |func: &str| -> Vec<(String, f64)> {
let query = format!(
"{body} VISUALISE \
DRAW point MAPPING g AS x, v AS y \
SETTING aggregate => '{func}'"
);
let prepared = crate::execute::prepare_data_with_reader(&query, &reader).unwrap();
let df = prepared
.data
.get(prepared.specs[0].layers[0].data_key.as_ref().unwrap())
.unwrap();
let xs = df.column("__ggsql_aes_pos1__").unwrap();
let ys = df.column("__ggsql_aes_pos2__").unwrap();
let mut out: Vec<(String, f64)> = (0..df.height())
.map(|i| {
let x = crate::array_util::value_to_string(xs, i);
let y = crate::array_util::value_to_string(ys, i)
.parse::<f64>()
.unwrap();
(x, y)
})
.collect();
out.sort_by(|a, b| a.0.cmp(&b.0));
out
};
assert_eq!(
run_agg("first"),
vec![("A".to_string(), 10.0), ("B".to_string(), 100.0)],
"first should pick the group's first row in ORDER BY ord"
);
assert_eq!(
run_agg("last"),
vec![("A".to_string(), 20.0), ("B".to_string(), 50.0)],
"last should pick the group's last row"
);
assert_eq!(
run_agg("diff"),
vec![("A".to_string(), 10.0), ("B".to_string(), -50.0)],
"diff should be last - first per group"
);
let _ = naming::layer_key(0); }
#[test]
fn unsupported_aggregate_errors_with_dialect_that_lacks_function() {
struct OptOutDialect;
impl SqlDialect for OptOutDialect {
fn sql_aggregate(&self, name: &str, qcol: &str) -> Option<String> {
if name == "first" {
return None;
}
crate::reader::default_sql_aggregate(name, qcol)
}
}
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]);
let err = run(
ParameterValue::String("first".to_string()),
&aes,
&schema,
&[],
&OptOutDialect,
)
.unwrap_err();
let msg = format!("{}", err);
assert!(
msg.contains("first") && msg.contains("not supported"),
"expected unsupported-function error mentioning 'first', got: {msg}"
);
}
#[test]
fn mid_emits_min_max_midpoint() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]);
let result = run(
ParameterValue::String("mid".to_string()),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(
query.contains(
"(MIN(\"__ggsql_aes_pos1__\") + MAX(\"__ggsql_aes_pos1__\")) / 2.0"
),
"{}",
query
);
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn diff_uses_row_position_and_subtracts_first_from_last() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]);
struct AnsiTestDialect;
impl SqlDialect for AnsiTestDialect {}
let result = run(
ParameterValue::String("diff".to_string()),
&aes,
&schema,
&[],
&AnsiTestDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(query.contains("ROW_NUMBER()"), "{query}");
assert!(
query.contains("\"__ggsql_rn__\" = \"__ggsql_max_rn__\""),
"{query}"
);
assert!(query.contains("\"__ggsql_rn__\" = 1"), "{query}");
assert!(query.contains(" - "), "expected subtraction, got: {query}");
}
_ => panic!("expected Transformed"),
}
let result = run(
ParameterValue::String("diff".to_string()),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(
query.contains("LAST(") && query.contains("FIRST("),
"expected native LAST/FIRST: {query}"
);
assert!(
!query.contains("__ggsql_rn__"),
"native dialect must not add ROW_NUMBER prep: {query}"
);
}
_ => panic!("expected Transformed"),
}
}
#[cfg(feature = "duckdb")]
#[test]
fn duckdb_first_skips_row_number_cte() {
use crate::reader::duckdb::DuckDbDialect;
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]);
let result = run(
ParameterValue::String("first".to_string()),
&aes,
&schema,
&[],
&DuckDbDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(
query.contains("FIRST(\""),
"expected native FIRST aggregate, got: {query}"
);
assert!(
!query.contains("__ggsql_rn__"),
"DuckDB has native FIRST, must not add ROW_NUMBER prep: {query}"
);
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn last_with_discrete_group_partitions_row_number_over_group() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
let schema = schema_for(&[
("__ggsql_aes_pos1__", true), ("__ggsql_aes_pos2__", false),
]);
let result = run(
ParameterValue::String("last".to_string()),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(
!query.contains("__ggsql_rn__"),
"native LAST must not add ROW_NUMBER prep: {query}"
);
assert!(query.contains("LAST(\"__ggsql_aes_pos2__\")"), "{query}");
assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\""), "{query}");
}
_ => panic!("expected Transformed"),
}
struct AnsiTestDialect;
impl SqlDialect for AnsiTestDialect {}
let result = run(
ParameterValue::String("last".to_string()),
&aes,
&schema,
&[],
&AnsiTestDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(
query.contains(
"ROW_NUMBER() OVER (PARTITION BY \"__ggsql_aes_pos1__\" ORDER BY (SELECT 1))"
),
"{query}"
);
assert!(
query.contains("COUNT(*) OVER (PARTITION BY \"__ggsql_aes_pos1__\")"),
"{query}"
);
assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\""), "{query}");
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn first_and_last_emit_positional_aggregates() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2min", col("__ggsql_aes_pos2min__"));
aes.insert("pos2max", col("__ggsql_aes_pos2max__"));
let schema = schema_for(&[
("__ggsql_aes_pos1__", false),
("__ggsql_aes_pos2min__", false),
("__ggsql_aes_pos2max__", false),
]);
let result = run(
arr(&["first", "last"]),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(
query.contains("FIRST(\"__ggsql_aes_pos2min__\")"),
"{}",
query
);
assert!(
query.contains("LAST(\"__ggsql_aes_pos2max__\")"),
"{}",
query
);
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn two_defaults_split_lower_and_upper_for_segment() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
aes.insert("pos1end", col("__ggsql_aes_pos1end__"));
aes.insert("pos2end", col("__ggsql_aes_pos2end__"));
let schema = schema_for(&[
("__ggsql_aes_pos1__", false),
("__ggsql_aes_pos2__", false),
("__ggsql_aes_pos1end__", false),
("__ggsql_aes_pos2end__", false),
]);
let result = run(
arr(&["min", "max"]),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(query.contains("MIN(\"__ggsql_aes_pos1__\")"), "{}", query);
assert!(query.contains("MIN(\"__ggsql_aes_pos2__\")"), "{}", query);
assert!(
query.contains("MAX(\"__ggsql_aes_pos1end__\")"),
"{}",
query
);
assert!(
query.contains("MAX(\"__ggsql_aes_pos2end__\")"),
"{}",
query
);
assert!(!query.contains("MIN(\"__ggsql_aes_pos1end__\")"));
assert!(!query.contains("MAX(\"__ggsql_aes_pos1__\")"));
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn two_defaults_split_for_ribbon() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2min", col("__ggsql_aes_pos2min__"));
aes.insert("pos2max", col("__ggsql_aes_pos2max__"));
let schema = schema_for(&[
("__ggsql_aes_pos1__", false),
("__ggsql_aes_pos2min__", false),
("__ggsql_aes_pos2max__", false),
]);
let result = run(
arr(&["mean-sdev", "mean+sdev"]),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(query.contains("STDDEV_POP(\"__ggsql_aes_pos2max__\")"));
assert!(query.contains("AVG(\"__ggsql_aes_pos2min__\")"));
let pos2max_section = query.split("__ggsql_aes_pos2max__\")").next().unwrap_or("");
assert!(pos2max_section.contains('+') || query.contains("+ STDDEV_POP"));
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn targeted_prefix_overrides_default() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]);
let result = run(
arr(&["mean", "y:max"]),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")"), "{}", query);
assert!(query.contains("MAX(\"__ggsql_aes_pos2__\")"), "{}", query);
assert!(!query.contains("AVG(\"__ggsql_aes_pos2__\")"));
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn material_aesthetic_targeted_by_user_facing_name() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
aes.insert("size", col("__ggsql_aes_size__"));
let schema = schema_for(&[
("__ggsql_aes_pos1__", false),
("__ggsql_aes_pos2__", false),
("__ggsql_aes_size__", false),
]);
let result = run(
arr(&["mean", "size:median"]),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed {
query,
stat_columns,
..
} => {
assert!(query.contains("QUANTILE_CONT(\"__ggsql_aes_size__\", 0.5)"));
assert!(stat_columns.contains(&"size".to_string()));
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn color_alias_targets_stroke_and_fill() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
aes.insert("fill", col("__ggsql_aes_fill__"));
let schema = schema_for(&[
("__ggsql_aes_pos1__", false),
("__ggsql_aes_pos2__", false),
("__ggsql_aes_fill__", false),
]);
let result = run(
arr(&["mean", "color:max"]),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed {
query,
stat_columns,
..
} => {
assert!(query.contains("MAX(\"__ggsql_aes_fill__\")"), "{}", query);
assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")"));
assert!(stat_columns.contains(&"fill".to_string()));
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn explosion_emits_union_all_with_aggregate_label_column() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]);
let result = run(
arr(&["y:min", "y:max"]),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed {
query,
stat_columns,
consumed_aesthetics,
..
} => {
assert!(query.contains("UNION ALL"), "{}", query);
assert!(query.contains("MIN(\"__ggsql_aes_pos2__\")"), "{}", query);
assert!(query.contains("MAX(\"__ggsql_aes_pos2__\")"), "{}", query);
assert!(query.contains("'min' AS \"__ggsql_stat_aggregate\""));
assert!(query.contains("'max' AS \"__ggsql_stat_aggregate\""));
assert!(consumed_aesthetics.contains(&"pos2".to_string()));
assert!(!consumed_aesthetics.contains(&"aggregate".to_string()));
assert!(stat_columns.contains(&"aggregate".to_string()));
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn explosion_recycles_length_one_targets_and_defaults() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
aes.insert("fill", col("__ggsql_aes_fill__"));
aes.insert("size", col("__ggsql_aes_size__"));
let schema = schema_for(&[
("__ggsql_aes_pos1__", false),
("__ggsql_aes_pos2__", false),
("__ggsql_aes_fill__", false),
("__ggsql_aes_size__", false),
]);
let result = run(
arr(&["mean", "y:min", "y:max", "color:median"]),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(query.contains("MIN(\"__ggsql_aes_pos2__\")"), "{}", query);
assert!(query.contains("MAX(\"__ggsql_aes_pos2__\")"));
let median_count = query
.matches("QUANTILE_CONT(\"__ggsql_aes_fill__\", 0.5)")
.count();
assert_eq!(
median_count, 2,
"color median should appear once per branch: {}",
query
);
let avg_size = query.matches("AVG(\"__ggsql_aes_size__\")").count();
assert_eq!(
avg_size, 2,
"size mean should appear once per branch: {}",
query
);
let avg_pos1 = query.matches("AVG(\"__ggsql_aes_pos1__\")").count();
assert_eq!(avg_pos1, 2);
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn domain_aesthetic_kept_as_group_key_even_when_continuous() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
let schema = schema_for(&[
("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false),
]);
let result = run_with_domain(
arr(&["y:min", "y:max"]),
&aes,
&schema,
&[],
&InlineQuantileDialect,
&["pos1"],
)
.unwrap();
match result {
StatResult::Transformed {
query,
stat_columns,
consumed_aesthetics,
..
} => {
assert!(
query.contains("GROUP BY \"__ggsql_aes_pos1__\""),
"{}",
query
);
assert!(!query.contains("MIN(\"__ggsql_aes_pos1__\")"));
assert!(!query.contains("MAX(\"__ggsql_aes_pos1__\")"));
assert!(query.contains("MIN(\"__ggsql_aes_pos2__\")"));
assert!(query.contains("MAX(\"__ggsql_aes_pos2__\")"));
assert!(!consumed_aesthetics.contains(&"pos1".to_string()));
assert!(consumed_aesthetics.contains(&"pos2".to_string()));
assert!(stat_columns.contains(&"aggregate".to_string()));
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn explosion_with_range_geom_two_defaults() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2min", col("__ggsql_aes_pos2min__"));
aes.insert("pos2max", col("__ggsql_aes_pos2max__"));
aes.insert("fill", col("__ggsql_aes_fill__"));
let schema = schema_for(&[
("__ggsql_aes_pos1__", false),
("__ggsql_aes_pos2min__", false),
("__ggsql_aes_pos2max__", false),
("__ggsql_aes_fill__", false),
]);
let result = run(
arr(&["mean-sdev", "mean+sdev", "color:p25", "color:p75"]),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed {
query,
stat_columns,
..
} => {
assert!(query.contains("UNION ALL"));
let upper_branch_marker = "AVG(\"__ggsql_aes_pos2max__\") + STDDEV_POP";
assert!(query.contains(upper_branch_marker), "{}", query);
assert!(query.contains("QUANTILE_CONT(\"__ggsql_aes_fill__\", 0.25)"));
assert!(query.contains("QUANTILE_CONT(\"__ggsql_aes_fill__\", 0.75)"));
assert!(stat_columns.contains(&"aggregate".to_string()));
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn discrete_mapping_becomes_group_key() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
aes.insert("color", col("__ggsql_aes_color__"));
let schema = schema_for(&[
("__ggsql_aes_pos1__", false),
("__ggsql_aes_pos2__", false),
("__ggsql_aes_color__", true), ]);
let result = run(
ParameterValue::String("mean".to_string()),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed {
query,
stat_columns,
..
} => {
assert!(
query.contains("GROUP BY \"__ggsql_aes_color__\""),
"{}",
query
);
assert!(!stat_columns.contains(&"color".to_string()));
assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")"));
assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")"));
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn literal_mapping_passes_through() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
aes.insert(
"fill",
AestheticValue::Literal(ParameterValue::String("steelblue".to_string())),
);
let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]);
let result = run(
ParameterValue::String("mean".to_string()),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(!query.contains("AVG(\"__ggsql_aes_fill__\")"));
assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")"));
assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")"));
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn untargeted_numeric_mapping_dropped_when_no_default() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]);
let result = run(
ParameterValue::String("y:mean".to_string()),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed {
query,
stat_columns,
..
} => {
assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")"));
assert!(!query.contains("\"__ggsql_aes_pos1__\""));
assert_eq!(stat_columns, vec!["pos2".to_string()]);
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn quantile_uses_dialect_inline_when_available() {
let mut aes = Mappings::new();
aes.insert("pos2", col("__ggsql_aes_pos2__"));
let schema = schema_for(&[("__ggsql_aes_pos2__", false)]);
let result = run(
ParameterValue::String("p25".to_string()),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(query.contains("QUANTILE_CONT"));
assert!(query.contains("0.25"));
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn quantile_falls_back_to_correlated_subquery_without_inline() {
let mut aes = Mappings::new();
aes.insert("pos2", col("__ggsql_aes_pos2__"));
let schema = schema_for(&[("__ggsql_aes_pos2__", false)]);
let result = run(
ParameterValue::String("p25".to_string()),
&aes,
&schema,
&[],
&NoInlineQuantileDialect,
)
.unwrap();
match result {
StatResult::Transformed { query, .. } => {
assert!(query.contains("NTILE(4)"));
assert!(!query.contains("UNION ALL"));
}
_ => panic!("expected Transformed"),
}
}
#[test]
fn unknown_targeted_aesthetic_is_error() {
let mut aes = Mappings::new();
aes.insert("pos1", col("__ggsql_aes_pos1__"));
aes.insert("pos2", col("__ggsql_aes_pos2__"));
let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]);
let err = run(
ParameterValue::String("size:mean".to_string()),
&aes,
&schema,
&[],
&InlineQuantileDialect,
)
.unwrap_err();
let msg = format!("{}", err);
assert!(msg.contains("not mapped"), "got: {}", msg);
}
}