Skip to main content

fdars_core/
function_on_scalar_2d.rs

1//! 2D Function-on-Scalar Regression (FOSR).
2//!
3//! Extends the 1D function-on-scalar model to surface-valued functional
4//! responses observed on a regular 2D grid:
5//! ```text
6//! Y_i(s,t) = beta_0(s,t) + sum_j z_{ij} beta_j(s,t) + epsilon_i(s,t)
7//! ```
8//!
9//! The estimation uses a two-step approach:
10//! 1. Pointwise OLS at each grid point to obtain raw coefficient surfaces.
11//! 2. Tensor-product roughness penalty smoothing of each coefficient surface.
12//!
13//! # Methods
14//!
15//! - [`fosr_2d`]: Penalized 2D FOSR with anisotropic smoothing
16//! - [`predict_fosr_2d`]: Predict new surfaces from a fitted model
17
18use crate::error::FdarError;
19use crate::function_on_scalar::{
20    build_fosr_design, compute_xty_matrix, penalty_matrix, pointwise_r_squared,
21};
22use crate::linalg::{cholesky_factor, cholesky_forward_back, compute_xtx};
23use crate::matrix::FdMatrix;
24
25// ---------------------------------------------------------------------------
26// Types
27// ---------------------------------------------------------------------------
28
29/// 2D grid description for surface-valued functional data.
30#[derive(Debug, Clone, PartialEq)]
31pub struct Grid2d {
32    /// Grid points along the first (row) dimension.
33    pub argvals_s: Vec<f64>,
34    /// Grid points along the second (column) dimension.
35    pub argvals_t: Vec<f64>,
36}
37
38impl Grid2d {
39    /// Create a new 2D grid from argument vectors.
40    pub fn new(argvals_s: Vec<f64>, argvals_t: Vec<f64>) -> Self {
41        Self {
42            argvals_s,
43            argvals_t,
44        }
45    }
46
47    /// Number of grid points in the first dimension.
48    #[inline]
49    pub fn m1(&self) -> usize {
50        self.argvals_s.len()
51    }
52
53    /// Number of grid points in the second dimension.
54    #[inline]
55    pub fn m2(&self) -> usize {
56        self.argvals_t.len()
57    }
58
59    /// Total number of grid points (m1 * m2).
60    #[inline]
61    pub fn total(&self) -> usize {
62        self.m1() * self.m2()
63    }
64}
65
66/// Result of 2D function-on-scalar regression.
67#[derive(Debug, Clone, PartialEq)]
68#[non_exhaustive]
69pub struct FosrResult2d {
70    /// Intercept surface beta_0(s,t), flattened column-major (length m1*m2).
71    pub intercept: Vec<f64>,
72    /// Coefficient surfaces beta_j(s,t), p x (m1*m2) matrix (row j = flattened beta_j).
73    pub beta: FdMatrix,
74    /// Fitted surface values, n x (m1*m2) matrix.
75    pub fitted: FdMatrix,
76    /// Residual surfaces, n x (m1*m2) matrix.
77    pub residuals: FdMatrix,
78    /// Pointwise R^2(s,t), flattened (length m1*m2).
79    pub r_squared_pointwise: Vec<f64>,
80    /// Global R^2 (average of pointwise).
81    pub r_squared: f64,
82    /// Standard error surfaces for each beta_j (p x (m1*m2)), or None if
83    /// the system is underdetermined.
84    pub beta_se: Option<FdMatrix>,
85    /// Smoothing parameter in the s-direction.
86    pub lambda_s: f64,
87    /// Smoothing parameter in the t-direction.
88    pub lambda_t: f64,
89    /// Generalised cross-validation score.
90    pub gcv: f64,
91    /// Grid specification.
92    pub grid: Grid2d,
93}
94
95impl FosrResult2d {
96    /// Reshape the j-th coefficient surface into an m1 x m2 matrix.
97    ///
98    /// # Panics
99    /// Panics if `j >= p` (number of predictors).
100    #[must_use]
101    pub fn beta_surface(&self, j: usize) -> FdMatrix {
102        let m1 = self.grid.m1();
103        let m2 = self.grid.m2();
104        let m_total = m1 * m2;
105        let mut mat = FdMatrix::zeros(m1, m2);
106        for g in 0..m_total {
107            // Column-major: g = s_idx + t_idx * m1
108            let s_idx = g % m1;
109            let t_idx = g / m1;
110            mat[(s_idx, t_idx)] = self.beta[(j, g)];
111        }
112        mat
113    }
114
115    /// Reshape pointwise R^2(s,t) into an m1 x m2 matrix.
116    #[must_use]
117    pub fn r_squared_surface(&self) -> FdMatrix {
118        let m1 = self.grid.m1();
119        let m2 = self.grid.m2();
120        let m_total = m1 * m2;
121        let mut mat = FdMatrix::zeros(m1, m2);
122        for g in 0..m_total {
123            let s_idx = g % m1;
124            let t_idx = g / m1;
125            mat[(s_idx, t_idx)] = self.r_squared_pointwise[g];
126        }
127        mat
128    }
129
130    /// Reshape the residual for observation `i` into an m1 x m2 matrix.
131    ///
132    /// # Panics
133    /// Panics if `i >= n` (number of observations).
134    #[must_use]
135    pub fn residual_surface(&self, i: usize) -> FdMatrix {
136        let m1 = self.grid.m1();
137        let m2 = self.grid.m2();
138        let m_total = m1 * m2;
139        let mut mat = FdMatrix::zeros(m1, m2);
140        for g in 0..m_total {
141            let s_idx = g % m1;
142            let t_idx = g / m1;
143            mat[(s_idx, t_idx)] = self.residuals[(i, g)];
144        }
145        mat
146    }
147
148    /// Predict functional surfaces for new predictors. Delegates to
149    /// [`predict_fosr_2d`].
150    pub fn predict(&self, new_predictors: &FdMatrix) -> Result<FdMatrix, FdarError> {
151        predict_fosr_2d(self, new_predictors)
152    }
153}
154
155/// Kronecker product of two flat row-major matrices.
156///
157/// Given A (rows_a x cols_a) and B (rows_b x cols_b), produces
158/// C = A kron B of size (rows_a * rows_b) x (cols_a * cols_b) in
159/// row-major layout.
160fn kronecker_product(
161    a: &[f64],
162    rows_a: usize,
163    cols_a: usize,
164    b: &[f64],
165    rows_b: usize,
166    cols_b: usize,
167) -> Vec<f64> {
168    let out_rows = rows_a * rows_b;
169    let out_cols = cols_a * cols_b;
170    let mut c = vec![0.0; out_rows * out_cols];
171    for ia in 0..rows_a {
172        for ja in 0..cols_a {
173            let a_val = a[ia * cols_a + ja];
174            for ib in 0..rows_b {
175                for jb in 0..cols_b {
176                    let row = ia * rows_b + ib;
177                    let col = ja * cols_b + jb;
178                    c[row * out_cols + col] = a_val * b[ib * cols_b + jb];
179                }
180            }
181        }
182    }
183    c
184}
185
186/// Identity matrix as flat row-major vector.
187fn identity_matrix(n: usize) -> Vec<f64> {
188    let mut m = vec![0.0; n * n];
189    for i in 0..n {
190        m[i * n + i] = 1.0;
191    }
192    m
193}
194
195/// Build the 2D tensor-product penalty matrix.
196///
197/// P_2d = lambda_s * (P_s kron I_t) + lambda_t * (I_s kron P_t)
198///
199/// where P_s = D_s'D_s (m1 x m1) and P_t = D_t'D_t (m2 x m2) are the
200/// second-difference penalty matrices.
201fn penalty_matrix_2d(m1: usize, m2: usize, lambda_s: f64, lambda_t: f64) -> Vec<f64> {
202    let m_total = m1 * m2;
203    let ps = penalty_matrix(m1);
204    let pt = penalty_matrix(m2);
205    let i_t = identity_matrix(m2);
206    let i_s = identity_matrix(m1);
207
208    let ps_kron_it = kronecker_product(&ps, m1, m1, &i_t, m2, m2);
209    let is_kron_pt = kronecker_product(&i_s, m1, m1, &pt, m2, m2);
210
211    let mut p2d = vec![0.0; m_total * m_total];
212    for i in 0..m_total * m_total {
213        p2d[i] = lambda_s * ps_kron_it[i] + lambda_t * is_kron_pt[i];
214    }
215    p2d
216}
217
218// ---------------------------------------------------------------------------
219// Core fitting routines
220// ---------------------------------------------------------------------------
221
222/// Compute fitted values Y_hat = X * beta and residuals Y - Y_hat.
223fn compute_fitted_residuals(
224    design: &FdMatrix,
225    beta: &FdMatrix,
226    data: &FdMatrix,
227) -> (FdMatrix, FdMatrix) {
228    let (n, m_total) = data.shape();
229    let p_total = design.ncols();
230    let mut fitted = FdMatrix::zeros(n, m_total);
231    let mut residuals = FdMatrix::zeros(n, m_total);
232    for i in 0..n {
233        for g in 0..m_total {
234            let mut yhat = 0.0;
235            for j in 0..p_total {
236                yhat += design[(i, j)] * beta[(j, g)];
237            }
238            fitted[(i, g)] = yhat;
239            residuals[(i, g)] = data[(i, g)] - yhat;
240        }
241    }
242    (fitted, residuals)
243}
244
245/// Compute GCV: (1/n*m) * sum(r^2) / (1 - tr(H)/n)^2.
246fn compute_gcv(residuals: &FdMatrix, trace_h: f64) -> f64 {
247    let (n, m) = residuals.shape();
248    let denom = (1.0 - trace_h / n as f64).max(1e-10);
249    let ss_res: f64 = residuals.as_slice().iter().map(|v| v * v).sum();
250    ss_res / (n as f64 * m as f64 * denom * denom)
251}
252
253/// Compute trace of hat matrix for OLS: tr(H) = p_total (since X(X'X)^{-1}X'
254/// projects onto a p_total-dimensional subspace).
255fn trace_hat_ols(p_total: usize) -> f64 {
256    p_total as f64
257}
258
259/// Smooth a raw coefficient vector (length m_total) using the 2D penalty.
260///
261/// Solves (I + P_2d) * beta_smooth = beta_raw via Cholesky.
262fn smooth_coefficient_surface(
263    beta_raw: &[f64],
264    penalty_2d: &[f64],
265    m_total: usize,
266) -> Result<Vec<f64>, FdarError> {
267    // Build (I + P_2d)
268    let mut a = penalty_2d.to_vec();
269    for i in 0..m_total {
270        a[i * m_total + i] += 1.0;
271    }
272    let l = cholesky_factor(&a, m_total)?;
273    Ok(cholesky_forward_back(&l, beta_raw, m_total))
274}
275
276/// Compute standard errors from the diagonal of (X'X)^{-1} and residual
277/// variance at each grid point.
278fn compute_beta_se_2d(
279    xtx: &[f64],
280    residuals: &FdMatrix,
281    p_total: usize,
282    n: usize,
283) -> Option<FdMatrix> {
284    let m_total = residuals.ncols();
285    let l = cholesky_factor(xtx, p_total).ok()?;
286
287    // Diagonal of (X'X)^{-1}
288    let a_inv_diag: Vec<f64> = (0..p_total)
289        .map(|j| {
290            let mut ej = vec![0.0; p_total];
291            ej[j] = 1.0;
292            let v = cholesky_forward_back(&l, &ej, p_total);
293            v[j]
294        })
295        .collect();
296
297    let df = (n - p_total).max(1) as f64;
298    // We return SE for the predictor coefficients only (drop intercept).
299    let p = p_total - 1;
300    let mut se = FdMatrix::zeros(p, m_total);
301    for g in 0..m_total {
302        let sigma2: f64 = (0..n).map(|i| residuals[(i, g)].powi(2)).sum::<f64>() / df;
303        for j in 0..p {
304            // j+1 to skip intercept row in a_inv_diag
305            se[(j, g)] = (sigma2 * a_inv_diag[j + 1]).max(0.0).sqrt();
306        }
307    }
308    Some(se)
309}
310
311// ---------------------------------------------------------------------------
312// GCV lambda selection
313// ---------------------------------------------------------------------------
314
315/// Select (lambda_s, lambda_t) via GCV over a 2D grid of candidate values.
316fn select_lambdas_gcv(
317    xtx: &[f64],
318    xty: &FdMatrix,
319    design: &FdMatrix,
320    data: &FdMatrix,
321    m1: usize,
322    m2: usize,
323    fix_lambda_s: Option<f64>,
324    fix_lambda_t: Option<f64>,
325) -> (f64, f64) {
326    let candidates = [0.0, 1e-6, 1e-4, 1e-2, 0.1, 1.0, 10.0, 100.0, 1000.0];
327    let p_total = design.ncols();
328    let m_total = m1 * m2;
329
330    let ls_candidates: Vec<f64> = if let Some(ls) = fix_lambda_s {
331        vec![ls]
332    } else {
333        candidates.to_vec()
334    };
335    let lt_candidates: Vec<f64> = if let Some(lt) = fix_lambda_t {
336        vec![lt]
337    } else {
338        candidates.to_vec()
339    };
340
341    // Pre-compute OLS inverse and raw beta
342    let Ok(l_xtx) = cholesky_factor(xtx, p_total) else {
343        return (0.0, 0.0);
344    };
345
346    // beta_ols: p_total x m_total
347    let mut beta_ols = FdMatrix::zeros(p_total, m_total);
348    for g in 0..m_total {
349        let b: Vec<f64> = (0..p_total).map(|j| xty[(j, g)]).collect();
350        let x = cholesky_forward_back(&l_xtx, &b, p_total);
351        for j in 0..p_total {
352            beta_ols[(j, g)] = x[j];
353        }
354    }
355
356    let trace_h = trace_hat_ols(p_total);
357    let mut best_ls = 0.0;
358    let mut best_lt = 0.0;
359    let mut best_gcv = f64::INFINITY;
360
361    for &ls in &ls_candidates {
362        for &lt in &lt_candidates {
363            if ls == 0.0 && lt == 0.0 {
364                // No smoothing: use raw OLS
365                let (_, residuals) = compute_fitted_residuals(design, &beta_ols, data);
366                let gcv = compute_gcv(&residuals, trace_h);
367                if gcv < best_gcv {
368                    best_gcv = gcv;
369                    best_ls = ls;
370                    best_lt = lt;
371                }
372                continue;
373            }
374
375            let p2d = penalty_matrix_2d(m1, m2, ls, lt);
376            let mut beta_smooth = FdMatrix::zeros(p_total, m_total);
377            let mut ok = true;
378            for j in 0..p_total {
379                let raw: Vec<f64> = (0..m_total).map(|g| beta_ols[(j, g)]).collect();
380                if let Ok(smoothed) = smooth_coefficient_surface(&raw, &p2d, m_total) {
381                    for g in 0..m_total {
382                        beta_smooth[(j, g)] = smoothed[g];
383                    }
384                } else {
385                    ok = false;
386                    break;
387                }
388            }
389            if !ok {
390                continue;
391            }
392
393            let (_, residuals) = compute_fitted_residuals(design, &beta_smooth, data);
394            let gcv = compute_gcv(&residuals, trace_h);
395            if gcv < best_gcv {
396                best_gcv = gcv;
397                best_ls = ls;
398                best_lt = lt;
399            }
400        }
401    }
402
403    (best_ls, best_lt)
404}
405
406// ---------------------------------------------------------------------------
407// Public API
408// ---------------------------------------------------------------------------
409
410/// 2D Function-on-Scalar Regression with tensor-product penalty.
411///
412/// Fits the model
413/// ```text
414/// Y_i(s,t) = beta_0(s,t) + sum_j x_{ij} beta_j(s,t) + epsilon_i(s,t)
415/// ```
416/// with anisotropic roughness penalty
417/// `lambda_s * ||d^2 beta/ds^2||^2 + lambda_t * ||d^2 beta/dt^2||^2`.
418///
419/// # Arguments
420/// * `data` - Functional response matrix (n x m_total where m_total = m1*m2)
421/// * `predictors` - Scalar predictor matrix (n x p)
422/// * `grid` - 2D grid specification
423/// * `lambda_s` - Smoothing parameter in the s-direction (negative = GCV)
424/// * `lambda_t` - Smoothing parameter in the t-direction (negative = GCV)
425///
426/// # Errors
427///
428/// Returns [`FdarError::InvalidDimension`] if `data.ncols() != grid.total()`,
429/// `predictors.nrows() != data.nrows()`, or `n < p + 2`.
430/// Returns [`FdarError::InvalidParameter`] if the grid has zero size in
431/// either dimension.
432/// Returns [`FdarError::ComputationFailed`] if the Cholesky factorization
433/// fails during OLS or smoothing.
434///
435/// # Examples
436///
437/// ```
438/// use fdars_core::matrix::FdMatrix;
439/// use fdars_core::function_on_scalar_2d::{fosr_2d, Grid2d};
440///
441/// let m1 = 5;
442/// let m2 = 4;
443/// let n = 20;
444/// let grid = Grid2d::new(
445///     (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect(),
446///     (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect(),
447/// );
448/// let data = FdMatrix::from_column_major(
449///     (0..n * m1 * m2).map(|k| {
450///         let i = (k % n) as f64;
451///         let j = (k / n) as f64;
452///         ((i + 1.0) * 0.3 + j * 0.7).sin()
453///     }).collect(),
454///     n, m1 * m2,
455/// ).unwrap();
456/// let predictors = FdMatrix::from_column_major(
457///     (0..n * 2).map(|k| {
458///         let i = (k % n) as f64;
459///         let j = (k / n) as f64;
460///         (i * 0.4 + j * 1.5).cos()
461///     }).collect(),
462///     n, 2,
463/// ).unwrap();
464/// let result = fosr_2d(&data, &predictors, &grid, 0.1, 0.1).unwrap();
465/// assert_eq!(result.fitted.shape(), (n, m1 * m2));
466/// ```
467#[must_use = "expensive computation whose result should not be discarded"]
468pub fn fosr_2d(
469    data: &FdMatrix,
470    predictors: &FdMatrix,
471    grid: &Grid2d,
472    lambda_s: f64,
473    lambda_t: f64,
474) -> Result<FosrResult2d, FdarError> {
475    let (n, m_data) = data.shape();
476    let p = predictors.ncols();
477    let m1 = grid.m1();
478    let m2 = grid.m2();
479    let m_total = grid.total();
480
481    // ---- Validate inputs ----
482
483    if m1 == 0 {
484        return Err(FdarError::InvalidParameter {
485            parameter: "grid",
486            message: "argvals_s must not be empty".to_string(),
487        });
488    }
489    if m2 == 0 {
490        return Err(FdarError::InvalidParameter {
491            parameter: "grid",
492            message: "argvals_t must not be empty".to_string(),
493        });
494    }
495    if m_data != m_total {
496        return Err(FdarError::InvalidDimension {
497            parameter: "data",
498            expected: format!("{m_total} columns (grid m1*m2 = {m1}*{m2})"),
499            actual: format!("{m_data} columns"),
500        });
501    }
502    if predictors.nrows() != n {
503        return Err(FdarError::InvalidDimension {
504            parameter: "predictors",
505            expected: format!("{n} rows (matching data)"),
506            actual: format!("{} rows", predictors.nrows()),
507        });
508    }
509    if n < p + 2 {
510        return Err(FdarError::InvalidDimension {
511            parameter: "data",
512            expected: format!("at least {} observations (p + 2)", p + 2),
513            actual: format!("{n} observations"),
514        });
515    }
516
517    // ---- Step 1: Build design and compute OLS ----
518
519    let design = build_fosr_design(predictors, n);
520    let p_total = design.ncols(); // p + 1
521    let xtx = compute_xtx(&design);
522    let xty = compute_xty_matrix(&design, data);
523
524    let l_xtx = cholesky_factor(&xtx, p_total)?;
525
526    // Pointwise OLS: beta_ols[:,g] = (X'X)^{-1} X' y[:,g]
527    let mut beta_ols = FdMatrix::zeros(p_total, m_total);
528    for g in 0..m_total {
529        let b: Vec<f64> = (0..p_total).map(|j| xty[(j, g)]).collect();
530        let x = cholesky_forward_back(&l_xtx, &b, p_total);
531        for j in 0..p_total {
532            beta_ols[(j, g)] = x[j];
533        }
534    }
535
536    // ---- Step 2: Determine lambda values ----
537
538    let fix_ls = if lambda_s >= 0.0 {
539        Some(lambda_s)
540    } else {
541        None
542    };
543    let fix_lt = if lambda_t >= 0.0 {
544        Some(lambda_t)
545    } else {
546        None
547    };
548
549    let (lambda_s_final, lambda_t_final) = if fix_ls.is_some() && fix_lt.is_some() {
550        (lambda_s, lambda_t)
551    } else {
552        select_lambdas_gcv(&xtx, &xty, &design, data, m1, m2, fix_ls, fix_lt)
553    };
554
555    // ---- Step 3: Smooth coefficient surfaces ----
556
557    let beta_smooth = if lambda_s_final == 0.0 && lambda_t_final == 0.0 {
558        beta_ols
559    } else {
560        let p2d = penalty_matrix_2d(m1, m2, lambda_s_final, lambda_t_final);
561        let mut smoothed = FdMatrix::zeros(p_total, m_total);
562        for j in 0..p_total {
563            let raw: Vec<f64> = (0..m_total).map(|g| beta_ols[(j, g)]).collect();
564            let s = smooth_coefficient_surface(&raw, &p2d, m_total)?;
565            for g in 0..m_total {
566                smoothed[(j, g)] = s[g];
567            }
568        }
569        smoothed
570    };
571
572    // ---- Step 4: Compute diagnostics ----
573
574    let (fitted, residuals) = compute_fitted_residuals(&design, &beta_smooth, data);
575
576    let r_squared_pointwise = pointwise_r_squared(data, &fitted);
577    let r_squared = if m_total > 0 {
578        r_squared_pointwise.iter().sum::<f64>() / m_total as f64
579    } else {
580        0.0
581    };
582
583    let trace_h = trace_hat_ols(p_total);
584    let gcv = compute_gcv(&residuals, trace_h);
585
586    let beta_se = compute_beta_se_2d(&xtx, &residuals, p_total, n);
587
588    // Extract intercept (row 0 of beta_smooth)
589    let intercept: Vec<f64> = (0..m_total).map(|g| beta_smooth[(0, g)]).collect();
590
591    // Extract predictor coefficients (rows 1..p_total)
592    let mut beta_out = FdMatrix::zeros(p, m_total);
593    for j in 0..p {
594        for g in 0..m_total {
595            beta_out[(j, g)] = beta_smooth[(j + 1, g)];
596        }
597    }
598
599    Ok(FosrResult2d {
600        intercept,
601        beta: beta_out,
602        fitted,
603        residuals,
604        r_squared_pointwise,
605        r_squared,
606        beta_se,
607        lambda_s: lambda_s_final,
608        lambda_t: lambda_t_final,
609        gcv,
610        grid: grid.clone(),
611    })
612}
613
614/// Predict functional surfaces for new observations.
615///
616/// # Arguments
617/// * `result` - Fitted [`FosrResult2d`]
618/// * `new_predictors` - New scalar predictors (n_new x p)
619///
620/// # Errors
621///
622/// Returns [`FdarError::InvalidDimension`] if the number of predictor columns
623/// does not match the fitted model.
624#[must_use = "prediction result should not be discarded"]
625pub fn predict_fosr_2d(
626    result: &FosrResult2d,
627    new_predictors: &FdMatrix,
628) -> Result<FdMatrix, FdarError> {
629    let n_new = new_predictors.nrows();
630    let m_total = result.intercept.len();
631    let p = result.beta.nrows();
632
633    if new_predictors.ncols() != p {
634        return Err(FdarError::InvalidDimension {
635            parameter: "new_predictors",
636            expected: format!("{p} columns (matching fitted model)"),
637            actual: format!("{} columns", new_predictors.ncols()),
638        });
639    }
640
641    let mut predicted = FdMatrix::zeros(n_new, m_total);
642    for i in 0..n_new {
643        for g in 0..m_total {
644            let mut yhat = result.intercept[g];
645            for j in 0..p {
646                yhat += new_predictors[(i, j)] * result.beta[(j, g)];
647            }
648            predicted[(i, g)] = yhat;
649        }
650    }
651    Ok(predicted)
652}
653
654// ---------------------------------------------------------------------------
655// Tests
656// ---------------------------------------------------------------------------
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661
662    fn uniform_grid_1d(m: usize) -> Vec<f64> {
663        (0..m).map(|j| j as f64 / (m - 1).max(1) as f64).collect()
664    }
665
666    fn make_grid(m1: usize, m2: usize) -> Grid2d {
667        Grid2d::new(uniform_grid_1d(m1), uniform_grid_1d(m2))
668    }
669
670    /// Generate test data: Y_i(s,t) = intercept(s,t) + z_1 * beta_1(s,t) + z_2 * beta_2(s,t) + noise
671    fn generate_2d_data(
672        n: usize,
673        m1: usize,
674        m2: usize,
675        noise_scale: f64,
676    ) -> (FdMatrix, FdMatrix, Grid2d) {
677        let grid = make_grid(m1, m2);
678        let m_total = m1 * m2;
679        let mut y = FdMatrix::zeros(n, m_total);
680        let mut z = FdMatrix::zeros(n, 2);
681
682        for i in 0..n {
683            let z1 = (i as f64) / (n as f64);
684            let z2 = if i % 2 == 0 { 1.0 } else { 0.0 };
685            z[(i, 0)] = z1;
686            z[(i, 1)] = z2;
687
688            for si in 0..m1 {
689                for ti in 0..m2 {
690                    let g = si + ti * m1; // column-major flat index
691                    let s = grid.argvals_s[si];
692                    let t = grid.argvals_t[ti];
693
694                    let intercept = s + t;
695                    let beta1 = s * t;
696                    let beta2 = s - t;
697                    let noise = noise_scale * ((i * 13 + si * 7 + ti * 3) % 100) as f64 / 100.0;
698
699                    y[(i, g)] = intercept + z1 * beta1 + z2 * beta2 + noise;
700                }
701            }
702        }
703        (y, z, grid)
704    }
705
706    #[test]
707    fn test_grid2d_basic() {
708        let grid = make_grid(5, 4);
709        assert_eq!(grid.m1(), 5);
710        assert_eq!(grid.m2(), 4);
711        assert_eq!(grid.total(), 20);
712    }
713
714    #[test]
715    fn test_kronecker_product_small() {
716        // A = [[1,2],[3,4]], B = [[0,5],[6,7]]
717        // A kron B = [[0,5,0,10],[6,7,12,14],[0,15,0,20],[18,21,24,28]]
718        let a = vec![1.0, 2.0, 3.0, 4.0];
719        let b = vec![0.0, 5.0, 6.0, 7.0];
720        let c = kronecker_product(&a, 2, 2, &b, 2, 2);
721        assert_eq!(c.len(), 16);
722        #[rustfmt::skip]
723        let expected = vec![
724            0.0, 5.0, 0.0, 10.0,
725            6.0, 7.0, 12.0, 14.0,
726            0.0, 15.0, 0.0, 20.0,
727            18.0, 21.0, 24.0, 28.0,
728        ];
729        for (i, (&ci, &ei)) in c.iter().zip(expected.iter()).enumerate() {
730            assert!(
731                (ci - ei).abs() < 1e-12,
732                "kronecker mismatch at index {i}: got {ci}, expected {ei}"
733            );
734        }
735    }
736
737    #[test]
738    fn test_penalty_2d_symmetry() {
739        let m1 = 5;
740        let m2 = 4;
741        let p2d = penalty_matrix_2d(m1, m2, 1.0, 1.0);
742        let m_total = m1 * m2;
743        for i in 0..m_total {
744            for j in 0..m_total {
745                assert!(
746                    (p2d[i * m_total + j] - p2d[j * m_total + i]).abs() < 1e-12,
747                    "P_2d not symmetric at ({i},{j})"
748                );
749            }
750        }
751    }
752
753    #[test]
754    fn test_penalty_2d_shape() {
755        let m1 = 5;
756        let m2 = 4;
757        let p2d = penalty_matrix_2d(m1, m2, 1.0, 2.0);
758        assert_eq!(p2d.len(), (m1 * m2) * (m1 * m2));
759    }
760
761    #[test]
762    fn test_fosr_2d_constant_response() {
763        let n = 20;
764        let m1 = 5;
765        let m2 = 4;
766        let grid = make_grid(m1, m2);
767        let m_total = m1 * m2;
768
769        // Y_i(s,t) = 3.0 for all i,s,t
770        let mut y = FdMatrix::zeros(n, m_total);
771        for i in 0..n {
772            for g in 0..m_total {
773                y[(i, g)] = 3.0;
774            }
775        }
776
777        let mut z = FdMatrix::zeros(n, 2);
778        for i in 0..n {
779            z[(i, 0)] = i as f64;
780            z[(i, 1)] = (i % 3) as f64;
781        }
782
783        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
784
785        // Intercept should be near 3.0 everywhere
786        for g in 0..m_total {
787            assert!(
788                (result.intercept[g] - 3.0).abs() < 1e-8,
789                "intercept[{g}] = {}, expected 3.0",
790                result.intercept[g]
791            );
792        }
793
794        // Beta coefficients should be near zero
795        for j in 0..2 {
796            for g in 0..m_total {
797                assert!(
798                    result.beta[(j, g)].abs() < 1e-8,
799                    "beta[{j},{g}] = {}, expected ~0",
800                    result.beta[(j, g)]
801                );
802            }
803        }
804    }
805
806    #[test]
807    fn test_fosr_2d_single_predictor() {
808        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.01);
809        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
810
811        // With low noise, R^2 should be high
812        assert!(
813            result.r_squared > 0.8,
814            "R^2 = {}, expected > 0.8",
815            result.r_squared
816        );
817    }
818
819    #[test]
820    fn test_fosr_2d_fitted_plus_residuals() {
821        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
822        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
823
824        let (n, m_total) = y.shape();
825        for i in 0..n {
826            for g in 0..m_total {
827                let reconstructed = result.fitted[(i, g)] + result.residuals[(i, g)];
828                assert!(
829                    (reconstructed - y[(i, g)]).abs() < 1e-10,
830                    "fitted + residuals != y at ({i},{g})"
831                );
832            }
833        }
834    }
835
836    #[test]
837    fn test_fosr_2d_r_squared_range() {
838        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
839        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
840
841        for (g, &r2) in result.r_squared_pointwise.iter().enumerate() {
842            assert!(
843                (-0.01..=1.0 + 1e-10).contains(&r2),
844                "R^2 out of range at grid point {g}: {r2}"
845            );
846        }
847    }
848
849    #[test]
850    fn test_fosr_2d_predict_matches_fitted() {
851        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
852        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
853
854        let preds = predict_fosr_2d(&result, &z).unwrap();
855        let (n, m_total) = y.shape();
856        for i in 0..n {
857            for g in 0..m_total {
858                assert!(
859                    (preds[(i, g)] - result.fitted[(i, g)]).abs() < 1e-8,
860                    "prediction != fitted at ({i},{g})"
861                );
862            }
863        }
864    }
865
866    #[test]
867    fn test_fosr_2d_reshape_beta_surface() {
868        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
869        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
870
871        let surface = result.beta_surface(0);
872        assert_eq!(surface.shape(), (5, 4));
873
874        let r2_surface = result.r_squared_surface();
875        assert_eq!(r2_surface.shape(), (5, 4));
876
877        let resid_surface = result.residual_surface(0);
878        assert_eq!(resid_surface.shape(), (5, 4));
879    }
880
881    #[test]
882    fn test_fosr_2d_dimension_mismatch() {
883        let grid = make_grid(5, 4);
884
885        // Wrong number of columns in data
886        let y = FdMatrix::zeros(20, 10); // 10 != 5*4=20
887        let z = FdMatrix::zeros(20, 2);
888        assert!(fosr_2d(&y, &z, &grid, 0.0, 0.0).is_err());
889
890        // Mismatched rows between data and predictors
891        let y = FdMatrix::zeros(20, 20);
892        let z = FdMatrix::zeros(10, 2);
893        assert!(fosr_2d(&y, &z, &grid, 0.0, 0.0).is_err());
894
895        // Too few observations
896        let y = FdMatrix::zeros(3, 20);
897        let z = FdMatrix::zeros(3, 2);
898        assert!(fosr_2d(&y, &z, &grid, 0.0, 0.0).is_err());
899
900        // Empty grid
901        let empty_grid = Grid2d::new(vec![], vec![0.0, 1.0]);
902        let y = FdMatrix::zeros(20, 0);
903        let z = FdMatrix::zeros(20, 2);
904        assert!(fosr_2d(&y, &z, &empty_grid, 0.0, 0.0).is_err());
905
906        // Predictor dimension mismatch in predict
907        let grid = make_grid(3, 3);
908        let y = FdMatrix::zeros(20, 9);
909        let mut z = FdMatrix::zeros(20, 2);
910        for i in 0..20 {
911            z[(i, 0)] = i as f64;
912            z[(i, 1)] = (i * 3 % 7) as f64;
913        }
914        let result = fosr_2d(&y, &z, &grid, 0.0, 0.0).unwrap();
915        let z_bad = FdMatrix::zeros(5, 3); // 3 != 2
916        assert!(predict_fosr_2d(&result, &z_bad).is_err());
917    }
918
919    #[test]
920    fn test_fosr_2d_gcv() {
921        let (y, z, grid) = generate_2d_data(20, 5, 4, 0.05);
922        // Negative lambda triggers GCV selection
923        let result = fosr_2d(&y, &z, &grid, -1.0, -1.0).unwrap();
924        assert!(result.lambda_s >= 0.0);
925        assert!(result.lambda_t >= 0.0);
926        assert!(result.gcv > 0.0);
927        assert!(result.r_squared > 0.5);
928    }
929}