Skip to main content

chartml_core/scales/
band.rs

1/// Maps discrete domain values (categories) to continuous range positions.
2/// Equivalent to D3's `scaleBand()`, used for bar chart x-axes.
3pub struct ScaleBand {
4    domain: Vec<String>,
5    range: (f64, f64),
6    padding_inner: f64,
7    padding_outer: f64,
8    step: f64,
9    bandwidth: f64,
10}
11
12impl ScaleBand {
13    /// Create a new band scale with the given domain and range.
14    /// Default padding: inner=0.1, outer=0.1.
15    pub fn new(domain: Vec<String>, range: (f64, f64)) -> Self {
16        let mut scale = Self {
17            domain,
18            range,
19            padding_inner: 0.1,
20            padding_outer: 0.1,
21            step: 0.0,
22            bandwidth: 0.0,
23        };
24        scale.recalculate();
25        scale
26    }
27
28    /// Set inner padding (between bands). Returns self for chaining.
29    pub fn padding_inner(mut self, padding: f64) -> Self {
30        self.padding_inner = padding;
31        self.recalculate();
32        self
33    }
34
35    /// Set outer padding (before first and after last band). Returns self for chaining.
36    pub fn padding_outer(mut self, padding: f64) -> Self {
37        self.padding_outer = padding;
38        self.recalculate();
39        self
40    }
41
42    /// Set both inner and outer padding to the same value.
43    pub fn padding(mut self, padding: f64) -> Self {
44        self.padding_inner = padding;
45        self.padding_outer = padding;
46        self.recalculate();
47        self
48    }
49
50    /// Map a domain value to its range position (the start of the band).
51    /// Returns None if the value is not in the domain.
52    pub fn map(&self, value: &str) -> Option<f64> {
53        let index = self.domain.iter().position(|d| d == value)?;
54        let start = self.range.0.min(self.range.1);
55        Some(start + self.padding_outer * self.step + index as f64 * self.step)
56    }
57
58    /// Get the width of each band.
59    pub fn bandwidth(&self) -> f64 {
60        self.bandwidth
61    }
62
63    /// Get the step size (band + inner padding).
64    pub fn step(&self) -> f64 {
65        self.step
66    }
67
68    /// Get the domain values.
69    pub fn domain(&self) -> &[String] {
70        &self.domain
71    }
72
73    /// Get the range extent.
74    pub fn range(&self) -> (f64, f64) {
75        self.range
76    }
77
78    /// Recalculate step and bandwidth from current domain/range/padding.
79    fn recalculate(&mut self) {
80        let n = self.domain.len() as f64;
81        let range_size = (self.range.1 - self.range.0).abs();
82
83        if n == 0.0 {
84            self.step = 0.0;
85            self.bandwidth = 0.0;
86            return;
87        }
88
89        self.step = range_size / (n - self.padding_inner + 2.0 * self.padding_outer).max(1.0);
90        self.bandwidth = self.step * (1.0 - self.padding_inner);
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    fn domain_abc() -> Vec<String> {
99        vec!["A".to_string(), "B".to_string(), "C".to_string()]
100    }
101
102    #[test]
103    fn band_scale_basic() {
104        let scale = ScaleBand::new(domain_abc(), (0.0, 300.0));
105        let a = scale.map("A").unwrap();
106        let c = scale.map("C").unwrap();
107        // A should be the first position, C should be the last
108        assert!(a < c, "A position {} should be less than C position {}", a, c);
109        // A should start after outer padding
110        assert!(a > 0.0, "A should have some outer padding offset");
111        // C + bandwidth should be close to but not exceed 300
112        assert!(c + scale.bandwidth() <= 300.0 + 1e-10);
113    }
114
115    #[test]
116    fn band_scale_bandwidth() {
117        let scale = ScaleBand::new(domain_abc(), (0.0, 300.0));
118        // With 3 items in 300px, bandwidth should be reasonable (not 0, not 300)
119        assert!(scale.bandwidth() > 0.0);
120        assert!(scale.bandwidth() < 300.0);
121        // Bandwidth should be less than step
122        assert!(scale.bandwidth() <= scale.step());
123    }
124
125    #[test]
126    fn band_scale_unknown_value() {
127        let scale = ScaleBand::new(domain_abc(), (0.0, 300.0));
128        assert!(scale.map("D").is_none());
129    }
130
131    #[test]
132    fn band_scale_empty_domain() {
133        let scale = ScaleBand::new(vec![], (0.0, 300.0));
134        assert_eq!(scale.bandwidth(), 0.0);
135        assert_eq!(scale.step(), 0.0);
136    }
137
138    #[test]
139    fn band_scale_single_item() {
140        let scale = ScaleBand::new(vec!["A".to_string()], (0.0, 300.0));
141        let a = scale.map("A").unwrap();
142        assert!(a >= 0.0);
143        assert!(scale.bandwidth() > 0.0);
144        assert!(a + scale.bandwidth() <= 300.0 + 1e-10);
145    }
146
147    #[test]
148    fn band_scale_custom_padding() {
149        let scale_default = ScaleBand::new(domain_abc(), (0.0, 300.0));
150        let scale_padded = ScaleBand::new(domain_abc(), (0.0, 300.0)).padding(0.2);
151        // More padding means smaller bandwidth
152        assert!(
153            scale_padded.bandwidth() < scale_default.bandwidth(),
154            "padded bandwidth {} should be less than default {}",
155            scale_padded.bandwidth(),
156            scale_default.bandwidth()
157        );
158    }
159
160    #[test]
161    fn band_scale_no_padding() {
162        let scale = ScaleBand::new(domain_abc(), (0.0, 300.0))
163            .padding_inner(0.0)
164            .padding_outer(0.0);
165        // With no padding, bandwidth == step == range / n
166        let expected = 300.0 / 3.0;
167        assert!(
168            (scale.bandwidth() - expected).abs() < 1e-10,
169            "bandwidth should be {} but got {}",
170            expected,
171            scale.bandwidth()
172        );
173        assert!(
174            (scale.step() - expected).abs() < 1e-10,
175            "step should be {} but got {}",
176            expected,
177            scale.step()
178        );
179        // First item should start at 0
180        let a = scale.map("A").unwrap();
181        assert!((a - 0.0).abs() < 1e-10, "A should be at 0, got {}", a);
182    }
183
184    #[test]
185    fn band_scale_step() {
186        let scale = ScaleBand::new(domain_abc(), (0.0, 300.0));
187        assert!(
188            scale.step() >= scale.bandwidth(),
189            "step {} should be >= bandwidth {}",
190            scale.step(),
191            scale.bandwidth()
192        );
193    }
194}