use std::{collections::HashMap, fmt::Display, slice};
use gnuplot::{
AutoOption,
Axes2D,
AxesCommon,
ColorType,
DashType,
Fix,
LabelOption,
Major,
PlotOption,
TickOption,
};
use indexmap::IndexMap;
use num_traits::AsPrimitive;
use statrs::statistics::{Data, Distribution, Max, Min, OrderStatistics};
use crate::plot::{
AxisFormat,
DEFAULT_FONT_FAMILY,
DEFAULT_FONT_SIZE,
Plot,
auto_option,
init_scaler,
};
#[derive(Default, Clone)]
pub struct BoxPlot {
font_type: Option<String>,
font_size: Option<f64>,
title: Option<String>,
x_label: Option<String>,
y_label: Option<String>,
y_min: Option<f64>,
y_max: Option<f64>,
y_tick: Option<f64>,
y_format: Option<AxisFormat>,
y_log_base: Option<f64>,
map: IndexMap<String, Vec<f64>>,
colors: HashMap<String, String>,
}
struct Stats {
min: f64,
max: f64,
mean: f64,
median: f64,
q1: f64,
q3: f64,
}
impl Plot for BoxPlot {
fn is_empty(&self) -> bool {
self.map.is_empty()
}
fn set_font_type<T>(&mut self, font_type: T)
where
T: AsRef<str>,
{
self.font_type = Some(font_type.as_ref().to_string());
}
fn with_font_type<T>(mut self, font_type: T) -> Self
where
T: AsRef<str>,
{
self.set_font_type(font_type);
self
}
fn set_font_size(&mut self, font_size: impl AsPrimitive<f64>) {
self.font_size = Some(font_size.as_());
}
fn with_font_size(mut self, font_size: impl AsPrimitive<f64>) -> Self {
self.set_font_size(font_size);
self
}
fn set_title<T>(&mut self, title: T)
where
T: Display,
{
self.title = Some(title.to_string());
}
fn with_title<T>(mut self, title: T) -> Self
where
T: Display,
{
self.set_title(title);
self
}
fn set_x_label<T>(&mut self, label: T)
where
T: Display,
{
self.x_label = Some(label.to_string());
}
fn with_x_label<T>(mut self, label: T) -> Self
where
T: Display,
{
self.set_x_label(label);
self
}
fn set_y_label<T>(&mut self, label: T)
where
T: Display,
{
self.y_label = Some(label.to_string());
}
fn with_y_label<T>(mut self, label: T) -> Self
where
T: Display,
{
self.set_y_label(label);
self
}
fn configure(&mut self, axes: &mut Axes2D) {
let font = LabelOption::Font(
self.font_type
.as_deref()
.unwrap_or(DEFAULT_FONT_FAMILY),
self.font_size.unwrap_or(DEFAULT_FONT_SIZE),
);
let labels = self
.map
.keys()
.map(|label| label.into())
.collect::<Vec<String>>();
let y_scaler = init_scaler(self.y_format, self.max_y_value());
axes.set_x_range(
AutoOption::Fix(0.0),
AutoOption::Fix(self.map.len() as f64 + 1.0),
)
.set_y_range(
auto_option(self.y_min, y_scaler.as_ref()),
auto_option(self.y_max, y_scaler.as_ref()),
)
.set_x_ticks_custom(
labels
.iter()
.enumerate()
.map(|(index, label)| Major(index as f64 + 1.0, Fix(label))),
&[
TickOption::Mirror(false),
TickOption::Inward(false),
],
&[
font.clone(),
LabelOption::Rotate(-45.0),
],
)
.set_y_ticks(
Some((auto_option(self.y_tick, y_scaler.as_ref()), 0)),
&[
TickOption::Mirror(false),
TickOption::Inward(false),
],
slice::from_ref(&font),
)
.set_grid_options(false, &[
PlotOption::Color(ColorType::RGBString("#bbbbbb")),
PlotOption::LineWidth(2.0),
PlotOption::LineStyle(DashType::Dot),
])
.set_y_grid(true);
if let Some(title) = &self.title {
axes.set_title(title, slice::from_ref(&font));
}
if let Some(x_label) = &self.x_label {
axes.set_x_label(x_label, slice::from_ref(&font));
}
if let Some(y_label) = &self.y_label {
axes.set_y_label(&y_scaler.apply_unit(y_label), &[font]);
}
if let Some(base) = self.y_log_base {
axes.set_y_log(Some(base));
}
for (index, label) in labels.iter().enumerate() {
let x_value = index as f64 + 1.0;
let stats = self.get_stats(label);
let color = self
.colors
.get(label)
.map(|color| color.as_str())
.unwrap_or("red");
axes.box_and_whisker(
[x_value],
[y_scaler.scale(stats.q1())],
[y_scaler.scale(stats.min())],
[y_scaler.scale(stats.max())],
[y_scaler.scale(stats.q3())],
&[
PlotOption::BoxWidth(vec![0.25]),
PlotOption::Color(ColorType::RGBString("white")),
PlotOption::BorderColor(ColorType::RGBString(color)),
PlotOption::WhiskerBars(0.5),
PlotOption::LineWidth(1.25),
],
)
.points([x_value], [y_scaler.scale(stats.mean())], &[
PlotOption::Color(ColorType::RGBString("blue")),
PlotOption::PointSymbol('x'),
PlotOption::PointSize(0.75),
])
.points([x_value], [y_scaler.scale(stats.median())], &[
PlotOption::Color(ColorType::RGBString("blue")),
PlotOption::PointSymbol('+'),
PlotOption::PointSize(0.75),
]);
}
}
}
impl BoxPlot {
pub fn set_y_min(&mut self, y_min: impl AsPrimitive<f64>) {
self.y_min = Some(y_min.as_());
}
pub fn with_y_min(mut self, y_min: impl AsPrimitive<f64>) -> Self {
self.set_y_min(y_min);
self
}
pub fn set_y_max(&mut self, y_max: impl AsPrimitive<f64>) {
self.y_max = Some(y_max.as_());
}
pub fn with_y_max(mut self, y_max: impl AsPrimitive<f64>) -> Self {
self.set_y_max(y_max);
self
}
pub fn set_y_tick(&mut self, y_tick: impl AsPrimitive<f64>) {
self.y_tick = Some(y_tick.as_());
}
pub fn with_y_tick(mut self, y_tick: impl AsPrimitive<f64>) -> Self {
self.set_y_tick(y_tick);
self
}
pub fn set_y_format(&mut self, format_type: AxisFormat) {
if let AxisFormat::Log(base) = format_type {
self.y_log_base = Some(base);
return;
}
self.y_format = Some(format_type);
}
pub fn with_y_format(mut self, format_type: AxisFormat) -> Self {
self.set_y_format(format_type);
self
}
pub fn set_color<T1, T2>(&mut self, label: T1, color: T2)
where
T1: AsRef<str>,
T2: AsRef<str>,
{
self.colors
.insert(label.as_ref().to_string(), color.as_ref().to_string());
}
pub fn with_color<T1, T2>(mut self, label: T1, color: T1) -> Self
where
T1: AsRef<str>,
T2: AsRef<str>,
{
self.set_color(label, color);
self
}
pub fn add<T>(&mut self, label: T, value: impl AsPrimitive<f64>)
where
T: Display,
{
self.map
.entry(label.to_string())
.and_modify(|values| values.push(value.as_()))
.or_insert(vec![value.as_()]);
}
fn get_stats(&mut self, label: &str) -> Stats {
let values = self
.map
.get_mut(label)
.expect("Could not get stats");
Stats::new(values)
}
fn max_y_value(&self) -> f64 {
let mut max = self.y_max;
for (_, values) in &self.map {
for y_value in values {
if max.is_none_or(|value| value < *y_value) {
max = Some(*y_value);
}
}
}
max.unwrap_or(0.0)
}
}
impl Stats {
fn new(values: &mut Vec<f64>) -> Self {
let mut data = Data::new(values);
Stats {
min: data.min(),
max: data.max(),
mean: data
.mean()
.expect("Could not calculate mean of data."),
median: data.median(),
q1: data.lower_quartile(),
q3: data.upper_quartile(),
}
}
fn min(&self) -> f64 {
self.min
}
fn max(&self) -> f64 {
self.max
}
fn mean(&self) -> f64 {
self.mean
}
fn median(&self) -> f64 {
self.median
}
fn q1(&self) -> f64 {
self.q1
}
fn q3(&self) -> f64 {
self.q3
}
}