Skip to main content

chartml_core/scales/
band.rs

1/// Maps discrete domain values (categories) to continuous range positions.
2/// Equivalent to D3's `scaleBand()`, used for bar chart x-axes.
3pub struct ScaleBand {
4    domain: Vec<String>,
5    range: (f64, f64),
6    padding_inner: f64,
7    padding_outer: f64,
8    step: f64,
9    bandwidth: f64,
10}
11
12impl ScaleBand {
13    /// Create a new band scale with the given domain and range.
14    /// Default padding: inner=0.1, outer=0.1.
15    pub fn new(domain: Vec<String>, range: (f64, f64)) -> Self {
16        let mut scale = Self {
17            domain,
18            range,
19            padding_inner: 0.1,
20            padding_outer: 0.1,
21            step: 0.0,
22            bandwidth: 0.0,
23        };
24        scale.recalculate();
25        scale
26    }
27
28    /// Set inner padding (between bands). Returns self for chaining.
29    pub fn padding_inner(mut self, padding: f64) -> Self {
30        self.padding_inner = padding;
31        self.recalculate();
32        self
33    }
34
35    /// Set outer padding (before first and after last band). Returns self for chaining.
36    pub fn padding_outer(mut self, padding: f64) -> Self {
37        self.padding_outer = padding;
38        self.recalculate();
39        self
40    }
41
42    /// Set both inner and outer padding to the same value.
43    pub fn padding(mut self, padding: f64) -> Self {
44        self.padding_inner = padding;
45        self.padding_outer = padding;
46        self.recalculate();
47        self
48    }
49
50    /// Map a domain value to its range position (the start of the band).
51    /// Returns None if the value is not in the domain.
52    pub fn map(&self, value: &str) -> Option<f64> {
53        let index = self.domain.iter().position(|d| d == value)?;
54        let start = self.range.0.min(self.range.1);
55        Some(start + self.padding_outer * self.step + index as f64 * self.step)
56    }
57
58    /// Get the width of each band.
59    pub fn bandwidth(&self) -> f64 {
60        self.bandwidth
61    }
62
63    /// Get the step size (band + inner padding).
64    pub fn step(&self) -> f64 {
65        self.step
66    }
67
68    /// Get the domain values.
69    pub fn domain(&self) -> &[String] {
70        &self.domain
71    }
72
73    /// Get the range extent.
74    pub fn range(&self) -> (f64, f64) {
75        self.range
76    }
77
78    /// Recalculate step and bandwidth from current domain/range/padding.
79    fn recalculate(&mut self) {
80        let n = self.domain.len() as f64;
81        let range_size = (self.range.1 - self.range.0).abs();
82
83        if n == 0.0 {
84            self.step = 0.0;
85            self.bandwidth = 0.0;
86            return;
87        }
88
89        self.step = range_size / (n - self.padding_inner + 2.0 * self.padding_outer).max(1.0);
90        self.bandwidth = self.step * (1.0 - self.padding_inner);
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    #![allow(clippy::unwrap_used)]
97    use super::*;
98
99    fn domain_abc() -> Vec<String> {
100        vec!["A".to_string(), "B".to_string(), "C".to_string()]
101    }
102
103    #[test]
104    fn band_scale_basic() {
105        let scale = ScaleBand::new(domain_abc(), (0.0, 300.0));
106        let a = scale.map("A").unwrap();
107        let c = scale.map("C").unwrap();
108        // A should be the first position, C should be the last
109        assert!(a < c, "A position {} should be less than C position {}", a, c);
110        // A should start after outer padding
111        assert!(a > 0.0, "A should have some outer padding offset");
112        // C + bandwidth should be close to but not exceed 300
113        assert!(c + scale.bandwidth() <= 300.0 + 1e-10);
114    }
115
116    #[test]
117    fn band_scale_bandwidth() {
118        let scale = ScaleBand::new(domain_abc(), (0.0, 300.0));
119        // With 3 items in 300px, bandwidth should be reasonable (not 0, not 300)
120        assert!(scale.bandwidth() > 0.0);
121        assert!(scale.bandwidth() < 300.0);
122        // Bandwidth should be less than step
123        assert!(scale.bandwidth() <= scale.step());
124    }
125
126    #[test]
127    fn band_scale_unknown_value() {
128        let scale = ScaleBand::new(domain_abc(), (0.0, 300.0));
129        assert!(scale.map("D").is_none());
130    }
131
132    #[test]
133    fn band_scale_empty_domain() {
134        let scale = ScaleBand::new(vec![], (0.0, 300.0));
135        assert_eq!(scale.bandwidth(), 0.0);
136        assert_eq!(scale.step(), 0.0);
137    }
138
139    #[test]
140    fn band_scale_single_item() {
141        let scale = ScaleBand::new(vec!["A".to_string()], (0.0, 300.0));
142        let a = scale.map("A").unwrap();
143        assert!(a >= 0.0);
144        assert!(scale.bandwidth() > 0.0);
145        assert!(a + scale.bandwidth() <= 300.0 + 1e-10);
146    }
147
148    #[test]
149    fn band_scale_custom_padding() {
150        let scale_default = ScaleBand::new(domain_abc(), (0.0, 300.0));
151        let scale_padded = ScaleBand::new(domain_abc(), (0.0, 300.0)).padding(0.2);
152        // More padding means smaller bandwidth
153        assert!(
154            scale_padded.bandwidth() < scale_default.bandwidth(),
155            "padded bandwidth {} should be less than default {}",
156            scale_padded.bandwidth(),
157            scale_default.bandwidth()
158        );
159    }
160
161    #[test]
162    fn band_scale_no_padding() {
163        let scale = ScaleBand::new(domain_abc(), (0.0, 300.0))
164            .padding_inner(0.0)
165            .padding_outer(0.0);
166        // With no padding, bandwidth == step == range / n
167        let expected = 300.0 / 3.0;
168        assert!(
169            (scale.bandwidth() - expected).abs() < 1e-10,
170            "bandwidth should be {} but got {}",
171            expected,
172            scale.bandwidth()
173        );
174        assert!(
175            (scale.step() - expected).abs() < 1e-10,
176            "step should be {} but got {}",
177            expected,
178            scale.step()
179        );
180        // First item should start at 0
181        let a = scale.map("A").unwrap();
182        assert!((a - 0.0).abs() < 1e-10, "A should be at 0, got {}", a);
183    }
184
185    #[test]
186    fn band_scale_step() {
187        let scale = ScaleBand::new(domain_abc(), (0.0, 300.0));
188        assert!(
189            scale.step() >= scale.bandwidth(),
190            "step {} should be >= bandwidth {}",
191            scale.step(),
192            scale.bandwidth()
193        );
194    }
195}