Skip to main content

fdars_core/
covariance.rs

1//! Covariance kernels and Gaussian process generation.
2//!
3//! This module provides a flexible [`CovKernel`] enum for defining covariance
4//! functions (Gaussian, Matern, Periodic, etc.) with kernel algebra via
5//! [`Sum`](CovKernel::Sum) and [`Product`](CovKernel::Product) combinators,
6//! and a [`generate_gaussian_process`] function for drawing sample paths from
7//! a Gaussian process with a given kernel and optional mean function.
8
9use crate::error::FdarError;
10use crate::linalg::cholesky_d;
11use crate::matrix::FdMatrix;
12use rand::prelude::*;
13use rand_distr::StandardNormal;
14use std::f64::consts::PI;
15
16/// Covariance kernel specification.
17///
18/// Each variant encodes a family of covariance functions `k(s, t)`.
19/// Kernels can be composed with [`Sum`](CovKernel::Sum) and
20/// [`Product`](CovKernel::Product) to build richer covariance structures.
21#[derive(Debug, Clone, PartialEq)]
22#[non_exhaustive]
23pub enum CovKernel {
24    /// Squared-exponential (RBF) kernel: `variance * exp(-0.5 * ((s-t)/length_scale)^2)`.
25    Gaussian { length_scale: f64, variance: f64 },
26    /// Exponential (Ornstein-Uhlenbeck) kernel: `variance * exp(-|s-t| / length_scale)`.
27    Exponential { length_scale: f64, variance: f64 },
28    /// Matern kernel with smoothness parameter `nu`.
29    ///
30    /// Closed-form expressions are used for `nu = 0.5` (exponential),
31    /// `nu = 1.5`, and `nu = 2.5`. For other values of `nu` the general
32    /// formula with a gamma-function approximation is used.
33    Matern {
34        length_scale: f64,
35        variance: f64,
36        nu: f64,
37    },
38    /// Brownian motion (Wiener process) kernel: `variance * min(s, t)`.
39    Brownian { variance: f64 },
40    /// Periodic kernel: `variance * exp(-2 * sin^2(pi * |s-t| / period) / length_scale^2)`.
41    Periodic {
42        length_scale: f64,
43        variance: f64,
44        period: f64,
45    },
46    /// Linear kernel: `variance * (s - offset) * (t - offset)`.
47    Linear { variance: f64, offset: f64 },
48    /// Polynomial kernel: `(variance * s * t + offset)^degree`.
49    Polynomial {
50        variance: f64,
51        offset: f64,
52        degree: u32,
53    },
54    /// White noise kernel: `variance * delta(s, t)`.
55    WhiteNoise { variance: f64 },
56    /// Sum of two kernels: `k1(s,t) + k2(s,t)`.
57    Sum(Box<CovKernel>, Box<CovKernel>),
58    /// Product of two kernels: `k1(s,t) * k2(s,t)`.
59    Product(Box<CovKernel>, Box<CovKernel>),
60}
61
62impl CovKernel {
63    /// Evaluate the covariance function `k(s, t)`.
64    pub fn eval(&self, s: f64, t: f64) -> f64 {
65        match self {
66            CovKernel::Gaussian {
67                length_scale,
68                variance,
69            } => {
70                let d = (s - t) / length_scale;
71                variance * (-0.5 * d * d).exp()
72            }
73            CovKernel::Exponential {
74                length_scale,
75                variance,
76            } => {
77                let d = (s - t).abs() / length_scale;
78                variance * (-d).exp()
79            }
80            CovKernel::Matern {
81                length_scale,
82                variance,
83                nu,
84            } => eval_matern(s, t, *length_scale, *variance, *nu),
85            CovKernel::Brownian { variance } => {
86                if s >= 0.0 && t >= 0.0 {
87                    variance * s.min(t)
88                } else {
89                    0.0
90                }
91            }
92            CovKernel::Periodic {
93                length_scale,
94                variance,
95                period,
96            } => {
97                let sin_val = (PI * (s - t).abs() / period).sin();
98                variance * (-2.0 * sin_val * sin_val / (length_scale * length_scale)).exp()
99            }
100            CovKernel::Linear { variance, offset } => variance * (s - offset) * (t - offset),
101            CovKernel::Polynomial {
102                variance,
103                offset,
104                degree,
105            } => {
106                let inner = variance * s * t + offset;
107                inner.powi(*degree as i32)
108            }
109            CovKernel::WhiteNoise { variance } => {
110                if (s - t).abs() < 1e-15 {
111                    *variance
112                } else {
113                    0.0
114                }
115            }
116            CovKernel::Sum(k1, k2) => k1.eval(s, t) + k2.eval(s, t),
117            CovKernel::Product(k1, k2) => k1.eval(s, t) * k2.eval(s, t),
118        }
119    }
120
121    /// Validate kernel parameters, returning an error for negative variance or
122    /// length scale, or invalid degree.
123    fn validate(&self) -> Result<(), FdarError> {
124        match self {
125            CovKernel::Gaussian {
126                length_scale,
127                variance,
128            } => {
129                check_positive("variance", *variance)?;
130                check_positive("length_scale", *length_scale)?;
131            }
132            CovKernel::Exponential {
133                length_scale,
134                variance,
135            } => {
136                check_positive("variance", *variance)?;
137                check_positive("length_scale", *length_scale)?;
138            }
139            CovKernel::Matern {
140                length_scale,
141                variance,
142                nu,
143            } => {
144                check_positive("variance", *variance)?;
145                check_positive("length_scale", *length_scale)?;
146                check_positive("nu", *nu)?;
147            }
148            CovKernel::Brownian { variance } => {
149                check_positive("variance", *variance)?;
150            }
151            CovKernel::Periodic {
152                length_scale,
153                variance,
154                period,
155            } => {
156                check_positive("variance", *variance)?;
157                check_positive("length_scale", *length_scale)?;
158                check_positive("period", *period)?;
159            }
160            CovKernel::Linear { variance, .. } => {
161                check_positive("variance", *variance)?;
162            }
163            CovKernel::Polynomial {
164                variance, degree, ..
165            } => {
166                check_positive("variance", *variance)?;
167                if *degree == 0 {
168                    return Err(FdarError::InvalidParameter {
169                        parameter: "degree",
170                        message: "must be >= 1".to_string(),
171                    });
172                }
173            }
174            CovKernel::WhiteNoise { variance } => {
175                check_positive("variance", *variance)?;
176            }
177            CovKernel::Sum(k1, k2) => {
178                k1.validate()?;
179                k2.validate()?;
180            }
181            CovKernel::Product(k1, k2) => {
182                k1.validate()?;
183                k2.validate()?;
184            }
185        }
186        Ok(())
187    }
188}
189
190/// Check that a parameter is strictly positive.
191fn check_positive(name: &'static str, value: f64) -> Result<(), FdarError> {
192    if value <= 0.0 || value.is_nan() {
193        return Err(FdarError::InvalidParameter {
194            parameter: name,
195            message: format!("must be positive, got {value}"),
196        });
197    }
198    Ok(())
199}
200
201/// Evaluate the Matern covariance function.
202///
203/// Uses closed-form expressions for `nu = 0.5`, `1.5`, and `2.5`.
204/// For half-integer `nu = p + 0.5`, uses the general half-integer closed form.
205/// For other values of `nu`, uses the general formula with the
206/// Lanczos gamma approximation and a robust Bessel K_nu implementation.
207fn eval_matern(s: f64, t: f64, length_scale: f64, variance: f64, nu: f64) -> f64 {
208    let d = (s - t).abs();
209    if d < 1e-15 {
210        return variance;
211    }
212    let r = d / length_scale;
213    let z = (2.0 * nu).sqrt() * r;
214
215    // Check if nu is a half-integer: nu = p + 0.5 for integer p >= 0
216    let twice_nu = 2.0 * nu;
217    let twice_nu_rounded = twice_nu.round();
218    if (twice_nu - twice_nu_rounded).abs() < 1e-10 && twice_nu_rounded >= 1.0 {
219        let twice_nu_int = twice_nu_rounded as u64;
220        if twice_nu_int % 2 == 1 {
221            // nu is half-integer: use the closed-form polynomial expression
222            // For nu = p + 0.5, K_{p+1/2}(z) = sqrt(pi/(2z)) * exp(-z) * sum_{k=0}^p (p+k)!/(k!(p-k)!) * (2z)^{-k}
223            let p = (twice_nu_int - 1) / 2;
224            return variance * matern_half_integer(z, p as usize);
225        }
226    }
227
228    // General formula: k(r) = variance * 2^(1-nu) / Gamma(nu) * (sqrt(2*nu)*r)^nu * K_nu(sqrt(2*nu)*r)
229    // Work in log-space for numerical stability
230    let log_prefactor = (1.0 - nu) * 2.0_f64.ln() - ln_gamma(nu) + nu * z.ln();
231    let knu = bessel_knu(nu, z);
232    if knu <= 0.0 {
233        return 0.0;
234    }
235    variance * (log_prefactor + knu.ln()).exp()
236}
237
238/// Closed-form Matern for half-integer nu = p + 0.5.
239///
240/// Uses the Rasmussen & Williams formula (eq. 4.17): for nu = p + 1/2,
241/// `k(r) = exp(-z) * (p! / (2p)!) * sum_{i=0}^{p} ((p+i)! / (i! * (p-i)!)) * (2z)^{p-i}`
242/// where `z = sqrt(2*nu) * r / l`.
243///
244/// This avoids all Bessel function evaluation and is numerically exact.
245fn matern_half_integer(z: f64, p: usize) -> f64 {
246    let two_z = 2.0 * z;
247    let mut poly = 0.0;
248    for i in 0..=p {
249        // coefficient = (p+i)! / (i! * (p-i)!)
250        let coeff = factorial(p + i) as f64 / (factorial(i) as f64 * factorial(p - i) as f64);
251        // (2z)^{p-i}
252        let power = two_z.powi((p - i) as i32);
253        poly += coeff * power;
254    }
255    // Prefactor: p! / (2p)!
256    let prefactor = factorial(p) as f64 / factorial(2 * p) as f64;
257    prefactor * poly * (-z).exp()
258}
259
260/// Compute n! (for small n).
261fn factorial(n: usize) -> u64 {
262    (1..=n as u64).product::<u64>().max(1)
263}
264
265/// Lanczos approximation of the log-gamma function.
266fn ln_gamma(x: f64) -> f64 {
267    if x <= 0.0 {
268        return f64::INFINITY;
269    }
270    if x < 0.5 {
271        // Reflection formula: Gamma(x) * Gamma(1-x) = pi / sin(pi*x)
272        let log_pi = PI.ln();
273        return log_pi - (PI * x).sin().abs().ln() - ln_gamma(1.0 - x);
274    }
275
276    let g = 7.0;
277    #[allow(clippy::excessive_precision, clippy::inconsistent_digit_grouping)]
278    let coefficients = [
279        0.999_999_999_999_809_93,
280        676.520_368_121_885_1,
281        -1_259.139_216_722_402_9,
282        771.323_428_777_653_1,
283        -176.615_029_162_140_6,
284        12.507_343_278_686_905,
285        -0.138_571_095_265_720_12,
286        9.984_369_578_019_572e-6,
287        1.505_632_735_149_311_6e-7,
288    ];
289
290    let x = x - 1.0;
291    let mut sum = coefficients[0];
292    for (i, &c) in coefficients.iter().enumerate().skip(1) {
293        sum += c / (x + i as f64);
294    }
295
296    let t = x + g + 0.5;
297    0.5 * (2.0 * PI).ln() + (t.ln() * (x + 0.5)) - t + sum.ln()
298}
299
300/// Compute the modified Bessel function of the second kind K_nu(z).
301///
302/// Uses the series representation K_nu(z) = (pi/2)(I_{-nu}(z) - I_nu(z)) / sin(nu*pi)
303/// for non-integer nu, and the asymptotic expansion for large z.
304fn bessel_knu(nu: f64, z: f64) -> f64 {
305    if z <= 0.0 {
306        return f64::INFINITY;
307    }
308
309    // For large z, use asymptotic expansion (good for all nu)
310    if z > 50.0 {
311        return bessel_knu_asymptotic(nu, z);
312    }
313
314    // Check if nu is close to an integer
315    let nu_rounded = nu.round();
316    let is_integer = (nu - nu_rounded).abs() < 1e-10;
317
318    if is_integer {
319        // For integer nu, use the Miller backward recurrence from K_0 and K_1
320        let n = nu_rounded.abs() as u32;
321        bessel_kn_miller(n, z)
322    } else {
323        // For non-integer nu, use the series K_nu = (pi/2)(I_{-nu} - I_nu)/sin(nu*pi)
324        let sin_nu_pi = (nu * PI).sin();
325        let i_neg_nu = bessel_inu_series(-nu, z);
326        let i_nu = bessel_inu_series(nu, z);
327        (PI / 2.0) * (i_neg_nu - i_nu) / sin_nu_pi
328    }
329}
330
331/// Asymptotic expansion of K_nu(z) for large z.
332fn bessel_knu_asymptotic(nu: f64, z: f64) -> f64 {
333    let prefactor = (PI / (2.0 * z)).sqrt() * (-z).exp();
334    let mu = 4.0 * nu * nu;
335    let mut term = 1.0;
336    let mut sum = 1.0;
337    for k in 1..=20 {
338        let kf = k as f64;
339        term *= (mu - (2.0 * kf - 1.0).powi(2)) / (8.0 * z * kf);
340        sum += term;
341        if term.abs() < 1e-15 * sum.abs() {
342            break;
343        }
344    }
345    prefactor * sum
346}
347
348/// Power series for I_nu(z), the modified Bessel function of the first kind.
349fn bessel_inu_series(nu: f64, z: f64) -> f64 {
350    let half_z = z / 2.0;
351    // First term: (z/2)^nu / Gamma(nu+1)
352    let log_first = nu * half_z.ln() - ln_gamma(nu + 1.0);
353    let mut term = log_first.exp();
354    let mut sum = term;
355    let z2_over4 = half_z * half_z;
356    for k in 1..=80 {
357        let kf = k as f64;
358        term *= z2_over4 / (kf * (kf + nu));
359        sum += term;
360        if term.abs() < 1e-15 * sum.abs() {
361            break;
362        }
363    }
364    sum
365}
366
367/// Compute K_n(z) for non-negative integer n using K_0 and K_1 with forward recurrence.
368fn bessel_kn_miller(n: u32, z: f64) -> f64 {
369    let k0 = bessel_k0_series(z);
370    if n == 0 {
371        return k0;
372    }
373    let k1 = bessel_k1_series(z);
374    if n == 1 {
375        return k1;
376    }
377    let mut km1 = k0;
378    let mut k_cur = k1;
379    for i in 1..n {
380        let k_next = (2.0 * i as f64 / z) * k_cur + km1;
381        km1 = k_cur;
382        k_cur = k_next;
383    }
384    k_cur
385}
386
387/// K_0(z) via the Temme series (Abramowitz & Stegun 9.6.13).
388fn bessel_k0_series(z: f64) -> f64 {
389    if z > 2.0 {
390        // Asymptotic
391        return bessel_knu_asymptotic(0.0, z);
392    }
393    // K_0(z) = -[ln(z/2) + gamma] * I_0(z) + sum_{k=0}^inf (z/2)^{2k} * h_k / (k!)^2
394    // where h_k = sum_{j=1}^k 1/j (harmonic numbers), h_0 = 0
395    let euler_gamma = 0.577_215_664_901_532_9;
396    let half_z = z / 2.0;
397    let ln_half_z = half_z.ln();
398
399    // I_0(z) series
400    let mut i0 = 1.0;
401    let mut term_i0 = 1.0;
402    let z2_over4 = half_z * half_z;
403    for k in 1..=30 {
404        let kf = k as f64;
405        term_i0 *= z2_over4 / (kf * kf);
406        i0 += term_i0;
407    }
408
409    // The sum part
410    let mut sum_part = 0.0;
411    let mut term_s = 1.0; // (z/2)^{2k} / (k!)^2 for k=0
412    let mut h_k = 0.0;
413    sum_part += term_s * h_k; // k=0 contributes 0
414    for k in 1..=30 {
415        let kf = k as f64;
416        term_s *= z2_over4 / (kf * kf);
417        h_k += 1.0 / kf;
418        sum_part += term_s * h_k;
419    }
420
421    -(ln_half_z + euler_gamma) * i0 + sum_part
422}
423
424/// K_1(z) via K_1(z) = (1/z) + ln(z/2)*I_1(z) + series (Abramowitz & Stegun).
425fn bessel_k1_series(z: f64) -> f64 {
426    if z > 2.0 {
427        return bessel_knu_asymptotic(1.0, z);
428    }
429    // Use the relation: K_1(z) = -dK_0/dz, which from the series gives:
430    // K_1(z) = 1/z + (ln(z/2) + gamma - 1/2) * z/2 * ... (complex)
431    // Simpler: use the Wronskian relation I_1*K_0 + I_0*K_1 = 1/z
432    // => K_1 = (1/z - I_1*K_0) / I_0
433    let half_z = z / 2.0;
434    let z2_over4 = half_z * half_z;
435
436    // I_0(z)
437    let mut i0 = 1.0;
438    let mut term = 1.0;
439    for k in 1..=30 {
440        let kf = k as f64;
441        term *= z2_over4 / (kf * kf);
442        i0 += term;
443    }
444
445    // I_1(z) = (z/2) * sum (z/2)^{2k} / (k! * (k+1)!)
446    let mut i1 = half_z;
447    term = half_z;
448    for k in 1..=30 {
449        let kf = k as f64;
450        term *= z2_over4 / (kf * (kf + 1.0));
451        i1 += term;
452    }
453
454    let k0 = bessel_k0_series(z);
455    (1.0 / z - i1 * k0) / i0
456}
457
458// ---------------------------------------------------------------------------
459// Public API
460// ---------------------------------------------------------------------------
461
462/// Build the m x m covariance matrix K\[i,j\] = k(argvals\[i\], argvals\[j\]).
463///
464/// The result is a symmetric positive semi-definite matrix stored in an
465/// [`FdMatrix`] with `m` rows and `m` columns (column-major layout).
466///
467/// # Errors
468///
469/// * [`FdarError::InvalidDimension`] if `argvals` is empty.
470/// * [`FdarError::InvalidParameter`] if kernel parameters are invalid
471///   (e.g. negative variance or length scale).
472#[must_use = "returns the covariance matrix without modifying the kernel"]
473pub fn covariance_matrix(kernel: &CovKernel, argvals: &[f64]) -> Result<FdMatrix, FdarError> {
474    if argvals.is_empty() {
475        return Err(FdarError::InvalidDimension {
476            parameter: "argvals",
477            expected: ">= 1".to_string(),
478            actual: "0".to_string(),
479        });
480    }
481    kernel.validate()?;
482
483    let m = argvals.len();
484    let mut data = vec![0.0; m * m];
485
486    // Fill column-major: data[i + j * m] = k(argvals[i], argvals[j])
487    for j in 0..m {
488        for i in 0..m {
489            let val = kernel.eval(argvals[i], argvals[j]);
490            data[i + j * m] = val;
491        }
492    }
493
494    FdMatrix::from_column_major(data, m, m)
495}
496
497/// Result of Gaussian process sample generation.
498#[derive(Debug, Clone, PartialEq)]
499#[non_exhaustive]
500pub struct GaussianProcessResult {
501    /// Sample paths, stored as an n x m matrix (n samples, m evaluation points).
502    pub samples: FdMatrix,
503    /// Evaluation points (length m).
504    pub argvals: Vec<f64>,
505    /// Kernel used for generation.
506    pub kernel: CovKernel,
507    /// Mean function used for generation (length m).
508    pub mean_function: Vec<f64>,
509}
510
511/// Generate `n` sample paths from a Gaussian process.
512///
513/// Uses Cholesky decomposition of `K + jitter * I` (where `jitter = 1e-10`)
514/// to produce `L`, then each sample is `mean + L * z` where `z ~ N(0, I)`.
515///
516/// # Arguments
517///
518/// * `n` — number of sample paths to generate.
519/// * `kernel` — covariance kernel specification.
520/// * `argvals` — evaluation points (length `m`).
521/// * `mean_fn` — optional mean function values (length `m`); defaults to zero.
522/// * `seed` — optional RNG seed for reproducibility.
523///
524/// # Errors
525///
526/// * [`FdarError::InvalidDimension`] if `argvals` is empty, `n == 0`, or
527///   `mean_fn` length does not match `argvals`.
528/// * [`FdarError::InvalidParameter`] if kernel parameters are invalid.
529/// * [`FdarError::ComputationFailed`] if Cholesky decomposition fails.
530#[must_use = "returns GP samples without modifying inputs"]
531pub fn generate_gaussian_process(
532    n: usize,
533    kernel: &CovKernel,
534    argvals: &[f64],
535    mean_fn: Option<&[f64]>,
536    seed: Option<u64>,
537) -> Result<GaussianProcessResult, FdarError> {
538    // Validate inputs
539    if argvals.is_empty() {
540        return Err(FdarError::InvalidDimension {
541            parameter: "argvals",
542            expected: ">= 1".to_string(),
543            actual: "0".to_string(),
544        });
545    }
546    if n == 0 {
547        return Err(FdarError::InvalidDimension {
548            parameter: "n",
549            expected: ">= 1".to_string(),
550            actual: "0".to_string(),
551        });
552    }
553
554    let m = argvals.len();
555
556    let mean = if let Some(mf) = mean_fn {
557        if mf.len() != m {
558            return Err(FdarError::InvalidDimension {
559                parameter: "mean_fn",
560                expected: format!("{m}"),
561                actual: format!("{}", mf.len()),
562            });
563        }
564        mf.to_vec()
565    } else {
566        vec![0.0; m]
567    };
568
569    kernel.validate()?;
570
571    // Build covariance matrix in row-major for cholesky_d
572    let mut cov_row = vec![0.0; m * m];
573    for i in 0..m {
574        for j in 0..m {
575            cov_row[i * m + j] = kernel.eval(argvals[i], argvals[j]);
576        }
577    }
578
579    // Add jitter for numerical stability
580    let jitter = 1e-10;
581    for i in 0..m {
582        cov_row[i * m + i] += jitter;
583    }
584
585    // Cholesky decomposition: cov = L * L^T
586    let l = cholesky_d(&cov_row, m).map_err(|_| FdarError::ComputationFailed {
587        operation: "Cholesky decomposition",
588        detail: "covariance matrix is not positive definite — try adding jitter or checking kernel parameters".to_string(),
589    })?;
590
591    // Generate samples
592    let mut rng: Box<dyn RngCore> = match seed {
593        Some(s) => Box::new(StdRng::seed_from_u64(s)),
594        None => Box::new(StdRng::from_entropy()),
595    };
596
597    // Samples stored column-major: data[i + j * n] = sample i at argval j
598    let mut data = vec![0.0; n * m];
599
600    for i in 0..n {
601        // Draw z ~ N(0, I) of length m
602        let z: Vec<f64> = (0..m)
603            .map(|_| rng.sample::<f64, _>(StandardNormal))
604            .collect();
605
606        // Compute L * z (L is row-major m×m lower triangular)
607        for j in 0..m {
608            let mut val = mean[j];
609            for k in 0..=j {
610                val += l[j * m + k] * z[k];
611            }
612            data[i + j * n] = val;
613        }
614    }
615
616    let samples = FdMatrix::from_column_major(data, n, m)?;
617
618    Ok(GaussianProcessResult {
619        samples,
620        argvals: argvals.to_vec(),
621        kernel: kernel.clone(),
622        mean_function: mean,
623    })
624}
625
626#[cfg(test)]
627mod tests {
628    use super::*;
629
630    const TOL: f64 = 1e-10;
631
632    // -----------------------------------------------------------------------
633    // Kernel evaluation tests
634    // -----------------------------------------------------------------------
635
636    #[test]
637    fn test_gaussian_kernel_eval() {
638        let k = CovKernel::Gaussian {
639            length_scale: 1.0,
640            variance: 1.0,
641        };
642        // k(0,0) = 1
643        assert!((k.eval(0.0, 0.0) - 1.0).abs() < TOL);
644        // k(0,1) = exp(-0.5)
645        assert!((k.eval(0.0, 1.0) - (-0.5_f64).exp()).abs() < TOL);
646        // symmetry
647        assert!((k.eval(0.3, 0.7) - k.eval(0.7, 0.3)).abs() < TOL);
648    }
649
650    #[test]
651    fn test_gaussian_kernel_variance_scale() {
652        let k = CovKernel::Gaussian {
653            length_scale: 1.0,
654            variance: 2.5,
655        };
656        assert!((k.eval(0.0, 0.0) - 2.5).abs() < TOL);
657        assert!((k.eval(0.0, 1.0) - 2.5 * (-0.5_f64).exp()).abs() < TOL);
658    }
659
660    #[test]
661    fn test_exponential_kernel_eval() {
662        let k = CovKernel::Exponential {
663            length_scale: 2.0,
664            variance: 1.0,
665        };
666        assert!((k.eval(0.0, 0.0) - 1.0).abs() < TOL);
667        // k(0, 1) = exp(-1/2) = exp(-0.5)
668        assert!((k.eval(0.0, 1.0) - (-0.5_f64).exp()).abs() < TOL);
669    }
670
671    #[test]
672    fn test_matern_05_matches_exponential() {
673        let matern = CovKernel::Matern {
674            length_scale: 1.5,
675            variance: 2.0,
676            nu: 0.5,
677        };
678        let exp = CovKernel::Exponential {
679            length_scale: 1.5,
680            variance: 2.0,
681        };
682        let points = [0.0, 0.3, 0.7, 1.0, 2.5];
683        for &s in &points {
684            for &t in &points {
685                assert!(
686                    (matern.eval(s, t) - exp.eval(s, t)).abs() < 1e-8,
687                    "Matern(0.5) != Exponential at ({s}, {t}): {} vs {}",
688                    matern.eval(s, t),
689                    exp.eval(s, t)
690                );
691            }
692        }
693    }
694
695    #[test]
696    fn test_matern_15_eval() {
697        let k = CovKernel::Matern {
698            length_scale: 1.0,
699            variance: 1.0,
700            nu: 1.5,
701        };
702        // At s=t, k=variance
703        assert!((k.eval(0.0, 0.0) - 1.0).abs() < TOL);
704        // k(0, 1) = (1 + sqrt(3)) * exp(-sqrt(3))
705        let sqrt3 = 3.0_f64.sqrt();
706        let expected = (1.0 + sqrt3) * (-sqrt3).exp();
707        assert!((k.eval(0.0, 1.0) - expected).abs() < 1e-10);
708    }
709
710    #[test]
711    fn test_matern_25_eval() {
712        let k = CovKernel::Matern {
713            length_scale: 1.0,
714            variance: 1.0,
715            nu: 2.5,
716        };
717        assert!((k.eval(0.0, 0.0) - 1.0).abs() < TOL);
718        let sqrt5 = 5.0_f64.sqrt();
719        let expected = (1.0 + sqrt5 + 5.0 / 3.0) * (-sqrt5).exp();
720        assert!((k.eval(0.0, 1.0) - expected).abs() < 1e-10);
721    }
722
723    #[test]
724    fn test_brownian_kernel_eval() {
725        let k = CovKernel::Brownian { variance: 1.0 };
726        assert!((k.eval(0.3, 0.7) - 0.3).abs() < TOL);
727        assert!((k.eval(0.7, 0.3) - 0.3).abs() < TOL);
728        assert!((k.eval(0.0, 0.5) - 0.0).abs() < TOL);
729    }
730
731    #[test]
732    fn test_periodic_kernel_eval() {
733        let k = CovKernel::Periodic {
734            length_scale: 1.0,
735            variance: 1.0,
736            period: 1.0,
737        };
738        // k(t, t) = variance
739        assert!((k.eval(0.5, 0.5) - 1.0).abs() < TOL);
740        // k(0, 1) should equal k(0, 0) because period=1
741        assert!((k.eval(0.0, 1.0) - 1.0).abs() < 1e-10);
742        // symmetry
743        assert!((k.eval(0.2, 0.8) - k.eval(0.8, 0.2)).abs() < TOL);
744    }
745
746    #[test]
747    fn test_linear_kernel_eval() {
748        let k = CovKernel::Linear {
749            variance: 1.0,
750            offset: 0.0,
751        };
752        assert!((k.eval(2.0, 3.0) - 6.0).abs() < TOL);
753        assert!((k.eval(0.0, 5.0) - 0.0).abs() < TOL);
754    }
755
756    #[test]
757    fn test_polynomial_kernel_eval() {
758        let k = CovKernel::Polynomial {
759            variance: 1.0,
760            offset: 1.0,
761            degree: 2,
762        };
763        // (1*2*3 + 1)^2 = 49
764        assert!((k.eval(2.0, 3.0) - 49.0).abs() < TOL);
765    }
766
767    #[test]
768    fn test_white_noise_kernel_eval() {
769        let k = CovKernel::WhiteNoise { variance: 3.0 };
770        assert!((k.eval(1.0, 1.0) - 3.0).abs() < TOL);
771        assert!((k.eval(1.0, 1.001) - 0.0).abs() < TOL);
772    }
773
774    // -----------------------------------------------------------------------
775    // Kernel algebra tests
776    // -----------------------------------------------------------------------
777
778    #[test]
779    fn test_sum_kernel() {
780        let k1 = CovKernel::Gaussian {
781            length_scale: 1.0,
782            variance: 1.0,
783        };
784        let k2 = CovKernel::WhiteNoise { variance: 0.5 };
785        let sum = CovKernel::Sum(Box::new(k1.clone()), Box::new(k2.clone()));
786
787        // At s=t: 1.0 + 0.5 = 1.5
788        assert!((sum.eval(0.0, 0.0) - 1.5).abs() < TOL);
789        // At s!=t: gaussian + 0 = gaussian
790        let val = sum.eval(0.0, 1.0);
791        let expected = k1.eval(0.0, 1.0);
792        assert!((val - expected).abs() < TOL);
793    }
794
795    #[test]
796    fn test_product_kernel() {
797        let k1 = CovKernel::Gaussian {
798            length_scale: 1.0,
799            variance: 2.0,
800        };
801        let k2 = CovKernel::Gaussian {
802            length_scale: 2.0,
803            variance: 3.0,
804        };
805        let prod = CovKernel::Product(Box::new(k1.clone()), Box::new(k2.clone()));
806
807        let s = 0.0;
808        let t = 0.5;
809        let expected = k1.eval(s, t) * k2.eval(s, t);
810        assert!((prod.eval(s, t) - expected).abs() < TOL);
811    }
812
813    // -----------------------------------------------------------------------
814    // Covariance matrix tests
815    // -----------------------------------------------------------------------
816
817    #[test]
818    fn test_covariance_matrix_symmetric() {
819        let k = CovKernel::Gaussian {
820            length_scale: 1.0,
821            variance: 1.0,
822        };
823        let argvals: Vec<f64> = (0..20).map(|i| i as f64 * 0.1).collect();
824        let cov = covariance_matrix(&k, &argvals).unwrap();
825        let m = argvals.len();
826        assert_eq!(cov.nrows(), m);
827        assert_eq!(cov.ncols(), m);
828
829        // Check symmetry
830        for i in 0..m {
831            for j in 0..m {
832                assert!(
833                    (cov[(i, j)] - cov[(j, i)]).abs() < TOL,
834                    "not symmetric at ({i}, {j})"
835                );
836            }
837        }
838    }
839
840    #[test]
841    fn test_covariance_matrix_positive_definite() {
842        let k = CovKernel::Gaussian {
843            length_scale: 0.5,
844            variance: 1.0,
845        };
846        let argvals: Vec<f64> = (0..10).map(|i| i as f64 * 0.1).collect();
847        let cov = covariance_matrix(&k, &argvals).unwrap();
848        let m = argvals.len();
849
850        // Convert to row-major for cholesky_d
851        let mut row_major = vec![0.0; m * m];
852        for i in 0..m {
853            for j in 0..m {
854                row_major[i * m + j] = cov[(i, j)];
855            }
856        }
857        // Add tiny jitter
858        for i in 0..m {
859            row_major[i * m + i] += 1e-10;
860        }
861
862        // Cholesky should succeed on a positive definite matrix
863        assert!(cholesky_d(&row_major, m).is_ok());
864    }
865
866    #[test]
867    fn test_covariance_matrix_diagonal_equals_variance() {
868        let variance = 2.5;
869        let k = CovKernel::Gaussian {
870            length_scale: 1.0,
871            variance,
872        };
873        let argvals = vec![0.0, 0.5, 1.0];
874        let cov = covariance_matrix(&k, &argvals).unwrap();
875        for i in 0..3 {
876            assert!((cov[(i, i)] - variance).abs() < TOL);
877        }
878    }
879
880    // -----------------------------------------------------------------------
881    // GP generation tests
882    // -----------------------------------------------------------------------
883
884    #[test]
885    fn test_gp_deterministic_with_seed() {
886        let k = CovKernel::Gaussian {
887            length_scale: 0.5,
888            variance: 1.0,
889        };
890        let argvals: Vec<f64> = (0..20).map(|i| i as f64 * 0.05).collect();
891
892        let r1 = generate_gaussian_process(5, &k, &argvals, None, Some(42)).unwrap();
893        let r2 = generate_gaussian_process(5, &k, &argvals, None, Some(42)).unwrap();
894
895        assert_eq!(r1.samples.nrows(), 5);
896        assert_eq!(r1.samples.ncols(), 20);
897
898        // Same seed should produce identical results
899        for i in 0..5 {
900            for j in 0..20 {
901                assert!(
902                    (r1.samples[(i, j)] - r2.samples[(i, j)]).abs() < TOL,
903                    "mismatch at ({i}, {j})"
904                );
905            }
906        }
907    }
908
909    #[test]
910    fn test_gp_different_seeds_differ() {
911        let k = CovKernel::Gaussian {
912            length_scale: 0.5,
913            variance: 1.0,
914        };
915        let argvals: Vec<f64> = (0..10).map(|i| i as f64 * 0.1).collect();
916
917        let r1 = generate_gaussian_process(3, &k, &argvals, None, Some(1)).unwrap();
918        let r2 = generate_gaussian_process(3, &k, &argvals, None, Some(2)).unwrap();
919
920        // At least some values should differ
921        let mut differ = false;
922        for i in 0..3 {
923            for j in 0..10 {
924                if (r1.samples[(i, j)] - r2.samples[(i, j)]).abs() > 1e-6 {
925                    differ = true;
926                }
927            }
928        }
929        assert!(differ, "different seeds should produce different samples");
930    }
931
932    #[test]
933    fn test_gp_mean_and_variance() {
934        let variance = 1.0;
935        let k = CovKernel::Gaussian {
936            length_scale: 0.3,
937            variance,
938        };
939        let argvals: Vec<f64> = (0..5).map(|i| i as f64 * 0.25).collect();
940        let n = 10_000;
941
942        let result = generate_gaussian_process(n, &k, &argvals, None, Some(123)).unwrap();
943
944        // Check empirical mean is close to zero and variance close to 1.0
945        let m = argvals.len();
946        for j in 0..m {
947            let col: Vec<f64> = (0..n).map(|i| result.samples[(i, j)]).collect();
948            let mean = col.iter().sum::<f64>() / n as f64;
949            let var = col.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
950            assert!(
951                mean.abs() < 0.1,
952                "empirical mean at j={j} is {mean}, expected ~0"
953            );
954            assert!(
955                (var - variance).abs() < 0.15,
956                "empirical variance at j={j} is {var}, expected ~{variance}"
957            );
958        }
959    }
960
961    #[test]
962    fn test_gp_with_mean_function() {
963        let k = CovKernel::Gaussian {
964            length_scale: 0.3,
965            variance: 1.0,
966        };
967        let argvals = vec![0.0, 0.5, 1.0];
968        let mean_fn = vec![10.0, 20.0, 30.0];
969        let n = 5000;
970
971        let result = generate_gaussian_process(n, &k, &argvals, Some(&mean_fn), Some(99)).unwrap();
972        assert_eq!(result.mean_function, mean_fn);
973
974        // Empirical mean should be close to the specified mean function
975        for j in 0..3 {
976            let col_mean: f64 = (0..n).map(|i| result.samples[(i, j)]).sum::<f64>() / n as f64;
977            assert!(
978                (col_mean - mean_fn[j]).abs() < 0.2,
979                "empirical mean at j={j} is {col_mean}, expected ~{}",
980                mean_fn[j]
981            );
982        }
983    }
984
985    #[test]
986    fn test_gp_result_fields() {
987        let k = CovKernel::Gaussian {
988            length_scale: 1.0,
989            variance: 1.0,
990        };
991        let argvals = vec![0.0, 0.5, 1.0];
992        let result = generate_gaussian_process(3, &k, &argvals, None, Some(0)).unwrap();
993        assert_eq!(result.argvals, argvals);
994        assert_eq!(result.kernel, k);
995        assert_eq!(result.mean_function, vec![0.0, 0.0, 0.0]);
996    }
997
998    // -----------------------------------------------------------------------
999    // Error case tests
1000    // -----------------------------------------------------------------------
1001
1002    #[test]
1003    fn test_covariance_matrix_empty_argvals() {
1004        let k = CovKernel::Gaussian {
1005            length_scale: 1.0,
1006            variance: 1.0,
1007        };
1008        let err = covariance_matrix(&k, &[]).unwrap_err();
1009        match err {
1010            FdarError::InvalidDimension { parameter, .. } => {
1011                assert_eq!(parameter, "argvals");
1012            }
1013            other => panic!("expected InvalidDimension, got {other:?}"),
1014        }
1015    }
1016
1017    #[test]
1018    fn test_gp_empty_argvals() {
1019        let k = CovKernel::Gaussian {
1020            length_scale: 1.0,
1021            variance: 1.0,
1022        };
1023        let err = generate_gaussian_process(5, &k, &[], None, None).unwrap_err();
1024        match err {
1025            FdarError::InvalidDimension { parameter, .. } => {
1026                assert_eq!(parameter, "argvals");
1027            }
1028            other => panic!("expected InvalidDimension, got {other:?}"),
1029        }
1030    }
1031
1032    #[test]
1033    fn test_gp_n_zero() {
1034        let k = CovKernel::Gaussian {
1035            length_scale: 1.0,
1036            variance: 1.0,
1037        };
1038        let err = generate_gaussian_process(0, &k, &[0.0, 1.0], None, None).unwrap_err();
1039        match err {
1040            FdarError::InvalidDimension { parameter, .. } => {
1041                assert_eq!(parameter, "n");
1042            }
1043            other => panic!("expected InvalidDimension, got {other:?}"),
1044        }
1045    }
1046
1047    #[test]
1048    fn test_gp_wrong_mean_fn_length() {
1049        let k = CovKernel::Gaussian {
1050            length_scale: 1.0,
1051            variance: 1.0,
1052        };
1053        let err = generate_gaussian_process(3, &k, &[0.0, 0.5, 1.0], Some(&[1.0, 2.0]), None)
1054            .unwrap_err();
1055        match err {
1056            FdarError::InvalidDimension { parameter, .. } => {
1057                assert_eq!(parameter, "mean_fn");
1058            }
1059            other => panic!("expected InvalidDimension, got {other:?}"),
1060        }
1061    }
1062
1063    #[test]
1064    fn test_negative_variance_error() {
1065        let k = CovKernel::Gaussian {
1066            length_scale: 1.0,
1067            variance: -1.0,
1068        };
1069        let err = covariance_matrix(&k, &[0.0, 1.0]).unwrap_err();
1070        match err {
1071            FdarError::InvalidParameter { parameter, .. } => {
1072                assert_eq!(parameter, "variance");
1073            }
1074            other => panic!("expected InvalidParameter, got {other:?}"),
1075        }
1076    }
1077
1078    #[test]
1079    fn test_negative_length_scale_error() {
1080        let k = CovKernel::Gaussian {
1081            length_scale: -1.0,
1082            variance: 1.0,
1083        };
1084        let err = covariance_matrix(&k, &[0.0, 1.0]).unwrap_err();
1085        match err {
1086            FdarError::InvalidParameter { parameter, .. } => {
1087                assert_eq!(parameter, "length_scale");
1088            }
1089            other => panic!("expected InvalidParameter, got {other:?}"),
1090        }
1091    }
1092
1093    #[test]
1094    fn test_negative_variance_in_sum_error() {
1095        let k = CovKernel::Sum(
1096            Box::new(CovKernel::Gaussian {
1097                length_scale: 1.0,
1098                variance: 1.0,
1099            }),
1100            Box::new(CovKernel::WhiteNoise { variance: -0.1 }),
1101        );
1102        let err = covariance_matrix(&k, &[0.0, 1.0]).unwrap_err();
1103        match err {
1104            FdarError::InvalidParameter { parameter, .. } => {
1105                assert_eq!(parameter, "variance");
1106            }
1107            other => panic!("expected InvalidParameter, got {other:?}"),
1108        }
1109    }
1110
1111    // -----------------------------------------------------------------------
1112    // Kernel-specific edge cases
1113    // -----------------------------------------------------------------------
1114
1115    #[test]
1116    fn test_matern_general_nu() {
1117        // Test with a non-standard nu value (e.g. 3.5)
1118        let k = CovKernel::Matern {
1119            length_scale: 1.0,
1120            variance: 1.0,
1121            nu: 3.5,
1122        };
1123        // At s=t, should be variance
1124        assert!((k.eval(0.0, 0.0) - 1.0).abs() < TOL);
1125        // Should be positive and decreasing with distance
1126        let v1 = k.eval(0.0, 0.5);
1127        let v2 = k.eval(0.0, 1.0);
1128        assert!(v1 > 0.0, "Matern should be positive: {v1}");
1129        assert!(v2 > 0.0, "Matern should be positive: {v2}");
1130        assert!(v1 > v2, "Matern should decrease with distance");
1131    }
1132
1133    #[test]
1134    fn test_brownian_negative_args() {
1135        let k = CovKernel::Brownian { variance: 1.0 };
1136        // Brownian motion is defined for non-negative arguments
1137        assert!((k.eval(-1.0, 0.5)).abs() < TOL);
1138    }
1139
1140    #[test]
1141    fn test_periodic_kernel_periodicity() {
1142        let period = 2.0;
1143        let k = CovKernel::Periodic {
1144            length_scale: 1.0,
1145            variance: 1.0,
1146            period,
1147        };
1148        // k(0, period) should be close to k(0, 0)
1149        assert!((k.eval(0.0, period) - k.eval(0.0, 0.0)).abs() < 1e-10);
1150        // k(0.3, 0.3 + period) should be close to k(0.3, 0.3)
1151        assert!((k.eval(0.3, 0.3 + period) - k.eval(0.3, 0.3)).abs() < 1e-10);
1152    }
1153
1154    #[test]
1155    fn test_gp_single_point() {
1156        let k = CovKernel::Gaussian {
1157            length_scale: 1.0,
1158            variance: 1.0,
1159        };
1160        let result = generate_gaussian_process(10, &k, &[0.5], None, Some(7)).unwrap();
1161        assert_eq!(result.samples.nrows(), 10);
1162        assert_eq!(result.samples.ncols(), 1);
1163    }
1164
1165    #[test]
1166    fn test_gp_with_brownian_kernel() {
1167        let k = CovKernel::Brownian { variance: 1.0 };
1168        let argvals: Vec<f64> = (1..=10).map(|i| i as f64 * 0.1).collect();
1169        // Brownian kernel on positive argvals should work
1170        let result = generate_gaussian_process(5, &k, &argvals, None, Some(42)).unwrap();
1171        assert_eq!(result.samples.nrows(), 5);
1172        assert_eq!(result.samples.ncols(), 10);
1173    }
1174
1175    #[test]
1176    fn test_gp_with_sum_kernel() {
1177        let k = CovKernel::Sum(
1178            Box::new(CovKernel::Gaussian {
1179                length_scale: 0.5,
1180                variance: 1.0,
1181            }),
1182            Box::new(CovKernel::WhiteNoise { variance: 0.1 }),
1183        );
1184        let argvals: Vec<f64> = (0..10).map(|i| i as f64 * 0.1).collect();
1185        let result = generate_gaussian_process(3, &k, &argvals, None, Some(55)).unwrap();
1186        assert_eq!(result.samples.nrows(), 3);
1187        assert_eq!(result.samples.ncols(), 10);
1188    }
1189
1190    #[test]
1191    fn test_polynomial_degree_zero_error() {
1192        let k = CovKernel::Polynomial {
1193            variance: 1.0,
1194            offset: 1.0,
1195            degree: 0,
1196        };
1197        let err = covariance_matrix(&k, &[0.0, 1.0]).unwrap_err();
1198        match err {
1199            FdarError::InvalidParameter { parameter, .. } => {
1200                assert_eq!(parameter, "degree");
1201            }
1202            other => panic!("expected InvalidParameter, got {other:?}"),
1203        }
1204    }
1205}