use crate::sd_fm::SdFmTrainConfig;
use crate::{Error, Result};
use ndarray::{Array1, ArrayView1, ArrayView2};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use skel::Manifold;
pub trait ManifoldVectorField {
fn eval(&self, x: &ArrayView1<f32>, t: f32) -> Array1<f32>;
fn sgd_step(&mut self, x: &ArrayView1<f32>, t: f32, target: &ArrayView1<f32>, lr: f32);
}
pub fn train_riemannian_fm<M, F, S>(
manifold: &M,
field: &mut F,
x1_samples: &ArrayView2<f32>,
cfg: &SdFmTrainConfig,
sample_x0: S,
) -> Result<()>
where
M: Manifold,
F: ManifoldVectorField,
S: Fn(&mut ChaCha8Rng) -> Array1<f32>,
{
let n_data = x1_samples.nrows();
if n_data == 0 {
return Err(Error::Domain("x1_samples must be non-empty"));
}
let mut rng = ChaCha8Rng::seed_from_u64(cfg.seed);
for _step in 0..cfg.steps {
for _ in 0..cfg.batch_size {
let t = cfg.t_schedule.sample_t(&mut rng);
let x0_f32 = sample_x0(&mut rng);
let idx = rng.random_range(0..n_data);
let x1_f32 = x1_samples.row(idx);
let x0 = x0_f32.mapv(|v| v as f64);
let x1 = x1_f32.mapv(|v| v as f64);
let v_init = manifold.log_map(&x0.view(), &x1.view());
let t_v_init = &v_init * (t as f64);
let xt = manifold.exp_map(&x0.view(), &t_v_init.view());
let ut = manifold.parallel_transport(&x0.view(), &xt.view(), &v_init.view());
let xt_f32 = xt.mapv(|v| v as f32);
let ut_f32 = ut.mapv(|v| v as f32);
field.sgd_step(&xt_f32.view(), t, &ut_f32.view(), cfg.lr);
}
}
Ok(())
}