use crate::core::Result;
use crate::plots::traits::{PlotArea, PlotCompute, PlotConfig, PlotData, PlotRender};
use crate::render::skia::SkiaRenderer;
use crate::render::{Color, MarkerStyle, Theme};
use crate::stats::beeswarm::beeswarm_positions;
#[derive(Debug, Clone)]
pub struct SwarmConfig {
pub size: f32,
pub color: Option<Color>,
pub alpha: f32,
pub orientation: SwarmOrientation,
pub width: f64,
pub dodge: bool,
pub dodge_gap: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SwarmOrientation {
Vertical,
Horizontal,
}
impl Default for SwarmConfig {
fn default() -> Self {
Self {
size: 5.0,
color: None,
alpha: 0.8,
orientation: SwarmOrientation::Vertical,
width: 0.8,
dodge: false,
dodge_gap: 0.05,
}
}
}
impl SwarmConfig {
pub fn new() -> Self {
Self::default()
}
pub fn size(mut self, size: f32) -> Self {
self.size = size.max(0.1);
self
}
pub fn color(mut self, color: Color) -> Self {
self.color = Some(color);
self
}
pub fn alpha(mut self, alpha: f32) -> Self {
self.alpha = alpha.clamp(0.0, 1.0);
self
}
pub fn horizontal(mut self) -> Self {
self.orientation = SwarmOrientation::Horizontal;
self
}
pub fn width(mut self, width: f64) -> Self {
self.width = width.clamp(0.1, 1.0);
self
}
pub fn dodge(mut self, dodge: bool) -> Self {
self.dodge = dodge;
self
}
}
impl PlotConfig for SwarmConfig {}
pub struct Swarm;
#[derive(Debug, Clone, Copy)]
pub struct SwarmPoint {
pub category: usize,
pub value: f64,
pub x: f64,
pub y: f64,
pub group: Option<usize>,
}
pub fn compute_swarm_points(
categories: &[usize],
values: &[f64],
groups: Option<&[usize]>,
config: &SwarmConfig,
) -> Vec<SwarmPoint> {
let n = categories.len().min(values.len());
if n == 0 {
return vec![];
}
let num_categories = categories.iter().max().map_or(0, |&m| m + 1);
let num_groups = groups.map_or(1, |g| g.iter().max().map_or(1, |&m| m + 1));
let mut all_points = Vec::with_capacity(n);
for cat in 0..num_categories {
let cat_indices: Vec<usize> = (0..n).filter(|&i| categories[i] == cat).collect();
if cat_indices.is_empty() {
continue;
}
if config.dodge && num_groups > 1 {
for grp in 0..num_groups {
let grp_indices: Vec<usize> = cat_indices
.iter()
.filter(|&&i| groups.is_none_or(|g| g.get(i).copied().unwrap_or(0) == grp))
.copied()
.collect();
if grp_indices.is_empty() {
continue;
}
let grp_values: Vec<f64> = grp_indices.iter().map(|&i| values[i]).collect();
let point_size = config.size as f64;
let jitter_width = config.width / num_groups as f64;
let positions = beeswarm_positions(&grp_values, point_size, jitter_width);
let dodge_width = 0.8 / num_groups as f64;
let dodge_offset = (grp as f64 - (num_groups - 1) as f64 / 2.0) * dodge_width;
for (idx, (i, &pos)) in grp_indices.iter().zip(positions.iter()).enumerate() {
let base_x = cat as f64 + dodge_offset;
let (x, y) = match config.orientation {
SwarmOrientation::Vertical => (base_x + pos, grp_values[idx]),
SwarmOrientation::Horizontal => (grp_values[idx], base_x + pos),
};
all_points.push(SwarmPoint {
category: cat,
value: grp_values[idx],
x,
y,
group: Some(grp),
});
}
}
} else {
let cat_values: Vec<f64> = cat_indices.iter().map(|&i| values[i]).collect();
let point_size = config.size as f64;
let jitter_width = config.width;
let positions = beeswarm_positions(&cat_values, point_size, jitter_width);
for (idx, (&i, &pos)) in cat_indices.iter().zip(positions.iter()).enumerate() {
let base_x = cat as f64;
let grp = groups.map(|g| g.get(i).copied().unwrap_or(0));
let (x, y) = match config.orientation {
SwarmOrientation::Vertical => (base_x + pos, cat_values[idx]),
SwarmOrientation::Horizontal => (cat_values[idx], base_x + pos),
};
all_points.push(SwarmPoint {
category: cat,
value: cat_values[idx],
x,
y,
group: grp,
});
}
}
}
all_points
}
pub fn swarm_range(
points: &[SwarmPoint],
num_categories: usize,
orientation: SwarmOrientation,
) -> ((f64, f64), (f64, f64)) {
if points.is_empty() {
return ((0.0, 1.0), (0.0, 1.0));
}
let val_min = points.iter().map(|p| p.value).fold(f64::INFINITY, f64::min);
let val_max = points
.iter()
.map(|p| p.value)
.fold(f64::NEG_INFINITY, f64::max);
let x_spread = points
.iter()
.map(|p| p.x)
.fold(0.0_f64, |a, b| a.max(b.abs()));
let cat_range = (-x_spread - 0.5, num_categories as f64 + x_spread - 0.5);
match orientation {
SwarmOrientation::Vertical => (cat_range, (val_min, val_max)),
SwarmOrientation::Horizontal => ((val_min, val_max), cat_range),
}
}
#[derive(Debug, Clone)]
pub struct SwarmData {
pub points: Vec<SwarmPoint>,
pub num_categories: usize,
pub(crate) config: SwarmConfig,
}
pub struct SwarmInput<'a> {
pub categories: &'a [usize],
pub values: &'a [f64],
pub groups: Option<&'a [usize]>,
}
impl<'a> SwarmInput<'a> {
pub fn new(categories: &'a [usize], values: &'a [f64]) -> Self {
Self {
categories,
values,
groups: None,
}
}
pub fn with_groups(mut self, groups: &'a [usize]) -> Self {
self.groups = Some(groups);
self
}
}
impl PlotCompute for Swarm {
type Input<'a> = SwarmInput<'a>;
type Config = SwarmConfig;
type Output = SwarmData;
fn compute(input: Self::Input<'_>, config: &Self::Config) -> Result<Self::Output> {
let points = compute_swarm_points(input.categories, input.values, input.groups, config);
if points.is_empty() {
return Err(crate::core::PlottingError::EmptyDataSet);
}
let num_categories = input.categories.iter().max().map_or(0, |&m| m + 1);
Ok(SwarmData {
points,
num_categories,
config: config.clone(),
})
}
}
impl PlotData for SwarmData {
fn data_bounds(&self) -> ((f64, f64), (f64, f64)) {
swarm_range(&self.points, self.num_categories, self.config.orientation)
}
fn is_empty(&self) -> bool {
self.points.is_empty()
}
}
impl PlotRender for SwarmData {
fn render(
&self,
renderer: &mut SkiaRenderer,
area: &PlotArea,
_theme: &Theme,
color: Color,
) -> Result<()> {
if self.points.is_empty() {
return Ok(());
}
let config = &self.config;
let point_color = config.color.unwrap_or(color).with_alpha(config.alpha);
for point in &self.points {
let (px, py) = area.data_to_screen(point.x, point.y);
renderer.draw_marker(px, py, config.size, MarkerStyle::Circle, point_color)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_swarm_basic() {
let categories = vec![0, 0, 0, 1, 1, 1];
let values = vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0];
let config = SwarmConfig::default();
let points = compute_swarm_points(&categories, &values, None, &config);
assert_eq!(points.len(), 6);
let cat0_points: Vec<_> = points.iter().filter(|p| p.category == 0).collect();
let x_positions: Vec<f64> = cat0_points.iter().map(|p| p.x).collect();
let all_same = x_positions.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
assert!(!all_same || cat0_points.len() == 1);
}
#[test]
fn test_swarm_horizontal() {
let categories = vec![0, 1];
let values = vec![1.0, 2.0];
let config = SwarmConfig::default().horizontal();
let points = compute_swarm_points(&categories, &values, None, &config);
for point in &points {
assert!((point.x - point.value).abs() < 1e-10);
}
}
#[test]
fn test_swarm_with_groups() {
let categories = vec![0, 0, 0, 0];
let values = vec![1.0, 1.0, 2.0, 2.0];
let groups = vec![0, 1, 0, 1];
let config = SwarmConfig::default().dodge(true);
let points = compute_swarm_points(&categories, &values, Some(&groups), &config);
assert_eq!(points.len(), 4);
for point in &points {
assert!(point.group.is_some());
}
}
#[test]
fn test_swarm_empty() {
let categories: Vec<usize> = vec![];
let values: Vec<f64> = vec![];
let config = SwarmConfig::default();
let points = compute_swarm_points(&categories, &values, None, &config);
assert!(points.is_empty());
}
#[test]
fn test_swarm_config_implements_plot_config() {
fn assert_plot_config<T: PlotConfig>() {}
assert_plot_config::<SwarmConfig>();
}
#[test]
fn test_swarm_plot_compute_trait() {
use crate::plots::traits::PlotCompute;
let categories = vec![0, 0, 1, 1, 2, 2];
let values = vec![1.0, 1.5, 2.0, 2.5, 3.0, 3.5];
let config = SwarmConfig::default();
let input = SwarmInput::new(&categories, &values);
let result = Swarm::compute(input, &config);
assert!(result.is_ok());
let swarm_data = result.unwrap();
assert_eq!(swarm_data.points.len(), 6);
assert_eq!(swarm_data.num_categories, 3);
}
#[test]
fn test_swarm_plot_compute_with_groups() {
use crate::plots::traits::PlotCompute;
let categories = vec![0, 0, 1, 1];
let values = vec![1.0, 2.0, 1.0, 2.0];
let groups = vec![0, 1, 0, 1];
let config = SwarmConfig::default().dodge(true);
let input = SwarmInput::new(&categories, &values).with_groups(&groups);
let result = Swarm::compute(input, &config);
assert!(result.is_ok());
let swarm_data = result.unwrap();
assert_eq!(swarm_data.points.len(), 4);
}
#[test]
fn test_swarm_plot_compute_empty() {
use crate::plots::traits::PlotCompute;
let categories: Vec<usize> = vec![];
let values: Vec<f64> = vec![];
let config = SwarmConfig::default();
let input = SwarmInput::new(&categories, &values);
let result = Swarm::compute(input, &config);
assert!(result.is_err());
}
#[test]
fn test_swarm_plot_data_trait() {
use crate::plots::traits::{PlotCompute, PlotData};
let categories = vec![0, 1, 2];
let values = vec![1.0, 5.0, 3.0];
let config = SwarmConfig::default();
let input = SwarmInput::new(&categories, &values);
let swarm_data = Swarm::compute(input, &config).unwrap();
let ((x_min, x_max), (y_min, y_max)) = swarm_data.data_bounds();
assert!(x_min <= x_max);
assert!(y_min <= y_max);
assert!(!swarm_data.is_empty());
}
}