use crate::error::{NeuralError, Result};
#[derive(Debug, Clone)]
pub struct ConsistencyConfig {
pub data_dim: usize,
pub sigma_min: f64,
pub sigma_max: f64,
pub sigma_data: f64,
pub rho: f64,
pub n_timesteps: usize,
pub hidden_dim: usize,
pub seed: u64,
}
impl Default for ConsistencyConfig {
fn default() -> Self {
Self {
data_dim: 784,
sigma_min: 0.002,
sigma_max: 80.0,
sigma_data: 0.5,
rho: 7.0,
n_timesteps: 40,
hidden_dim: 512,
seed: 42,
}
}
}
impl ConsistencyConfig {
pub fn tiny(data_dim: usize) -> Self {
Self {
data_dim,
sigma_min: 0.002,
sigma_max: 80.0,
sigma_data: 0.5,
rho: 7.0,
n_timesteps: 10,
hidden_dim: 32,
seed: 0,
}
}
}
#[inline]
pub fn c_skip(sigma: f64, sigma_data: f64) -> f64 {
let sd2 = sigma_data * sigma_data;
sd2 / (sigma * sigma + sd2)
}
#[inline]
pub fn c_out(sigma: f64, sigma_data: f64) -> f64 {
let sd2 = sigma_data * sigma_data;
sigma * sigma_data / (sigma * sigma + sd2).sqrt().max(1e-12)
}
#[inline]
pub fn c_in(sigma: f64, sigma_data: f64) -> f64 {
let sd2 = sigma_data * sigma_data;
1.0 / (sigma * sigma + sd2).sqrt().max(1e-12)
}
#[derive(Debug, Clone)]
pub struct ConsistencySchedule {
pub sigmas: Vec<f64>,
pub sigma_data: f64,
}
impl ConsistencySchedule {
pub fn new(config: &ConsistencyConfig) -> Result<Self> {
if config.n_timesteps < 2 {
return Err(NeuralError::InvalidArgument(
"ConsistencySchedule: n_timesteps must be >= 2".to_string(),
));
}
if config.sigma_min <= 0.0 || config.sigma_max <= config.sigma_min {
return Err(NeuralError::InvalidArgument(format!(
"ConsistencySchedule: require 0 < sigma_min ({}) < sigma_max ({})",
config.sigma_min, config.sigma_max
)));
}
let n = config.n_timesteps;
let inv_rho = 1.0 / config.rho;
let s_max_pow = config.sigma_max.powf(inv_rho);
let s_min_pow = config.sigma_min.powf(inv_rho);
let sigmas: Vec<f64> = (0..n)
.map(|i| {
let frac = i as f64 / (n - 1) as f64;
(s_max_pow + frac * (s_min_pow - s_max_pow)).powf(config.rho)
})
.collect();
Ok(Self {
sigmas,
sigma_data: config.sigma_data,
})
}
pub fn sigma_at(&self, i: usize) -> Result<f64> {
self.sigmas.get(i).copied().ok_or_else(|| {
NeuralError::InvalidArgument(format!(
"ConsistencySchedule: index {} out of range (len={})",
i,
self.sigmas.len()
))
})
}
pub fn closest_index(&self, sigma: f64) -> usize {
self.sigmas
.iter()
.enumerate()
.min_by(|(_, &a), (_, &b)| {
(a - sigma)
.abs()
.partial_cmp(&(b - sigma).abs())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0)
}
}
#[derive(Debug, Clone)]
struct ConsistencyMLP {
data_dim: usize,
hidden_dim: usize,
layers: Vec<(Vec<f64>, Vec<f64>)>,
}
impl ConsistencyMLP {
fn new(data_dim: usize, hidden_dim: usize, seed: u64) -> Result<Self> {
if data_dim == 0 || hidden_dim == 0 {
return Err(NeuralError::InvalidArgument(
"ConsistencyMLP: data_dim and hidden_dim must be > 0".to_string(),
));
}
let input_dim = data_dim + 1;
let golden = 0.618_033_988_749_895_f64;
let make_layer = |in_d: usize, out_d: usize, offset: u64| -> (Vec<f64>, Vec<f64>) {
let std = (2.0 / in_d as f64).sqrt();
let weights: Vec<f64> = (0..in_d * out_d)
.map(|k| std * ((k as f64 + offset as f64 + seed as f64) * golden).sin())
.collect();
let bias = vec![0.0f64; out_d];
(weights, bias)
};
let layers = vec![
make_layer(input_dim, hidden_dim, 0),
make_layer(hidden_dim, hidden_dim, (hidden_dim * hidden_dim) as u64),
make_layer(hidden_dim, data_dim, (2 * hidden_dim * hidden_dim) as u64),
];
Ok(Self {
data_dim,
hidden_dim,
layers,
})
}
fn forward(&self, input: &[f64]) -> Vec<f64> {
let mut h = input.to_vec();
for (layer_idx, (w, b)) in self.layers.iter().enumerate() {
let out_dim = b.len();
let in_dim = h.len();
let mut next = vec![0.0f64; out_dim];
for j in 0..out_dim {
let mut acc = b[j];
for i in 0..in_dim {
let idx = j * in_dim + i;
if idx < w.len() {
acc += w[idx] * h[i];
}
}
next[j] = acc;
}
if layer_idx < self.layers.len() - 1 {
for v in &mut next {
*v = v.max(0.0);
}
}
h = next;
}
h
}
}
#[derive(Debug, Clone)]
pub struct ConsistencyModel {
pub config: ConsistencyConfig,
pub schedule: ConsistencySchedule,
network: ConsistencyMLP,
rng_state: u64,
}
impl ConsistencyModel {
pub fn new(config: ConsistencyConfig) -> Result<Self> {
let schedule = ConsistencySchedule::new(&config)?;
let network = ConsistencyMLP::new(config.data_dim, config.hidden_dim, config.seed)?;
let rng_state = config.seed.wrapping_add(0xdeadbeef_cafebabe);
Ok(Self {
config,
schedule,
network,
rng_state,
})
}
fn lcg_step(&mut self) -> u64 {
self.rng_state = self
.rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
self.rng_state
}
fn sample_normal(&mut self) -> f64 {
let r1 = self.lcg_step();
let r2 = self.lcg_step();
let u1 = ((r1 >> 11) as f64 + 1.0) / ((1u64 << 53) as f64 + 1.0);
let u2 = ((r2 >> 11) as f64 + 1.0) / ((1u64 << 53) as f64 + 1.0);
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
pub fn consistency_fn(&self, x_t: &[f64], sigma: f64) -> Vec<f64> {
let d = self.config.data_dim;
let sd = self.config.sigma_data;
let cs = c_skip(sigma, sd);
let co = c_out(sigma, sd);
let ci = c_in(sigma, sd);
let mut inp: Vec<f64> = x_t.iter().map(|&v| ci * v).collect();
let sigma_enc = (sigma / sd).ln(); inp.push(sigma_enc);
let f_out = self.network.forward(&inp);
let len = d.min(f_out.len()).min(x_t.len());
let mut out = vec![0.0f64; d];
for i in 0..len {
out[i] = cs * x_t[i] + co * f_out[i];
}
out
}
pub fn ct_loss(&mut self, x0: &[f64], sigma: f64, sigma_next: f64) -> Result<f64> {
let d = self.config.data_dim;
if x0.len() != d {
return Err(NeuralError::ShapeMismatch(format!(
"ConsistencyModel ct_loss: x0 len {} != data_dim {}",
x0.len(),
d
)));
}
if sigma <= 0.0 {
return Err(NeuralError::InvalidArgument(format!(
"ConsistencyModel ct_loss: sigma must be > 0, got {sigma}"
)));
}
if sigma_next <= 0.0 || sigma_next >= sigma {
return Err(NeuralError::InvalidArgument(format!(
"ConsistencyModel ct_loss: sigma_next ({sigma_next}) must be in (0, sigma={sigma})"
)));
}
let noise: Vec<f64> = (0..d).map(|_| self.sample_normal()).collect();
let x_t: Vec<f64> = x0
.iter()
.zip(&noise)
.map(|(&x, &n)| x + n * sigma)
.collect();
let x_t_prime: Vec<f64> = x0
.iter()
.zip(&noise)
.map(|(&x, &n)| x + n * sigma_next)
.collect();
let f1 = self.consistency_fn(&x_t, sigma);
let f2 = self.consistency_fn(&x_t_prime, sigma_next);
let loss: f64 = f1
.iter()
.zip(&f2)
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f64>()
/ d as f64;
Ok(loss)
}
pub fn cd_loss(&mut self, x0: &[f64], sigma: f64, teacher_output: &[f64]) -> Result<f64> {
let d = self.config.data_dim;
if x0.len() != d {
return Err(NeuralError::ShapeMismatch(format!(
"ConsistencyModel cd_loss: x0 len {} != data_dim {}",
x0.len(),
d
)));
}
if teacher_output.len() != d {
return Err(NeuralError::ShapeMismatch(format!(
"ConsistencyModel cd_loss: teacher_output len {} != data_dim {}",
teacher_output.len(),
d
)));
}
if sigma <= 0.0 {
return Err(NeuralError::InvalidArgument(format!(
"ConsistencyModel cd_loss: sigma must be > 0, got {sigma}"
)));
}
let noise: Vec<f64> = (0..d).map(|_| self.sample_normal()).collect();
let x_t: Vec<f64> = x0
.iter()
.zip(&noise)
.map(|(&x, &n)| x + n * sigma)
.collect();
let f_student = self.consistency_fn(&x_t, sigma);
let sigma_target = self.config.sigma_min;
let f_teacher = self.consistency_fn(teacher_output, sigma_target);
let loss: f64 = f_student
.iter()
.zip(&f_teacher)
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f64>()
/ d as f64;
Ok(loss)
}
pub fn sample_single_step(&self, x_t: &[f64]) -> Result<Vec<f64>> {
let d = self.config.data_dim;
if x_t.len() != d {
return Err(NeuralError::ShapeMismatch(format!(
"ConsistencyModel sample_single_step: x_T len {} != data_dim {}",
x_t.len(),
d
)));
}
Ok(self.consistency_fn(x_t, self.config.sigma_max))
}
pub fn sample_two_step(&mut self, x_t: &[f64], sigma_mid: f64) -> Result<Vec<f64>> {
let d = self.config.data_dim;
if x_t.len() != d {
return Err(NeuralError::ShapeMismatch(format!(
"ConsistencyModel sample_two_step: x_T len {} != data_dim {}",
x_t.len(),
d
)));
}
if sigma_mid <= self.config.sigma_min || sigma_mid >= self.config.sigma_max {
return Err(NeuralError::InvalidArgument(format!(
"ConsistencyModel sample_two_step: sigma_mid ({sigma_mid}) must be in \
(sigma_min={}, sigma_max={})",
self.config.sigma_min, self.config.sigma_max
)));
}
let x0_hat = self.consistency_fn(x_t, self.config.sigma_max);
let x_mid: Vec<f64> = x0_hat
.iter()
.map(|&v| {
let z = self.sample_normal();
v + z * sigma_mid
})
.collect();
Ok(self.consistency_fn(&x_mid, sigma_mid))
}
pub fn sample_multistep(&mut self, x_t: &[f64]) -> Result<Vec<f64>> {
let d = self.config.data_dim;
if x_t.len() != d {
return Err(NeuralError::ShapeMismatch(format!(
"ConsistencyModel sample_multistep: x_T len {} != data_dim {}",
x_t.len(),
d
)));
}
let sigmas = self.schedule.sigmas.clone();
let mut x_current = x_t.to_vec();
for (i, &sigma) in sigmas.iter().enumerate() {
let x0_hat = self.consistency_fn(&x_current, sigma);
if i + 1 < sigmas.len() {
let sigma_next = sigmas[i + 1];
x_current = x0_hat
.iter()
.map(|&v| {
let z = self.sample_normal();
v + z * sigma_next
})
.collect();
} else {
x_current = x0_hat;
}
}
Ok(x_current)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_defaults() {
let cfg = ConsistencyConfig::default();
assert_eq!(cfg.data_dim, 784);
assert!((cfg.sigma_min - 0.002).abs() < 1e-9);
assert!((cfg.sigma_max - 80.0).abs() < 1e-9);
assert!((cfg.sigma_data - 0.5).abs() < 1e-9);
assert!((cfg.rho - 7.0).abs() < 1e-9);
assert_eq!(cfg.n_timesteps, 40);
assert_eq!(cfg.seed, 42);
}
#[test]
fn test_tiny_config() {
let cfg = ConsistencyConfig::tiny(4);
assert_eq!(cfg.data_dim, 4);
assert_eq!(cfg.n_timesteps, 10);
}
#[test]
fn test_schedule_length() {
let cfg = ConsistencyConfig::tiny(4);
let sched = ConsistencySchedule::new(&cfg).expect("schedule");
assert_eq!(sched.sigmas.len(), cfg.n_timesteps);
}
#[test]
fn test_schedule_strictly_decreasing() {
let cfg = ConsistencyConfig::tiny(4);
let sched = ConsistencySchedule::new(&cfg).expect("schedule");
for window in sched.sigmas.windows(2) {
assert!(
window[0] > window[1],
"schedule not strictly decreasing: {} >= {}",
window[0],
window[1]
);
}
}
#[test]
fn test_schedule_endpoints() {
let cfg = ConsistencyConfig::tiny(4);
let sched = ConsistencySchedule::new(&cfg).expect("schedule");
assert!(
(sched.sigmas[0] - cfg.sigma_max).abs() < 1e-6,
"first sigma should be sigma_max; got {}",
sched.sigmas[0]
);
let last = *sched.sigmas.last().expect("non-empty");
assert!(
(last - cfg.sigma_min).abs() < 1e-6,
"last sigma should be sigma_min; got {last}"
);
}
#[test]
fn test_schedule_invalid_args() {
let mut cfg = ConsistencyConfig::tiny(4);
cfg.n_timesteps = 1;
assert!(ConsistencySchedule::new(&cfg).is_err());
cfg.n_timesteps = 10;
cfg.sigma_min = 0.0;
assert!(ConsistencySchedule::new(&cfg).is_err());
cfg.sigma_min = 100.0; assert!(ConsistencySchedule::new(&cfg).is_err());
}
#[test]
fn test_c_skip_at_zero_noise() {
let sd = 0.5;
let eps = 1e-10;
let cs = c_skip(eps, sd);
assert!((cs - 1.0).abs() < 1e-4, "c_skip near 0 should ≈ 1, got {cs}");
}
#[test]
fn test_c_out_at_zero_noise() {
let sd = 0.5;
let eps = 1e-10;
let co = c_out(eps, sd);
assert!(co.abs() < 1e-5, "c_out near 0 should ≈ 0, got {co}");
}
#[test]
fn test_c_skip_plus_c_out_at_sigma_data() {
let sd = 0.5;
let cs = c_skip(sd, sd);
assert!(
(cs - 0.5).abs() < 1e-10,
"c_skip(sigma_data) should be 0.5, got {cs}"
);
}
#[test]
fn test_c_in_normalises_scale() {
let sd = 0.5;
for sigma in [0.1, 0.5, 1.0, 10.0, 80.0] {
let ci = c_in(sigma, sd);
assert!(ci > 0.0 && ci.is_finite(), "c_in should be finite positive");
assert!(ci * sigma < 1.0 + 1e-9);
}
}
#[test]
fn test_model_creation() {
let cfg = ConsistencyConfig::tiny(4);
let model = ConsistencyModel::new(cfg).expect("model creation");
assert_eq!(model.config.data_dim, 4);
}
#[test]
fn test_consistency_fn_output_shape() {
let cfg = ConsistencyConfig::tiny(4);
let model = ConsistencyModel::new(cfg).expect("model");
let x_noisy = vec![0.5, -0.3, 0.2, 0.8];
let out = model.consistency_fn(&x_noisy, 10.0);
assert_eq!(out.len(), 4);
for &v in &out {
assert!(v.is_finite(), "consistency_fn output must be finite");
}
}
#[test]
fn test_consistency_fn_boundary_condition() {
let cfg = ConsistencyConfig::tiny(4);
let model = ConsistencyModel::new(cfg.clone()).expect("model");
let x_clean = vec![0.1, 0.2, -0.3, 0.4];
let sigma_min = cfg.sigma_min; let out = model.consistency_fn(&x_clean, sigma_min);
for (&xi, &oi) in x_clean.iter().zip(&out) {
let diff = (xi - oi).abs();
assert!(
diff < 0.5,
"At sigma_min, f(x,t) should be close to x; diff={diff}"
);
}
}
#[test]
fn test_ct_loss_runs() {
let cfg = ConsistencyConfig::tiny(4);
let mut model = ConsistencyModel::new(cfg).expect("model");
let x0 = vec![0.3, -0.1, 0.5, -0.2];
let loss = model.ct_loss(&x0, 10.0, 5.0).expect("ct_loss");
assert!(loss >= 0.0 && loss.is_finite(), "CT loss must be finite non-negative");
}
#[test]
fn test_ct_loss_invalid_args() {
let cfg = ConsistencyConfig::tiny(4);
let mut model = ConsistencyModel::new(cfg).expect("model");
let x0 = vec![0.3, -0.1, 0.5, -0.2];
assert!(model.ct_loss(&x0, 0.0, -1.0).is_err());
assert!(model.ct_loss(&x0, 5.0, 5.0).is_err());
assert!(model.ct_loss(&x0, 5.0, 10.0).is_err());
assert!(model.ct_loss(&[1.0, 2.0], 10.0, 5.0).is_err());
}
#[test]
fn test_cd_loss_runs() {
let cfg = ConsistencyConfig::tiny(4);
let mut model = ConsistencyModel::new(cfg).expect("model");
let x0 = vec![0.3, -0.1, 0.5, -0.2];
let teacher = vec![-0.1, 0.2, 0.1, 0.3];
let loss = model.cd_loss(&x0, 10.0, &teacher).expect("cd_loss");
assert!(loss >= 0.0 && loss.is_finite());
}
#[test]
fn test_sample_single_step_shape() {
let cfg = ConsistencyConfig::tiny(4);
let model = ConsistencyModel::new(cfg).expect("model");
let x_noise = vec![1.0, -0.5, 0.3, -0.2];
let x0 = model.sample_single_step(&x_noise).expect("single-step sample");
assert_eq!(x0.len(), 4);
for &v in &x0 {
assert!(v.is_finite());
}
}
#[test]
fn test_sample_single_step_wrong_dim() {
let cfg = ConsistencyConfig::tiny(4);
let model = ConsistencyModel::new(cfg).expect("model");
let x_wrong = vec![1.0, -0.5]; assert!(model.sample_single_step(&x_wrong).is_err());
}
#[test]
fn test_sample_two_step_shape() {
let cfg = ConsistencyConfig::tiny(4);
let mut model = ConsistencyModel::new(cfg).expect("model");
let x_noise = vec![1.0, -0.5, 0.3, -0.2];
let sigma_mid = 5.0; let x0 = model.sample_two_step(&x_noise, sigma_mid).expect("two-step sample");
assert_eq!(x0.len(), 4);
for &v in &x0 {
assert!(v.is_finite());
}
}
#[test]
fn test_sample_two_step_invalid_sigma_mid() {
let cfg = ConsistencyConfig::tiny(4);
let mut model = ConsistencyModel::new(cfg.clone()).expect("model");
let x_noise = vec![1.0, -0.5, 0.3, -0.2];
assert!(model.sample_two_step(&x_noise, cfg.sigma_min).is_err());
assert!(model.sample_two_step(&x_noise, cfg.sigma_max).is_err());
assert!(model.sample_two_step(&x_noise, 0.0001).is_err());
}
#[test]
fn test_sample_multistep_shape() {
let cfg = ConsistencyConfig::tiny(4);
let mut model = ConsistencyModel::new(cfg).expect("model");
let x_noise = vec![1.0, -0.5, 0.3, -0.2];
let x0 = model.sample_multistep(&x_noise).expect("multistep");
assert_eq!(x0.len(), 4);
for &v in &x0 {
assert!(v.is_finite());
}
}
}