use crate::error::{Result, TimeSeriesError};
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub struct CnfConfig {
pub dim: usize,
pub hidden_dim: usize,
pub n_layers: usize,
pub t0: f64,
pub t1: f64,
pub rtol: f64,
pub atol: f64,
pub n_steps: usize,
pub fd_eps: f64,
}
impl Default for CnfConfig {
fn default() -> Self {
Self {
dim: 2,
hidden_dim: 32,
n_layers: 3,
t0: 0.0,
t1: 1.0,
rtol: 1e-5,
atol: 1e-5,
n_steps: 50,
fd_eps: 1e-5,
}
}
}
#[derive(Debug, Clone)]
pub struct MlpDynamics {
pub dim: usize,
pub hidden_dim: usize,
pub n_layers: usize,
pub params: Vec<f64>,
}
impl MlpDynamics {
pub fn new(dim: usize, hidden_dim: usize, n_layers: usize) -> Self {
let n_layers = n_layers.max(1);
let mut params = Vec::new();
let mut rng_state: u64 = 0xdeadbeef_cafebabe;
let layer_sizes: Vec<(usize, usize)> = {
let mut v = Vec::new();
let input_dim = dim + 1; v.push((hidden_dim, input_dim));
for _ in 1..n_layers {
v.push((hidden_dim, hidden_dim));
}
v.push((dim, hidden_dim));
v
};
for (out, inp) in &layer_sizes {
let fan_in = *inp as f64;
let bound = (6.0_f64 / fan_in).sqrt();
for _ in 0..(out * inp) {
let r = lcg_next(&mut rng_state);
params.push((r * 2.0 - 1.0) * bound);
}
for _ in 0..*out {
params.push(0.0);
}
}
Self {
dim,
hidden_dim,
n_layers,
params,
}
}
pub fn n_params(&self) -> usize {
self.params.len()
}
pub fn forward(&self, z: &[f64], t: f64) -> Vec<f64> {
let mut x: Vec<f64> = z.to_vec();
x.push(t);
let n_layers = self.n_layers;
let n_hidden_layers = n_layers; let total_layers = n_hidden_layers + 1;
let input_dim = self.dim + 1;
let layer_sizes: Vec<(usize, usize)> = {
let mut v = Vec::new();
v.push((self.hidden_dim, input_dim));
for _ in 1..n_hidden_layers {
v.push((self.hidden_dim, self.hidden_dim));
}
v.push((self.dim, self.hidden_dim));
v
};
let mut offset = 0usize;
let mut h: Vec<f64> = x;
for (layer_idx, (out, inp)) in layer_sizes.iter().enumerate() {
let w_size = out * inp;
let b_size = *out;
let w = &self.params[offset..offset + w_size];
let b = &self.params[offset + w_size..offset + w_size + b_size];
offset += w_size + b_size;
let mut next_h = vec![0.0f64; *out];
for i in 0..*out {
let mut s = b[i];
for j in 0..h.len() {
s += w[i * inp + j] * h[j];
}
let is_last = layer_idx == total_layers - 1;
next_h[i] = if is_last { s } else { s.tanh() };
}
h = next_h;
}
h
}
}
pub fn rk4_step<F>(f: &F, z: &[f64], t: f64, dt: f64) -> Vec<f64>
where
F: Fn(&[f64], f64) -> Vec<f64>,
{
let d = z.len();
let k1 = f(z, t);
let z2: Vec<f64> = (0..d).map(|i| z[i] + 0.5 * dt * k1[i]).collect();
let k2 = f(&z2, t + 0.5 * dt);
let z3: Vec<f64> = (0..d).map(|i| z[i] + 0.5 * dt * k2[i]).collect();
let k3 = f(&z3, t + 0.5 * dt);
let z4: Vec<f64> = (0..d).map(|i| z[i] + dt * k3[i]).collect();
let k4 = f(&z4, t + dt);
(0..d)
.map(|i| z[i] + (dt / 6.0) * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]))
.collect()
}
fn hutchinson_trace(mlp_dyn: &MlpDynamics, z: &[f64], t: f64, v: &[f64], eps: f64) -> f64 {
let d = z.len();
let zp: Vec<f64> = (0..d).map(|i| z[i] + eps * v[i]).collect();
let zm: Vec<f64> = (0..d).map(|i| z[i] - eps * v[i]).collect();
let fp = mlp_dyn.forward(&zp, t);
let fm = mlp_dyn.forward(&zm, t);
let jv: Vec<f64> = (0..d).map(|i| (fp[i] - fm[i]) / (2.0 * eps)).collect();
(0..d).map(|i| v[i] * jv[i]).sum()
}
#[derive(Debug, Clone)]
pub struct CnfModel {
pub dynamics: MlpDynamics,
pub config: CnfConfig,
rng_state: u64,
}
impl CnfModel {
pub fn new(config: CnfConfig) -> Self {
let dynamics = MlpDynamics::new(config.dim, config.hidden_dim, config.n_layers);
Self {
dynamics,
config,
rng_state: 0x12345678_9abcdef0,
}
}
fn rademacher(&mut self, d: usize) -> Vec<f64> {
let mut v = Vec::with_capacity(d);
for _ in 0..d {
let r = lcg_next(&mut self.rng_state);
v.push(if r < 0.5 { -1.0 } else { 1.0 });
}
v
}
pub fn forward(&mut self, z0: &[f64]) -> (Vec<f64>, f64) {
let d = z0.len();
let t0 = self.config.t0;
let t1 = self.config.t1;
let n_steps = self.config.n_steps;
let dt = (t1 - t0) / n_steps as f64;
let fd_eps = self.config.fd_eps;
let mut z = z0.to_vec();
let mut log_det = 0.0f64;
let v = self.rademacher(d);
for step in 0..n_steps {
let t = t0 + step as f64 * dt;
let dyn_ref = &self.dynamics;
let v_ref = &v;
let fd_e = fd_eps;
let f = |zs: &[f64], ts: f64| -> Vec<f64> { dyn_ref.forward(zs, ts) };
let k1_z = f(&z, t);
let trace1 = hutchinson_trace(dyn_ref, &z, t, v_ref, fd_e);
let z2: Vec<f64> = (0..d).map(|i| z[i] + 0.5 * dt * k1_z[i]).collect();
let k2_z = f(&z2, t + 0.5 * dt);
let trace2 = hutchinson_trace(dyn_ref, &z2, t + 0.5 * dt, v_ref, fd_e);
let z3: Vec<f64> = (0..d).map(|i| z[i] + 0.5 * dt * k2_z[i]).collect();
let k3_z = f(&z3, t + 0.5 * dt);
let trace3 = hutchinson_trace(dyn_ref, &z3, t + 0.5 * dt, v_ref, fd_e);
let z4: Vec<f64> = (0..d).map(|i| z[i] + dt * k3_z[i]).collect();
let k4_z = f(&z4, t + dt);
let trace4 = hutchinson_trace(dyn_ref, &z4, t + dt, v_ref, fd_e);
for i in 0..d {
z[i] += (dt / 6.0) * (k1_z[i] + 2.0 * k2_z[i] + 2.0 * k3_z[i] + k4_z[i]);
}
let avg_trace = (trace1 + 2.0 * trace2 + 2.0 * trace3 + trace4) / 6.0;
log_det -= dt * avg_trace;
}
(z, log_det)
}
pub fn backward(&mut self, x: &[f64]) -> (Vec<f64>, f64) {
let d = x.len();
let t0 = self.config.t0;
let t1 = self.config.t1;
let n_steps = self.config.n_steps;
let dt = (t1 - t0) / n_steps as f64; let fd_eps = self.config.fd_eps;
let mut z = x.to_vec();
let mut log_det = 0.0f64;
let v = self.rademacher(d);
for step in 0..n_steps {
let t = t1 - step as f64 * dt;
let dyn_ref = &self.dynamics;
let v_ref = &v;
let fd_e = fd_eps;
let f = |zs: &[f64], ts: f64| -> Vec<f64> { dyn_ref.forward(zs, ts) };
let k1_z = f(&z, t);
let trace1 = hutchinson_trace(dyn_ref, &z, t, v_ref, fd_e);
let z2: Vec<f64> = (0..d).map(|i| z[i] - 0.5 * dt * k1_z[i]).collect();
let k2_z = f(&z2, t - 0.5 * dt);
let trace2 = hutchinson_trace(dyn_ref, &z2, t - 0.5 * dt, v_ref, fd_e);
let z3: Vec<f64> = (0..d).map(|i| z[i] - 0.5 * dt * k2_z[i]).collect();
let k3_z = f(&z3, t - 0.5 * dt);
let trace3 = hutchinson_trace(dyn_ref, &z3, t - 0.5 * dt, v_ref, fd_e);
let z4: Vec<f64> = (0..d).map(|i| z[i] - dt * k3_z[i]).collect();
let k4_z = f(&z4, t - dt);
let trace4 = hutchinson_trace(dyn_ref, &z4, t - dt, v_ref, fd_e);
for i in 0..d {
z[i] -= (dt / 6.0) * (k1_z[i] + 2.0 * k2_z[i] + 2.0 * k3_z[i] + k4_z[i]);
}
let avg_trace = (trace1 + 2.0 * trace2 + 2.0 * trace3 + trace4) / 6.0;
log_det += dt * avg_trace;
}
(z, log_det)
}
pub fn log_prob(&mut self, x: &[f64]) -> f64 {
let d = x.len();
let (z0, delta_log_p) = self.backward(x);
let log_base =
-0.5 * d as f64 * (2.0 * PI).ln() - 0.5 * z0.iter().map(|&v| v * v).sum::<f64>();
log_base + delta_log_p
}
pub fn sample(&mut self, n: usize) -> Vec<Vec<f64>> {
let d = self.config.dim;
let mut samples = Vec::with_capacity(n);
for _ in 0..n {
let z0 = self.sample_base(d);
let (x, _) = self.forward(&z0);
samples.push(x);
}
samples
}
fn sample_base(&mut self, d: usize) -> Vec<f64> {
let mut z = Vec::with_capacity(d);
let mut i = 0;
while i < d {
let u1 = lcg_next(&mut self.rng_state).max(1e-12);
let u2 = lcg_next(&mut self.rng_state);
let r = (-2.0 * u1.ln()).sqrt();
let theta = 2.0 * PI * u2;
z.push(r * theta.cos());
if i + 1 < d {
z.push(r * theta.sin());
}
i += 2;
}
z.truncate(d);
z
}
pub fn train_step(&mut self, batch: &[Vec<f64>], lr: f64) -> Result<f64> {
if batch.is_empty() {
return Err(TimeSeriesError::InvalidInput(
"train_step: empty batch".to_string(),
));
}
let n = batch.len();
let n_params = self.dynamics.n_params();
let baseline_nll = self.batch_nll(batch);
let fd_step = 1e-4;
let mut grad = vec![0.0f64; n_params];
for p_idx in 0..n_params {
let orig = self.dynamics.params[p_idx];
self.dynamics.params[p_idx] = orig + fd_step;
let nll_plus = self.batch_nll(batch);
self.dynamics.params[p_idx] = orig - fd_step;
let nll_minus = self.batch_nll(batch);
self.dynamics.params[p_idx] = orig;
grad[p_idx] = (nll_plus - nll_minus) / (2.0 * fd_step);
}
for p_idx in 0..n_params {
self.dynamics.params[p_idx] -= lr * grad[p_idx];
}
Ok(baseline_nll / n as f64)
}
fn batch_nll(&mut self, batch: &[Vec<f64>]) -> f64 {
let total: f64 = batch.iter().map(|x| self.log_prob(x)).sum();
-total
}
}
pub(crate) fn lcg_next(state: &mut u64) -> f64 {
*state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let bits = (*state >> 33) as u64;
bits as f64 / (1u64 << 31) as f64
}
#[cfg(test)]
mod tests {
use super::*;
fn make_cnf_2d() -> CnfModel {
CnfModel::new(CnfConfig {
dim: 2,
hidden_dim: 8,
n_layers: 2,
t0: 0.0,
t1: 1.0,
n_steps: 4,
fd_eps: 1e-5,
..Default::default()
})
}
#[test]
fn test_mlp_output_shape_matches_input_dim() {
let mlp = MlpDynamics::new(3, 8, 2);
let z = vec![0.1, 0.2, 0.3];
let out = mlp.forward(&z, 0.5);
assert_eq!(out.len(), 3, "output dim must equal input dim");
}
#[test]
fn test_mlp_different_dims() {
for dim in [1, 4, 8] {
let mlp = MlpDynamics::new(dim, 16, 3);
let z: Vec<f64> = (0..dim).map(|i| i as f64 * 0.1).collect();
let out = mlp.forward(&z, 0.0);
assert_eq!(out.len(), dim);
}
}
#[test]
fn test_cnf_forward_output_shapes() {
let mut model = make_cnf_2d();
let z0 = vec![0.5, -0.3];
let (z1, _log_det) = model.forward(&z0);
assert_eq!(z1.len(), 2, "forward output dim must match input dim");
}
#[test]
fn test_cnf_forward_log_det_is_finite() {
let mut model = make_cnf_2d();
let z0 = vec![0.5, -0.3];
let (_, log_det) = model.forward(&z0);
assert!(log_det.is_finite(), "log_det must be finite, got {log_det}");
}
#[test]
fn test_rk4_step_accuracy_linear_ode() {
let f = |z: &[f64], _t: f64| -> Vec<f64> { vec![z[0]] };
let z0 = vec![1.0_f64];
let dt = 0.1;
let z1 = rk4_step(&f, &z0, 0.0, dt);
let exact = dt.exp();
let err = (z1[0] - exact).abs();
assert!(err < 1e-6, "RK4 error {err} should be < 1e-6 for dt={dt}");
}
#[test]
fn test_cnf_log_prob_is_finite() {
let mut model = make_cnf_2d();
let x = vec![0.0, 0.0];
let lp = model.log_prob(&x);
assert!(lp.is_finite(), "log_prob must be finite, got {lp}");
}
#[test]
fn test_cnf_sample_correct_dimension() {
let mut model = make_cnf_2d();
let samples = model.sample(5);
assert_eq!(samples.len(), 5);
for s in &samples {
assert_eq!(s.len(), 2);
}
}
#[test]
fn test_cnf_sample_finite_values() {
let mut model = make_cnf_2d();
let samples = model.sample(10);
for s in &samples {
for &v in s {
assert!(v.is_finite(), "sample value must be finite, got {v}");
}
}
}
#[test]
fn test_hutchinson_estimator_unbiasedness() {
let mlp = MlpDynamics::new(4, 8, 1);
let z = vec![0.0, 0.0, 0.0, 0.0];
let t = 0.5;
let eps = 1e-5;
let mut sum = 0.0;
let n_probes = 200;
let mut rng_state: u64 = 42;
for _ in 0..n_probes {
let v: Vec<f64> = (0..4)
.map(|_| {
if lcg_next(&mut rng_state) < 0.5 {
-1.0
} else {
1.0
}
})
.collect();
sum += hutchinson_trace(&mlp, &z, t, &v, eps);
}
let mean_est = sum / n_probes as f64;
let mut true_trace = 0.0;
for i in 0..4 {
let mut ei = vec![0.0; 4];
ei[i] = eps;
let zp: Vec<f64> = z.iter().zip(&ei).map(|(a, b)| a + b).collect();
let zm: Vec<f64> = z.iter().zip(&ei).map(|(a, b)| a - b).collect();
let fp = mlp.forward(&zp, t);
let fm = mlp.forward(&zm, t);
true_trace += (fp[i] - fm[i]) / (2.0 * eps);
}
let err = (mean_est - true_trace).abs();
assert!(
err < (true_trace.abs() + 1.0) * 2.0 + 2.0,
"Hutchinson mean {mean_est:.4} far from true trace {true_trace:.4} (err={err:.4})"
);
}
}