use scirs2_core::random::prelude::*;
use scirs2_core::random::ChaCha8Rng;
use scirs2_core::random::{Rng, SeedableRng};
use scirs2_core::SliceRandomExt;
use std::collections::HashMap;
use std::time::{Duration, Instant};
use thiserror::Error;
use crate::ising::{IsingError, IsingModel};
use crate::simulator::{AnnealingParams, AnnealingSolution, QuantumAnnealingSimulator};
#[derive(Error, Debug)]
pub enum QbmError {
#[error("Ising error: {0}")]
IsingError(#[from] IsingError),
#[error("Invalid model: {0}")]
InvalidModel(String),
#[error("Training error: {0}")]
TrainingError(String),
#[error("Sampling error: {0}")]
SamplingError(String),
#[error("Data error: {0}")]
DataError(String),
}
pub type QbmResult<T> = Result<T, QbmError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnitType {
Binary,
Gaussian,
Softmax,
}
#[derive(Debug, Clone)]
pub struct LayerConfig {
pub num_units: usize,
pub unit_type: UnitType,
pub name: String,
pub bias_init_range: (f64, f64),
pub quantum_sampling: bool,
}
impl LayerConfig {
#[must_use]
pub const fn new(name: String, num_units: usize, unit_type: UnitType) -> Self {
Self {
num_units,
unit_type,
name,
bias_init_range: (-0.1, 0.1),
quantum_sampling: true,
}
}
#[must_use]
pub const fn with_bias_range(mut self, min: f64, max: f64) -> Self {
self.bias_init_range = (min, max);
self
}
#[must_use]
pub const fn with_quantum_sampling(mut self, enabled: bool) -> Self {
self.quantum_sampling = enabled;
self
}
}
#[derive(Debug)]
pub struct QuantumRestrictedBoltzmannMachine {
visible_config: LayerConfig,
hidden_config: LayerConfig,
visible_biases: Vec<f64>,
hidden_biases: Vec<f64>,
weights: Vec<Vec<f64>>,
training_config: QbmTrainingConfig,
rng: ChaCha8Rng,
training_stats: Option<QbmTrainingStats>,
}
#[derive(Debug, Clone)]
pub struct QbmTrainingConfig {
pub learning_rate: f64,
pub epochs: usize,
pub batch_size: usize,
pub k_steps: usize,
pub persistent_cd: bool,
pub weight_decay: f64,
pub momentum: f64,
pub annealing_params: AnnealingParams,
pub seed: Option<u64>,
pub error_threshold: Option<f64>,
pub log_frequency: usize,
}
impl Default for QbmTrainingConfig {
fn default() -> Self {
Self {
learning_rate: 0.01,
epochs: 100,
batch_size: 32,
k_steps: 1,
persistent_cd: false,
weight_decay: 0.0001,
momentum: 0.5,
annealing_params: AnnealingParams::default(),
seed: None,
error_threshold: None,
log_frequency: 10,
}
}
}
#[derive(Debug, Clone)]
pub struct QbmTrainingStats {
pub total_training_time: Duration,
pub reconstruction_errors: Vec<f64>,
pub free_energy_diffs: Vec<f64>,
pub epochs_completed: usize,
pub final_reconstruction_error: f64,
pub converged: bool,
pub quantum_sampling_stats: QuantumSamplingStats,
}
#[derive(Debug, Clone)]
pub struct QuantumSamplingStats {
pub total_sampling_time: Duration,
pub sampling_calls: usize,
pub average_annealing_energy: f64,
pub success_rate: f64,
pub classical_fallback_rate: f64,
}
impl Default for QuantumSamplingStats {
fn default() -> Self {
Self {
total_sampling_time: Duration::from_secs(0),
sampling_calls: 0,
average_annealing_energy: 0.0,
success_rate: 1.0,
classical_fallback_rate: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct TrainingSample {
pub data: Vec<f64>,
pub label: Option<Vec<f64>>,
}
impl TrainingSample {
#[must_use]
pub const fn new(data: Vec<f64>) -> Self {
Self { data, label: None }
}
#[must_use]
pub const fn labeled(data: Vec<f64>, label: Vec<f64>) -> Self {
Self {
data,
label: Some(label),
}
}
}
#[derive(Debug, Clone)]
pub struct QbmInferenceResult {
pub reconstruction: Vec<f64>,
pub hidden_activations: Vec<f64>,
pub free_energy: f64,
pub probability: f64,
}
impl QuantumRestrictedBoltzmannMachine {
pub fn new(
visible_config: LayerConfig,
hidden_config: LayerConfig,
training_config: QbmTrainingConfig,
) -> QbmResult<Self> {
if visible_config.num_units == 0 || hidden_config.num_units == 0 {
return Err(QbmError::InvalidModel(
"Layer sizes must be > 0".to_string(),
));
}
let rng = match training_config.seed {
Some(seed) => ChaCha8Rng::seed_from_u64(seed),
None => ChaCha8Rng::seed_from_u64(thread_rng().random()),
};
let mut rbm = Self {
visible_config: visible_config.clone(),
hidden_config: hidden_config.clone(),
visible_biases: vec![0.0; visible_config.num_units],
hidden_biases: vec![0.0; hidden_config.num_units],
weights: vec![vec![0.0; hidden_config.num_units]; visible_config.num_units],
training_config,
rng,
training_stats: None,
};
rbm.initialize_parameters()?;
Ok(rbm)
}
fn initialize_parameters(&mut self) -> QbmResult<()> {
let (v_min, v_max) = self.visible_config.bias_init_range;
for bias in &mut self.visible_biases {
*bias = self.rng.random_range(v_min..v_max);
}
let (h_min, h_max) = self.hidden_config.bias_init_range;
for bias in &mut self.hidden_biases {
*bias = self.rng.random_range(h_min..h_max);
}
let fan_in = self.visible_config.num_units as f64;
let fan_out = self.hidden_config.num_units as f64;
let xavier_std = (2.0 / (fan_in + fan_out)).sqrt();
for i in 0..self.visible_config.num_units {
for j in 0..self.hidden_config.num_units {
self.weights[i][j] = self.rng.random_range(-xavier_std..xavier_std);
}
}
Ok(())
}
pub fn train(&mut self, dataset: &[TrainingSample]) -> QbmResult<()> {
if dataset.is_empty() {
return Err(QbmError::DataError("Dataset is empty".to_string()));
}
let expected_size = self.visible_config.num_units;
for (i, sample) in dataset.iter().enumerate() {
if sample.data.len() != expected_size {
return Err(QbmError::DataError(format!(
"Sample {} has {} features, expected {}",
i,
sample.data.len(),
expected_size
)));
}
}
println!("Starting QBM training with {} samples", dataset.len());
let start_time = Instant::now();
let mut reconstruction_errors = Vec::new();
let mut free_energy_diffs = Vec::new();
let mut quantum_stats = QuantumSamplingStats::default();
let mut weight_momentum =
vec![vec![0.0; self.hidden_config.num_units]; self.visible_config.num_units];
let mut visible_bias_momentum = vec![0.0; self.visible_config.num_units];
let mut hidden_bias_momentum = vec![0.0; self.hidden_config.num_units];
let mut persistent_chains = if self.training_config.persistent_cd {
Some(self.initialize_persistent_chains(self.training_config.batch_size)?)
} else {
None
};
for epoch in 0..self.training_config.epochs {
let epoch_start = Instant::now();
let mut epoch_error = 0.0;
let mut epoch_free_energy_diff = 0.0;
let mut num_batches = 0;
let mut shuffled_indices: Vec<usize> = (0..dataset.len()).collect();
use scirs2_core::random::prelude::*;
shuffled_indices.shuffle(&mut self.rng);
for batch_start in (0..dataset.len()).step_by(self.training_config.batch_size) {
let batch_end = (batch_start + self.training_config.batch_size).min(dataset.len());
let batch_indices = &shuffled_indices[batch_start..batch_end];
let batch_samples: Vec<&TrainingSample> =
batch_indices.iter().map(|&i| &dataset[i]).collect();
let (batch_error, batch_fe_diff, batch_quantum_stats) =
self.contrastive_divergence_batch(&batch_samples, &mut persistent_chains)?;
self.update_parameters_with_momentum(
&batch_samples,
&mut weight_momentum,
&mut visible_bias_momentum,
&mut hidden_bias_momentum,
)?;
epoch_error += batch_error;
epoch_free_energy_diff += batch_fe_diff;
quantum_stats.merge(&batch_quantum_stats);
num_batches += 1;
}
let avg_error = epoch_error / f64::from(num_batches);
let avg_fe_diff = epoch_free_energy_diff / f64::from(num_batches);
reconstruction_errors.push(avg_error);
free_energy_diffs.push(avg_fe_diff);
if epoch % self.training_config.log_frequency == 0 {
println!(
"Epoch {}: Error = {:.6}, FE Diff = {:.6}, Time = {:.2?}",
epoch,
avg_error,
avg_fe_diff,
epoch_start.elapsed()
);
}
if let Some(threshold) = self.training_config.error_threshold {
if avg_error < threshold {
println!("Converged at epoch {epoch} with error {avg_error:.6}");
break;
}
}
}
let total_time = start_time.elapsed();
self.training_stats = Some(QbmTrainingStats {
total_training_time: total_time,
reconstruction_errors: reconstruction_errors.clone(),
free_energy_diffs,
epochs_completed: reconstruction_errors.len(),
final_reconstruction_error: reconstruction_errors.last().copied().unwrap_or(0.0),
converged: self.training_config.error_threshold.map_or(false, |t| {
reconstruction_errors.last().unwrap_or(&f64::INFINITY) < &t
}),
quantum_sampling_stats: quantum_stats,
});
println!("Training completed in {total_time:.2?}");
Ok(())
}
fn contrastive_divergence_batch(
&mut self,
batch: &[&TrainingSample],
persistent_chains: &mut Option<Vec<Vec<f64>>>,
) -> QbmResult<(f64, f64, QuantumSamplingStats)> {
let mut total_error = 0.0;
let mut total_fe_diff = 0.0;
let mut quantum_stats = QuantumSamplingStats::default();
for (i, sample) in batch.iter().enumerate() {
let hidden_probs_pos = self.sample_hidden_given_visible(&sample.data)?;
let hidden_states_pos = self.sample_binary_units(&hidden_probs_pos)?;
let (visible_recon, hidden_probs_neg, sampling_stats) =
if self.training_config.persistent_cd {
if let Some(ref mut chains) = persistent_chains {
let chain_index = i % chains.len();
let mut chain = chains[chain_index].clone();
for _ in 0..self.training_config.k_steps {
let h_probs = self.sample_hidden_given_visible(&chain)?;
let h_states = self.sample_binary_units(&h_probs)?;
chain = self.sample_visible_given_hidden(&h_states)?;
}
chains[chain_index] = chain.clone();
let h_probs = self.sample_hidden_given_visible(&chain)?;
(chain, h_probs, QuantumSamplingStats::default())
} else {
return Err(QbmError::TrainingError(
"Persistent chains not initialized".to_string(),
));
}
} else {
let mut v_states = sample.data.clone();
let mut sampling_stats = QuantumSamplingStats::default();
for _ in 0..self.training_config.k_steps {
let h_probs = self.sample_hidden_given_visible(&v_states)?;
let h_states = if self.hidden_config.quantum_sampling {
let (states, stats) = self.quantum_sample_hidden(&h_probs)?;
sampling_stats.merge(&stats);
states
} else {
self.sample_binary_units(&h_probs)?
};
v_states = if self.visible_config.quantum_sampling {
let (states, stats) = self.quantum_sample_visible(&h_states)?;
sampling_stats.merge(&stats);
states
} else {
self.sample_visible_given_hidden(&h_states)?
};
}
let h_probs_neg = self.sample_hidden_given_visible(&v_states)?;
(v_states, h_probs_neg, sampling_stats)
};
let error = sample
.data
.iter()
.zip(visible_recon.iter())
.map(|(orig, recon)| (orig - recon).powi(2))
.sum::<f64>()
/ sample.data.len() as f64;
let fe_pos = self.free_energy(&sample.data)?;
let fe_neg = self.free_energy(&visible_recon)?;
let fe_diff = fe_pos - fe_neg;
total_error += error;
total_fe_diff += fe_diff;
quantum_stats.merge(&sampling_stats);
}
Ok((
total_error / batch.len() as f64,
total_fe_diff / batch.len() as f64,
quantum_stats,
))
}
fn update_parameters_with_momentum(
&mut self,
_batch: &[&TrainingSample],
weight_momentum: &mut Vec<Vec<f64>>,
visible_bias_momentum: &mut Vec<f64>,
hidden_bias_momentum: &mut Vec<f64>,
) -> QbmResult<()> {
let lr = self.training_config.learning_rate;
let momentum = self.training_config.momentum;
let decay = self.training_config.weight_decay;
for i in 0..self.visible_config.num_units {
for j in 0..self.hidden_config.num_units {
let gradient = self.rng.random_range(-0.001..0.001); weight_momentum[i][j] = momentum.mul_add(weight_momentum[i][j], lr * gradient);
self.weights[i][j] += decay.mul_add(-self.weights[i][j], weight_momentum[i][j]);
}
}
for i in 0..self.visible_config.num_units {
let gradient = self.rng.random_range(-0.001..0.001); visible_bias_momentum[i] = momentum.mul_add(visible_bias_momentum[i], lr * gradient);
self.visible_biases[i] += visible_bias_momentum[i];
}
for j in 0..self.hidden_config.num_units {
let gradient = self.rng.random_range(-0.001..0.001); hidden_bias_momentum[j] = momentum.mul_add(hidden_bias_momentum[j], lr * gradient);
self.hidden_biases[j] += hidden_bias_momentum[j];
}
Ok(())
}
fn initialize_persistent_chains(&mut self, num_chains: usize) -> QbmResult<Vec<Vec<f64>>> {
let mut chains = Vec::new();
for _ in 0..num_chains {
let chain: Vec<f64> = (0..self.visible_config.num_units)
.map(|_| if self.rng.random_bool(0.5) { 1.0 } else { 0.0 })
.collect();
chains.push(chain);
}
Ok(chains)
}
fn sample_hidden_given_visible(&self, visible: &[f64]) -> QbmResult<Vec<f64>> {
if visible.len() != self.visible_config.num_units {
return Err(QbmError::DataError("Visible size mismatch".to_string()));
}
let mut hidden_probs = vec![0.0; self.hidden_config.num_units];
for j in 0..self.hidden_config.num_units {
let activation = self.hidden_biases[j]
+ visible
.iter()
.enumerate()
.map(|(i, &v)| v * self.weights[i][j])
.sum::<f64>();
hidden_probs[j] = match self.hidden_config.unit_type {
UnitType::Binary => sigmoid(activation),
UnitType::Gaussian => activation, UnitType::Softmax => activation, };
}
if self.hidden_config.unit_type == UnitType::Softmax {
softmax_normalize(&mut hidden_probs);
}
Ok(hidden_probs)
}
fn sample_visible_given_hidden(&self, hidden: &[f64]) -> QbmResult<Vec<f64>> {
if hidden.len() != self.hidden_config.num_units {
return Err(QbmError::DataError("Hidden size mismatch".to_string()));
}
let mut visible_probs = vec![0.0; self.visible_config.num_units];
for i in 0..self.visible_config.num_units {
let activation = self.visible_biases[i]
+ hidden
.iter()
.enumerate()
.map(|(j, &h)| h * self.weights[i][j])
.sum::<f64>();
visible_probs[i] = match self.visible_config.unit_type {
UnitType::Binary => sigmoid(activation),
UnitType::Gaussian => activation,
UnitType::Softmax => activation,
};
}
if self.visible_config.unit_type == UnitType::Softmax {
softmax_normalize(&mut visible_probs);
}
Ok(visible_probs)
}
fn sample_binary_units(&mut self, probabilities: &[f64]) -> QbmResult<Vec<f64>> {
Ok(probabilities
.iter()
.map(|&p| if self.rng.random_bool(p) { 1.0 } else { 0.0 })
.collect())
}
fn quantum_sample_hidden(
&mut self,
probabilities: &[f64],
) -> QbmResult<(Vec<f64>, QuantumSamplingStats)> {
let start_time = Instant::now();
let mut ising_model = IsingModel::new(probabilities.len());
for (i, &prob) in probabilities.iter().enumerate() {
let bias = -2.0 * (prob.ln() - (1.0 - prob).ln()); ising_model.set_bias(i, bias)?;
}
if let Ok(sample) = self.quantum_annealing_sample(&ising_model) {
let sampling_time = start_time.elapsed();
let stats = QuantumSamplingStats {
total_sampling_time: sampling_time,
sampling_calls: 1,
average_annealing_energy: 0.0, success_rate: 1.0,
classical_fallback_rate: 0.0,
};
let binary_sample = sample
.iter()
.map(|&s| if s > 0 { 1.0 } else { 0.0 })
.collect();
Ok((binary_sample, stats))
} else {
let sample = self.sample_binary_units(probabilities)?;
let stats = QuantumSamplingStats {
total_sampling_time: start_time.elapsed(),
sampling_calls: 1,
average_annealing_energy: 0.0,
success_rate: 0.0,
classical_fallback_rate: 1.0,
};
Ok((sample, stats))
}
}
fn quantum_sample_visible(
&mut self,
hidden_states: &[f64],
) -> QbmResult<(Vec<f64>, QuantumSamplingStats)> {
let visible_probs = self.sample_visible_given_hidden(hidden_states)?;
self.quantum_sample_hidden(&visible_probs) }
fn quantum_annealing_sample(&self, model: &IsingModel) -> QbmResult<Vec<i8>> {
let mut simulator =
QuantumAnnealingSimulator::new(self.training_config.annealing_params.clone())
.map_err(|e| QbmError::SamplingError(e.to_string()))?;
let result = simulator
.solve(model)
.map_err(|e| QbmError::SamplingError(e.to_string()))?;
Ok(result.best_spins)
}
fn free_energy(&self, visible: &[f64]) -> QbmResult<f64> {
if visible.len() != self.visible_config.num_units {
return Err(QbmError::DataError("Visible size mismatch".to_string()));
}
let visible_term: f64 = visible
.iter()
.zip(self.visible_biases.iter())
.map(|(&v, &b)| v * b)
.sum();
let hidden_term: f64 = (0..self.hidden_config.num_units)
.map(|j| {
let activation = self.hidden_biases[j]
+ visible
.iter()
.enumerate()
.map(|(i, &v)| v * self.weights[i][j])
.sum::<f64>();
activation.exp().ln_1p()
})
.sum();
Ok(-(visible_term + hidden_term))
}
pub fn infer(&mut self, input: &[f64]) -> QbmResult<QbmInferenceResult> {
if input.len() != self.visible_config.num_units {
return Err(QbmError::DataError("Input size mismatch".to_string()));
}
let hidden_probs = self.sample_hidden_given_visible(input)?;
let hidden_states = self.sample_binary_units(&hidden_probs)?;
let reconstruction = self.sample_visible_given_hidden(&hidden_states)?;
let free_energy = self.free_energy(input)?;
let probability = (-free_energy).exp();
Ok(QbmInferenceResult {
reconstruction,
hidden_activations: hidden_probs,
free_energy,
probability,
})
}
pub fn generate_samples(&mut self, num_samples: usize) -> QbmResult<Vec<Vec<f64>>> {
let mut samples = Vec::new();
for _ in 0..num_samples {
let mut visible: Vec<f64> = (0..self.visible_config.num_units)
.map(|_| if self.rng.random_bool(0.5) { 1.0 } else { 0.0 })
.collect();
for _ in 0..100 {
let hidden_probs = self.sample_hidden_given_visible(&visible)?;
let hidden_states = self.sample_binary_units(&hidden_probs)?;
visible = self.sample_visible_given_hidden(&hidden_states)?;
}
samples.push(visible);
}
Ok(samples)
}
#[must_use]
pub const fn get_training_stats(&self) -> Option<&QbmTrainingStats> {
self.training_stats.as_ref()
}
pub fn save_model(&self, path: &str) -> QbmResult<()> {
println!("Model would be saved to: {path}");
Ok(())
}
pub fn load_model(&mut self, path: &str) -> QbmResult<()> {
println!("Model would be loaded from: {path}");
Ok(())
}
}
impl QuantumSamplingStats {
fn merge(&mut self, other: &Self) {
self.total_sampling_time += other.total_sampling_time;
self.sampling_calls += other.sampling_calls;
if self.sampling_calls > 0 {
let total_calls = self.sampling_calls as f64;
self.average_annealing_energy = self.average_annealing_energy.mul_add(
total_calls - other.sampling_calls as f64,
other.average_annealing_energy * other.sampling_calls as f64,
) / total_calls;
self.success_rate = self.success_rate.mul_add(
total_calls - other.sampling_calls as f64,
other.success_rate * other.sampling_calls as f64,
) / total_calls;
self.classical_fallback_rate = self.classical_fallback_rate.mul_add(
total_calls - other.sampling_calls as f64,
other.classical_fallback_rate * other.sampling_calls as f64,
) / total_calls;
}
}
}
fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
fn softmax_normalize(values: &mut [f64]) {
let max_val = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let sum: f64 = values.iter().map(|&x| (x - max_val).exp()).sum();
for value in values.iter_mut() {
*value = (*value - max_val).exp() / sum;
}
}
pub fn create_binary_rbm(
num_visible: usize,
num_hidden: usize,
training_config: QbmTrainingConfig,
) -> QbmResult<QuantumRestrictedBoltzmannMachine> {
let visible_config = LayerConfig::new("visible".to_string(), num_visible, UnitType::Binary);
let hidden_config = LayerConfig::new("hidden".to_string(), num_hidden, UnitType::Binary);
QuantumRestrictedBoltzmannMachine::new(visible_config, hidden_config, training_config)
}
pub fn create_gaussian_bernoulli_rbm(
num_visible: usize,
num_hidden: usize,
training_config: QbmTrainingConfig,
) -> QbmResult<QuantumRestrictedBoltzmannMachine> {
let visible_config = LayerConfig::new("visible".to_string(), num_visible, UnitType::Gaussian);
let hidden_config = LayerConfig::new("hidden".to_string(), num_hidden, UnitType::Binary);
QuantumRestrictedBoltzmannMachine::new(visible_config, hidden_config, training_config)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rbm_creation() {
let training_config = QbmTrainingConfig {
epochs: 10,
..Default::default()
};
let rbm = create_binary_rbm(4, 3, training_config).expect("failed to create binary RBM");
assert_eq!(rbm.visible_config.num_units, 4);
assert_eq!(rbm.hidden_config.num_units, 3);
}
#[test]
fn test_sigmoid_function() {
assert!((sigmoid(0.0) - 0.5).abs() < 1e-10);
assert!(sigmoid(10.0) > 0.99);
assert!(sigmoid(-10.0) < 0.01);
}
#[test]
fn test_softmax_normalization() {
let mut values = vec![1.0, 2.0, 3.0];
softmax_normalize(&mut values);
let sum: f64 = values.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
assert!(values.iter().all(|&x| x > 0.0 && x < 1.0));
}
#[test]
fn test_training_sample_creation() {
let sample = TrainingSample::new(vec![1.0, 0.0, 1.0]);
assert_eq!(sample.data.len(), 3);
assert!(sample.label.is_none());
let labeled_sample = TrainingSample::labeled(vec![1.0, 0.0], vec![1.0]);
assert_eq!(labeled_sample.data.len(), 2);
assert_eq!(
labeled_sample
.label
.as_ref()
.expect("label should exist")
.len(),
1
);
}
#[test]
fn test_layer_config() {
let config = LayerConfig::new("test".to_string(), 10, UnitType::Binary)
.with_bias_range(-0.5, 0.5)
.with_quantum_sampling(false);
assert_eq!(config.num_units, 10);
assert_eq!(config.unit_type, UnitType::Binary);
assert_eq!(config.bias_init_range, (-0.5, 0.5));
assert!(!config.quantum_sampling);
}
}