1use super::srsf::{reparameterize_curve, srsf_inverse};
4use super::KarcherMeanResult;
5use crate::elastic_fpca::{horiz_fpca, sphere_karcher_mean, vert_fpca, warps_to_normalized_psi};
6use crate::error::FdarError;
7use crate::matrix::FdMatrix;
8use crate::warping::{exp_map_sphere, normalize_warp, psi_to_gam};
9
10use rand::prelude::*;
11use rand_distr::StandardNormal;
12
13#[derive(Debug, Clone, PartialEq)]
17#[non_exhaustive]
18pub struct GenerativeModelResult {
19 pub samples: FdMatrix,
21 pub warps: FdMatrix,
23 pub scores: FdMatrix,
25}
26
27#[must_use = "expensive computation whose result should not be discarded"]
46pub fn gauss_model(
47 karcher: &KarcherMeanResult,
48 argvals: &[f64],
49 ncomp: usize,
50 n_samples: usize,
51 seed: u64,
52) -> Result<GenerativeModelResult, FdarError> {
53 let (n, m) = karcher.aligned_data.shape();
54 if argvals.len() != m {
55 return Err(FdarError::InvalidDimension {
56 parameter: "argvals",
57 expected: format!("length {m}"),
58 actual: format!("length {}", argvals.len()),
59 });
60 }
61 if n < 2 || m < 2 {
62 return Err(FdarError::InvalidDimension {
63 parameter: "aligned_data",
64 expected: "n >= 2, m >= 2".to_string(),
65 actual: format!("n={n}, m={m}"),
66 });
67 }
68 if ncomp < 1 {
69 return Err(FdarError::InvalidParameter {
70 parameter: "ncomp",
71 message: "ncomp must be >= 1".to_string(),
72 });
73 }
74 if n_samples < 1 {
75 return Err(FdarError::InvalidParameter {
76 parameter: "n_samples",
77 message: "n_samples must be >= 1".to_string(),
78 });
79 }
80
81 let vert = vert_fpca(karcher, argvals, ncomp)?;
83 let vert_ncomp = vert.eigenvalues.len();
84 let m_aug = m + 1;
85
86 let horiz = horiz_fpca(karcher, argvals, ncomp)?;
88 let horiz_ncomp = horiz.eigenvalues.len();
89
90 let t0 = argvals[0];
91 let t1 = argvals[m - 1];
92 let domain = t1 - t0;
93 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
94
95 let psis = warps_to_normalized_psi(&karcher.gammas, argvals);
97 let mu_psi = sphere_karcher_mean(&psis, &time, 50);
98
99 let mean_q = &vert.mean_q;
101
102 let total_ncomp = vert_ncomp + horiz_ncomp;
103 let mut samples = FdMatrix::zeros(n_samples, m);
104 let mut warps = FdMatrix::zeros(n_samples, m);
105 let mut scores = FdMatrix::zeros(n_samples, total_ncomp);
106
107 for i in 0..n_samples {
108 let mut rng = StdRng::seed_from_u64(seed + i as u64);
109
110 let mut q_new = vec![0.0; m_aug];
112 q_new[..m_aug].copy_from_slice(&mean_q[..m_aug]);
113 for k in 0..vert_ncomp {
114 let std_dev = vert.eigenvalues[k].max(0.0).sqrt();
115 let z: f64 = rng.sample(StandardNormal);
116 let score_k = z * std_dev;
117 scores[(i, k)] = score_k;
118 for j in 0..m_aug {
119 q_new[j] += score_k * vert.eigenfunctions_q[(k, j)];
120 }
121 }
122
123 let aug_val = q_new[m];
125 let f0 = aug_val.signum() * aug_val * aug_val;
126 let f_new = srsf_inverse(&q_new[..m], argvals, f0);
127
128 let mut v = vec![0.0; m];
130 for k in 0..horiz_ncomp {
131 let std_dev = horiz.eigenvalues[k].max(0.0).sqrt();
132 let z: f64 = rng.sample(StandardNormal);
133 let score_k = z * std_dev;
134 scores[(i, vert_ncomp + k)] = score_k;
135 for j in 0..m {
136 v[j] += score_k * horiz.eigenfunctions_psi[(k, j)];
137 }
138 }
139
140 let psi_new = exp_map_sphere(&mu_psi, &v, &time);
142 let gam_01 = psi_to_gam(&psi_new, &time);
143
144 let mut gamma: Vec<f64> = gam_01.iter().map(|&g| t0 + g * domain).collect();
146 normalize_warp(&mut gamma, argvals);
147
148 let sample = reparameterize_curve(&f_new, argvals, &gamma);
150
151 for j in 0..m {
152 samples[(i, j)] = sample[j];
153 warps[(i, j)] = gamma[j];
154 }
155 }
156
157 Ok(GenerativeModelResult {
158 samples,
159 warps,
160 scores,
161 })
162}
163
164#[must_use = "expensive computation whose result should not be discarded"]
183pub fn joint_gauss_model(
184 karcher: &KarcherMeanResult,
185 argvals: &[f64],
186 ncomp: usize,
187 n_samples: usize,
188 balance_c: f64,
189 seed: u64,
190) -> Result<GenerativeModelResult, FdarError> {
191 let (_n, m) = karcher.aligned_data.shape();
192 if argvals.len() != m {
193 return Err(FdarError::InvalidDimension {
194 parameter: "argvals",
195 expected: format!("length {m}"),
196 actual: format!("length {}", argvals.len()),
197 });
198 }
199 if ncomp < 1 {
200 return Err(FdarError::InvalidParameter {
201 parameter: "ncomp",
202 message: "ncomp must be >= 1".to_string(),
203 });
204 }
205 if n_samples < 1 {
206 return Err(FdarError::InvalidParameter {
207 parameter: "n_samples",
208 message: "n_samples must be >= 1".to_string(),
209 });
210 }
211
212 let vert = vert_fpca(karcher, argvals, ncomp)?;
214 let vert_ncomp = vert.eigenvalues.len();
215 let m_aug = m + 1;
216
217 let horiz = horiz_fpca(karcher, argvals, ncomp)?;
219 let horiz_ncomp = horiz.eigenvalues.len();
220
221 let total_ncomp = vert_ncomp + horiz_ncomp;
222 let n = karcher.aligned_data.nrows();
223
224 let mut joint_scores = FdMatrix::zeros(n, total_ncomp);
226 for i in 0..n {
227 for k in 0..vert_ncomp {
228 joint_scores[(i, k)] = vert.scores[(i, k)];
229 }
230 for k in 0..horiz_ncomp {
231 joint_scores[(i, vert_ncomp + k)] = balance_c * horiz.scores[(i, k)];
232 }
233 }
234
235 let mut joint_mean = vec![0.0; total_ncomp];
237 for k in 0..total_ncomp {
238 for i in 0..n {
239 joint_mean[k] += joint_scores[(i, k)];
240 }
241 joint_mean[k] /= n as f64;
242 }
243
244 let mut joint_var = vec![0.0; total_ncomp];
245 for k in 0..total_ncomp {
246 for i in 0..n {
247 let diff = joint_scores[(i, k)] - joint_mean[k];
248 joint_var[k] += diff * diff;
249 }
250 joint_var[k] /= (n - 1).max(1) as f64;
251 }
252
253 let t0 = argvals[0];
255 let t1 = argvals[m - 1];
256 let domain = t1 - t0;
257 let time: Vec<f64> = (0..m).map(|i| i as f64 / (m - 1) as f64).collect();
258
259 let psis = warps_to_normalized_psi(&karcher.gammas, argvals);
260 let mu_psi = sphere_karcher_mean(&psis, &time, 50);
261 let mean_q = &vert.mean_q;
262
263 let mut samples = FdMatrix::zeros(n_samples, m);
264 let mut warps_out = FdMatrix::zeros(n_samples, m);
265 let mut scores_out = FdMatrix::zeros(n_samples, total_ncomp);
266
267 for i in 0..n_samples {
268 let mut rng = StdRng::seed_from_u64(seed + i as u64);
269
270 let mut joint_z = vec![0.0; total_ncomp];
272 for k in 0..total_ncomp {
273 let z: f64 = rng.sample(StandardNormal);
274 joint_z[k] = joint_mean[k] + z * joint_var[k].max(0.0).sqrt();
275 scores_out[(i, k)] = joint_z[k];
276 }
277
278 let mut q_new = vec![0.0; m_aug];
280 q_new[..m_aug].copy_from_slice(&mean_q[..m_aug]);
281 for k in 0..vert_ncomp {
282 let score_k = joint_z[k];
283 for j in 0..m_aug {
284 q_new[j] += score_k * vert.eigenfunctions_q[(k, j)];
285 }
286 }
287 let aug_val = q_new[m];
288 let f0 = aug_val.signum() * aug_val * aug_val;
289 let f_new = srsf_inverse(&q_new[..m], argvals, f0);
290
291 let mut v = vec![0.0; m];
293 for k in 0..horiz_ncomp {
294 let score_k = if balance_c.abs() > 1e-15 {
296 joint_z[vert_ncomp + k] / balance_c
297 } else {
298 0.0
299 };
300 for j in 0..m {
301 v[j] += score_k * horiz.eigenfunctions_psi[(k, j)];
302 }
303 }
304
305 let psi_new = exp_map_sphere(&mu_psi, &v, &time);
306 let gam_01 = psi_to_gam(&psi_new, &time);
307 let mut gamma: Vec<f64> = gam_01.iter().map(|&g| t0 + g * domain).collect();
308 normalize_warp(&mut gamma, argvals);
309
310 let sample = reparameterize_curve(&f_new, argvals, &gamma);
311 for j in 0..m {
312 samples[(i, j)] = sample[j];
313 warps_out[(i, j)] = gamma[j];
314 }
315 }
316
317 Ok(GenerativeModelResult {
318 samples,
319 warps: warps_out,
320 scores: scores_out,
321 })
322}
323
324#[cfg(test)]
327mod tests {
328 use super::*;
329 use crate::alignment::karcher_mean;
330 use std::f64::consts::PI;
331
332 fn make_test_karcher(n: usize, m: usize) -> (KarcherMeanResult, Vec<f64>) {
333 let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
334 let mut data = FdMatrix::zeros(n, m);
335 for i in 0..n {
336 let shift = 0.1 * (i as f64 - n as f64 / 2.0);
337 let scale = 1.0 + 0.2 * (i as f64 / n as f64);
338 for j in 0..m {
339 data[(i, j)] = scale * (2.0 * PI * (t[j] + shift)).sin();
340 }
341 }
342 let km = karcher_mean(&data, &t, 10, 1e-4, 0.0);
343 (km, t)
344 }
345
346 #[test]
347 fn gauss_model_correct_shapes() {
348 let (km, t) = make_test_karcher(15, 51);
349 let ncomp = 3;
350 let n_samples = 10;
351 let result = gauss_model(&km, &t, ncomp, n_samples, 42).unwrap();
352
353 assert_eq!(result.samples.shape(), (n_samples, 51));
354 assert_eq!(result.warps.shape(), (n_samples, 51));
355 let (_, score_cols) = result.scores.shape();
357 assert!(
358 score_cols >= ncomp,
359 "scores should have at least ncomp columns, got {score_cols}"
360 );
361 assert_eq!(result.scores.nrows(), n_samples);
362 }
363
364 #[test]
365 fn gauss_model_reproducible() {
366 let (km, t) = make_test_karcher(15, 51);
367 let r1 = gauss_model(&km, &t, 3, 5, 42).unwrap();
368 let r2 = gauss_model(&km, &t, 3, 5, 42).unwrap();
369
370 assert_eq!(r1.samples, r2.samples);
371 assert_eq!(r1.warps, r2.warps);
372 assert_eq!(r1.scores, r2.scores);
373 }
374
375 #[test]
376 fn gauss_model_warps_valid() {
377 let (km, t) = make_test_karcher(15, 51);
378 let result = gauss_model(&km, &t, 3, 10, 99).unwrap();
379 let m = t.len();
380
381 for i in 0..result.warps.nrows() {
382 let warp = result.warps.row(i);
383
384 for j in 1..m {
386 assert!(
387 warp[j] >= warp[j - 1] - 1e-12,
388 "warp {i} not monotone at j={j}: {} < {}",
389 warp[j],
390 warp[j - 1]
391 );
392 }
393
394 assert!(
396 (warp[0] - t[0]).abs() < 1e-10,
397 "warp {i} start: {} != {}",
398 warp[0],
399 t[0]
400 );
401 assert!(
402 (warp[m - 1] - t[m - 1]).abs() < 1e-10,
403 "warp {i} end: {} != {}",
404 warp[m - 1],
405 t[m - 1]
406 );
407 }
408 }
409
410 #[test]
411 fn joint_gauss_model_smoke() {
412 let (km, t) = make_test_karcher(15, 51);
413 let ncomp = 3;
414 let n_samples = 8;
415 let result = joint_gauss_model(&km, &t, ncomp, n_samples, 1.0, 42).unwrap();
416
417 assert_eq!(result.samples.shape(), (n_samples, 51));
418 assert_eq!(result.warps.shape(), (n_samples, 51));
419 assert_eq!(result.scores.nrows(), n_samples);
420
421 for i in 0..n_samples {
423 for j in 0..51 {
424 assert!(
425 result.samples[(i, j)].is_finite(),
426 "sample ({i},{j}) is not finite"
427 );
428 }
429 }
430 }
431}