use crate::render::Color;
#[derive(Debug, Clone)]
pub struct PairPlotConfig {
pub vars: Vec<String>,
pub diag_kind: DiagKind,
pub off_diag_kind: OffDiagKind,
pub colors: Option<Vec<Color>>,
pub scatter_size: f32,
pub scatter_alpha: f32,
pub bins: usize,
pub upper: bool,
pub lower: bool,
pub diag: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DiagKind {
Hist,
Kde,
None,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OffDiagKind {
Scatter,
Reg,
Kde,
}
impl Default for PairPlotConfig {
fn default() -> Self {
Self {
vars: vec![],
diag_kind: DiagKind::Hist,
off_diag_kind: OffDiagKind::Scatter,
colors: None,
scatter_size: 3.0,
scatter_alpha: 0.5,
bins: 20,
upper: true,
lower: true,
diag: true,
}
}
}
impl PairPlotConfig {
pub fn new() -> Self {
Self::default()
}
pub fn vars(mut self, vars: Vec<String>) -> Self {
self.vars = vars;
self
}
pub fn diag_kind(mut self, kind: DiagKind) -> Self {
self.diag_kind = kind;
self
}
pub fn off_diag_kind(mut self, kind: OffDiagKind) -> Self {
self.off_diag_kind = kind;
self
}
pub fn colors(mut self, colors: Vec<Color>) -> Self {
self.colors = Some(colors);
self
}
pub fn lower_only(mut self) -> Self {
self.upper = false;
self.lower = true;
self
}
pub fn upper_only(mut self) -> Self {
self.upper = true;
self.lower = false;
self
}
}
#[derive(Debug, Clone)]
pub struct PairPlotCell {
pub row: usize,
pub col: usize,
pub var_indices: (usize, usize),
pub is_diagonal: bool,
pub bounds: (f64, f64, f64, f64),
}
#[derive(Debug, Clone)]
pub struct PairPlotLayout {
pub n_vars: usize,
pub cells: Vec<PairPlotCell>,
pub gap: f64,
}
pub fn compute_pairplot_layout(n_vars: usize, config: &PairPlotConfig) -> PairPlotLayout {
if n_vars == 0 {
return PairPlotLayout {
n_vars: 0,
cells: vec![],
gap: 0.02,
};
}
let gap = 0.02;
let cell_size = (1.0 - gap * (n_vars + 1) as f64) / n_vars as f64;
let mut cells = Vec::new();
for row in 0..n_vars {
for col in 0..n_vars {
let is_diagonal = row == col;
let is_upper = col > row;
let is_lower = col < row;
let should_render = (is_diagonal && config.diag)
|| (is_upper && config.upper)
|| (is_lower && config.lower);
if should_render {
let x = gap + col as f64 * (cell_size + gap);
let y = gap + (n_vars - 1 - row) as f64 * (cell_size + gap);
cells.push(PairPlotCell {
row,
col,
var_indices: (col, row),
is_diagonal,
bounds: (x, y, cell_size, cell_size),
});
}
}
}
PairPlotLayout { n_vars, cells, gap }
}
pub fn cell_variable_names<'a>(cell: &PairPlotCell, var_names: &'a [String]) -> (&'a str, &'a str) {
let default = "";
let x_name = var_names
.get(cell.var_indices.0)
.map(|s| s.as_str())
.unwrap_or(default);
let y_name = var_names
.get(cell.var_indices.1)
.map(|s| s.as_str())
.unwrap_or(default);
(x_name, y_name)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pairplot_layout() {
let config = PairPlotConfig::default();
let layout = compute_pairplot_layout(3, &config);
assert_eq!(layout.n_vars, 3);
assert_eq!(layout.cells.len(), 9);
}
#[test]
fn test_pairplot_lower_only() {
let config = PairPlotConfig::default().lower_only();
let layout = compute_pairplot_layout(3, &config);
assert_eq!(layout.cells.len(), 6);
for cell in &layout.cells {
assert!(cell.col <= cell.row);
}
}
#[test]
fn test_pairplot_cell_bounds() {
let config = PairPlotConfig::default();
let layout = compute_pairplot_layout(2, &config);
for cell in &layout.cells {
assert!(cell.bounds.0 >= 0.0 && cell.bounds.0 <= 1.0);
assert!(cell.bounds.1 >= 0.0 && cell.bounds.1 <= 1.0);
}
}
#[test]
fn test_cell_variable_names() {
let var_names = vec!["A".to_string(), "B".to_string(), "C".to_string()];
let cell = PairPlotCell {
row: 1,
col: 0,
var_indices: (0, 1),
is_diagonal: false,
bounds: (0.0, 0.0, 0.5, 0.5),
};
let (x_name, y_name) = cell_variable_names(&cell, &var_names);
assert_eq!(x_name, "A");
assert_eq!(y_name, "B");
}
}