Skip to main content

gamlss_transform/transforms/
log1p_shift.rs

1use crate::{
2    TargetTransform, TransformError, validate_non_empty_finite, validate_output_len,
3    validate_shifted_non_negative,
4};
5
6/// `log1p` transform with a fitted shift for targets that may contain zero or
7/// negative values.
8///
9/// The fitted shift maps the minimum training value to `margin`, so
10/// `transform(y) = ln(1 + y + shift)`. New values must satisfy
11/// `y + shift >= 0`.
12#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
13pub struct Log1pShift;
14
15/// State for [`Log1pShift`].
16#[derive(Debug, Clone, Copy, PartialEq)]
17pub struct Log1pShiftState {
18    /// Additive shift applied before `ln_1p`.
19    pub shift: f64,
20    /// Small non-negative distance between the shifted training minimum and
21    /// zero.
22    pub margin: f64,
23}
24
25impl Log1pShiftState {
26    /// Lower bound accepted by [`Log1pShift::transform_slice`].
27    #[must_use]
28    #[inline(always)]
29    pub fn lower_bound(self) -> f64 {
30        -self.shift
31    }
32}
33
34impl TargetTransform for Log1pShift {
35    type State = Log1pShiftState;
36
37    fn fit(y: &[f64]) -> Result<Self::State, TransformError> {
38        validate_non_empty_finite(y)?;
39
40        let min = y.iter().copied().fold(f64::INFINITY, f64::min);
41        let margin = if min <= 0.0 { 1.0e-12 } else { 0.0 };
42        let shift = (-min + margin).max(0.0);
43        Ok(Log1pShiftState { shift, margin })
44    }
45
46    #[inline(always)]
47    fn transform(state: &Self::State, y: f64) -> f64 {
48        (y + state.shift).ln_1p()
49    }
50
51    #[inline(always)]
52    fn inverse(state: &Self::State, value: f64) -> f64 {
53        value.exp_m1() - state.shift
54    }
55
56    #[inline]
57    fn transform_slice(state: &Self::State, y: &[f64]) -> Result<Vec<f64>, TransformError> {
58        let mut out = vec![0.0; y.len()];
59        Self::transform_into(state, y, &mut out)?;
60        Ok(out)
61    }
62
63    #[inline]
64    fn transform_into(
65        state: &Self::State,
66        y: &[f64],
67        out: &mut [f64],
68    ) -> Result<(), TransformError> {
69        validate_output_len(y.len(), out.len())?;
70        validate_shifted_non_negative(y, state.shift)?;
71        for (out, value) in out.iter_mut().zip(y.iter().copied()) {
72            *out = Self::transform(state, value);
73        }
74        Ok(())
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use approx::assert_relative_eq;
81
82    use crate::{Log1pShift, TargetTransform, TransformError};
83
84    #[test]
85    fn round_trips_values_with_negative_minimum() {
86        let y = [-3.0, 0.0, 4.0];
87        let (state, transformed) = Log1pShift::fit_transform(&y).unwrap();
88        let restored = Log1pShift::inverse_slice(&state, &transformed).unwrap();
89
90        assert!(state.shift > 0.0);
91        assert_relative_eq!(state.lower_bound(), -state.shift);
92        for (actual, expected) in restored.iter().zip(y) {
93            assert_relative_eq!(*actual, expected, epsilon = 1.0e-12);
94        }
95    }
96
97    #[test]
98    fn rejects_values_below_fitted_lower_bound() {
99        let state = Log1pShift::fit(&[-2.0, 1.0]).unwrap();
100
101        assert_eq!(
102            Log1pShift::transform_slice(&state, &[-2.1]).unwrap_err(),
103            TransformError::BelowLowerBound
104        );
105    }
106}