Skip to main content

fdars_core/
fdata.rs

1//! Functional data operations: mean, center, derivatives, norms, and geometric median.
2
3use crate::helpers::{simpsons_weights, simpsons_weights_2d, NUMERICAL_EPS};
4use crate::iter_maybe_parallel;
5use crate::matrix::FdMatrix;
6#[cfg(feature = "parallel")]
7use rayon::iter::ParallelIterator;
8
9/// Compute finite difference for a 1D function at a given index.
10///
11/// Uses forward difference at left boundary, backward difference at right boundary,
12/// and central difference for interior points.
13fn finite_diff_1d(
14    values: impl Fn(usize) -> f64,
15    idx: usize,
16    n_points: usize,
17    step_sizes: &[f64],
18) -> f64 {
19    if idx == 0 {
20        (values(1) - values(0)) / step_sizes[0]
21    } else if idx == n_points - 1 {
22        (values(n_points - 1) - values(n_points - 2)) / step_sizes[n_points - 1]
23    } else {
24        (values(idx + 1) - values(idx - 1)) / step_sizes[idx]
25    }
26}
27
28/// Compute 2D partial derivatives at a single grid point.
29///
30/// Returns (∂f/∂s, ∂f/∂t, ∂²f/∂s∂t) using finite differences.
31fn compute_2d_derivatives(
32    get_val: impl Fn(usize, usize) -> f64,
33    si: usize,
34    ti: usize,
35    m1: usize,
36    m2: usize,
37    hs: &[f64],
38    ht: &[f64],
39) -> (f64, f64, f64) {
40    // ∂f/∂s
41    let ds = finite_diff_1d(|s| get_val(s, ti), si, m1, hs);
42
43    // ∂f/∂t
44    let dt = finite_diff_1d(|t| get_val(si, t), ti, m2, ht);
45
46    // ∂²f/∂s∂t (mixed partial)
47    let denom = hs[si] * ht[ti];
48
49    // Get the appropriate indices for s and t differences
50    let (s_lo, s_hi) = if si == 0 {
51        (0, 1)
52    } else if si == m1 - 1 {
53        (m1 - 2, m1 - 1)
54    } else {
55        (si - 1, si + 1)
56    };
57
58    let (t_lo, t_hi) = if ti == 0 {
59        (0, 1)
60    } else if ti == m2 - 1 {
61        (m2 - 2, m2 - 1)
62    } else {
63        (ti - 1, ti + 1)
64    };
65
66    let dsdt = (get_val(s_hi, t_hi) - get_val(s_lo, t_hi) - get_val(s_hi, t_lo)
67        + get_val(s_lo, t_lo))
68        / denom;
69
70    (ds, dt, dsdt)
71}
72
73/// Perform Weiszfeld iteration to compute geometric median.
74///
75/// This is the core algorithm shared by 1D and 2D geometric median computations.
76fn weiszfeld_iteration(data: &FdMatrix, weights: &[f64], max_iter: usize, tol: f64) -> Vec<f64> {
77    let (n, m) = data.shape();
78
79    // Initialize with the mean
80    let mut median: Vec<f64> = (0..m)
81        .map(|j| {
82            let col = data.column(j);
83            col.iter().sum::<f64>() / n as f64
84        })
85        .collect();
86
87    for _ in 0..max_iter {
88        // Compute distances from current median to all curves
89        let distances: Vec<f64> = (0..n)
90            .map(|i| {
91                let mut dist_sq = 0.0;
92                for j in 0..m {
93                    let diff = data[(i, j)] - median[j];
94                    dist_sq += diff * diff * weights[j];
95                }
96                dist_sq.sqrt()
97            })
98            .collect();
99
100        // Compute weights (1/distance), handling zero distances
101        let inv_distances: Vec<f64> = distances
102            .iter()
103            .map(|d| {
104                if *d > NUMERICAL_EPS {
105                    1.0 / d
106                } else {
107                    1.0 / NUMERICAL_EPS
108                }
109            })
110            .collect();
111
112        let sum_inv_dist: f64 = inv_distances.iter().sum();
113
114        // Update median using Weiszfeld iteration
115        let new_median: Vec<f64> = (0..m)
116            .map(|j| {
117                let mut weighted_sum = 0.0;
118                for i in 0..n {
119                    weighted_sum += data[(i, j)] * inv_distances[i];
120                }
121                weighted_sum / sum_inv_dist
122            })
123            .collect();
124
125        // Check convergence
126        let diff: f64 = median
127            .iter()
128            .zip(new_median.iter())
129            .map(|(a, b)| (a - b).abs())
130            .sum::<f64>()
131            / m as f64;
132
133        median = new_median;
134
135        if diff < tol {
136            break;
137        }
138    }
139
140    median
141}
142
143/// Compute the mean function across all samples (1D).
144///
145/// # Arguments
146/// * `data` - Functional data matrix (n x m)
147///
148/// # Returns
149/// Mean function values at each evaluation point
150///
151/// # Examples
152///
153/// ```
154/// use fdars_core::matrix::FdMatrix;
155/// use fdars_core::fdata::mean_1d;
156///
157/// // 3 curves at 4 evaluation points
158/// let data = FdMatrix::from_column_major(
159///     vec![1.0, 2.0, 3.0,  4.0, 5.0, 6.0,  7.0, 8.0, 9.0,  10.0, 11.0, 12.0],
160///     3, 4,
161/// ).unwrap();
162/// let mean = mean_1d(&data);
163/// assert_eq!(mean.len(), 4);
164/// assert!((mean[0] - 2.0).abs() < 1e-10); // mean of [1, 2, 3]
165/// ```
166pub fn mean_1d(data: &FdMatrix) -> Vec<f64> {
167    let (n, m) = data.shape();
168    if n == 0 || m == 0 {
169        return Vec::new();
170    }
171
172    iter_maybe_parallel!(0..m)
173        .map(|j| {
174            let col = data.column(j);
175            col.iter().sum::<f64>() / n as f64
176        })
177        .collect()
178}
179
180/// Compute the mean function for 2D surfaces.
181///
182/// Data is stored as n x (m1*m2) matrix where each row is a flattened surface.
183pub fn mean_2d(data: &FdMatrix) -> Vec<f64> {
184    // Same computation as 1D - just compute pointwise mean
185    mean_1d(data)
186}
187
188/// Center functional data by subtracting the mean function.
189///
190/// # Arguments
191/// * `data` - Functional data matrix (n x m)
192///
193/// # Returns
194/// Centered data matrix
195///
196/// # Examples
197///
198/// ```
199/// use fdars_core::matrix::FdMatrix;
200/// use fdars_core::fdata::{center_1d, mean_1d};
201///
202/// let data = FdMatrix::from_column_major(
203///     vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0], 2, 3,
204/// ).unwrap();
205/// let centered = center_1d(&data);
206/// assert_eq!(centered.shape(), (2, 3));
207/// // Column means of centered data should be zero
208/// let means = mean_1d(&centered);
209/// assert!(means.iter().all(|m| m.abs() < 1e-10));
210/// ```
211pub fn center_1d(data: &FdMatrix) -> FdMatrix {
212    let (n, m) = data.shape();
213    if n == 0 || m == 0 {
214        return FdMatrix::zeros(0, 0);
215    }
216
217    // First compute the mean for each column (parallelized)
218    let means: Vec<f64> = iter_maybe_parallel!(0..m)
219        .map(|j| {
220            let col = data.column(j);
221            col.iter().sum::<f64>() / n as f64
222        })
223        .collect();
224
225    // Create centered data
226    let mut centered = FdMatrix::zeros(n, m);
227    for j in 0..m {
228        let col = centered.column_mut(j);
229        let src = data.column(j);
230        for i in 0..n {
231            col[i] = src[i] - means[j];
232        }
233    }
234
235    centered
236}
237
238/// Compute Lp norm for each sample.
239///
240/// # Arguments
241/// * `data` - Functional data matrix (n x m)
242/// * `argvals` - Evaluation points for integration
243/// * `p` - Order of the norm (e.g., 2.0 for L2)
244///
245/// # Returns
246/// Vector of Lp norms for each sample
247pub fn norm_lp_1d(data: &FdMatrix, argvals: &[f64], p: f64) -> Vec<f64> {
248    let (n, m) = data.shape();
249    if n == 0 || m == 0 || argvals.len() != m {
250        return Vec::new();
251    }
252
253    let weights = simpsons_weights(argvals);
254
255    if (p - 2.0).abs() < 1e-14 {
256        iter_maybe_parallel!(0..n)
257            .map(|i| {
258                let mut integral = 0.0;
259                for j in 0..m {
260                    let v = data[(i, j)];
261                    integral += v * v * weights[j];
262                }
263                integral.sqrt()
264            })
265            .collect()
266    } else if (p - 1.0).abs() < 1e-14 {
267        iter_maybe_parallel!(0..n)
268            .map(|i| {
269                let mut integral = 0.0;
270                for j in 0..m {
271                    integral += data[(i, j)].abs() * weights[j];
272                }
273                integral
274            })
275            .collect()
276    } else {
277        iter_maybe_parallel!(0..n)
278            .map(|i| {
279                let mut integral = 0.0;
280                for j in 0..m {
281                    integral += data[(i, j)].abs().powf(p) * weights[j];
282                }
283                integral.powf(1.0 / p)
284            })
285            .collect()
286    }
287}
288
289/// Compute numerical derivative of functional data (parallelized over rows).
290///
291/// # Arguments
292/// * `data` - Functional data matrix (n x m)
293/// * `argvals` - Evaluation points
294/// * `nderiv` - Order of derivative
295///
296/// # Returns
297/// Derivative data matrix
298///
299/// # Examples
300///
301/// ```
302/// use fdars_core::matrix::FdMatrix;
303/// use fdars_core::fdata::deriv_1d;
304///
305/// // Linear function f(t) = t on [0, 1], derivative should be ~1
306/// let argvals: Vec<f64> = (0..20).map(|i| i as f64 / 19.0).collect();
307/// let data = FdMatrix::from_column_major(argvals.clone(), 1, 20).unwrap();
308/// let deriv = deriv_1d(&data, &argvals, 1);
309/// assert_eq!(deriv.shape(), (1, 20));
310/// // Interior points should have derivative close to 1.0
311/// assert!((deriv[(0, 10)] - 1.0).abs() < 0.1);
312/// ```
313/// Compute one derivative step: forward/central/backward differences written column-wise.
314fn deriv_1d_step(
315    current: &FdMatrix,
316    n: usize,
317    m: usize,
318    h0: f64,
319    hn: f64,
320    h_central: &[f64],
321) -> FdMatrix {
322    let mut next = FdMatrix::zeros(n, m);
323    // Column 0: forward difference
324    let src_col0 = current.column(0);
325    let src_col1 = current.column(1);
326    let dst = next.column_mut(0);
327    for i in 0..n {
328        dst[i] = (src_col1[i] - src_col0[i]) / h0;
329    }
330    // Interior columns: central difference
331    for j in 1..(m - 1) {
332        let src_prev = current.column(j - 1);
333        let src_next = current.column(j + 1);
334        let dst = next.column_mut(j);
335        let h = h_central[j - 1];
336        for i in 0..n {
337            dst[i] = (src_next[i] - src_prev[i]) / h;
338        }
339    }
340    // Column m-1: backward difference
341    let src_colm2 = current.column(m - 2);
342    let src_colm1 = current.column(m - 1);
343    let dst = next.column_mut(m - 1);
344    for i in 0..n {
345        dst[i] = (src_colm1[i] - src_colm2[i]) / hn;
346    }
347    next
348}
349
350pub fn deriv_1d(data: &FdMatrix, argvals: &[f64], nderiv: usize) -> FdMatrix {
351    let (n, m) = data.shape();
352    if n == 0 || m < 2 || argvals.len() != m {
353        return FdMatrix::zeros(n, m);
354    }
355    if nderiv == 0 {
356        return data.clone();
357    }
358
359    let mut current = data.clone();
360
361    // Pre-compute step sizes for central differences
362    let h0 = argvals[1] - argvals[0];
363    let hn = argvals[m - 1] - argvals[m - 2];
364    let h_central: Vec<f64> = (1..(m - 1))
365        .map(|j| argvals[j + 1] - argvals[j - 1])
366        .collect();
367
368    for _ in 0..nderiv {
369        current = deriv_1d_step(&current, n, m, h0, hn, &h_central);
370    }
371
372    current
373}
374
375/// Result of 2D partial derivatives.
376#[derive(Debug, Clone, PartialEq)]
377#[non_exhaustive]
378pub struct Deriv2DResult {
379    /// Partial derivative with respect to s (∂f/∂s)
380    pub ds: FdMatrix,
381    /// Partial derivative with respect to t (∂f/∂t)
382    pub dt: FdMatrix,
383    /// Mixed partial derivative (∂²f/∂s∂t)
384    pub dsdt: FdMatrix,
385}
386
387/// Compute finite-difference step sizes for a grid.
388///
389/// Uses forward/backward difference at boundaries and central difference for interior.
390fn compute_step_sizes(argvals: &[f64]) -> Vec<f64> {
391    let m = argvals.len();
392    if m < 2 {
393        return vec![1.0; m];
394    }
395    (0..m)
396        .map(|j| {
397            if j == 0 {
398                argvals[1] - argvals[0]
399            } else if j == m - 1 {
400                argvals[m - 1] - argvals[m - 2]
401            } else {
402                argvals[j + 1] - argvals[j - 1]
403            }
404        })
405        .collect()
406}
407
408/// Collect per-curve row vectors into a column-major FdMatrix.
409fn reassemble_colmajor(rows: &[Vec<f64>], n: usize, ncol: usize) -> FdMatrix {
410    let mut mat = FdMatrix::zeros(n, ncol);
411    for i in 0..n {
412        for j in 0..ncol {
413            mat[(i, j)] = rows[i][j];
414        }
415    }
416    mat
417}
418
419/// Compute 2D partial derivatives for surface data.
420///
421/// For a surface f(s,t), computes:
422/// - ds: partial derivative with respect to s (∂f/∂s)
423/// - dt: partial derivative with respect to t (∂f/∂t)
424/// - dsdt: mixed partial derivative (∂²f/∂s∂t)
425///
426/// # Arguments
427/// * `data` - Functional data matrix, n surfaces, each stored as m1*m2 values
428/// * `argvals_s` - Grid points in s direction (length m1)
429/// * `argvals_t` - Grid points in t direction (length m2)
430/// * `m1` - Grid size in s direction
431/// * `m2` - Grid size in t direction
432pub fn deriv_2d(
433    data: &FdMatrix,
434    argvals_s: &[f64],
435    argvals_t: &[f64],
436    m1: usize,
437    m2: usize,
438) -> Option<Deriv2DResult> {
439    let n = data.nrows();
440    let ncol = m1 * m2;
441    if n == 0
442        || ncol == 0
443        || m1 < 2
444        || m2 < 2
445        || data.ncols() != ncol
446        || argvals_s.len() != m1
447        || argvals_t.len() != m2
448    {
449        return None;
450    }
451
452    let hs = compute_step_sizes(argvals_s);
453    let ht = compute_step_sizes(argvals_t);
454
455    // Compute all derivatives in parallel over surfaces
456    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
457        .map(|i| {
458            let mut ds = vec![0.0; ncol];
459            let mut dt = vec![0.0; ncol];
460            let mut dsdt = vec![0.0; ncol];
461
462            let get_val = |si: usize, ti: usize| -> f64 { data[(i, si + ti * m1)] };
463
464            for ti in 0..m2 {
465                for si in 0..m1 {
466                    let idx = si + ti * m1;
467                    let (ds_val, dt_val, dsdt_val) =
468                        compute_2d_derivatives(get_val, si, ti, m1, m2, &hs, &ht);
469                    ds[idx] = ds_val;
470                    dt[idx] = dt_val;
471                    dsdt[idx] = dsdt_val;
472                }
473            }
474
475            (ds, dt, dsdt)
476        })
477        .collect();
478
479    let (ds_vecs, (dt_vecs, dsdt_vecs)): (Vec<Vec<f64>>, (Vec<Vec<f64>>, Vec<Vec<f64>>)) =
480        results.into_iter().map(|(a, b, c)| (a, (b, c))).unzip();
481
482    Some(Deriv2DResult {
483        ds: reassemble_colmajor(&ds_vecs, n, ncol),
484        dt: reassemble_colmajor(&dt_vecs, n, ncol),
485        dsdt: reassemble_colmajor(&dsdt_vecs, n, ncol),
486    })
487}
488
489/// Compute the geometric median (L1 median) of functional data using Weiszfeld's algorithm.
490///
491/// The geometric median minimizes sum of L2 distances to all curves.
492///
493/// # Arguments
494/// * `data` - Functional data matrix (n x m)
495/// * `argvals` - Evaluation points for integration
496/// * `max_iter` - Maximum iterations
497/// * `tol` - Convergence tolerance
498pub fn geometric_median_1d(
499    data: &FdMatrix,
500    argvals: &[f64],
501    max_iter: usize,
502    tol: f64,
503) -> Vec<f64> {
504    let (n, m) = data.shape();
505    if n == 0 || m == 0 || argvals.len() != m {
506        return Vec::new();
507    }
508
509    let weights = simpsons_weights(argvals);
510    weiszfeld_iteration(data, &weights, max_iter, tol)
511}
512
513/// Compute the geometric median for 2D functional data.
514///
515/// Data is stored as n x (m1*m2) matrix where each row is a flattened surface.
516///
517/// # Arguments
518/// * `data` - Functional data matrix (n x m) where m = m1*m2
519/// * `argvals_s` - Grid points in s direction (length m1)
520/// * `argvals_t` - Grid points in t direction (length m2)
521/// * `max_iter` - Maximum iterations
522/// * `tol` - Convergence tolerance
523pub fn geometric_median_2d(
524    data: &FdMatrix,
525    argvals_s: &[f64],
526    argvals_t: &[f64],
527    max_iter: usize,
528    tol: f64,
529) -> Vec<f64> {
530    let (n, m) = data.shape();
531    let expected_cols = argvals_s.len() * argvals_t.len();
532    if n == 0 || m == 0 || m != expected_cols {
533        return Vec::new();
534    }
535
536    let weights = simpsons_weights_2d(argvals_s, argvals_t);
537    weiszfeld_iteration(data, &weights, max_iter, tol)
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use crate::test_helpers::uniform_grid;
544    use std::f64::consts::PI;
545
546    // ============== Mean tests ==============
547
548    #[test]
549    fn test_mean_1d() {
550        // 2 samples, 3 points each
551        // Sample 1: [1, 2, 3]
552        // Sample 2: [3, 4, 5]
553        // Mean should be [2, 3, 4]
554        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; // column-major
555        let mat = FdMatrix::from_column_major(data, 2, 3).unwrap();
556        let mean = mean_1d(&mat);
557        assert_eq!(mean, vec![2.0, 3.0, 4.0]);
558    }
559
560    #[test]
561    fn test_mean_1d_single_sample() {
562        let data = vec![1.0, 2.0, 3.0];
563        let mat = FdMatrix::from_column_major(data, 1, 3).unwrap();
564        let mean = mean_1d(&mat);
565        assert_eq!(mean, vec![1.0, 2.0, 3.0]);
566    }
567
568    #[test]
569    fn test_mean_1d_invalid() {
570        assert!(mean_1d(&FdMatrix::zeros(0, 0)).is_empty());
571    }
572
573    #[test]
574    fn test_mean_2d_delegates() {
575        let data = vec![1.0, 3.0, 2.0, 4.0];
576        let mat = FdMatrix::from_column_major(data, 2, 2).unwrap();
577        let mean1d = mean_1d(&mat);
578        let mean2d = mean_2d(&mat);
579        assert_eq!(mean1d, mean2d);
580    }
581
582    // ============== Center tests ==============
583
584    #[test]
585    fn test_center_1d() {
586        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; // column-major
587        let mat = FdMatrix::from_column_major(data, 2, 3).unwrap();
588        let centered = center_1d(&mat);
589        // Mean is [2, 3, 4], so centered should be [-1, 1, -1, 1, -1, 1]
590        assert_eq!(centered.as_slice(), &[-1.0, 1.0, -1.0, 1.0, -1.0, 1.0]);
591    }
592
593    #[test]
594    fn test_center_1d_mean_zero() {
595        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0];
596        let mat = FdMatrix::from_column_major(data, 2, 3).unwrap();
597        let centered = center_1d(&mat);
598        let centered_mean = mean_1d(&centered);
599        for m in centered_mean {
600            assert!(m.abs() < 1e-10, "Centered data should have zero mean");
601        }
602    }
603
604    #[test]
605    fn test_center_1d_invalid() {
606        let centered = center_1d(&FdMatrix::zeros(0, 0));
607        assert!(centered.is_empty());
608    }
609
610    // ============== Norm tests ==============
611
612    #[test]
613    fn test_norm_lp_1d_constant() {
614        // Constant function 2 on [0, 1] has L2 norm = 2
615        let argvals = uniform_grid(21);
616        let data: Vec<f64> = vec![2.0; 21];
617        let mat = FdMatrix::from_column_major(data, 1, 21).unwrap();
618        let norms = norm_lp_1d(&mat, &argvals, 2.0);
619        assert_eq!(norms.len(), 1);
620        assert!(
621            (norms[0] - 2.0).abs() < 0.1,
622            "L2 norm of constant 2 should be 2"
623        );
624    }
625
626    #[test]
627    fn test_norm_lp_1d_sine() {
628        // L2 norm of sin(pi*x) on [0, 1] = sqrt(0.5)
629        let argvals = uniform_grid(101);
630        let data: Vec<f64> = argvals.iter().map(|&x| (PI * x).sin()).collect();
631        let mat = FdMatrix::from_column_major(data, 1, 101).unwrap();
632        let norms = norm_lp_1d(&mat, &argvals, 2.0);
633        let expected = 0.5_f64.sqrt();
634        assert!(
635            (norms[0] - expected).abs() < 0.05,
636            "Expected {}, got {}",
637            expected,
638            norms[0]
639        );
640    }
641
642    #[test]
643    fn test_norm_lp_1d_invalid() {
644        assert!(norm_lp_1d(&FdMatrix::zeros(0, 0), &[], 2.0).is_empty());
645    }
646
647    // ============== Derivative tests ==============
648
649    #[test]
650    fn test_deriv_1d_linear() {
651        // Derivative of linear function x should be 1
652        let argvals = uniform_grid(21);
653        let data = argvals.clone();
654        let mat = FdMatrix::from_column_major(data, 1, 21).unwrap();
655        let deriv = deriv_1d(&mat, &argvals, 1);
656        // Interior points should have derivative close to 1
657        for j in 2..19 {
658            assert!(
659                (deriv[(0, j)] - 1.0).abs() < 0.1,
660                "Derivative of x should be 1"
661            );
662        }
663    }
664
665    #[test]
666    fn test_deriv_1d_quadratic() {
667        // Derivative of x^2 should be 2x
668        let argvals = uniform_grid(51);
669        let data: Vec<f64> = argvals.iter().map(|&x| x * x).collect();
670        let mat = FdMatrix::from_column_major(data, 1, 51).unwrap();
671        let deriv = deriv_1d(&mat, &argvals, 1);
672        // Check interior points
673        for j in 5..45 {
674            let expected = 2.0 * argvals[j];
675            assert!(
676                (deriv[(0, j)] - expected).abs() < 0.1,
677                "Derivative of x^2 should be 2x"
678            );
679        }
680    }
681
682    #[test]
683    fn test_deriv_1d_invalid() {
684        let result = deriv_1d(&FdMatrix::zeros(0, 0), &[], 1);
685        assert!(result.is_empty() || result.as_slice().iter().all(|&x| x == 0.0));
686    }
687
688    // ============== Geometric median tests ==============
689
690    #[test]
691    fn test_geometric_median_identical_curves() {
692        // All curves identical -> median = that curve
693        let argvals = uniform_grid(21);
694        let n = 5;
695        let m = 21;
696        let mut data = vec![0.0; n * m];
697        for i in 0..n {
698            for j in 0..m {
699                data[i + j * n] = (2.0 * PI * argvals[j]).sin();
700            }
701        }
702        let mat = FdMatrix::from_column_major(data, n, m).unwrap();
703        let median = geometric_median_1d(&mat, &argvals, 100, 1e-6);
704        for j in 0..m {
705            let expected = (2.0 * PI * argvals[j]).sin();
706            assert!(
707                (median[j] - expected).abs() < 0.01,
708                "Median should equal all curves"
709            );
710        }
711    }
712
713    #[test]
714    fn test_geometric_median_converges() {
715        let argvals = uniform_grid(21);
716        let n = 10;
717        let m = 21;
718        let mut data = vec![0.0; n * m];
719        for i in 0..n {
720            for j in 0..m {
721                data[i + j * n] = (i as f64 / n as f64) * argvals[j];
722            }
723        }
724        let mat = FdMatrix::from_column_major(data, n, m).unwrap();
725        let median = geometric_median_1d(&mat, &argvals, 100, 1e-6);
726        assert_eq!(median.len(), m);
727        assert!(median.iter().all(|&x| x.is_finite()));
728    }
729
730    #[test]
731    fn test_geometric_median_invalid() {
732        assert!(geometric_median_1d(&FdMatrix::zeros(0, 0), &[], 100, 1e-6).is_empty());
733    }
734
735    // ============== 2D derivative tests ==============
736
737    #[test]
738    fn test_deriv_2d_linear_surface() {
739        // f(s, t) = 2*s + 3*t
740        // ∂f/∂s = 2, ∂f/∂t = 3, ∂²f/∂s∂t = 0
741        let m1 = 11;
742        let m2 = 11;
743        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
744        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
745
746        let n = 1; // single surface
747        let ncol = m1 * m2;
748        let mut data = vec![0.0; n * ncol];
749
750        for si in 0..m1 {
751            for ti in 0..m2 {
752                let s = argvals_s[si];
753                let t = argvals_t[ti];
754                let idx = si + ti * m1;
755                data[idx] = 2.0 * s + 3.0 * t;
756            }
757        }
758
759        let mat = FdMatrix::from_column_major(data, n, ncol).unwrap();
760        let result = deriv_2d(&mat, &argvals_s, &argvals_t, m1, m2).unwrap();
761
762        // Check interior points for ∂f/∂s ≈ 2
763        for si in 2..(m1 - 2) {
764            for ti in 2..(m2 - 2) {
765                let idx = si + ti * m1;
766                assert!(
767                    (result.ds[(0, idx)] - 2.0).abs() < 0.2,
768                    "∂f/∂s at ({}, {}) = {}, expected 2",
769                    si,
770                    ti,
771                    result.ds[(0, idx)]
772                );
773            }
774        }
775
776        // Check interior points for ∂f/∂t ≈ 3
777        for si in 2..(m1 - 2) {
778            for ti in 2..(m2 - 2) {
779                let idx = si + ti * m1;
780                assert!(
781                    (result.dt[(0, idx)] - 3.0).abs() < 0.2,
782                    "∂f/∂t at ({}, {}) = {}, expected 3",
783                    si,
784                    ti,
785                    result.dt[(0, idx)]
786                );
787            }
788        }
789
790        // Check interior points for mixed partial ≈ 0
791        for si in 2..(m1 - 2) {
792            for ti in 2..(m2 - 2) {
793                let idx = si + ti * m1;
794                assert!(
795                    result.dsdt[(0, idx)].abs() < 0.5,
796                    "∂²f/∂s∂t at ({}, {}) = {}, expected 0",
797                    si,
798                    ti,
799                    result.dsdt[(0, idx)]
800                );
801            }
802        }
803    }
804
805    #[test]
806    fn test_deriv_2d_quadratic_surface() {
807        // f(s, t) = s*t
808        // ∂f/∂s = t, ∂f/∂t = s, ∂²f/∂s∂t = 1
809        let m1 = 21;
810        let m2 = 21;
811        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
812        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
813
814        let n = 1;
815        let ncol = m1 * m2;
816        let mut data = vec![0.0; n * ncol];
817
818        for si in 0..m1 {
819            for ti in 0..m2 {
820                let s = argvals_s[si];
821                let t = argvals_t[ti];
822                let idx = si + ti * m1;
823                data[idx] = s * t;
824            }
825        }
826
827        let mat = FdMatrix::from_column_major(data, n, ncol).unwrap();
828        let result = deriv_2d(&mat, &argvals_s, &argvals_t, m1, m2).unwrap();
829
830        // Check interior points for ∂f/∂s ≈ t
831        for si in 3..(m1 - 3) {
832            for ti in 3..(m2 - 3) {
833                let idx = si + ti * m1;
834                let expected = argvals_t[ti];
835                assert!(
836                    (result.ds[(0, idx)] - expected).abs() < 0.1,
837                    "∂f/∂s at ({}, {}) = {}, expected {}",
838                    si,
839                    ti,
840                    result.ds[(0, idx)],
841                    expected
842                );
843            }
844        }
845
846        // Check interior points for ∂f/∂t ≈ s
847        for si in 3..(m1 - 3) {
848            for ti in 3..(m2 - 3) {
849                let idx = si + ti * m1;
850                let expected = argvals_s[si];
851                assert!(
852                    (result.dt[(0, idx)] - expected).abs() < 0.1,
853                    "∂f/∂t at ({}, {}) = {}, expected {}",
854                    si,
855                    ti,
856                    result.dt[(0, idx)],
857                    expected
858                );
859            }
860        }
861
862        // Check interior points for mixed partial ≈ 1
863        for si in 3..(m1 - 3) {
864            for ti in 3..(m2 - 3) {
865                let idx = si + ti * m1;
866                assert!(
867                    (result.dsdt[(0, idx)] - 1.0).abs() < 0.3,
868                    "∂²f/∂s∂t at ({}, {}) = {}, expected 1",
869                    si,
870                    ti,
871                    result.dsdt[(0, idx)]
872                );
873            }
874        }
875    }
876
877    #[test]
878    fn test_deriv_2d_invalid_input() {
879        // Empty data
880        let result = deriv_2d(&FdMatrix::zeros(0, 0), &[], &[], 0, 0);
881        assert!(result.is_none());
882
883        // Mismatched dimensions
884        let mat = FdMatrix::from_column_major(vec![1.0; 4], 1, 4).unwrap();
885        let argvals = vec![0.0, 1.0];
886        let result = deriv_2d(&mat, &argvals, &[0.0, 0.5, 1.0], 2, 2);
887        assert!(result.is_none());
888    }
889
890    // ============== 2D geometric median tests ==============
891
892    #[test]
893    fn test_geometric_median_2d_basic() {
894        // Three identical surfaces -> median = that surface
895        let m1 = 5;
896        let m2 = 5;
897        let m = m1 * m2;
898        let n = 3;
899        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
900        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
901
902        let mut data = vec![0.0; n * m];
903
904        // Create identical surfaces: f(s, t) = s + t
905        for i in 0..n {
906            for si in 0..m1 {
907                for ti in 0..m2 {
908                    let idx = si + ti * m1;
909                    let s = argvals_s[si];
910                    let t = argvals_t[ti];
911                    data[i + idx * n] = s + t;
912                }
913            }
914        }
915
916        let mat = FdMatrix::from_column_major(data, n, m).unwrap();
917        let median = geometric_median_2d(&mat, &argvals_s, &argvals_t, 100, 1e-6);
918        assert_eq!(median.len(), m);
919
920        // Check that median equals the surface
921        for si in 0..m1 {
922            for ti in 0..m2 {
923                let idx = si + ti * m1;
924                let expected = argvals_s[si] + argvals_t[ti];
925                assert!(
926                    (median[idx] - expected).abs() < 0.01,
927                    "Median at ({}, {}) = {}, expected {}",
928                    si,
929                    ti,
930                    median[idx],
931                    expected
932                );
933            }
934        }
935    }
936
937    #[test]
938    fn test_nan_mean_no_panic() {
939        let mut data_vec = vec![1.0; 6];
940        data_vec[2] = f64::NAN;
941        let data = FdMatrix::from_column_major(data_vec, 2, 3).unwrap();
942        let m = mean_1d(&data);
943        assert_eq!(m.len(), 3);
944    }
945
946    #[test]
947    fn test_nan_center_no_panic() {
948        let mut data_vec = vec![1.0; 6];
949        data_vec[2] = f64::NAN;
950        let data = FdMatrix::from_column_major(data_vec, 2, 3).unwrap();
951        let c = center_1d(&data);
952        assert_eq!(c.nrows(), 2);
953    }
954
955    #[test]
956    fn test_nan_norm_no_panic() {
957        let mut data_vec = vec![1.0; 6];
958        data_vec[2] = f64::NAN;
959        let data = FdMatrix::from_column_major(data_vec, 2, 3).unwrap();
960        let argvals = vec![0.0, 0.5, 1.0];
961        let norms = norm_lp_1d(&data, &argvals, 2.0);
962        assert_eq!(norms.len(), 2);
963    }
964
965    #[test]
966    fn test_n1_norm() {
967        let data = FdMatrix::from_column_major(vec![0.0, 1.0, 0.0], 1, 3).unwrap();
968        let argvals = vec![0.0, 0.5, 1.0];
969        let norms = norm_lp_1d(&data, &argvals, 2.0);
970        assert_eq!(norms.len(), 1);
971        assert!(norms[0] > 0.0);
972    }
973
974    #[test]
975    fn test_n2_center() {
976        let data = FdMatrix::from_column_major(vec![1.0, 3.0, 2.0, 4.0], 2, 2).unwrap();
977        let centered = center_1d(&data);
978        // Mean at each point: [2.0, 3.0]
979        // centered[0,0] = 1.0 - 2.0 = -1.0
980        assert!((centered[(0, 0)] - (-1.0)).abs() < 1e-12);
981        assert!((centered[(1, 0)] - 1.0).abs() < 1e-12);
982    }
983
984    #[test]
985    fn test_deriv_nderiv0() {
986        // nderiv=0 returns the original data (0th derivative = identity)
987        let data = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).unwrap();
988        let argvals = vec![0.0, 0.5, 1.0];
989        let result = deriv_1d(&data, &argvals, 0);
990        assert_eq!(result.shape(), data.shape());
991        for i in 0..2 {
992            for j in 0..3 {
993                assert!((result[(i, j)] - data[(i, j)]).abs() < 1e-12);
994            }
995        }
996    }
997}