Skip to main content

fdars_core/
function_on_scalar_2d.rs

1//! 2D Function-on-Scalar Regression (FOSR).
2//!
3//! Extends the 1D function-on-scalar model to surface-valued functional
4//! responses observed on a regular 2D grid:
5//! ```text
6//! Y_i(s,t) = beta_0(s,t) + sum_j z_{ij} beta_j(s,t) + epsilon_i(s,t)
7//! ```
8//!
9//! The estimation uses a two-step approach:
10//! 1. Pointwise OLS at each grid point to obtain raw coefficient surfaces.
11//! 2. Tensor-product roughness penalty smoothing of each coefficient surface.
12//!
13//! # Methods
14//!
15//! - [`fosr_2d`]: Penalized 2D FOSR with anisotropic smoothing
16//! - [`predict_fosr_2d`]: Predict new surfaces from a fitted model
17
18use crate::error::FdarError;
19use crate::function_on_scalar::{
20    build_fosr_design, compute_xtx, compute_xty_matrix, penalty_matrix, pointwise_r_squared,
21};
22use crate::matrix::FdMatrix;
23
24// ---------------------------------------------------------------------------
25// Types
26// ---------------------------------------------------------------------------
27
28/// 2D grid description for surface-valued functional data.
29#[derive(Debug, Clone, PartialEq)]
30pub struct Grid2d {
31    /// Grid points along the first (row) dimension.
32    pub argvals_s: Vec<f64>,
33    /// Grid points along the second (column) dimension.
34    pub argvals_t: Vec<f64>,
35}
36
37impl Grid2d {
38    /// Create a new 2D grid from argument vectors.
39    pub fn new(argvals_s: Vec<f64>, argvals_t: Vec<f64>) -> Self {
40        Self {
41            argvals_s,
42            argvals_t,
43        }
44    }
45
46    /// Number of grid points in the first dimension.
47    #[inline]
48    pub fn m1(&self) -> usize {
49        self.argvals_s.len()
50    }
51
52    /// Number of grid points in the second dimension.
53    #[inline]
54    pub fn m2(&self) -> usize {
55        self.argvals_t.len()
56    }
57
58    /// Total number of grid points (m1 * m2).
59    #[inline]
60    pub fn total(&self) -> usize {
61        self.m1() * self.m2()
62    }
63}
64
65/// Result of 2D function-on-scalar regression.
66#[derive(Debug, Clone, PartialEq)]
67#[non_exhaustive]
68pub struct FosrResult2d {
69    /// Intercept surface beta_0(s,t), flattened column-major (length m1*m2).
70    pub intercept: Vec<f64>,
71    /// Coefficient surfaces beta_j(s,t), p x (m1*m2) matrix (row j = flattened beta_j).
72    pub beta: FdMatrix,
73    /// Fitted surface values, n x (m1*m2) matrix.
74    pub fitted: FdMatrix,
75    /// Residual surfaces, n x (m1*m2) matrix.
76    pub residuals: FdMatrix,
77    /// Pointwise R^2(s,t), flattened (length m1*m2).
78    pub r_squared_pointwise: Vec<f64>,
79    /// Global R^2 (average of pointwise).
80    pub r_squared: f64,
81    /// Standard error surfaces for each beta_j (p x (m1*m2)), or None if
82    /// the system is underdetermined.
83    pub beta_se: Option<FdMatrix>,
84    /// Smoothing parameter in the s-direction.
85    pub lambda_s: f64,
86    /// Smoothing parameter in the t-direction.
87    pub lambda_t: f64,
88    /// Generalised cross-validation score.
89    pub gcv: f64,
90    /// Grid specification.
91    pub grid: Grid2d,
92}
93
94impl FosrResult2d {
95    /// Reshape the j-th coefficient surface into an m1 x m2 matrix.
96    ///
97    /// # Panics
98    /// Panics if `j >= p` (number of predictors).
99    #[must_use]
100    pub fn beta_surface(&self, j: usize) -> FdMatrix {
101        let m1 = self.grid.m1();
102        let m2 = self.grid.m2();
103        let m_total = m1 * m2;
104        let mut mat = FdMatrix::zeros(m1, m2);
105        for g in 0..m_total {
106            // Column-major: g = s_idx + t_idx * m1
107            let s_idx = g % m1;
108            let t_idx = g / m1;
109            mat[(s_idx, t_idx)] = self.beta[(j, g)];
110        }
111        mat
112    }
113
114    /// Reshape pointwise R^2(s,t) into an m1 x m2 matrix.
115    #[must_use]
116    pub fn r_squared_surface(&self) -> FdMatrix {
117        let m1 = self.grid.m1();
118        let m2 = self.grid.m2();
119        let m_total = m1 * m2;
120        let mut mat = FdMatrix::zeros(m1, m2);
121        for g in 0..m_total {
122            let s_idx = g % m1;
123            let t_idx = g / m1;
124            mat[(s_idx, t_idx)] = self.r_squared_pointwise[g];
125        }
126        mat
127    }
128
129    /// Reshape the residual for observation `i` into an m1 x m2 matrix.
130    ///
131    /// # Panics
132    /// Panics if `i >= n` (number of observations).
133    #[must_use]
134    pub fn residual_surface(&self, i: usize) -> FdMatrix {
135        let m1 = self.grid.m1();
136        let m2 = self.grid.m2();
137        let m_total = m1 * m2;
138        let mut mat = FdMatrix::zeros(m1, m2);
139        for g in 0..m_total {
140            let s_idx = g % m1;
141            let t_idx = g / m1;
142            mat[(s_idx, t_idx)] = self.residuals[(i, g)];
143        }
144        mat
145    }
146
147    /// Predict functional surfaces for new predictors. Delegates to
148    /// [`predict_fosr_2d`].
149    pub fn predict(&self, new_predictors: &FdMatrix) -> Result<FdMatrix, FdarError> {
150        predict_fosr_2d(self, new_predictors)
151    }
152}
153
154// ---------------------------------------------------------------------------
155// Linear algebra helpers
156// ---------------------------------------------------------------------------
157
158/// Cholesky factorization: A = LL'. Returns L (p x p flat row-major) or None
159/// if the matrix is not positive definite.
160fn cholesky_factor(a: &[f64], p: usize) -> Option<Vec<f64>> {
161    let mut l = vec![0.0; p * p];
162    for j in 0..p {
163        let mut diag = a[j * p + j];
164        for k in 0..j {
165            diag -= l[j * p + k] * l[j * p + k];
166        }
167        if diag <= 1e-12 {
168            return None;
169        }
170        l[j * p + j] = diag.sqrt();
171        for i in (j + 1)..p {
172            let mut s = a[i * p + j];
173            for k in 0..j {
174                s -= l[i * p + k] * l[j * p + k];
175            }
176            l[i * p + j] = s / l[j * p + j];
177        }
178    }
179    Some(l)
180}
181
182/// Solve Lz = b (forward) then L'x = z (back).
183fn cholesky_forward_back(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
184    let mut z = b.to_vec();
185    for j in 0..p {
186        for k in 0..j {
187            z[j] -= l[j * p + k] * z[k];
188        }
189        z[j] /= l[j * p + j];
190    }
191    for j in (0..p).rev() {
192        for k in (j + 1)..p {
193            z[j] -= l[k * p + j] * z[k];
194        }
195        z[j] /= l[j * p + j];
196    }
197    z
198}
199
200/// Kronecker product of two flat row-major matrices.
201///
202/// Given A (rows_a x cols_a) and B (rows_b x cols_b), produces
203/// C = A kron B of size (rows_a * rows_b) x (cols_a * cols_b) in
204/// row-major layout.
205fn kronecker_product(
206    a: &[f64],
207    rows_a: usize,
208    cols_a: usize,
209    b: &[f64],
210    rows_b: usize,
211    cols_b: usize,
212) -> Vec<f64> {
213    let out_rows = rows_a * rows_b;
214    let out_cols = cols_a * cols_b;
215    let mut c = vec![0.0; out_rows * out_cols];
216    for ia in 0..rows_a {
217        for ja in 0..cols_a {
218            let a_val = a[ia * cols_a + ja];
219            for ib in 0..rows_b {
220                for jb in 0..cols_b {
221                    let row = ia * rows_b + ib;
222                    let col = ja * cols_b + jb;
223                    c[row * out_cols + col] = a_val * b[ib * cols_b + jb];
224                }
225            }
226        }
227    }
228    c
229}
230
231/// Identity matrix as flat row-major vector.
232fn identity_matrix(n: usize) -> Vec<f64> {
233    let mut m = vec![0.0; n * n];
234    for i in 0..n {
235        m[i * n + i] = 1.0;
236    }
237    m
238}
239
240/// Build the 2D tensor-product penalty matrix.
241///
242/// P_2d = lambda_s * (P_s kron I_t) + lambda_t * (I_s kron P_t)
243///
244/// where P_s = D_s'D_s (m1 x m1) and P_t = D_t'D_t (m2 x m2) are the
245/// second-difference penalty matrices.
246fn penalty_matrix_2d(m1: usize, m2: usize, lambda_s: f64, lambda_t: f64) -> Vec<f64> {
247    let m_total = m1 * m2;
248    let ps = penalty_matrix(m1);
249    let pt = penalty_matrix(m2);
250    let i_t = identity_matrix(m2);
251    let i_s = identity_matrix(m1);
252
253    let ps_kron_it = kronecker_product(&ps, m1, m1, &i_t, m2, m2);
254    let is_kron_pt = kronecker_product(&i_s, m1, m1, &pt, m2, m2);
255
256    let mut p2d = vec![0.0; m_total * m_total];
257    for i in 0..m_total * m_total {
258        p2d[i] = lambda_s * ps_kron_it[i] + lambda_t * is_kron_pt[i];
259    }
260    p2d
261}
262
263// ---------------------------------------------------------------------------
264// Core fitting routines
265// ---------------------------------------------------------------------------
266
267/// Compute fitted values Y_hat = X * beta and residuals Y - Y_hat.
268fn compute_fitted_residuals(
269    design: &FdMatrix,
270    beta: &FdMatrix,
271    data: &FdMatrix,
272) -> (FdMatrix, FdMatrix) {
273    let (n, m_total) = data.shape();
274    let p_total = design.ncols();
275    let mut fitted = FdMatrix::zeros(n, m_total);
276    let mut residuals = FdMatrix::zeros(n, m_total);
277    for i in 0..n {
278        for g in 0..m_total {
279            let mut yhat = 0.0;
280            for j in 0..p_total {
281                yhat += design[(i, j)] * beta[(j, g)];
282            }
283            fitted[(i, g)] = yhat;
284            residuals[(i, g)] = data[(i, g)] - yhat;
285        }
286    }
287    (fitted, residuals)
288}
289
290/// Compute GCV: (1/n*m) * sum(r^2) / (1 - tr(H)/n)^2.
291fn compute_gcv(residuals: &FdMatrix, trace_h: f64) -> f64 {
292    let (n, m) = residuals.shape();
293    let denom = (1.0 - trace_h / n as f64).max(1e-10);
294    let ss_res: f64 = residuals.as_slice().iter().map(|v| v * v).sum();
295    ss_res / (n as f64 * m as f64 * denom * denom)
296}
297
298/// Compute trace of hat matrix for OLS: tr(H) = p_total (since X(X'X)^{-1}X'
299/// projects onto a p_total-dimensional subspace).
300fn trace_hat_ols(p_total: usize) -> f64 {
301    p_total as f64
302}
303
304/// Smooth a raw coefficient vector (length m_total) using the 2D penalty.
305///
306/// Solves (I + P_2d) * beta_smooth = beta_raw via Cholesky.
307fn smooth_coefficient_surface(
308    beta_raw: &[f64],
309    penalty_2d: &[f64],
310    m_total: usize,
311) -> Result<Vec<f64>, FdarError> {
312    // Build (I + P_2d)
313    let mut a = penalty_2d.to_vec();
314    for i in 0..m_total {
315        a[i * m_total + i] += 1.0;
316    }
317    let l = cholesky_factor(&a, m_total).ok_or_else(|| FdarError::ComputationFailed {
318        operation: "smooth_coefficient_surface",
319        detail: "Cholesky factorization of (I + P_2d) failed; try increasing the smoothing parameter (lambda)".to_string(),
320    })?;
321    Ok(cholesky_forward_back(&l, beta_raw, m_total))
322}
323
324/// Compute standard errors from the diagonal of (X'X)^{-1} and residual
325/// variance at each grid point.
326fn compute_beta_se_2d(
327    xtx: &[f64],
328    residuals: &FdMatrix,
329    p_total: usize,
330    n: usize,
331) -> Option<FdMatrix> {
332    let m_total = residuals.ncols();
333    let l = cholesky_factor(xtx, p_total)?;
334
335    // Diagonal of (X'X)^{-1}
336    let a_inv_diag: Vec<f64> = (0..p_total)
337        .map(|j| {
338            let mut ej = vec![0.0; p_total];
339            ej[j] = 1.0;
340            let v = cholesky_forward_back(&l, &ej, p_total);
341            v[j]
342        })
343        .collect();
344
345    let df = (n - p_total).max(1) as f64;
346    // We return SE for the predictor coefficients only (drop intercept).
347    let p = p_total - 1;
348    let mut se = FdMatrix::zeros(p, m_total);
349    for g in 0..m_total {
350        let sigma2: f64 = (0..n).map(|i| residuals[(i, g)].powi(2)).sum::<f64>() / df;
351        for j in 0..p {
352            // j+1 to skip intercept row in a_inv_diag
353            se[(j, g)] = (sigma2 * a_inv_diag[j + 1]).max(0.0).sqrt();
354        }
355    }
356    Some(se)
357}
358
359// ---------------------------------------------------------------------------
360// GCV lambda selection
361// ---------------------------------------------------------------------------
362
363/// Select (lambda_s, lambda_t) via GCV over a 2D grid of candidate values.
364fn select_lambdas_gcv(
365    xtx: &[f64],
366    xty: &FdMatrix,
367    design: &FdMatrix,
368    data: &FdMatrix,
369    m1: usize,
370    m2: usize,
371    fix_lambda_s: Option<f64>,
372    fix_lambda_t: Option<f64>,
373) -> (f64, f64) {
374    let candidates = [0.0, 1e-6, 1e-4, 1e-2, 0.1, 1.0, 10.0, 100.0, 1000.0];
375    let p_total = design.ncols();
376    let m_total = m1 * m2;
377
378    let ls_candidates: Vec<f64> = if let Some(ls) = fix_lambda_s {
379        vec![ls]
380    } else {
381        candidates.to_vec()
382    };
383    let lt_candidates: Vec<f64> = if let Some(lt) = fix_lambda_t {
384        vec![lt]
385    } else {
386        candidates.to_vec()
387    };
388
389    // Pre-compute OLS inverse and raw beta
390    let Some(l_xtx) = cholesky_factor(xtx, p_total) else {
391        return (0.0, 0.0);
392    };
393
394    // beta_ols: p_total x m_total
395    let mut beta_ols = FdMatrix::zeros(p_total, m_total);
396    for g in 0..m_total {
397        let b: Vec<f64> = (0..p_total).map(|j| xty[(j, g)]).collect();
398        let x = cholesky_forward_back(&l_xtx, &b, p_total);
399        for j in 0..p_total {
400            beta_ols[(j, g)] = x[j];
401        }
402    }
403
404    let trace_h = trace_hat_ols(p_total);
405    let mut best_ls = 0.0;
406    let mut best_lt = 0.0;
407    let mut best_gcv = f64::INFINITY;
408
409    for &ls in &ls_candidates {
410        for &lt in &lt_candidates {
411            if ls == 0.0 && lt == 0.0 {
412                // No smoothing: use raw OLS
413                let (_, residuals) = compute_fitted_residuals(design, &beta_ols, data);
414                let gcv = compute_gcv(&residuals, trace_h);
415                if gcv < best_gcv {
416                    best_gcv = gcv;
417                    best_ls = ls;
418                    best_lt = lt;
419                }
420                continue;
421            }
422
423            let p2d = penalty_matrix_2d(m1, m2, ls, lt);
424            let mut beta_smooth = FdMatrix::zeros(p_total, m_total);
425            let mut ok = true;
426            for j in 0..p_total {
427                let raw: Vec<f64> = (0..m_total).map(|g| beta_ols[(j, g)]).collect();
428                if let Ok(smoothed) = smooth_coefficient_surface(&raw, &p2d, m_total) {
429                    for g in 0..m_total {
430                        beta_smooth[(j, g)] = smoothed[g];
431                    }
432                } else {
433                    ok = false;
434                    break;
435                }
436            }
437            if !ok {
438                continue;
439            }
440
441            let (_, residuals) = compute_fitted_residuals(design, &beta_smooth, data);
442            let gcv = compute_gcv(&residuals, trace_h);
443            if gcv < best_gcv {
444                best_gcv = gcv;
445                best_ls = ls;
446                best_lt = lt;
447            }
448        }
449    }
450
451    (best_ls, best_lt)
452}
453
454// ---------------------------------------------------------------------------
455// Public API
456// ---------------------------------------------------------------------------
457
458/// 2D Function-on-Scalar Regression with tensor-product penalty.
459///
460/// Fits the model
461/// ```text
462/// Y_i(s,t) = beta_0(s,t) + sum_j x_{ij} beta_j(s,t) + epsilon_i(s,t)
463/// ```
464/// with anisotropic roughness penalty
465/// `lambda_s * ||d^2 beta/ds^2||^2 + lambda_t * ||d^2 beta/dt^2||^2`.
466///
467/// # Arguments
468/// * `data` - Functional response matrix (n x m_total where m_total = m1*m2)
469/// * `predictors` - Scalar predictor matrix (n x p)
470/// * `grid` - 2D grid specification
471/// * `lambda_s` - Smoothing parameter in the s-direction (negative = GCV)
472/// * `lambda_t` - Smoothing parameter in the t-direction (negative = GCV)
473///
474/// # Errors
475///
476/// Returns [`FdarError::InvalidDimension`] if `data.ncols() != grid.total()`,
477/// `predictors.nrows() != data.nrows()`, or `n < p + 2`.
478/// Returns [`FdarError::InvalidParameter`] if the grid has zero size in
479/// either dimension.
480/// Returns [`FdarError::ComputationFailed`] if the Cholesky factorization
481/// fails during OLS or smoothing.
482///
483/// # Examples
484///
485/// ```
486/// use fdars_core::matrix::FdMatrix;
487/// use fdars_core::function_on_scalar_2d::{fosr_2d, Grid2d};
488///
489/// let m1 = 5;
490/// let m2 = 4;
491/// let n = 20;
492/// let grid = Grid2d::new(
493///     (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect(),
494///     (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect(),
495/// );
496/// let data = FdMatrix::from_column_major(
497///     (0..n * m1 * m2).map(|k| {
498///         let i = (k % n) as f64;
499///         let j = (k / n) as f64;
500///         ((i + 1.0) * 0.3 + j * 0.7).sin()
501///     }).collect(),
502///     n, m1 * m2,
503/// ).unwrap();
504/// let predictors = FdMatrix::from_column_major(
505///     (0..n * 2).map(|k| {
506///         let i = (k % n) as f64;
507///         let j = (k / n) as f64;
508///         (i * 0.4 + j * 1.5).cos()
509///     }).collect(),
510///     n, 2,
511/// ).unwrap();
512/// let result = fosr_2d(&data, &predictors, &grid, 0.1, 0.1).unwrap();
513/// assert_eq!(result.fitted.shape(), (n, m1 * m2));
514/// ```
515#[must_use = "expensive computation whose result should not be discarded"]
516pub fn fosr_2d(
517    data: &FdMatrix,
518    predictors: &FdMatrix,
519    grid: &Grid2d,
520    lambda_s: f64,
521    lambda_t: f64,
522) -> Result<FosrResult2d, FdarError> {
523    let (n, m_data) = data.shape();
524    let p = predictors.ncols();
525    let m1 = grid.m1();
526    let m2 = grid.m2();
527    let m_total = grid.total();
528
529    // ---- Validate inputs ----
530
531    if m1 == 0 {
532        return Err(FdarError::InvalidParameter {
533            parameter: "grid",
534            message: "argvals_s must not be empty".to_string(),
535        });
536    }
537    if m2 == 0 {
538        return Err(FdarError::InvalidParameter {
539            parameter: "grid",
540            message: "argvals_t must not be empty".to_string(),
541        });
542    }
543    if m_data != m_total {
544        return Err(FdarError::InvalidDimension {
545            parameter: "data",
546            expected: format!("{m_total} columns (grid m1*m2 = {m1}*{m2})"),
547            actual: format!("{m_data} columns"),
548        });
549    }
550    if predictors.nrows() != n {
551        return Err(FdarError::InvalidDimension {
552            parameter: "predictors",
553            expected: format!("{n} rows (matching data)"),
554            actual: format!("{} rows", predictors.nrows()),
555        });
556    }
557    if n < p + 2 {
558        return Err(FdarError::InvalidDimension {
559            parameter: "data",
560            expected: format!("at least {} observations (p + 2)", p + 2),
561            actual: format!("{n} observations"),
562        });
563    }
564
565    // ---- Step 1: Build design and compute OLS ----
566
567    let design = build_fosr_design(predictors, n);
568    let p_total = design.ncols(); // p + 1
569    let xtx = compute_xtx(&design);
570    let xty = compute_xty_matrix(&design, data);
571
572    let l_xtx = cholesky_factor(&xtx, p_total).ok_or_else(|| FdarError::ComputationFailed {
573        operation: "fosr_2d",
574        detail: "Cholesky factorization of X'X failed; design matrix is rank-deficient — remove constant or collinear predictors, or add regularization".to_string(),
575    })?;
576
577    // Pointwise OLS: beta_ols[:,g] = (X'X)^{-1} X' y[:,g]
578    let mut beta_ols = FdMatrix::zeros(p_total, m_total);
579    for g in 0..m_total {
580        let b: Vec<f64> = (0..p_total).map(|j| xty[(j, g)]).collect();
581        let x = cholesky_forward_back(&l_xtx, &b, p_total);
582        for j in 0..p_total {
583            beta_ols[(j, g)] = x[j];
584        }
585    }
586
587    // ---- Step 2: Determine lambda values ----
588
589    let fix_ls = if lambda_s >= 0.0 {
590        Some(lambda_s)
591    } else {
592        None
593    };
594    let fix_lt = if lambda_t >= 0.0 {
595        Some(lambda_t)
596    } else {
597        None
598    };
599
600    let (lambda_s_final, lambda_t_final) = if fix_ls.is_some() && fix_lt.is_some() {
601        (lambda_s, lambda_t)
602    } else {
603        select_lambdas_gcv(&xtx, &xty, &design, data, m1, m2, fix_ls, fix_lt)
604    };
605
606    // ---- Step 3: Smooth coefficient surfaces ----
607
608    let beta_smooth = if lambda_s_final == 0.0 && lambda_t_final == 0.0 {
609        beta_ols
610    } else {
611        let p2d = penalty_matrix_2d(m1, m2, lambda_s_final, lambda_t_final);
612        let mut smoothed = FdMatrix::zeros(p_total, m_total);
613        for j in 0..p_total {
614            let raw: Vec<f64> = (0..m_total).map(|g| beta_ols[(j, g)]).collect();
615            let s = smooth_coefficient_surface(&raw, &p2d, m_total)?;
616            for g in 0..m_total {
617                smoothed[(j, g)] = s[g];
618            }
619        }
620        smoothed
621    };
622
623    // ---- Step 4: Compute diagnostics ----
624
625    let (fitted, residuals) = compute_fitted_residuals(&design, &beta_smooth, data);
626
627    let r_squared_pointwise = pointwise_r_squared(data, &fitted);
628    let r_squared = if m_total > 0 {
629        r_squared_pointwise.iter().sum::<f64>() / m_total as f64
630    } else {
631        0.0
632    };
633
634    let trace_h = trace_hat_ols(p_total);
635    let gcv = compute_gcv(&residuals, trace_h);
636
637    let beta_se = compute_beta_se_2d(&xtx, &residuals, p_total, n);
638
639    // Extract intercept (row 0 of beta_smooth)
640    let intercept: Vec<f64> = (0..m_total).map(|g| beta_smooth[(0, g)]).collect();
641
642    // Extract predictor coefficients (rows 1..p_total)
643    let mut beta_out = FdMatrix::zeros(p, m_total);
644    for j in 0..p {
645        for g in 0..m_total {
646            beta_out[(j, g)] = beta_smooth[(j + 1, g)];
647        }
648    }
649
650    Ok(FosrResult2d {
651        intercept,
652        beta: beta_out,
653        fitted,
654        residuals,
655        r_squared_pointwise,
656        r_squared,
657        beta_se,
658        lambda_s: lambda_s_final,
659        lambda_t: lambda_t_final,
660        gcv,
661        grid: grid.clone(),
662    })
663}
664
665/// Predict functional surfaces for new observations.
666///
667/// # Arguments
668/// * `result` - Fitted [`FosrResult2d`]
669/// * `new_predictors` - New scalar predictors (n_new x p)
670///
671/// # Errors
672///
673/// Returns [`FdarError::InvalidDimension`] if the number of predictor columns
674/// does not match the fitted model.
675#[must_use = "prediction result should not be discarded"]
676pub fn predict_fosr_2d(
677    result: &FosrResult2d,
678    new_predictors: &FdMatrix,
679) -> Result<FdMatrix, FdarError> {
680    let n_new = new_predictors.nrows();
681    let m_total = result.intercept.len();
682    let p = result.beta.nrows();
683
684    if new_predictors.ncols() != p {
685        return Err(FdarError::InvalidDimension {
686            parameter: "new_predictors",
687            expected: format!("{p} columns (matching fitted model)"),
688            actual: format!("{} columns", new_predictors.ncols()),
689        });
690    }
691
692    let mut predicted = FdMatrix::zeros(n_new, m_total);
693    for i in 0..n_new {
694        for g in 0..m_total {
695            let mut yhat = result.intercept[g];
696            for j in 0..p {
697                yhat += new_predictors[(i, j)] * result.beta[(j, g)];
698            }
699            predicted[(i, g)] = yhat;
700        }
701    }
702    Ok(predicted)
703}
704
705// ---------------------------------------------------------------------------
706// Tests
707// ---------------------------------------------------------------------------
708
709#[cfg(test)]
710mod tests {
711    use super::*;
712
713    fn uniform_grid_1d(m: usize) -> Vec<f64> {
714        (0..m).map(|j| j as f64 / (m - 1).max(1) as f64).collect()
715    }
716
717    fn make_grid(m1: usize, m2: usize) -> Grid2d {
718        Grid2d::new(uniform_grid_1d(m1), uniform_grid_1d(m2))
719    }
720
721    /// Generate test data: Y_i(s,t) = intercept(s,t) + z_1 * beta_1(s,t) + z_2 * beta_2(s,t) + noise
722    fn generate_2d_data(
723        n: usize,
724        m1: usize,
725        m2: usize,
726        noise_scale: f64,
727    ) -> (FdMatrix, FdMatrix, Grid2d) {
728        let grid = make_grid(m1, m2);
729        let m_total = m1 * m2;
730        let mut y = FdMatrix::zeros(n, m_total);
731        let mut z = FdMatrix::zeros(n, 2);
732
733        for i in 0..n {
734            let z1 = (i as f64) / (n as f64);
735            let z2 = if i % 2 == 0 { 1.0 } else { 0.0 };
736            z[(i, 0)] = z1;
737            z[(i, 1)] = z2;
738
739            for si in 0..m1 {
740                for ti in 0..m2 {
741                    let g = si + ti * m1; // column-major flat index
742                    let s = grid.argvals_s[si];
743                    let t = grid.argvals_t[ti];
744
745                    let intercept = s + t;
746                    let beta1 = s * t;
747                    let beta2 = s - t;
748                    let noise = noise_scale * ((i * 13 + si * 7 + ti * 3) % 100) as f64 / 100.0;
749
750                    y[(i, g)] = intercept + z1 * beta1 + z2 * beta2 + noise;
751                }
752            }
753        }
754        (y, z, grid)
755    }
756
757    #[test]
758    fn test_grid2d_basic() {
759        let grid = make_grid(5, 4);
760        assert_eq!(grid.m1(), 5);
761        assert_eq!(grid.m2(), 4);
762        assert_eq!(grid.total(), 20);
763    }
764
765    #[test]
766    fn test_kronecker_product_small() {
767        // A = [[1,2],[3,4]], B = [[0,5],[6,7]]
768        // A kron B = [[0,5,0,10],[6,7,12,14],[0,15,0,20],[18,21,24,28]]
769        let a = vec![1.0, 2.0, 3.0, 4.0];
770        let b = vec![0.0, 5.0, 6.0, 7.0];
771        let c = kronecker_product(&a, 2, 2, &b, 2, 2);
772        assert_eq!(c.len(), 16);
773        #[rustfmt::skip]
774        let expected = vec![
775            0.0, 5.0, 0.0, 10.0,
776            6.0, 7.0, 12.0, 14.0,
777            0.0, 15.0, 0.0, 20.0,
778            18.0, 21.0, 24.0, 28.0,
779        ];
780        for (i, (&ci, &ei)) in c.iter().zip(expected.iter()).enumerate() {
781            assert!(
782                (ci - ei).abs() < 1e-12,
783                "kronecker mismatch at index {i}: got {ci}, expected {ei}"
784            );
785        }
786    }
787
788    #[test]
789    fn test_penalty_2d_symmetry() {
790        let m1 = 5;
791        let m2 = 4;
792        let p2d = penalty_matrix_2d(m1, m2, 1.0, 1.0);
793        let m_total = m1 * m2;
794        for i in 0..m_total {
795            for j in 0..m_total {
796                assert!(
797                    (p2d[i * m_total + j] - p2d[j * m_total + i]).abs() < 1e-12,
798                    "P_2d not symmetric at ({i},{j})"
799                );
800            }
801        }
802    }
803
804    #[test]
805    fn test_penalty_2d_shape() {
806        let m1 = 5;
807        let m2 = 4;
808        let p2d = penalty_matrix_2d(m1, m2, 1.0, 2.0);
809        assert_eq!(p2d.len(), (m1 * m2) * (m1 * m2));
810    }
811
812    #[test]
813    fn test_fosr_2d_constant_response() {
814        let n = 20;
815        let m1 = 5;
816        let m2 = 4;
817        let grid = make_grid(m1, m2);
818        let m_total = m1 * m2;
819
820        // Y_i(s,t) = 3.0 for all i,s,t
821        let mut y = FdMatrix::zeros(n, m_total);
822        for i in 0..n {
823            for g in 0..m_total {
824                y[(i, g)] = 3.0;
825            }
826        }
827
828        let mut z = FdMatrix::zeros(n, 2);
829        for i in 0..n {
830            z[(i, 0)] = i as f64;
831            z[(i, 1)] = (i % 3) as f64;
832        }
833
834        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
835
836        // Intercept should be near 3.0 everywhere
837        for g in 0..m_total {
838            assert!(
839                (result.intercept[g] - 3.0).abs() < 1e-8,
840                "intercept[{g}] = {}, expected 3.0",
841                result.intercept[g]
842            );
843        }
844
845        // Beta coefficients should be near zero
846        for j in 0..2 {
847            for g in 0..m_total {
848                assert!(
849                    result.beta[(j, g)].abs() < 1e-8,
850                    "beta[{j},{g}] = {}, expected ~0",
851                    result.beta[(j, g)]
852                );
853            }
854        }
855    }
856
857    #[test]
858    fn test_fosr_2d_single_predictor() {
859        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.01);
860        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
861
862        // With low noise, R^2 should be high
863        assert!(
864            result.r_squared > 0.8,
865            "R^2 = {}, expected > 0.8",
866            result.r_squared
867        );
868    }
869
870    #[test]
871    fn test_fosr_2d_fitted_plus_residuals() {
872        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
873        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
874
875        let (n, m_total) = y.shape();
876        for i in 0..n {
877            for g in 0..m_total {
878                let reconstructed = result.fitted[(i, g)] + result.residuals[(i, g)];
879                assert!(
880                    (reconstructed - y[(i, g)]).abs() < 1e-10,
881                    "fitted + residuals != y at ({i},{g})"
882                );
883            }
884        }
885    }
886
887    #[test]
888    fn test_fosr_2d_r_squared_range() {
889        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
890        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
891
892        for (g, &r2) in result.r_squared_pointwise.iter().enumerate() {
893            assert!(
894                (-0.01..=1.0 + 1e-10).contains(&r2),
895                "R^2 out of range at grid point {g}: {r2}"
896            );
897        }
898    }
899
900    #[test]
901    fn test_fosr_2d_predict_matches_fitted() {
902        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
903        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
904
905        let preds = predict_fosr_2d(&result, &z).unwrap();
906        let (n, m_total) = y.shape();
907        for i in 0..n {
908            for g in 0..m_total {
909                assert!(
910                    (preds[(i, g)] - result.fitted[(i, g)]).abs() < 1e-8,
911                    "prediction != fitted at ({i},{g})"
912                );
913            }
914        }
915    }
916
917    #[test]
918    fn test_fosr_2d_reshape_beta_surface() {
919        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
920        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
921
922        let surface = result.beta_surface(0);
923        assert_eq!(surface.shape(), (5, 4));
924
925        let r2_surface = result.r_squared_surface();
926        assert_eq!(r2_surface.shape(), (5, 4));
927
928        let resid_surface = result.residual_surface(0);
929        assert_eq!(resid_surface.shape(), (5, 4));
930    }
931
932    #[test]
933    fn test_fosr_2d_dimension_mismatch() {
934        let grid = make_grid(5, 4);
935
936        // Wrong number of columns in data
937        let y = FdMatrix::zeros(20, 10); // 10 != 5*4=20
938        let z = FdMatrix::zeros(20, 2);
939        assert!(fosr_2d(&y, &z, &grid, 0.0, 0.0).is_err());
940
941        // Mismatched rows between data and predictors
942        let y = FdMatrix::zeros(20, 20);
943        let z = FdMatrix::zeros(10, 2);
944        assert!(fosr_2d(&y, &z, &grid, 0.0, 0.0).is_err());
945
946        // Too few observations
947        let y = FdMatrix::zeros(3, 20);
948        let z = FdMatrix::zeros(3, 2);
949        assert!(fosr_2d(&y, &z, &grid, 0.0, 0.0).is_err());
950
951        // Empty grid
952        let empty_grid = Grid2d::new(vec![], vec![0.0, 1.0]);
953        let y = FdMatrix::zeros(20, 0);
954        let z = FdMatrix::zeros(20, 2);
955        assert!(fosr_2d(&y, &z, &empty_grid, 0.0, 0.0).is_err());
956
957        // Predictor dimension mismatch in predict
958        let grid = make_grid(3, 3);
959        let y = FdMatrix::zeros(20, 9);
960        let mut z = FdMatrix::zeros(20, 2);
961        for i in 0..20 {
962            z[(i, 0)] = i as f64;
963            z[(i, 1)] = (i * 3 % 7) as f64;
964        }
965        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
966        let z_bad = FdMatrix::zeros(5, 3); // 3 != 2
967        assert!(predict_fosr_2d(&result, &z_bad).is_err());
968    }
969
970    #[test]
971    fn test_fosr_2d_gcv() {
972        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
973        // Negative lambda triggers GCV selection
974        let result = fosr_2d(&y, &z, &grid, -1.0, -1.0).unwrap();
975        assert!(result.lambda_s >= 0.0);
976        assert!(result.lambda_t >= 0.0);
977        assert!(result.gcv > 0.0);
978        assert!(result.r_squared > 0.5);
979    }
980}