Skip to main content

augurs_forecaster/transforms/
exp.rs

1//! Exponential transformations, including log and logit.
2
3use std::fmt;
4
5use super::{Error, Transformer};
6
7// Logit and logistic functions.
8
9/// Returns the logistic function of the given value.
10fn logistic(x: f64) -> f64 {
11    1.0 / (1.0 + (-x).exp())
12}
13
14/// Returns the logit function of the given value.
15fn logit(x: f64) -> f64 {
16    (x / (1.0 - x)).ln()
17}
18
19/// The logit transform.
20#[derive(Clone, Default)]
21pub struct Logit {
22    _priv: (),
23}
24
25impl Logit {
26    /// Create a new logit transform.
27    pub fn new() -> Self {
28        Self::default()
29    }
30}
31
32impl fmt::Debug for Logit {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        f.debug_struct("Logit").finish()
35    }
36}
37
38impl Transformer for Logit {
39    fn fit(&mut self, _data: &[f64]) -> Result<(), Error> {
40        Ok(())
41    }
42
43    fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
44        data.iter_mut().for_each(|x| *x = logit(*x));
45        Ok(())
46    }
47
48    fn inverse_transform(&self, data: &mut [f64]) -> Result<(), Error> {
49        data.iter_mut().for_each(|x| *x = logistic(*x));
50        Ok(())
51    }
52}
53
54/// The log transform.
55#[derive(Clone, Default)]
56pub struct Log {
57    _priv: (),
58}
59
60impl Log {
61    /// Create a new log transform.
62    pub fn new() -> Self {
63        Self::default()
64    }
65}
66
67impl fmt::Debug for Log {
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        f.debug_struct("Log").finish()
70    }
71}
72
73impl Transformer for Log {
74    fn fit(&mut self, _data: &[f64]) -> Result<(), Error> {
75        Ok(())
76    }
77
78    fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
79        data.iter_mut().for_each(|x| *x = f64::ln(*x));
80        Ok(())
81    }
82
83    fn inverse_transform(&self, data: &mut [f64]) -> Result<(), Error> {
84        data.iter_mut().for_each(|x| *x = f64::exp(*x));
85        Ok(())
86    }
87}
88
89#[cfg(test)]
90mod test {
91    use augurs_testing::{assert_all_close, assert_approx_eq};
92
93    use super::*;
94
95    #[test]
96    fn test_logistic() {
97        let x = 0.0;
98        let expected = 0.5;
99        let actual = logistic(x);
100        assert_approx_eq!(expected, actual);
101        let x = 1.0;
102        let expected = 1.0 / (1.0 + (-1.0_f64).exp());
103        let actual = logistic(x);
104        assert_approx_eq!(expected, actual);
105        let x = -1.0;
106        let expected = 1.0 / (1.0 + 1.0_f64.exp());
107        let actual = logistic(x);
108        assert_approx_eq!(expected, actual);
109    }
110
111    #[test]
112    fn test_logistic_nan() {
113        assert!(logistic(f64::NAN).is_nan());
114    }
115
116    #[test]
117    fn test_logit() {
118        let x = 0.5;
119        let expected = 0.0;
120        let actual = logit(x);
121        assert_eq!(expected, actual);
122        let x = 0.75;
123        let expected = (0.75_f64 / (1.0 - 0.75)).ln();
124        let actual = logit(x);
125        assert_eq!(expected, actual);
126        let x = 0.25;
127        let expected = (0.25_f64 / (1.0 - 0.25)).ln();
128        let actual = logit(x);
129        assert_eq!(expected, actual);
130    }
131
132    #[test]
133    fn test_logit_nan() {
134        assert!(logit(f64::NAN).is_nan());
135    }
136
137    #[test]
138    fn logit_transform() {
139        let mut data = vec![0.5, 0.75, 0.25];
140        let expected = vec![
141            0.0_f64,
142            (0.75_f64 / (1.0 - 0.75)).ln(),
143            (0.25_f64 / (1.0 - 0.25)).ln(),
144        ];
145        Logit::new()
146            .transform(&mut data)
147            .expect("failed to logit transform");
148        assert_all_close(&expected, &data);
149    }
150
151    #[test]
152    fn logit_inverse_transform() {
153        let mut data = vec![0.0, 1.0, -1.0];
154        let expected = vec![
155            0.5_f64,
156            1.0 / (1.0 + (-1.0_f64).exp()),
157            1.0 / (1.0 + 1.0_f64.exp()),
158        ];
159        Logit::new()
160            .inverse_transform(&mut data)
161            .expect("failed to inverse logit transform");
162        assert_all_close(&expected, &data);
163    }
164
165    #[test]
166    fn log_transform() {
167        let mut data = vec![1.0, 2.0, 3.0];
168        let expected = vec![0.0_f64, 2.0_f64.ln(), 3.0_f64.ln()];
169        Log::new()
170            .transform(&mut data)
171            .expect("failed to log transform");
172        assert_all_close(&expected, &data);
173    }
174
175    #[test]
176    fn log_inverse_transform() {
177        let mut data = vec![0.0, 2.0_f64.ln(), 3.0_f64.ln()];
178        let expected = vec![1.0, 2.0, 3.0];
179        Log::new()
180            .inverse_transform(&mut data)
181            .expect("failed to inverse log transform");
182        assert_all_close(&expected, &data);
183    }
184}