fdars_core/
fdata.rs

1//! Functional data operations: mean, center, derivatives, norms, and geometric median.
2
3use crate::helpers::{simpsons_weights, simpsons_weights_2d, NUMERICAL_EPS};
4use crate::iter_maybe_parallel;
5#[cfg(feature = "parallel")]
6use rayon::iter::ParallelIterator;
7
8/// Compute finite difference for a 1D function at a given index.
9///
10/// Uses forward difference at left boundary, backward difference at right boundary,
11/// and central difference for interior points.
12fn finite_diff_1d(
13    values: impl Fn(usize) -> f64,
14    idx: usize,
15    n_points: usize,
16    step_sizes: &[f64],
17) -> f64 {
18    if idx == 0 {
19        (values(1) - values(0)) / step_sizes[0]
20    } else if idx == n_points - 1 {
21        (values(n_points - 1) - values(n_points - 2)) / step_sizes[n_points - 1]
22    } else {
23        (values(idx + 1) - values(idx - 1)) / step_sizes[idx]
24    }
25}
26
27/// Compute 2D partial derivatives at a single grid point.
28///
29/// Returns (∂f/∂s, ∂f/∂t, ∂²f/∂s∂t) using finite differences.
30fn compute_2d_derivatives(
31    get_val: impl Fn(usize, usize) -> f64,
32    si: usize,
33    ti: usize,
34    m1: usize,
35    m2: usize,
36    hs: &[f64],
37    ht: &[f64],
38) -> (f64, f64, f64) {
39    // ∂f/∂s
40    let ds = finite_diff_1d(|s| get_val(s, ti), si, m1, hs);
41
42    // ∂f/∂t
43    let dt = finite_diff_1d(|t| get_val(si, t), ti, m2, ht);
44
45    // ∂²f/∂s∂t (mixed partial)
46    let denom = hs[si] * ht[ti];
47
48    // Get the appropriate indices for s and t differences
49    let (s_lo, s_hi) = if si == 0 {
50        (0, 1)
51    } else if si == m1 - 1 {
52        (m1 - 2, m1 - 1)
53    } else {
54        (si - 1, si + 1)
55    };
56
57    let (t_lo, t_hi) = if ti == 0 {
58        (0, 1)
59    } else if ti == m2 - 1 {
60        (m2 - 2, m2 - 1)
61    } else {
62        (ti - 1, ti + 1)
63    };
64
65    let dsdt = (get_val(s_hi, t_hi) - get_val(s_lo, t_hi) - get_val(s_hi, t_lo)
66        + get_val(s_lo, t_lo))
67        / denom;
68
69    (ds, dt, dsdt)
70}
71
72/// Perform Weiszfeld iteration to compute geometric median.
73///
74/// This is the core algorithm shared by 1D and 2D geometric median computations.
75fn weiszfeld_iteration(
76    data: &[f64],
77    n: usize,
78    m: usize,
79    weights: &[f64],
80    max_iter: usize,
81    tol: f64,
82) -> Vec<f64> {
83    // Initialize with the mean
84    let mut median: Vec<f64> = (0..m)
85        .map(|j| {
86            let mut sum = 0.0;
87            for i in 0..n {
88                sum += data[i + j * n];
89            }
90            sum / n as f64
91        })
92        .collect();
93
94    for _ in 0..max_iter {
95        // Compute distances from current median to all curves
96        let distances: Vec<f64> = (0..n)
97            .map(|i| {
98                let mut dist_sq = 0.0;
99                for j in 0..m {
100                    let diff = data[i + j * n] - median[j];
101                    dist_sq += diff * diff * weights[j];
102                }
103                dist_sq.sqrt()
104            })
105            .collect();
106
107        // Compute weights (1/distance), handling zero distances
108        let inv_distances: Vec<f64> = distances
109            .iter()
110            .map(|d| {
111                if *d > NUMERICAL_EPS {
112                    1.0 / d
113                } else {
114                    1.0 / NUMERICAL_EPS
115                }
116            })
117            .collect();
118
119        let sum_inv_dist: f64 = inv_distances.iter().sum();
120
121        // Update median using Weiszfeld iteration
122        let new_median: Vec<f64> = (0..m)
123            .map(|j| {
124                let mut weighted_sum = 0.0;
125                for i in 0..n {
126                    weighted_sum += data[i + j * n] * inv_distances[i];
127                }
128                weighted_sum / sum_inv_dist
129            })
130            .collect();
131
132        // Check convergence
133        let diff: f64 = median
134            .iter()
135            .zip(new_median.iter())
136            .map(|(a, b)| (a - b).abs())
137            .sum::<f64>()
138            / m as f64;
139
140        median = new_median;
141
142        if diff < tol {
143            break;
144        }
145    }
146
147    median
148}
149
150/// Compute the mean function across all samples (1D).
151///
152/// # Arguments
153/// * `data` - Column-major matrix (n x m)
154/// * `n` - Number of samples
155/// * `m` - Number of evaluation points
156///
157/// # Returns
158/// Mean function values at each evaluation point
159pub fn mean_1d(data: &[f64], n: usize, m: usize) -> Vec<f64> {
160    if n == 0 || m == 0 || data.len() != n * m {
161        return Vec::new();
162    }
163
164    iter_maybe_parallel!(0..m)
165        .map(|j| {
166            let mut sum = 0.0;
167            for i in 0..n {
168                sum += data[i + j * n];
169            }
170            sum / n as f64
171        })
172        .collect()
173}
174
175/// Compute the mean function for 2D surfaces.
176///
177/// Data is stored as n x (m1*m2) matrix where each row is a flattened surface.
178pub fn mean_2d(data: &[f64], n: usize, m: usize) -> Vec<f64> {
179    // Same computation as 1D - just compute pointwise mean
180    mean_1d(data, n, m)
181}
182
183/// Center functional data by subtracting the mean function.
184///
185/// # Arguments
186/// * `data` - Column-major matrix (n x m)
187/// * `n` - Number of samples
188/// * `m` - Number of evaluation points
189///
190/// # Returns
191/// Centered data matrix (column-major)
192pub fn center_1d(data: &[f64], n: usize, m: usize) -> Vec<f64> {
193    if n == 0 || m == 0 || data.len() != n * m {
194        return Vec::new();
195    }
196
197    // First compute the mean for each column (parallelized)
198    let means: Vec<f64> = iter_maybe_parallel!(0..m)
199        .map(|j| {
200            let mut sum = 0.0;
201            for i in 0..n {
202                sum += data[i + j * n];
203            }
204            sum / n as f64
205        })
206        .collect();
207
208    // Create centered data (parallelized by column)
209    let mut centered = vec![0.0; n * m];
210    for j in 0..m {
211        for i in 0..n {
212            centered[i + j * n] = data[i + j * n] - means[j];
213        }
214    }
215
216    centered
217}
218
219/// Compute Lp norm for each sample.
220///
221/// # Arguments
222/// * `data` - Column-major matrix (n x m)
223/// * `n` - Number of samples
224/// * `m` - Number of evaluation points
225/// * `argvals` - Evaluation points for integration
226/// * `p` - Order of the norm (e.g., 2.0 for L2)
227///
228/// # Returns
229/// Vector of Lp norms for each sample
230pub fn norm_lp_1d(data: &[f64], n: usize, m: usize, argvals: &[f64], p: f64) -> Vec<f64> {
231    if n == 0 || m == 0 || argvals.len() != m || data.len() != n * m {
232        return Vec::new();
233    }
234
235    let weights = simpsons_weights(argvals);
236
237    iter_maybe_parallel!(0..n)
238        .map(|i| {
239            let mut integral = 0.0;
240            for j in 0..m {
241                let val = data[i + j * n].abs().powf(p);
242                integral += val * weights[j];
243            }
244            integral.powf(1.0 / p)
245        })
246        .collect()
247}
248
249/// Compute numerical derivative of functional data (parallelized over rows).
250///
251/// # Arguments
252/// * `data` - Column-major matrix (n x m)
253/// * `n` - Number of samples
254/// * `m` - Number of evaluation points
255/// * `argvals` - Evaluation points
256/// * `nderiv` - Order of derivative
257///
258/// # Returns
259/// Derivative data matrix (column-major)
260pub fn deriv_1d(data: &[f64], n: usize, m: usize, argvals: &[f64], nderiv: usize) -> Vec<f64> {
261    if n == 0 || m == 0 || argvals.len() != m || nderiv < 1 || data.len() != n * m {
262        return vec![0.0; n * m];
263    }
264
265    let mut current = data.to_vec();
266
267    // Pre-compute step sizes for central differences
268    let h0 = argvals[1] - argvals[0];
269    let hn = argvals[m - 1] - argvals[m - 2];
270    let h_central: Vec<f64> = (1..(m - 1))
271        .map(|j| argvals[j + 1] - argvals[j - 1])
272        .collect();
273
274    for _ in 0..nderiv {
275        // Compute derivative for each row in parallel
276        let deriv: Vec<f64> = iter_maybe_parallel!(0..n)
277            .flat_map(|i| {
278                let mut row_deriv = vec![0.0; m];
279
280                // Forward difference at left boundary
281                row_deriv[0] = (current[i + n] - current[i]) / h0;
282
283                // Central differences for interior points
284                for j in 1..(m - 1) {
285                    row_deriv[j] =
286                        (current[i + (j + 1) * n] - current[i + (j - 1) * n]) / h_central[j - 1];
287                }
288
289                // Backward difference at right boundary
290                row_deriv[m - 1] = (current[i + (m - 1) * n] - current[i + (m - 2) * n]) / hn;
291
292                row_deriv
293            })
294            .collect();
295
296        // Reorder from row-major to column-major order
297        current = vec![0.0; n * m];
298        for i in 0..n {
299            for j in 0..m {
300                current[i + j * n] = deriv[i * m + j];
301            }
302        }
303    }
304
305    current
306}
307
308/// Result of 2D partial derivatives.
309pub struct Deriv2DResult {
310    /// Partial derivative with respect to s (∂f/∂s)
311    pub ds: Vec<f64>,
312    /// Partial derivative with respect to t (∂f/∂t)
313    pub dt: Vec<f64>,
314    /// Mixed partial derivative (∂²f/∂s∂t)
315    pub dsdt: Vec<f64>,
316}
317
318/// Compute 2D partial derivatives for surface data.
319///
320/// For a surface f(s,t), computes:
321/// - ds: partial derivative with respect to s (∂f/∂s)
322/// - dt: partial derivative with respect to t (∂f/∂t)
323/// - dsdt: mixed partial derivative (∂²f/∂s∂t)
324///
325/// # Arguments
326/// * `data` - Column-major matrix, n surfaces, each stored as m1*m2 values
327/// * `n` - Number of surfaces
328/// * `argvals_s` - Grid points in s direction (length m1)
329/// * `argvals_t` - Grid points in t direction (length m2)
330/// * `m1` - Grid size in s direction
331/// * `m2` - Grid size in t direction
332pub fn deriv_2d(
333    data: &[f64],
334    n: usize,
335    argvals_s: &[f64],
336    argvals_t: &[f64],
337    m1: usize,
338    m2: usize,
339) -> Option<Deriv2DResult> {
340    let ncol = m1 * m2;
341    if n == 0 || ncol == 0 || argvals_s.len() != m1 || argvals_t.len() != m2 {
342        return None;
343    }
344
345    // Pre-compute step sizes for s direction
346    let hs: Vec<f64> = (0..m1)
347        .map(|j| {
348            if j == 0 {
349                argvals_s[1] - argvals_s[0]
350            } else if j == m1 - 1 {
351                argvals_s[m1 - 1] - argvals_s[m1 - 2]
352            } else {
353                argvals_s[j + 1] - argvals_s[j - 1]
354            }
355        })
356        .collect();
357
358    // Pre-compute step sizes for t direction
359    let ht: Vec<f64> = (0..m2)
360        .map(|j| {
361            if j == 0 {
362                argvals_t[1] - argvals_t[0]
363            } else if j == m2 - 1 {
364                argvals_t[m2 - 1] - argvals_t[m2 - 2]
365            } else {
366                argvals_t[j + 1] - argvals_t[j - 1]
367            }
368        })
369        .collect();
370
371    // Compute all derivatives in parallel over surfaces
372    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
373        .map(|i| {
374            let mut ds = vec![0.0; m1 * m2];
375            let mut dt = vec![0.0; m1 * m2];
376            let mut dsdt = vec![0.0; m1 * m2];
377
378            // Closure to access data for surface i
379            let get_val = |si: usize, ti: usize| -> f64 { data[i + (si + ti * m1) * n] };
380
381            for ti in 0..m2 {
382                for si in 0..m1 {
383                    let idx = si + ti * m1;
384                    let (ds_val, dt_val, dsdt_val) =
385                        compute_2d_derivatives(get_val, si, ti, m1, m2, &hs, &ht);
386                    ds[idx] = ds_val;
387                    dt[idx] = dt_val;
388                    dsdt[idx] = dsdt_val;
389                }
390            }
391
392            (ds, dt, dsdt)
393        })
394        .collect();
395
396    // Convert to column-major matrices
397    let mut ds_mat = vec![0.0; n * ncol];
398    let mut dt_mat = vec![0.0; n * ncol];
399    let mut dsdt_mat = vec![0.0; n * ncol];
400
401    for i in 0..n {
402        for j in 0..ncol {
403            ds_mat[i + j * n] = results[i].0[j];
404            dt_mat[i + j * n] = results[i].1[j];
405            dsdt_mat[i + j * n] = results[i].2[j];
406        }
407    }
408
409    Some(Deriv2DResult {
410        ds: ds_mat,
411        dt: dt_mat,
412        dsdt: dsdt_mat,
413    })
414}
415
416/// Compute the geometric median (L1 median) of functional data using Weiszfeld's algorithm.
417///
418/// The geometric median minimizes sum of L2 distances to all curves.
419///
420/// # Arguments
421/// * `data` - Column-major matrix (n x m)
422/// * `n` - Number of samples
423/// * `m` - Number of evaluation points
424/// * `argvals` - Evaluation points for integration
425/// * `max_iter` - Maximum iterations
426/// * `tol` - Convergence tolerance
427pub fn geometric_median_1d(
428    data: &[f64],
429    n: usize,
430    m: usize,
431    argvals: &[f64],
432    max_iter: usize,
433    tol: f64,
434) -> Vec<f64> {
435    if n == 0 || m == 0 || argvals.len() != m || data.len() != n * m {
436        return Vec::new();
437    }
438
439    let weights = simpsons_weights(argvals);
440    weiszfeld_iteration(data, n, m, &weights, max_iter, tol)
441}
442
443/// Compute the geometric median for 2D functional data.
444///
445/// Data is stored as n x (m1*m2) matrix where each row is a flattened surface.
446///
447/// # Arguments
448/// * `data` - Column-major matrix (n x m) where m = m1*m2
449/// * `n` - Number of samples
450/// * `m` - Number of grid points (m1 * m2)
451/// * `argvals_s` - Grid points in s direction (length m1)
452/// * `argvals_t` - Grid points in t direction (length m2)
453/// * `max_iter` - Maximum iterations
454/// * `tol` - Convergence tolerance
455pub fn geometric_median_2d(
456    data: &[f64],
457    n: usize,
458    m: usize,
459    argvals_s: &[f64],
460    argvals_t: &[f64],
461    max_iter: usize,
462    tol: f64,
463) -> Vec<f64> {
464    let expected_cols = argvals_s.len() * argvals_t.len();
465    if n == 0 || m == 0 || m != expected_cols || data.len() != n * m {
466        return Vec::new();
467    }
468
469    let weights = simpsons_weights_2d(argvals_s, argvals_t);
470    weiszfeld_iteration(data, n, m, &weights, max_iter, tol)
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476    use std::f64::consts::PI;
477
478    fn uniform_grid(n: usize) -> Vec<f64> {
479        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
480    }
481
482    // ============== Mean tests ==============
483
484    #[test]
485    fn test_mean_1d() {
486        // 2 samples, 3 points each
487        // Sample 1: [1, 2, 3]
488        // Sample 2: [3, 4, 5]
489        // Mean should be [2, 3, 4]
490        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; // column-major
491        let mean = mean_1d(&data, 2, 3);
492        assert_eq!(mean, vec![2.0, 3.0, 4.0]);
493    }
494
495    #[test]
496    fn test_mean_1d_single_sample() {
497        let data = vec![1.0, 2.0, 3.0];
498        let mean = mean_1d(&data, 1, 3);
499        assert_eq!(mean, vec![1.0, 2.0, 3.0]);
500    }
501
502    #[test]
503    fn test_mean_1d_invalid() {
504        assert!(mean_1d(&[], 0, 0).is_empty());
505        assert!(mean_1d(&[1.0], 1, 2).is_empty()); // wrong data length
506    }
507
508    #[test]
509    fn test_mean_2d_delegates() {
510        let data = vec![1.0, 3.0, 2.0, 4.0];
511        let mean1d = mean_1d(&data, 2, 2);
512        let mean2d = mean_2d(&data, 2, 2);
513        assert_eq!(mean1d, mean2d);
514    }
515
516    // ============== Center tests ==============
517
518    #[test]
519    fn test_center_1d() {
520        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; // column-major
521        let centered = center_1d(&data, 2, 3);
522        // Mean is [2, 3, 4], so centered should be [-1, 1, -1, 1, -1, 1]
523        assert_eq!(centered, vec![-1.0, 1.0, -1.0, 1.0, -1.0, 1.0]);
524    }
525
526    #[test]
527    fn test_center_1d_mean_zero() {
528        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0];
529        let centered = center_1d(&data, 2, 3);
530        let centered_mean = mean_1d(&centered, 2, 3);
531        for m in centered_mean {
532            assert!(m.abs() < 1e-10, "Centered data should have zero mean");
533        }
534    }
535
536    #[test]
537    fn test_center_1d_invalid() {
538        assert!(center_1d(&[], 0, 0).is_empty());
539    }
540
541    // ============== Norm tests ==============
542
543    #[test]
544    fn test_norm_lp_1d_constant() {
545        // Constant function 2 on [0, 1] has L2 norm = 2
546        let argvals = uniform_grid(21);
547        let mut data = vec![0.0; 21];
548        for j in 0..21 {
549            data[j] = 2.0;
550        }
551        let norms = norm_lp_1d(&data, 1, 21, &argvals, 2.0);
552        assert_eq!(norms.len(), 1);
553        assert!(
554            (norms[0] - 2.0).abs() < 0.1,
555            "L2 norm of constant 2 should be 2"
556        );
557    }
558
559    #[test]
560    fn test_norm_lp_1d_sine() {
561        // L2 norm of sin(pi*x) on [0, 1] = sqrt(0.5)
562        let argvals = uniform_grid(101);
563        let mut data = vec![0.0; 101];
564        for j in 0..101 {
565            data[j] = (PI * argvals[j]).sin();
566        }
567        let norms = norm_lp_1d(&data, 1, 101, &argvals, 2.0);
568        let expected = 0.5_f64.sqrt();
569        assert!(
570            (norms[0] - expected).abs() < 0.05,
571            "Expected {}, got {}",
572            expected,
573            norms[0]
574        );
575    }
576
577    #[test]
578    fn test_norm_lp_1d_invalid() {
579        assert!(norm_lp_1d(&[], 0, 0, &[], 2.0).is_empty());
580    }
581
582    // ============== Derivative tests ==============
583
584    #[test]
585    fn test_deriv_1d_linear() {
586        // Derivative of linear function x should be 1
587        let argvals = uniform_grid(21);
588        let data = argvals.clone();
589        let deriv = deriv_1d(&data, 1, 21, &argvals, 1);
590        // Interior points should have derivative close to 1
591        for j in 2..19 {
592            assert!((deriv[j] - 1.0).abs() < 0.1, "Derivative of x should be 1");
593        }
594    }
595
596    #[test]
597    fn test_deriv_1d_quadratic() {
598        // Derivative of x^2 should be 2x
599        let argvals = uniform_grid(51);
600        let mut data = vec![0.0; 51];
601        for j in 0..51 {
602            data[j] = argvals[j] * argvals[j];
603        }
604        let deriv = deriv_1d(&data, 1, 51, &argvals, 1);
605        // Check interior points
606        for j in 5..45 {
607            let expected = 2.0 * argvals[j];
608            assert!(
609                (deriv[j] - expected).abs() < 0.1,
610                "Derivative of x^2 should be 2x"
611            );
612        }
613    }
614
615    #[test]
616    fn test_deriv_1d_invalid() {
617        let result = deriv_1d(&[], 0, 0, &[], 1);
618        assert!(result.is_empty() || result.iter().all(|&x| x == 0.0));
619    }
620
621    // ============== Geometric median tests ==============
622
623    #[test]
624    fn test_geometric_median_identical_curves() {
625        // All curves identical -> median = that curve
626        let argvals = uniform_grid(21);
627        let n = 5;
628        let m = 21;
629        let mut data = vec![0.0; n * m];
630        for i in 0..n {
631            for j in 0..m {
632                data[i + j * n] = (2.0 * PI * argvals[j]).sin();
633            }
634        }
635        let median = geometric_median_1d(&data, n, m, &argvals, 100, 1e-6);
636        for j in 0..m {
637            let expected = (2.0 * PI * argvals[j]).sin();
638            assert!(
639                (median[j] - expected).abs() < 0.01,
640                "Median should equal all curves"
641            );
642        }
643    }
644
645    #[test]
646    fn test_geometric_median_converges() {
647        let argvals = uniform_grid(21);
648        let n = 10;
649        let m = 21;
650        let mut data = vec![0.0; n * m];
651        for i in 0..n {
652            for j in 0..m {
653                data[i + j * n] = (i as f64 / n as f64) * argvals[j];
654            }
655        }
656        let median = geometric_median_1d(&data, n, m, &argvals, 100, 1e-6);
657        assert_eq!(median.len(), m);
658        assert!(median.iter().all(|&x| x.is_finite()));
659    }
660
661    #[test]
662    fn test_geometric_median_invalid() {
663        assert!(geometric_median_1d(&[], 0, 0, &[], 100, 1e-6).is_empty());
664    }
665
666    // ============== 2D derivative tests ==============
667
668    #[test]
669    fn test_deriv_2d_linear_surface() {
670        // f(s, t) = 2*s + 3*t
671        // ∂f/∂s = 2, ∂f/∂t = 3, ∂²f/∂s∂t = 0
672        let m1 = 11;
673        let m2 = 11;
674        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
675        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
676
677        let n = 1; // single surface
678        let ncol = m1 * m2;
679        let mut data = vec![0.0; n * ncol];
680
681        for si in 0..m1 {
682            for ti in 0..m2 {
683                let s = argvals_s[si];
684                let t = argvals_t[ti];
685                let idx = si + ti * m1;
686                data[idx] = 2.0 * s + 3.0 * t;
687            }
688        }
689
690        let result = deriv_2d(&data, n, &argvals_s, &argvals_t, m1, m2).unwrap();
691
692        // Check interior points for ∂f/∂s ≈ 2
693        for si in 2..(m1 - 2) {
694            for ti in 2..(m2 - 2) {
695                let idx = si + ti * m1;
696                assert!(
697                    (result.ds[idx] - 2.0).abs() < 0.2,
698                    "∂f/∂s at ({}, {}) = {}, expected 2",
699                    si,
700                    ti,
701                    result.ds[idx]
702                );
703            }
704        }
705
706        // Check interior points for ∂f/∂t ≈ 3
707        for si in 2..(m1 - 2) {
708            for ti in 2..(m2 - 2) {
709                let idx = si + ti * m1;
710                assert!(
711                    (result.dt[idx] - 3.0).abs() < 0.2,
712                    "∂f/∂t at ({}, {}) = {}, expected 3",
713                    si,
714                    ti,
715                    result.dt[idx]
716                );
717            }
718        }
719
720        // Check interior points for mixed partial ≈ 0
721        for si in 2..(m1 - 2) {
722            for ti in 2..(m2 - 2) {
723                let idx = si + ti * m1;
724                assert!(
725                    result.dsdt[idx].abs() < 0.5,
726                    "∂²f/∂s∂t at ({}, {}) = {}, expected 0",
727                    si,
728                    ti,
729                    result.dsdt[idx]
730                );
731            }
732        }
733    }
734
735    #[test]
736    fn test_deriv_2d_quadratic_surface() {
737        // f(s, t) = s*t
738        // ∂f/∂s = t, ∂f/∂t = s, ∂²f/∂s∂t = 1
739        let m1 = 21;
740        let m2 = 21;
741        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
742        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
743
744        let n = 1;
745        let ncol = m1 * m2;
746        let mut data = vec![0.0; n * ncol];
747
748        for si in 0..m1 {
749            for ti in 0..m2 {
750                let s = argvals_s[si];
751                let t = argvals_t[ti];
752                let idx = si + ti * m1;
753                data[idx] = s * t;
754            }
755        }
756
757        let result = deriv_2d(&data, n, &argvals_s, &argvals_t, m1, m2).unwrap();
758
759        // Check interior points for ∂f/∂s ≈ t
760        for si in 3..(m1 - 3) {
761            for ti in 3..(m2 - 3) {
762                let idx = si + ti * m1;
763                let expected = argvals_t[ti];
764                assert!(
765                    (result.ds[idx] - expected).abs() < 0.1,
766                    "∂f/∂s at ({}, {}) = {}, expected {}",
767                    si,
768                    ti,
769                    result.ds[idx],
770                    expected
771                );
772            }
773        }
774
775        // Check interior points for ∂f/∂t ≈ s
776        for si in 3..(m1 - 3) {
777            for ti in 3..(m2 - 3) {
778                let idx = si + ti * m1;
779                let expected = argvals_s[si];
780                assert!(
781                    (result.dt[idx] - expected).abs() < 0.1,
782                    "∂f/∂t at ({}, {}) = {}, expected {}",
783                    si,
784                    ti,
785                    result.dt[idx],
786                    expected
787                );
788            }
789        }
790
791        // Check interior points for mixed partial ≈ 1
792        for si in 3..(m1 - 3) {
793            for ti in 3..(m2 - 3) {
794                let idx = si + ti * m1;
795                assert!(
796                    (result.dsdt[idx] - 1.0).abs() < 0.3,
797                    "∂²f/∂s∂t at ({}, {}) = {}, expected 1",
798                    si,
799                    ti,
800                    result.dsdt[idx]
801                );
802            }
803        }
804    }
805
806    #[test]
807    fn test_deriv_2d_invalid_input() {
808        // Empty data
809        let result = deriv_2d(&[], 0, &[], &[], 0, 0);
810        assert!(result.is_none());
811
812        // Mismatched dimensions
813        let data = vec![1.0; 4];
814        let argvals = vec![0.0, 1.0];
815        let result = deriv_2d(&data, 1, &argvals, &[0.0, 0.5, 1.0], 2, 2);
816        assert!(result.is_none());
817    }
818
819    // ============== 2D geometric median tests ==============
820
821    #[test]
822    fn test_geometric_median_2d_basic() {
823        // Three identical surfaces -> median = that surface
824        let m1 = 5;
825        let m2 = 5;
826        let m = m1 * m2;
827        let n = 3;
828        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
829        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
830
831        let mut data = vec![0.0; n * m];
832
833        // Create identical surfaces: f(s, t) = s + t
834        for i in 0..n {
835            for si in 0..m1 {
836                for ti in 0..m2 {
837                    let idx = si + ti * m1;
838                    let s = argvals_s[si];
839                    let t = argvals_t[ti];
840                    data[i + idx * n] = s + t;
841                }
842            }
843        }
844
845        let median = geometric_median_2d(&data, n, m, &argvals_s, &argvals_t, 100, 1e-6);
846        assert_eq!(median.len(), m);
847
848        // Check that median equals the surface
849        for si in 0..m1 {
850            for ti in 0..m2 {
851                let idx = si + ti * m1;
852                let expected = argvals_s[si] + argvals_t[ti];
853                assert!(
854                    (median[idx] - expected).abs() < 0.01,
855                    "Median at ({}, {}) = {}, expected {}",
856                    si,
857                    ti,
858                    median[idx],
859                    expected
860                );
861            }
862        }
863    }
864}