fdars_core/
fdata.rs

1//! Functional data operations: mean, center, derivatives, norms, and geometric median.
2
3use crate::helpers::{simpsons_weights, simpsons_weights_2d, NUMERICAL_EPS};
4use rayon::prelude::*;
5
6/// Compute finite difference for a 1D function at a given index.
7///
8/// Uses forward difference at left boundary, backward difference at right boundary,
9/// and central difference for interior points.
10fn finite_diff_1d(
11    values: impl Fn(usize) -> f64,
12    idx: usize,
13    n_points: usize,
14    step_sizes: &[f64],
15) -> f64 {
16    if idx == 0 {
17        (values(1) - values(0)) / step_sizes[0]
18    } else if idx == n_points - 1 {
19        (values(n_points - 1) - values(n_points - 2)) / step_sizes[n_points - 1]
20    } else {
21        (values(idx + 1) - values(idx - 1)) / step_sizes[idx]
22    }
23}
24
25/// Compute 2D partial derivatives at a single grid point.
26///
27/// Returns (∂f/∂s, ∂f/∂t, ∂²f/∂s∂t) using finite differences.
28fn compute_2d_derivatives(
29    get_val: impl Fn(usize, usize) -> f64,
30    si: usize,
31    ti: usize,
32    m1: usize,
33    m2: usize,
34    hs: &[f64],
35    ht: &[f64],
36) -> (f64, f64, f64) {
37    // ∂f/∂s
38    let ds = finite_diff_1d(|s| get_val(s, ti), si, m1, hs);
39
40    // ∂f/∂t
41    let dt = finite_diff_1d(|t| get_val(si, t), ti, m2, ht);
42
43    // ∂²f/∂s∂t (mixed partial)
44    let denom = hs[si] * ht[ti];
45
46    // Get the appropriate indices for s and t differences
47    let (s_lo, s_hi) = if si == 0 {
48        (0, 1)
49    } else if si == m1 - 1 {
50        (m1 - 2, m1 - 1)
51    } else {
52        (si - 1, si + 1)
53    };
54
55    let (t_lo, t_hi) = if ti == 0 {
56        (0, 1)
57    } else if ti == m2 - 1 {
58        (m2 - 2, m2 - 1)
59    } else {
60        (ti - 1, ti + 1)
61    };
62
63    let dsdt = (get_val(s_hi, t_hi) - get_val(s_lo, t_hi) - get_val(s_hi, t_lo)
64        + get_val(s_lo, t_lo))
65        / denom;
66
67    (ds, dt, dsdt)
68}
69
70/// Perform Weiszfeld iteration to compute geometric median.
71///
72/// This is the core algorithm shared by 1D and 2D geometric median computations.
73fn weiszfeld_iteration(
74    data: &[f64],
75    n: usize,
76    m: usize,
77    weights: &[f64],
78    max_iter: usize,
79    tol: f64,
80) -> Vec<f64> {
81    // Initialize with the mean
82    let mut median: Vec<f64> = (0..m)
83        .map(|j| {
84            let mut sum = 0.0;
85            for i in 0..n {
86                sum += data[i + j * n];
87            }
88            sum / n as f64
89        })
90        .collect();
91
92    for _ in 0..max_iter {
93        // Compute distances from current median to all curves
94        let distances: Vec<f64> = (0..n)
95            .map(|i| {
96                let mut dist_sq = 0.0;
97                for j in 0..m {
98                    let diff = data[i + j * n] - median[j];
99                    dist_sq += diff * diff * weights[j];
100                }
101                dist_sq.sqrt()
102            })
103            .collect();
104
105        // Compute weights (1/distance), handling zero distances
106        let inv_distances: Vec<f64> = distances
107            .iter()
108            .map(|d| {
109                if *d > NUMERICAL_EPS {
110                    1.0 / d
111                } else {
112                    1.0 / NUMERICAL_EPS
113                }
114            })
115            .collect();
116
117        let sum_inv_dist: f64 = inv_distances.iter().sum();
118
119        // Update median using Weiszfeld iteration
120        let new_median: Vec<f64> = (0..m)
121            .map(|j| {
122                let mut weighted_sum = 0.0;
123                for i in 0..n {
124                    weighted_sum += data[i + j * n] * inv_distances[i];
125                }
126                weighted_sum / sum_inv_dist
127            })
128            .collect();
129
130        // Check convergence
131        let diff: f64 = median
132            .iter()
133            .zip(new_median.iter())
134            .map(|(a, b)| (a - b).abs())
135            .sum::<f64>()
136            / m as f64;
137
138        median = new_median;
139
140        if diff < tol {
141            break;
142        }
143    }
144
145    median
146}
147
148/// Compute the mean function across all samples (1D).
149///
150/// # Arguments
151/// * `data` - Column-major matrix (n x m)
152/// * `n` - Number of samples
153/// * `m` - Number of evaluation points
154///
155/// # Returns
156/// Mean function values at each evaluation point
157pub fn mean_1d(data: &[f64], n: usize, m: usize) -> Vec<f64> {
158    if n == 0 || m == 0 || data.len() != n * m {
159        return Vec::new();
160    }
161
162    (0..m)
163        .into_par_iter()
164        .map(|j| {
165            let mut sum = 0.0;
166            for i in 0..n {
167                sum += data[i + j * n];
168            }
169            sum / n as f64
170        })
171        .collect()
172}
173
174/// Compute the mean function for 2D surfaces.
175///
176/// Data is stored as n x (m1*m2) matrix where each row is a flattened surface.
177pub fn mean_2d(data: &[f64], n: usize, m: usize) -> Vec<f64> {
178    // Same computation as 1D - just compute pointwise mean
179    mean_1d(data, n, m)
180}
181
182/// Center functional data by subtracting the mean function.
183///
184/// # Arguments
185/// * `data` - Column-major matrix (n x m)
186/// * `n` - Number of samples
187/// * `m` - Number of evaluation points
188///
189/// # Returns
190/// Centered data matrix (column-major)
191pub fn center_1d(data: &[f64], n: usize, m: usize) -> Vec<f64> {
192    if n == 0 || m == 0 || data.len() != n * m {
193        return Vec::new();
194    }
195
196    // First compute the mean for each column (parallelized)
197    let means: Vec<f64> = (0..m)
198        .into_par_iter()
199        .map(|j| {
200            let mut sum = 0.0;
201            for i in 0..n {
202                sum += data[i + j * n];
203            }
204            sum / n as f64
205        })
206        .collect();
207
208    // Create centered data (parallelized by column)
209    let mut centered = vec![0.0; n * m];
210    for j in 0..m {
211        for i in 0..n {
212            centered[i + j * n] = data[i + j * n] - means[j];
213        }
214    }
215
216    centered
217}
218
219/// Compute Lp norm for each sample.
220///
221/// # Arguments
222/// * `data` - Column-major matrix (n x m)
223/// * `n` - Number of samples
224/// * `m` - Number of evaluation points
225/// * `argvals` - Evaluation points for integration
226/// * `p` - Order of the norm (e.g., 2.0 for L2)
227///
228/// # Returns
229/// Vector of Lp norms for each sample
230pub fn norm_lp_1d(data: &[f64], n: usize, m: usize, argvals: &[f64], p: f64) -> Vec<f64> {
231    if n == 0 || m == 0 || argvals.len() != m || data.len() != n * m {
232        return Vec::new();
233    }
234
235    let weights = simpsons_weights(argvals);
236
237    (0..n)
238        .into_par_iter()
239        .map(|i| {
240            let mut integral = 0.0;
241            for j in 0..m {
242                let val = data[i + j * n].abs().powf(p);
243                integral += val * weights[j];
244            }
245            integral.powf(1.0 / p)
246        })
247        .collect()
248}
249
250/// Compute numerical derivative of functional data (parallelized over rows).
251///
252/// # Arguments
253/// * `data` - Column-major matrix (n x m)
254/// * `n` - Number of samples
255/// * `m` - Number of evaluation points
256/// * `argvals` - Evaluation points
257/// * `nderiv` - Order of derivative
258///
259/// # Returns
260/// Derivative data matrix (column-major)
261pub fn deriv_1d(data: &[f64], n: usize, m: usize, argvals: &[f64], nderiv: usize) -> Vec<f64> {
262    if n == 0 || m == 0 || argvals.len() != m || nderiv < 1 || data.len() != n * m {
263        return vec![0.0; n * m];
264    }
265
266    let mut current = data.to_vec();
267
268    // Pre-compute step sizes for central differences
269    let h0 = argvals[1] - argvals[0];
270    let hn = argvals[m - 1] - argvals[m - 2];
271    let h_central: Vec<f64> = (1..(m - 1))
272        .map(|j| argvals[j + 1] - argvals[j - 1])
273        .collect();
274
275    for _ in 0..nderiv {
276        // Compute derivative for each row in parallel
277        let deriv: Vec<f64> = (0..n)
278            .into_par_iter()
279            .flat_map(|i| {
280                let mut row_deriv = vec![0.0; m];
281
282                // Forward difference at left boundary
283                row_deriv[0] = (current[i + n] - current[i]) / h0;
284
285                // Central differences for interior points
286                for j in 1..(m - 1) {
287                    row_deriv[j] =
288                        (current[i + (j + 1) * n] - current[i + (j - 1) * n]) / h_central[j - 1];
289                }
290
291                // Backward difference at right boundary
292                row_deriv[m - 1] = (current[i + (m - 1) * n] - current[i + (m - 2) * n]) / hn;
293
294                row_deriv
295            })
296            .collect();
297
298        // Reorder from row-major to column-major order
299        current = vec![0.0; n * m];
300        for i in 0..n {
301            for j in 0..m {
302                current[i + j * n] = deriv[i * m + j];
303            }
304        }
305    }
306
307    current
308}
309
310/// Result of 2D partial derivatives.
311pub struct Deriv2DResult {
312    /// Partial derivative with respect to s (∂f/∂s)
313    pub ds: Vec<f64>,
314    /// Partial derivative with respect to t (∂f/∂t)
315    pub dt: Vec<f64>,
316    /// Mixed partial derivative (∂²f/∂s∂t)
317    pub dsdt: Vec<f64>,
318}
319
320/// Compute 2D partial derivatives for surface data.
321///
322/// For a surface f(s,t), computes:
323/// - ds: partial derivative with respect to s (∂f/∂s)
324/// - dt: partial derivative with respect to t (∂f/∂t)
325/// - dsdt: mixed partial derivative (∂²f/∂s∂t)
326///
327/// # Arguments
328/// * `data` - Column-major matrix, n surfaces, each stored as m1*m2 values
329/// * `n` - Number of surfaces
330/// * `argvals_s` - Grid points in s direction (length m1)
331/// * `argvals_t` - Grid points in t direction (length m2)
332/// * `m1` - Grid size in s direction
333/// * `m2` - Grid size in t direction
334pub fn deriv_2d(
335    data: &[f64],
336    n: usize,
337    argvals_s: &[f64],
338    argvals_t: &[f64],
339    m1: usize,
340    m2: usize,
341) -> Option<Deriv2DResult> {
342    let ncol = m1 * m2;
343    if n == 0 || ncol == 0 || argvals_s.len() != m1 || argvals_t.len() != m2 {
344        return None;
345    }
346
347    // Pre-compute step sizes for s direction
348    let hs: Vec<f64> = (0..m1)
349        .map(|j| {
350            if j == 0 {
351                argvals_s[1] - argvals_s[0]
352            } else if j == m1 - 1 {
353                argvals_s[m1 - 1] - argvals_s[m1 - 2]
354            } else {
355                argvals_s[j + 1] - argvals_s[j - 1]
356            }
357        })
358        .collect();
359
360    // Pre-compute step sizes for t direction
361    let ht: Vec<f64> = (0..m2)
362        .map(|j| {
363            if j == 0 {
364                argvals_t[1] - argvals_t[0]
365            } else if j == m2 - 1 {
366                argvals_t[m2 - 1] - argvals_t[m2 - 2]
367            } else {
368                argvals_t[j + 1] - argvals_t[j - 1]
369            }
370        })
371        .collect();
372
373    // Compute all derivatives in parallel over surfaces
374    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>)> = (0..n)
375        .into_par_iter()
376        .map(|i| {
377            let mut ds = vec![0.0; m1 * m2];
378            let mut dt = vec![0.0; m1 * m2];
379            let mut dsdt = vec![0.0; m1 * m2];
380
381            // Closure to access data for surface i
382            let get_val = |si: usize, ti: usize| -> f64 { data[i + (si + ti * m1) * n] };
383
384            for ti in 0..m2 {
385                for si in 0..m1 {
386                    let idx = si + ti * m1;
387                    let (ds_val, dt_val, dsdt_val) =
388                        compute_2d_derivatives(get_val, si, ti, m1, m2, &hs, &ht);
389                    ds[idx] = ds_val;
390                    dt[idx] = dt_val;
391                    dsdt[idx] = dsdt_val;
392                }
393            }
394
395            (ds, dt, dsdt)
396        })
397        .collect();
398
399    // Convert to column-major matrices
400    let mut ds_mat = vec![0.0; n * ncol];
401    let mut dt_mat = vec![0.0; n * ncol];
402    let mut dsdt_mat = vec![0.0; n * ncol];
403
404    for i in 0..n {
405        for j in 0..ncol {
406            ds_mat[i + j * n] = results[i].0[j];
407            dt_mat[i + j * n] = results[i].1[j];
408            dsdt_mat[i + j * n] = results[i].2[j];
409        }
410    }
411
412    Some(Deriv2DResult {
413        ds: ds_mat,
414        dt: dt_mat,
415        dsdt: dsdt_mat,
416    })
417}
418
419/// Compute the geometric median (L1 median) of functional data using Weiszfeld's algorithm.
420///
421/// The geometric median minimizes sum of L2 distances to all curves.
422///
423/// # Arguments
424/// * `data` - Column-major matrix (n x m)
425/// * `n` - Number of samples
426/// * `m` - Number of evaluation points
427/// * `argvals` - Evaluation points for integration
428/// * `max_iter` - Maximum iterations
429/// * `tol` - Convergence tolerance
430pub fn geometric_median_1d(
431    data: &[f64],
432    n: usize,
433    m: usize,
434    argvals: &[f64],
435    max_iter: usize,
436    tol: f64,
437) -> Vec<f64> {
438    if n == 0 || m == 0 || argvals.len() != m || data.len() != n * m {
439        return Vec::new();
440    }
441
442    let weights = simpsons_weights(argvals);
443    weiszfeld_iteration(data, n, m, &weights, max_iter, tol)
444}
445
446/// Compute the geometric median for 2D functional data.
447///
448/// Data is stored as n x (m1*m2) matrix where each row is a flattened surface.
449///
450/// # Arguments
451/// * `data` - Column-major matrix (n x m) where m = m1*m2
452/// * `n` - Number of samples
453/// * `m` - Number of grid points (m1 * m2)
454/// * `argvals_s` - Grid points in s direction (length m1)
455/// * `argvals_t` - Grid points in t direction (length m2)
456/// * `max_iter` - Maximum iterations
457/// * `tol` - Convergence tolerance
458pub fn geometric_median_2d(
459    data: &[f64],
460    n: usize,
461    m: usize,
462    argvals_s: &[f64],
463    argvals_t: &[f64],
464    max_iter: usize,
465    tol: f64,
466) -> Vec<f64> {
467    let expected_cols = argvals_s.len() * argvals_t.len();
468    if n == 0 || m == 0 || m != expected_cols || data.len() != n * m {
469        return Vec::new();
470    }
471
472    let weights = simpsons_weights_2d(argvals_s, argvals_t);
473    weiszfeld_iteration(data, n, m, &weights, max_iter, tol)
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479    use std::f64::consts::PI;
480
481    fn uniform_grid(n: usize) -> Vec<f64> {
482        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
483    }
484
485    // ============== Mean tests ==============
486
487    #[test]
488    fn test_mean_1d() {
489        // 2 samples, 3 points each
490        // Sample 1: [1, 2, 3]
491        // Sample 2: [3, 4, 5]
492        // Mean should be [2, 3, 4]
493        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; // column-major
494        let mean = mean_1d(&data, 2, 3);
495        assert_eq!(mean, vec![2.0, 3.0, 4.0]);
496    }
497
498    #[test]
499    fn test_mean_1d_single_sample() {
500        let data = vec![1.0, 2.0, 3.0];
501        let mean = mean_1d(&data, 1, 3);
502        assert_eq!(mean, vec![1.0, 2.0, 3.0]);
503    }
504
505    #[test]
506    fn test_mean_1d_invalid() {
507        assert!(mean_1d(&[], 0, 0).is_empty());
508        assert!(mean_1d(&[1.0], 1, 2).is_empty()); // wrong data length
509    }
510
511    #[test]
512    fn test_mean_2d_delegates() {
513        let data = vec![1.0, 3.0, 2.0, 4.0];
514        let mean1d = mean_1d(&data, 2, 2);
515        let mean2d = mean_2d(&data, 2, 2);
516        assert_eq!(mean1d, mean2d);
517    }
518
519    // ============== Center tests ==============
520
521    #[test]
522    fn test_center_1d() {
523        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; // column-major
524        let centered = center_1d(&data, 2, 3);
525        // Mean is [2, 3, 4], so centered should be [-1, 1, -1, 1, -1, 1]
526        assert_eq!(centered, vec![-1.0, 1.0, -1.0, 1.0, -1.0, 1.0]);
527    }
528
529    #[test]
530    fn test_center_1d_mean_zero() {
531        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0];
532        let centered = center_1d(&data, 2, 3);
533        let centered_mean = mean_1d(&centered, 2, 3);
534        for m in centered_mean {
535            assert!(m.abs() < 1e-10, "Centered data should have zero mean");
536        }
537    }
538
539    #[test]
540    fn test_center_1d_invalid() {
541        assert!(center_1d(&[], 0, 0).is_empty());
542    }
543
544    // ============== Norm tests ==============
545
546    #[test]
547    fn test_norm_lp_1d_constant() {
548        // Constant function 2 on [0, 1] has L2 norm = 2
549        let argvals = uniform_grid(21);
550        let mut data = vec![0.0; 21];
551        for j in 0..21 {
552            data[j] = 2.0;
553        }
554        let norms = norm_lp_1d(&data, 1, 21, &argvals, 2.0);
555        assert_eq!(norms.len(), 1);
556        assert!(
557            (norms[0] - 2.0).abs() < 0.1,
558            "L2 norm of constant 2 should be 2"
559        );
560    }
561
562    #[test]
563    fn test_norm_lp_1d_sine() {
564        // L2 norm of sin(pi*x) on [0, 1] = sqrt(0.5)
565        let argvals = uniform_grid(101);
566        let mut data = vec![0.0; 101];
567        for j in 0..101 {
568            data[j] = (PI * argvals[j]).sin();
569        }
570        let norms = norm_lp_1d(&data, 1, 101, &argvals, 2.0);
571        let expected = 0.5_f64.sqrt();
572        assert!(
573            (norms[0] - expected).abs() < 0.05,
574            "Expected {}, got {}",
575            expected,
576            norms[0]
577        );
578    }
579
580    #[test]
581    fn test_norm_lp_1d_invalid() {
582        assert!(norm_lp_1d(&[], 0, 0, &[], 2.0).is_empty());
583    }
584
585    // ============== Derivative tests ==============
586
587    #[test]
588    fn test_deriv_1d_linear() {
589        // Derivative of linear function x should be 1
590        let argvals = uniform_grid(21);
591        let data = argvals.clone();
592        let deriv = deriv_1d(&data, 1, 21, &argvals, 1);
593        // Interior points should have derivative close to 1
594        for j in 2..19 {
595            assert!((deriv[j] - 1.0).abs() < 0.1, "Derivative of x should be 1");
596        }
597    }
598
599    #[test]
600    fn test_deriv_1d_quadratic() {
601        // Derivative of x^2 should be 2x
602        let argvals = uniform_grid(51);
603        let mut data = vec![0.0; 51];
604        for j in 0..51 {
605            data[j] = argvals[j] * argvals[j];
606        }
607        let deriv = deriv_1d(&data, 1, 51, &argvals, 1);
608        // Check interior points
609        for j in 5..45 {
610            let expected = 2.0 * argvals[j];
611            assert!(
612                (deriv[j] - expected).abs() < 0.1,
613                "Derivative of x^2 should be 2x"
614            );
615        }
616    }
617
618    #[test]
619    fn test_deriv_1d_invalid() {
620        let result = deriv_1d(&[], 0, 0, &[], 1);
621        assert!(result.is_empty() || result.iter().all(|&x| x == 0.0));
622    }
623
624    // ============== Geometric median tests ==============
625
626    #[test]
627    fn test_geometric_median_identical_curves() {
628        // All curves identical -> median = that curve
629        let argvals = uniform_grid(21);
630        let n = 5;
631        let m = 21;
632        let mut data = vec![0.0; n * m];
633        for i in 0..n {
634            for j in 0..m {
635                data[i + j * n] = (2.0 * PI * argvals[j]).sin();
636            }
637        }
638        let median = geometric_median_1d(&data, n, m, &argvals, 100, 1e-6);
639        for j in 0..m {
640            let expected = (2.0 * PI * argvals[j]).sin();
641            assert!(
642                (median[j] - expected).abs() < 0.01,
643                "Median should equal all curves"
644            );
645        }
646    }
647
648    #[test]
649    fn test_geometric_median_converges() {
650        let argvals = uniform_grid(21);
651        let n = 10;
652        let m = 21;
653        let mut data = vec![0.0; n * m];
654        for i in 0..n {
655            for j in 0..m {
656                data[i + j * n] = (i as f64 / n as f64) * argvals[j];
657            }
658        }
659        let median = geometric_median_1d(&data, n, m, &argvals, 100, 1e-6);
660        assert_eq!(median.len(), m);
661        assert!(median.iter().all(|&x| x.is_finite()));
662    }
663
664    #[test]
665    fn test_geometric_median_invalid() {
666        assert!(geometric_median_1d(&[], 0, 0, &[], 100, 1e-6).is_empty());
667    }
668
669    // ============== 2D derivative tests ==============
670
671    #[test]
672    fn test_deriv_2d_linear_surface() {
673        // f(s, t) = 2*s + 3*t
674        // ∂f/∂s = 2, ∂f/∂t = 3, ∂²f/∂s∂t = 0
675        let m1 = 11;
676        let m2 = 11;
677        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
678        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
679
680        let n = 1; // single surface
681        let ncol = m1 * m2;
682        let mut data = vec![0.0; n * ncol];
683
684        for si in 0..m1 {
685            for ti in 0..m2 {
686                let s = argvals_s[si];
687                let t = argvals_t[ti];
688                let idx = si + ti * m1;
689                data[idx] = 2.0 * s + 3.0 * t;
690            }
691        }
692
693        let result = deriv_2d(&data, n, &argvals_s, &argvals_t, m1, m2).unwrap();
694
695        // Check interior points for ∂f/∂s ≈ 2
696        for si in 2..(m1 - 2) {
697            for ti in 2..(m2 - 2) {
698                let idx = si + ti * m1;
699                assert!(
700                    (result.ds[idx] - 2.0).abs() < 0.2,
701                    "∂f/∂s at ({}, {}) = {}, expected 2",
702                    si,
703                    ti,
704                    result.ds[idx]
705                );
706            }
707        }
708
709        // Check interior points for ∂f/∂t ≈ 3
710        for si in 2..(m1 - 2) {
711            for ti in 2..(m2 - 2) {
712                let idx = si + ti * m1;
713                assert!(
714                    (result.dt[idx] - 3.0).abs() < 0.2,
715                    "∂f/∂t at ({}, {}) = {}, expected 3",
716                    si,
717                    ti,
718                    result.dt[idx]
719                );
720            }
721        }
722
723        // Check interior points for mixed partial ≈ 0
724        for si in 2..(m1 - 2) {
725            for ti in 2..(m2 - 2) {
726                let idx = si + ti * m1;
727                assert!(
728                    result.dsdt[idx].abs() < 0.5,
729                    "∂²f/∂s∂t at ({}, {}) = {}, expected 0",
730                    si,
731                    ti,
732                    result.dsdt[idx]
733                );
734            }
735        }
736    }
737
738    #[test]
739    fn test_deriv_2d_quadratic_surface() {
740        // f(s, t) = s*t
741        // ∂f/∂s = t, ∂f/∂t = s, ∂²f/∂s∂t = 1
742        let m1 = 21;
743        let m2 = 21;
744        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
745        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
746
747        let n = 1;
748        let ncol = m1 * m2;
749        let mut data = vec![0.0; n * ncol];
750
751        for si in 0..m1 {
752            for ti in 0..m2 {
753                let s = argvals_s[si];
754                let t = argvals_t[ti];
755                let idx = si + ti * m1;
756                data[idx] = s * t;
757            }
758        }
759
760        let result = deriv_2d(&data, n, &argvals_s, &argvals_t, m1, m2).unwrap();
761
762        // Check interior points for ∂f/∂s ≈ t
763        for si in 3..(m1 - 3) {
764            for ti in 3..(m2 - 3) {
765                let idx = si + ti * m1;
766                let expected = argvals_t[ti];
767                assert!(
768                    (result.ds[idx] - expected).abs() < 0.1,
769                    "∂f/∂s at ({}, {}) = {}, expected {}",
770                    si,
771                    ti,
772                    result.ds[idx],
773                    expected
774                );
775            }
776        }
777
778        // Check interior points for ∂f/∂t ≈ s
779        for si in 3..(m1 - 3) {
780            for ti in 3..(m2 - 3) {
781                let idx = si + ti * m1;
782                let expected = argvals_s[si];
783                assert!(
784                    (result.dt[idx] - expected).abs() < 0.1,
785                    "∂f/∂t at ({}, {}) = {}, expected {}",
786                    si,
787                    ti,
788                    result.dt[idx],
789                    expected
790                );
791            }
792        }
793
794        // Check interior points for mixed partial ≈ 1
795        for si in 3..(m1 - 3) {
796            for ti in 3..(m2 - 3) {
797                let idx = si + ti * m1;
798                assert!(
799                    (result.dsdt[idx] - 1.0).abs() < 0.3,
800                    "∂²f/∂s∂t at ({}, {}) = {}, expected 1",
801                    si,
802                    ti,
803                    result.dsdt[idx]
804                );
805            }
806        }
807    }
808
809    #[test]
810    fn test_deriv_2d_invalid_input() {
811        // Empty data
812        let result = deriv_2d(&[], 0, &[], &[], 0, 0);
813        assert!(result.is_none());
814
815        // Mismatched dimensions
816        let data = vec![1.0; 4];
817        let argvals = vec![0.0, 1.0];
818        let result = deriv_2d(&data, 1, &argvals, &[0.0, 0.5, 1.0], 2, 2);
819        assert!(result.is_none());
820    }
821
822    // ============== 2D geometric median tests ==============
823
824    #[test]
825    fn test_geometric_median_2d_basic() {
826        // Three identical surfaces -> median = that surface
827        let m1 = 5;
828        let m2 = 5;
829        let m = m1 * m2;
830        let n = 3;
831        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
832        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
833
834        let mut data = vec![0.0; n * m];
835
836        // Create identical surfaces: f(s, t) = s + t
837        for i in 0..n {
838            for si in 0..m1 {
839                for ti in 0..m2 {
840                    let idx = si + ti * m1;
841                    let s = argvals_s[si];
842                    let t = argvals_t[ti];
843                    data[i + idx * n] = s + t;
844                }
845            }
846        }
847
848        let median = geometric_median_2d(&data, n, m, &argvals_s, &argvals_t, 100, 1e-6);
849        assert_eq!(median.len(), m);
850
851        // Check that median equals the surface
852        for si in 0..m1 {
853            for ti in 0..m2 {
854                let idx = si + ti * m1;
855                let expected = argvals_s[si] + argvals_t[ti];
856                assert!(
857                    (median[idx] - expected).abs() < 0.01,
858                    "Median at ({}, {}) = {}, expected {}",
859                    si,
860                    ti,
861                    median[idx],
862                    expected
863                );
864            }
865        }
866    }
867}