use crate::error::{NeuralError, Result};
pub trait EnergyFunction: Send + Sync {
fn energy(&self, x: &[f64]) -> f64;
fn energy_gradient(&self, x: &[f64], eps: f64) -> Vec<f64> {
let e0 = self.energy(x);
let mut grad = vec![0.0f64; x.len()];
let mut x_perturb = x.to_vec();
for i in 0..x.len() {
x_perturb[i] += eps;
let e1 = self.energy(&x_perturb);
grad[i] = (e1 - e0) / eps;
x_perturb[i] = x[i];
}
grad
}
}
fn sigmoid(x: f64) -> f64 {
if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let ex = x.exp();
ex / (1.0 + ex)
}
}
fn bernoulli_sample(p: f64, state: &mut u64) -> f64 {
*state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1_442_695_040_888_963_407);
let u = (*state >> 33) as f64 / (u32::MAX as f64);
if u < p { 1.0 } else { 0.0 }
}
fn normal_sample(state: &mut u64) -> f64 {
*state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1_442_695_040_888_963_407);
let u1 = ((*state >> 33) as f64 + 1.0) / (u32::MAX as f64 + 2.0);
*state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1_442_695_040_888_963_407);
let u2 = ((*state >> 33) as f64 + 1.0) / (u32::MAX as f64 + 2.0);
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
#[derive(Debug, Clone)]
pub struct RestrictedBoltzmannMachine {
pub n_visible: usize,
pub n_hidden: usize,
pub weights: Vec<f64>,
pub visible_bias: Vec<f64>,
pub hidden_bias: Vec<f64>,
pub learning_rate: f64,
pub cd_steps: usize,
rng_state: u64,
}
impl RestrictedBoltzmannMachine {
pub fn new(n_visible: usize, n_hidden: usize, learning_rate: f64, cd_steps: usize) -> Result<Self> {
if n_visible == 0 || n_hidden == 0 {
return Err(NeuralError::InvalidArgument(
"RBM: n_visible and n_hidden must be > 0".to_string(),
));
}
if cd_steps == 0 {
return Err(NeuralError::InvalidArgument(
"RBM: cd_steps must be >= 1".to_string(),
));
}
let std = (2.0 / (n_visible + n_hidden) as f64).sqrt();
let weights: Vec<f64> = (0..n_visible * n_hidden)
.map(|k| std * ((k as f64) * 0.6180339887).sin())
.collect();
Ok(Self {
n_visible,
n_hidden,
weights,
visible_bias: vec![0.0; n_visible],
hidden_bias: vec![0.0; n_hidden],
learning_rate,
cd_steps,
rng_state: 0xdeadbeef_cafebabe,
})
}
pub fn hidden_probs(&self, v: &[f64]) -> Vec<f64> {
(0..self.n_hidden)
.map(|j| {
let s: f64 = (0..self.n_visible)
.map(|i| v[i] * self.weights[i * self.n_hidden + j])
.sum();
sigmoid(s + self.hidden_bias[j])
})
.collect()
}
pub fn visible_probs(&self, h: &[f64]) -> Vec<f64> {
(0..self.n_visible)
.map(|i| {
let s: f64 = (0..self.n_hidden)
.map(|j| h[j] * self.weights[i * self.n_hidden + j])
.sum();
sigmoid(s + self.visible_bias[i])
})
.collect()
}
pub fn sample_hidden(&mut self, v: &[f64]) -> Vec<f64> {
let probs = self.hidden_probs(v);
probs
.iter()
.map(|&p| bernoulli_sample(p, &mut self.rng_state))
.collect()
}
pub fn sample_visible(&mut self, h: &[f64]) -> Vec<f64> {
let probs = self.visible_probs(h);
probs
.iter()
.map(|&p| bernoulli_sample(p, &mut self.rng_state))
.collect()
}
pub fn free_energy(&self, v: &[f64]) -> f64 {
let bv_term: f64 = v.iter().zip(&self.visible_bias).map(|(&vi, &bi)| vi * bi).sum();
let hidden_term: f64 = (0..self.n_hidden)
.map(|j| {
let s: f64 = (0..self.n_visible)
.map(|i| v[i] * self.weights[i * self.n_hidden + j])
.sum();
let x = s + self.hidden_bias[j];
if x > 0.0 {
x + (1.0 + (-x).exp()).ln()
} else {
(1.0 + x.exp()).ln()
}
})
.sum();
-bv_term - hidden_term
}
pub fn train_step(&mut self, v_data: &[f64]) -> Result<f64> {
if v_data.len() != self.n_visible {
return Err(NeuralError::ShapeMismatch(format!(
"RBM train_step: expected {} visible units, got {}",
self.n_visible,
v_data.len()
)));
}
let h_pos_probs = self.hidden_probs(v_data);
let h_pos: Vec<f64> = h_pos_probs
.iter()
.map(|&p| bernoulli_sample(p, &mut self.rng_state))
.collect();
let mut v_neg = v_data.to_vec();
let mut h_neg = h_pos.clone();
for _ in 0..self.cd_steps {
v_neg = self.sample_visible(&h_neg);
h_neg = self.sample_hidden(&v_neg);
}
let h_neg_probs = self.hidden_probs(&v_neg);
let lr = self.learning_rate;
for i in 0..self.n_visible {
for j in 0..self.n_hidden {
let pos = v_data[i] * h_pos_probs[j];
let neg = v_neg[i] * h_neg_probs[j];
self.weights[i * self.n_hidden + j] += lr * (pos - neg);
}
self.visible_bias[i] += lr * (v_data[i] - v_neg[i]);
}
for j in 0..self.n_hidden {
self.hidden_bias[j] += lr * (h_pos_probs[j] - h_neg_probs[j]);
}
let recon = self.visible_probs(&h_pos);
let err: f64 = v_data
.iter()
.zip(&recon)
.map(|(&v, &r)| (v - r).powi(2))
.sum::<f64>()
/ self.n_visible as f64;
Ok(err)
}
pub fn train_epoch(&mut self, data: &[Vec<f64>]) -> Result<f64> {
if data.is_empty() {
return Ok(0.0);
}
let total_err: f64 = data
.iter()
.map(|v| self.train_step(v))
.collect::<Result<Vec<f64>>>()?
.iter()
.sum();
Ok(total_err / data.len() as f64)
}
pub fn reconstruct(&mut self, v: &[f64]) -> Vec<f64> {
let h = self.sample_hidden(v);
self.visible_probs(&h)
}
pub fn generate(&mut self, steps: usize) -> Vec<f64> {
let mut v: Vec<f64> = (0..self.n_visible)
.map(|_| bernoulli_sample(0.5, &mut self.rng_state))
.collect();
for _ in 0..steps {
let h = self.sample_hidden(&v);
v = self.sample_visible(&h);
}
v
}
}
impl EnergyFunction for RestrictedBoltzmannMachine {
fn energy(&self, x: &[f64]) -> f64 {
self.free_energy(x)
}
}
pub struct DeepBoltzmannMachine {
rbms: Vec<RestrictedBoltzmannMachine>,
pub layer_sizes: Vec<usize>,
pub learning_rate: f64,
pub cd_steps: usize,
}
impl std::fmt::Debug for DeepBoltzmannMachine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DeepBoltzmannMachine")
.field("layer_sizes", &self.layer_sizes)
.field("num_rbms", &self.rbms.len())
.finish()
}
}
impl DeepBoltzmannMachine {
pub fn new(layer_sizes: Vec<usize>, learning_rate: f64, cd_steps: usize) -> Result<Self> {
if layer_sizes.len() < 2 {
return Err(NeuralError::InvalidArgument(
"DBM requires at least 2 layer sizes (visible + 1 hidden)".to_string(),
));
}
let rbms = layer_sizes
.windows(2)
.map(|w| RestrictedBoltzmannMachine::new(w[0], w[1], learning_rate, cd_steps))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
rbms,
layer_sizes,
learning_rate,
cd_steps,
})
}
pub fn pretrain(&mut self, data: &[Vec<f64>], epochs_per_layer: usize) -> Result<()> {
let mut current_data: Vec<Vec<f64>> = data.to_vec();
for (layer_idx, rbm) in self.rbms.iter_mut().enumerate() {
for epoch in 0..epochs_per_layer {
let err = rbm.train_epoch(¤t_data)?;
if epoch % 10 == 0 {
let _ = err; }
let _ = (layer_idx, epoch);
}
current_data = current_data
.iter()
.map(|v| rbm.hidden_probs(v))
.collect();
}
Ok(())
}
pub fn encode(&self, v: &[f64]) -> Result<Vec<f64>> {
if v.len() != self.layer_sizes[0] {
return Err(NeuralError::ShapeMismatch(format!(
"DBM encode: expected {} inputs, got {}",
self.layer_sizes[0],
v.len()
)));
}
let mut h = v.to_vec();
for rbm in &self.rbms {
h = rbm.hidden_probs(&h);
}
Ok(h)
}
pub fn total_energy(&self, v: &[f64]) -> Result<f64> {
if v.len() != self.layer_sizes[0] {
return Err(NeuralError::ShapeMismatch(format!(
"DBM total_energy: expected {} inputs, got {}",
self.layer_sizes[0],
v.len()
)));
}
let mut energy = 0.0f64;
let mut current = v.to_vec();
for rbm in &self.rbms {
energy += rbm.free_energy(¤t);
current = rbm.hidden_probs(¤t);
}
Ok(energy)
}
pub fn num_layers(&self) -> usize {
self.rbms.len()
}
}
#[derive(Debug, Clone)]
pub struct LangevinConfig {
pub step_size: f64,
pub num_steps: usize,
pub noise_scale: f64,
pub grad_clip: f64,
pub fd_eps: f64,
}
impl Default for LangevinConfig {
fn default() -> Self {
Self {
step_size: 0.01,
num_steps: 20,
noise_scale: 1.0,
grad_clip: 1.0,
fd_eps: 1e-3,
}
}
}
pub struct EnergyBasedModel {
pub energy_fn: Box<dyn EnergyFunction>,
pub langevin_config: LangevinConfig,
rng_state: u64,
}
impl std::fmt::Debug for EnergyBasedModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EnergyBasedModel")
.field("langevin_config", &self.langevin_config)
.finish()
}
}
impl EnergyBasedModel {
pub fn new(energy_fn: Box<dyn EnergyFunction>, langevin_config: LangevinConfig) -> Self {
Self {
energy_fn,
langevin_config,
rng_state: 0x1234567890abcdef,
}
}
pub fn langevin_sample(&mut self, x_init: &[f64]) -> Vec<f64> {
let cfg = &self.langevin_config;
let mut x = x_init.to_vec();
for _ in 0..cfg.num_steps {
let grad = self.energy_fn.energy_gradient(&x, cfg.fd_eps);
let noise_std = (2.0 * cfg.step_size).sqrt() * cfg.noise_scale;
for i in 0..x.len() {
let mut g = grad[i];
if cfg.grad_clip > 0.0 {
g = g.clamp(-cfg.grad_clip, cfg.grad_clip);
}
let noise = normal_sample(&mut self.rng_state) * noise_std;
x[i] -= cfg.step_size * g + noise;
}
}
x
}
pub fn energy(&self, x: &[f64]) -> f64 {
self.energy_fn.energy(x)
}
pub fn log_likelihood_estimate(&mut self, x: &[f64], n_samples: usize) -> f64 {
let e_data = self.energy(x);
let mut noise_init: Vec<f64> = (0..x.len())
.map(|_| normal_sample(&mut self.rng_state))
.collect();
let mut energies = Vec::with_capacity(n_samples);
for _ in 0..n_samples {
let sample = self.langevin_sample(&noise_init);
energies.push(self.energy(&sample));
noise_init = sample;
}
let mean_model_e = energies.iter().sum::<f64>() / n_samples as f64;
-(e_data - mean_model_e)
}
}
pub struct ContrastiveDivergenceK {
pub rbm: RestrictedBoltzmannMachine,
pub k: usize,
}
impl std::fmt::Debug for ContrastiveDivergenceK {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ContrastiveDivergenceK")
.field("rbm", &format!("RBM({}, {})", self.rbm.n_visible, self.rbm.n_hidden))
.field("k", &self.k)
.finish()
}
}
impl ContrastiveDivergenceK {
pub fn new(
n_visible: usize,
n_hidden: usize,
learning_rate: f64,
k: usize,
) -> Result<Self> {
let rbm = RestrictedBoltzmannMachine::new(n_visible, n_hidden, learning_rate, k)?;
Ok(Self { rbm, k })
}
pub fn train_epoch(&mut self, data: &[Vec<f64>]) -> Result<f64> {
self.rbm.train_epoch(data)
}
pub fn train(&mut self, data: &[Vec<f64>], epochs: usize) -> Result<Vec<f64>> {
(0..epochs).map(|_| self.train_epoch(data)).collect()
}
pub fn pseudo_log_likelihood(&self, v: &[f64]) -> Result<f64> {
if v.len() != self.rbm.n_visible {
return Err(NeuralError::ShapeMismatch(format!(
"CD-k PLL: expected {} visible, got {}",
self.rbm.n_visible,
v.len()
)));
}
let fe = self.rbm.free_energy(v);
let mut pll = 0.0f64;
let mut v_flip = v.to_vec();
for i in 0..self.rbm.n_visible {
v_flip[i] = 1.0 - v_flip[i];
let fe_flip = self.rbm.free_energy(&v_flip);
v_flip[i] = v[i];
let diff = fe_flip - fe;
pll += -sigmoid(-diff).ln();
}
Ok(-pll)
}
}
pub struct PersistentCD {
pub rbm: RestrictedBoltzmannMachine,
pub num_chains: usize,
chain_state: Vec<Vec<f64>>,
pub gibbs_steps: usize,
}
impl std::fmt::Debug for PersistentCD {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PersistentCD")
.field("n_visible", &self.rbm.n_visible)
.field("n_hidden", &self.rbm.n_hidden)
.field("num_chains", &self.num_chains)
.field("gibbs_steps", &self.gibbs_steps)
.finish()
}
}
impl PersistentCD {
pub fn new(
n_visible: usize,
n_hidden: usize,
learning_rate: f64,
num_chains: usize,
gibbs_steps: usize,
) -> Result<Self> {
if num_chains == 0 {
return Err(NeuralError::InvalidArgument(
"PersistentCD: num_chains must be >= 1".to_string(),
));
}
if gibbs_steps == 0 {
return Err(NeuralError::InvalidArgument(
"PersistentCD: gibbs_steps must be >= 1".to_string(),
));
}
let rbm = RestrictedBoltzmannMachine::new(n_visible, n_hidden, learning_rate, 1)?;
let mut rng_state: u64 = 0xfeedcafedeadbeef;
let chain_state = (0..num_chains)
.map(|_| {
(0..n_visible)
.map(|_| bernoulli_sample(0.5, &mut rng_state))
.collect::<Vec<f64>>()
})
.collect();
Ok(Self {
rbm,
num_chains,
chain_state,
gibbs_steps,
})
}
pub fn train_step(&mut self, batch: &[Vec<f64>]) -> Result<f64> {
if batch.is_empty() {
return Ok(0.0);
}
let lr = self.rbm.learning_rate;
let n_v = self.rbm.n_visible;
let n_h = self.rbm.n_hidden;
let mut dw_pos = vec![0.0f64; n_v * n_h];
let mut dv_pos = vec![0.0f64; n_v];
let mut dh_pos = vec![0.0f64; n_h];
for v in batch.iter() {
if v.len() != n_v {
return Err(NeuralError::ShapeMismatch(format!(
"PCD train_step: expected {n_v} visible, got {}",
v.len()
)));
}
let h_probs = self.rbm.hidden_probs(v);
for i in 0..n_v {
dv_pos[i] += v[i];
for j in 0..n_h {
dw_pos[i * n_h + j] += v[i] * h_probs[j];
}
}
for j in 0..n_h {
dh_pos[j] += h_probs[j];
}
}
let bs = batch.len() as f64;
let mut dw_neg = vec![0.0f64; n_v * n_h];
let mut dv_neg = vec![0.0f64; n_v];
let mut dh_neg = vec![0.0f64; n_h];
for c in 0..self.num_chains {
for _ in 0..self.gibbs_steps {
let h = self.rbm.sample_hidden(&self.chain_state[c].clone());
self.chain_state[c] = self.rbm.sample_visible(&h);
}
let v_neg = &self.chain_state[c];
let h_probs_neg = self.rbm.hidden_probs(v_neg);
for i in 0..n_v {
dv_neg[i] += v_neg[i];
for j in 0..n_h {
dw_neg[i * n_h + j] += v_neg[i] * h_probs_neg[j];
}
}
for j in 0..n_h {
dh_neg[j] += h_probs_neg[j];
}
}
let chains_f = self.num_chains as f64;
for i in 0..n_v {
for j in 0..n_h {
self.rbm.weights[i * n_h + j] +=
lr * (dw_pos[i * n_h + j] / bs - dw_neg[i * n_h + j] / chains_f);
}
self.rbm.visible_bias[i] += lr * (dv_pos[i] / bs - dv_neg[i] / chains_f);
}
for j in 0..n_h {
self.rbm.hidden_bias[j] += lr * (dh_pos[j] / bs - dh_neg[j] / chains_f);
}
let recon_err: f64 = batch
.iter()
.map(|v| {
let h = self.rbm.hidden_probs(v);
let v_recon = self.rbm.visible_probs(&h);
v.iter()
.zip(&v_recon)
.map(|(&vi, &ri)| (vi - ri).powi(2))
.sum::<f64>()
/ n_v as f64
})
.sum::<f64>()
/ bs;
Ok(recon_err)
}
pub fn train_epoch(&mut self, data: &[Vec<f64>], batch_size: usize) -> Result<f64> {
if data.is_empty() {
return Ok(0.0);
}
let actual_bs = batch_size.max(1);
let mut total_err = 0.0f64;
let mut count = 0usize;
let mut start = 0;
while start < data.len() {
let end = (start + actual_bs).min(data.len());
let batch = &data[start..end];
total_err += self.train_step(batch)?;
count += 1;
start = end;
}
Ok(if count > 0 { total_err / count as f64 } else { 0.0 })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rbm_creation() {
let rbm = RestrictedBoltzmannMachine::new(4, 3, 0.01, 1).expect("RBM creation failed");
assert_eq!(rbm.n_visible, 4);
assert_eq!(rbm.n_hidden, 3);
}
#[test]
fn test_rbm_hidden_probs_range() {
let rbm = RestrictedBoltzmannMachine::new(4, 3, 0.01, 1).expect("RBM creation");
let v = vec![1.0, 0.0, 1.0, 0.0];
let probs = rbm.hidden_probs(&v);
assert_eq!(probs.len(), 3);
for &p in &probs {
assert!(p >= 0.0 && p <= 1.0, "prob {p} out of range");
}
}
#[test]
fn test_rbm_train_step() {
let mut rbm = RestrictedBoltzmannMachine::new(4, 3, 0.01, 1).expect("RBM creation");
let v = vec![1.0, 0.0, 1.0, 0.0];
let err = rbm.train_step(&v).expect("train step failed");
assert!(err >= 0.0, "reconstruction error must be non-negative");
}
#[test]
fn test_rbm_free_energy() {
let rbm = RestrictedBoltzmannMachine::new(4, 3, 0.01, 1).expect("RBM creation");
let v = vec![0.5, 0.5, 0.5, 0.5];
let fe = rbm.free_energy(&v);
assert!(fe.is_finite(), "free energy must be finite");
}
#[test]
fn test_dbm_creation_and_pretrain() {
let mut dbm =
DeepBoltzmannMachine::new(vec![4, 3, 2], 0.01, 1).expect("DBM creation failed");
let data = vec![
vec![1.0, 0.0, 1.0, 0.0],
vec![0.0, 1.0, 0.0, 1.0],
vec![1.0, 1.0, 0.0, 0.0],
];
dbm.pretrain(&data, 2).expect("pretraining failed");
let enc = dbm.encode(&data[0]).expect("encode failed");
assert_eq!(enc.len(), 2);
}
#[test]
fn test_cd_k_pseudo_log_likelihood() {
let cdk = ContrastiveDivergenceK::new(4, 3, 0.01, 1).expect("CDk creation");
let v = vec![1.0, 0.0, 1.0, 0.0];
let pll = cdk.pseudo_log_likelihood(&v).expect("pll failed");
assert!(pll.is_finite(), "PLL should be finite");
}
#[test]
fn test_persistent_cd_train() {
let mut pcd = PersistentCD::new(4, 3, 0.01, 4, 1).expect("PCD creation");
let data = vec![
vec![1.0, 0.0, 1.0, 0.0],
vec![0.0, 1.0, 0.0, 1.0],
];
let err = pcd.train_epoch(&data, 2).expect("PCD epoch failed");
assert!(err >= 0.0);
}
struct QuadraticEnergy {
center: Vec<f64>,
}
impl EnergyFunction for QuadraticEnergy {
fn energy(&self, x: &[f64]) -> f64 {
x.iter()
.zip(&self.center)
.map(|(&xi, &ci)| (xi - ci).powi(2))
.sum()
}
}
#[test]
fn test_ebm_langevin_converges() {
let energy_fn = Box::new(QuadraticEnergy {
center: vec![1.0, 2.0],
});
let cfg = LangevinConfig {
step_size: 0.05,
num_steps: 100,
noise_scale: 0.01,
grad_clip: 5.0,
fd_eps: 1e-3,
};
let mut ebm = EnergyBasedModel::new(energy_fn, cfg);
let x_init = vec![0.0, 0.0];
let sample = ebm.langevin_sample(&x_init);
assert!(
(sample[0] - 1.0).abs() < 0.5,
"Langevin should converge: x[0]={}", sample[0]
);
assert!(
(sample[1] - 2.0).abs() < 0.5,
"Langevin should converge: x[1]={}", sample[1]
);
}
}