1use std::ops::Range;
2
3use gamlss_core::{DenseDesign, DesignMatrix, Link, ModelError, PredictorBlock, Softplus};
4use gamlss_spline::{ISplineBasis, MonotoneDirection};
5
6#[derive(Debug, Clone, PartialEq)]
9pub struct FormulaPredictorBlock {
10 dense: DenseDesign,
11 offset: Option<Vec<f64>>,
12 monotone: Vec<MonotoneSegment>,
13 nparams: usize,
14}
15
16#[derive(Debug, Clone, PartialEq)]
17pub(crate) struct MonotoneSegment {
18 pub(crate) range: Range<usize>,
19 pub(crate) values: Vec<f64>,
20 pub(crate) basis: ISplineBasis,
21 pub(crate) direction: MonotoneDirection,
22}
23
24impl FormulaPredictorBlock {
25 #[must_use]
27 pub(crate) fn new(
28 dense: DenseDesign,
29 offset: Option<Vec<f64>>,
30 monotone: Vec<MonotoneSegment>,
31 nparams: usize,
32 ) -> Self {
33 Self {
34 dense,
35 offset,
36 monotone,
37 nparams,
38 }
39 }
40
41 #[must_use]
43 pub fn dense(&self) -> &DenseDesign {
44 &self.dense
45 }
46
47 fn monotone_eta(segment: &MonotoneSegment, row: usize, beta: &[f64]) -> f64 {
48 debug_assert_eq!(beta.len(), segment.range.len());
49
50 let sign = monotone_sign(segment.direction);
51 beta[0]
52 + segment
53 .basis
54 .evaluate(segment.values[row])
55 .iter()
56 .zip(&beta[1..])
57 .map(|(basis, beta)| sign * Softplus::inverse(*beta) * basis)
58 .sum::<f64>()
59 }
60
61 fn add_monotone_gradient(
62 segment: &MonotoneSegment,
63 scores: &[f64],
64 multiplier: Option<&[f64]>,
65 beta: &[f64],
66 grad: &mut [f64],
67 ) {
68 debug_assert_eq!(beta.len(), segment.range.len());
69 debug_assert_eq!(grad.len(), segment.range.len());
70
71 let sign = monotone_sign(segment.direction);
72 for (row, score) in scores.iter().copied().enumerate() {
73 let score = multiplier.map_or(score, |multiplier| score * multiplier[row]);
74 grad[0] += score;
75 for (index, basis) in segment
76 .basis
77 .evaluate(segment.values[row])
78 .iter()
79 .enumerate()
80 {
81 grad[index + 1] +=
82 score * sign * basis * Softplus::derivative_inverse(beta[index + 1]);
83 }
84 }
85 }
86}
87
88impl PredictorBlock for FormulaPredictorBlock {
89 fn nrows(&self) -> usize {
90 self.dense.nrows()
91 }
92
93 fn nparams(&self) -> usize {
94 self.nparams
95 }
96
97 fn eta_row(&self, row: usize, beta: &[f64]) -> f64 {
98 debug_assert_eq!(beta.len(), self.nparams);
99
100 let dense_ncols = self.dense.ncols();
101 let mut eta = self.dense.dot_row(row, &beta[..dense_ncols]);
102 if let Some(offset) = &self.offset {
103 eta += offset[row];
104 }
105 for segment in &self.monotone {
106 eta += Self::monotone_eta(segment, row, &beta[segment.range.clone()]);
107 }
108 eta
109 }
110
111 fn add_gradient(&self, scores: &[f64], beta: &[f64], grad: &mut [f64]) {
112 debug_assert_eq!(scores.len(), self.nrows());
113 debug_assert_eq!(beta.len(), self.nparams);
114 debug_assert_eq!(grad.len(), self.nparams);
115
116 let dense_ncols = self.dense.ncols();
117 self.dense.add_t_mul_vec(scores, &mut grad[..dense_ncols]);
118 for segment in &self.monotone {
119 Self::add_monotone_gradient(
120 segment,
121 scores,
122 None,
123 &beta[segment.range.clone()],
124 &mut grad[segment.range.clone()],
125 );
126 }
127 }
128
129 fn add_weighted_gradient(
130 &self,
131 scores: &[f64],
132 multiplier: &[f64],
133 beta: &[f64],
134 grad: &mut [f64],
135 ) {
136 debug_assert_eq!(scores.len(), self.nrows());
137 debug_assert_eq!(multiplier.len(), self.nrows());
138 debug_assert_eq!(beta.len(), self.nparams);
139 debug_assert_eq!(grad.len(), self.nparams);
140
141 let dense_ncols = self.dense.ncols();
142 self.dense
143 .add_weighted_t_mul_vec(scores, multiplier, &mut grad[..dense_ncols]);
144 for segment in &self.monotone {
145 Self::add_monotone_gradient(
146 segment,
147 scores,
148 Some(multiplier),
149 &beta[segment.range.clone()],
150 &mut grad[segment.range.clone()],
151 );
152 }
153 }
154
155 fn validate(&self) -> Result<(), ModelError> {
156 let nrows = self.nrows();
157 let dense_ncols = self.dense.ncols();
158 if dense_ncols > self.nparams {
159 return Err(ModelError::InvalidParameter {
160 parameter: "formula predictor",
161 expected: "dense columns <= local parameter count",
162 });
163 }
164
165 if let Some(offset) = &self.offset
166 && offset.len() != nrows
167 {
168 return Err(ModelError::DesignRowMismatch {
169 parameter: "formula offset",
170 expected_rows: nrows,
171 actual_rows: offset.len(),
172 });
173 }
174
175 for segment in &self.monotone {
176 if segment.range.end > self.nparams {
177 return Err(ModelError::BlockRangeOverflow {
178 parameter: "formula monotone",
179 offset: segment.range.start,
180 len: segment.range.len(),
181 });
182 }
183 if segment.range.start < dense_ncols {
184 return Err(ModelError::BlockOverlap {
185 first: "formula dense",
186 second: "formula monotone",
187 });
188 }
189 if segment.values.len() != nrows {
190 return Err(ModelError::DesignRowMismatch {
191 parameter: "formula monotone",
192 expected_rows: nrows,
193 actual_rows: segment.values.len(),
194 });
195 }
196 }
197
198 Ok(())
199 }
200}
201
202fn monotone_sign(direction: MonotoneDirection) -> f64 {
203 match direction {
204 MonotoneDirection::Increasing => 1.0,
205 MonotoneDirection::Decreasing => -1.0,
206 }
207}