use super::types::{wrap_with_dummy_axis, POSITION_VALUES, SIDE_VALUES};
use super::{DefaultAesthetics, GeomTrait, GeomType, StatResult};
use crate::{
naming,
plot::{
geom::types::get_column_name, DefaultAestheticValue, DefaultParamValue, ParamConstraint,
ParamDefinition, ParameterValue, Parameters,
},
DataFrame, GgsqlError, Mappings, Result,
};
const KERNEL_VALUES: &[&str] = &[
"gaussian",
"epanechnikov",
"triangular",
"rectangular",
"uniform",
"biweight",
"quartic",
"cosine",
];
#[derive(Debug, Clone, Copy)]
pub struct Violin;
impl GeomTrait for Violin {
fn geom_type(&self) -> GeomType {
GeomType::Violin
}
fn aesthetics(&self) -> DefaultAesthetics {
DefaultAesthetics {
defaults: &[
("pos1", DefaultAestheticValue::Dummy),
("pos2", DefaultAestheticValue::Required),
("weight", DefaultAestheticValue::Null),
("fill", DefaultAestheticValue::String("black")),
("stroke", DefaultAestheticValue::String("black")),
("opacity", DefaultAestheticValue::Number(0.8)),
("linewidth", DefaultAestheticValue::Number(1.0)),
("linetype", DefaultAestheticValue::String("solid")),
("offset", DefaultAestheticValue::Delayed), ],
}
}
fn default_params(&self) -> &'static [ParamDefinition] {
const PARAMS: &[ParamDefinition] = &[
ParamDefinition {
name: "bandwidth",
default: DefaultParamValue::Null,
constraint: ParamConstraint::number_min_exclusive(0.0),
},
ParamDefinition {
name: "adjust",
default: DefaultParamValue::Number(1.0),
constraint: ParamConstraint::number_min_exclusive(0.0),
},
ParamDefinition {
name: "kernel",
default: DefaultParamValue::String("gaussian"),
constraint: ParamConstraint::string_option(KERNEL_VALUES),
},
ParamDefinition {
name: "position",
default: DefaultParamValue::String("dodge"),
constraint: ParamConstraint::string_option(POSITION_VALUES),
},
ParamDefinition {
name: "width",
default: DefaultParamValue::Number(0.9),
constraint: ParamConstraint::number_min_exclusive(0.0),
},
ParamDefinition {
name: "side",
default: DefaultParamValue::String("both"),
constraint: ParamConstraint::string_option(SIDE_VALUES),
},
ParamDefinition {
name: "tails",
default: DefaultParamValue::Number(3.0),
constraint: ParamConstraint::number_min(0.0),
},
];
PARAMS
}
fn default_remappings(&self) -> DefaultAesthetics {
DefaultAesthetics {
defaults: &[
("pos2", DefaultAestheticValue::Column("pos2")),
("offset", DefaultAestheticValue::Column("density")),
],
}
}
fn valid_stat_columns(&self) -> &'static [&'static str] {
&["pos2", "density", "intensity"]
}
fn stat_consumed_aesthetics(&self) -> &'static [&'static str] {
&["pos2", "weight"]
}
fn apply_stat_transform(
&self,
query: &str,
_schema: &crate::plot::Schema,
aesthetics: &Mappings,
group_by: &[String],
parameters: &Parameters,
_execute_query: &dyn Fn(&str) -> crate::Result<crate::DataFrame>,
dialect: &dyn crate::reader::SqlDialect,
aesthetic_ctx: &crate::plot::aesthetic::AestheticContext,
) -> Result<StatResult> {
stat_violin(
query,
aesthetics,
group_by,
parameters,
dialect,
aesthetic_ctx,
)
}
fn post_process(&self, df: DataFrame, parameters: &Parameters) -> Result<DataFrame> {
let offset_col = naming::aesthetic_column("offset");
let width = parameters
.get("width")
.and_then(|v| match v {
ParameterValue::Number(n) => Some(*n),
_ => None,
})
.unwrap_or(0.9);
let half_width = 0.5 * width;
scale_offset_column(df, &offset_col, half_width)
}
}
impl std::fmt::Display for Violin {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "violin")
}
}
fn scale_offset_column(df: DataFrame, offset_col: &str, half_width: f64) -> Result<DataFrame> {
if df.column(offset_col).is_err() {
return Ok(df);
}
use arrow::array::Array;
let offset_arr = df.column(offset_col)?;
let f64_arr = crate::array_util::as_f64(offset_arr)
.map_err(|e| GgsqlError::InternalError(format!("Offset column must be f64: {}", e)))?;
let max_val = arrow::compute::max(f64_arr).unwrap_or(1.0);
if max_val <= 0.0 {
return Ok(df);
}
let scale_factor = half_width / max_val;
let scaled_values: Vec<Option<f64>> = (0..f64_arr.len())
.map(|i| {
if f64_arr.is_null(i) {
None
} else {
Some(f64_arr.value(i) * scale_factor)
}
})
.collect();
let scaled_array = crate::array_util::new_f64_array(scaled_values);
let scaled = df.with_column(offset_col, scaled_array)?;
Ok(scaled)
}
fn stat_violin(
query: &str,
aesthetics: &Mappings,
group_by: &[String],
parameters: &Parameters,
dialect: &dyn crate::reader::SqlDialect,
aesthetic_ctx: &crate::plot::aesthetic::AestheticContext,
) -> Result<StatResult> {
if get_column_name(aesthetics, "pos2").is_none() {
let name = aesthetic_ctx.map_internal_to_user("pos2");
return Err(GgsqlError::ValidationError(format!(
"Violin requires '{}' aesthetic mapping (continuous)",
name
)));
}
let mut group_by = group_by.to_vec();
let (working_query, use_dummy) = match get_column_name(aesthetics, "pos1") {
Some(x_col) => {
if !group_by.contains(&x_col) {
group_by.push(x_col);
}
(query.to_string(), false)
}
None => {
let dummy_col = naming::stat_column("pos1");
group_by.push(dummy_col);
(wrap_with_dummy_axis(query, "pos1"), true)
}
};
let inner = super::density::stat_density(
&working_query,
aesthetics,
"pos2",
None,
group_by.as_slice(),
parameters,
dialect,
aesthetic_ctx,
)?;
if !use_dummy {
return Ok(inner);
}
match inner {
StatResult::Identity => unreachable!("stat_density always returns Transformed"),
StatResult::Transformed {
query,
mut stat_columns,
mut dummy_columns,
consumed_aesthetics,
} => {
if !stat_columns.iter().any(|s| s == "pos1") {
stat_columns.push("pos1".to_string());
}
if !dummy_columns.iter().any(|s| s == "pos1") {
dummy_columns.push("pos1".to_string());
}
Ok(StatResult::Transformed {
query,
stat_columns,
dummy_columns,
consumed_aesthetics,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::plot::AestheticValue;
use crate::plot::Parameters;
use crate::reader::duckdb::DuckDBReader;
use crate::reader::AnsiDialect;
use crate::reader::Reader;
use arrow::array::Array;
fn count_unique_strings(col: &arrow::array::ArrayRef) -> usize {
let arr = crate::array_util::as_str(col).expect("expected string array");
let mut seen = std::collections::HashSet::new();
for i in 0..arr.len() {
if !arr.is_null(i) {
seen.insert(arr.value(i).to_string());
}
}
seen.len()
}
fn create_basic_aesthetics() -> Mappings {
let mut aesthetics = Mappings::new();
aesthetics.insert(
"pos1".to_string(),
AestheticValue::standard_column("species".to_string()),
);
aesthetics.insert(
"pos2".to_string(),
AestheticValue::standard_column("flipper_length".to_string()),
);
aesthetics
}
fn create_aesthetics_with_color() -> Mappings {
let mut aesthetics = create_basic_aesthetics();
aesthetics.insert(
"color".to_string(),
AestheticValue::standard_column("island".to_string()),
);
aesthetics
}
#[test]
fn test_violin_no_extra_groups() {
let query = "SELECT species, flipper_length FROM penguins";
let aesthetics = create_basic_aesthetics();
let groups: Vec<String> = vec![];
let mut parameters = Parameters::new();
parameters.insert("bandwidth".to_string(), ParameterValue::Number(5.0));
parameters.insert(
"kernel".to_string(),
ParameterValue::String("gaussian".to_string()),
);
let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap();
let setup_sql = "CREATE TABLE penguins AS SELECT * FROM (VALUES
('Adelie', 181.0), ('Adelie', 186.0), ('Adelie', 195.0),
('Gentoo', 217.0), ('Gentoo', 221.0), ('Gentoo', 230.0),
('Chinstrap', 192.0), ('Chinstrap', 196.0), ('Chinstrap', 201.0)
) AS t(species, flipper_length)";
reader.execute_sql(setup_sql).unwrap();
let execute = |sql: &str| reader.execute_sql(sql);
let ctx = crate::plot::aesthetic::AestheticContext::from_static(&["x", "y"], &[]);
let result = stat_violin(query, &aesthetics, &groups, ¶meters, &AnsiDialect, &ctx)
.expect("stat_violin should succeed");
match result {
StatResult::Transformed {
query: stat_query,
stat_columns,
consumed_aesthetics,
..
} => {
assert_eq!(stat_columns, vec!["pos2", "intensity", "density"]);
assert_eq!(consumed_aesthetics, vec!["pos2"]);
let df = execute(&stat_query).expect("Generated SQL should execute");
let col_names = df.get_column_names();
assert!(col_names.iter().any(|s| s == "__ggsql_stat_pos2"));
assert!(col_names.iter().any(|s| s == "__ggsql_stat_density"));
assert!(col_names.iter().any(|s| s == "species"));
assert!(df.height() > 0);
let species_col = df.column("species").unwrap();
let unique_species = count_unique_strings(species_col);
assert_eq!(unique_species, 3, "Should have 3 unique species");
}
_ => panic!("Expected Transformed result"),
}
}
#[test]
fn test_violin_with_extra_groups() {
let query = "SELECT species, flipper_length, island FROM penguins";
let aesthetics = create_aesthetics_with_color();
let groups = vec!["island".to_string()]; let mut parameters = Parameters::new();
parameters.insert("bandwidth".to_string(), ParameterValue::Number(5.0));
parameters.insert(
"kernel".to_string(),
ParameterValue::String("gaussian".to_string()),
);
let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap();
let setup_sql = "CREATE TABLE penguins AS SELECT * FROM (VALUES
('Adelie', 181.0, 'Torgersen'), ('Adelie', 186.0, 'Torgersen'),
('Adelie', 195.0, 'Biscoe'), ('Adelie', 190.0, 'Biscoe'),
('Gentoo', 217.0, 'Biscoe'), ('Gentoo', 221.0, 'Biscoe'),
('Chinstrap', 192.0, 'Dream'), ('Chinstrap', 196.0, 'Dream')
) AS t(species, flipper_length, island)";
reader.execute_sql(setup_sql).unwrap();
let execute = |sql: &str| reader.execute_sql(sql);
let ctx = crate::plot::aesthetic::AestheticContext::from_static(&["x", "y"], &[]);
let result = stat_violin(query, &aesthetics, &groups, ¶meters, &AnsiDialect, &ctx)
.expect("stat_violin should succeed");
match result {
StatResult::Transformed {
query: stat_query,
stat_columns,
consumed_aesthetics,
..
} => {
assert_eq!(stat_columns, vec!["pos2", "intensity", "density"]);
assert_eq!(consumed_aesthetics, vec!["pos2"]);
let df = execute(&stat_query).expect("Generated SQL should execute");
let col_names = df.get_column_names();
assert!(col_names.iter().any(|s| s == "__ggsql_stat_pos2"));
assert!(col_names.iter().any(|s| s == "__ggsql_stat_density"));
assert!(col_names.iter().any(|s| s == "species"));
assert!(col_names.iter().any(|s| s == "island"));
assert!(df.height() > 0);
let species_col = df.column("species").unwrap();
let unique_species = count_unique_strings(species_col);
assert!(unique_species >= 2, "Should have at least 2 unique species");
let island_col = df.column("island").unwrap();
let unique_islands = count_unique_strings(island_col);
assert!(unique_islands >= 2, "Should have at least 2 unique islands");
}
_ => panic!("Expected Transformed result"),
}
}
#[test]
fn test_violin_width_parameter() {
let violin = Violin;
let params = violin.default_params();
let width_param = params.iter().find(|p| p.name == "width");
assert!(
width_param.is_some(),
"Violin should have a 'width' parameter"
);
if let Some(param) = width_param {
match param.default {
DefaultParamValue::Number(n) => {
assert!(
(n - 0.9).abs() < 1e-6,
"Default width should be 0.9, got {}",
n
);
}
_ => panic!("Width parameter should have a numeric default"),
}
}
}
#[test]
fn test_violin_tails_parameter() {
let violin = Violin;
let params = violin.default_params();
let tails_param = params.iter().find(|p| p.name == "tails");
assert!(
tails_param.is_some(),
"Violin should have a 'tails' parameter"
);
if let Some(param) = tails_param {
match param.default {
DefaultParamValue::Number(n) => {
assert!(
(n - 3.0).abs() < 1e-6,
"Default tails should be 3.0, got {}",
n
);
}
_ => panic!("Tails parameter should have a numeric default"),
}
}
let query = "SELECT species, flipper_length FROM penguins";
let aesthetics = create_basic_aesthetics();
let groups: Vec<String> = vec![];
let mut parameters = Parameters::new();
parameters.insert("bandwidth".to_string(), ParameterValue::Number(5.0));
parameters.insert(
"kernel".to_string(),
ParameterValue::String("gaussian".to_string()),
);
parameters.insert("tails".to_string(), ParameterValue::Number(1.5));
let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap();
let setup_sql = "CREATE TABLE penguins AS SELECT * FROM (VALUES
('Adelie', 181.0), ('Adelie', 186.0), ('Adelie', 195.0),
('Gentoo', 217.0), ('Gentoo', 221.0), ('Gentoo', 230.0)
) AS t(species, flipper_length)";
reader.execute_sql(setup_sql).unwrap();
let execute = |sql: &str| reader.execute_sql(sql);
let ctx = crate::plot::aesthetic::AestheticContext::from_static(&["x", "y"], &[]);
let result = stat_violin(query, &aesthetics, &groups, ¶meters, &AnsiDialect, &ctx)
.expect("stat_violin with custom tails should succeed");
match result {
StatResult::Transformed {
query: stat_query, ..
} => {
assert!(
stat_query.contains("1.5"),
"SQL should contain the custom tails value 1.5"
);
let df = execute(&stat_query).expect("Generated SQL should execute");
assert!(df.height() > 0, "Should produce density data");
}
_ => panic!("Expected Transformed result"),
}
}
#[test]
fn test_violin_post_process_scales_offset() {
use crate::df;
let violin = Violin;
let offset_col = naming::aesthetic_column("offset");
let df = df! {
offset_col.as_str() => vec![0.0, 0.5, 1.0, 0.25],
"__ggsql_aes_pos2__" => vec![1.0, 2.0, 3.0, 4.0],
}
.unwrap();
let parameters = Parameters::new();
let result = violin.post_process(df, ¶meters).unwrap();
let scaled_arr = crate::array_util::as_f64(result.column(&offset_col).unwrap()).unwrap();
let values: Vec<f64> = (0..scaled_arr.len())
.filter(|&i| !scaled_arr.is_null(i))
.map(|i| scaled_arr.value(i))
.collect();
assert!((values[0] - 0.0).abs() < 1e-6, "0.0 should stay 0.0");
assert!((values[1] - 0.225).abs() < 1e-6, "0.5 should become 0.225");
assert!((values[2] - 0.45).abs() < 1e-6, "1.0 should become 0.45");
assert!(
(values[3] - 0.1125).abs() < 1e-6,
"0.25 should become 0.1125"
);
}
#[test]
fn test_violin_post_process_custom_width() {
use crate::df;
let violin = Violin;
let offset_col = naming::aesthetic_column("offset");
let df = df! {
offset_col.as_str() => vec![0.0, 0.5, 1.0],
"__ggsql_aes_pos2__" => vec![1.0, 2.0, 3.0],
}
.unwrap();
let mut parameters = Parameters::new();
parameters.insert("width".to_string(), ParameterValue::Number(0.6));
let result = violin.post_process(df, ¶meters).unwrap();
let scaled_arr = crate::array_util::as_f64(result.column(&offset_col).unwrap()).unwrap();
let values: Vec<f64> = (0..scaled_arr.len())
.filter(|&i| !scaled_arr.is_null(i))
.map(|i| scaled_arr.value(i))
.collect();
assert!((values[0] - 0.0).abs() < 1e-6, "0.0 should stay 0.0");
assert!((values[1] - 0.15).abs() < 1e-6, "0.5 should become 0.15");
assert!((values[2] - 0.3).abs() < 1e-6, "1.0 should become 0.3");
}
#[test]
fn test_violin_dummy_pos1_when_unmapped() {
let query = "SELECT flipper_length FROM penguins";
let mut aesthetics = Mappings::new();
aesthetics.insert(
"pos2".to_string(),
AestheticValue::standard_column("flipper_length".to_string()),
);
let groups: Vec<String> = vec![];
let mut parameters = Parameters::new();
parameters.insert("bandwidth".to_string(), ParameterValue::Number(5.0));
parameters.insert(
"kernel".to_string(),
ParameterValue::String("gaussian".to_string()),
);
let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap();
let setup_sql = "CREATE TABLE penguins AS SELECT * FROM (VALUES
(181.0), (186.0), (195.0), (217.0), (221.0), (230.0), (192.0)
) AS t(flipper_length)";
reader.execute_sql(setup_sql).unwrap();
let execute = |sql: &str| reader.execute_sql(sql);
let ctx = crate::plot::aesthetic::AestheticContext::from_static(&["x", "y"], &[]);
let result = stat_violin(query, &aesthetics, &groups, ¶meters, &AnsiDialect, &ctx)
.expect("stat_violin should succeed without pos1");
match result {
StatResult::Transformed {
query: stat_query,
stat_columns,
dummy_columns,
..
} => {
assert!(stat_columns.contains(&"pos1".to_string()));
assert_eq!(dummy_columns, vec!["pos1".to_string()]);
assert!(stat_query.contains("__ggsql_stat_dummy"));
assert!(stat_query.contains("__ggsql_stat_pos1"));
let df = execute(&stat_query).expect("Generated SQL should execute");
assert!(df.height() > 0);
let pos1_col = df.column("__ggsql_stat_pos1").unwrap();
let unique = count_unique_strings(pos1_col);
assert_eq!(unique, 1, "dummy pos1 should collapse to one group");
}
_ => panic!("Expected Transformed result"),
}
}
#[test]
fn test_violin_post_process_no_offset_column() {
use crate::df;
let violin = Violin;
let df = df! {
"__ggsql_aes_pos2__" => vec![1.0, 2.0, 3.0],
}
.unwrap();
let parameters = Parameters::new();
let result = violin.post_process(df.clone(), ¶meters).unwrap();
assert_eq!(result.height(), df.height());
}
}