1use gam_linalg::faer_ndarray::{fast_xt_diag_x, fast_xt_diag_y};
18use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
19
20pub struct FixedDesignGramCache {
27 xtwx: Array2<f64>,
28 xtwy: Array1<f64>,
29 ywy: f64,
30 n: usize,
31 p: usize,
32}
33
34impl FixedDesignGramCache {
35 pub fn build(
40 x: ArrayView2<'_, f64>,
41 y: ArrayView1<'_, f64>,
42 offset: Option<ArrayView1<'_, f64>>,
43 weights: Option<ArrayView1<'_, f64>>,
44 ) -> Result<Self, String> {
45 let n = x.nrows();
46 let p = x.ncols();
47 if y.len() != n {
48 return Err(format!(
49 "y length {} must match design row count {}",
50 y.len(),
51 n
52 ));
53 }
54 if let Some(offset_values) = offset {
55 if offset_values.len() != n {
56 return Err(format!(
57 "offset length {} must match design row count {}",
58 offset_values.len(),
59 n
60 ));
61 }
62 }
63 if let Some(weight_values) = weights {
64 if weight_values.len() != n {
65 return Err(format!(
66 "weights length {} must match design row count {}",
67 weight_values.len(),
68 n
69 ));
70 }
71 validate_nonnegative_finite_weights(weight_values)?;
72 }
73 validate_finite_vector("y", y)?;
74 if let Some(offset_values) = offset {
75 validate_finite_vector("offset", offset_values)?;
76 }
77 validate_finite_matrix("x", x)?;
78
79 let r = match offset {
80 Some(offset_values) => &y.to_owned() - &offset_values.to_owned(),
81 None => y.to_owned(),
82 };
83 let w = match weights {
84 Some(weight_values) => weight_values.to_owned(),
85 None => Array1::ones(n),
86 };
87 let x_owned = x.to_owned();
88 let xtwx = fast_xt_diag_x(&x_owned, &w);
89 let r2 = r.view().insert_axis(ndarray::Axis(1));
90 let xtwy_mat = fast_xt_diag_y(&x_owned, &w, &r2);
91 let xtwy = xtwy_mat.column(0).to_owned();
92 let ywy = weighted_sum_squares(w.view(), r.view());
93
94 Ok(Self {
95 xtwx,
96 xtwy,
97 ywy,
98 n,
99 p,
100 })
101 }
102
103 pub fn n(&self) -> usize {
104 self.n
105 }
106
107 pub fn p(&self) -> usize {
108 self.p
109 }
110
111 pub fn xtwx(&self) -> ArrayView2<'_, f64> {
112 self.xtwx.view()
113 }
114
115 pub fn xtwy(&self) -> ArrayView1<'_, f64> {
116 self.xtwy.view()
117 }
118
119 pub fn ywy(&self) -> f64 {
120 self.ywy
121 }
122
123 pub fn penalized_normal_matrix(
125 &self,
126 penalty: ArrayView2<'_, f64>,
127 ) -> Result<Array2<f64>, String> {
128 if penalty.nrows() != self.p || penalty.ncols() != self.p {
129 return Err(format!(
130 "penalty shape {}x{} must match {}x{}",
131 penalty.nrows(),
132 penalty.ncols(),
133 self.p,
134 self.p
135 ));
136 }
137 let mut normal = self.xtwx.clone();
138 normal += &penalty;
139 Ok(normal)
140 }
141
142 pub fn penalized_rss(&self, beta: ArrayView1<'_, f64>) -> Result<f64, String> {
144 if beta.len() != self.p {
145 return Err(format!(
146 "beta length {} must match design column count {}",
147 beta.len(),
148 self.p
149 ));
150 }
151 let gram_beta = self.xtwx.dot(&beta);
153 let linear = beta.dot(&self.xtwy);
154 let quadratic = beta.dot(&gram_beta);
155 Ok(self.ywy - 2.0 * linear + quadratic)
156 }
157}
158
159pub struct FixedDesignRowCache {
167 x: Array2<f64>,
168 n: usize,
169 p: usize,
170}
171
172impl FixedDesignRowCache {
173 pub fn build(x: ArrayView2<'_, f64>) -> Result<Self, String> {
175 if x.nrows() == 0 || x.ncols() == 0 {
176 return Err(format!(
177 "design must be non-empty, got shape {}x{}",
178 x.nrows(),
179 x.ncols()
180 ));
181 }
182 validate_finite_matrix("x", x)?;
183 let n = x.nrows();
184 let p = x.ncols();
185 Ok(Self {
186 x: x.to_owned(),
187 n,
188 p,
189 })
190 }
191
192 pub fn n(&self) -> usize {
193 self.n
194 }
195
196 pub fn p(&self) -> usize {
197 self.p
198 }
199
200 pub fn design(&self) -> ArrayView2<'_, f64> {
201 self.x.view()
202 }
203
204 pub fn xtwx(&self, weights: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
209 self.validate_changing_weights(weights)?;
210 Ok(fast_xt_diag_x(&self.x, &weights))
211 }
212
213 pub fn xtwz(
215 &self,
216 weights: ArrayView1<'_, f64>,
217 z: ArrayView1<'_, f64>,
218 ) -> Result<Array1<f64>, String> {
219 self.validate_changing_weights(weights)?;
220 if z.len() != self.n {
221 return Err(format!(
222 "z length {} must match design row count {}",
223 z.len(),
224 self.n
225 ));
226 }
227 validate_finite_vector("z", z)?;
228 let z2 = z.insert_axis(ndarray::Axis(1));
229 let xtwz_mat = fast_xt_diag_y(&self.x, &weights, &z2);
230 Ok(xtwz_mat.column(0).to_owned())
231 }
232
233 fn validate_changing_weights(&self, weights: ArrayView1<'_, f64>) -> Result<(), String> {
234 if weights.len() != self.n {
235 return Err(format!(
236 "weights length {} must match design row count {}",
237 weights.len(),
238 self.n
239 ));
240 }
241 validate_finite_vector("weights", weights)
242 }
243}
244
245fn validate_finite_matrix(name: &str, matrix: ArrayView2<'_, f64>) -> Result<(), String> {
246 for ((row, col), value) in matrix.indexed_iter() {
247 if !(*value).is_finite() {
248 return Err(format!("{name}[{row},{col}] must be finite"));
249 }
250 }
251 Ok(())
252}
253
254fn validate_finite_vector(name: &str, vector: ArrayView1<'_, f64>) -> Result<(), String> {
255 for (index, value) in vector.iter().enumerate() {
256 if !(*value).is_finite() {
257 return Err(format!("{name}[{index}] must be finite"));
258 }
259 }
260 Ok(())
261}
262
263fn validate_nonnegative_finite_weights(weights: ArrayView1<'_, f64>) -> Result<(), String> {
264 for (index, weight) in weights.iter().enumerate() {
265 if !(*weight).is_finite() {
266 return Err(format!("weights[{index}] must be finite"));
267 }
268 if *weight < 0.0 {
269 return Err(format!("weights[{index}] must be non-negative"));
270 }
271 }
272 Ok(())
273}
274
275fn weighted_sum_squares(weights: ArrayView1<'_, f64>, values: ArrayView1<'_, f64>) -> f64 {
276 weights
277 .iter()
278 .zip(values.iter())
279 .map(|(weight, value)| *weight * *value * *value)
280 .sum()
281}
282
283#[cfg(test)]
284mod tests {
285 use super::{FixedDesignGramCache, FixedDesignRowCache};
286 use gam_linalg::faer_ndarray::fast_xt_diag_x;
287 use approx::assert_abs_diff_eq;
288 use ndarray::{Array1, Array2};
289
290 fn deterministic_design(n: usize, p: usize) -> Array2<f64> {
291 Array2::from_shape_fn((n, p), |(i, j)| {
292 let row = i as f64 + 1.0;
293 let col = j as f64 + 1.0;
294 ((row * 0.17 + col * 0.31).sin()) + row * col * 0.002
295 })
296 }
297
298 fn deterministic_response(n: usize) -> Array1<f64> {
299 Array1::from_shape_fn(n, |i| {
300 let row = i as f64 + 1.0;
301 (row * 0.23).cos() + row * 0.015
302 })
303 }
304
305 fn deterministic_offset(n: usize) -> Array1<f64> {
306 Array1::from_shape_fn(n, |i| {
307 let row = i as f64 + 1.0;
308 0.2 * (row * 0.11).sin() - 0.01 * row
309 })
310 }
311
312 fn deterministic_weights(n: usize, scale: f64) -> Array1<f64> {
313 Array1::from_shape_fn(n, |i| {
314 let row = i as f64 + 1.0;
315 0.4 + scale * (1.0 + (row * 0.19).sin())
316 })
317 }
318
319 fn naive_xtx(x: &Array2<f64>) -> Array2<f64> {
320 let n = x.nrows();
321 let p = x.ncols();
322 let mut out = Array2::zeros((p, p));
323 for row in 0..n {
324 for a in 0..p {
325 for b in 0..p {
326 out[[a, b]] += x[[row, a]] * x[[row, b]];
327 }
328 }
329 }
330 out
331 }
332
333 fn naive_xtwy(x: &Array2<f64>, weights: &Array1<f64>, r: &Array1<f64>) -> Array1<f64> {
334 let n = x.nrows();
335 let p = x.ncols();
336 let mut out = Array1::zeros(p);
337 for row in 0..n {
338 for col in 0..p {
339 out[col] += x[[row, col]] * weights[row] * r[row];
340 }
341 }
342 out
343 }
344
345 fn naive_xtwz(x: &Array2<f64>, weights: &Array1<f64>, z: &Array1<f64>) -> Array1<f64> {
346 naive_xtwy(x, weights, z)
347 }
348
349 fn naive_ywy(weights: &Array1<f64>, r: &Array1<f64>) -> f64 {
350 let mut sum = 0.0;
351 for row in 0..weights.len() {
352 sum += weights[row] * r[row] * r[row];
353 }
354 sum
355 }
356
357 fn assert_matrix_close(actual: ndarray::ArrayView2<'_, f64>, expected: &Array2<f64>, eps: f64) {
358 assert_eq!(actual.nrows(), expected.nrows());
359 assert_eq!(actual.ncols(), expected.ncols());
360 for row in 0..expected.nrows() {
361 for col in 0..expected.ncols() {
362 assert_abs_diff_eq!(actual[[row, col]], expected[[row, col]], epsilon = eps);
363 }
364 }
365 }
366
367 fn assert_vector_close(actual: ndarray::ArrayView1<'_, f64>, expected: &Array1<f64>, eps: f64) {
368 assert_eq!(actual.len(), expected.len());
369 for index in 0..expected.len() {
370 assert_abs_diff_eq!(actual[index], expected[index], epsilon = eps);
371 }
372 }
373
374 #[test]
375 fn gaussian_xtwx_matches_naive() {
376 let n = 40;
377 let p = 4;
378 let x = deterministic_design(n, p);
379 let y = deterministic_response(n);
380 let cache = FixedDesignGramCache::build(x.view(), y.view(), None, None).unwrap();
381 let naive = naive_xtx(&x);
382 assert_matrix_close(cache.xtwx(), &naive, 1.0e-9);
383 }
384
385 #[test]
386 fn gaussian_xtwy_and_ywy_match_naive() {
387 let n = 40;
388 let p = 4;
389 let x = deterministic_design(n, p);
390 let y = deterministic_response(n);
391 let offset = deterministic_offset(n);
392 let weights = deterministic_weights(n, 0.35);
393 let r = &y - &offset;
394 let cache = FixedDesignGramCache::build(
395 x.view(),
396 y.view(),
397 Some(offset.view()),
398 Some(weights.view()),
399 )
400 .unwrap();
401 let expected_xtwy = naive_xtwy(&x, &weights, &r);
402 let expected_ywy = naive_ywy(&weights, &r);
403 assert_vector_close(cache.xtwy(), &expected_xtwy, 1.0e-9);
404 assert_abs_diff_eq!(cache.ywy(), expected_ywy, epsilon = 1.0e-9);
405 }
406
407 #[test]
408 fn penalized_rss_matches_direct_residual() {
409 let n = 40;
410 let p = 4;
411 let x = deterministic_design(n, p);
412 let y = deterministic_response(n);
413 let offset = deterministic_offset(n);
414 let weights = deterministic_weights(n, 0.21);
415 let beta = Array1::from_vec(vec![0.4, -0.2, 0.15, 0.05]);
416 let r = &y - &offset;
417 let cache = FixedDesignGramCache::build(
418 x.view(),
419 y.view(),
420 Some(offset.view()),
421 Some(weights.view()),
422 )
423 .unwrap();
424 let mut direct = 0.0;
425 for row in 0..n {
426 let mut fit = 0.0;
427 for col in 0..p {
428 fit += x[[row, col]] * beta[col];
429 }
430 let residual = r[row] - fit;
431 direct += weights[row] * residual * residual;
432 }
433 let cached = cache.penalized_rss(beta.view()).unwrap();
434 assert_abs_diff_eq!(cached, direct, epsilon = 1.0e-8);
435 }
436
437 #[test]
438 fn penalized_normal_matrix_adds_penalty() {
439 let n = 40;
440 let p = 4;
441 let x = deterministic_design(n, p);
442 let y = deterministic_response(n);
443 let cache = FixedDesignGramCache::build(x.view(), y.view(), None, None).unwrap();
444 let penalty = Array2::from_shape_fn((p, p), |(row, col)| {
445 if row == col {
446 0.5 + row as f64 * 0.1
447 } else {
448 0.02 * (row + col) as f64
449 }
450 });
451 let normal = cache.penalized_normal_matrix(penalty.view()).unwrap();
452 for row in 0..p {
453 for col in 0..p {
454 let expected = cache.xtwx()[[row, col]] + penalty[[row, col]];
455 assert_abs_diff_eq!(normal[[row, col]], expected, epsilon = 1.0e-12);
456 }
457 }
458 }
459
460 #[test]
461 fn row_cache_xtwx_matches_fresh_build_across_weights() {
462 let n = 40;
463 let p = 4;
464 let x = deterministic_design(n, p);
465 let cache = FixedDesignRowCache::build(x.view()).unwrap();
466 let weight_sets = [
467 deterministic_weights(n, 0.12),
468 deterministic_weights(n, 0.27),
469 deterministic_weights(n, 0.41),
470 ];
471 for weights in weight_sets.iter() {
472 let cached = cache.xtwx(weights.view()).unwrap();
473 let fresh = fast_xt_diag_x(&x, weights);
474 assert_matrix_close(cached.view(), &fresh, 1.0e-12);
475 }
476 }
477
478 #[test]
479 fn row_cache_xtwz_matches_naive() {
480 let n = 40;
481 let p = 4;
482 let x = deterministic_design(n, p);
483 let weights = deterministic_weights(n, 0.33);
484 let z = Array1::from_shape_fn(n, |i| {
485 let row = i as f64 + 1.0;
486 (row * 0.07).sin() + 0.03 * row
487 });
488 let cache = FixedDesignRowCache::build(x.view()).unwrap();
489 let cached = cache.xtwz(weights.view(), z.view()).unwrap();
490 let expected = naive_xtwz(&x, &weights, &z);
491 assert_vector_close(cached.view(), &expected, 1.0e-9);
492 }
493
494 #[test]
495 fn build_rejects_shape_mismatch() {
496 let n = 40;
497 let p = 4;
498 let x = deterministic_design(n, p);
499 let mismatched_y = deterministic_response(n - 1);
500 assert!(FixedDesignGramCache::build(x.view(), mismatched_y.view(), None, None).is_err());
501
502 let y = deterministic_response(n);
503 let mut weights = deterministic_weights(n, 0.2);
504 weights[3] = f64::NAN;
505 assert!(
506 FixedDesignGramCache::build(x.view(), y.view(), None, Some(weights.view())).is_err()
507 );
508 }
509}