Skip to main content

ggplot_rs/scale/
transform.rs

1use crate::data::Value;
2
3/// Scale transformation types.
4#[derive(Clone, Debug)]
5pub enum ScaleTransform {
6    Identity,
7    Log10,
8    Log2,
9    Ln,
10    Sqrt,
11    Reverse,
12    /// Logit: `ln(p / (1 - p))`, for proportions in (0, 1).
13    Logit,
14    /// Probit: the inverse normal CDF, for proportions in (0, 1).
15    Probit,
16    /// Sign-preserving pseudo-log (`asinh(x / 2)`) — handles zero and negatives.
17    PseudoLog,
18    /// Reciprocal `1 / x`.
19    Reciprocal,
20    /// Exponential (`exp`); axis labels are therefore spaced logarithmically.
21    Exp,
22    /// Box–Cox with the given lambda: `(x^λ − 1) / λ`, or `ln(x)` at λ = 0 (x > 0).
23    BoxCox(f64),
24}
25
26impl ScaleTransform {
27    /// Apply the forward transformation.
28    pub fn apply(&self, value: f64) -> f64 {
29        match self {
30            ScaleTransform::Identity => value,
31            ScaleTransform::Log10 => {
32                if value > 0.0 {
33                    value.log10()
34                } else {
35                    f64::NEG_INFINITY
36                }
37            }
38            ScaleTransform::Log2 => {
39                if value > 0.0 {
40                    value.log2()
41                } else {
42                    f64::NEG_INFINITY
43                }
44            }
45            ScaleTransform::Ln => {
46                if value > 0.0 {
47                    value.ln()
48                } else {
49                    f64::NEG_INFINITY
50                }
51            }
52            ScaleTransform::Sqrt => {
53                if value >= 0.0 {
54                    value.sqrt()
55                } else {
56                    f64::NAN
57                }
58            }
59            ScaleTransform::Reverse => -value,
60            ScaleTransform::Logit => {
61                if value <= 0.0 {
62                    f64::NEG_INFINITY
63                } else if value >= 1.0 {
64                    f64::INFINITY
65                } else {
66                    (value / (1.0 - value)).ln()
67                }
68            }
69            ScaleTransform::Probit => qnorm(value),
70            ScaleTransform::PseudoLog => (value / 2.0).asinh(),
71            ScaleTransform::Reciprocal => {
72                if value != 0.0 {
73                    1.0 / value
74                } else {
75                    f64::NAN
76                }
77            }
78            ScaleTransform::Exp => value.exp(),
79            ScaleTransform::BoxCox(lambda) => {
80                if value <= 0.0 {
81                    f64::NAN
82                } else if lambda.abs() < 1e-9 {
83                    value.ln()
84                } else {
85                    (value.powf(*lambda) - 1.0) / lambda
86                }
87            }
88        }
89    }
90
91    /// Apply the inverse transformation.
92    pub fn inverse(&self, value: f64) -> f64 {
93        match self {
94            ScaleTransform::Identity => value,
95            ScaleTransform::Log10 => 10f64.powf(value),
96            ScaleTransform::Log2 => 2f64.powf(value),
97            ScaleTransform::Ln => value.exp(),
98            ScaleTransform::Sqrt => value * value,
99            ScaleTransform::Reverse => -value,
100            ScaleTransform::Logit => 1.0 / (1.0 + (-value).exp()),
101            ScaleTransform::Probit => pnorm(value),
102            ScaleTransform::PseudoLog => 2.0 * value.sinh(),
103            ScaleTransform::Reciprocal => {
104                if value != 0.0 {
105                    1.0 / value
106                } else {
107                    f64::NAN
108                }
109            }
110            ScaleTransform::Exp => value.ln(),
111            ScaleTransform::BoxCox(lambda) => {
112                if lambda.abs() < 1e-9 {
113                    value.exp()
114                } else {
115                    (value * lambda + 1.0).powf(1.0 / lambda)
116                }
117            }
118        }
119    }
120
121    /// Transform a Value.
122    pub fn transform_value(&self, value: &Value) -> Value {
123        match value.as_f64() {
124            Some(f) => {
125                let t = self.apply(f);
126                if t.is_finite() {
127                    Value::Float(t)
128                } else {
129                    Value::Na
130                }
131            }
132            None => value.clone(),
133        }
134    }
135
136    pub fn is_identity(&self) -> bool {
137        matches!(self, ScaleTransform::Identity)
138    }
139}
140
141/// Standard normal CDF, via the Abramowitz & Stegun 7.1.26 erf approximation.
142fn pnorm(x: f64) -> f64 {
143    0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
144}
145
146fn erf(x: f64) -> f64 {
147    let sign = if x < 0.0 { -1.0 } else { 1.0 };
148    let x = x.abs();
149    let t = 1.0 / (1.0 + 0.327_591_1 * x);
150    let y = 1.0
151        - (((((1.061_405_429 * t - 1.453_152_027) * t) + 1.421_413_741) * t - 0.284_496_736) * t
152            + 0.254_829_592)
153            * t
154            * (-x * x).exp();
155    sign * y
156}
157
158/// Inverse normal CDF (probit), Abramowitz & Stegun 26.2.23 rational approximation.
159fn qnorm(p: f64) -> f64 {
160    if p <= 0.0 {
161        return f64::NEG_INFINITY;
162    }
163    if p >= 1.0 {
164        return f64::INFINITY;
165    }
166    if p < 0.5 {
167        -rational_approx((-2.0 * p.ln()).sqrt())
168    } else if p > 0.5 {
169        rational_approx((-2.0 * (1.0 - p).ln()).sqrt())
170    } else {
171        0.0
172    }
173}
174
175fn rational_approx(t: f64) -> f64 {
176    let c0 = 2.515_517;
177    let c1 = 0.802_853;
178    let c2 = 0.010_328;
179    let d1 = 1.432_788;
180    let d2 = 0.189_269;
181    let d3 = 0.001_308;
182    t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t)
183}
184
185#[cfg(test)]
186mod tests {
187    use super::ScaleTransform::*;
188
189    fn roundtrip(t: super::ScaleTransform, v: f64, tol: f64) {
190        let back = t.inverse(t.apply(v));
191        assert!(
192            (back - v).abs() < tol,
193            "{t:?}: {v} -> {} -> {back}",
194            t.apply(v)
195        );
196    }
197
198    #[test]
199    fn transforms_roundtrip() {
200        roundtrip(Logit, 0.3, 1e-9);
201        roundtrip(Probit, 0.3, 1e-2); // approximation
202        roundtrip(PseudoLog, -4.0, 1e-9);
203        roundtrip(PseudoLog, 0.0, 1e-9);
204        roundtrip(Reciprocal, 2.5, 1e-9);
205        roundtrip(Exp, 1.7, 1e-9);
206        roundtrip(BoxCox(0.5), 4.0, 1e-9);
207        roundtrip(BoxCox(0.0), 4.0, 1e-9); // lambda 0 == ln
208    }
209
210    #[test]
211    fn transforms_domain_edges() {
212        assert_eq!(Logit.apply(0.0), f64::NEG_INFINITY);
213        assert_eq!(Logit.apply(1.0), f64::INFINITY);
214        assert!(Reciprocal.apply(0.0).is_nan());
215        assert!(BoxCox(0.5).apply(-1.0).is_nan());
216        assert_eq!(Probit.apply(0.5), 0.0);
217        // Box-Cox at lambda 0 is ln.
218        assert!((BoxCox(0.0).apply(std::f64::consts::E) - 1.0).abs() < 1e-9);
219    }
220}