1use crate::{DesignMatrix, ModelError};
2
3pub trait PredictorBlock {
9 fn nrows(&self) -> usize;
11 fn nparams(&self) -> usize;
13 fn eta_row(&self, row: usize, beta: &[f64]) -> f64;
15 fn add_gradient(&self, scores: &[f64], beta: &[f64], grad: &mut [f64]);
17
18 fn validate(&self) -> Result<(), ModelError> {
20 Ok(())
21 }
22}
23
24#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
29pub struct LinearPredictorBlock<X> {
30 pub x: X,
32}
33
34impl<X> LinearPredictorBlock<X> {
35 pub fn new(x: X) -> Self {
37 Self { x }
38 }
39
40 pub fn into_inner(self) -> X {
42 self.x
43 }
44}
45
46impl<X> PredictorBlock for LinearPredictorBlock<X>
47where
48 X: DesignMatrix,
49{
50 fn nrows(&self) -> usize {
51 self.x.nrows()
52 }
53
54 fn nparams(&self) -> usize {
55 self.x.ncols()
56 }
57
58 fn eta_row(&self, row: usize, beta: &[f64]) -> f64 {
59 self.x.dot_row(row, beta)
60 }
61
62 fn add_gradient(&self, scores: &[f64], _: &[f64], grad: &mut [f64]) {
63 self.x.add_t_mul_vec(scores, grad);
64 }
65}
66
67#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
72pub struct SumBlock<Terms> {
73 pub terms: Terms,
75}
76
77impl<Terms> SumBlock<Terms> {
78 pub fn new(terms: Terms) -> Self {
80 Self { terms }
81 }
82}
83
84macro_rules! impl_sum_block {
85 (
86 terms = ($($term:ident),+);
87 vars = ($($var:ident),+);
88 indices = ($($idx:tt),+);
89 names = ($($name:literal),+)
90 ) => {
91 impl<$($term,)+> PredictorBlock for SumBlock<($($term,)+)>
92 where
93 $($term: PredictorBlock,)+
94 {
95 fn nrows(&self) -> usize {
96 self.terms.0.nrows()
97 }
98
99 fn nparams(&self) -> usize {
100 0 $(+ self.terms.$idx.nparams())+
101 }
102
103 fn eta_row(&self, row: usize, beta: &[f64]) -> f64 {
104 let mut start = 0;
105 let mut eta = 0.0;
106 $(
107 let $var = &self.terms.$idx;
108 let end = start + $var.nparams();
109 eta += $var.eta_row(row, &beta[start..end]);
110 start = end;
111 )+
112 let _ = start;
113 eta
114 }
115
116 fn add_gradient(&self, scores: &[f64], beta: &[f64], grad: &mut [f64]) {
117 let mut start = 0;
118 $(
119 let $var = &self.terms.$idx;
120 let end = start + $var.nparams();
121 $var.add_gradient(scores, &beta[start..end], &mut grad[start..end]);
122 start = end;
123 )+
124 let _ = start;
125 }
126
127 fn validate(&self) -> Result<(), ModelError> {
128 let expected_rows = self.terms.0.nrows();
129 $(
130 self.terms.$idx.validate()?;
131 if self.terms.$idx.nrows() != expected_rows {
132 return Err(ModelError::DesignRowMismatch {
133 parameter: $name,
134 expected_rows,
135 actual_rows: self.terms.$idx.nrows(),
136 });
137 }
138 )+
139 Ok(())
140 }
141 }
142 };
143}
144
145impl_sum_block!(
146 terms = (T1);
147 vars = (term1);
148 indices = (0);
149 names = ("sum term")
150);
151
152impl_sum_block!(
153 terms = (T1, T2);
154 vars = (term1, term2);
155 indices = (0, 1);
156 names = ("sum first term", "sum second term")
157);
158
159impl_sum_block!(
160 terms = (T1, T2, T3);
161 vars = (term1, term2, term3);
162 indices = (0, 1, 2);
163 names = ("sum first term", "sum second term", "sum third term")
164);
165
166impl_sum_block!(
167 terms = (T1, T2, T3, T4);
168 vars = (term1, term2, term3, term4);
169 indices = (0, 1, 2, 3);
170 names = (
171 "sum first term",
172 "sum second term",
173 "sum third term",
174 "sum fourth term"
175 )
176);
177
178impl_sum_block!(
179 terms = (T1, T2, T3, T4, T5);
180 vars = (term1, term2, term3, term4, term5);
181 indices = (0, 1, 2, 3, 4);
182 names = (
183 "sum first term",
184 "sum second term",
185 "sum third term",
186 "sum fourth term",
187 "sum fifth term"
188 )
189);
190
191impl_sum_block!(
192 terms = (T1, T2, T3, T4, T5, T6);
193 vars = (term1, term2, term3, term4, term5, term6);
194 indices = (0, 1, 2, 3, 4, 5);
195 names = (
196 "sum first term",
197 "sum second term",
198 "sum third term",
199 "sum fourth term",
200 "sum fifth term",
201 "sum sixth term"
202 )
203);
204
205#[cfg(test)]
206mod tests {
207 use approx::assert_relative_eq;
208
209 use crate::{DenseDesign, PredictorBlock};
210
211 use super::LinearPredictorBlock;
212
213 #[test]
214 fn linear_predictor_block_matches_design_matrix_operations() {
215 let design = DenseDesign::from_rows(&[[1.0, 2.0], [3.0, 4.0]]);
216 let block = LinearPredictorBlock::new(design);
217 let beta = [10.0, 1.0];
218
219 assert_relative_eq!(block.eta_row(1, &beta), 34.0);
220
221 let mut grad = vec![0.0, 0.0];
222 block.add_gradient(&[0.5, 2.0], &beta, &mut grad);
223
224 assert_relative_eq!(grad[0], 6.5);
225 assert_relative_eq!(grad[1], 9.0);
226 }
227}