Skip to main content

fdars_core/
fdata.rs

1//! Functional data operations: mean, center, derivatives, norms, and geometric median.
2
3use crate::helpers::{simpsons_weights, simpsons_weights_2d, NUMERICAL_EPS};
4use crate::iter_maybe_parallel;
5#[cfg(feature = "parallel")]
6use rayon::iter::ParallelIterator;
7
8/// Compute finite difference for a 1D function at a given index.
9///
10/// Uses forward difference at left boundary, backward difference at right boundary,
11/// and central difference for interior points.
12fn finite_diff_1d(
13    values: impl Fn(usize) -> f64,
14    idx: usize,
15    n_points: usize,
16    step_sizes: &[f64],
17) -> f64 {
18    if idx == 0 {
19        (values(1) - values(0)) / step_sizes[0]
20    } else if idx == n_points - 1 {
21        (values(n_points - 1) - values(n_points - 2)) / step_sizes[n_points - 1]
22    } else {
23        (values(idx + 1) - values(idx - 1)) / step_sizes[idx]
24    }
25}
26
27/// Compute 2D partial derivatives at a single grid point.
28///
29/// Returns (∂f/∂s, ∂f/∂t, ∂²f/∂s∂t) using finite differences.
30fn compute_2d_derivatives(
31    get_val: impl Fn(usize, usize) -> f64,
32    si: usize,
33    ti: usize,
34    m1: usize,
35    m2: usize,
36    hs: &[f64],
37    ht: &[f64],
38) -> (f64, f64, f64) {
39    // ∂f/∂s
40    let ds = finite_diff_1d(|s| get_val(s, ti), si, m1, hs);
41
42    // ∂f/∂t
43    let dt = finite_diff_1d(|t| get_val(si, t), ti, m2, ht);
44
45    // ∂²f/∂s∂t (mixed partial)
46    let denom = hs[si] * ht[ti];
47
48    // Get the appropriate indices for s and t differences
49    let (s_lo, s_hi) = if si == 0 {
50        (0, 1)
51    } else if si == m1 - 1 {
52        (m1 - 2, m1 - 1)
53    } else {
54        (si - 1, si + 1)
55    };
56
57    let (t_lo, t_hi) = if ti == 0 {
58        (0, 1)
59    } else if ti == m2 - 1 {
60        (m2 - 2, m2 - 1)
61    } else {
62        (ti - 1, ti + 1)
63    };
64
65    let dsdt = (get_val(s_hi, t_hi) - get_val(s_lo, t_hi) - get_val(s_hi, t_lo)
66        + get_val(s_lo, t_lo))
67        / denom;
68
69    (ds, dt, dsdt)
70}
71
72/// Perform Weiszfeld iteration to compute geometric median.
73///
74/// This is the core algorithm shared by 1D and 2D geometric median computations.
75fn weiszfeld_iteration(
76    data: &[f64],
77    n: usize,
78    m: usize,
79    weights: &[f64],
80    max_iter: usize,
81    tol: f64,
82) -> Vec<f64> {
83    // Initialize with the mean
84    let mut median: Vec<f64> = (0..m)
85        .map(|j| {
86            let mut sum = 0.0;
87            for i in 0..n {
88                sum += data[i + j * n];
89            }
90            sum / n as f64
91        })
92        .collect();
93
94    for _ in 0..max_iter {
95        // Compute distances from current median to all curves
96        let distances: Vec<f64> = (0..n)
97            .map(|i| {
98                let mut dist_sq = 0.0;
99                for j in 0..m {
100                    let diff = data[i + j * n] - median[j];
101                    dist_sq += diff * diff * weights[j];
102                }
103                dist_sq.sqrt()
104            })
105            .collect();
106
107        // Compute weights (1/distance), handling zero distances
108        let inv_distances: Vec<f64> = distances
109            .iter()
110            .map(|d| {
111                if *d > NUMERICAL_EPS {
112                    1.0 / d
113                } else {
114                    1.0 / NUMERICAL_EPS
115                }
116            })
117            .collect();
118
119        let sum_inv_dist: f64 = inv_distances.iter().sum();
120
121        // Update median using Weiszfeld iteration
122        let new_median: Vec<f64> = (0..m)
123            .map(|j| {
124                let mut weighted_sum = 0.0;
125                for i in 0..n {
126                    weighted_sum += data[i + j * n] * inv_distances[i];
127                }
128                weighted_sum / sum_inv_dist
129            })
130            .collect();
131
132        // Check convergence
133        let diff: f64 = median
134            .iter()
135            .zip(new_median.iter())
136            .map(|(a, b)| (a - b).abs())
137            .sum::<f64>()
138            / m as f64;
139
140        median = new_median;
141
142        if diff < tol {
143            break;
144        }
145    }
146
147    median
148}
149
150/// Compute the mean function across all samples (1D).
151///
152/// # Arguments
153/// * `data` - Column-major matrix (n x m)
154/// * `n` - Number of samples
155/// * `m` - Number of evaluation points
156///
157/// # Returns
158/// Mean function values at each evaluation point
159pub fn mean_1d(data: &[f64], n: usize, m: usize) -> Vec<f64> {
160    if n == 0 || m == 0 || data.len() != n * m {
161        return Vec::new();
162    }
163
164    iter_maybe_parallel!(0..m)
165        .map(|j| {
166            let mut sum = 0.0;
167            for i in 0..n {
168                sum += data[i + j * n];
169            }
170            sum / n as f64
171        })
172        .collect()
173}
174
175/// Compute the mean function for 2D surfaces.
176///
177/// Data is stored as n x (m1*m2) matrix where each row is a flattened surface.
178pub fn mean_2d(data: &[f64], n: usize, m: usize) -> Vec<f64> {
179    // Same computation as 1D - just compute pointwise mean
180    mean_1d(data, n, m)
181}
182
183/// Center functional data by subtracting the mean function.
184///
185/// # Arguments
186/// * `data` - Column-major matrix (n x m)
187/// * `n` - Number of samples
188/// * `m` - Number of evaluation points
189///
190/// # Returns
191/// Centered data matrix (column-major)
192pub fn center_1d(data: &[f64], n: usize, m: usize) -> Vec<f64> {
193    if n == 0 || m == 0 || data.len() != n * m {
194        return Vec::new();
195    }
196
197    // First compute the mean for each column (parallelized)
198    let means: Vec<f64> = iter_maybe_parallel!(0..m)
199        .map(|j| {
200            let mut sum = 0.0;
201            for i in 0..n {
202                sum += data[i + j * n];
203            }
204            sum / n as f64
205        })
206        .collect();
207
208    // Create centered data (parallelized by column)
209    let mut centered = vec![0.0; n * m];
210    for j in 0..m {
211        for i in 0..n {
212            centered[i + j * n] = data[i + j * n] - means[j];
213        }
214    }
215
216    centered
217}
218
219/// Compute Lp norm for each sample.
220///
221/// # Arguments
222/// * `data` - Column-major matrix (n x m)
223/// * `n` - Number of samples
224/// * `m` - Number of evaluation points
225/// * `argvals` - Evaluation points for integration
226/// * `p` - Order of the norm (e.g., 2.0 for L2)
227///
228/// # Returns
229/// Vector of Lp norms for each sample
230pub fn norm_lp_1d(data: &[f64], n: usize, m: usize, argvals: &[f64], p: f64) -> Vec<f64> {
231    if n == 0 || m == 0 || argvals.len() != m || data.len() != n * m {
232        return Vec::new();
233    }
234
235    let weights = simpsons_weights(argvals);
236
237    iter_maybe_parallel!(0..n)
238        .map(|i| {
239            let mut integral = 0.0;
240            for j in 0..m {
241                let val = data[i + j * n].abs().powf(p);
242                integral += val * weights[j];
243            }
244            integral.powf(1.0 / p)
245        })
246        .collect()
247}
248
249/// Compute numerical derivative of functional data (parallelized over rows).
250///
251/// # Arguments
252/// * `data` - Column-major matrix (n x m)
253/// * `n` - Number of samples
254/// * `m` - Number of evaluation points
255/// * `argvals` - Evaluation points
256/// * `nderiv` - Order of derivative
257///
258/// # Returns
259/// Derivative data matrix (column-major)
260pub fn deriv_1d(data: &[f64], n: usize, m: usize, argvals: &[f64], nderiv: usize) -> Vec<f64> {
261    if n == 0 || m == 0 || argvals.len() != m || nderiv < 1 || data.len() != n * m {
262        return vec![0.0; n * m];
263    }
264
265    let mut current = data.to_vec();
266
267    // Pre-compute step sizes for central differences
268    let h0 = argvals[1] - argvals[0];
269    let hn = argvals[m - 1] - argvals[m - 2];
270    let h_central: Vec<f64> = (1..(m - 1))
271        .map(|j| argvals[j + 1] - argvals[j - 1])
272        .collect();
273
274    for _ in 0..nderiv {
275        // Compute derivative for each row in parallel
276        let deriv: Vec<f64> = iter_maybe_parallel!(0..n)
277            .flat_map(|i| {
278                let mut row_deriv = vec![0.0; m];
279
280                // Forward difference at left boundary
281                row_deriv[0] = (current[i + n] - current[i]) / h0;
282
283                // Central differences for interior points
284                for j in 1..(m - 1) {
285                    row_deriv[j] =
286                        (current[i + (j + 1) * n] - current[i + (j - 1) * n]) / h_central[j - 1];
287                }
288
289                // Backward difference at right boundary
290                row_deriv[m - 1] = (current[i + (m - 1) * n] - current[i + (m - 2) * n]) / hn;
291
292                row_deriv
293            })
294            .collect();
295
296        // Reorder from row-major to column-major order
297        current = vec![0.0; n * m];
298        for i in 0..n {
299            for j in 0..m {
300                current[i + j * n] = deriv[i * m + j];
301            }
302        }
303    }
304
305    current
306}
307
308/// Result of 2D partial derivatives.
309pub struct Deriv2DResult {
310    /// Partial derivative with respect to s (∂f/∂s)
311    pub ds: Vec<f64>,
312    /// Partial derivative with respect to t (∂f/∂t)
313    pub dt: Vec<f64>,
314    /// Mixed partial derivative (∂²f/∂s∂t)
315    pub dsdt: Vec<f64>,
316}
317
318/// Compute finite-difference step sizes for a grid.
319///
320/// Uses forward/backward difference at boundaries and central difference for interior.
321fn compute_step_sizes(argvals: &[f64]) -> Vec<f64> {
322    let m = argvals.len();
323    (0..m)
324        .map(|j| {
325            if j == 0 {
326                argvals[1] - argvals[0]
327            } else if j == m - 1 {
328                argvals[m - 1] - argvals[m - 2]
329            } else {
330                argvals[j + 1] - argvals[j - 1]
331            }
332        })
333        .collect()
334}
335
336/// Collect per-curve row vectors into a column-major matrix.
337fn reassemble_colmajor(rows: &[Vec<f64>], n: usize, ncol: usize) -> Vec<f64> {
338    let mut mat = vec![0.0; n * ncol];
339    for i in 0..n {
340        for j in 0..ncol {
341            mat[i + j * n] = rows[i][j];
342        }
343    }
344    mat
345}
346
347/// Compute 2D partial derivatives for surface data.
348///
349/// For a surface f(s,t), computes:
350/// - ds: partial derivative with respect to s (∂f/∂s)
351/// - dt: partial derivative with respect to t (∂f/∂t)
352/// - dsdt: mixed partial derivative (∂²f/∂s∂t)
353///
354/// # Arguments
355/// * `data` - Column-major matrix, n surfaces, each stored as m1*m2 values
356/// * `n` - Number of surfaces
357/// * `argvals_s` - Grid points in s direction (length m1)
358/// * `argvals_t` - Grid points in t direction (length m2)
359/// * `m1` - Grid size in s direction
360/// * `m2` - Grid size in t direction
361pub fn deriv_2d(
362    data: &[f64],
363    n: usize,
364    argvals_s: &[f64],
365    argvals_t: &[f64],
366    m1: usize,
367    m2: usize,
368) -> Option<Deriv2DResult> {
369    let ncol = m1 * m2;
370    if n == 0 || ncol == 0 || argvals_s.len() != m1 || argvals_t.len() != m2 {
371        return None;
372    }
373
374    let hs = compute_step_sizes(argvals_s);
375    let ht = compute_step_sizes(argvals_t);
376
377    // Compute all derivatives in parallel over surfaces
378    let results: Vec<(Vec<f64>, Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
379        .map(|i| {
380            let mut ds = vec![0.0; ncol];
381            let mut dt = vec![0.0; ncol];
382            let mut dsdt = vec![0.0; ncol];
383
384            let get_val = |si: usize, ti: usize| -> f64 { data[i + (si + ti * m1) * n] };
385
386            for ti in 0..m2 {
387                for si in 0..m1 {
388                    let idx = si + ti * m1;
389                    let (ds_val, dt_val, dsdt_val) =
390                        compute_2d_derivatives(get_val, si, ti, m1, m2, &hs, &ht);
391                    ds[idx] = ds_val;
392                    dt[idx] = dt_val;
393                    dsdt[idx] = dsdt_val;
394                }
395            }
396
397            (ds, dt, dsdt)
398        })
399        .collect();
400
401    let ds_vecs: Vec<Vec<f64>> = results.iter().map(|r| r.0.clone()).collect();
402    let dt_vecs: Vec<Vec<f64>> = results.iter().map(|r| r.1.clone()).collect();
403    let dsdt_vecs: Vec<Vec<f64>> = results.iter().map(|r| r.2.clone()).collect();
404
405    Some(Deriv2DResult {
406        ds: reassemble_colmajor(&ds_vecs, n, ncol),
407        dt: reassemble_colmajor(&dt_vecs, n, ncol),
408        dsdt: reassemble_colmajor(&dsdt_vecs, n, ncol),
409    })
410}
411
412/// Compute the geometric median (L1 median) of functional data using Weiszfeld's algorithm.
413///
414/// The geometric median minimizes sum of L2 distances to all curves.
415///
416/// # Arguments
417/// * `data` - Column-major matrix (n x m)
418/// * `n` - Number of samples
419/// * `m` - Number of evaluation points
420/// * `argvals` - Evaluation points for integration
421/// * `max_iter` - Maximum iterations
422/// * `tol` - Convergence tolerance
423pub fn geometric_median_1d(
424    data: &[f64],
425    n: usize,
426    m: usize,
427    argvals: &[f64],
428    max_iter: usize,
429    tol: f64,
430) -> Vec<f64> {
431    if n == 0 || m == 0 || argvals.len() != m || data.len() != n * m {
432        return Vec::new();
433    }
434
435    let weights = simpsons_weights(argvals);
436    weiszfeld_iteration(data, n, m, &weights, max_iter, tol)
437}
438
439/// Compute the geometric median for 2D functional data.
440///
441/// Data is stored as n x (m1*m2) matrix where each row is a flattened surface.
442///
443/// # Arguments
444/// * `data` - Column-major matrix (n x m) where m = m1*m2
445/// * `n` - Number of samples
446/// * `m` - Number of grid points (m1 * m2)
447/// * `argvals_s` - Grid points in s direction (length m1)
448/// * `argvals_t` - Grid points in t direction (length m2)
449/// * `max_iter` - Maximum iterations
450/// * `tol` - Convergence tolerance
451pub fn geometric_median_2d(
452    data: &[f64],
453    n: usize,
454    m: usize,
455    argvals_s: &[f64],
456    argvals_t: &[f64],
457    max_iter: usize,
458    tol: f64,
459) -> Vec<f64> {
460    let expected_cols = argvals_s.len() * argvals_t.len();
461    if n == 0 || m == 0 || m != expected_cols || data.len() != n * m {
462        return Vec::new();
463    }
464
465    let weights = simpsons_weights_2d(argvals_s, argvals_t);
466    weiszfeld_iteration(data, n, m, &weights, max_iter, tol)
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use std::f64::consts::PI;
473
474    fn uniform_grid(n: usize) -> Vec<f64> {
475        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
476    }
477
478    // ============== Mean tests ==============
479
480    #[test]
481    fn test_mean_1d() {
482        // 2 samples, 3 points each
483        // Sample 1: [1, 2, 3]
484        // Sample 2: [3, 4, 5]
485        // Mean should be [2, 3, 4]
486        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; // column-major
487        let mean = mean_1d(&data, 2, 3);
488        assert_eq!(mean, vec![2.0, 3.0, 4.0]);
489    }
490
491    #[test]
492    fn test_mean_1d_single_sample() {
493        let data = vec![1.0, 2.0, 3.0];
494        let mean = mean_1d(&data, 1, 3);
495        assert_eq!(mean, vec![1.0, 2.0, 3.0]);
496    }
497
498    #[test]
499    fn test_mean_1d_invalid() {
500        assert!(mean_1d(&[], 0, 0).is_empty());
501        assert!(mean_1d(&[1.0], 1, 2).is_empty()); // wrong data length
502    }
503
504    #[test]
505    fn test_mean_2d_delegates() {
506        let data = vec![1.0, 3.0, 2.0, 4.0];
507        let mean1d = mean_1d(&data, 2, 2);
508        let mean2d = mean_2d(&data, 2, 2);
509        assert_eq!(mean1d, mean2d);
510    }
511
512    // ============== Center tests ==============
513
514    #[test]
515    fn test_center_1d() {
516        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0]; // column-major
517        let centered = center_1d(&data, 2, 3);
518        // Mean is [2, 3, 4], so centered should be [-1, 1, -1, 1, -1, 1]
519        assert_eq!(centered, vec![-1.0, 1.0, -1.0, 1.0, -1.0, 1.0]);
520    }
521
522    #[test]
523    fn test_center_1d_mean_zero() {
524        let data = vec![1.0, 3.0, 2.0, 4.0, 3.0, 5.0];
525        let centered = center_1d(&data, 2, 3);
526        let centered_mean = mean_1d(&centered, 2, 3);
527        for m in centered_mean {
528            assert!(m.abs() < 1e-10, "Centered data should have zero mean");
529        }
530    }
531
532    #[test]
533    fn test_center_1d_invalid() {
534        assert!(center_1d(&[], 0, 0).is_empty());
535    }
536
537    // ============== Norm tests ==============
538
539    #[test]
540    fn test_norm_lp_1d_constant() {
541        // Constant function 2 on [0, 1] has L2 norm = 2
542        let argvals = uniform_grid(21);
543        let mut data = vec![0.0; 21];
544        for j in 0..21 {
545            data[j] = 2.0;
546        }
547        let norms = norm_lp_1d(&data, 1, 21, &argvals, 2.0);
548        assert_eq!(norms.len(), 1);
549        assert!(
550            (norms[0] - 2.0).abs() < 0.1,
551            "L2 norm of constant 2 should be 2"
552        );
553    }
554
555    #[test]
556    fn test_norm_lp_1d_sine() {
557        // L2 norm of sin(pi*x) on [0, 1] = sqrt(0.5)
558        let argvals = uniform_grid(101);
559        let mut data = vec![0.0; 101];
560        for j in 0..101 {
561            data[j] = (PI * argvals[j]).sin();
562        }
563        let norms = norm_lp_1d(&data, 1, 101, &argvals, 2.0);
564        let expected = 0.5_f64.sqrt();
565        assert!(
566            (norms[0] - expected).abs() < 0.05,
567            "Expected {}, got {}",
568            expected,
569            norms[0]
570        );
571    }
572
573    #[test]
574    fn test_norm_lp_1d_invalid() {
575        assert!(norm_lp_1d(&[], 0, 0, &[], 2.0).is_empty());
576    }
577
578    // ============== Derivative tests ==============
579
580    #[test]
581    fn test_deriv_1d_linear() {
582        // Derivative of linear function x should be 1
583        let argvals = uniform_grid(21);
584        let data = argvals.clone();
585        let deriv = deriv_1d(&data, 1, 21, &argvals, 1);
586        // Interior points should have derivative close to 1
587        for j in 2..19 {
588            assert!((deriv[j] - 1.0).abs() < 0.1, "Derivative of x should be 1");
589        }
590    }
591
592    #[test]
593    fn test_deriv_1d_quadratic() {
594        // Derivative of x^2 should be 2x
595        let argvals = uniform_grid(51);
596        let mut data = vec![0.0; 51];
597        for j in 0..51 {
598            data[j] = argvals[j] * argvals[j];
599        }
600        let deriv = deriv_1d(&data, 1, 51, &argvals, 1);
601        // Check interior points
602        for j in 5..45 {
603            let expected = 2.0 * argvals[j];
604            assert!(
605                (deriv[j] - expected).abs() < 0.1,
606                "Derivative of x^2 should be 2x"
607            );
608        }
609    }
610
611    #[test]
612    fn test_deriv_1d_invalid() {
613        let result = deriv_1d(&[], 0, 0, &[], 1);
614        assert!(result.is_empty() || result.iter().all(|&x| x == 0.0));
615    }
616
617    // ============== Geometric median tests ==============
618
619    #[test]
620    fn test_geometric_median_identical_curves() {
621        // All curves identical -> median = that curve
622        let argvals = uniform_grid(21);
623        let n = 5;
624        let m = 21;
625        let mut data = vec![0.0; n * m];
626        for i in 0..n {
627            for j in 0..m {
628                data[i + j * n] = (2.0 * PI * argvals[j]).sin();
629            }
630        }
631        let median = geometric_median_1d(&data, n, m, &argvals, 100, 1e-6);
632        for j in 0..m {
633            let expected = (2.0 * PI * argvals[j]).sin();
634            assert!(
635                (median[j] - expected).abs() < 0.01,
636                "Median should equal all curves"
637            );
638        }
639    }
640
641    #[test]
642    fn test_geometric_median_converges() {
643        let argvals = uniform_grid(21);
644        let n = 10;
645        let m = 21;
646        let mut data = vec![0.0; n * m];
647        for i in 0..n {
648            for j in 0..m {
649                data[i + j * n] = (i as f64 / n as f64) * argvals[j];
650            }
651        }
652        let median = geometric_median_1d(&data, n, m, &argvals, 100, 1e-6);
653        assert_eq!(median.len(), m);
654        assert!(median.iter().all(|&x| x.is_finite()));
655    }
656
657    #[test]
658    fn test_geometric_median_invalid() {
659        assert!(geometric_median_1d(&[], 0, 0, &[], 100, 1e-6).is_empty());
660    }
661
662    // ============== 2D derivative tests ==============
663
664    #[test]
665    fn test_deriv_2d_linear_surface() {
666        // f(s, t) = 2*s + 3*t
667        // ∂f/∂s = 2, ∂f/∂t = 3, ∂²f/∂s∂t = 0
668        let m1 = 11;
669        let m2 = 11;
670        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
671        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
672
673        let n = 1; // single surface
674        let ncol = m1 * m2;
675        let mut data = vec![0.0; n * ncol];
676
677        for si in 0..m1 {
678            for ti in 0..m2 {
679                let s = argvals_s[si];
680                let t = argvals_t[ti];
681                let idx = si + ti * m1;
682                data[idx] = 2.0 * s + 3.0 * t;
683            }
684        }
685
686        let result = deriv_2d(&data, n, &argvals_s, &argvals_t, m1, m2).unwrap();
687
688        // Check interior points for ∂f/∂s ≈ 2
689        for si in 2..(m1 - 2) {
690            for ti in 2..(m2 - 2) {
691                let idx = si + ti * m1;
692                assert!(
693                    (result.ds[idx] - 2.0).abs() < 0.2,
694                    "∂f/∂s at ({}, {}) = {}, expected 2",
695                    si,
696                    ti,
697                    result.ds[idx]
698                );
699            }
700        }
701
702        // Check interior points for ∂f/∂t ≈ 3
703        for si in 2..(m1 - 2) {
704            for ti in 2..(m2 - 2) {
705                let idx = si + ti * m1;
706                assert!(
707                    (result.dt[idx] - 3.0).abs() < 0.2,
708                    "∂f/∂t at ({}, {}) = {}, expected 3",
709                    si,
710                    ti,
711                    result.dt[idx]
712                );
713            }
714        }
715
716        // Check interior points for mixed partial ≈ 0
717        for si in 2..(m1 - 2) {
718            for ti in 2..(m2 - 2) {
719                let idx = si + ti * m1;
720                assert!(
721                    result.dsdt[idx].abs() < 0.5,
722                    "∂²f/∂s∂t at ({}, {}) = {}, expected 0",
723                    si,
724                    ti,
725                    result.dsdt[idx]
726                );
727            }
728        }
729    }
730
731    #[test]
732    fn test_deriv_2d_quadratic_surface() {
733        // f(s, t) = s*t
734        // ∂f/∂s = t, ∂f/∂t = s, ∂²f/∂s∂t = 1
735        let m1 = 21;
736        let m2 = 21;
737        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
738        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
739
740        let n = 1;
741        let ncol = m1 * m2;
742        let mut data = vec![0.0; n * ncol];
743
744        for si in 0..m1 {
745            for ti in 0..m2 {
746                let s = argvals_s[si];
747                let t = argvals_t[ti];
748                let idx = si + ti * m1;
749                data[idx] = s * t;
750            }
751        }
752
753        let result = deriv_2d(&data, n, &argvals_s, &argvals_t, m1, m2).unwrap();
754
755        // Check interior points for ∂f/∂s ≈ t
756        for si in 3..(m1 - 3) {
757            for ti in 3..(m2 - 3) {
758                let idx = si + ti * m1;
759                let expected = argvals_t[ti];
760                assert!(
761                    (result.ds[idx] - expected).abs() < 0.1,
762                    "∂f/∂s at ({}, {}) = {}, expected {}",
763                    si,
764                    ti,
765                    result.ds[idx],
766                    expected
767                );
768            }
769        }
770
771        // Check interior points for ∂f/∂t ≈ s
772        for si in 3..(m1 - 3) {
773            for ti in 3..(m2 - 3) {
774                let idx = si + ti * m1;
775                let expected = argvals_s[si];
776                assert!(
777                    (result.dt[idx] - expected).abs() < 0.1,
778                    "∂f/∂t at ({}, {}) = {}, expected {}",
779                    si,
780                    ti,
781                    result.dt[idx],
782                    expected
783                );
784            }
785        }
786
787        // Check interior points for mixed partial ≈ 1
788        for si in 3..(m1 - 3) {
789            for ti in 3..(m2 - 3) {
790                let idx = si + ti * m1;
791                assert!(
792                    (result.dsdt[idx] - 1.0).abs() < 0.3,
793                    "∂²f/∂s∂t at ({}, {}) = {}, expected 1",
794                    si,
795                    ti,
796                    result.dsdt[idx]
797                );
798            }
799        }
800    }
801
802    #[test]
803    fn test_deriv_2d_invalid_input() {
804        // Empty data
805        let result = deriv_2d(&[], 0, &[], &[], 0, 0);
806        assert!(result.is_none());
807
808        // Mismatched dimensions
809        let data = vec![1.0; 4];
810        let argvals = vec![0.0, 1.0];
811        let result = deriv_2d(&data, 1, &argvals, &[0.0, 0.5, 1.0], 2, 2);
812        assert!(result.is_none());
813    }
814
815    // ============== 2D geometric median tests ==============
816
817    #[test]
818    fn test_geometric_median_2d_basic() {
819        // Three identical surfaces -> median = that surface
820        let m1 = 5;
821        let m2 = 5;
822        let m = m1 * m2;
823        let n = 3;
824        let argvals_s: Vec<f64> = (0..m1).map(|i| i as f64 / (m1 - 1) as f64).collect();
825        let argvals_t: Vec<f64> = (0..m2).map(|i| i as f64 / (m2 - 1) as f64).collect();
826
827        let mut data = vec![0.0; n * m];
828
829        // Create identical surfaces: f(s, t) = s + t
830        for i in 0..n {
831            for si in 0..m1 {
832                for ti in 0..m2 {
833                    let idx = si + ti * m1;
834                    let s = argvals_s[si];
835                    let t = argvals_t[ti];
836                    data[i + idx * n] = s + t;
837                }
838            }
839        }
840
841        let median = geometric_median_2d(&data, n, m, &argvals_s, &argvals_t, 100, 1e-6);
842        assert_eq!(median.len(), m);
843
844        // Check that median equals the surface
845        for si in 0..m1 {
846            for ti in 0..m2 {
847                let idx = si + ti * m1;
848                let expected = argvals_s[si] + argvals_t[ti];
849                assert!(
850                    (median[idx] - expected).abs() < 0.01,
851                    "Median at ({}, {}) = {}, expected {}",
852                    si,
853                    ti,
854                    median[idx],
855                    expected
856                );
857            }
858        }
859    }
860}