Skip to main content

chartml_core/scales/
sqrt.rs

1use super::{ContinuousScale, tick_step, round_to_precision};
2
3/// Square root scale for bubble sizes. Maps continuous domain to continuous range
4/// via sqrt transformation. Equivalent to D3's `scaleSqrt()`.
5/// This is a power scale with exponent 0.5.
6pub struct ScaleSqrt {
7    domain: (f64, f64),
8    range: (f64, f64),
9}
10
11impl ScaleSqrt {
12    /// Create a new sqrt scale with the given domain and range.
13    pub fn new(domain: (f64, f64), range: (f64, f64)) -> Self {
14        Self { domain, range }
15    }
16
17    /// Map a domain value to a range value using sqrt interpolation.
18    /// Negative domain values are clamped to 0 before sqrt.
19    pub fn map(&self, value: f64) -> f64 {
20        let (d0, d1) = self.domain;
21        let (r0, r1) = self.range;
22        let sqrt_d0 = d0.max(0.0).sqrt();
23        let sqrt_d1 = d1.max(0.0).sqrt();
24        let sqrt_val = value.max(0.0).sqrt();
25        let sqrt_span = sqrt_d1 - sqrt_d0;
26        if sqrt_span == 0.0 {
27            return (r0 + r1) / 2.0;
28        }
29        r0 + (sqrt_val - sqrt_d0) / sqrt_span * (r1 - r0)
30    }
31
32    /// Inverse mapping: range value back to domain value.
33    pub fn invert(&self, value: f64) -> f64 {
34        let (d0, d1) = self.domain;
35        let (r0, r1) = self.range;
36        let sqrt_d0 = d0.max(0.0).sqrt();
37        let sqrt_d1 = d1.max(0.0).sqrt();
38        let range_span = r1 - r0;
39        if range_span == 0.0 {
40            return (d0 + d1) / 2.0;
41        }
42        let sqrt_val = sqrt_d0 + (value - r0) / range_span * (sqrt_d1 - sqrt_d0);
43        sqrt_val * sqrt_val
44    }
45
46    /// Generate ticks by computing them in sqrt-transformed space,
47    /// then squaring back to the original domain.
48    pub fn ticks(&self, count: usize) -> Vec<f64> {
49        if count == 0 {
50            return vec![];
51        }
52        let (d0, d1) = self.domain;
53        let sqrt_min = d0.max(0.0).sqrt().min(d1.max(0.0).sqrt());
54        let sqrt_max = d0.max(0.0).sqrt().max(d1.max(0.0).sqrt());
55        if sqrt_min == sqrt_max {
56            return vec![sqrt_min * sqrt_min];
57        }
58
59        let step = tick_step(sqrt_min, sqrt_max, count);
60        if step == 0.0 || !step.is_finite() {
61            return vec![];
62        }
63
64        let mut ticks = Vec::new();
65        let start = (sqrt_min / step).ceil();
66        let stop = (sqrt_max / step).floor();
67
68        let mut i = start;
69        while i <= stop {
70            let sqrt_tick = i * step;
71            let tick = round_to_precision(sqrt_tick * sqrt_tick, step * step);
72            ticks.push(tick);
73            i += 1.0;
74        }
75
76        ticks
77    }
78
79    /// Get the domain extent.
80    pub fn domain(&self) -> (f64, f64) {
81        self.domain
82    }
83
84    /// Get the range extent.
85    pub fn range(&self) -> (f64, f64) {
86        self.range
87    }
88}
89
90impl ContinuousScale for ScaleSqrt {
91    fn map(&self, value: f64) -> f64 {
92        ScaleSqrt::map(self, value)
93    }
94
95    fn domain(&self) -> (f64, f64) {
96        ScaleSqrt::domain(self)
97    }
98
99    fn range(&self) -> (f64, f64) {
100        ScaleSqrt::range(self)
101    }
102
103    fn ticks(&self, count: usize) -> Vec<f64> {
104        ScaleSqrt::ticks(self, count)
105    }
106
107    fn clamp(&self, value: f64) -> f64 {
108        let (d0, d1) = self.domain;
109        let min = d0.min(d1);
110        let max = d0.max(d1);
111        value.clamp(min, max)
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn sqrt_scale_maps_zero() {
121        let scale = ScaleSqrt::new((0.0, 100.0), (0.0, 100.0));
122        assert!((scale.map(0.0) - 0.0).abs() < 1e-10);
123    }
124
125    #[test]
126    fn sqrt_scale_maps_max() {
127        let scale = ScaleSqrt::new((0.0, 100.0), (0.0, 100.0));
128        assert!((scale.map(100.0) - 100.0).abs() < 1e-10);
129    }
130
131    #[test]
132    fn sqrt_scale_nonlinear() {
133        // sqrt(25)/sqrt(100) = 5/10 = 0.5, so map(25) should be 50
134        let scale = ScaleSqrt::new((0.0, 100.0), (0.0, 100.0));
135        assert!(
136            (scale.map(25.0) - 50.0).abs() < 1e-10,
137            "map(25) should be 50, got {}",
138            scale.map(25.0)
139        );
140    }
141
142    #[test]
143    fn sqrt_scale_inverts() {
144        let scale = ScaleSqrt::new((0.0, 100.0), (0.0, 100.0));
145        let x = 42.0;
146        let mapped = scale.map(x);
147        let inverted = scale.invert(mapped);
148        assert!(
149            (inverted - x).abs() < 1e-10,
150            "invert(map({})) should be {}, got {}",
151            x,
152            x,
153            inverted
154        );
155    }
156
157    #[test]
158    fn sqrt_scale_ticks() {
159        let scale = ScaleSqrt::new((0.0, 100.0), (0.0, 100.0));
160        let ticks = scale.ticks(5);
161        assert!(!ticks.is_empty(), "should generate at least one tick");
162        assert!(
163            ticks.len() <= 15,
164            "should not generate too many ticks, got {}",
165            ticks.len()
166        );
167        // All ticks should be non-negative and within domain
168        for tick in &ticks {
169            assert!(*tick >= 0.0, "tick {} should be >= 0", tick);
170            assert!(*tick <= 100.0 + 1e-10, "tick {} should be <= 100", tick);
171        }
172    }
173
174    #[test]
175    fn sqrt_scale_negative_clamped() {
176        let scale = ScaleSqrt::new((0.0, 100.0), (0.0, 100.0));
177        // Negative values should be clamped to 0 before sqrt, so map(-10) == map(0)
178        assert!(
179            (scale.map(-10.0) - scale.map(0.0)).abs() < 1e-10,
180            "negative values should be clamped to 0"
181        );
182    }
183}