Skip to main content

gamlss_core/
predictor.rs

1use crate::{DesignMatrix, ModelError};
2
3/// Predictor block for one distribution parameter.
4///
5/// Implementations map a local coefficient slice to a scalar linear predictor
6/// contribution for each observation and know how to propagate per-observation
7/// scores back to that local coefficient slice.
8pub trait PredictorBlock {
9    /// Number of observations.
10    fn nrows(&self) -> usize;
11    /// Number of local coefficients consumed by this block.
12    fn nparams(&self) -> usize;
13    /// Predictor contribution for one row.
14    fn eta_row(&self, row: usize, beta: &[f64]) -> f64;
15    /// Adds the gradient contribution implied by `scores` into `grad`.
16    fn add_gradient(&self, scores: &[f64], beta: &[f64], grad: &mut [f64]);
17
18    /// Validates internal block consistency.
19    fn validate(&self) -> Result<(), ModelError> {
20        Ok(())
21    }
22}
23
24/// Linear predictor block backed by a [`DesignMatrix`].
25///
26/// This is the explicit adapter from matrix-based predictors to the more
27/// general [`PredictorBlock`] extension point.
28#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
29pub struct LinearPredictorBlock<X> {
30    /// Design matrix used by this predictor.
31    pub x: X,
32}
33
34impl<X> LinearPredictorBlock<X> {
35    /// Wraps a design matrix as a predictor block.
36    pub fn new(x: X) -> Self {
37        Self { x }
38    }
39
40    /// Returns the wrapped design matrix.
41    pub fn into_inner(self) -> X {
42        self.x
43    }
44}
45
46impl<X> PredictorBlock for LinearPredictorBlock<X>
47where
48    X: DesignMatrix,
49{
50    fn nrows(&self) -> usize {
51        self.x.nrows()
52    }
53
54    fn nparams(&self) -> usize {
55        self.x.ncols()
56    }
57
58    fn eta_row(&self, row: usize, beta: &[f64]) -> f64 {
59        self.x.dot_row(row, beta)
60    }
61
62    fn add_gradient(&self, scores: &[f64], _: &[f64], grad: &mut [f64]) {
63        self.x.add_t_mul_vec(scores, grad);
64    }
65}
66
67/// Sum of several predictor blocks sharing the same observations.
68///
69/// The local beta slice is split between terms in tuple order. This keeps
70/// nonlinear or sparse user-defined terms composable without dynamic dispatch.
71#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
72pub struct SumBlock<Terms> {
73    /// Predictor terms summed into one parameter predictor.
74    pub terms: Terms,
75}
76
77impl<Terms> SumBlock<Terms> {
78    /// Creates a summed predictor from tuple terms.
79    pub fn new(terms: Terms) -> Self {
80        Self { terms }
81    }
82}
83
84macro_rules! impl_sum_block {
85    (
86        terms = ($($term:ident),+);
87        vars = ($($var:ident),+);
88        indices = ($($idx:tt),+);
89        names = ($($name:literal),+)
90    ) => {
91        impl<$($term,)+> PredictorBlock for SumBlock<($($term,)+)>
92        where
93            $($term: PredictorBlock,)+
94        {
95            fn nrows(&self) -> usize {
96                self.terms.0.nrows()
97            }
98
99            fn nparams(&self) -> usize {
100                0 $(+ self.terms.$idx.nparams())+
101            }
102
103            fn eta_row(&self, row: usize, beta: &[f64]) -> f64 {
104                let mut start = 0;
105                let mut eta = 0.0;
106                $(
107                    let $var = &self.terms.$idx;
108                    let end = start + $var.nparams();
109                    eta += $var.eta_row(row, &beta[start..end]);
110                    start = end;
111                )+
112                let _ = start;
113                eta
114            }
115
116            fn add_gradient(&self, scores: &[f64], beta: &[f64], grad: &mut [f64]) {
117                let mut start = 0;
118                $(
119                    let $var = &self.terms.$idx;
120                    let end = start + $var.nparams();
121                    $var.add_gradient(scores, &beta[start..end], &mut grad[start..end]);
122                    start = end;
123                )+
124                let _ = start;
125            }
126
127            fn validate(&self) -> Result<(), ModelError> {
128                let expected_rows = self.terms.0.nrows();
129                $(
130                    self.terms.$idx.validate()?;
131                    if self.terms.$idx.nrows() != expected_rows {
132                        return Err(ModelError::DesignRowMismatch {
133                            parameter: $name,
134                            expected_rows,
135                            actual_rows: self.terms.$idx.nrows(),
136                        });
137                    }
138                )+
139                Ok(())
140            }
141        }
142    };
143}
144
145impl_sum_block!(
146    terms = (T1);
147    vars = (term1);
148    indices = (0);
149    names = ("sum term")
150);
151
152impl_sum_block!(
153    terms = (T1, T2);
154    vars = (term1, term2);
155    indices = (0, 1);
156    names = ("sum first term", "sum second term")
157);
158
159impl_sum_block!(
160    terms = (T1, T2, T3);
161    vars = (term1, term2, term3);
162    indices = (0, 1, 2);
163    names = ("sum first term", "sum second term", "sum third term")
164);
165
166impl_sum_block!(
167    terms = (T1, T2, T3, T4);
168    vars = (term1, term2, term3, term4);
169    indices = (0, 1, 2, 3);
170    names = (
171        "sum first term",
172        "sum second term",
173        "sum third term",
174        "sum fourth term"
175    )
176);
177
178impl_sum_block!(
179    terms = (T1, T2, T3, T4, T5);
180    vars = (term1, term2, term3, term4, term5);
181    indices = (0, 1, 2, 3, 4);
182    names = (
183        "sum first term",
184        "sum second term",
185        "sum third term",
186        "sum fourth term",
187        "sum fifth term"
188    )
189);
190
191impl_sum_block!(
192    terms = (T1, T2, T3, T4, T5, T6);
193    vars = (term1, term2, term3, term4, term5, term6);
194    indices = (0, 1, 2, 3, 4, 5);
195    names = (
196        "sum first term",
197        "sum second term",
198        "sum third term",
199        "sum fourth term",
200        "sum fifth term",
201        "sum sixth term"
202    )
203);
204
205#[cfg(test)]
206mod tests {
207    use approx::assert_relative_eq;
208
209    use crate::{DenseDesign, PredictorBlock};
210
211    use super::LinearPredictorBlock;
212
213    #[test]
214    fn linear_predictor_block_matches_design_matrix_operations() {
215        let design = DenseDesign::from_rows(&[[1.0, 2.0], [3.0, 4.0]]);
216        let block = LinearPredictorBlock::new(design);
217        let beta = [10.0, 1.0];
218
219        assert_relative_eq!(block.eta_row(1, &beta), 34.0);
220
221        let mut grad = vec![0.0, 0.0];
222        block.add_gradient(&[0.5, 2.0], &beta, &mut grad);
223
224        assert_relative_eq!(grad[0], 6.5);
225        assert_relative_eq!(grad[1], 9.0);
226    }
227}