use crate::error::{QuantRS2Error, QuantRS2Result};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::prelude::*;
use scirs2_core::Complex64;
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub struct QRBMConfig {
pub num_visible: usize,
pub num_hidden: usize,
pub learning_rate: f64,
pub k_steps: usize,
pub temperature: f64,
pub l2_reg: f64,
}
impl Default for QRBMConfig {
fn default() -> Self {
Self {
num_visible: 4,
num_hidden: 2,
learning_rate: 0.01,
k_steps: 1,
temperature: 1.0,
l2_reg: 0.001,
}
}
}
#[derive(Debug, Clone)]
pub struct QuantumRBM {
config: QRBMConfig,
weights: Array2<f64>,
visible_bias: Array1<f64>,
hidden_bias: Array1<f64>,
history: Vec<f64>,
}
impl QuantumRBM {
pub fn new(config: QRBMConfig) -> Self {
let mut rng = thread_rng();
let scale = 0.01;
let weights = Array2::from_shape_fn((config.num_visible, config.num_hidden), |_| {
rng.random_range(-scale..scale)
});
let visible_bias = Array1::zeros(config.num_visible);
let hidden_bias = Array1::zeros(config.num_hidden);
Self {
config,
weights,
visible_bias,
hidden_bias,
history: Vec::new(),
}
}
pub fn train_batch(&mut self, data: &[Array1<Complex64>]) -> QuantRS2Result<f64> {
let mut total_error = 0.0;
for state in data {
let visible = self.quantum_to_classical(state)?;
let hidden_probs = self.hidden_given_visible(&visible)?;
let hidden_sample = self.sample_binary(&hidden_probs)?;
let mut v_neg = visible.clone();
let mut h_neg = hidden_sample.clone();
for _ in 0..self.config.k_steps {
v_neg = self.visible_given_hidden(&h_neg)?;
h_neg = self.hidden_given_visible(&v_neg)?;
}
let pos_grad = self.outer_product(&visible, &hidden_probs);
let neg_grad = self.outer_product(&v_neg, &h_neg);
let grad = (pos_grad - neg_grad) / data.len() as f64;
self.weights = &self.weights + &(grad * self.config.learning_rate)
- &(&self.weights * self.config.l2_reg * self.config.learning_rate);
let visible_grad = &visible - &v_neg;
let hidden_grad = &hidden_probs - &h_neg;
self.visible_bias = &self.visible_bias + &(visible_grad * self.config.learning_rate);
self.hidden_bias = &self.hidden_bias + &(hidden_grad * self.config.learning_rate);
let error = (&visible - &v_neg)
.iter()
.map(|x| x * x)
.sum::<f64>()
.sqrt();
total_error += error;
}
let avg_error = total_error / data.len() as f64;
self.history.push(avg_error);
Ok(avg_error)
}
fn quantum_to_classical(&self, state: &Array1<Complex64>) -> QuantRS2Result<Array1<f64>> {
let dim = 1 << self.config.num_visible;
if state.len() != dim {
return Err(QuantRS2Error::InvalidInput(format!(
"State dimension {} doesn't match visible units 2^{}",
state.len(),
self.config.num_visible
)));
}
let mut probs = Array1::zeros(self.config.num_visible);
for q in 0..self.config.num_visible {
let mut prob_one = 0.0;
for i in 0..dim {
let bit = (i >> q) & 1;
if bit == 1 {
prob_one += state[i].norm_sqr();
}
}
probs[q] = prob_one;
}
Ok(probs)
}
fn hidden_given_visible(&self, visible: &Array1<f64>) -> QuantRS2Result<Array1<f64>> {
let mut hidden_probs = self.hidden_bias.clone();
for j in 0..self.config.num_hidden {
for i in 0..self.config.num_visible {
hidden_probs[j] += self.weights[[i, j]] * visible[i];
}
hidden_probs[j] = 1.0 / (1.0 + (-hidden_probs[j] / self.config.temperature).exp());
}
Ok(hidden_probs)
}
fn visible_given_hidden(&self, hidden: &Array1<f64>) -> QuantRS2Result<Array1<f64>> {
let mut visible_probs = self.visible_bias.clone();
for i in 0..self.config.num_visible {
for j in 0..self.config.num_hidden {
visible_probs[i] += self.weights[[i, j]] * hidden[j];
}
visible_probs[i] = 1.0 / (1.0 + (-visible_probs[i] / self.config.temperature).exp());
}
Ok(visible_probs)
}
fn sample_binary(&self, probs: &Array1<f64>) -> QuantRS2Result<Array1<f64>> {
let mut rng = thread_rng();
let mut samples = Array1::zeros(probs.len());
for i in 0..probs.len() {
samples[i] = if rng.random::<f64>() < probs[i] {
1.0
} else {
0.0
};
}
Ok(samples)
}
fn outer_product(&self, a: &Array1<f64>, b: &Array1<f64>) -> Array2<f64> {
let mut result = Array2::zeros((a.len(), b.len()));
for i in 0..a.len() {
for j in 0..b.len() {
result[[i, j]] = a[i] * b[j];
}
}
result
}
pub fn generate_sample(&self) -> QuantRS2Result<Array1<Complex64>> {
let mut rng = thread_rng();
let mut hidden = Array1::from_shape_fn(self.config.num_hidden, |_| {
if rng.random::<f64>() < 0.5 {
0.0
} else {
1.0
}
});
for _ in 0..100 {
let visible = self.visible_given_hidden(&hidden)?;
hidden = self.hidden_given_visible(&visible)?;
}
let visible_probs = self.visible_given_hidden(&hidden)?;
self.classical_to_quantum(&visible_probs)
}
fn classical_to_quantum(&self, probs: &Array1<f64>) -> QuantRS2Result<Array1<Complex64>> {
let dim = 1 << self.config.num_visible;
let mut state = Array1::zeros(dim);
for i in 0..dim {
let mut amplitude = 1.0;
for q in 0..self.config.num_visible {
let bit = (i >> q) & 1;
amplitude *= if bit == 1 {
probs[q].sqrt()
} else {
(1.0 - probs[q]).sqrt()
};
}
state[i] = Complex64::new(amplitude, 0.0);
}
let norm: f64 = state
.iter()
.map(|x: &Complex64| x.norm_sqr())
.sum::<f64>()
.sqrt();
for i in 0..dim {
state[i] = state[i] / norm;
}
Ok(state)
}
pub fn free_energy(&self, visible: &Array1<f64>) -> QuantRS2Result<f64> {
let mut energy = 0.0;
for i in 0..self.config.num_visible {
energy -= self.visible_bias[i] * visible[i];
}
for j in 0..self.config.num_hidden {
let mut h_input = self.hidden_bias[j];
for i in 0..self.config.num_visible {
h_input += self.weights[[i, j]] * visible[i];
}
energy -= h_input.exp().ln_1p();
}
Ok(energy)
}
pub fn history(&self) -> &[f64] {
&self.history
}
pub const fn weights(&self) -> &Array2<f64> {
&self.weights
}
}
#[derive(Debug)]
pub struct DeepQuantumBoltzmannMachine {
layers: Vec<QuantumRBM>,
layer_configs: Vec<QRBMConfig>,
}
impl DeepQuantumBoltzmannMachine {
pub fn new(layer_configs: Vec<QRBMConfig>) -> Self {
let layers = layer_configs
.iter()
.map(|config| QuantumRBM::new(config.clone()))
.collect();
Self {
layers,
layer_configs,
}
}
pub fn pretrain(
&mut self,
data: &[Array1<Complex64>],
epochs_per_layer: usize,
) -> QuantRS2Result<Vec<Vec<f64>>> {
let mut all_histories = Vec::new();
let mut current_data = data.to_vec();
let num_layers = self.layers.len();
for layer_idx in 0..num_layers {
println!("Pretraining layer {layer_idx}...");
let mut layer_history = Vec::new();
for epoch in 0..epochs_per_layer {
let error = self.layers[layer_idx].train_batch(¤t_data)?;
layer_history.push(error);
if epoch % 10 == 0 {
println!(" Epoch {epoch}: Error = {error:.6}");
}
}
all_histories.push(layer_history);
if layer_idx < num_layers - 1 {
current_data =
self.transform_to_next_layer(¤t_data, &self.layers[layer_idx])?;
}
}
Ok(all_histories)
}
fn transform_to_next_layer(
&self,
data: &[Array1<Complex64>],
layer: &QuantumRBM,
) -> QuantRS2Result<Vec<Array1<Complex64>>> {
let mut transformed = Vec::new();
for state in data {
let visible = layer.quantum_to_classical(state)?;
let hidden_probs = layer.hidden_given_visible(&visible)?;
transformed.push(layer.classical_to_quantum(&hidden_probs)?);
}
Ok(transformed)
}
pub fn generate(&self) -> QuantRS2Result<Array1<Complex64>> {
let mut sample = self
.layers
.last()
.ok_or_else(|| {
QuantRS2Error::RuntimeError(
"No layers in deep quantum Boltzmann machine".to_string(),
)
})?
.generate_sample()?;
for layer in self.layers.iter().rev().skip(1) {
let hidden = layer.quantum_to_classical(&sample)?;
let visible_probs = layer.visible_given_hidden(&hidden)?;
sample = layer.classical_to_quantum(&visible_probs)?;
}
Ok(sample)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_qrbm() {
let config = QRBMConfig {
num_visible: 2,
num_hidden: 2,
learning_rate: 0.01,
k_steps: 1,
temperature: 1.0,
l2_reg: 0.001,
};
let mut rbm = QuantumRBM::new(config);
let state = Array1::from_vec(vec![
Complex64::new(0.7, 0.0),
Complex64::new(0.3, 0.0),
Complex64::new(0.2, 0.0),
Complex64::new(0.6, 0.0),
]);
let error = rbm
.train_batch(&[state])
.expect("Failed to train quantum RBM on batch");
assert!(error >= 0.0);
}
#[test]
fn test_deep_qbm() {
let layer1 = QRBMConfig {
num_visible: 2,
num_hidden: 2,
..Default::default()
};
let layer2 = QRBMConfig {
num_visible: 2,
num_hidden: 1,
..Default::default()
};
let dbm = DeepQuantumBoltzmannMachine::new(vec![layer1, layer2]);
assert_eq!(dbm.layers.len(), 2);
}
}