Skip to main content

fdars_core/
fdata.rs

1//! Functional data operations: mean, center, derivatives, norms, and geometric median.
2
3use crate::helpers::{simpsons_weights, simpsons_weights_2d, NUMERICAL_EPS};
4use crate::iter_maybe_parallel;
5use crate::matrix::FdMatrix;
6#[cfg(feature = "parallel")]
7use rayon::iter::ParallelIterator;
8
9/// Compute finite difference for a 1D function at a given index.
10///
11/// Uses forward difference at left boundary, backward difference at right boundary,
12/// and central difference for interior points.
13fn finite_diff_1d(
14    values: impl Fn(usize) -> f64,
15    idx: usize,
16    n_points: usize,
17    step_sizes: &[f64],
18) -> f64 {
19    if idx == 0 {
20        (values(1) - values(0)) / step_sizes[0]
21    } else if idx == n_points - 1 {
22        (values(n_points - 1) - values(n_points - 2)) / step_sizes[n_points - 1]
23    } else {
24        (values(idx + 1) - values(idx - 1)) / step_sizes[idx]
25    }
26}
27
28/// Compute 2D partial derivatives at a single grid point.
29///
30/// Returns (∂f/∂s, ∂f/∂t, ∂²f/∂s∂t) using finite differences.
31fn compute_2d_derivatives(
32    get_val: impl Fn(usize, usize) -> f64,
33    si: usize,
34    ti: usize,
35    m1: usize,
36    m2: usize,
37    hs: &[f64],
38    ht: &[f64],
39) -> (f64, f64, f64) {
40    // ∂f/∂s
41    let ds = finite_diff_1d(|s| get_val(s, ti), si, m1, hs);
42
43    // ∂f/∂t
44    let dt = finite_diff_1d(|t| get_val(si, t), ti, m2, ht);
45
46    // ∂²f/∂s∂t (mixed partial)
47    let denom = hs[si] * ht[ti];
48
49    // Get the appropriate indices for s and t differences
50    let (s_lo, s_hi) = if si == 0 {
51        (0, 1)
52    } else if si == m1 - 1 {
53        (m1 - 2, m1 - 1)
54    } else {
55        (si - 1, si + 1)
56    };
57
58    let (t_lo, t_hi) = if ti == 0 {
59        (0, 1)
60    } else if ti == m2 - 1 {
61        (m2 - 2, m2 - 1)
62    } else {
63        (ti - 1, ti + 1)
64    };
65
66    let dsdt = (get_val(s_hi, t_hi) - get_val(s_lo, t_hi) - get_val(s_hi, t_lo)
67        + get_val(s_lo, t_lo))
68        / denom;
69
70    (ds, dt, dsdt)
71}
72
73/// Perform Weiszfeld iteration to compute geometric median.
74///
75/// This is the core algorithm shared by 1D and 2D geometric median computations.
76fn weiszfeld_iteration(data: &FdMatrix, weights: &[f64], max_iter: usize, tol: f64) -> Vec<f64> {
77    let (n, m) = data.shape();
78
79    // Initialize with the mean
80    let mut median: Vec<f64> = (0..m)
81        .map(|j| {
82            let col = data.column(j);
83            col.iter().sum::<f64>() / n as f64
84        })
85        .collect();
86
87    for _ in 0..max_iter {
88        // Compute distances from current median to all curves
89        let distances: Vec<f64> = (0..n)
90            .map(|i| {
91                let mut dist_sq = 0.0;
92                for j in 0..m {
93                    let diff = data[(i, j)] - median[j];
94                    dist_sq += diff * diff * weights[j];
95                }
96                dist_sq.sqrt()
97            })
98            .collect();
99
100        // Compute weights (1/distance), handling zero distances
101        let inv_distances: Vec<f64> = distances
102            .iter()
103            .map(|d| {
104                if *d > NUMERICAL_EPS {
105                    1.0 / d
106                } else {
107                    1.0 / NUMERICAL_EPS
108                }
109            })
110            .collect();
111
112        let sum_inv_dist: f64 = inv_distances.iter().sum();
113
114        // Update median using Weiszfeld iteration
115        let new_median: Vec<f64> = (0..m)
116            .map(|j| {
117                let mut weighted_sum = 0.0;
118                for i in 0..n {
119                    weighted_sum += data[(i, j)] * inv_distances[i];
120                }
121                weighted_sum / sum_inv_dist
122            })
123            .collect();
124
125        // Check convergence
126        let diff: f64 = median
127            .iter()
128            .zip(new_median.iter())
129            .map(|(a, b)| (a - b).abs())
130            .sum::<f64>()
131            / m as f64;
132
133        median = new_median;
134
135        if diff < tol {
136            break;
137        }
138    }
139
140    median
141}
142
143/// Compute the mean function across all samples (1D).
144///
145/// # Arguments
146/// * `data` - Functional data matrix (n x m)
147///
148/// # Returns
149/// Mean function values at each evaluation point
150pub fn mean_1d(data: &FdMatrix) -> Vec<f64> {
151    let (n, m) = data.shape();
152    if n == 0 || m == 0 {
153        return Vec::new();
154    }
155
156    iter_maybe_parallel!(0..m)
157        .map(|j| {
158            let col = data.column(j);
159            col.iter().sum::<f64>() / n as f64
160        })
161        .collect()
162}
163
164/// Compute the mean function for 2D surfaces.
165///
166/// Data is stored as n x (m1*m2) matrix where each row is a flattened surface.
167pub fn mean_2d(data: &FdMatrix) -> Vec<f64> {
168    // Same computation as 1D - just compute pointwise mean
169    mean_1d(data)
170}
171
172/// Center functional data by subtracting the mean function.
173///
174/// # Arguments
175/// * `data` - Functional data matrix (n x m)
176///
177/// # Returns
178/// Centered data matrix
179pub fn center_1d(data: &FdMatrix) -> FdMatrix {
180    let (n, m) = data.shape();
181    if n == 0 || m == 0 {
182        return FdMatrix::zeros(0, 0);
183    }
184
185    // First compute the mean for each column (parallelized)
186    let means: Vec<f64> = iter_maybe_parallel!(0..m)
187        .map(|j| {
188            let col = data.column(j);
189            col.iter().sum::<f64>() / n as f64
190        })
191        .collect();
192
193    // Create centered data
194    let mut centered = FdMatrix::zeros(n, m);
195    for j in 0..m {
196        let col = centered.column_mut(j);
197        let src = data.column(j);
198        for i in 0..n {
199            col[i] = src[i] - means[j];
200        }
201    }
202
203    centered
204}
205
206/// Compute Lp norm for each sample.
207///
208/// # Arguments
209/// * `data` - Functional data matrix (n x m)
210/// * `argvals` - Evaluation points for integration
211/// * `p` - Order of the norm (e.g., 2.0 for L2)
212///
213/// # Returns
214/// Vector of Lp norms for each sample
215pub fn norm_lp_1d(data: &FdMatrix, argvals: &[f64], p: f64) -> Vec<f64> {
216    let (n, m) = data.shape();
217    if n == 0 || m == 0 || argvals.len() != m {
218        return Vec::new();
219    }
220
221    let weights = simpsons_weights(argvals);
222
223    if (p - 2.0).abs() < 1e-14 {
224        iter_maybe_parallel!(0..n)
225            .map(|i| {
226                let mut integral = 0.0;
227                for j in 0..m {
228                    let v = data[(i, j)];
229                    integral += v * v * weights[j];
230                }
231                integral.sqrt()
232            })
233            .collect()
234    } else if (p - 1.0).abs() < 1e-14 {
235        iter_maybe_parallel!(0..n)
236            .map(|i| {
237                let mut integral = 0.0;
238                for j in 0..m {
239                    integral += data[(i, j)].abs() * weights[j];
240                }
241                integral
242            })
243            .collect()
244    } else {
245        iter_maybe_parallel!(0..n)
246            .map(|i| {
247                let mut integral = 0.0;
248                for j in 0..m {
249                    integral += data[(i, j)].abs().powf(p) * weights[j];
250                }
251                integral.powf(1.0 / p)
252            })
253            .collect()
254    }
255}
256
257/// Compute numerical derivative of functional data (parallelized over rows).
258///
259/// # Arguments
260/// * `data` - Functional data matrix (n x m)
261/// * `argvals` - Evaluation points
262/// * `nderiv` - Order of derivative
263///
264/// # Returns
265/// Derivative data matrix
266/// Compute one derivative step: forward/central/backward differences written column-wise.
267fn deriv_1d_step(
268    current: &FdMatrix,
269    n: usize,
270    m: usize,
271    h0: f64,
272    hn: f64,
273    h_central: &[f64],
274) -> FdMatrix {
275    let mut next = FdMatrix::zeros(n, m);
276    // Column 0: forward difference
277    let src_col0 = current.column(0);
278    let src_col1 = current.column(1);
279    let dst = next.column_mut(0);
280    for i in 0..n {
281        dst[i] = (src_col1[i] - src_col0[i]) / h0;
282    }
283    // Interior columns: central difference
284    for j in 1..(m - 1) {
285        let src_prev = current.column(j - 1);
286        let src_next = current.column(j + 1);
287        let dst = next.column_mut(j);
288        let h = h_central[j - 1];
289        for i in 0..n {
290            dst[i] = (src_next[i] - src_prev[i]) / h;
291        }
292    }
293    // Column m-1: backward difference
294    let src_colm2 = current.column(m - 2);
295    let src_colm1 = current.column(m - 1);
296    let dst = next.column_mut(m - 1);
297    for i in 0..n {
298        dst[i] = (src_colm1[i] - src_colm2[i]) / hn;
299    }
300    next
301}
302
303pub fn deriv_1d(data: &FdMatrix, argvals: &[f64], nderiv: usize) -> FdMatrix {
304    let (n, m) = data.shape();
305    if n == 0 || m < 2 || argvals.len() != m {
306        return FdMatrix::zeros(n, m);
307    }
308    if nderiv == 0 {
309        return data.clone();
310    }
311
312    let mut current = data.clone();
313
314    // Pre-compute step sizes for central differences
315    let h0 = argvals[1] - argvals[0];
316    let hn = argvals[m - 1] - argvals[m - 2];
317    let h_central: Vec<f64> = (1..(m - 1))
318        .map(|j| argvals[j + 1] - argvals[j - 1])
319        .collect();
320
321    for _ in 0..nderiv {
322        current = deriv_1d_step(&current, n, m, h0, hn, &h_central);
323    }
324
325    current
326}
327
328/// Result of 2D partial derivatives.
329#[derive(Debug, Clone, PartialEq)]
330pub struct Deriv2DResult {
331    /// Partial derivative with respect to s (∂f/∂s)
332    pub ds: FdMatrix,
333    /// Partial derivative with respect to t (∂f/∂t)
334    pub dt: FdMatrix,
335    /// Mixed partial derivative (∂²f/∂s∂t)
336    pub dsdt: FdMatrix,
337}
338
339/// Compute finite-difference step sizes for a grid.
340///
341/// Uses forward/backward difference at boundaries and central difference for interior.
342fn compute_step_sizes(argvals: &[f64]) -> Vec<f64> {
343    let m = argvals.len();
344    if m < 2 {
345        return vec![1.0; m];
346    }
347    (0..m)
348        .map(|j| {
349            if j == 0 {
350                argvals[1] - argvals[0]
351            } else if j == m - 1 {
352                argvals[m - 1] - argvals[m - 2]
353            } else {
354                argvals[j + 1] - argvals[j - 1]
355            }
356        })
357        .collect()
358}
359
360/// Collect per-curve row vectors into a column-major FdMatrix.
361fn reassemble_colmajor(rows: &[Vec<f64>], n: usize, ncol: usize) -> FdMatrix {
362    let mut mat = FdMatrix::zeros(n, ncol);
363    for i in 0..n {
364        for j in 0..ncol {
365            mat[(i, j)] = rows[i][j];
366        }
367    }
368    mat
369}
370
371/// Compute 2D partial derivatives for surface data.
372///
373/// For a surface f(s,t), computes:
374/// - ds: partial derivative with respect to s (∂f/∂s)
375/// - dt: partial derivative with respect to t (∂f/∂t)
376/// - dsdt: mixed partial derivative (∂²f/∂s∂t)
377///
378/// # Arguments
379/// * `data` - Functional data matrix, n surfaces, each stored as m1*m2 values
380/// * `argvals_s` - Grid points in s direction (length m1)
381/// * `argvals_t` - Grid points in t direction (length m2)
382/// * `m1` - Grid size in s direction
383/// * `m2` - Grid size in t direction
384pub fn deriv_2d(
385    data: &FdMatrix,
386    argvals_s: &[f64],
387    argvals_t: &[f64],
388    m1: usize,
389    m2: usize,
390) -> Option<Deriv2DResult> {
391    let n = data.nrows();
392    let ncol = m1 * m2;
393    if n == 0
394        || ncol == 0
395        || m1 < 2
396        || m2 < 2
397        || data.ncols() != ncol
398        || argvals_s.len() != m1
399        || argvals_t.len() != m2
400    {
401        return None;
402    }
403
404    let hs = compute_step_sizes(argvals_s);
405    let ht = compute_step_sizes(argvals_t);
406
407    // Compute all derivatives in parallel over surfaces
408    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
409        .map(|i| {
410            let mut ds = vec![0.0; ncol];
411            let mut dt = vec![0.0; ncol];
412            let mut dsdt = vec![0.0; ncol];
413
414            let get_val = |si: usize, ti: usize| -> f64 { data[(i, si + ti * m1)] };
415
416            for ti in 0..m2 {
417                for si in 0..m1 {
418                    let idx = si + ti * m1;
419                    let (ds_val, dt_val, dsdt_val) =
420                        compute_2d_derivatives(get_val, si, ti, m1, m2, &hs, &ht);
421                    ds[idx] = ds_val;
422                    dt[idx] = dt_val;
423                    dsdt[idx] = dsdt_val;
424                }
425            }
426
427            (ds, dt, dsdt)
428        })
429        .collect();
430
431    let (ds_vecs, (dt_vecs, dsdt_vecs)): (Vec<Vec<f64>>, (Vec<Vec<f64>>, Vec<Vec<f64>>)) =
432        results.into_iter().map(|(a, b, c)| (a, (b, c))).unzip();
433
434    Some(Deriv2DResult {
435        ds: reassemble_colmajor(&ds_vecs, n, ncol),
436        dt: reassemble_colmajor(&dt_vecs, n, ncol),
437        dsdt: reassemble_colmajor(&dsdt_vecs, n, ncol),
438    })
439}
440
441/// Compute the geometric median (L1 median) of functional data using Weiszfeld's algorithm.
442///
443/// The geometric median minimizes sum of L2 distances to all curves.
444///
445/// # Arguments
446/// * `data` - Functional data matrix (n x m)
447/// * `argvals` - Evaluation points for integration
448/// * `max_iter` - Maximum iterations
449/// * `tol` - Convergence tolerance
450pub fn geometric_median_1d(
451    data: &FdMatrix,
452    argvals: &[f64],
453    max_iter: usize,
454    tol: f64,
455) -> Vec<f64> {
456    let (n, m) = data.shape();
457    if n == 0 || m == 0 || argvals.len() != m {
458        return Vec::new();
459    }
460
461    let weights = simpsons_weights(argvals);
462    weiszfeld_iteration(data, &weights, max_iter, tol)
463}
464
465/// Compute the geometric median for 2D functional data.
466///
467/// Data is stored as n x (m1*m2) matrix where each row is a flattened surface.
468///
469/// # Arguments
470/// * `data` - Functional data matrix (n x m) where m = m1*m2
471/// * `argvals_s` - Grid points in s direction (length m1)
472/// * `argvals_t` - Grid points in t direction (length m2)
473/// * `max_iter` - Maximum iterations
474/// * `tol` - Convergence tolerance
475pub fn geometric_median_2d(
476    data: &FdMatrix,
477    argvals_s: &[f64],
478    argvals_t: &[f64],
479    max_iter: usize,
480    tol: f64,
481) -> Vec<f64> {
482    let (n, m) = data.shape();
483    let expected_cols = argvals_s.len() * argvals_t.len();
484    if n == 0 || m == 0 || m != expected_cols {
485        return Vec::new();
486    }
487
488    let weights = simpsons_weights_2d(argvals_s, argvals_t);
489    weiszfeld_iteration(data, &weights, max_iter, tol)
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use std::f64::consts::PI;
496
497    fn uniform_grid(n: usize) -> Vec<f64> {
498        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
499    }
500
501    // ============== Mean tests ==============
502
503    #[test]
504    fn test_mean_1d() {
505        // 2 samples, 3 points each
506        // Sample 1: [1, 2, 3]
507        // Sample 2: [3, 4, 5]
508        // Mean should be [2, 3, 4]
509        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; // column-major
510        let mat = FdMatrix::from_column_major(data, 2, 3).unwrap();
511        let mean = mean_1d(&mat);
512        assert_eq!(mean, vec![2.0, 3.0, 4.0]);
513    }
514
515    #[test]
516    fn test_mean_1d_single_sample() {
517        let data = vec![1.0, 2.0, 3.0];
518        let mat = FdMatrix::from_column_major(data, 1, 3).unwrap();
519        let mean = mean_1d(&mat);
520        assert_eq!(mean, vec![1.0, 2.0, 3.0]);
521    }
522
523    #[test]
524    fn test_mean_1d_invalid() {
525        assert!(mean_1d(&FdMatrix::zeros(0, 0)).is_empty());
526    }
527
528    #[test]
529    fn test_mean_2d_delegates() {
530        let data = vec![1.0, 3.0, 2.0, 4.0];
531        let mat = FdMatrix::from_column_major(data, 2, 2).unwrap();
532        let mean1d = mean_1d(&mat);
533        let mean2d = mean_2d(&mat);
534        assert_eq!(mean1d, mean2d);
535    }
536
537    // ============== Center tests ==============
538
539    #[test]
540    fn test_center_1d() {
541        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; // column-major
542        let mat = FdMatrix::from_column_major(data, 2, 3).unwrap();
543        let centered = center_1d(&mat);
544        // Mean is [2, 3, 4], so centered should be [-1, 1, -1, 1, -1, 1]
545        assert_eq!(centered.as_slice(), &[-1.0, 1.0, -1.0, 1.0, -1.0, 1.0]);
546    }
547
548    #[test]
549    fn test_center_1d_mean_zero() {
550        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0];
551        let mat = FdMatrix::from_column_major(data, 2, 3).unwrap();
552        let centered = center_1d(&mat);
553        let centered_mean = mean_1d(&centered);
554        for m in centered_mean {
555            assert!(m.abs() < 1e-10, "Centered data should have zero mean");
556        }
557    }
558
559    #[test]
560    fn test_center_1d_invalid() {
561        let centered = center_1d(&FdMatrix::zeros(0, 0));
562        assert!(centered.is_empty());
563    }
564
565    // ============== Norm tests ==============
566
567    #[test]
568    fn test_norm_lp_1d_constant() {
569        // Constant function 2 on [0, 1] has L2 norm = 2
570        let argvals = uniform_grid(21);
571        let data: Vec<f64> = vec![2.0; 21];
572        let mat = FdMatrix::from_column_major(data, 1, 21).unwrap();
573        let norms = norm_lp_1d(&mat, &argvals, 2.0);
574        assert_eq!(norms.len(), 1);
575        assert!(
576            (norms[0] - 2.0).abs() < 0.1,
577            "L2 norm of constant 2 should be 2"
578        );
579    }
580
581    #[test]
582    fn test_norm_lp_1d_sine() {
583        // L2 norm of sin(pi*x) on [0, 1] = sqrt(0.5)
584        let argvals = uniform_grid(101);
585        let data: Vec<f64> = argvals.iter().map(|&x| (PI * x).sin()).collect();
586        let mat = FdMatrix::from_column_major(data, 1, 101).unwrap();
587        let norms = norm_lp_1d(&mat, &argvals, 2.0);
588        let expected = 0.5_f64.sqrt();
589        assert!(
590            (norms[0] - expected).abs() < 0.05,
591            "Expected {}, got {}",
592            expected,
593            norms[0]
594        );
595    }
596
597    #[test]
598    fn test_norm_lp_1d_invalid() {
599        assert!(norm_lp_1d(&FdMatrix::zeros(0, 0), &[], 2.0).is_empty());
600    }
601
602    // ============== Derivative tests ==============
603
604    #[test]
605    fn test_deriv_1d_linear() {
606        // Derivative of linear function x should be 1
607        let argvals = uniform_grid(21);
608        let data = argvals.clone();
609        let mat = FdMatrix::from_column_major(data, 1, 21).unwrap();
610        let deriv = deriv_1d(&mat, &argvals, 1);
611        // Interior points should have derivative close to 1
612        for j in 2..19 {
613            assert!(
614                (deriv[(0, j)] - 1.0).abs() < 0.1,
615                "Derivative of x should be 1"
616            );
617        }
618    }
619
620    #[test]
621    fn test_deriv_1d_quadratic() {
622        // Derivative of x^2 should be 2x
623        let argvals = uniform_grid(51);
624        let data: Vec<f64> = argvals.iter().map(|&x| x * x).collect();
625        let mat = FdMatrix::from_column_major(data, 1, 51).unwrap();
626        let deriv = deriv_1d(&mat, &argvals, 1);
627        // Check interior points
628        for j in 5..45 {
629            let expected = 2.0 * argvals[j];
630            assert!(
631                (deriv[(0, j)] - expected).abs() < 0.1,
632                "Derivative of x^2 should be 2x"
633            );
634        }
635    }
636
637    #[test]
638    fn test_deriv_1d_invalid() {
639        let result = deriv_1d(&FdMatrix::zeros(0, 0), &[], 1);
640        assert!(result.is_empty() || result.as_slice().iter().all(|&x| x == 0.0));
641    }
642
643    // ============== Geometric median tests ==============
644
645    #[test]
646    fn test_geometric_median_identical_curves() {
647        // All curves identical -> median = that curve
648        let argvals = uniform_grid(21);
649        let n = 5;
650        let m = 21;
651        let mut data = vec![0.0; n * m];
652        for i in 0..n {
653            for j in 0..m {
654                data[i + j * n] = (2.0 * PI * argvals[j]).sin();
655            }
656        }
657        let mat = FdMatrix::from_column_major(data, n, m).unwrap();
658        let median = geometric_median_1d(&mat, &argvals, 100, 1e-6);
659        for j in 0..m {
660            let expected = (2.0 * PI * argvals[j]).sin();
661            assert!(
662                (median[j] - expected).abs() < 0.01,
663                "Median should equal all curves"
664            );
665        }
666    }
667
668    #[test]
669    fn test_geometric_median_converges() {
670        let argvals = uniform_grid(21);
671        let n = 10;
672        let m = 21;
673        let mut data = vec![0.0; n * m];
674        for i in 0..n {
675            for j in 0..m {
676                data[i + j * n] = (i as f64 / n as f64) * argvals[j];
677            }
678        }
679        let mat = FdMatrix::from_column_major(data, n, m).unwrap();
680        let median = geometric_median_1d(&mat, &argvals, 100, 1e-6);
681        assert_eq!(median.len(), m);
682        assert!(median.iter().all(|&x| x.is_finite()));
683    }
684
685    #[test]
686    fn test_geometric_median_invalid() {
687        assert!(geometric_median_1d(&FdMatrix::zeros(0, 0), &[], 100, 1e-6).is_empty());
688    }
689
690    // ============== 2D derivative tests ==============
691
692    #[test]
693    fn test_deriv_2d_linear_surface() {
694        // f(s, t) = 2*s + 3*t
695        // ∂f/∂s = 2, ∂f/∂t = 3, ∂²f/∂s∂t = 0
696        let m1 = 11;
697        let m2 = 11;
698        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
699        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
700
701        let n = 1; // single surface
702        let ncol = m1 * m2;
703        let mut data = vec![0.0; n * ncol];
704
705        for si in 0..m1 {
706            for ti in 0..m2 {
707                let s = argvals_s[si];
708                let t = argvals_t[ti];
709                let idx = si + ti * m1;
710                data[idx] = 2.0 * s + 3.0 * t;
711            }
712        }
713
714        let mat = FdMatrix::from_column_major(data, n, ncol).unwrap();
715        let result = deriv_2d(&mat, &argvals_s, &argvals_t, m1, m2).unwrap();
716
717        // Check interior points for ∂f/∂s ≈ 2
718        for si in 2..(m1 - 2) {
719            for ti in 2..(m2 - 2) {
720                let idx = si + ti * m1;
721                assert!(
722                    (result.ds[(0, idx)] - 2.0).abs() < 0.2,
723                    "∂f/∂s at ({}, {}) = {}, expected 2",
724                    si,
725                    ti,
726                    result.ds[(0, idx)]
727                );
728            }
729        }
730
731        // Check interior points for ∂f/∂t ≈ 3
732        for si in 2..(m1 - 2) {
733            for ti in 2..(m2 - 2) {
734                let idx = si + ti * m1;
735                assert!(
736                    (result.dt[(0, idx)] - 3.0).abs() < 0.2,
737                    "∂f/∂t at ({}, {}) = {}, expected 3",
738                    si,
739                    ti,
740                    result.dt[(0, idx)]
741                );
742            }
743        }
744
745        // Check interior points for mixed partial ≈ 0
746        for si in 2..(m1 - 2) {
747            for ti in 2..(m2 - 2) {
748                let idx = si + ti * m1;
749                assert!(
750                    result.dsdt[(0, idx)].abs() < 0.5,
751                    "∂²f/∂s∂t at ({}, {}) = {}, expected 0",
752                    si,
753                    ti,
754                    result.dsdt[(0, idx)]
755                );
756            }
757        }
758    }
759
760    #[test]
761    fn test_deriv_2d_quadratic_surface() {
762        // f(s, t) = s*t
763        // ∂f/∂s = t, ∂f/∂t = s, ∂²f/∂s∂t = 1
764        let m1 = 21;
765        let m2 = 21;
766        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
767        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
768
769        let n = 1;
770        let ncol = m1 * m2;
771        let mut data = vec![0.0; n * ncol];
772
773        for si in 0..m1 {
774            for ti in 0..m2 {
775                let s = argvals_s[si];
776                let t = argvals_t[ti];
777                let idx = si + ti * m1;
778                data[idx] = s * t;
779            }
780        }
781
782        let mat = FdMatrix::from_column_major(data, n, ncol).unwrap();
783        let result = deriv_2d(&mat, &argvals_s, &argvals_t, m1, m2).unwrap();
784
785        // Check interior points for ∂f/∂s ≈ t
786        for si in 3..(m1 - 3) {
787            for ti in 3..(m2 - 3) {
788                let idx = si + ti * m1;
789                let expected = argvals_t[ti];
790                assert!(
791                    (result.ds[(0, idx)] - expected).abs() < 0.1,
792                    "∂f/∂s at ({}, {}) = {}, expected {}",
793                    si,
794                    ti,
795                    result.ds[(0, idx)],
796                    expected
797                );
798            }
799        }
800
801        // Check interior points for ∂f/∂t ≈ s
802        for si in 3..(m1 - 3) {
803            for ti in 3..(m2 - 3) {
804                let idx = si + ti * m1;
805                let expected = argvals_s[si];
806                assert!(
807                    (result.dt[(0, idx)] - expected).abs() < 0.1,
808                    "∂f/∂t at ({}, {}) = {}, expected {}",
809                    si,
810                    ti,
811                    result.dt[(0, idx)],
812                    expected
813                );
814            }
815        }
816
817        // Check interior points for mixed partial ≈ 1
818        for si in 3..(m1 - 3) {
819            for ti in 3..(m2 - 3) {
820                let idx = si + ti * m1;
821                assert!(
822                    (result.dsdt[(0, idx)] - 1.0).abs() < 0.3,
823                    "∂²f/∂s∂t at ({}, {}) = {}, expected 1",
824                    si,
825                    ti,
826                    result.dsdt[(0, idx)]
827                );
828            }
829        }
830    }
831
832    #[test]
833    fn test_deriv_2d_invalid_input() {
834        // Empty data
835        let result = deriv_2d(&FdMatrix::zeros(0, 0), &[], &[], 0, 0);
836        assert!(result.is_none());
837
838        // Mismatched dimensions
839        let mat = FdMatrix::from_column_major(vec![1.0; 4], 1, 4).unwrap();
840        let argvals = vec![0.0, 1.0];
841        let result = deriv_2d(&mat, &argvals, &[0.0, 0.5, 1.0], 2, 2);
842        assert!(result.is_none());
843    }
844
845    // ============== 2D geometric median tests ==============
846
847    #[test]
848    fn test_geometric_median_2d_basic() {
849        // Three identical surfaces -> median = that surface
850        let m1 = 5;
851        let m2 = 5;
852        let m = m1 * m2;
853        let n = 3;
854        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
855        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
856
857        let mut data = vec![0.0; n * m];
858
859        // Create identical surfaces: f(s, t) = s + t
860        for i in 0..n {
861            for si in 0..m1 {
862                for ti in 0..m2 {
863                    let idx = si + ti * m1;
864                    let s = argvals_s[si];
865                    let t = argvals_t[ti];
866                    data[i + idx * n] = s + t;
867                }
868            }
869        }
870
871        let mat = FdMatrix::from_column_major(data, n, m).unwrap();
872        let median = geometric_median_2d(&mat, &argvals_s, &argvals_t, 100, 1e-6);
873        assert_eq!(median.len(), m);
874
875        // Check that median equals the surface
876        for si in 0..m1 {
877            for ti in 0..m2 {
878                let idx = si + ti * m1;
879                let expected = argvals_s[si] + argvals_t[ti];
880                assert!(
881                    (median[idx] - expected).abs() < 0.01,
882                    "Median at ({}, {}) = {}, expected {}",
883                    si,
884                    ti,
885                    median[idx],
886                    expected
887                );
888            }
889        }
890    }
891
892    #[test]
893    fn test_nan_mean_no_panic() {
894        let mut data_vec = vec![1.0; 6];
895        data_vec[2] = f64::NAN;
896        let data = FdMatrix::from_column_major(data_vec, 2, 3).unwrap();
897        let m = mean_1d(&data);
898        assert_eq!(m.len(), 3);
899    }
900
901    #[test]
902    fn test_nan_center_no_panic() {
903        let mut data_vec = vec![1.0; 6];
904        data_vec[2] = f64::NAN;
905        let data = FdMatrix::from_column_major(data_vec, 2, 3).unwrap();
906        let c = center_1d(&data);
907        assert_eq!(c.nrows(), 2);
908    }
909
910    #[test]
911    fn test_nan_norm_no_panic() {
912        let mut data_vec = vec![1.0; 6];
913        data_vec[2] = f64::NAN;
914        let data = FdMatrix::from_column_major(data_vec, 2, 3).unwrap();
915        let argvals = vec![0.0, 0.5, 1.0];
916        let norms = norm_lp_1d(&data, &argvals, 2.0);
917        assert_eq!(norms.len(), 2);
918    }
919
920    #[test]
921    fn test_n1_norm() {
922        let data = FdMatrix::from_column_major(vec![0.0, 1.0, 0.0], 1, 3).unwrap();
923        let argvals = vec![0.0, 0.5, 1.0];
924        let norms = norm_lp_1d(&data, &argvals, 2.0);
925        assert_eq!(norms.len(), 1);
926        assert!(norms[0] > 0.0);
927    }
928
929    #[test]
930    fn test_n2_center() {
931        let data = FdMatrix::from_column_major(vec![1.0, 3.0, 2.0, 4.0], 2, 2).unwrap();
932        let centered = center_1d(&data);
933        // Mean at each point: [2.0, 3.0]
934        // centered[0,0] = 1.0 - 2.0 = -1.0
935        assert!((centered[(0, 0)] - (-1.0)).abs() < 1e-12);
936        assert!((centered[(1, 0)] - 1.0).abs() < 1e-12);
937    }
938
939    #[test]
940    fn test_deriv_nderiv0() {
941        // nderiv=0 returns the original data (0th derivative = identity)
942        let data = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).unwrap();
943        let argvals = vec![0.0, 0.5, 1.0];
944        let result = deriv_1d(&data, &argvals, 0);
945        assert_eq!(result.shape(), data.shape());
946        for i in 0..2 {
947            for j in 0..3 {
948                assert!((result[(i, j)] - data[(i, j)]).abs() < 1e-12);
949            }
950        }
951    }
952}