Skip to main content

gamlss_formula/
predictor.rs

1use std::ops::Range;
2
3use gamlss_core::{DenseDesign, DesignMatrix, Link, ModelError, PredictorBlock, Softplus};
4use gamlss_spline::{ISplineBasis, MonotoneDirection};
5
6/// Formula predictor composed from a dense linear design, optional row offsets,
7/// and nonlinear monotone spline segments.
8#[derive(Debug, Clone, PartialEq)]
9pub struct FormulaPredictorBlock {
10    dense: DenseDesign,
11    offset: Option<Vec<f64>>,
12    monotone: Vec<MonotoneSegment>,
13    nparams: usize,
14}
15
16#[derive(Debug, Clone, PartialEq)]
17pub(crate) struct MonotoneSegment {
18    pub(crate) range: Range<usize>,
19    pub(crate) values: Vec<f64>,
20    pub(crate) basis: ISplineBasis,
21    pub(crate) direction: MonotoneDirection,
22}
23
24impl FormulaPredictorBlock {
25    /// Creates a formula predictor block.
26    #[must_use]
27    pub(crate) fn new(
28        dense: DenseDesign,
29        offset: Option<Vec<f64>>,
30        monotone: Vec<MonotoneSegment>,
31        nparams: usize,
32    ) -> Self {
33        Self {
34            dense,
35            offset,
36            monotone,
37            nparams,
38        }
39    }
40
41    /// Returns the dense linear part.
42    #[must_use]
43    pub fn dense(&self) -> &DenseDesign {
44        &self.dense
45    }
46
47    fn monotone_eta(segment: &MonotoneSegment, row: usize, beta: &[f64]) -> f64 {
48        debug_assert_eq!(beta.len(), segment.range.len());
49
50        let sign = monotone_sign(segment.direction);
51        beta[0]
52            + segment
53                .basis
54                .evaluate(segment.values[row])
55                .iter()
56                .zip(&beta[1..])
57                .map(|(basis, beta)| sign * Softplus::inverse(*beta) * basis)
58                .sum::<f64>()
59    }
60
61    fn add_monotone_gradient(
62        segment: &MonotoneSegment,
63        scores: &[f64],
64        multiplier: Option<&[f64]>,
65        beta: &[f64],
66        grad: &mut [f64],
67    ) {
68        debug_assert_eq!(beta.len(), segment.range.len());
69        debug_assert_eq!(grad.len(), segment.range.len());
70
71        let sign = monotone_sign(segment.direction);
72        for (row, score) in scores.iter().copied().enumerate() {
73            let score = multiplier.map_or(score, |multiplier| score * multiplier[row]);
74            grad[0] += score;
75            for (index, basis) in segment
76                .basis
77                .evaluate(segment.values[row])
78                .iter()
79                .enumerate()
80            {
81                grad[index + 1] +=
82                    score * sign * basis * Softplus::derivative_inverse(beta[index + 1]);
83            }
84        }
85    }
86}
87
88impl PredictorBlock for FormulaPredictorBlock {
89    fn nrows(&self) -> usize {
90        self.dense.nrows()
91    }
92
93    fn nparams(&self) -> usize {
94        self.nparams
95    }
96
97    fn eta_row(&self, row: usize, beta: &[f64]) -> f64 {
98        debug_assert_eq!(beta.len(), self.nparams);
99
100        let dense_ncols = self.dense.ncols();
101        let mut eta = self.dense.dot_row(row, &beta[..dense_ncols]);
102        if let Some(offset) = &self.offset {
103            eta += offset[row];
104        }
105        for segment in &self.monotone {
106            eta += Self::monotone_eta(segment, row, &beta[segment.range.clone()]);
107        }
108        eta
109    }
110
111    fn add_gradient(&self, scores: &[f64], beta: &[f64], grad: &mut [f64]) {
112        debug_assert_eq!(scores.len(), self.nrows());
113        debug_assert_eq!(beta.len(), self.nparams);
114        debug_assert_eq!(grad.len(), self.nparams);
115
116        let dense_ncols = self.dense.ncols();
117        self.dense.add_t_mul_vec(scores, &mut grad[..dense_ncols]);
118        for segment in &self.monotone {
119            Self::add_monotone_gradient(
120                segment,
121                scores,
122                None,
123                &beta[segment.range.clone()],
124                &mut grad[segment.range.clone()],
125            );
126        }
127    }
128
129    fn add_weighted_gradient(
130        &self,
131        scores: &[f64],
132        multiplier: &[f64],
133        beta: &[f64],
134        grad: &mut [f64],
135    ) {
136        debug_assert_eq!(scores.len(), self.nrows());
137        debug_assert_eq!(multiplier.len(), self.nrows());
138        debug_assert_eq!(beta.len(), self.nparams);
139        debug_assert_eq!(grad.len(), self.nparams);
140
141        let dense_ncols = self.dense.ncols();
142        self.dense
143            .add_weighted_t_mul_vec(scores, multiplier, &mut grad[..dense_ncols]);
144        for segment in &self.monotone {
145            Self::add_monotone_gradient(
146                segment,
147                scores,
148                Some(multiplier),
149                &beta[segment.range.clone()],
150                &mut grad[segment.range.clone()],
151            );
152        }
153    }
154
155    fn validate(&self) -> Result<(), ModelError> {
156        let nrows = self.nrows();
157        let dense_ncols = self.dense.ncols();
158        if dense_ncols > self.nparams {
159            return Err(ModelError::InvalidParameter {
160                parameter: "formula predictor",
161                expected: "dense columns <= local parameter count",
162            });
163        }
164
165        if let Some(offset) = &self.offset
166            && offset.len() != nrows
167        {
168            return Err(ModelError::DesignRowMismatch {
169                parameter: "formula offset",
170                expected_rows: nrows,
171                actual_rows: offset.len(),
172            });
173        }
174
175        for segment in &self.monotone {
176            if segment.range.end > self.nparams {
177                return Err(ModelError::BlockRangeOverflow {
178                    parameter: "formula monotone",
179                    offset: segment.range.start,
180                    len: segment.range.len(),
181                });
182            }
183            if segment.range.start < dense_ncols {
184                return Err(ModelError::BlockOverlap {
185                    first: "formula dense",
186                    second: "formula monotone",
187                });
188            }
189            if segment.values.len() != nrows {
190                return Err(ModelError::DesignRowMismatch {
191                    parameter: "formula monotone",
192                    expected_rows: nrows,
193                    actual_rows: segment.values.len(),
194                });
195            }
196        }
197
198        Ok(())
199    }
200}
201
202fn monotone_sign(direction: MonotoneDirection) -> f64 {
203    match direction {
204        MonotoneDirection::Increasing => 1.0,
205        MonotoneDirection::Decreasing => -1.0,
206    }
207}