Skip to main content

chartml_core/scales/
sqrt.rs

1use super::{ContinuousScale, tick_step, round_to_precision};
2
3/// Square root scale for bubble sizes. Maps continuous domain to continuous range
4/// via sqrt transformation. Equivalent to D3's `scaleSqrt()`.
5/// This is a power scale with exponent 0.5.
6pub struct ScaleSqrt {
7    domain: (f64, f64),
8    range: (f64, f64),
9}
10
11impl ScaleSqrt {
12    /// Create a new sqrt scale with the given domain and range.
13    pub fn new(domain: (f64, f64), range: (f64, f64)) -> Self {
14        Self { domain, range }
15    }
16
17    /// Map a domain value to a range value using sqrt interpolation.
18    /// Negative domain values are clamped to 0 before sqrt.
19    pub fn map(&self, value: f64) -> f64 {
20        let (d0, d1) = self.domain;
21        let (r0, r1) = self.range;
22        let sqrt_d0 = d0.max(0.0).sqrt();
23        let sqrt_d1 = d1.max(0.0).sqrt();
24        let sqrt_val = value.max(0.0).sqrt();
25        let sqrt_span = sqrt_d1 - sqrt_d0;
26        if sqrt_span == 0.0 {
27            return (r0 + r1) / 2.0;
28        }
29        r0 + (sqrt_val - sqrt_d0) / sqrt_span * (r1 - r0)
30    }
31
32    /// Inverse mapping: range value back to domain value.
33    pub fn invert(&self, value: f64) -> f64 {
34        let (d0, d1) = self.domain;
35        let (r0, r1) = self.range;
36        let sqrt_d0 = d0.max(0.0).sqrt();
37        let sqrt_d1 = d1.max(0.0).sqrt();
38        let range_span = r1 - r0;
39        if range_span == 0.0 {
40            return (d0 + d1) / 2.0;
41        }
42        let sqrt_val = sqrt_d0 + (value - r0) / range_span * (sqrt_d1 - sqrt_d0);
43        sqrt_val * sqrt_val
44    }
45
46    /// Generate ticks by computing them in sqrt-transformed space,
47    /// then squaring back to the original domain.
48    pub fn ticks(&self, count: usize) -> Vec<f64> {
49        if count == 0 {
50            return vec![];
51        }
52        let (d0, d1) = self.domain;
53        let sqrt_min = d0.max(0.0).sqrt().min(d1.max(0.0).sqrt());
54        let sqrt_max = d0.max(0.0).sqrt().max(d1.max(0.0).sqrt());
55        if sqrt_min == sqrt_max {
56            return vec![sqrt_min * sqrt_min];
57        }
58
59        let step = tick_step(sqrt_min, sqrt_max, count);
60        if step == 0.0 || !step.is_finite() {
61            return vec![];
62        }
63
64        let mut ticks = Vec::new();
65        let start = (sqrt_min / step).ceil();
66        let stop = (sqrt_max / step).floor();
67
68        let mut i = start;
69        while i <= stop {
70            let sqrt_tick = i * step;
71            let tick = round_to_precision(sqrt_tick * sqrt_tick, step * step);
72            ticks.push(tick);
73            i += 1.0;
74        }
75
76        ticks
77    }
78
79    /// Get the domain extent.
80    pub fn domain(&self) -> (f64, f64) {
81        self.domain
82    }
83
84    /// Get the range extent.
85    pub fn range(&self) -> (f64, f64) {
86        self.range
87    }
88}
89
90impl ContinuousScale for ScaleSqrt {
91    fn map(&self, value: f64) -> f64 {
92        ScaleSqrt::map(self, value)
93    }
94
95    fn domain(&self) -> (f64, f64) {
96        ScaleSqrt::domain(self)
97    }
98
99    fn range(&self) -> (f64, f64) {
100        ScaleSqrt::range(self)
101    }
102
103    fn ticks(&self, count: usize) -> Vec<f64> {
104        ScaleSqrt::ticks(self, count)
105    }
106
107    fn clamp(&self, value: f64) -> f64 {
108        let (d0, d1) = self.domain;
109        let min = d0.min(d1);
110        let max = d0.max(d1);
111        value.clamp(min, max)
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    #![allow(clippy::unwrap_used)]
118    use super::*;
119
120    #[test]
121    fn sqrt_scale_maps_zero() {
122        let scale = ScaleSqrt::new((0.0, 100.0), (0.0, 100.0));
123        assert!((scale.map(0.0) - 0.0).abs() < 1e-10);
124    }
125
126    #[test]
127    fn sqrt_scale_maps_max() {
128        let scale = ScaleSqrt::new((0.0, 100.0), (0.0, 100.0));
129        assert!((scale.map(100.0) - 100.0).abs() < 1e-10);
130    }
131
132    #[test]
133    fn sqrt_scale_nonlinear() {
134        // sqrt(25)/sqrt(100) = 5/10 = 0.5, so map(25) should be 50
135        let scale = ScaleSqrt::new((0.0, 100.0), (0.0, 100.0));
136        assert!(
137            (scale.map(25.0) - 50.0).abs() < 1e-10,
138            "map(25) should be 50, got {}",
139            scale.map(25.0)
140        );
141    }
142
143    #[test]
144    fn sqrt_scale_inverts() {
145        let scale = ScaleSqrt::new((0.0, 100.0), (0.0, 100.0));
146        let x = 42.0;
147        let mapped = scale.map(x);
148        let inverted = scale.invert(mapped);
149        assert!(
150            (inverted - x).abs() < 1e-10,
151            "invert(map({})) should be {}, got {}",
152            x,
153            x,
154            inverted
155        );
156    }
157
158    #[test]
159    fn sqrt_scale_ticks() {
160        let scale = ScaleSqrt::new((0.0, 100.0), (0.0, 100.0));
161        let ticks = scale.ticks(5);
162        assert!(!ticks.is_empty(), "should generate at least one tick");
163        assert!(
164            ticks.len() <= 15,
165            "should not generate too many ticks, got {}",
166            ticks.len()
167        );
168        // All ticks should be non-negative and within domain
169        for tick in &ticks {
170            assert!(*tick >= 0.0, "tick {} should be >= 0", tick);
171            assert!(*tick <= 100.0 + 1e-10, "tick {} should be <= 100", tick);
172        }
173    }
174
175    #[test]
176    fn sqrt_scale_negative_clamped() {
177        let scale = ScaleSqrt::new((0.0, 100.0), (0.0, 100.0));
178        // Negative values should be clamped to 0 before sqrt, so map(-10) == map(0)
179        assert!(
180            (scale.map(-10.0) - scale.map(0.0)).abs() < 1e-10,
181            "negative values should be clamped to 0"
182        );
183    }
184}