use crate::error::{Result, TimeSeriesError};
use crate::latent_ode::{
ode_func::{integrate_trajectory, OdeFunc},
recognition_rnn::RecognitionRnn,
types::{LatentOdeConfig, LatentOdeResult},
};
#[derive(Debug, Clone)]
struct Decoder {
w: Vec<f64>, b: Vec<f64>, latent_dim: usize,
out_dim: usize,
}
impl Decoder {
fn new(latent_dim: usize, out_dim: usize) -> Self {
let scale = (1.0_f64 / latent_dim as f64).sqrt();
let w = (0..out_dim * latent_dim)
.map(|k| {
let v = ((k as f64 * 1.73205080) % 2.0) - 1.0;
v * scale
})
.collect();
Self {
w,
b: vec![0.0; out_dim],
latent_dim,
out_dim,
}
}
fn decode(&self, z: &[f64]) -> Vec<f64> {
let mut out = self.b.clone();
for (i, oi) in out.iter_mut().enumerate() {
for (j, &zj) in z.iter().enumerate() {
*oi += self.w[i * self.latent_dim + j] * zj;
}
}
out
}
}
#[derive(Debug, Clone)]
struct AdamState {
m: Vec<f64>,
v: Vec<f64>,
t: u64,
lr: f64,
beta1: f64,
beta2: f64,
eps: f64,
}
impl AdamState {
fn new(n: usize, lr: f64) -> Self {
Self {
m: vec![0.0; n],
v: vec![0.0; n],
t: 0,
lr,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
}
}
fn step(&mut self, p: &mut [f64], g: &[f64]) {
self.t += 1;
let bc1 = 1.0 - self.beta1.powi(self.t as i32);
let bc2 = 1.0 - self.beta2.powi(self.t as i32);
for i in 0..p.len() {
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g[i];
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g[i] * g[i];
let m_hat = self.m[i] / bc1;
let v_hat = self.v[i] / bc2;
p[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
}
}
}
fn pseudo_eps(dim: usize, seed: u64) -> Vec<f64> {
let scale = 0.5_f64;
(0..dim)
.map(|i| {
let v = ((((seed.wrapping_add(i as u64))
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407))
>> 33) as f64)
/ (u32::MAX as f64)
* 2.0
- 1.0;
v * scale
})
.collect()
}
fn kl_normal(mu: &[f64], log_sigma: &[f64]) -> f64 {
mu.iter()
.zip(log_sigma.iter())
.map(|(&m, &ls)| {
let sigma2 = (2.0 * ls).exp();
0.5 * (sigma2 + m * m - 1.0 - 2.0 * ls)
})
.sum::<f64>()
}
fn flatten_layers(layers: &[crate::latent_ode::ode_func::LinearLayer]) -> Vec<f64> {
let mut flat = Vec::new();
for l in layers {
flat.extend_from_slice(&l.weights);
flat.extend_from_slice(&l.biases);
}
flat
}
fn unflatten_layers(flat: &[f64], layers: &mut Vec<crate::latent_ode::ode_func::LinearLayer>) {
let mut idx = 0;
for l in layers.iter_mut() {
let wn = l.weights.len();
let bn = l.biases.len();
l.weights.copy_from_slice(&flat[idx..idx + wn]);
idx += wn;
l.biases.copy_from_slice(&flat[idx..idx + bn]);
idx += bn;
}
}
fn flatten_gru(gru: &crate::latent_ode::recognition_rnn::GruCell) -> Vec<f64> {
let _ = gru;
vec![]
}
#[derive(Debug, Clone)]
pub struct LatentOde {
config: LatentOdeConfig,
input_dim: usize,
encoder: RecognitionRnn,
ode_func: OdeFunc,
decoder: Decoder,
z0_mu: Vec<f64>,
z0_log_sigma: Vec<f64>,
adam_dec: AdamState,
adam_ode: AdamState,
last_obs_times: Vec<f64>,
}
impl LatentOde {
const STEPS_PER_UNIT: usize = 10;
pub fn new(config: LatentOdeConfig, input_dim: usize) -> Result<Self> {
if input_dim == 0 {
return Err(TimeSeriesError::InvalidInput(
"input_dim must be > 0".to_string(),
));
}
let latent_dim = config.latent_dim;
let hidden_dim = config.hidden_dim;
let n_layers = config.n_layers;
let lr = config.learning_rate;
let encoder = RecognitionRnn::new(input_dim + 1, hidden_dim, latent_dim);
let ode_func = OdeFunc::new(latent_dim, hidden_dim, n_layers);
let decoder = Decoder::new(latent_dim, input_dim);
let n_dec = decoder.w.len() + decoder.b.len();
let n_ode = flatten_layers(ode_func.layers()).len();
Ok(Self {
config,
input_dim,
encoder,
ode_func,
decoder,
z0_mu: vec![0.0; latent_dim],
z0_log_sigma: vec![-1.0; latent_dim],
adam_dec: AdamState::new(n_dec, lr),
adam_ode: AdamState::new(n_ode, lr),
last_obs_times: vec![],
})
}
fn sample_z0(&self, epoch: u64) -> Vec<f64> {
let eps = pseudo_eps(self.config.latent_dim, epoch);
self.z0_mu
.iter()
.zip(self.z0_log_sigma.iter())
.zip(eps.iter())
.map(|((&m, &ls), &e)| m + e * ls.exp())
.collect()
}
fn forward(
&self,
observations: &[(f64, Vec<f64>)],
epoch: u64,
) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, f64, f64) {
let (mu, log_sigma) = self.encoder.encode(observations);
let kl = kl_normal(&mu, &log_sigma);
let eps = pseudo_eps(self.config.latent_dim, epoch);
let z0: Vec<f64> = mu
.iter()
.zip(log_sigma.iter())
.zip(eps.iter())
.map(|((&m, &ls), &e)| m + e * ls.exp())
.collect();
let times: Vec<f64> = observations.iter().map(|(t, _)| *t).collect();
let t0 = times.first().copied().unwrap_or(0.0);
let t_range = times.last().copied().unwrap_or(1.0) - t0;
let n_steps = (Self::STEPS_PER_UNIT as f64 * t_range.abs()).ceil() as usize + 1;
let latent_traj = integrate_trajectory(&self.ode_func, &z0, t0, ×, n_steps);
let predictions: Vec<Vec<f64>> =
latent_traj.iter().map(|z| self.decoder.decode(z)).collect();
let recon_loss = predictions
.iter()
.zip(observations.iter())
.map(|(pred, (_, obs))| {
pred.iter()
.zip(obs.iter())
.map(|(&p, &o)| (p - o).powi(2))
.sum::<f64>()
})
.sum::<f64>()
/ predictions.len().max(1) as f64;
(predictions, latent_traj, recon_loss, kl)
}
fn decoder_gradient(&self, observations: &[(f64, Vec<f64>)], epoch: u64) -> Vec<f64> {
let eps_fd = 1e-4;
let base = self.elbo_loss(observations, epoch);
let n_w = self.decoder.w.len();
let n_b = self.decoder.b.len();
let mut grad = vec![0.0_f64; n_w + n_b];
for k in 0..n_w {
let old = self.decoder.w[k];
let delta = elbo_with_dec_perturb(self, observations, epoch, k, eps_fd, true);
grad[k] = (delta - base) / eps_fd;
let _ = old; }
for k in 0..n_b {
let delta = elbo_with_dec_perturb(self, observations, epoch, k, eps_fd, false);
grad[n_w + k] = (delta - base) / eps_fd;
}
grad
}
fn elbo_loss(&self, observations: &[(f64, Vec<f64>)], epoch: u64) -> f64 {
let (_, _, recon, kl) = self.forward(observations, epoch);
recon + kl
}
pub fn fit(
&mut self,
observations: &[(f64, Vec<f64>)],
n_epochs: usize,
) -> Result<LatentOdeResult> {
if observations.is_empty() {
return Err(TimeSeriesError::InvalidInput(
"observations must not be empty".to_string(),
));
}
let obs_dim = observations[0].1.len();
if obs_dim != self.input_dim {
return Err(TimeSeriesError::DimensionMismatch {
expected: self.input_dim,
actual: obs_dim,
});
}
for w in observations.windows(2) {
if w[0].0 > w[1].0 {
return Err(TimeSeriesError::InvalidInput(
"observations must be sorted by time".to_string(),
));
}
}
self.last_obs_times = observations.iter().map(|(t, _)| *t).collect();
let mut last_predictions = vec![];
let mut last_latent = vec![];
let mut last_recon = 0.0;
let mut last_kl = 0.0;
for epoch in 0..n_epochs as u64 {
let (mu, log_sigma) = self.encoder.encode(observations);
self.z0_mu = mu;
self.z0_log_sigma = log_sigma;
let dec_grad = self.decoder_gradient(observations, epoch);
let n_w = self.decoder.w.len();
let n_b = self.decoder.b.len();
{
let mut all_params: Vec<f64> = self
.decoder
.w
.iter()
.chain(self.decoder.b.iter())
.copied()
.collect();
self.adam_dec.step(&mut all_params, &dec_grad);
self.decoder.w.copy_from_slice(&all_params[..n_w]);
self.decoder.b.copy_from_slice(&all_params[n_w..n_w + n_b]);
}
let (preds, latent, recon, kl) = self.forward(observations, epoch);
last_predictions = preds;
last_latent = latent;
last_recon = recon;
last_kl = kl;
}
let times: Vec<f64> = observations.iter().map(|(t, _)| *t).collect();
Ok(LatentOdeResult {
predicted_times: times,
predicted_values: last_predictions,
latent_trajectory: last_latent,
reconstruction_loss: last_recon,
kl_divergence: last_kl,
})
}
pub fn predict(&self, query_times: &[f64]) -> Result<LatentOdeResult> {
if query_times.is_empty() {
return Err(TimeSeriesError::InvalidInput(
"query_times must not be empty".to_string(),
));
}
let z0 = self.sample_z0(999);
let t0 = self
.last_obs_times
.first()
.copied()
.unwrap_or_else(|| query_times[0]);
let t_range = query_times.last().copied().unwrap_or(1.0) - t0;
let n_steps = (Self::STEPS_PER_UNIT as f64 * t_range.abs()).ceil() as usize + 1;
let latent_traj = integrate_trajectory(&self.ode_func, &z0, t0, query_times, n_steps);
let predictions: Vec<Vec<f64>> =
latent_traj.iter().map(|z| self.decoder.decode(z)).collect();
let kl = kl_normal(&self.z0_mu, &self.z0_log_sigma);
Ok(LatentOdeResult {
predicted_times: query_times.to_vec(),
predicted_values: predictions,
latent_trajectory: latent_traj,
reconstruction_loss: 0.0, kl_divergence: kl,
})
}
}
fn elbo_with_dec_perturb(
model: &LatentOde,
observations: &[(f64, Vec<f64>)],
epoch: u64,
k: usize,
eps_fd: f64,
is_weight: bool,
) -> f64 {
let mut perturbed = model.clone();
if is_weight {
perturbed.decoder.w[k] += eps_fd;
} else {
perturbed.decoder.b[k] += eps_fd;
}
perturbed.elbo_loss(observations, epoch)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_obs(n: usize) -> Vec<(f64, Vec<f64>)> {
(0..n)
.map(|i| (i as f64 * 0.1, vec![(i as f64 * 0.1).sin()]))
.collect()
}
#[test]
fn latent_ode_new_and_fit() {
let config = LatentOdeConfig {
latent_dim: 4,
hidden_dim: 8,
n_layers: 1,
..Default::default()
};
let mut model = LatentOde::new(config, 1).expect("new");
let obs = make_obs(5);
let result = model.fit(&obs, 2).expect("fit");
assert_eq!(result.predicted_times.len(), 5);
assert_eq!(result.predicted_values.len(), 5);
assert_eq!(result.predicted_values[0].len(), 1);
}
#[test]
fn latent_ode_predict_shape() {
let config = LatentOdeConfig {
latent_dim: 4,
hidden_dim: 8,
n_layers: 1,
..Default::default()
};
let mut model = LatentOde::new(config, 1).expect("new");
let obs = make_obs(5);
model.fit(&obs, 1).expect("fit");
let query = vec![0.0, 0.25, 0.5, 0.75, 1.0];
let result = model.predict(&query).expect("predict");
assert_eq!(result.predicted_times.len(), 5);
assert_eq!(result.predicted_values.len(), 5);
}
#[test]
fn latent_ode_config_default() {
let cfg = LatentOdeConfig::default();
assert_eq!(cfg.latent_dim, 16);
assert_eq!(cfg.hidden_dim, 64);
assert_eq!(cfg.n_layers, 2);
}
#[test]
fn latent_ode_kl_positive() {
let config = LatentOdeConfig {
latent_dim: 4,
hidden_dim: 8,
n_layers: 1,
..Default::default()
};
let mut model = LatentOde::new(config, 1).expect("new");
let obs = make_obs(5);
let result = model.fit(&obs, 3).expect("fit");
assert!(result.kl_divergence >= 0.0);
}
#[test]
fn latent_ode_latent_trajectory_changes_over_time() {
let config = LatentOdeConfig {
latent_dim: 4,
hidden_dim: 8,
n_layers: 1,
..Default::default()
};
let mut model = LatentOde::new(config, 1).expect("new");
let obs = make_obs(5);
let result = model.fit(&obs, 2).expect("fit");
let first = &result.latent_trajectory[0];
let last = &result.latent_trajectory[result.latent_trajectory.len() - 1];
let diff: f64 = first
.iter()
.zip(last.iter())
.map(|(&a, &b)| (a - b).abs())
.sum();
assert!(diff >= 0.0); }
#[test]
fn latent_ode_empty_obs_error() {
let mut model = LatentOde::new(LatentOdeConfig::default(), 1).expect("new");
assert!(model.fit(&[], 1).is_err());
}
#[test]
fn latent_ode_dim_mismatch_error() {
let mut model = LatentOde::new(LatentOdeConfig::default(), 2).expect("new");
let obs = vec![(0.0, vec![1.0])]; assert!(model.fit(&obs, 1).is_err());
}
}