use crate::error::{NeuralError, Result};
#[derive(Debug, Clone)]
pub struct ScoreNetworkConfig {
pub data_dim: usize,
pub hidden_dim: usize,
pub num_layers: usize,
pub num_noise_levels: usize,
pub sigma_min: f64,
pub sigma_max: f64,
pub seed: u64,
}
impl ScoreNetworkConfig {
pub fn default_config(data_dim: usize) -> Self {
Self {
data_dim,
hidden_dim: 128,
num_layers: 3,
num_noise_levels: 10,
sigma_min: 0.01,
sigma_max: 1.0,
seed: 42,
}
}
pub fn tiny(data_dim: usize) -> Self {
Self {
data_dim,
hidden_dim: 16,
num_layers: 2,
num_noise_levels: 5,
sigma_min: 0.05,
sigma_max: 0.5,
seed: 0,
}
}
}
pub trait ScoreFunction: Send + Sync + std::fmt::Debug {
fn score(&self, x: &[f64], sigma: f64) -> Result<Vec<f64>>;
fn parameter_count(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct ScoreNetwork {
pub config: ScoreNetworkConfig,
layers: Vec<(Vec<f64>, Vec<f64>)>,
}
impl ScoreNetwork {
pub fn new(config: ScoreNetworkConfig) -> Result<Self> {
if config.data_dim == 0 {
return Err(NeuralError::InvalidArgument(
"ScoreNetwork: data_dim must be > 0".to_string(),
));
}
if config.hidden_dim == 0 {
return Err(NeuralError::InvalidArgument(
"ScoreNetwork: hidden_dim must be > 0".to_string(),
));
}
if config.num_layers == 0 {
return Err(NeuralError::InvalidArgument(
"ScoreNetwork: num_layers must be > 0".to_string(),
));
}
let in_dim = config.data_dim + 1;
let layers = Self::init_layers(in_dim, config.hidden_dim, config.data_dim, config.num_layers, config.seed);
Ok(Self { config, layers })
}
fn lcg_sample(state: &mut u64) -> f64 {
*state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let bits = *state >> 11;
(bits as f64) / (1u64 << 53) as f64 * 2.0 - 1.0
}
fn init_layers(
in_dim: usize,
hidden: usize,
out_dim: usize,
num_layers: usize,
seed: u64,
) -> Vec<(Vec<f64>, Vec<f64>)> {
let mut rng = seed.wrapping_add(0xdeadbeef);
let mut layers = Vec::with_capacity(num_layers + 1);
let limit = (6.0 / (in_dim + hidden) as f64).sqrt();
let w: Vec<f64> = (0..in_dim * hidden)
.map(|_| Self::lcg_sample(&mut rng) * limit)
.collect();
layers.push((w, vec![0.0f64; hidden]));
for _ in 1..num_layers {
let limit = (6.0 / (hidden + hidden) as f64).sqrt();
let w: Vec<f64> = (0..hidden * hidden)
.map(|_| Self::lcg_sample(&mut rng) * limit)
.collect();
layers.push((w, vec![0.0f64; hidden]));
}
let limit = (6.0 / (hidden + out_dim) as f64).sqrt();
let w: Vec<f64> = (0..hidden * out_dim)
.map(|_| Self::lcg_sample(&mut rng) * limit)
.collect();
layers.push((w, vec![0.0f64; out_dim]));
layers
}
fn silu(x: f64) -> f64 {
x / (1.0 + (-x).exp())
}
fn mlp_forward(&self, inp: &[f64]) -> Vec<f64> {
let mut h = inp.to_vec();
let n = self.layers.len();
for (idx, (w, b)) in self.layers.iter().enumerate() {
let out_dim = b.len();
let in_dim = h.len();
let mut next = vec![0.0f64; out_dim];
for j in 0..out_dim {
let mut s = b[j];
for i in 0..in_dim {
let wi = j * in_dim + i;
if wi < w.len() {
s += w[wi] * h[i];
}
}
next[j] = s;
}
if idx < n - 1 {
for v in &mut next {
*v = Self::silu(*v);
}
}
h = next;
}
h
}
pub fn jvp_approx(&self, x: &[f64], v: &[f64], sigma: f64) -> Result<Vec<f64>> {
const EPS: f64 = 1e-4;
let d = x.len();
if v.len() != d {
return Err(NeuralError::ShapeMismatch(format!(
"ScoreNetwork jvp_approx: x len {} != v len {}",
d,
v.len()
)));
}
let x_plus: Vec<f64> = x.iter().zip(v).map(|(&xi, &vi)| xi + EPS * vi).collect();
let x_minus: Vec<f64> = x.iter().zip(v).map(|(&xi, &vi)| xi - EPS * vi).collect();
let s_plus = self.score(&x_plus, sigma)?;
let s_minus = self.score(&x_minus, sigma)?;
let jvp: Vec<f64> = s_plus
.iter()
.zip(&s_minus)
.map(|(&sp, &sm)| (sp - sm) / (2.0 * EPS))
.collect();
Ok(jvp)
}
pub fn noise_levels(&self) -> Vec<f64> {
let l = self.config.num_noise_levels.max(1);
if l == 1 {
return vec![self.config.sigma_max];
}
let ratio = self.config.sigma_max / self.config.sigma_min.max(1e-12);
(0..l)
.map(|i| self.config.sigma_min * ratio.powf(i as f64 / (l - 1) as f64))
.collect()
}
}
impl ScoreFunction for ScoreNetwork {
fn score(&self, x: &[f64], sigma: f64) -> Result<Vec<f64>> {
if x.len() != self.config.data_dim {
return Err(NeuralError::ShapeMismatch(format!(
"ScoreNetwork: input dim {} != data_dim {}",
x.len(),
self.config.data_dim
)));
}
if sigma <= 0.0 {
return Err(NeuralError::InvalidArgument(format!(
"ScoreNetwork: sigma must be > 0, got {sigma}"
)));
}
let mut inp = x.to_vec();
inp.push(sigma.ln());
let raw = self.mlp_forward(&inp);
Ok(raw.iter().map(|&v| v / sigma).collect())
}
fn parameter_count(&self) -> usize {
self.layers.iter().map(|(w, b)| w.len() + b.len()).sum()
}
}
#[derive(Debug)]
pub struct DenoisingScoreMatching {
rng_state: u64,
}
impl DenoisingScoreMatching {
pub fn new(seed: u64) -> Self {
Self {
rng_state: seed.wrapping_add(0xfeedface),
}
}
fn sample_normal(&mut self) -> f64 {
self.rng_state = self
.rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let u1 = ((self.rng_state >> 11) as f64 + 1.0) / ((1u64 << 53) as f64 + 1.0);
self.rng_state = self
.rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let u2 = ((self.rng_state >> 11) as f64 + 1.0) / ((1u64 << 53) as f64 + 1.0);
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
pub fn compute_loss(
&mut self,
x: &[f64],
sigma: f64,
score_fn: &dyn ScoreFunction,
) -> Result<(f64, Vec<f64>)> {
let d = x.len();
if sigma <= 0.0 {
return Err(NeuralError::InvalidArgument(format!(
"DSM: sigma must be > 0, got {sigma}"
)));
}
let eps: Vec<f64> = (0..d).map(|_| self.sample_normal()).collect();
let x_tilde: Vec<f64> = x.iter().zip(&eps).map(|(&xi, &ei)| xi + sigma * ei).collect();
let s_pred = score_fn.score(&x_tilde, sigma)?;
let loss: f64 = s_pred
.iter()
.zip(&eps)
.map(|(&s, &e)| {
let residual = s + e / sigma;
residual * residual
})
.sum::<f64>()
/ d as f64;
Ok((loss, x_tilde))
}
pub fn annealed_loss(
&mut self,
x: &[f64],
score_net: &ScoreNetwork,
) -> Result<f64> {
let sigmas = score_net.noise_levels();
let l = sigmas.len();
if l == 0 {
return Ok(0.0);
}
self.rng_state = self
.rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let idx = (self.rng_state >> 33) as usize % l;
let sigma = sigmas[idx];
let (loss, _x_tilde) = self.compute_loss(x, sigma, score_net)?;
Ok(loss * sigma * sigma)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ProjectionDist {
Rademacher,
Gaussian,
}
#[derive(Debug)]
pub struct SlicedScoreMatching {
pub num_projections: usize,
pub proj_dist: ProjectionDist,
rng_state: u64,
}
impl SlicedScoreMatching {
pub fn new(num_projections: usize, proj_dist: ProjectionDist, seed: u64) -> Result<Self> {
if num_projections == 0 {
return Err(NeuralError::InvalidArgument(
"SSM: num_projections must be > 0".to_string(),
));
}
Ok(Self {
num_projections,
proj_dist,
rng_state: seed.wrapping_add(0xc0ffee42),
})
}
fn sample_rademacher(&mut self) -> f64 {
self.rng_state = self
.rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
if self.rng_state >> 63 == 0 { 1.0 } else { -1.0 }
}
fn sample_gaussian(&mut self) -> f64 {
self.rng_state = self
.rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let u1 = ((self.rng_state >> 11) as f64 + 1.0) / ((1u64 << 53) as f64 + 1.0);
self.rng_state = self
.rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let u2 = ((self.rng_state >> 11) as f64 + 1.0) / ((1u64 << 53) as f64 + 1.0);
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
fn sample_projection(&mut self, dim: usize) -> Vec<f64> {
match self.proj_dist {
ProjectionDist::Rademacher => (0..dim).map(|_| self.sample_rademacher()).collect(),
ProjectionDist::Gaussian => (0..dim).map(|_| self.sample_gaussian()).collect(),
}
}
pub fn compute_loss(
&mut self,
x: &[f64],
sigma: f64,
score_net: &ScoreNetwork,
) -> Result<f64> {
let d = x.len();
let s = score_net.score(x, sigma)?;
let half_sq_norm: f64 = s.iter().map(|&si| si * si).sum::<f64>() / (2.0 * d as f64);
let mut div_term = 0.0f64;
for _ in 0..self.num_projections {
let v = self.sample_projection(d);
let jvp = score_net.jvp_approx(x, &v, sigma)?;
let vt_jvp: f64 = v.iter().zip(&jvp).map(|(&vi, &ji)| vi * ji).sum();
div_term += vt_jvp;
}
div_term /= self.num_projections as f64;
Ok(div_term + half_sq_norm)
}
pub fn train_epoch(
&mut self,
data: &[Vec<f64>],
sigma: f64,
score_net: &ScoreNetwork,
) -> Result<f64> {
if data.is_empty() {
return Ok(0.0);
}
let total: f64 = data
.iter()
.map(|x| self.compute_loss(x, sigma, score_net))
.collect::<Result<Vec<f64>>>()?
.iter()
.sum();
Ok(total / data.len() as f64)
}
}
#[derive(Debug, Clone)]
pub struct LangevinConfig {
pub steps_per_level: usize,
pub step_size_coeff: f64,
pub add_noise: bool,
pub seed: u64,
}
impl Default for LangevinConfig {
fn default() -> Self {
Self {
steps_per_level: 100,
step_size_coeff: 1e-5,
add_noise: true,
seed: 12345,
}
}
}
#[derive(Debug)]
pub struct AnnealedLangevin {
pub config: LangevinConfig,
rng_state: u64,
}
impl AnnealedLangevin {
pub fn new(config: LangevinConfig) -> Self {
let rng = config.seed.wrapping_add(0xabcdef01);
Self { config, rng_state: rng }
}
fn sample_normal(&mut self) -> f64 {
self.rng_state = self
.rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let u1 = ((self.rng_state >> 11) as f64 + 1.0) / ((1u64 << 53) as f64 + 1.0);
self.rng_state = self
.rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let u2 = ((self.rng_state >> 11) as f64 + 1.0) / ((1u64 << 53) as f64 + 1.0);
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
pub fn sample(
&mut self,
x_init: &[f64],
score_net: &ScoreNetwork,
) -> Result<Vec<f64>> {
let sigmas = score_net.noise_levels();
if sigmas.is_empty() {
return Err(NeuralError::InvalidArgument(
"AnnealedLangevin: no noise levels".to_string(),
));
}
let sigma_max = sigmas.last().copied().unwrap_or(1.0);
let sigma_max_sq = sigma_max * sigma_max;
let d = x_init.len();
let mut x = x_init.to_vec();
for &sigma in sigmas.iter().rev() {
let alpha = self.config.step_size_coeff * sigma * sigma / sigma_max_sq.max(1e-12);
let noise_std = (2.0 * alpha).sqrt();
for _ in 0..self.config.steps_per_level {
let score = score_net.score(&x, sigma)?;
let mut x_new: Vec<f64> = x
.iter()
.zip(&score)
.map(|(&xi, &si)| xi + alpha * si)
.collect();
if self.config.add_noise {
for xi in x_new.iter_mut() {
*xi += noise_std * self.sample_normal();
}
}
for i in 0..d {
if !x_new[i].is_finite() {
x_new[i] = x[i];
}
}
x = x_new;
}
}
Ok(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_score_network_creation() {
let cfg = ScoreNetworkConfig::tiny(4);
let net = ScoreNetwork::new(cfg).expect("score network creation");
assert!(net.parameter_count() > 0);
}
#[test]
fn test_score_network_output_shape() {
let cfg = ScoreNetworkConfig::tiny(4);
let net = ScoreNetwork::new(cfg).expect("network");
let x = vec![0.1, -0.2, 0.3, -0.4];
let s = net.score(&x, 0.1).expect("score evaluation");
assert_eq!(s.len(), 4);
for &v in &s {
assert!(v.is_finite(), "score not finite: {v}");
}
}
#[test]
fn test_score_network_sigma_scaling() {
let cfg = ScoreNetworkConfig::tiny(4);
let net = ScoreNetwork::new(cfg).expect("network");
let x = vec![0.1, -0.2, 0.3, -0.4];
let s1 = net.score(&x, 0.1).expect("score at σ=0.1");
let s2 = net.score(&x, 0.2).expect("score at σ=0.2");
assert_ne!(s1[0], s2[0]);
for &v in s1.iter().chain(s2.iter()) {
assert!(v.is_finite());
}
}
#[test]
fn test_dsm_loss() {
let cfg = ScoreNetworkConfig::tiny(4);
let net = ScoreNetwork::new(cfg).expect("network");
let mut dsm = DenoisingScoreMatching::new(42);
let x = vec![0.5, -0.3, 0.2, 0.8];
let (loss, x_tilde) = dsm.compute_loss(&x, 0.1, &net).expect("DSM loss");
assert!(loss >= 0.0 && loss.is_finite(), "DSM loss invalid: {loss}");
assert_eq!(x_tilde.len(), 4);
}
#[test]
fn test_dsm_annealed() {
let cfg = ScoreNetworkConfig::tiny(4);
let net = ScoreNetwork::new(cfg).expect("network");
let mut dsm = DenoisingScoreMatching::new(0);
let x = vec![0.5, -0.3, 0.2, 0.8];
let loss = dsm.annealed_loss(&x, &net).expect("annealed loss");
assert!(loss.is_finite());
}
#[test]
fn test_ssm_loss() {
let cfg = ScoreNetworkConfig::tiny(4);
let net = ScoreNetwork::new(cfg).expect("network");
let mut ssm = SlicedScoreMatching::new(4, ProjectionDist::Rademacher, 99)
.expect("SSM creation");
let x = vec![0.5, -0.3, 0.2, 0.8];
let loss = ssm.compute_loss(&x, 0.1, &net).expect("SSM loss");
assert!(loss.is_finite(), "SSM loss not finite: {loss}");
}
#[test]
fn test_ssm_gaussian_projections() {
let cfg = ScoreNetworkConfig::tiny(4);
let net = ScoreNetwork::new(cfg).expect("network");
let mut ssm = SlicedScoreMatching::new(2, ProjectionDist::Gaussian, 7)
.expect("SSM");
let data: Vec<Vec<f64>> = (0..5).map(|i| vec![i as f64 * 0.1; 4]).collect();
let loss = ssm.train_epoch(&data, 0.2, &net).expect("epoch");
assert!(loss.is_finite());
}
#[test]
fn test_noise_levels_geometric() {
let cfg = ScoreNetworkConfig {
num_noise_levels: 5,
sigma_min: 0.1,
sigma_max: 1.0,
..ScoreNetworkConfig::tiny(4)
};
let net = ScoreNetwork::new(cfg).expect("network");
let levels = net.noise_levels();
assert_eq!(levels.len(), 5);
assert!((levels[0] - 0.1).abs() < 1e-9);
assert!((levels[4] - 1.0).abs() < 1e-9);
for i in 1..5 {
assert!(levels[i] > levels[i - 1], "noise levels not increasing");
}
}
#[test]
fn test_annealed_langevin() {
let cfg = ScoreNetworkConfig::tiny(4);
let net = ScoreNetwork::new(cfg).expect("network");
let langevin_cfg = LangevinConfig {
steps_per_level: 3,
step_size_coeff: 1e-5,
add_noise: true,
seed: 0,
};
let mut sampler = AnnealedLangevin::new(langevin_cfg);
let x_init = vec![0.0; 4];
let sample = sampler.sample(&x_init, &net).expect("langevin sample");
assert_eq!(sample.len(), 4);
for &v in &sample {
assert!(v.is_finite(), "sample not finite: {v}");
}
}
#[test]
fn test_ssm_zero_projection_error() {
let result = SlicedScoreMatching::new(0, ProjectionDist::Rademacher, 42);
assert!(result.is_err());
}
}