use crate::chart::Chart;
use crate::core::data::{ColumnVector, Dataset};
use crate::error::ChartonError;
use crate::mark::Mark;
use ahash::AHashMap;
use kernel_density_estimation::prelude::*;
#[derive(Debug, Clone)]
pub enum KernelType {
Normal,
Epanechnikov,
Uniform,
}
impl KernelType {
fn as_str(&self) -> &'static str {
match self {
KernelType::Normal => "Normal",
KernelType::Epanechnikov => "Epanechnikov",
KernelType::Uniform => "Uniform",
}
}
}
impl std::fmt::Display for KernelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone)]
pub enum BandwidthType {
Scott,
Silverman,
Fixed(f64),
}
impl BandwidthType {
fn as_str(&self) -> &'static str {
match self {
BandwidthType::Scott => "Scott",
BandwidthType::Silverman => "Silverman",
BandwidthType::Fixed(_) => "Fixed",
}
}
}
impl std::fmt::Display for BandwidthType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BandwidthType::Fixed(value) => write!(f, "Fixed({})", value),
_ => write!(f, "{}", self.as_str()),
}
}
}
#[derive(Debug, Clone)]
pub struct DensityTransform {
pub(crate) density: String,
pub(crate) as_: [String; 2],
pub(crate) bandwidth: BandwidthType,
pub(crate) counts: bool,
pub(crate) cumulative: bool,
pub(crate) groupby: Option<String>,
pub(crate) kernel: KernelType,
}
impl DensityTransform {
pub fn new(density_field: impl Into<String>) -> Self {
Self {
density: density_field.into(),
as_: ["value".to_string(), "density".to_string()],
bandwidth: BandwidthType::Scott, counts: false,
cumulative: false,
groupby: None,
kernel: KernelType::Normal,
}
}
pub fn with_as(
mut self,
value_field: impl Into<String>,
density_field: impl Into<String>,
) -> Self {
self.as_ = [value_field.into(), density_field.into()];
self
}
pub fn with_bandwidth(mut self, bandwidth: BandwidthType) -> Self {
self.bandwidth = bandwidth;
self
}
pub fn with_counts(mut self, counts: bool) -> Self {
self.counts = counts;
self
}
pub fn with_cumulative(mut self, cumulative: bool) -> Self {
self.cumulative = cumulative;
self
}
pub fn with_groupby(mut self, groupby: &str) -> Self {
self.groupby = Some(groupby.into());
self
}
pub fn with_kernel(mut self, kernel: impl Into<KernelType>) -> Self {
self.kernel = kernel.into();
self
}
}
impl<T: Mark> Chart<T> {
pub fn transform_density(mut self, params: DensityTransform) -> Result<Self, ChartonError> {
let density_field = ¶ms.density;
let density_col = self.data.column(density_field)?;
let (min_val, max_val) = density_col.min_max();
let mut extended_min = 1.3 * min_val - 0.3 * max_val;
let mut extended_max = 1.3 * max_val - 0.3 * min_val;
if (extended_max - extended_min).abs() < 1e-12 {
let offset = if extended_min == 0.0 {
1.0
} else {
extended_min.abs() * 0.1
};
extended_min -= offset;
extended_max += offset;
}
let steps = 200;
let step_size = (extended_max - extended_min) / (steps as f64);
let eval_points: Vec<f32> = (0..steps)
.map(|i| (extended_min + (i as f64) * step_size) as f32)
.collect();
let x_axis_values: Vec<f64> = eval_points.iter().map(|&v| v as f64).collect();
let group_order: Vec<Option<String>> = if let Some(ref g_field) = params.groupby {
self.data
.column(g_field)?
.unique_values()
.into_iter()
.map(Some)
.collect()
} else {
vec![None]
};
let mut groups: AHashMap<Option<String>, Vec<f32>> = AHashMap::new();
let row_count = self.data.height();
if let Some(ref g_field) = params.groupby {
let group_col = self.data.column(g_field)?;
for i in 0..row_count {
if let Some(val) = density_col.get_f64(i) {
let key = group_col.get_str(i);
groups.entry(key).or_default().push(val as f32);
}
}
} else {
let mut all_obs = Vec::with_capacity(row_count);
for i in 0..row_count {
if let Some(val) = density_col.get_f64(i) {
all_obs.push(val as f32);
}
}
groups.insert(None, all_obs);
}
let mut final_x = Vec::new();
let mut final_y = Vec::new();
let mut final_group = Vec::new();
for key in group_order {
let observations = match groups.get(&key) {
Some(obs) if !obs.is_empty() => obs,
_ => continue,
};
let group_label = key.as_deref().unwrap_or("all").to_string();
let density_values: Vec<f64> = match (¶ms.bandwidth, ¶ms.kernel) {
(BandwidthType::Scott, KernelType::Normal) => {
let kde = KernelDensityEstimator::new(observations.clone(), Scott, Normal);
if params.cumulative {
kde.cdf(&eval_points)
} else {
kde.pdf(&eval_points)
}
}
(BandwidthType::Scott, KernelType::Epanechnikov) => {
let kde =
KernelDensityEstimator::new(observations.clone(), Scott, Epanechnikov);
if params.cumulative {
kde.cdf(&eval_points)
} else {
kde.pdf(&eval_points)
}
}
(BandwidthType::Scott, KernelType::Uniform) => {
let kde = KernelDensityEstimator::new(observations.clone(), Scott, Uniform);
if params.cumulative {
kde.cdf(&eval_points)
} else {
kde.pdf(&eval_points)
}
}
(BandwidthType::Silverman, KernelType::Normal) => {
let kde = KernelDensityEstimator::new(observations.clone(), Silverman, Normal);
if params.cumulative {
kde.cdf(&eval_points)
} else {
kde.pdf(&eval_points)
}
}
(BandwidthType::Silverman, KernelType::Epanechnikov) => {
let kde =
KernelDensityEstimator::new(observations.clone(), Silverman, Epanechnikov);
if params.cumulative {
kde.cdf(&eval_points)
} else {
kde.pdf(&eval_points)
}
}
(BandwidthType::Silverman, KernelType::Uniform) => {
let kde = KernelDensityEstimator::new(observations.clone(), Silverman, Uniform);
if params.cumulative {
kde.cdf(&eval_points)
} else {
kde.pdf(&eval_points)
}
}
(BandwidthType::Fixed(bw), KernelType::Normal) => {
let h = *bw as f32;
let kde = KernelDensityEstimator::new(
observations.clone(),
move |_: &[f32]| h,
Normal,
);
if params.cumulative {
kde.cdf(&eval_points)
} else {
kde.pdf(&eval_points)
}
}
(BandwidthType::Fixed(bw), KernelType::Epanechnikov) => {
let h = *bw as f32;
let kde = KernelDensityEstimator::new(
observations.clone(),
move |_: &[f32]| h,
Epanechnikov,
);
if params.cumulative {
kde.cdf(&eval_points)
} else {
kde.pdf(&eval_points)
}
}
(BandwidthType::Fixed(bw), KernelType::Uniform) => {
let h = *bw as f32;
let kde = KernelDensityEstimator::new(
observations.clone(),
move |_: &[f32]| h,
Uniform,
);
if params.cumulative {
kde.cdf(&eval_points)
} else {
kde.pdf(&eval_points)
}
}
}
.into_iter()
.map(|v| v as f64)
.collect();
let obs_count = observations.len() as f64;
let processed_y = if params.counts {
density_values.into_iter().map(|v| v * obs_count).collect()
} else {
density_values
};
final_y.extend(processed_y);
final_x.extend(x_axis_values.clone());
if params.groupby.is_some() {
for _ in 0..steps {
final_group.push(group_label.clone());
}
}
}
let mut new_ds = Dataset::new();
new_ds.add_column(¶ms.as_[0], ColumnVector::F64 { data: final_x })?;
new_ds.add_column(¶ms.as_[1], ColumnVector::F64 { data: final_y })?;
if let Some(ref g_field) = params.groupby {
new_ds.add_column(
g_field,
ColumnVector::String {
data: final_group,
validity: None,
},
)?;
}
self.data = new_ds;
Ok(self)
}
}