Skip to main content

fdars_core/alignment/
generative.rs

1//! Gaussian generative model for random curve synthesis from aligned data.
2
3use super::srsf::{reparameterize_curve, srsf_inverse};
4use super::KarcherMeanResult;
5use crate::elastic_fpca::{horiz_fpca, sphere_karcher_mean, vert_fpca, warps_to_normalized_psi};
6use crate::error::FdarError;
7use crate::matrix::FdMatrix;
8use crate::warping::{exp_map_sphere, normalize_warp, psi_to_gam};
9
10use rand::prelude::*;
11use rand_distr::StandardNormal;
12
13// ─── Types ──────────────────────────────────────────────────────────────────
14
15/// Result of Gaussian generative model sampling.
16#[derive(Debug, Clone, PartialEq)]
17#[non_exhaustive]
18pub struct GenerativeModelResult {
19    /// Generated function samples (n_samples x m).
20    pub samples: FdMatrix,
21    /// Generated warping functions (n_samples x m).
22    pub warps: FdMatrix,
23    /// FPCA scores used for generation (n_samples x ncomp).
24    pub scores: FdMatrix,
25}
26
27// ─── Gaussian Generative Model ──────────────────────────────────────────────
28
29/// Generate random curves from a fitted Gaussian model on aligned data.
30///
31/// Samples amplitude and phase components independently from their
32/// respective FPCA score distributions (Gaussian with covariance = diag(eigenvalues)),
33/// then combines them to produce synthetic functional data.
34///
35/// # Arguments
36/// * `karcher` — Pre-computed Karcher mean result (with aligned data and gammas)
37/// * `argvals` — Evaluation points (length m)
38/// * `ncomp` — Number of principal components for both amplitude and phase
39/// * `n_samples` — Number of curves to generate
40/// * `seed` — RNG seed for reproducibility
41///
42/// # Errors
43/// Returns `FdarError::InvalidDimension` if dimensions are inconsistent or
44/// `FdarError::ComputationFailed` if FPCA fails.
45#[must_use = "expensive computation whose result should not be discarded"]
46pub fn gauss_model(
47    karcher: &KarcherMeanResult,
48    argvals: &[f64],
49    ncomp: usize,
50    n_samples: usize,
51    seed: u64,
52) -> Result<GenerativeModelResult, FdarError> {
53    let (n, m) = karcher.aligned_data.shape();
54    if argvals.len() != m {
55        return Err(FdarError::InvalidDimension {
56            parameter: "argvals",
57            expected: format!("length {m}"),
58            actual: format!("length {}", argvals.len()),
59        });
60    }
61    if n < 2 || m < 2 {
62        return Err(FdarError::InvalidDimension {
63            parameter: "aligned_data",
64            expected: "n >= 2, m >= 2".to_string(),
65            actual: format!("n={n}, m={m}"),
66        });
67    }
68    if ncomp < 1 {
69        return Err(FdarError::InvalidParameter {
70            parameter: "ncomp",
71            message: "ncomp must be >= 1".to_string(),
72        });
73    }
74    if n_samples < 1 {
75        return Err(FdarError::InvalidParameter {
76            parameter: "n_samples",
77            message: "n_samples must be >= 1".to_string(),
78        });
79    }
80
81    // Amplitude FPCA
82    let vert = vert_fpca(karcher, argvals, ncomp)?;
83    let vert_ncomp = vert.eigenvalues.len();
84    let m_aug = m + 1;
85
86    // Phase FPCA
87    let horiz = horiz_fpca(karcher, argvals, ncomp)?;
88    let horiz_ncomp = horiz.eigenvalues.len();
89
90    let t0 = argvals[0];
91    let t1 = argvals[m - 1];
92    let domain = t1 - t0;
93    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
94
95    // Get the mean psi on the sphere for phase generation
96    let psis = warps_to_normalized_psi(&karcher.gammas, argvals);
97    let mu_psi = sphere_karcher_mean(&psis, &time, 50);
98
99    // Mean SRSF (augmented)
100    let mean_q = &vert.mean_q;
101
102    let total_ncomp = vert_ncomp + horiz_ncomp;
103    let mut samples = FdMatrix::zeros(n_samples, m);
104    let mut warps = FdMatrix::zeros(n_samples, m);
105    let mut scores = FdMatrix::zeros(n_samples, total_ncomp);
106
107    for i in 0..n_samples {
108        let mut rng = StdRng::seed_from_u64(seed + i as u64);
109
110        // Generate amplitude scores and reconstruct SRSF
111        let mut q_new = vec![0.0; m_aug];
112        q_new[..m_aug].copy_from_slice(&mean_q[..m_aug]);
113        for k in 0..vert_ncomp {
114            let std_dev = vert.eigenvalues[k].max(0.0).sqrt();
115            let z: f64 = rng.sample(StandardNormal);
116            let score_k = z * std_dev;
117            scores[(i, k)] = score_k;
118            for j in 0..m_aug {
119                q_new[j] += score_k * vert.eigenfunctions_q[(k, j)];
120            }
121        }
122
123        // Reconstruct curve from SRSF
124        let aug_val = q_new[m];
125        let f0 = aug_val.signum() * aug_val * aug_val;
126        let f_new = srsf_inverse(&q_new[..m], argvals, f0);
127
128        // Generate phase scores and reconstruct warping function
129        let mut v = vec![0.0; m];
130        for k in 0..horiz_ncomp {
131            let std_dev = horiz.eigenvalues[k].max(0.0).sqrt();
132            let z: f64 = rng.sample(StandardNormal);
133            let score_k = z * std_dev;
134            scores[(i, vert_ncomp + k)] = score_k;
135            for j in 0..m {
136                v[j] += score_k * horiz.eigenfunctions_psi[(k, j)];
137            }
138        }
139
140        // Map shooting vector to sphere via exp map at mean psi
141        let psi_new = exp_map_sphere(&mu_psi, &v, &time);
142        let gam_01 = psi_to_gam(&psi_new, &time);
143
144        // Rescale gamma to original domain
145        let mut gamma: Vec<f64> = gam_01.iter().map(|&g| t0 + g * domain).collect();
146        normalize_warp(&mut gamma, argvals);
147
148        // Apply warp to generate final sample
149        let sample = reparameterize_curve(&f_new, argvals, &gamma);
150
151        for j in 0..m {
152            samples[(i, j)] = sample[j];
153            warps[(i, j)] = gamma[j];
154        }
155    }
156
157    Ok(GenerativeModelResult {
158        samples,
159        warps,
160        scores,
161    })
162}
163
164/// Generate random curves from a joint Gaussian model preserving amplitude-phase
165/// correlation.
166///
167/// Computes amplitude and phase FPCA separately, concatenates their scores to
168/// form a joint score vector, estimates the joint covariance, and samples from
169/// the joint distribution. This preserves cross-correlation between amplitude
170/// and phase variability.
171///
172/// # Arguments
173/// * `karcher` — Pre-computed Karcher mean result
174/// * `argvals` — Evaluation points (length m)
175/// * `ncomp` — Number of principal components per domain (amplitude and phase)
176/// * `n_samples` — Number of curves to generate
177/// * `balance_c` — Weight for balancing phase vs amplitude variance
178/// * `seed` — RNG seed for reproducibility
179///
180/// # Errors
181/// Returns `FdarError` on dimension mismatch or FPCA failure.
182#[must_use = "expensive computation whose result should not be discarded"]
183pub fn joint_gauss_model(
184    karcher: &KarcherMeanResult,
185    argvals: &[f64],
186    ncomp: usize,
187    n_samples: usize,
188    balance_c: f64,
189    seed: u64,
190) -> Result<GenerativeModelResult, FdarError> {
191    let (_n, m) = karcher.aligned_data.shape();
192    if argvals.len() != m {
193        return Err(FdarError::InvalidDimension {
194            parameter: "argvals",
195            expected: format!("length {m}"),
196            actual: format!("length {}", argvals.len()),
197        });
198    }
199    if ncomp < 1 {
200        return Err(FdarError::InvalidParameter {
201            parameter: "ncomp",
202            message: "ncomp must be >= 1".to_string(),
203        });
204    }
205    if n_samples < 1 {
206        return Err(FdarError::InvalidParameter {
207            parameter: "n_samples",
208            message: "n_samples must be >= 1".to_string(),
209        });
210    }
211
212    // Amplitude FPCA
213    let vert = vert_fpca(karcher, argvals, ncomp)?;
214    let vert_ncomp = vert.eigenvalues.len();
215    let m_aug = m + 1;
216
217    // Phase FPCA
218    let horiz = horiz_fpca(karcher, argvals, ncomp)?;
219    let horiz_ncomp = horiz.eigenvalues.len();
220
221    let total_ncomp = vert_ncomp + horiz_ncomp;
222    let n = karcher.aligned_data.nrows();
223
224    // Build joint score matrix: [vert_scores | balance_c * horiz_scores]
225    let mut joint_scores = FdMatrix::zeros(n, total_ncomp);
226    for i in 0..n {
227        for k in 0..vert_ncomp {
228            joint_scores[(i, k)] = vert.scores[(i, k)];
229        }
230        for k in 0..horiz_ncomp {
231            joint_scores[(i, vert_ncomp + k)] = balance_c * horiz.scores[(i, k)];
232        }
233    }
234
235    // Estimate joint covariance (diagonal for sampling)
236    let mut joint_mean = vec![0.0; total_ncomp];
237    for k in 0..total_ncomp {
238        for i in 0..n {
239            joint_mean[k] += joint_scores[(i, k)];
240        }
241        joint_mean[k] /= n as f64;
242    }
243
244    let mut joint_var = vec![0.0; total_ncomp];
245    for k in 0..total_ncomp {
246        for i in 0..n {
247            let diff = joint_scores[(i, k)] - joint_mean[k];
248            joint_var[k] += diff * diff;
249        }
250        joint_var[k] /= (n - 1).max(1) as f64;
251    }
252
253    // Sphere/warping setup
254    let t0 = argvals[0];
255    let t1 = argvals[m - 1];
256    let domain = t1 - t0;
257    let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
258
259    let psis = warps_to_normalized_psi(&karcher.gammas, argvals);
260    let mu_psi = sphere_karcher_mean(&psis, &time, 50);
261    let mean_q = &vert.mean_q;
262
263    let mut samples = FdMatrix::zeros(n_samples, m);
264    let mut warps_out = FdMatrix::zeros(n_samples, m);
265    let mut scores_out = FdMatrix::zeros(n_samples, total_ncomp);
266
267    for i in 0..n_samples {
268        let mut rng = StdRng::seed_from_u64(seed + i as u64);
269
270        // Sample from joint distribution
271        let mut joint_z = vec![0.0; total_ncomp];
272        for k in 0..total_ncomp {
273            let z: f64 = rng.sample(StandardNormal);
274            joint_z[k] = joint_mean[k] + z * joint_var[k].max(0.0).sqrt();
275            scores_out[(i, k)] = joint_z[k];
276        }
277
278        // Reconstruct amplitude from SRSF
279        let mut q_new = vec![0.0; m_aug];
280        q_new[..m_aug].copy_from_slice(&mean_q[..m_aug]);
281        for k in 0..vert_ncomp {
282            let score_k = joint_z[k];
283            for j in 0..m_aug {
284                q_new[j] += score_k * vert.eigenfunctions_q[(k, j)];
285            }
286        }
287        let aug_val = q_new[m];
288        let f0 = aug_val.signum() * aug_val * aug_val;
289        let f_new = srsf_inverse(&q_new[..m], argvals, f0);
290
291        // Reconstruct phase from shooting vector
292        let mut v = vec![0.0; m];
293        for k in 0..horiz_ncomp {
294            // Undo balance_c scaling
295            let score_k = if balance_c.abs() > 1e-15 {
296                joint_z[vert_ncomp + k] / balance_c
297            } else {
298                0.0
299            };
300            for j in 0..m {
301                v[j] += score_k * horiz.eigenfunctions_psi[(k, j)];
302            }
303        }
304
305        let psi_new = exp_map_sphere(&mu_psi, &v, &time);
306        let gam_01 = psi_to_gam(&psi_new, &time);
307        let mut gamma: Vec<f64> = gam_01.iter().map(|&g| t0 + g * domain).collect();
308        normalize_warp(&mut gamma, argvals);
309
310        let sample = reparameterize_curve(&f_new, argvals, &gamma);
311        for j in 0..m {
312            samples[(i, j)] = sample[j];
313            warps_out[(i, j)] = gamma[j];
314        }
315    }
316
317    Ok(GenerativeModelResult {
318        samples,
319        warps: warps_out,
320        scores: scores_out,
321    })
322}
323
324// ─── Helper ─────────────────────────────────────────────────────────────────
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use crate::alignment::karcher_mean;
330    use std::f64::consts::PI;
331
332    fn make_test_karcher(n: usize, m: usize) -> (KarcherMeanResult, Vec<f64>) {
333        let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
334        let mut data = FdMatrix::zeros(n, m);
335        for i in 0..n {
336            let shift = 0.1 * (i as f64 - n as f64 / 2.0);
337            let scale = 1.0 + 0.2 * (i as f64 / n as f64);
338            for j in 0..m {
339                data[(i, j)] = scale * (2.0 * PI * (t[j] + shift)).sin();
340            }
341        }
342        let km = karcher_mean(&data, &t, 10, 1e-4, 0.0);
343        (km, t)
344    }
345
346    #[test]
347    fn gauss_model_correct_shapes() {
348        let (km, t) = make_test_karcher(15, 51);
349        let ncomp = 3;
350        let n_samples = 10;
351        let result = gauss_model(&km, &t, ncomp, n_samples, 42).unwrap();
352
353        assert_eq!(result.samples.shape(), (n_samples, 51));
354        assert_eq!(result.warps.shape(), (n_samples, 51));
355        // scores is n_samples x (vert_ncomp + horiz_ncomp)
356        let (_, score_cols) = result.scores.shape();
357        assert!(
358            score_cols >= ncomp,
359            "scores should have at least ncomp columns, got {score_cols}"
360        );
361        assert_eq!(result.scores.nrows(), n_samples);
362    }
363
364    #[test]
365    fn gauss_model_reproducible() {
366        let (km, t) = make_test_karcher(15, 51);
367        let r1 = gauss_model(&km, &t, 3, 5, 42).unwrap();
368        let r2 = gauss_model(&km, &t, 3, 5, 42).unwrap();
369
370        assert_eq!(r1.samples, r2.samples);
371        assert_eq!(r1.warps, r2.warps);
372        assert_eq!(r1.scores, r2.scores);
373    }
374
375    #[test]
376    fn gauss_model_warps_valid() {
377        let (km, t) = make_test_karcher(15, 51);
378        let result = gauss_model(&km, &t, 3, 10, 99).unwrap();
379        let m = t.len();
380
381        for i in 0..result.warps.nrows() {
382            let warp = result.warps.row(i);
383
384            // Monotone non-decreasing
385            for j in 1..m {
386                assert!(
387                    warp[j] >= warp[j - 1] - 1e-12,
388                    "warp {i} not monotone at j={j}: {} < {}",
389                    warp[j],
390                    warp[j - 1]
391                );
392            }
393
394            // Correct boundary values
395            assert!(
396                (warp[0] - t[0]).abs() < 1e-10,
397                "warp {i} start: {} != {}",
398                warp[0],
399                t[0]
400            );
401            assert!(
402                (warp[m - 1] - t[m - 1]).abs() < 1e-10,
403                "warp {i} end: {} != {}",
404                warp[m - 1],
405                t[m - 1]
406            );
407        }
408    }
409
410    #[test]
411    fn joint_gauss_model_smoke() {
412        let (km, t) = make_test_karcher(15, 51);
413        let ncomp = 3;
414        let n_samples = 8;
415        let result = joint_gauss_model(&km, &t, ncomp, n_samples, 1.0, 42).unwrap();
416
417        assert_eq!(result.samples.shape(), (n_samples, 51));
418        assert_eq!(result.warps.shape(), (n_samples, 51));
419        assert_eq!(result.scores.nrows(), n_samples);
420
421        // All samples should be finite
422        for i in 0..n_samples {
423            for j in 0..51 {
424                assert!(
425                    result.samples[(i, j)].is_finite(),
426                    "sample ({i},{j}) is not finite"
427                );
428            }
429        }
430    }
431}