gam_terms/analytic_penalties/
nested_prefix.rs1use super::*;
2
3#[derive(Debug, Clone)]
36pub struct NestedPrefixPenalty {
37 pub target: PsiSlice,
38 pub target_tier: PenaltyTier,
39 pub prefix_sizes: Vec<usize>,
41 pub shell_weights: Vec<f64>,
44 pub eps: f64,
47 pub rho_indices: Vec<usize>,
49 pub weight_schedule: Option<ScalarWeightSchedule>,
50}
51
52impl NestedPrefixPenalty {
53 #[must_use = "build error must be handled"]
63 pub fn new(
64 target: PsiSlice,
65 target_tier: PenaltyTier,
66 prefix_sizes: Vec<usize>,
67 shell_weights: Vec<f64>,
68 eps: f64,
69 ) -> Result<Self, String> {
70 if prefix_sizes.is_empty() {
71 return Err("NestedPrefixPenalty requires at least one prefix".into());
72 }
73 if shell_weights.len() != prefix_sizes.len() {
74 return Err(format!(
75 "NestedPrefixPenalty requires shell_weights.len() == prefix_sizes.len(); \
76 got {} weights for {} prefixes",
77 shell_weights.len(),
78 prefix_sizes.len()
79 ));
80 }
81 for w in &shell_weights {
82 if !w.is_finite() || *w < 0.0 {
83 return Err(format!(
84 "NestedPrefixPenalty shell weights must be finite and ≥ 0; got {w}"
85 ));
86 }
87 }
88 for i in 0..prefix_sizes.len() {
89 if prefix_sizes[i] == 0 {
90 return Err("NestedPrefixPenalty prefixes must be > 0".into());
91 }
92 if i > 0 && prefix_sizes[i] <= prefix_sizes[i - 1] {
93 return Err(format!(
94 "NestedPrefixPenalty prefixes must be strictly increasing; got {:?}",
95 prefix_sizes
96 ));
97 }
98 }
99 if let Some(d) = target.latent_dim {
100 let max_prefix = *prefix_sizes.last().expect("non-empty");
101 if max_prefix > d {
102 return Err(format!(
103 "NestedPrefixPenalty largest prefix {max_prefix} exceeds latent_dim {d}"
104 ));
105 }
106 }
107 if !(eps.is_finite() && eps > 0.0) {
108 return Err(format!(
109 "NestedPrefixPenalty requires eps > 0 (1/sqrt(x²+ε²) singularity at 0); got {eps}"
110 ));
111 }
112 let rho_indices = (0..prefix_sizes.len()).collect();
113 Ok(Self {
114 target,
115 target_tier,
116 prefix_sizes,
117 shell_weights,
118 eps,
119 rho_indices,
120 weight_schedule: None,
121 })
122 }
123
124 #[must_use]
127 pub fn with_weight_schedule(mut self, schedule: ScalarWeightSchedule) -> Self {
128 self.weight_schedule = Some(schedule);
129 self
130 }
131
132 fn latent_dim(&self) -> usize {
134 self.target
135 .latent_dim
136 .unwrap_or_else(|| *self.prefix_sizes.last().expect("non-empty"))
137 }
138
139 fn lambdas(&self, rho: ArrayView1<'_, f64>) -> Vec<f64> {
141 self.prefix_sizes
142 .iter()
143 .enumerate()
144 .map(|(k, _)| resolve_learnable_weight(self.shell_weights[k], rho[self.rho_indices[k]]))
145 .collect()
146 }
147
148 fn per_axis_weights(&self, lambdas: &[f64]) -> Vec<f64> {
151 let f = self.latent_dim();
152 let mut w = vec![0.0_f64; f];
153 for (k, &m_k) in self.prefix_sizes.iter().enumerate() {
157 let lam = lambdas[k];
158 if lam == 0.0 {
159 continue;
160 }
161 let end = m_k.min(f);
162 for entry in w.iter_mut().take(end) {
163 *entry += lam;
164 }
165 }
166 w
167 }
168}
169
170impl AnalyticPenalty for NestedPrefixPenalty {
171 fn tier(&self) -> PenaltyTier {
172 self.target_tier
173 }
174
175 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
176 let f = self.latent_dim();
177 assert!(
178 target.len().is_multiple_of(f),
179 "target length must be n_rows · F"
180 );
181 let n_rows = target.len() / f;
182 let lambdas = self.lambdas(rho);
183 let eps2 = self.eps * self.eps;
184 let mut s_axis = vec![0.0_f64; f];
186 for n in 0..n_rows {
187 let row = &target.as_slice().expect("contiguous")[n * f..(n + 1) * f];
188 for (i, &x) in row.iter().enumerate() {
189 s_axis[i] += (x * x + eps2).sqrt();
190 }
191 }
192 let mut total = 0.0;
194 for (k, &m_k) in self.prefix_sizes.iter().enumerate() {
195 let end = m_k.min(f);
196 let mut acc = 0.0;
197 for &v in s_axis.iter().take(end) {
198 acc += v;
199 }
200 total += lambdas[k] * acc;
201 }
202 total
203 }
204
205 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
206 let f = self.latent_dim();
207 let n_rows = target.len() / f;
208 let lambdas = self.lambdas(rho);
209 let w_per_axis = self.per_axis_weights(&lambdas);
210 let eps2 = self.eps * self.eps;
211 let src = target.as_slice().expect("contiguous");
212 let mut g = Array1::<f64>::zeros(target.len());
213 let g_slice = g.as_slice_mut().expect("contiguous");
214 for n in 0..n_rows {
215 for i in 0..f {
216 let x = src[n * f + i];
217 let w = w_per_axis[i];
218 if w == 0.0 {
219 continue;
220 }
221 g_slice[n * f + i] = w * x / (x * x + eps2).sqrt();
222 }
223 }
224 g
225 }
226
227 fn hessian_diag(
228 &self,
229 target: ArrayView1<'_, f64>,
230 rho: ArrayView1<'_, f64>,
231 ) -> Option<Array1<f64>> {
232 let f = self.latent_dim();
233 let n_rows = target.len() / f;
234 let lambdas = self.lambdas(rho);
235 let w_per_axis = self.per_axis_weights(&lambdas);
236 let eps2 = self.eps * self.eps;
237 let src = target.as_slice().expect("contiguous");
238 let mut d = Array1::<f64>::zeros(target.len());
239 let d_slice = d.as_slice_mut().expect("contiguous");
240 for n in 0..n_rows {
241 for i in 0..f {
242 let w = w_per_axis[i];
243 if w == 0.0 {
244 continue;
245 }
246 let x = src[n * f + i];
247 let r = (x * x + eps2).sqrt();
248 d_slice[n * f + i] = w * eps2 / (r * r * r);
249 }
250 }
251 Some(d)
252 }
253
254 fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
255 let f = self.latent_dim();
256 let n_rows = target.len() / f;
257 let lambdas = self.lambdas(rho);
258 let eps2 = self.eps * self.eps;
259 let mut s_axis = vec![0.0_f64; f];
262 let src = target.as_slice().expect("contiguous");
263 for n in 0..n_rows {
264 for i in 0..f {
265 let x = src[n * f + i];
266 s_axis[i] += (x * x + eps2).sqrt();
267 }
268 }
269 let n_rho = self.rho_count();
270 let mut out = Array1::<f64>::zeros(n_rho);
271 for (k, &m_k) in self.prefix_sizes.iter().enumerate() {
272 let end = m_k.min(f);
273 let mut shell_sum = 0.0;
274 for &v in s_axis.iter().take(end) {
275 shell_sum += v;
276 }
277 out[self.rho_indices[k]] = lambdas[k] * shell_sum;
279 }
280 out
281 }
282
283 fn rho_count(&self) -> usize {
284 self.prefix_sizes.len()
285 }
286
287 fn name(&self) -> &str {
288 "nested_prefix"
289 }
290
291 fn apply_schedule(&mut self, iter: usize) {
292 if let Some(schedule) = self.weight_schedule.as_mut() {
293 let prev = schedule.current_weight(schedule.iter_count);
294 let next = schedule.current_weight(iter);
295 if prev > 0.0 {
296 let ratio = next / prev;
297 for w in &mut self.shell_weights {
298 *w *= ratio;
299 }
300 }
301 schedule.iter_count = iter + 1;
302 }
303 }
304}