Skip to main content

fdars_core/
function_on_scalar_2d.rs

1//! 2D Function-on-Scalar Regression (FOSR).
2//!
3//! Extends the 1D function-on-scalar model to surface-valued functional
4//! responses observed on a regular 2D grid:
5//! ```text
6//! Y_i(s,t) = beta_0(s,t) + sum_j z_{ij} beta_j(s,t) + epsilon_i(s,t)
7//! ```
8//!
9//! The estimation uses a two-step approach:
10//! 1. Pointwise OLS at each grid point to obtain raw coefficient surfaces.
11//! 2. Tensor-product roughness penalty smoothing of each coefficient surface.
12//!
13//! # Methods
14//!
15//! - [`fosr_2d`]: Penalized 2D FOSR with anisotropic smoothing
16//! - [`predict_fosr_2d`]: Predict new surfaces from a fitted model
17
18use crate::error::FdarError;
19use crate::function_on_scalar::{
20    build_fosr_design, compute_xtx, compute_xty_matrix, penalty_matrix, pointwise_r_squared,
21};
22use crate::matrix::FdMatrix;
23
24// ---------------------------------------------------------------------------
25// Types
26// ---------------------------------------------------------------------------
27
28/// 2D grid description for surface-valued functional data.
29#[derive(Debug, Clone, PartialEq)]
30pub struct Grid2d {
31    /// Grid points along the first (row) dimension.
32    pub argvals_s: Vec<f64>,
33    /// Grid points along the second (column) dimension.
34    pub argvals_t: Vec<f64>,
35}
36
37impl Grid2d {
38    /// Create a new 2D grid from argument vectors.
39    pub fn new(argvals_s: Vec<f64>, argvals_t: Vec<f64>) -> Self {
40        Self {
41            argvals_s,
42            argvals_t,
43        }
44    }
45
46    /// Number of grid points in the first dimension.
47    #[inline]
48    pub fn m1(&self) -> usize {
49        self.argvals_s.len()
50    }
51
52    /// Number of grid points in the second dimension.
53    #[inline]
54    pub fn m2(&self) -> usize {
55        self.argvals_t.len()
56    }
57
58    /// Total number of grid points (m1 * m2).
59    #[inline]
60    pub fn total(&self) -> usize {
61        self.m1() * self.m2()
62    }
63}
64
65/// Result of 2D function-on-scalar regression.
66#[derive(Debug, Clone, PartialEq)]
67pub struct FosrResult2d {
68    /// Intercept surface beta_0(s,t), flattened column-major (length m1*m2).
69    pub intercept: Vec<f64>,
70    /// Coefficient surfaces beta_j(s,t), p x (m1*m2) matrix (row j = flattened beta_j).
71    pub beta: FdMatrix,
72    /// Fitted surface values, n x (m1*m2) matrix.
73    pub fitted: FdMatrix,
74    /// Residual surfaces, n x (m1*m2) matrix.
75    pub residuals: FdMatrix,
76    /// Pointwise R^2(s,t), flattened (length m1*m2).
77    pub r_squared_pointwise: Vec<f64>,
78    /// Global R^2 (average of pointwise).
79    pub r_squared: f64,
80    /// Standard error surfaces for each beta_j (p x (m1*m2)), or None if
81    /// the system is underdetermined.
82    pub beta_se: Option<FdMatrix>,
83    /// Smoothing parameter in the s-direction.
84    pub lambda_s: f64,
85    /// Smoothing parameter in the t-direction.
86    pub lambda_t: f64,
87    /// Generalised cross-validation score.
88    pub gcv: f64,
89    /// Grid specification.
90    pub grid: Grid2d,
91}
92
93impl FosrResult2d {
94    /// Reshape the j-th coefficient surface into an m1 x m2 matrix.
95    ///
96    /// # Panics
97    /// Panics if `j >= p` (number of predictors).
98    #[must_use]
99    pub fn beta_surface(&self, j: usize) -> FdMatrix {
100        let m1 = self.grid.m1();
101        let m2 = self.grid.m2();
102        let m_total = m1 * m2;
103        let mut mat = FdMatrix::zeros(m1, m2);
104        for g in 0..m_total {
105            // Column-major: g = s_idx + t_idx * m1
106            let s_idx = g % m1;
107            let t_idx = g / m1;
108            mat[(s_idx, t_idx)] = self.beta[(j, g)];
109        }
110        mat
111    }
112
113    /// Reshape pointwise R^2(s,t) into an m1 x m2 matrix.
114    #[must_use]
115    pub fn r_squared_surface(&self) -> FdMatrix {
116        let m1 = self.grid.m1();
117        let m2 = self.grid.m2();
118        let m_total = m1 * m2;
119        let mut mat = FdMatrix::zeros(m1, m2);
120        for g in 0..m_total {
121            let s_idx = g % m1;
122            let t_idx = g / m1;
123            mat[(s_idx, t_idx)] = self.r_squared_pointwise[g];
124        }
125        mat
126    }
127
128    /// Reshape the residual for observation `i` into an m1 x m2 matrix.
129    ///
130    /// # Panics
131    /// Panics if `i >= n` (number of observations).
132    #[must_use]
133    pub fn residual_surface(&self, i: usize) -> FdMatrix {
134        let m1 = self.grid.m1();
135        let m2 = self.grid.m2();
136        let m_total = m1 * m2;
137        let mut mat = FdMatrix::zeros(m1, m2);
138        for g in 0..m_total {
139            let s_idx = g % m1;
140            let t_idx = g / m1;
141            mat[(s_idx, t_idx)] = self.residuals[(i, g)];
142        }
143        mat
144    }
145
146    /// Predict functional surfaces for new predictors. Delegates to
147    /// [`predict_fosr_2d`].
148    pub fn predict(&self, new_predictors: &FdMatrix) -> Result<FdMatrix, FdarError> {
149        predict_fosr_2d(self, new_predictors)
150    }
151}
152
153// ---------------------------------------------------------------------------
154// Linear algebra helpers
155// ---------------------------------------------------------------------------
156
157/// Cholesky factorization: A = LL'. Returns L (p x p flat row-major) or None
158/// if the matrix is not positive definite.
159fn cholesky_factor(a: &[f64], p: usize) -> Option<Vec<f64>> {
160    let mut l = vec![0.0; p * p];
161    for j in 0..p {
162        let mut diag = a[j * p + j];
163        for k in 0..j {
164            diag -= l[j * p + k] * l[j * p + k];
165        }
166        if diag <= 1e-12 {
167            return None;
168        }
169        l[j * p + j] = diag.sqrt();
170        for i in (j + 1)..p {
171            let mut s = a[i * p + j];
172            for k in 0..j {
173                s -= l[i * p + k] * l[j * p + k];
174            }
175            l[i * p + j] = s / l[j * p + j];
176        }
177    }
178    Some(l)
179}
180
181/// Solve Lz = b (forward) then L'x = z (back).
182fn cholesky_forward_back(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
183    let mut z = b.to_vec();
184    for j in 0..p {
185        for k in 0..j {
186            z[j] -= l[j * p + k] * z[k];
187        }
188        z[j] /= l[j * p + j];
189    }
190    for j in (0..p).rev() {
191        for k in (j + 1)..p {
192            z[j] -= l[k * p + j] * z[k];
193        }
194        z[j] /= l[j * p + j];
195    }
196    z
197}
198
199/// Kronecker product of two flat row-major matrices.
200///
201/// Given A (rows_a x cols_a) and B (rows_b x cols_b), produces
202/// C = A kron B of size (rows_a * rows_b) x (cols_a * cols_b) in
203/// row-major layout.
204fn kronecker_product(
205    a: &[f64],
206    rows_a: usize,
207    cols_a: usize,
208    b: &[f64],
209    rows_b: usize,
210    cols_b: usize,
211) -> Vec<f64> {
212    let out_rows = rows_a * rows_b;
213    let out_cols = cols_a * cols_b;
214    let mut c = vec![0.0; out_rows * out_cols];
215    for ia in 0..rows_a {
216        for ja in 0..cols_a {
217            let a_val = a[ia * cols_a + ja];
218            for ib in 0..rows_b {
219                for jb in 0..cols_b {
220                    let row = ia * rows_b + ib;
221                    let col = ja * cols_b + jb;
222                    c[row * out_cols + col] = a_val * b[ib * cols_b + jb];
223                }
224            }
225        }
226    }
227    c
228}
229
230/// Identity matrix as flat row-major vector.
231fn identity_matrix(n: usize) -> Vec<f64> {
232    let mut m = vec![0.0; n * n];
233    for i in 0..n {
234        m[i * n + i] = 1.0;
235    }
236    m
237}
238
239/// Build the 2D tensor-product penalty matrix.
240///
241/// P_2d = lambda_s * (P_s kron I_t) + lambda_t * (I_s kron P_t)
242///
243/// where P_s = D_s'D_s (m1 x m1) and P_t = D_t'D_t (m2 x m2) are the
244/// second-difference penalty matrices.
245fn penalty_matrix_2d(m1: usize, m2: usize, lambda_s: f64, lambda_t: f64) -> Vec<f64> {
246    let m_total = m1 * m2;
247    let ps = penalty_matrix(m1);
248    let pt = penalty_matrix(m2);
249    let i_t = identity_matrix(m2);
250    let i_s = identity_matrix(m1);
251
252    let ps_kron_it = kronecker_product(&ps, m1, m1, &i_t, m2, m2);
253    let is_kron_pt = kronecker_product(&i_s, m1, m1, &pt, m2, m2);
254
255    let mut p2d = vec![0.0; m_total * m_total];
256    for i in 0..m_total * m_total {
257        p2d[i] = lambda_s * ps_kron_it[i] + lambda_t * is_kron_pt[i];
258    }
259    p2d
260}
261
262// ---------------------------------------------------------------------------
263// Core fitting routines
264// ---------------------------------------------------------------------------
265
266/// Compute fitted values Y_hat = X * beta and residuals Y - Y_hat.
267fn compute_fitted_residuals(
268    design: &FdMatrix,
269    beta: &FdMatrix,
270    data: &FdMatrix,
271) -> (FdMatrix, FdMatrix) {
272    let (n, m_total) = data.shape();
273    let p_total = design.ncols();
274    let mut fitted = FdMatrix::zeros(n, m_total);
275    let mut residuals = FdMatrix::zeros(n, m_total);
276    for i in 0..n {
277        for g in 0..m_total {
278            let mut yhat = 0.0;
279            for j in 0..p_total {
280                yhat += design[(i, j)] * beta[(j, g)];
281            }
282            fitted[(i, g)] = yhat;
283            residuals[(i, g)] = data[(i, g)] - yhat;
284        }
285    }
286    (fitted, residuals)
287}
288
289/// Compute GCV: (1/n*m) * sum(r^2) / (1 - tr(H)/n)^2.
290fn compute_gcv(residuals: &FdMatrix, trace_h: f64) -> f64 {
291    let (n, m) = residuals.shape();
292    let denom = (1.0 - trace_h / n as f64).max(1e-10);
293    let ss_res: f64 = residuals.as_slice().iter().map(|v| v * v).sum();
294    ss_res / (n as f64 * m as f64 * denom * denom)
295}
296
297/// Compute trace of hat matrix for OLS: tr(H) = p_total (since X(X'X)^{-1}X'
298/// projects onto a p_total-dimensional subspace).
299fn trace_hat_ols(p_total: usize) -> f64 {
300    p_total as f64
301}
302
303/// Smooth a raw coefficient vector (length m_total) using the 2D penalty.
304///
305/// Solves (I + P_2d) * beta_smooth = beta_raw via Cholesky.
306fn smooth_coefficient_surface(
307    beta_raw: &[f64],
308    penalty_2d: &[f64],
309    m_total: usize,
310) -> Result<Vec<f64>, FdarError> {
311    // Build (I + P_2d)
312    let mut a = penalty_2d.to_vec();
313    for i in 0..m_total {
314        a[i * m_total + i] += 1.0;
315    }
316    let l = cholesky_factor(&a, m_total).ok_or_else(|| FdarError::ComputationFailed {
317        operation: "smooth_coefficient_surface",
318        detail: "Cholesky factorization of (I + P_2d) failed".to_string(),
319    })?;
320    Ok(cholesky_forward_back(&l, beta_raw, m_total))
321}
322
323/// Compute standard errors from the diagonal of (X'X)^{-1} and residual
324/// variance at each grid point.
325fn compute_beta_se_2d(
326    xtx: &[f64],
327    residuals: &FdMatrix,
328    p_total: usize,
329    n: usize,
330) -> Option<FdMatrix> {
331    let m_total = residuals.ncols();
332    let l = cholesky_factor(xtx, p_total)?;
333
334    // Diagonal of (X'X)^{-1}
335    let a_inv_diag: Vec<f64> = (0..p_total)
336        .map(|j| {
337            let mut ej = vec![0.0; p_total];
338            ej[j] = 1.0;
339            let v = cholesky_forward_back(&l, &ej, p_total);
340            v[j]
341        })
342        .collect();
343
344    let df = (n - p_total).max(1) as f64;
345    // We return SE for the predictor coefficients only (drop intercept).
346    let p = p_total - 1;
347    let mut se = FdMatrix::zeros(p, m_total);
348    for g in 0..m_total {
349        let sigma2: f64 = (0..n).map(|i| residuals[(i, g)].powi(2)).sum::<f64>() / df;
350        for j in 0..p {
351            // j+1 to skip intercept row in a_inv_diag
352            se[(j, g)] = (sigma2 * a_inv_diag[j + 1]).max(0.0).sqrt();
353        }
354    }
355    Some(se)
356}
357
358// ---------------------------------------------------------------------------
359// GCV lambda selection
360// ---------------------------------------------------------------------------
361
362/// Select (lambda_s, lambda_t) via GCV over a 2D grid of candidate values.
363fn select_lambdas_gcv(
364    xtx: &[f64],
365    xty: &FdMatrix,
366    design: &FdMatrix,
367    data: &FdMatrix,
368    m1: usize,
369    m2: usize,
370    fix_lambda_s: Option<f64>,
371    fix_lambda_t: Option<f64>,
372) -> (f64, f64) {
373    let candidates = [0.0, 1e-6, 1e-4, 1e-2, 0.1, 1.0, 10.0, 100.0, 1000.0];
374    let p_total = design.ncols();
375    let m_total = m1 * m2;
376
377    let ls_candidates: Vec<f64> = if let Some(ls) = fix_lambda_s {
378        vec![ls]
379    } else {
380        candidates.to_vec()
381    };
382    let lt_candidates: Vec<f64> = if let Some(lt) = fix_lambda_t {
383        vec![lt]
384    } else {
385        candidates.to_vec()
386    };
387
388    // Pre-compute OLS inverse and raw beta
389    let l_xtx = match cholesky_factor(xtx, p_total) {
390        Some(l) => l,
391        None => return (0.0, 0.0),
392    };
393
394    // beta_ols: p_total x m_total
395    let mut beta_ols = FdMatrix::zeros(p_total, m_total);
396    for g in 0..m_total {
397        let b: Vec<f64> = (0..p_total).map(|j| xty[(j, g)]).collect();
398        let x = cholesky_forward_back(&l_xtx, &b, p_total);
399        for j in 0..p_total {
400            beta_ols[(j, g)] = x[j];
401        }
402    }
403
404    let trace_h = trace_hat_ols(p_total);
405    let mut best_ls = 0.0;
406    let mut best_lt = 0.0;
407    let mut best_gcv = f64::INFINITY;
408
409    for &ls in &ls_candidates {
410        for &lt in &lt_candidates {
411            if ls == 0.0 && lt == 0.0 {
412                // No smoothing: use raw OLS
413                let (_, residuals) = compute_fitted_residuals(design, &beta_ols, data);
414                let gcv = compute_gcv(&residuals, trace_h);
415                if gcv < best_gcv {
416                    best_gcv = gcv;
417                    best_ls = ls;
418                    best_lt = lt;
419                }
420                continue;
421            }
422
423            let p2d = penalty_matrix_2d(m1, m2, ls, lt);
424            let mut beta_smooth = FdMatrix::zeros(p_total, m_total);
425            let mut ok = true;
426            for j in 0..p_total {
427                let raw: Vec<f64> = (0..m_total).map(|g| beta_ols[(j, g)]).collect();
428                match smooth_coefficient_surface(&raw, &p2d, m_total) {
429                    Ok(smoothed) => {
430                        for g in 0..m_total {
431                            beta_smooth[(j, g)] = smoothed[g];
432                        }
433                    }
434                    Err(_) => {
435                        ok = false;
436                        break;
437                    }
438                }
439            }
440            if !ok {
441                continue;
442            }
443
444            let (_, residuals) = compute_fitted_residuals(design, &beta_smooth, data);
445            let gcv = compute_gcv(&residuals, trace_h);
446            if gcv < best_gcv {
447                best_gcv = gcv;
448                best_ls = ls;
449                best_lt = lt;
450            }
451        }
452    }
453
454    (best_ls, best_lt)
455}
456
457// ---------------------------------------------------------------------------
458// Public API
459// ---------------------------------------------------------------------------
460
461/// 2D Function-on-Scalar Regression with tensor-product penalty.
462///
463/// Fits the model
464/// ```text
465/// Y_i(s,t) = beta_0(s,t) + sum_j x_{ij} beta_j(s,t) + epsilon_i(s,t)
466/// ```
467/// with anisotropic roughness penalty
468/// `lambda_s * ||d^2 beta/ds^2||^2 + lambda_t * ||d^2 beta/dt^2||^2`.
469///
470/// # Arguments
471/// * `data` - Functional response matrix (n x m_total where m_total = m1*m2)
472/// * `predictors` - Scalar predictor matrix (n x p)
473/// * `grid` - 2D grid specification
474/// * `lambda_s` - Smoothing parameter in the s-direction (negative = GCV)
475/// * `lambda_t` - Smoothing parameter in the t-direction (negative = GCV)
476///
477/// # Errors
478///
479/// Returns [`FdarError::InvalidDimension`] if `data.ncols() != grid.total()`,
480/// `predictors.nrows() != data.nrows()`, or `n < p + 2`.
481/// Returns [`FdarError::InvalidParameter`] if the grid has zero size in
482/// either dimension.
483/// Returns [`FdarError::ComputationFailed`] if the Cholesky factorization
484/// fails during OLS or smoothing.
485#[must_use = "expensive computation whose result should not be discarded"]
486pub fn fosr_2d(
487    data: &FdMatrix,
488    predictors: &FdMatrix,
489    grid: &Grid2d,
490    lambda_s: f64,
491    lambda_t: f64,
492) -> Result<FosrResult2d, FdarError> {
493    let (n, m_data) = data.shape();
494    let p = predictors.ncols();
495    let m1 = grid.m1();
496    let m2 = grid.m2();
497    let m_total = grid.total();
498
499    // ---- Validate inputs ----
500
501    if m1 == 0 {
502        return Err(FdarError::InvalidParameter {
503            parameter: "grid",
504            message: "argvals_s must not be empty".to_string(),
505        });
506    }
507    if m2 == 0 {
508        return Err(FdarError::InvalidParameter {
509            parameter: "grid",
510            message: "argvals_t must not be empty".to_string(),
511        });
512    }
513    if m_data != m_total {
514        return Err(FdarError::InvalidDimension {
515            parameter: "data",
516            expected: format!("{m_total} columns (grid m1*m2 = {m1}*{m2})"),
517            actual: format!("{m_data} columns"),
518        });
519    }
520    if predictors.nrows() != n {
521        return Err(FdarError::InvalidDimension {
522            parameter: "predictors",
523            expected: format!("{n} rows (matching data)"),
524            actual: format!("{} rows", predictors.nrows()),
525        });
526    }
527    if n < p + 2 {
528        return Err(FdarError::InvalidDimension {
529            parameter: "data",
530            expected: format!("at least {} observations (p + 2)", p + 2),
531            actual: format!("{n} observations"),
532        });
533    }
534
535    // ---- Step 1: Build design and compute OLS ----
536
537    let design = build_fosr_design(predictors, n);
538    let p_total = design.ncols(); // p + 1
539    let xtx = compute_xtx(&design);
540    let xty = compute_xty_matrix(&design, data);
541
542    let l_xtx = cholesky_factor(&xtx, p_total).ok_or_else(|| FdarError::ComputationFailed {
543        operation: "fosr_2d",
544        detail: "Cholesky factorization of X'X failed; design matrix is rank-deficient".to_string(),
545    })?;
546
547    // Pointwise OLS: beta_ols[:,g] = (X'X)^{-1} X' y[:,g]
548    let mut beta_ols = FdMatrix::zeros(p_total, m_total);
549    for g in 0..m_total {
550        let b: Vec<f64> = (0..p_total).map(|j| xty[(j, g)]).collect();
551        let x = cholesky_forward_back(&l_xtx, &b, p_total);
552        for j in 0..p_total {
553            beta_ols[(j, g)] = x[j];
554        }
555    }
556
557    // ---- Step 2: Determine lambda values ----
558
559    let fix_ls = if lambda_s >= 0.0 {
560        Some(lambda_s)
561    } else {
562        None
563    };
564    let fix_lt = if lambda_t >= 0.0 {
565        Some(lambda_t)
566    } else {
567        None
568    };
569
570    let (lambda_s_final, lambda_t_final) = if fix_ls.is_some() && fix_lt.is_some() {
571        (lambda_s, lambda_t)
572    } else {
573        select_lambdas_gcv(&xtx, &xty, &design, data, m1, m2, fix_ls, fix_lt)
574    };
575
576    // ---- Step 3: Smooth coefficient surfaces ----
577
578    let beta_smooth = if lambda_s_final == 0.0 && lambda_t_final == 0.0 {
579        beta_ols
580    } else {
581        let p2d = penalty_matrix_2d(m1, m2, lambda_s_final, lambda_t_final);
582        let mut smoothed = FdMatrix::zeros(p_total, m_total);
583        for j in 0..p_total {
584            let raw: Vec<f64> = (0..m_total).map(|g| beta_ols[(j, g)]).collect();
585            let s = smooth_coefficient_surface(&raw, &p2d, m_total)?;
586            for g in 0..m_total {
587                smoothed[(j, g)] = s[g];
588            }
589        }
590        smoothed
591    };
592
593    // ---- Step 4: Compute diagnostics ----
594
595    let (fitted, residuals) = compute_fitted_residuals(&design, &beta_smooth, data);
596
597    let r_squared_pointwise = pointwise_r_squared(data, &fitted);
598    let r_squared = if m_total > 0 {
599        r_squared_pointwise.iter().sum::<f64>() / m_total as f64
600    } else {
601        0.0
602    };
603
604    let trace_h = trace_hat_ols(p_total);
605    let gcv = compute_gcv(&residuals, trace_h);
606
607    let beta_se = compute_beta_se_2d(&xtx, &residuals, p_total, n);
608
609    // Extract intercept (row 0 of beta_smooth)
610    let intercept: Vec<f64> = (0..m_total).map(|g| beta_smooth[(0, g)]).collect();
611
612    // Extract predictor coefficients (rows 1..p_total)
613    let mut beta_out = FdMatrix::zeros(p, m_total);
614    for j in 0..p {
615        for g in 0..m_total {
616            beta_out[(j, g)] = beta_smooth[(j + 1, g)];
617        }
618    }
619
620    Ok(FosrResult2d {
621        intercept,
622        beta: beta_out,
623        fitted,
624        residuals,
625        r_squared_pointwise,
626        r_squared,
627        beta_se,
628        lambda_s: lambda_s_final,
629        lambda_t: lambda_t_final,
630        gcv,
631        grid: grid.clone(),
632    })
633}
634
635/// Predict functional surfaces for new observations.
636///
637/// # Arguments
638/// * `result` - Fitted [`FosrResult2d`]
639/// * `new_predictors` - New scalar predictors (n_new x p)
640///
641/// # Errors
642///
643/// Returns [`FdarError::InvalidDimension`] if the number of predictor columns
644/// does not match the fitted model.
645#[must_use = "prediction result should not be discarded"]
646pub fn predict_fosr_2d(
647    result: &FosrResult2d,
648    new_predictors: &FdMatrix,
649) -> Result<FdMatrix, FdarError> {
650    let n_new = new_predictors.nrows();
651    let m_total = result.intercept.len();
652    let p = result.beta.nrows();
653
654    if new_predictors.ncols() != p {
655        return Err(FdarError::InvalidDimension {
656            parameter: "new_predictors",
657            expected: format!("{p} columns (matching fitted model)"),
658            actual: format!("{} columns", new_predictors.ncols()),
659        });
660    }
661
662    let mut predicted = FdMatrix::zeros(n_new, m_total);
663    for i in 0..n_new {
664        for g in 0..m_total {
665            let mut yhat = result.intercept[g];
666            for j in 0..p {
667                yhat += new_predictors[(i, j)] * result.beta[(j, g)];
668            }
669            predicted[(i, g)] = yhat;
670        }
671    }
672    Ok(predicted)
673}
674
675// ---------------------------------------------------------------------------
676// Tests
677// ---------------------------------------------------------------------------
678
679#[cfg(test)]
680mod tests {
681    use super::*;
682
683    fn uniform_grid_1d(m: usize) -> Vec<f64> {
684        (0..m).map(|j| j as f64 / (m - 1).max(1) as f64).collect()
685    }
686
687    fn make_grid(m1: usize, m2: usize) -> Grid2d {
688        Grid2d::new(uniform_grid_1d(m1), uniform_grid_1d(m2))
689    }
690
691    /// Generate test data: Y_i(s,t) = intercept(s,t) + z_1 * beta_1(s,t) + z_2 * beta_2(s,t) + noise
692    fn generate_2d_data(
693        n: usize,
694        m1: usize,
695        m2: usize,
696        noise_scale: f64,
697    ) -> (FdMatrix, FdMatrix, Grid2d) {
698        let grid = make_grid(m1, m2);
699        let m_total = m1 * m2;
700        let mut y = FdMatrix::zeros(n, m_total);
701        let mut z = FdMatrix::zeros(n, 2);
702
703        for i in 0..n {
704            let z1 = (i as f64) / (n as f64);
705            let z2 = if i % 2 == 0 { 1.0 } else { 0.0 };
706            z[(i, 0)] = z1;
707            z[(i, 1)] = z2;
708
709            for si in 0..m1 {
710                for ti in 0..m2 {
711                    let g = si + ti * m1; // column-major flat index
712                    let s = grid.argvals_s[si];
713                    let t = grid.argvals_t[ti];
714
715                    let intercept = s + t;
716                    let beta1 = s * t;
717                    let beta2 = s - t;
718                    let noise = noise_scale * ((i * 13 + si * 7 + ti * 3) % 100) as f64 / 100.0;
719
720                    y[(i, g)] = intercept + z1 * beta1 + z2 * beta2 + noise;
721                }
722            }
723        }
724        (y, z, grid)
725    }
726
727    #[test]
728    fn test_grid2d_basic() {
729        let grid = make_grid(5, 4);
730        assert_eq!(grid.m1(), 5);
731        assert_eq!(grid.m2(), 4);
732        assert_eq!(grid.total(), 20);
733    }
734
735    #[test]
736    fn test_kronecker_product_small() {
737        // A = [[1,2],[3,4]], B = [[0,5],[6,7]]
738        // A kron B = [[0,5,0,10],[6,7,12,14],[0,15,0,20],[18,21,24,28]]
739        let a = vec![1.0, 2.0, 3.0, 4.0];
740        let b = vec![0.0, 5.0, 6.0, 7.0];
741        let c = kronecker_product(&a, 2, 2, &b, 2, 2);
742        assert_eq!(c.len(), 16);
743        #[rustfmt::skip]
744        let expected = vec![
745            0.0, 5.0, 0.0, 10.0,
746            6.0, 7.0, 12.0, 14.0,
747            0.0, 15.0, 0.0, 20.0,
748            18.0, 21.0, 24.0, 28.0,
749        ];
750        for (i, (&ci, &ei)) in c.iter().zip(expected.iter()).enumerate() {
751            assert!(
752                (ci - ei).abs() < 1e-12,
753                "kronecker mismatch at index {i}: got {ci}, expected {ei}"
754            );
755        }
756    }
757
758    #[test]
759    fn test_penalty_2d_symmetry() {
760        let m1 = 5;
761        let m2 = 4;
762        let p2d = penalty_matrix_2d(m1, m2, 1.0, 1.0);
763        let m_total = m1 * m2;
764        for i in 0..m_total {
765            for j in 0..m_total {
766                assert!(
767                    (p2d[i * m_total + j] - p2d[j * m_total + i]).abs() < 1e-12,
768                    "P_2d not symmetric at ({i},{j})"
769                );
770            }
771        }
772    }
773
774    #[test]
775    fn test_penalty_2d_shape() {
776        let m1 = 5;
777        let m2 = 4;
778        let p2d = penalty_matrix_2d(m1, m2, 1.0, 2.0);
779        assert_eq!(p2d.len(), (m1 * m2) * (m1 * m2));
780    }
781
782    #[test]
783    fn test_fosr_2d_constant_response() {
784        let n = 20;
785        let m1 = 5;
786        let m2 = 4;
787        let grid = make_grid(m1, m2);
788        let m_total = m1 * m2;
789
790        // Y_i(s,t) = 3.0 for all i,s,t
791        let mut y = FdMatrix::zeros(n, m_total);
792        for i in 0..n {
793            for g in 0..m_total {
794                y[(i, g)] = 3.0;
795            }
796        }
797
798        let mut z = FdMatrix::zeros(n, 2);
799        for i in 0..n {
800            z[(i, 0)] = i as f64;
801            z[(i, 1)] = (i % 3) as f64;
802        }
803
804        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
805
806        // Intercept should be near 3.0 everywhere
807        for g in 0..m_total {
808            assert!(
809                (result.intercept[g] - 3.0).abs() < 1e-8,
810                "intercept[{g}] = {}, expected 3.0",
811                result.intercept[g]
812            );
813        }
814
815        // Beta coefficients should be near zero
816        for j in 0..2 {
817            for g in 0..m_total {
818                assert!(
819                    result.beta[(j, g)].abs() < 1e-8,
820                    "beta[{j},{g}] = {}, expected ~0",
821                    result.beta[(j, g)]
822                );
823            }
824        }
825    }
826
827    #[test]
828    fn test_fosr_2d_single_predictor() {
829        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.01);
830        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
831
832        // With low noise, R^2 should be high
833        assert!(
834            result.r_squared > 0.8,
835            "R^2 = {}, expected > 0.8",
836            result.r_squared
837        );
838    }
839
840    #[test]
841    fn test_fosr_2d_fitted_plus_residuals() {
842        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
843        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
844
845        let (n, m_total) = y.shape();
846        for i in 0..n {
847            for g in 0..m_total {
848                let reconstructed = result.fitted[(i, g)] + result.residuals[(i, g)];
849                assert!(
850                    (reconstructed - y[(i, g)]).abs() < 1e-10,
851                    "fitted + residuals != y at ({i},{g})"
852                );
853            }
854        }
855    }
856
857    #[test]
858    fn test_fosr_2d_r_squared_range() {
859        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
860        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
861
862        for (g, &r2) in result.r_squared_pointwise.iter().enumerate() {
863            assert!(
864                (-0.01..=1.0 + 1e-10).contains(&r2),
865                "R^2 out of range at grid point {g}: {r2}"
866            );
867        }
868    }
869
870    #[test]
871    fn test_fosr_2d_predict_matches_fitted() {
872        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
873        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
874
875        let preds = predict_fosr_2d(&result, &z).unwrap();
876        let (n, m_total) = y.shape();
877        for i in 0..n {
878            for g in 0..m_total {
879                assert!(
880                    (preds[(i, g)] - result.fitted[(i, g)]).abs() < 1e-8,
881                    "prediction != fitted at ({i},{g})"
882                );
883            }
884        }
885    }
886
887    #[test]
888    fn test_fosr_2d_reshape_beta_surface() {
889        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
890        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
891
892        let surface = result.beta_surface(0);
893        assert_eq!(surface.shape(), (5, 4));
894
895        let r2_surface = result.r_squared_surface();
896        assert_eq!(r2_surface.shape(), (5, 4));
897
898        let resid_surface = result.residual_surface(0);
899        assert_eq!(resid_surface.shape(), (5, 4));
900    }
901
902    #[test]
903    fn test_fosr_2d_dimension_mismatch() {
904        let grid = make_grid(5, 4);
905
906        // Wrong number of columns in data
907        let y = FdMatrix::zeros(20, 10); // 10 != 5*4=20
908        let z = FdMatrix::zeros(20, 2);
909        assert!(fosr_2d(&y, &z, &grid, 0.0, 0.0).is_err());
910
911        // Mismatched rows between data and predictors
912        let y = FdMatrix::zeros(20, 20);
913        let z = FdMatrix::zeros(10, 2);
914        assert!(fosr_2d(&y, &z, &grid, 0.0, 0.0).is_err());
915
916        // Too few observations
917        let y = FdMatrix::zeros(3, 20);
918        let z = FdMatrix::zeros(3, 2);
919        assert!(fosr_2d(&y, &z, &grid, 0.0, 0.0).is_err());
920
921        // Empty grid
922        let empty_grid = Grid2d::new(vec![], vec![0.0, 1.0]);
923        let y = FdMatrix::zeros(20, 0);
924        let z = FdMatrix::zeros(20, 2);
925        assert!(fosr_2d(&y, &z, &empty_grid, 0.0, 0.0).is_err());
926
927        // Predictor dimension mismatch in predict
928        let grid = make_grid(3, 3);
929        let y = FdMatrix::zeros(20, 9);
930        let mut z = FdMatrix::zeros(20, 2);
931        for i in 0..20 {
932            z[(i, 0)] = i as f64;
933            z[(i, 1)] = (i * 3 % 7) as f64;
934        }
935        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
936        let z_bad = FdMatrix::zeros(5, 3); // 3 != 2
937        assert!(predict_fosr_2d(&result, &z_bad).is_err());
938    }
939
940    #[test]
941    fn test_fosr_2d_gcv() {
942        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
943        // Negative lambda triggers GCV selection
944        let result = fosr_2d(&y, &z, &grid, -1.0, -1.0).unwrap();
945        assert!(result.lambda_s >= 0.0);
946        assert!(result.lambda_t >= 0.0);
947        assert!(result.gcv > 0.0);
948        assert!(result.r_squared > 0.5);
949    }
950}