use crate::chart::Chart;
use crate::core::data::*;
use crate::error::ChartonError;
use crate::mark::Mark;
use kernel_density_estimation::prelude::*;
use polars::prelude::*;
#[derive(Debug, Clone)]
pub enum KernelType {
Normal,
Epanechnikov,
Uniform,
}
impl KernelType {
fn as_str(&self) -> &'static str {
match self {
KernelType::Normal => "Normal",
KernelType::Epanechnikov => "Epanechnikov",
KernelType::Uniform => "Uniform",
}
}
}
impl std::fmt::Display for KernelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone)]
pub enum BandwidthType {
Scott,
Silverman,
Fixed(f64),
}
impl BandwidthType {
fn as_str(&self) -> &'static str {
match self {
BandwidthType::Scott => "Scott",
BandwidthType::Silverman => "Silverman",
BandwidthType::Fixed(_) => "Fixed",
}
}
}
impl std::fmt::Display for BandwidthType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BandwidthType::Fixed(value) => write!(f, "Fixed({})", value),
_ => write!(f, "{}", self.as_str()),
}
}
}
#[derive(Debug, Clone)]
pub struct DensityTransform {
pub(crate) density: String,
pub(crate) as_: [String; 2],
pub(crate) bandwidth: BandwidthType,
pub(crate) counts: bool,
pub(crate) cumulative: bool,
pub(crate) groupby: Option<String>,
pub(crate) kernel: KernelType,
}
impl DensityTransform {
pub fn new(density_field: impl Into<String>) -> Self {
Self {
density: density_field.into(),
as_: ["value".to_string(), "density".to_string()],
bandwidth: BandwidthType::Scott, counts: false,
cumulative: false,
groupby: None,
kernel: KernelType::Normal,
}
}
pub fn with_as(
mut self,
value_field: impl Into<String>,
density_field: impl Into<String>,
) -> Self {
self.as_ = [value_field.into(), density_field.into()];
self
}
pub fn with_bandwidth(mut self, bandwidth: BandwidthType) -> Self {
self.bandwidth = bandwidth;
self
}
pub fn with_counts(mut self, counts: bool) -> Self {
self.counts = counts;
self
}
pub fn with_cumulative(mut self, cumulative: bool) -> Self {
self.cumulative = cumulative;
self
}
pub fn with_groupby(mut self, groupby: &str) -> Self {
self.groupby = Some(groupby.into());
self
}
pub fn with_kernel(mut self, kernel: impl Into<KernelType>) -> Self {
self.kernel = kernel.into();
self
}
}
impl<T: Mark> Chart<T> {
pub fn transform_density(mut self, params: DensityTransform) -> Result<Self, ChartonError> {
let density_field = ¶ms.density;
let density_series = self.data.column(density_field)?;
let min_val = density_series.min::<f64>()?.unwrap();
let max_val = density_series.max::<f64>()?.unwrap();
let extended_min = 1.3 * min_val - 0.3 * max_val;
let extended_max = 1.3 * max_val - 0.3 * min_val;
let (min_val, max_val) = if (extended_max - extended_min).abs() < 1e-12 {
let offset = if extended_min == 0.0 {
1.0
} else {
extended_min.abs() * 0.1
};
(extended_min - offset, extended_max + offset)
} else {
(extended_min, extended_max)
};
let steps = 200;
let step_size = (max_val - min_val) / (steps as f64);
let eval_points: Vec<f32> = (0..steps)
.map(|i| (min_val + (i as f64) * step_size) as f32)
.collect();
let value_column: Vec<f64> = eval_points.iter().map(|&v| v as f64).collect();
let group_field_name = params
.groupby
.clone()
.unwrap_or_else(|| format!("__charton_temp_group_{}", crate::TEMP_SUFFIX));
let working_df = if let Some(ref group_field) = params.groupby {
self.data.df.select([density_field, group_field])?
} else {
let fake_group_series = Series::new(
(&group_field_name).into(),
vec!["fake"; self.data.df.height()],
);
self.data
.df
.select([density_field])?
.with_column(fake_group_series)?
.clone() };
let grouped_df = working_df
.lazy()
.group_by_stable([col(&group_field_name)])
.agg([col(density_field).implode().alias(density_field)])
.collect()?;
let mut all_groups = Vec::new();
let mut all_x_values = Vec::new();
let mut all_y_values = Vec::new();
for i in 0..grouped_df.height() {
let group_value = match grouped_df
.column(&group_field_name)
.map_err(ChartonError::Polars)?
.get(i)?
{
AnyValue::String(s) => s.to_string(),
AnyValue::Int32(v) => v.to_string(),
AnyValue::Int64(v) => v.to_string(),
AnyValue::Float64(v) => v.to_string(),
_ => "unknown".to_string(),
};
let list_series = grouped_df.column(density_field)?.get(i)?;
let group_vals: Vec<f64> = match list_series {
AnyValue::List(inner) => inner.f64()?.into_no_null_iter().collect(),
_ => continue,
};
let observations: Vec<f32> = group_vals.iter().map(|&v| v as f32).collect();
let density_values: Vec<f64> = match (¶ms.bandwidth, ¶ms.kernel) {
(BandwidthType::Scott, KernelType::Normal) => {
let kde = KernelDensityEstimator::new(observations.clone(), Scott, Normal);
if params.cumulative {
kde.cdf(&eval_points).iter().map(|&v| v as f64).collect()
} else {
kde.pdf(&eval_points).iter().map(|&v| v as f64).collect()
}
}
(BandwidthType::Scott, KernelType::Epanechnikov) => {
let kde =
KernelDensityEstimator::new(observations.clone(), Scott, Epanechnikov);
if params.cumulative {
kde.cdf(&eval_points).iter().map(|&v| v as f64).collect()
} else {
kde.pdf(&eval_points).iter().map(|&v| v as f64).collect()
}
}
(BandwidthType::Scott, KernelType::Uniform) => {
let kde = KernelDensityEstimator::new(observations.clone(), Scott, Uniform);
if params.cumulative {
kde.cdf(&eval_points).iter().map(|&v| v as f64).collect()
} else {
kde.pdf(&eval_points).iter().map(|&v| v as f64).collect()
}
}
(BandwidthType::Silverman, KernelType::Normal) => {
let kde = KernelDensityEstimator::new(observations.clone(), Silverman, Normal);
if params.cumulative {
kde.cdf(&eval_points).iter().map(|&v| v as f64).collect()
} else {
kde.pdf(&eval_points).iter().map(|&v| v as f64).collect()
}
}
(BandwidthType::Silverman, KernelType::Epanechnikov) => {
let kde =
KernelDensityEstimator::new(observations.clone(), Silverman, Epanechnikov);
if params.cumulative {
kde.cdf(&eval_points).iter().map(|&v| v as f64).collect()
} else {
kde.pdf(&eval_points).iter().map(|&v| v as f64).collect()
}
}
(BandwidthType::Silverman, KernelType::Uniform) => {
let kde = KernelDensityEstimator::new(observations.clone(), Silverman, Uniform);
if params.cumulative {
kde.cdf(&eval_points).iter().map(|&v| v as f64).collect()
} else {
kde.pdf(&eval_points).iter().map(|&v| v as f64).collect()
}
}
(BandwidthType::Fixed(value), KernelType::Normal) => {
let bandwidth = Box::new(|_: &[f32]| *value as f32);
let kde = KernelDensityEstimator::new(observations.clone(), bandwidth, Normal);
if params.cumulative {
kde.cdf(&eval_points).iter().map(|&v| v as f64).collect()
} else {
kde.pdf(&eval_points).iter().map(|&v| v as f64).collect()
}
}
(BandwidthType::Fixed(value), KernelType::Epanechnikov) => {
let bandwidth = Box::new(|_: &[f32]| *value as f32);
let kde =
KernelDensityEstimator::new(observations.clone(), bandwidth, Epanechnikov);
if params.cumulative {
kde.cdf(&eval_points).iter().map(|&v| v as f64).collect()
} else {
kde.pdf(&eval_points).iter().map(|&v| v as f64).collect()
}
}
(BandwidthType::Fixed(value), KernelType::Uniform) => {
let bandwidth = Box::new(|_: &[f32]| *value as f32);
let kde = KernelDensityEstimator::new(observations.clone(), bandwidth, Uniform);
if params.cumulative {
kde.cdf(&eval_points).iter().map(|&v| v as f64).collect()
} else {
kde.pdf(&eval_points).iter().map(|&v| v as f64).collect()
}
}
};
let density_values = if params.counts {
density_values
.into_iter()
.map(|v| v * group_vals.len() as f64)
.collect()
} else {
density_values
};
for _ in 0..value_column.len() {
all_groups.push(group_value.clone());
}
all_x_values.extend(value_column.clone());
all_y_values.extend(density_values);
}
let result_df = if params.groupby.is_some() {
df![
¶ms.as_[0] => all_x_values,
¶ms.as_[1] => all_y_values,
&group_field_name => all_groups
]
} else {
df![
¶ms.as_[0] => all_x_values,
¶ms.as_[1] => all_y_values
]
};
self.data = DataFrameSource::new(result_df?);
Ok(self)
}
}