use crate::core::error::Result;
use crate::core::style_utils::StyleResolver;
use crate::plots::traits::{PlotArea, PlotCompute, PlotConfig, PlotData, PlotRender};
use crate::render::{Color, LineStyle, SkiaRenderer, Theme};
use crate::stats::kde::{kde_1d, kde_2d};
#[derive(Debug, Clone)]
pub struct KdeConfig {
pub bandwidth: Option<f64>,
pub n_points: usize,
pub fill: bool,
pub fill_alpha: f32,
pub color: Option<Color>,
pub line_width: f32,
pub shade: bool,
pub vertical_lines: Vec<f64>,
pub cumulative: bool,
pub clip: Option<(f64, f64)>,
}
impl Default for KdeConfig {
fn default() -> Self {
Self {
bandwidth: None,
n_points: 200,
fill: true,
fill_alpha: 0.3,
color: None,
line_width: 2.0,
shade: false,
vertical_lines: vec![],
cumulative: false,
clip: None,
}
}
}
impl PlotConfig for KdeConfig {}
impl KdeConfig {
pub fn new() -> Self {
Self::default()
}
pub fn bandwidth(mut self, bw: f64) -> Self {
self.bandwidth = Some(bw);
self
}
pub fn n_points(mut self, n: usize) -> Self {
self.n_points = n.max(10);
self
}
pub fn fill(mut self, fill: bool) -> Self {
self.fill = fill;
self
}
pub fn fill_alpha(mut self, alpha: f32) -> Self {
self.fill_alpha = alpha.clamp(0.0, 1.0);
self
}
pub fn color(mut self, color: Color) -> Self {
self.color = Some(color);
self
}
pub fn line_width(mut self, width: f32) -> Self {
self.line_width = width.max(0.1);
self
}
pub fn cumulative(mut self, cumulative: bool) -> Self {
self.cumulative = cumulative;
self
}
pub fn clip(mut self, min: f64, max: f64) -> Self {
self.clip = Some((min, max));
self
}
pub fn vertical_line(mut self, x: f64) -> Self {
self.vertical_lines.push(x);
self
}
}
#[deprecated(since = "0.8.0", note = "Use KdeConfig instead")]
pub type KdePlotConfig = KdeConfig;
#[derive(Debug, Clone)]
pub struct KdeData {
pub x: Vec<f64>,
pub y: Vec<f64>,
pub bandwidth: f64,
pub cumulative: bool,
pub(crate) config: KdeConfig,
}
#[deprecated(since = "0.8.0", note = "Use KdeData instead")]
pub type KdePlotData = KdeData;
pub fn compute_kde(data: &[f64], config: &KdeConfig) -> KdeData {
if data.is_empty() {
return KdeData {
x: vec![],
y: vec![],
bandwidth: 0.0,
cumulative: false,
config: config.clone(),
};
}
let kde = kde_1d(data, config.bandwidth, Some(config.n_points));
let (x, y) = if config.cumulative {
let mut cumulative = Vec::with_capacity(kde.density.len());
let mut sum = 0.0;
let dx = if kde.x.len() > 1 {
kde.x[1] - kde.x[0]
} else {
1.0
};
for d in &kde.density {
sum += d * dx;
cumulative.push(sum);
}
if let Some(&max) = cumulative.last() {
if max > 0.0 {
for c in &mut cumulative {
*c /= max;
}
}
}
(kde.x, cumulative)
} else {
(kde.x, kde.density)
};
let (x, y) = if let Some((min, max)) = config.clip {
let mut clipped_x = Vec::new();
let mut clipped_y = Vec::new();
for (xi, yi) in x.iter().zip(y.iter()) {
if *xi >= min && *xi <= max {
clipped_x.push(*xi);
clipped_y.push(*yi);
}
}
(clipped_x, clipped_y)
} else {
(x, y)
};
KdeData {
x,
y,
bandwidth: kde.bandwidth,
cumulative: config.cumulative,
config: config.clone(),
}
}
#[deprecated(since = "0.8.0", note = "Use compute_kde instead")]
pub fn compute_kde_plot(data: &[f64], config: &KdeConfig) -> KdeData {
compute_kde(data, config)
}
pub fn kde_fill_polygon(kde_data: &KdeData, baseline: f64) -> Vec<(f64, f64)> {
if kde_data.x.is_empty() {
return vec![];
}
let n = kde_data.x.len();
let mut polygon = Vec::with_capacity(n * 2 + 2);
for i in 0..n {
polygon.push((kde_data.x[i], kde_data.y[i]));
}
polygon.push((kde_data.x[n - 1], baseline));
polygon.push((kde_data.x[0], baseline));
polygon
}
pub struct Kde;
impl PlotCompute for Kde {
type Input<'a> = &'a [f64];
type Config = KdeConfig;
type Output = KdeData;
fn compute(input: Self::Input<'_>, config: &Self::Config) -> Result<Self::Output> {
Ok(compute_kde(input, config))
}
}
impl PlotData for KdeData {
fn data_bounds(&self) -> ((f64, f64), (f64, f64)) {
if self.x.is_empty() {
return ((0.0, 1.0), (0.0, 1.0));
}
let x_min = self.x.iter().copied().fold(f64::INFINITY, f64::min);
let x_max = self.x.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let y_min = 0.0;
let y_max = self.y.iter().copied().fold(f64::NEG_INFINITY, f64::max);
((x_min, x_max), (y_min, y_max * 1.05)) }
fn is_empty(&self) -> bool {
self.x.is_empty()
}
}
impl PlotRender for KdeData {
fn render(
&self,
renderer: &mut SkiaRenderer,
area: &PlotArea,
_theme: &Theme,
color: Color,
) -> Result<()> {
if self.is_empty() {
return Ok(());
}
let points: Vec<(f32, f32)> = self
.x
.iter()
.zip(self.y.iter())
.map(|(&x, &y)| area.data_to_screen(x, y))
.collect();
if self.config.fill {
let baseline_y = area.data_to_screen(0.0, 0.0).1;
let mut polygon: Vec<(f32, f32)> = Vec::with_capacity(points.len() + 2);
polygon.push((points[0].0, baseline_y));
polygon.extend_from_slice(&points);
polygon.push((points[points.len() - 1].0, baseline_y));
let fill_color = color.with_alpha(self.config.fill_alpha);
renderer.draw_filled_polygon(&polygon, fill_color)?;
}
let line_width = self.config.line_width;
renderer.draw_polyline(&points, color, line_width, LineStyle::Solid)?;
Ok(())
}
fn render_styled(
&self,
renderer: &mut SkiaRenderer,
area: &PlotArea,
theme: &Theme,
color: Color,
alpha: f32,
line_width: Option<f32>,
) -> Result<()> {
if self.is_empty() {
return Ok(());
}
let resolver = StyleResolver::new(theme);
let actual_line_width =
line_width.unwrap_or_else(|| resolver.line_width(Some(self.config.line_width)));
let points: Vec<(f32, f32)> = self
.x
.iter()
.zip(self.y.iter())
.map(|(&x, &y)| area.data_to_screen(x, y))
.collect();
if self.config.fill {
let baseline_y = area.data_to_screen(0.0, 0.0).1;
let mut polygon: Vec<(f32, f32)> = Vec::with_capacity(points.len() + 2);
polygon.push((points[0].0, baseline_y));
polygon.extend_from_slice(&points);
polygon.push((points[points.len() - 1].0, baseline_y));
let fill_alpha = self.config.fill_alpha * alpha.clamp(0.0, 1.0);
let fill_color = color.with_alpha(fill_alpha);
renderer.draw_filled_polygon(&polygon, fill_color)?;
}
let actual_color = color.with_alpha(alpha.clamp(0.0, 1.0));
renderer.draw_polyline(&points, actual_color, actual_line_width, LineStyle::Solid)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct Kde2dPlotConfig {
pub bandwidth_x: Option<f64>,
pub bandwidth_y: Option<f64>,
pub grid_size: usize,
pub levels: usize,
pub fill: bool,
pub cmap: String,
pub show_points: bool,
pub point_size: f32,
pub point_alpha: f32,
}
impl Default for Kde2dPlotConfig {
fn default() -> Self {
Self {
bandwidth_x: None,
bandwidth_y: None,
grid_size: 100,
levels: 10,
fill: true,
cmap: "viridis".to_string(),
show_points: false,
point_size: 3.0,
point_alpha: 0.5,
}
}
}
impl Kde2dPlotConfig {
pub fn new() -> Self {
Self::default()
}
pub fn grid_size(mut self, size: usize) -> Self {
self.grid_size = size.max(10);
self
}
pub fn levels(mut self, levels: usize) -> Self {
self.levels = levels.max(2);
self
}
pub fn fill(mut self, fill: bool) -> Self {
self.fill = fill;
self
}
pub fn show_points(mut self, show: bool) -> Self {
self.show_points = show;
self
}
pub fn cmap(mut self, cmap: &str) -> Self {
self.cmap = cmap.to_string();
self
}
}
#[derive(Debug, Clone)]
pub struct Kde2dPlotData {
pub x: Vec<f64>,
pub y: Vec<f64>,
pub density: Vec<Vec<f64>>,
}
pub fn compute_kde_2d_plot(x: &[f64], y: &[f64], config: &Kde2dPlotConfig) -> Kde2dPlotData {
let bandwidth = match (config.bandwidth_x, config.bandwidth_y) {
(Some(bx), Some(by)) => Some((bx, by)),
_ => None,
};
let (x_grid, y_grid, density) = kde_2d(x, y, bandwidth, Some(config.grid_size));
Kde2dPlotData {
x: x_grid,
y: y_grid,
density,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::plots::traits::PlotCompute;
#[test]
fn test_kde_basic() {
let data = vec![1.0, 2.0, 2.5, 3.0, 3.5, 4.0];
let config = KdeConfig::default();
let kde_data = compute_kde(&data, &config);
assert!(!kde_data.x.is_empty());
assert_eq!(kde_data.x.len(), kde_data.y.len());
assert!(!kde_data.cumulative);
}
#[test]
fn test_kde_cumulative() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let config = KdeConfig::default().cumulative(true);
let kde_data = compute_kde(&data, &config);
assert!(kde_data.cumulative);
for i in 1..kde_data.y.len() {
assert!(kde_data.y[i] >= kde_data.y[i - 1] - 1e-10);
}
if let Some(&last) = kde_data.y.last() {
assert!((last - 1.0).abs() < 0.01);
}
}
#[test]
fn test_kde_clipped() {
let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
let config = KdeConfig::default().clip(1.0, 4.0);
let kde_data = compute_kde(&data, &config);
for &xi in &kde_data.x {
assert!((1.0..=4.0).contains(&xi));
}
}
#[test]
fn test_kde_fill_polygon() {
let kde_data = KdeData {
x: vec![0.0, 1.0, 2.0],
y: vec![0.1, 0.5, 0.2],
bandwidth: 0.5,
cumulative: false,
config: KdeConfig::default(),
};
let polygon = kde_fill_polygon(&kde_data, 0.0);
assert_eq!(polygon.len(), 5); }
#[test]
fn test_kde_empty() {
let data: Vec<f64> = vec![];
let config = KdeConfig::default();
let kde_data = compute_kde(&data, &config);
assert!(kde_data.x.is_empty());
assert!(kde_data.y.is_empty());
}
#[test]
fn test_kde_plot_compute_trait() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let config = KdeConfig::default();
let result = Kde::compute(&data, &config);
assert!(result.is_ok());
let kde_data = result.unwrap();
assert!(!kde_data.is_empty());
assert_eq!(kde_data.x.len(), config.n_points);
}
#[test]
fn test_kde_plot_data_trait() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let config = KdeConfig::default();
let kde_data = compute_kde(&data, &config);
let ((x_min, x_max), (y_min, y_max)) = kde_data.data_bounds();
assert!(x_min < x_max);
assert_eq!(y_min, 0.0); assert!(y_max > 0.0);
assert!(!kde_data.is_empty());
let empty_data: Vec<f64> = vec![];
let empty_kde = compute_kde(&empty_data, &config);
assert!(empty_kde.is_empty());
}
#[test]
fn test_kde_config_implements_plot_config() {
fn accepts_plot_config<T: PlotConfig>(_: &T) {}
let config = KdeConfig::default();
accepts_plot_config(&config);
}
#[test]
fn test_kde_config_builder_methods() {
let config = KdeConfig::new()
.bandwidth(0.5)
.n_points(100)
.fill(true)
.fill_alpha(0.5)
.cumulative(false)
.clip(0.0, 10.0)
.vertical_line(5.0);
assert_eq!(config.bandwidth, Some(0.5));
assert_eq!(config.n_points, 100);
assert!(config.fill);
assert_eq!(config.fill_alpha, 0.5);
assert!(!config.cumulative);
assert_eq!(config.clip, Some((0.0, 10.0)));
assert_eq!(config.vertical_lines.len(), 1);
}
#[test]
#[allow(deprecated)]
fn test_deprecated_type_aliases() {
let _config: KdePlotConfig = KdeConfig::default();
let data = vec![1.0, 2.0, 3.0];
let _kde_data: KdePlotData = compute_kde_plot(&data, &_config);
}
}