use std::collections::HashSet;
use super::stat_aggregate;
use super::types::{get_column_name, wrap_stat_with_dummy_pos1, POSITION_VALUES};
use super::{
has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType,
ParamConstraint, ParamDefinition, StatResult,
};
use crate::naming;
use crate::plot::types::{DefaultAestheticValue, Parameters};
use crate::reader::SqlDialect;
use crate::{DataFrame, GgsqlError, Mappings, Result};
use super::types::Schema;
#[derive(Debug, Clone, Copy)]
pub struct Bar;
impl GeomTrait for Bar {
fn geom_type(&self) -> GeomType {
GeomType::Bar
}
fn aesthetics(&self) -> DefaultAesthetics {
DefaultAesthetics {
defaults: &[
("pos1", DefaultAestheticValue::Dummy), ("pos2", DefaultAestheticValue::Null), ("pos2end", DefaultAestheticValue::Delayed),
("weight", DefaultAestheticValue::Null),
("fill", DefaultAestheticValue::String("black")),
("stroke", DefaultAestheticValue::String("black")),
("opacity", DefaultAestheticValue::Number(0.8)),
],
}
}
fn default_remappings(&self) -> DefaultAesthetics {
DefaultAesthetics {
defaults: &[
("pos2", DefaultAestheticValue::Column("count")),
("pos2end", DefaultAestheticValue::Number(0.0)),
],
}
}
fn valid_stat_columns(&self) -> &'static [&'static str] {
&["count", "proportion"]
}
fn default_params(&self) -> &'static [ParamDefinition] {
const PARAMS: &[ParamDefinition] = &[
ParamDefinition {
name: "width",
default: DefaultParamValue::Number(0.9),
constraint: ParamConstraint::number_range(0.0, 1.0),
},
ParamDefinition {
name: "position",
default: DefaultParamValue::String("stack"),
constraint: ParamConstraint::string_option(POSITION_VALUES),
},
super::types::AGGREGATE_PARAM,
];
PARAMS
}
fn stat_consumed_aesthetics(&self) -> &'static [&'static str] {
&["pos1", "pos2", "weight"]
}
fn aggregate_domain_aesthetics(&self) -> Option<&'static [&'static str]> {
Some(&[])
}
fn apply_stat_transform(
&self,
query: &str,
schema: &Schema,
aesthetics: &Mappings,
group_by: &[String],
parameters: &Parameters,
_execute_query: &dyn Fn(&str) -> Result<DataFrame>,
dialect: &dyn SqlDialect,
aesthetic_ctx: &crate::plot::aesthetic::AestheticContext,
) -> Result<StatResult> {
let inner = if has_aggregate_param(parameters) {
stat_aggregate::apply(
query,
schema,
aesthetics,
group_by,
parameters,
dialect,
aesthetic_ctx,
self.aggregate_domain_aesthetics().unwrap_or(&[]),
)?
} else {
stat_bar_count(query, schema, aesthetics, group_by)?
};
if get_column_name(aesthetics, "pos1").is_none() {
Ok(wrap_stat_with_dummy_pos1(query, inner))
} else {
Ok(inner)
}
}
}
impl std::fmt::Display for Bar {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "bar")
}
}
fn stat_bar_count(
query: &str,
schema: &Schema,
aesthetics: &Mappings,
group_by: &[String],
) -> Result<StatResult> {
let x_col = get_column_name(aesthetics, "pos1");
let use_dummy_x = x_col.is_none();
let schema_columns: HashSet<&str> = schema.iter().map(|c| c.name.as_str()).collect();
if let Some(y_value) = aesthetics.get("pos2") {
if y_value.is_literal() {
return Ok(StatResult::Identity);
}
if let Some(y_col) = y_value.column_name() {
if schema_columns.contains(y_col) {
return Ok(StatResult::Identity);
}
}
}
let stat_count = naming::stat_column("count");
let stat_proportion = naming::stat_column("proportion");
let stat_x = naming::stat_column("pos1");
let stat_dummy_value = naming::stat_column("dummy");
let agg_expr = if let Some(weight_value) = aesthetics.get("weight") {
if weight_value.is_literal() {
return Err(GgsqlError::ValidationError(
"Bar weight aesthetic must be a column, not a literal".to_string(),
));
}
if let Some(weight_col) = weight_value.column_name() {
if schema_columns.contains(weight_col) {
format!(
"SUM({}) AS {}",
naming::quote_ident(weight_col),
naming::quote_ident(&stat_count)
)
} else {
format!("COUNT(*) AS {}", naming::quote_ident(&stat_count))
}
} else {
format!("COUNT(*) AS {}", naming::quote_ident(&stat_count))
}
} else {
format!("COUNT(*) AS {}", naming::quote_ident(&stat_count))
};
let (transformed_query, stat_columns, dummy_columns, consumed_aesthetics) = if use_dummy_x {
let q_x = naming::quote_ident(&stat_x);
let q_count = naming::quote_ident(&stat_count);
let q_prop = naming::quote_ident(&stat_proportion);
let (grouped_select, final_select) = if group_by.is_empty() {
(
format!(
"'{dummy}' AS {x}, {agg}",
dummy = stat_dummy_value,
x = q_x,
agg = agg_expr
),
format!(
"*, {count} * 1.0 / SUM({count}) OVER () AS {prop}",
count = q_count,
prop = q_prop
),
)
} else {
let grp_cols = group_by.join(", ");
(
format!(
"{g}, '{dummy}' AS {x}, {agg}",
g = grp_cols,
dummy = stat_dummy_value,
x = q_x,
agg = agg_expr
),
format!(
"*, {count} * 1.0 / SUM({count}) OVER (PARTITION BY {grp}) AS {prop}",
count = q_count,
grp = grp_cols,
prop = q_prop
),
)
};
let query_str = if group_by.is_empty() {
format!(
"WITH \"__stat_src__\" AS ({query}), \"__grouped__\" AS (SELECT {grouped} FROM \"__stat_src__\") SELECT {final} FROM \"__grouped__\"",
query = query,
grouped = grouped_select,
final = final_select
)
} else {
let group_cols = group_by.join(", ");
format!(
"WITH \"__stat_src__\" AS ({query}), \"__grouped__\" AS (SELECT {grouped} FROM \"__stat_src__\" GROUP BY {group}) SELECT {final} FROM \"__grouped__\"",
query = query,
grouped = grouped_select,
group = group_cols,
final = final_select
)
};
(
query_str,
vec![
"pos1".to_string(),
"count".to_string(),
"proportion".to_string(),
],
vec!["pos1".to_string()],
vec!["weight".to_string()],
)
} else {
let x_col = naming::quote_ident(&x_col.unwrap());
let group_cols = if group_by.is_empty() {
x_col.clone()
} else {
let mut cols = group_by.to_vec();
cols.push(x_col.clone());
cols.join(", ")
};
let q_count = naming::quote_ident(&stat_count);
let q_prop = naming::quote_ident(&stat_proportion);
let (grouped_select, final_select) = if group_by.is_empty() {
(
format!("{x}, {agg}", x = x_col, agg = agg_expr),
format!(
"*, {count} * 1.0 / SUM({count}) OVER () AS {prop}",
count = q_count,
prop = q_prop
),
)
} else {
let grp_cols = group_by.join(", ");
(
format!("{g}, {x}, {agg}", g = grp_cols, x = x_col, agg = agg_expr),
format!(
"*, {count} * 1.0 / SUM({count}) OVER (PARTITION BY {grp}) AS {prop}",
count = q_count,
grp = grp_cols,
prop = q_prop
),
)
};
let query_str = format!(
"WITH \"__stat_src__\" AS ({query}), \"__grouped__\" AS (SELECT {grouped} FROM \"__stat_src__\" GROUP BY {group}) SELECT {final} FROM \"__grouped__\"",
query = query,
grouped = grouped_select,
group = group_cols,
final = final_select
);
(
query_str,
vec!["count".to_string(), "proportion".to_string()],
vec![],
vec!["weight".to_string()],
)
};
Ok(StatResult::Transformed {
query: transformed_query,
stat_columns,
dummy_columns,
consumed_aesthetics,
})
}