Skip to main content

fdars_core/
fdata.rs

1//! Functional data operations: mean, center, derivatives, norms, and geometric median.
2
3use crate::helpers::{simpsons_weights, simpsons_weights_2d, NUMERICAL_EPS};
4use crate::iter_maybe_parallel;
5use crate::matrix::FdMatrix;
6#[cfg(feature = "parallel")]
7use rayon::iter::ParallelIterator;
8
9/// Compute finite difference for a 1D function at a given index.
10///
11/// Uses forward difference at left boundary, backward difference at right boundary,
12/// and central difference for interior points.
13fn finite_diff_1d(
14    values: impl Fn(usize) -> f64,
15    idx: usize,
16    n_points: usize,
17    step_sizes: &[f64],
18) -> f64 {
19    if idx == 0 {
20        (values(1) - values(0)) / step_sizes[0]
21    } else if idx == n_points - 1 {
22        (values(n_points - 1) - values(n_points - 2)) / step_sizes[n_points - 1]
23    } else {
24        (values(idx + 1) - values(idx - 1)) / step_sizes[idx]
25    }
26}
27
28/// Compute 2D partial derivatives at a single grid point.
29///
30/// Returns (∂f/∂s, ∂f/∂t, ∂²f/∂s∂t) using finite differences.
31fn compute_2d_derivatives(
32    get_val: impl Fn(usize, usize) -> f64,
33    si: usize,
34    ti: usize,
35    m1: usize,
36    m2: usize,
37    hs: &[f64],
38    ht: &[f64],
39) -> (f64, f64, f64) {
40    // ∂f/∂s
41    let ds = finite_diff_1d(|s| get_val(s, ti), si, m1, hs);
42
43    // ∂f/∂t
44    let dt = finite_diff_1d(|t| get_val(si, t), ti, m2, ht);
45
46    // ∂²f/∂s∂t (mixed partial)
47    let denom = hs[si] * ht[ti];
48
49    // Get the appropriate indices for s and t differences
50    let (s_lo, s_hi) = if si == 0 {
51        (0, 1)
52    } else if si == m1 - 1 {
53        (m1 - 2, m1 - 1)
54    } else {
55        (si - 1, si + 1)
56    };
57
58    let (t_lo, t_hi) = if ti == 0 {
59        (0, 1)
60    } else if ti == m2 - 1 {
61        (m2 - 2, m2 - 1)
62    } else {
63        (ti - 1, ti + 1)
64    };
65
66    let dsdt = (get_val(s_hi, t_hi) - get_val(s_lo, t_hi) - get_val(s_hi, t_lo)
67        + get_val(s_lo, t_lo))
68        / denom;
69
70    (ds, dt, dsdt)
71}
72
73/// Perform Weiszfeld iteration to compute geometric median.
74///
75/// This is the core algorithm shared by 1D and 2D geometric median computations.
76fn weiszfeld_iteration(data: &FdMatrix, weights: &[f64], max_iter: usize, tol: f64) -> Vec<f64> {
77    let (n, m) = data.shape();
78
79    // Initialize with the mean
80    let mut median: Vec<f64> = (0..m)
81        .map(|j| {
82            let col = data.column(j);
83            col.iter().sum::<f64>() / n as f64
84        })
85        .collect();
86
87    for _ in 0..max_iter {
88        // Compute distances from current median to all curves
89        let distances: Vec<f64> = (0..n)
90            .map(|i| {
91                let mut dist_sq = 0.0;
92                for j in 0..m {
93                    let diff = data[(i, j)] - median[j];
94                    dist_sq += diff * diff * weights[j];
95                }
96                dist_sq.sqrt()
97            })
98            .collect();
99
100        // Compute weights (1/distance), handling zero distances
101        let inv_distances: Vec<f64> = distances
102            .iter()
103            .map(|d| {
104                if *d > NUMERICAL_EPS {
105                    1.0 / d
106                } else {
107                    1.0 / NUMERICAL_EPS
108                }
109            })
110            .collect();
111
112        let sum_inv_dist: f64 = inv_distances.iter().sum();
113
114        // Update median using Weiszfeld iteration
115        let new_median: Vec<f64> = (0..m)
116            .map(|j| {
117                let mut weighted_sum = 0.0;
118                for i in 0..n {
119                    weighted_sum += data[(i, j)] * inv_distances[i];
120                }
121                weighted_sum / sum_inv_dist
122            })
123            .collect();
124
125        // Check convergence
126        let diff: f64 = median
127            .iter()
128            .zip(new_median.iter())
129            .map(|(a, b)| (a - b).abs())
130            .sum::<f64>()
131            / m as f64;
132
133        median = new_median;
134
135        if diff < tol {
136            break;
137        }
138    }
139
140    median
141}
142
143/// Compute the mean function across all samples (1D).
144///
145/// # Arguments
146/// * `data` - Functional data matrix (n x m)
147///
148/// # Returns
149/// Mean function values at each evaluation point
150pub fn mean_1d(data: &FdMatrix) -> Vec<f64> {
151    let (n, m) = data.shape();
152    if n == 0 || m == 0 {
153        return Vec::new();
154    }
155
156    iter_maybe_parallel!(0..m)
157        .map(|j| {
158            let col = data.column(j);
159            col.iter().sum::<f64>() / n as f64
160        })
161        .collect()
162}
163
164/// Compute the mean function for 2D surfaces.
165///
166/// Data is stored as n x (m1*m2) matrix where each row is a flattened surface.
167pub fn mean_2d(data: &FdMatrix) -> Vec<f64> {
168    // Same computation as 1D - just compute pointwise mean
169    mean_1d(data)
170}
171
172/// Center functional data by subtracting the mean function.
173///
174/// # Arguments
175/// * `data` - Functional data matrix (n x m)
176///
177/// # Returns
178/// Centered data matrix
179pub fn center_1d(data: &FdMatrix) -> FdMatrix {
180    let (n, m) = data.shape();
181    if n == 0 || m == 0 {
182        return FdMatrix::zeros(0, 0);
183    }
184
185    // First compute the mean for each column (parallelized)
186    let means: Vec<f64> = iter_maybe_parallel!(0..m)
187        .map(|j| {
188            let col = data.column(j);
189            col.iter().sum::<f64>() / n as f64
190        })
191        .collect();
192
193    // Create centered data
194    let mut centered = FdMatrix::zeros(n, m);
195    for j in 0..m {
196        let col = centered.column_mut(j);
197        let src = data.column(j);
198        for i in 0..n {
199            col[i] = src[i] - means[j];
200        }
201    }
202
203    centered
204}
205
206/// Compute Lp norm for each sample.
207///
208/// # Arguments
209/// * `data` - Functional data matrix (n x m)
210/// * `argvals` - Evaluation points for integration
211/// * `p` - Order of the norm (e.g., 2.0 for L2)
212///
213/// # Returns
214/// Vector of Lp norms for each sample
215pub fn norm_lp_1d(data: &FdMatrix, argvals: &[f64], p: f64) -> Vec<f64> {
216    let (n, m) = data.shape();
217    if n == 0 || m == 0 || argvals.len() != m {
218        return Vec::new();
219    }
220
221    let weights = simpsons_weights(argvals);
222
223    iter_maybe_parallel!(0..n)
224        .map(|i| {
225            let mut integral = 0.0;
226            for j in 0..m {
227                let val = data[(i, j)].abs().powf(p);
228                integral += val * weights[j];
229            }
230            integral.powf(1.0 / p)
231        })
232        .collect()
233}
234
235/// Compute numerical derivative of functional data (parallelized over rows).
236///
237/// # Arguments
238/// * `data` - Functional data matrix (n x m)
239/// * `argvals` - Evaluation points
240/// * `nderiv` - Order of derivative
241///
242/// # Returns
243/// Derivative data matrix
244pub fn deriv_1d(data: &FdMatrix, argvals: &[f64], nderiv: usize) -> FdMatrix {
245    let (n, m) = data.shape();
246    if n == 0 || m == 0 || argvals.len() != m || nderiv < 1 {
247        return FdMatrix::zeros(n, m);
248    }
249
250    let mut current = data.clone();
251
252    // Pre-compute step sizes for central differences
253    let h0 = argvals[1] - argvals[0];
254    let hn = argvals[m - 1] - argvals[m - 2];
255    let h_central: Vec<f64> = (1..(m - 1))
256        .map(|j| argvals[j + 1] - argvals[j - 1])
257        .collect();
258
259    for _ in 0..nderiv {
260        // Compute derivative for each row in parallel
261        let deriv: Vec<f64> = iter_maybe_parallel!(0..n)
262            .flat_map(|i| {
263                let mut row_deriv = vec![0.0; m];
264
265                // Forward difference at left boundary
266                row_deriv[0] = (current[(i, 1)] - current[(i, 0)]) / h0;
267
268                // Central differences for interior points
269                for j in 1..(m - 1) {
270                    row_deriv[j] = (current[(i, j + 1)] - current[(i, j - 1)]) / h_central[j - 1];
271                }
272
273                // Backward difference at right boundary
274                row_deriv[m - 1] = (current[(i, m - 1)] - current[(i, m - 2)]) / hn;
275
276                row_deriv
277            })
278            .collect();
279
280        // Reorder from row-major to column-major order
281        let mut next = FdMatrix::zeros(n, m);
282        for i in 0..n {
283            for j in 0..m {
284                next[(i, j)] = deriv[i * m + j];
285            }
286        }
287        current = next;
288    }
289
290    current
291}
292
293/// Result of 2D partial derivatives.
294pub struct Deriv2DResult {
295    /// Partial derivative with respect to s (∂f/∂s)
296    pub ds: FdMatrix,
297    /// Partial derivative with respect to t (∂f/∂t)
298    pub dt: FdMatrix,
299    /// Mixed partial derivative (∂²f/∂s∂t)
300    pub dsdt: FdMatrix,
301}
302
303/// Compute finite-difference step sizes for a grid.
304///
305/// Uses forward/backward difference at boundaries and central difference for interior.
306fn compute_step_sizes(argvals: &[f64]) -> Vec<f64> {
307    let m = argvals.len();
308    (0..m)
309        .map(|j| {
310            if j == 0 {
311                argvals[1] - argvals[0]
312            } else if j == m - 1 {
313                argvals[m - 1] - argvals[m - 2]
314            } else {
315                argvals[j + 1] - argvals[j - 1]
316            }
317        })
318        .collect()
319}
320
321/// Collect per-curve row vectors into a column-major FdMatrix.
322fn reassemble_colmajor(rows: &[Vec<f64>], n: usize, ncol: usize) -> FdMatrix {
323    let mut mat = FdMatrix::zeros(n, ncol);
324    for i in 0..n {
325        for j in 0..ncol {
326            mat[(i, j)] = rows[i][j];
327        }
328    }
329    mat
330}
331
332/// Compute 2D partial derivatives for surface data.
333///
334/// For a surface f(s,t), computes:
335/// - ds: partial derivative with respect to s (∂f/∂s)
336/// - dt: partial derivative with respect to t (∂f/∂t)
337/// - dsdt: mixed partial derivative (∂²f/∂s∂t)
338///
339/// # Arguments
340/// * `data` - Functional data matrix, n surfaces, each stored as m1*m2 values
341/// * `argvals_s` - Grid points in s direction (length m1)
342/// * `argvals_t` - Grid points in t direction (length m2)
343/// * `m1` - Grid size in s direction
344/// * `m2` - Grid size in t direction
345pub fn deriv_2d(
346    data: &FdMatrix,
347    argvals_s: &[f64],
348    argvals_t: &[f64],
349    m1: usize,
350    m2: usize,
351) -> Option<Deriv2DResult> {
352    let n = data.nrows();
353    let ncol = m1 * m2;
354    if n == 0 || ncol == 0 || argvals_s.len() != m1 || argvals_t.len() != m2 {
355        return None;
356    }
357
358    let hs = compute_step_sizes(argvals_s);
359    let ht = compute_step_sizes(argvals_t);
360
361    // Compute all derivatives in parallel over surfaces
362    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
363        .map(|i| {
364            let mut ds = vec![0.0; ncol];
365            let mut dt = vec![0.0; ncol];
366            let mut dsdt = vec![0.0; ncol];
367
368            let get_val = |si: usize, ti: usize| -> f64 { data[(i, si + ti * m1)] };
369
370            for ti in 0..m2 {
371                for si in 0..m1 {
372                    let idx = si + ti * m1;
373                    let (ds_val, dt_val, dsdt_val) =
374                        compute_2d_derivatives(get_val, si, ti, m1, m2, &hs, &ht);
375                    ds[idx] = ds_val;
376                    dt[idx] = dt_val;
377                    dsdt[idx] = dsdt_val;
378                }
379            }
380
381            (ds, dt, dsdt)
382        })
383        .collect();
384
385    let ds_vecs: Vec<Vec<f64>> = results.iter().map(|r| r.0.clone()).collect();
386    let dt_vecs: Vec<Vec<f64>> = results.iter().map(|r| r.1.clone()).collect();
387    let dsdt_vecs: Vec<Vec<f64>> = results.iter().map(|r| r.2.clone()).collect();
388
389    Some(Deriv2DResult {
390        ds: reassemble_colmajor(&ds_vecs, n, ncol),
391        dt: reassemble_colmajor(&dt_vecs, n, ncol),
392        dsdt: reassemble_colmajor(&dsdt_vecs, n, ncol),
393    })
394}
395
396/// Compute the geometric median (L1 median) of functional data using Weiszfeld's algorithm.
397///
398/// The geometric median minimizes sum of L2 distances to all curves.
399///
400/// # Arguments
401/// * `data` - Functional data matrix (n x m)
402/// * `argvals` - Evaluation points for integration
403/// * `max_iter` - Maximum iterations
404/// * `tol` - Convergence tolerance
405pub fn geometric_median_1d(
406    data: &FdMatrix,
407    argvals: &[f64],
408    max_iter: usize,
409    tol: f64,
410) -> Vec<f64> {
411    let (n, m) = data.shape();
412    if n == 0 || m == 0 || argvals.len() != m {
413        return Vec::new();
414    }
415
416    let weights = simpsons_weights(argvals);
417    weiszfeld_iteration(data, &weights, max_iter, tol)
418}
419
420/// Compute the geometric median for 2D functional data.
421///
422/// Data is stored as n x (m1*m2) matrix where each row is a flattened surface.
423///
424/// # Arguments
425/// * `data` - Functional data matrix (n x m) where m = m1*m2
426/// * `argvals_s` - Grid points in s direction (length m1)
427/// * `argvals_t` - Grid points in t direction (length m2)
428/// * `max_iter` - Maximum iterations
429/// * `tol` - Convergence tolerance
430pub fn geometric_median_2d(
431    data: &FdMatrix,
432    argvals_s: &[f64],
433    argvals_t: &[f64],
434    max_iter: usize,
435    tol: f64,
436) -> Vec<f64> {
437    let (n, m) = data.shape();
438    let expected_cols = argvals_s.len() * argvals_t.len();
439    if n == 0 || m == 0 || m != expected_cols {
440        return Vec::new();
441    }
442
443    let weights = simpsons_weights_2d(argvals_s, argvals_t);
444    weiszfeld_iteration(data, &weights, max_iter, tol)
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450    use std::f64::consts::PI;
451
452    fn uniform_grid(n: usize) -> Vec<f64> {
453        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
454    }
455
456    // ============== Mean tests ==============
457
458    #[test]
459    fn test_mean_1d() {
460        // 2 samples, 3 points each
461        // Sample 1: [1, 2, 3]
462        // Sample 2: [3, 4, 5]
463        // Mean should be [2, 3, 4]
464        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; // column-major
465        let mat = FdMatrix::from_column_major(data, 2, 3).unwrap();
466        let mean = mean_1d(&mat);
467        assert_eq!(mean, vec![2.0, 3.0, 4.0]);
468    }
469
470    #[test]
471    fn test_mean_1d_single_sample() {
472        let data = vec![1.0, 2.0, 3.0];
473        let mat = FdMatrix::from_column_major(data, 1, 3).unwrap();
474        let mean = mean_1d(&mat);
475        assert_eq!(mean, vec![1.0, 2.0, 3.0]);
476    }
477
478    #[test]
479    fn test_mean_1d_invalid() {
480        assert!(mean_1d(&FdMatrix::zeros(0, 0)).is_empty());
481    }
482
483    #[test]
484    fn test_mean_2d_delegates() {
485        let data = vec![1.0, 3.0, 2.0, 4.0];
486        let mat = FdMatrix::from_column_major(data, 2, 2).unwrap();
487        let mean1d = mean_1d(&mat);
488        let mean2d = mean_2d(&mat);
489        assert_eq!(mean1d, mean2d);
490    }
491
492    // ============== Center tests ==============
493
494    #[test]
495    fn test_center_1d() {
496        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; // column-major
497        let mat = FdMatrix::from_column_major(data, 2, 3).unwrap();
498        let centered = center_1d(&mat);
499        // Mean is [2, 3, 4], so centered should be [-1, 1, -1, 1, -1, 1]
500        assert_eq!(centered.as_slice(), &[-1.0, 1.0, -1.0, 1.0, -1.0, 1.0]);
501    }
502
503    #[test]
504    fn test_center_1d_mean_zero() {
505        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0];
506        let mat = FdMatrix::from_column_major(data, 2, 3).unwrap();
507        let centered = center_1d(&mat);
508        let centered_mean = mean_1d(&centered);
509        for m in centered_mean {
510            assert!(m.abs() < 1e-10, "Centered data should have zero mean");
511        }
512    }
513
514    #[test]
515    fn test_center_1d_invalid() {
516        let centered = center_1d(&FdMatrix::zeros(0, 0));
517        assert!(centered.is_empty());
518    }
519
520    // ============== Norm tests ==============
521
522    #[test]
523    fn test_norm_lp_1d_constant() {
524        // Constant function 2 on [0, 1] has L2 norm = 2
525        let argvals = uniform_grid(21);
526        let data: Vec<f64> = vec![2.0; 21];
527        let mat = FdMatrix::from_column_major(data, 1, 21).unwrap();
528        let norms = norm_lp_1d(&mat, &argvals, 2.0);
529        assert_eq!(norms.len(), 1);
530        assert!(
531            (norms[0] - 2.0).abs() < 0.1,
532            "L2 norm of constant 2 should be 2"
533        );
534    }
535
536    #[test]
537    fn test_norm_lp_1d_sine() {
538        // L2 norm of sin(pi*x) on [0, 1] = sqrt(0.5)
539        let argvals = uniform_grid(101);
540        let data: Vec<f64> = argvals.iter().map(|&x| (PI * x).sin()).collect();
541        let mat = FdMatrix::from_column_major(data, 1, 101).unwrap();
542        let norms = norm_lp_1d(&mat, &argvals, 2.0);
543        let expected = 0.5_f64.sqrt();
544        assert!(
545            (norms[0] - expected).abs() < 0.05,
546            "Expected {}, got {}",
547            expected,
548            norms[0]
549        );
550    }
551
552    #[test]
553    fn test_norm_lp_1d_invalid() {
554        assert!(norm_lp_1d(&FdMatrix::zeros(0, 0), &[], 2.0).is_empty());
555    }
556
557    // ============== Derivative tests ==============
558
559    #[test]
560    fn test_deriv_1d_linear() {
561        // Derivative of linear function x should be 1
562        let argvals = uniform_grid(21);
563        let data = argvals.clone();
564        let mat = FdMatrix::from_column_major(data, 1, 21).unwrap();
565        let deriv = deriv_1d(&mat, &argvals, 1);
566        // Interior points should have derivative close to 1
567        for j in 2..19 {
568            assert!(
569                (deriv[(0, j)] - 1.0).abs() < 0.1,
570                "Derivative of x should be 1"
571            );
572        }
573    }
574
575    #[test]
576    fn test_deriv_1d_quadratic() {
577        // Derivative of x^2 should be 2x
578        let argvals = uniform_grid(51);
579        let data: Vec<f64> = argvals.iter().map(|&x| x * x).collect();
580        let mat = FdMatrix::from_column_major(data, 1, 51).unwrap();
581        let deriv = deriv_1d(&mat, &argvals, 1);
582        // Check interior points
583        for j in 5..45 {
584            let expected = 2.0 * argvals[j];
585            assert!(
586                (deriv[(0, j)] - expected).abs() < 0.1,
587                "Derivative of x^2 should be 2x"
588            );
589        }
590    }
591
592    #[test]
593    fn test_deriv_1d_invalid() {
594        let result = deriv_1d(&FdMatrix::zeros(0, 0), &[], 1);
595        assert!(result.is_empty() || result.as_slice().iter().all(|&x| x == 0.0));
596    }
597
598    // ============== Geometric median tests ==============
599
600    #[test]
601    fn test_geometric_median_identical_curves() {
602        // All curves identical -> median = that curve
603        let argvals = uniform_grid(21);
604        let n = 5;
605        let m = 21;
606        let mut data = vec![0.0; n * m];
607        for i in 0..n {
608            for j in 0..m {
609                data[i + j * n] = (2.0 * PI * argvals[j]).sin();
610            }
611        }
612        let mat = FdMatrix::from_column_major(data, n, m).unwrap();
613        let median = geometric_median_1d(&mat, &argvals, 100, 1e-6);
614        for j in 0..m {
615            let expected = (2.0 * PI * argvals[j]).sin();
616            assert!(
617                (median[j] - expected).abs() < 0.01,
618                "Median should equal all curves"
619            );
620        }
621    }
622
623    #[test]
624    fn test_geometric_median_converges() {
625        let argvals = uniform_grid(21);
626        let n = 10;
627        let m = 21;
628        let mut data = vec![0.0; n * m];
629        for i in 0..n {
630            for j in 0..m {
631                data[i + j * n] = (i as f64 / n as f64) * argvals[j];
632            }
633        }
634        let mat = FdMatrix::from_column_major(data, n, m).unwrap();
635        let median = geometric_median_1d(&mat, &argvals, 100, 1e-6);
636        assert_eq!(median.len(), m);
637        assert!(median.iter().all(|&x| x.is_finite()));
638    }
639
640    #[test]
641    fn test_geometric_median_invalid() {
642        assert!(geometric_median_1d(&FdMatrix::zeros(0, 0), &[], 100, 1e-6).is_empty());
643    }
644
645    // ============== 2D derivative tests ==============
646
647    #[test]
648    fn test_deriv_2d_linear_surface() {
649        // f(s, t) = 2*s + 3*t
650        // ∂f/∂s = 2, ∂f/∂t = 3, ∂²f/∂s∂t = 0
651        let m1 = 11;
652        let m2 = 11;
653        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
654        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
655
656        let n = 1; // single surface
657        let ncol = m1 * m2;
658        let mut data = vec![0.0; n * ncol];
659
660        for si in 0..m1 {
661            for ti in 0..m2 {
662                let s = argvals_s[si];
663                let t = argvals_t[ti];
664                let idx = si + ti * m1;
665                data[idx] = 2.0 * s + 3.0 * t;
666            }
667        }
668
669        let mat = FdMatrix::from_column_major(data, n, ncol).unwrap();
670        let result = deriv_2d(&mat, &argvals_s, &argvals_t, m1, m2).unwrap();
671
672        // Check interior points for ∂f/∂s ≈ 2
673        for si in 2..(m1 - 2) {
674            for ti in 2..(m2 - 2) {
675                let idx = si + ti * m1;
676                assert!(
677                    (result.ds[(0, idx)] - 2.0).abs() < 0.2,
678                    "∂f/∂s at ({}, {}) = {}, expected 2",
679                    si,
680                    ti,
681                    result.ds[(0, idx)]
682                );
683            }
684        }
685
686        // Check interior points for ∂f/∂t ≈ 3
687        for si in 2..(m1 - 2) {
688            for ti in 2..(m2 - 2) {
689                let idx = si + ti * m1;
690                assert!(
691                    (result.dt[(0, idx)] - 3.0).abs() < 0.2,
692                    "∂f/∂t at ({}, {}) = {}, expected 3",
693                    si,
694                    ti,
695                    result.dt[(0, idx)]
696                );
697            }
698        }
699
700        // Check interior points for mixed partial ≈ 0
701        for si in 2..(m1 - 2) {
702            for ti in 2..(m2 - 2) {
703                let idx = si + ti * m1;
704                assert!(
705                    result.dsdt[(0, idx)].abs() < 0.5,
706                    "∂²f/∂s∂t at ({}, {}) = {}, expected 0",
707                    si,
708                    ti,
709                    result.dsdt[(0, idx)]
710                );
711            }
712        }
713    }
714
715    #[test]
716    fn test_deriv_2d_quadratic_surface() {
717        // f(s, t) = s*t
718        // ∂f/∂s = t, ∂f/∂t = s, ∂²f/∂s∂t = 1
719        let m1 = 21;
720        let m2 = 21;
721        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
722        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
723
724        let n = 1;
725        let ncol = m1 * m2;
726        let mut data = vec![0.0; n * ncol];
727
728        for si in 0..m1 {
729            for ti in 0..m2 {
730                let s = argvals_s[si];
731                let t = argvals_t[ti];
732                let idx = si + ti * m1;
733                data[idx] = s * t;
734            }
735        }
736
737        let mat = FdMatrix::from_column_major(data, n, ncol).unwrap();
738        let result = deriv_2d(&mat, &argvals_s, &argvals_t, m1, m2).unwrap();
739
740        // Check interior points for ∂f/∂s ≈ t
741        for si in 3..(m1 - 3) {
742            for ti in 3..(m2 - 3) {
743                let idx = si + ti * m1;
744                let expected = argvals_t[ti];
745                assert!(
746                    (result.ds[(0, idx)] - expected).abs() < 0.1,
747                    "∂f/∂s at ({}, {}) = {}, expected {}",
748                    si,
749                    ti,
750                    result.ds[(0, idx)],
751                    expected
752                );
753            }
754        }
755
756        // Check interior points for ∂f/∂t ≈ s
757        for si in 3..(m1 - 3) {
758            for ti in 3..(m2 - 3) {
759                let idx = si + ti * m1;
760                let expected = argvals_s[si];
761                assert!(
762                    (result.dt[(0, idx)] - expected).abs() < 0.1,
763                    "∂f/∂t at ({}, {}) = {}, expected {}",
764                    si,
765                    ti,
766                    result.dt[(0, idx)],
767                    expected
768                );
769            }
770        }
771
772        // Check interior points for mixed partial ≈ 1
773        for si in 3..(m1 - 3) {
774            for ti in 3..(m2 - 3) {
775                let idx = si + ti * m1;
776                assert!(
777                    (result.dsdt[(0, idx)] - 1.0).abs() < 0.3,
778                    "∂²f/∂s∂t at ({}, {}) = {}, expected 1",
779                    si,
780                    ti,
781                    result.dsdt[(0, idx)]
782                );
783            }
784        }
785    }
786
787    #[test]
788    fn test_deriv_2d_invalid_input() {
789        // Empty data
790        let result = deriv_2d(&FdMatrix::zeros(0, 0), &[], &[], 0, 0);
791        assert!(result.is_none());
792
793        // Mismatched dimensions
794        let mat = FdMatrix::from_column_major(vec![1.0; 4], 1, 4).unwrap();
795        let argvals = vec![0.0, 1.0];
796        let result = deriv_2d(&mat, &argvals, &[0.0, 0.5, 1.0], 2, 2);
797        assert!(result.is_none());
798    }
799
800    // ============== 2D geometric median tests ==============
801
802    #[test]
803    fn test_geometric_median_2d_basic() {
804        // Three identical surfaces -> median = that surface
805        let m1 = 5;
806        let m2 = 5;
807        let m = m1 * m2;
808        let n = 3;
809        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
810        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
811
812        let mut data = vec![0.0; n * m];
813
814        // Create identical surfaces: f(s, t) = s + t
815        for i in 0..n {
816            for si in 0..m1 {
817                for ti in 0..m2 {
818                    let idx = si + ti * m1;
819                    let s = argvals_s[si];
820                    let t = argvals_t[ti];
821                    data[i + idx * n] = s + t;
822                }
823            }
824        }
825
826        let mat = FdMatrix::from_column_major(data, n, m).unwrap();
827        let median = geometric_median_2d(&mat, &argvals_s, &argvals_t, 100, 1e-6);
828        assert_eq!(median.len(), m);
829
830        // Check that median equals the surface
831        for si in 0..m1 {
832            for ti in 0..m2 {
833                let idx = si + ti * m1;
834                let expected = argvals_s[si] + argvals_t[ti];
835                assert!(
836                    (median[idx] - expected).abs() < 0.01,
837                    "Median at ({}, {}) = {}, expected {}",
838                    si,
839                    ti,
840                    median[idx],
841                    expected
842                );
843            }
844        }
845    }
846}