use crate::plot::{Plot, PositionType};
use crate::{DataFrame, Result};
use std::collections::HashMap;
pub fn apply_position_adjustments(
spec: &mut Plot,
data_map: &mut HashMap<String, DataFrame>,
) -> Result<()> {
for idx in 0..spec.layers.len() {
if spec.layers[idx].position.position_type() == PositionType::Identity {
continue;
}
let Some(key) = spec.layers[idx].data_key.clone() else {
continue;
};
let Some(df) = data_map.remove(&key) else {
continue;
};
let (adjusted_df, adjusted_width) =
spec.layers[idx]
.position
.apply_adjustment(df, &spec.layers[idx], spec)?;
data_map.insert(key.clone(), adjusted_df);
if let Some(width) = adjusted_width {
spec.layers[idx].adjusted_width = Some(width);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array_util::as_f64;
use crate::df;
use crate::plot::facet::{Facet, FacetLayout};
use crate::plot::layer::{Geom, Position};
use crate::plot::{AestheticValue, Mappings, ParameterValue, Scale, ScaleType};
use arrow::array::Array;
fn make_continuous_scale(aesthetic: &str) -> Scale {
let mut scale = Scale::new(aesthetic);
scale.scale_type = Some(ScaleType::continuous());
scale
}
fn make_discrete_scale(aesthetic: &str) -> Scale {
let mut scale = Scale::new(aesthetic);
scale.scale_type = Some(ScaleType::discrete());
scale
}
fn make_test_df() -> DataFrame {
df! {
"__ggsql_aes_pos1__" => vec!["A", "A", "B", "B"],
"__ggsql_aes_pos2__" => vec![10.0, 20.0, 15.0, 25.0],
"__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0],
"__ggsql_aes_fill__" => vec!["X", "Y", "X", "Y"],
}
.unwrap()
}
fn make_test_layer() -> crate::plot::Layer {
let mut layer = crate::plot::Layer::new(Geom::bar());
layer.mappings = {
let mut m = Mappings::new();
m.insert(
"pos1",
AestheticValue::standard_column("__ggsql_aes_pos1__"),
);
m.insert(
"pos2",
AestheticValue::standard_column("__ggsql_aes_pos2__"),
);
m.insert(
"pos2end",
AestheticValue::standard_column("__ggsql_aes_pos2end__"),
);
m.insert(
"fill",
AestheticValue::standard_column("__ggsql_aes_fill__"),
);
m
};
layer.partition_by = vec!["__ggsql_aes_fill__".to_string()];
layer
}
#[test]
fn test_identity_no_change() {
let df = make_test_df();
let mut layer = make_test_layer();
layer.position = Position::identity();
let spec = Plot::new();
let mut data_map = HashMap::new();
layer.data_key = Some("__ggsql_layer_0__".to_string());
data_map.insert("__ggsql_layer_0__".to_string(), df.clone());
let mut spec_with_layer = spec;
spec_with_layer.layers.push(layer);
apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap();
let result_df = data_map.get("__ggsql_layer_0__").unwrap();
assert_eq!(result_df.height(), 4);
}
#[test]
fn test_stack_cumsum() {
let df = make_test_df();
let mut layer = make_test_layer();
layer.position = Position::stack();
let spec = Plot::new();
let mut data_map = HashMap::new();
layer.data_key = Some("__ggsql_layer_0__".to_string());
data_map.insert("__ggsql_layer_0__".to_string(), df);
let mut spec_with_layer = spec;
spec_with_layer.layers.push(layer);
apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap();
let result_df = data_map.get("__ggsql_layer_0__").unwrap();
let pos2_col = result_df.column("__ggsql_aes_pos2__").unwrap();
let pos2end_col = result_df.column("__ggsql_aes_pos2end__").unwrap();
assert!(
matches!(
pos2_col.data_type(),
arrow::datatypes::DataType::Float64
| arrow::datatypes::DataType::Int64
| arrow::datatypes::DataType::Int32
),
"pos2 should be numeric"
);
assert!(
matches!(
pos2end_col.data_type(),
arrow::datatypes::DataType::Float64
| arrow::datatypes::DataType::Int64
| arrow::datatypes::DataType::Int32
),
"pos2end should be numeric"
);
}
#[test]
fn test_dodge_offset() {
let df = make_test_df();
let mut layer = make_test_layer();
layer.position = Position::dodge();
let mut spec = Plot::new();
spec.scales.push(make_discrete_scale("pos1"));
spec.scales.push(make_continuous_scale("pos2"));
let mut data_map = HashMap::new();
layer.data_key = Some("__ggsql_layer_0__".to_string());
data_map.insert("__ggsql_layer_0__".to_string(), df);
let mut spec_with_layer = spec;
spec_with_layer.layers.push(layer);
apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap();
let result_df = data_map.get("__ggsql_layer_0__").unwrap();
let offset_col = result_df.column("__ggsql_aes_pos1offset__");
assert!(offset_col.is_ok(), "pos1offset column should be created");
let offset = as_f64(offset_col.unwrap()).unwrap();
let offsets: Vec<f64> = (0..offset.len())
.filter(|&i| !offset.is_null(i))
.map(|i| offset.value(i))
.collect();
assert!(
offsets.iter().any(|&v| (v - (-0.225)).abs() < 0.001),
"Should have offset -0.225 for group X, got {:?}",
offsets
);
assert!(
offsets.iter().any(|&v| (v - 0.225).abs() < 0.001),
"Should have offset +0.225 for group Y, got {:?}",
offsets
);
let adjusted = spec_with_layer.layers[0].adjusted_width;
assert!(adjusted.is_some());
assert!((adjusted.unwrap() - 0.45).abs() < 0.001);
}
#[test]
fn test_dodge_custom_width() {
let df = make_test_df();
let mut layer = make_test_layer();
layer.position = Position::dodge();
layer
.parameters
.insert("width".to_string(), ParameterValue::Number(0.6));
let mut spec = Plot::new();
spec.scales.push(make_discrete_scale("pos1"));
spec.scales.push(make_continuous_scale("pos2"));
let mut data_map = HashMap::new();
layer.data_key = Some("__ggsql_layer_0__".to_string());
data_map.insert("__ggsql_layer_0__".to_string(), df);
let mut spec_with_layer = spec;
spec_with_layer.layers.push(layer);
apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap();
let result_df = data_map.get("__ggsql_layer_0__").unwrap();
let offset_col = result_df.column("__ggsql_aes_pos1offset__").unwrap();
let offset = as_f64(offset_col).unwrap();
let offsets: Vec<f64> = (0..offset.len())
.filter(|&i| !offset.is_null(i))
.map(|i| offset.value(i))
.collect();
assert!(offsets.iter().any(|&v| (v - (-0.15)).abs() < 0.001));
assert!(offsets.iter().any(|&v| (v - 0.15).abs() < 0.001));
let adjusted = spec_with_layer.layers[0].adjusted_width;
assert!((adjusted.unwrap() - 0.3).abs() < 0.001);
}
#[test]
fn test_jitter_offset() {
let df = make_test_df();
let mut layer = make_test_layer();
layer.position = Position::jitter();
let mut spec = Plot::new();
spec.scales.push(make_discrete_scale("pos1"));
spec.scales.push(make_continuous_scale("pos2"));
let mut data_map = HashMap::new();
layer.data_key = Some("__ggsql_layer_0__".to_string());
data_map.insert("__ggsql_layer_0__".to_string(), df);
let mut spec_with_layer = spec;
spec_with_layer.layers.push(layer);
apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap();
let result_df = data_map.get("__ggsql_layer_0__").unwrap();
let offset_col = result_df.column("__ggsql_aes_pos1offset__");
assert!(offset_col.is_ok());
let offset = as_f64(offset_col.unwrap()).unwrap();
let offsets: Vec<f64> = (0..offset.len())
.filter(|&i| !offset.is_null(i))
.map(|i| offset.value(i))
.collect();
for &v in &offsets {
assert!((-0.45..=0.45).contains(&v));
}
assert!(spec_with_layer.layers[0].adjusted_width.is_none());
}
#[test]
fn test_jitter_custom_width() {
let df = make_test_df();
let mut layer = make_test_layer();
layer.position = Position::jitter();
layer
.parameters
.insert("width".to_string(), ParameterValue::Number(0.6));
let mut spec = Plot::new();
spec.scales.push(make_discrete_scale("pos1"));
spec.scales.push(make_continuous_scale("pos2"));
let mut data_map = HashMap::new();
layer.data_key = Some("__ggsql_layer_0__".to_string());
data_map.insert("__ggsql_layer_0__".to_string(), df);
let mut spec_with_layer = spec;
spec_with_layer.layers.push(layer);
apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap();
let result_df = data_map.get("__ggsql_layer_0__").unwrap();
let offset_col = result_df.column("__ggsql_aes_pos1offset__").unwrap();
let offset = as_f64(offset_col).unwrap();
let offsets: Vec<f64> = (0..offset.len())
.filter(|&i| !offset.is_null(i))
.map(|i| offset.value(i))
.collect();
for &v in &offsets {
assert!((-0.3..=0.3).contains(&v));
}
}
#[test]
fn test_stack_resets_per_facet_panel() {
let df = df! {
"__ggsql_aes_pos1__" => vec!["A", "A", "A", "A"],
"__ggsql_aes_pos2__" => vec![10.0, 20.0, 30.0, 40.0],
"__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0],
"__ggsql_aes_fill__" => vec!["X", "Y", "X", "Y"],
"__ggsql_aes_facet1__" => vec!["F1", "F1", "F2", "F2"],
}
.unwrap();
let mut layer = crate::plot::Layer::new(Geom::bar());
layer.mappings = {
let mut m = Mappings::new();
m.insert(
"pos1",
AestheticValue::standard_column("__ggsql_aes_pos1__"),
);
m.insert(
"pos2",
AestheticValue::standard_column("__ggsql_aes_pos2__"),
);
m.insert(
"pos2end",
AestheticValue::standard_column("__ggsql_aes_pos2end__"),
);
m.insert(
"fill",
AestheticValue::standard_column("__ggsql_aes_fill__"),
);
m.insert(
"facet1",
AestheticValue::standard_column("__ggsql_aes_facet1__"),
);
m
};
layer.partition_by = vec![
"__ggsql_aes_fill__".to_string(),
"__ggsql_aes_facet1__".to_string(),
];
layer.position = Position::stack();
layer.data_key = Some("__ggsql_layer_0__".to_string());
let mut spec = Plot::new();
spec.scales.push(make_discrete_scale("pos1"));
spec.scales.push(make_continuous_scale("pos2"));
spec.facet = Some(Facet::new(FacetLayout::Wrap {
variables: vec!["facet_var".to_string()],
}));
let mut data_map = HashMap::new();
data_map.insert("__ggsql_layer_0__".to_string(), df);
let mut spec_with_layer = spec;
spec_with_layer.layers.push(layer);
apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap();
let result_df = data_map.get("__ggsql_layer_0__").unwrap();
let facet_col =
crate::array_util::as_str(result_df.column("__ggsql_aes_facet1__").unwrap()).unwrap();
let fill_col =
crate::array_util::as_str(result_df.column("__ggsql_aes_fill__").unwrap()).unwrap();
let mut indices: Vec<usize> = (0..result_df.height()).collect();
indices.sort_by(|&a, &b| {
let fa = facet_col.value(a);
let fb = facet_col.value(b);
let cmp1 = fa.cmp(fb);
if cmp1 != std::cmp::Ordering::Equal {
return cmp1;
}
fill_col.value(a).cmp(fill_col.value(b))
});
let pos2_arr = as_f64(result_df.column("__ggsql_aes_pos2__").unwrap()).unwrap();
let pos2end_arr = as_f64(result_df.column("__ggsql_aes_pos2end__").unwrap()).unwrap();
let pos2_vals: Vec<f64> = indices.iter().map(|&i| pos2_arr.value(i)).collect();
let pos2end_vals: Vec<f64> = indices.iter().map(|&i| pos2end_arr.value(i)).collect();
assert_eq!(
pos2end_vals[2], 0.0,
"F2 panel first bar should start at 0, not carry over from F1. pos2end={:?}, pos2={:?}",
pos2end_vals, pos2_vals
);
}
#[test]
fn test_stack_groups_by_facet_not_fill_order() {
let df = df! {
"__ggsql_aes_pos1__" => vec!["A", "A", "A", "A", "A", "A"],
"__ggsql_aes_pos2__" => vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0],
"__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"__ggsql_aes_fill__" => vec!["X", "X", "X", "Y", "Y", "Y"],
"__ggsql_aes_facet1__" => vec!["F1", "F2", "F3", "F1", "F2", "F3"],
}
.unwrap();
let mut layer = crate::plot::Layer::new(Geom::bar());
layer.mappings = {
let mut m = Mappings::new();
m.insert(
"pos1",
AestheticValue::standard_column("__ggsql_aes_pos1__"),
);
m.insert(
"pos2",
AestheticValue::standard_column("__ggsql_aes_pos2__"),
);
m.insert(
"pos2end",
AestheticValue::standard_column("__ggsql_aes_pos2end__"),
);
m.insert(
"fill",
AestheticValue::standard_column("__ggsql_aes_fill__"),
);
m.insert(
"facet1",
AestheticValue::standard_column("__ggsql_aes_facet1__"),
);
m
};
layer.partition_by = vec![
"__ggsql_aes_fill__".to_string(),
"__ggsql_aes_facet1__".to_string(),
];
layer.position = Position::stack();
layer.data_key = Some("__ggsql_layer_0__".to_string());
let mut spec = Plot::new();
spec.scales.push(make_discrete_scale("pos1"));
spec.scales.push(make_continuous_scale("pos2"));
spec.facet = Some(Facet::new(FacetLayout::Wrap {
variables: vec!["facet_var".to_string()],
}));
let mut data_map = HashMap::new();
data_map.insert("__ggsql_layer_0__".to_string(), df);
spec.layers.push(layer);
apply_position_adjustments(&mut spec, &mut data_map).unwrap();
let result_df = data_map.get("__ggsql_layer_0__").unwrap();
let facet_col =
crate::array_util::as_str(result_df.column("__ggsql_aes_facet1__").unwrap()).unwrap();
let fill_col =
crate::array_util::as_str(result_df.column("__ggsql_aes_fill__").unwrap()).unwrap();
let pos2_arr = as_f64(result_df.column("__ggsql_aes_pos2__").unwrap()).unwrap();
let pos2end_arr = as_f64(result_df.column("__ggsql_aes_pos2end__").unwrap()).unwrap();
let mut by_key: std::collections::HashMap<(&str, &str), (f64, f64)> =
std::collections::HashMap::new();
for i in 0..result_df.height() {
by_key.insert(
(facet_col.value(i), fill_col.value(i)),
(pos2_arr.value(i), pos2end_arr.value(i)),
);
}
assert_eq!(by_key[&("F1", "X")], (10.0, 0.0));
assert_eq!(by_key[&("F1", "Y")], (50.0, 10.0));
assert_eq!(by_key[&("F2", "X")], (20.0, 0.0));
assert_eq!(by_key[&("F2", "Y")], (70.0, 20.0));
assert_eq!(by_key[&("F3", "X")], (30.0, 0.0));
assert_eq!(by_key[&("F3", "Y")], (90.0, 30.0));
}
#[test]
fn test_dodge_ignores_facet_columns_in_group_count() {
let df = df! {
"__ggsql_aes_pos1__" => vec!["A", "A", "A", "A"],
"__ggsql_aes_pos2__" => vec![10.0, 20.0, 30.0, 40.0],
"__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0],
"__ggsql_aes_fill__" => vec!["X", "Y", "X", "Y"],
"__ggsql_aes_facet1__" => vec!["F1", "F1", "F2", "F2"],
}
.unwrap();
let mut layer = crate::plot::Layer::new(Geom::bar());
layer.mappings = {
let mut m = Mappings::new();
m.insert(
"pos1",
AestheticValue::standard_column("__ggsql_aes_pos1__"),
);
m.insert(
"pos2",
AestheticValue::standard_column("__ggsql_aes_pos2__"),
);
m.insert(
"pos2end",
AestheticValue::standard_column("__ggsql_aes_pos2end__"),
);
m.insert(
"fill",
AestheticValue::standard_column("__ggsql_aes_fill__"),
);
m.insert(
"facet1",
AestheticValue::standard_column("__ggsql_aes_facet1__"),
);
m
};
layer.partition_by = vec![
"__ggsql_aes_fill__".to_string(),
"__ggsql_aes_facet1__".to_string(),
];
layer.position = Position::dodge();
layer.data_key = Some("__ggsql_layer_0__".to_string());
let mut spec = Plot::new();
spec.scales.push(make_discrete_scale("pos1"));
spec.scales.push(make_continuous_scale("pos2"));
spec.facet = Some(Facet::new(FacetLayout::Wrap {
variables: vec!["facet_var".to_string()],
}));
let mut data_map = HashMap::new();
data_map.insert("__ggsql_layer_0__".to_string(), df);
spec.layers.push(layer);
apply_position_adjustments(&mut spec, &mut data_map).unwrap();
let adjusted = spec.layers[0].adjusted_width.unwrap();
assert!(
(adjusted - 0.45).abs() < 0.001,
"adjusted_width should be 0.45 (2 groups), got {} (facet columns inflated group count)",
adjusted
);
}
}