Skip to main content

gamlss_transform/transforms/
log1p_shift.rs

1use crate::{
2    TargetTransform, TransformError, validate_non_empty_finite, validate_output_len,
3    validate_shifted_non_negative,
4};
5
6/// `log1p` transform with a fitted shift for targets that may contain zero or
7/// negative values.
8///
9/// The fitted shift maps the minimum training value to `margin`, so
10/// `transform(y) = ln(1 + y + shift)`. New values must satisfy
11/// `y + shift >= 0`.
12#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
13pub struct Log1pShift;
14
15/// State for [`Log1pShift`].
16#[derive(Debug, Clone, Copy, PartialEq)]
17pub struct Log1pShiftState {
18    /// Additive shift applied before `ln_1p`.
19    pub shift: f64,
20    /// Small non-negative distance between the shifted training minimum and
21    /// zero.
22    pub margin: f64,
23}
24
25impl Log1pShiftState {
26    /// Lower bound accepted by [`Log1pShift::transform_slice`].
27    #[must_use]
28    pub fn lower_bound(self) -> f64 {
29        -self.shift
30    }
31}
32
33impl TargetTransform for Log1pShift {
34    type State = Log1pShiftState;
35
36    fn fit(y: &[f64]) -> Result<Self::State, TransformError> {
37        validate_non_empty_finite(y)?;
38
39        let min = y.iter().copied().fold(f64::INFINITY, f64::min);
40        let margin = if min <= 0.0 { 1.0e-12 } else { 0.0 };
41        let shift = (-min + margin).max(0.0);
42        Ok(Log1pShiftState { shift, margin })
43    }
44
45    fn transform(state: &Self::State, y: f64) -> f64 {
46        (y + state.shift).ln_1p()
47    }
48
49    fn inverse(state: &Self::State, value: f64) -> f64 {
50        value.exp_m1() - state.shift
51    }
52
53    fn transform_slice(state: &Self::State, y: &[f64]) -> Result<Vec<f64>, TransformError> {
54        let mut out = vec![0.0; y.len()];
55        Self::transform_into(state, y, &mut out)?;
56        Ok(out)
57    }
58
59    fn transform_into(
60        state: &Self::State,
61        y: &[f64],
62        out: &mut [f64],
63    ) -> Result<(), TransformError> {
64        validate_output_len(y.len(), out.len())?;
65        validate_shifted_non_negative(y, state.shift)?;
66        for (out, value) in out.iter_mut().zip(y.iter().copied()) {
67            *out = Self::transform(state, value);
68        }
69        Ok(())
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use approx::assert_relative_eq;
76
77    use crate::{Log1pShift, TargetTransform, TransformError};
78
79    #[test]
80    fn round_trips_values_with_negative_minimum() {
81        let y = [-3.0, 0.0, 4.0];
82        let (state, transformed) = Log1pShift::fit_transform(&y).unwrap();
83        let restored = Log1pShift::inverse_slice(&state, &transformed).unwrap();
84
85        assert!(state.shift > 0.0);
86        assert_relative_eq!(state.lower_bound(), -state.shift);
87        for (actual, expected) in restored.iter().zip(y) {
88            assert_relative_eq!(*actual, expected, epsilon = 1.0e-12);
89        }
90    }
91
92    #[test]
93    fn rejects_values_below_fitted_lower_bound() {
94        let state = Log1pShift::fit(&[-2.0, 1.0]).unwrap();
95
96        assert_eq!(
97            Log1pShift::transform_slice(&state, &[-2.1]).unwrap_err(),
98            TransformError::BelowLowerBound
99        );
100    }
101}