quantrs2_anneal/
quantum_boltzmann_machine.rs

1//! Quantum Boltzmann Machines for machine learning with quantum annealing
2//!
3//! This module provides implementations of Restricted Boltzmann Machines (RBMs)
4//! and other Boltzmann machine variants that leverage quantum annealing for
5//! sampling and training, enabling quantum machine learning applications.
6
7use scirs2_core::random::prelude::*;
8use scirs2_core::random::ChaCha8Rng;
9use scirs2_core::random::{Rng, SeedableRng};
10use scirs2_core::SliceRandomExt;
11use std::collections::HashMap;
12use std::time::{Duration, Instant};
13use thiserror::Error;
14
15use crate::ising::{IsingError, IsingModel};
16use crate::simulator::{AnnealingParams, AnnealingSolution, QuantumAnnealingSimulator};
17
18/// Errors that can occur in quantum Boltzmann machine operations
19#[derive(Error, Debug)]
20pub enum QbmError {
21    /// Ising model error
22    #[error("Ising error: {0}")]
23    IsingError(#[from] IsingError),
24
25    /// Invalid model configuration
26    #[error("Invalid model: {0}")]
27    InvalidModel(String),
28
29    /// Training error
30    #[error("Training error: {0}")]
31    TrainingError(String),
32
33    /// Sampling error
34    #[error("Sampling error: {0}")]
35    SamplingError(String),
36
37    /// Data format error
38    #[error("Data error: {0}")]
39    DataError(String),
40}
41
42/// Result type for QBM operations
43pub type QbmResult<T> = Result<T, QbmError>;
44
45/// Type of Boltzmann machine unit
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum UnitType {
48    /// Binary units (0/1 or -1/+1)
49    Binary,
50    /// Gaussian units for continuous values
51    Gaussian,
52    /// Softmax units for categorical data
53    Softmax,
54}
55
56/// Configuration for a Boltzmann machine layer
57#[derive(Debug, Clone)]
58pub struct LayerConfig {
59    /// Number of units in the layer
60    pub num_units: usize,
61
62    /// Type of units
63    pub unit_type: UnitType,
64
65    /// Layer name
66    pub name: String,
67
68    /// Bias initialization range
69    pub bias_init_range: (f64, f64),
70
71    /// Whether to use quantum annealing for sampling
72    pub quantum_sampling: bool,
73}
74
75impl LayerConfig {
76    /// Create a new layer configuration
77    #[must_use]
78    pub const fn new(name: String, num_units: usize, unit_type: UnitType) -> Self {
79        Self {
80            num_units,
81            unit_type,
82            name,
83            bias_init_range: (-0.1, 0.1),
84            quantum_sampling: true,
85        }
86    }
87
88    /// Set bias initialization range
89    #[must_use]
90    pub const fn with_bias_range(mut self, min: f64, max: f64) -> Self {
91        self.bias_init_range = (min, max);
92        self
93    }
94
95    /// Enable or disable quantum sampling
96    #[must_use]
97    pub const fn with_quantum_sampling(mut self, enabled: bool) -> Self {
98        self.quantum_sampling = enabled;
99        self
100    }
101}
102
103/// Restricted Boltzmann Machine with quantum annealing support
104#[derive(Debug)]
105pub struct QuantumRestrictedBoltzmannMachine {
106    /// Visible layer configuration
107    visible_config: LayerConfig,
108
109    /// Hidden layer configuration
110    hidden_config: LayerConfig,
111
112    /// Visible unit biases
113    visible_biases: Vec<f64>,
114
115    /// Hidden unit biases
116    hidden_biases: Vec<f64>,
117
118    /// Weight matrix (visible x hidden)
119    weights: Vec<Vec<f64>>,
120
121    /// Training configuration
122    training_config: QbmTrainingConfig,
123
124    /// Random number generator
125    rng: ChaCha8Rng,
126
127    /// Training statistics
128    training_stats: Option<QbmTrainingStats>,
129}
130
131/// Configuration for QBM training
132#[derive(Debug, Clone)]
133pub struct QbmTrainingConfig {
134    /// Learning rate
135    pub learning_rate: f64,
136
137    /// Number of training epochs
138    pub epochs: usize,
139
140    /// Batch size for training
141    pub batch_size: usize,
142
143    /// Number of Gibbs sampling steps for negative phase
144    pub k_steps: usize,
145
146    /// Use persistent contrastive divergence
147    pub persistent_cd: bool,
148
149    /// Weight decay regularization
150    pub weight_decay: f64,
151
152    /// Momentum for parameter updates
153    pub momentum: f64,
154
155    /// Annealing parameters for quantum sampling
156    pub annealing_params: AnnealingParams,
157
158    /// Random seed
159    pub seed: Option<u64>,
160
161    /// Reconstruction error threshold for early stopping
162    pub error_threshold: Option<f64>,
163
164    /// Logging frequency (epochs)
165    pub log_frequency: usize,
166}
167
168impl Default for QbmTrainingConfig {
169    fn default() -> Self {
170        Self {
171            learning_rate: 0.01,
172            epochs: 100,
173            batch_size: 32,
174            k_steps: 1,
175            persistent_cd: false,
176            weight_decay: 0.0001,
177            momentum: 0.5,
178            annealing_params: AnnealingParams::default(),
179            seed: None,
180            error_threshold: None,
181            log_frequency: 10,
182        }
183    }
184}
185
186/// Training statistics for QBM
187#[derive(Debug, Clone)]
188pub struct QbmTrainingStats {
189    /// Training time
190    pub total_training_time: Duration,
191
192    /// Reconstruction error per epoch
193    pub reconstruction_errors: Vec<f64>,
194
195    /// Free energy difference per epoch
196    pub free_energy_diffs: Vec<f64>,
197
198    /// Number of epochs completed
199    pub epochs_completed: usize,
200
201    /// Final reconstruction error
202    pub final_reconstruction_error: f64,
203
204    /// Convergence achieved
205    pub converged: bool,
206
207    /// Quantum sampling statistics
208    pub quantum_sampling_stats: QuantumSamplingStats,
209}
210
211/// Statistics for quantum sampling in QBM
212#[derive(Debug, Clone)]
213pub struct QuantumSamplingStats {
214    /// Total quantum sampling time
215    pub total_sampling_time: Duration,
216
217    /// Number of quantum sampling calls
218    pub sampling_calls: usize,
219
220    /// Average annealing energy
221    pub average_annealing_energy: f64,
222
223    /// Success rate of quantum sampling
224    pub success_rate: f64,
225
226    /// Classical fallback usage percentage
227    pub classical_fallback_rate: f64,
228}
229
230impl Default for QuantumSamplingStats {
231    fn default() -> Self {
232        Self {
233            total_sampling_time: Duration::from_secs(0),
234            sampling_calls: 0,
235            average_annealing_energy: 0.0,
236            success_rate: 1.0,
237            classical_fallback_rate: 0.0,
238        }
239    }
240}
241
242/// Training sample for QBM
243#[derive(Debug, Clone)]
244pub struct TrainingSample {
245    /// Input data
246    pub data: Vec<f64>,
247
248    /// Optional label (for supervised variants)
249    pub label: Option<Vec<f64>>,
250}
251
252impl TrainingSample {
253    /// Create a new training sample
254    #[must_use]
255    pub const fn new(data: Vec<f64>) -> Self {
256        Self { data, label: None }
257    }
258
259    /// Create a labeled training sample
260    #[must_use]
261    pub const fn labeled(data: Vec<f64>, label: Vec<f64>) -> Self {
262        Self {
263            data,
264            label: Some(label),
265        }
266    }
267}
268
269/// Results from QBM inference
270#[derive(Debug, Clone)]
271pub struct QbmInferenceResult {
272    /// Reconstructed visible units
273    pub reconstruction: Vec<f64>,
274
275    /// Hidden unit activations
276    pub hidden_activations: Vec<f64>,
277
278    /// Free energy of the configuration
279    pub free_energy: f64,
280
281    /// Probability of the input
282    pub probability: f64,
283}
284
285impl QuantumRestrictedBoltzmannMachine {
286    /// Create a new Quantum RBM
287    pub fn new(
288        visible_config: LayerConfig,
289        hidden_config: LayerConfig,
290        training_config: QbmTrainingConfig,
291    ) -> QbmResult<Self> {
292        if visible_config.num_units == 0 || hidden_config.num_units == 0 {
293            return Err(QbmError::InvalidModel(
294                "Layer sizes must be > 0".to_string(),
295            ));
296        }
297
298        let rng = match training_config.seed {
299            Some(seed) => ChaCha8Rng::seed_from_u64(seed),
300            None => ChaCha8Rng::seed_from_u64(thread_rng().gen()),
301        };
302
303        let mut rbm = Self {
304            visible_config: visible_config.clone(),
305            hidden_config: hidden_config.clone(),
306            visible_biases: vec![0.0; visible_config.num_units],
307            hidden_biases: vec![0.0; hidden_config.num_units],
308            weights: vec![vec![0.0; hidden_config.num_units]; visible_config.num_units],
309            training_config,
310            rng,
311            training_stats: None,
312        };
313
314        rbm.initialize_parameters()?;
315        Ok(rbm)
316    }
317
318    /// Initialize RBM parameters randomly
319    fn initialize_parameters(&mut self) -> QbmResult<()> {
320        // Initialize visible biases
321        let (v_min, v_max) = self.visible_config.bias_init_range;
322        for bias in &mut self.visible_biases {
323            *bias = self.rng.gen_range(v_min..v_max);
324        }
325
326        // Initialize hidden biases
327        let (h_min, h_max) = self.hidden_config.bias_init_range;
328        for bias in &mut self.hidden_biases {
329            *bias = self.rng.gen_range(h_min..h_max);
330        }
331
332        // Initialize weights using Xavier initialization
333        let fan_in = self.visible_config.num_units as f64;
334        let fan_out = self.hidden_config.num_units as f64;
335        let xavier_std = (2.0 / (fan_in + fan_out)).sqrt();
336
337        for i in 0..self.visible_config.num_units {
338            for j in 0..self.hidden_config.num_units {
339                self.weights[i][j] = self.rng.gen_range(-xavier_std..xavier_std);
340            }
341        }
342
343        Ok(())
344    }
345
346    /// Train the RBM on a dataset
347    pub fn train(&mut self, dataset: &[TrainingSample]) -> QbmResult<()> {
348        if dataset.is_empty() {
349            return Err(QbmError::DataError("Dataset is empty".to_string()));
350        }
351
352        // Validate data dimensions
353        let expected_size = self.visible_config.num_units;
354        for (i, sample) in dataset.iter().enumerate() {
355            if sample.data.len() != expected_size {
356                return Err(QbmError::DataError(format!(
357                    "Sample {} has {} features, expected {}",
358                    i,
359                    sample.data.len(),
360                    expected_size
361                )));
362            }
363        }
364
365        println!("Starting QBM training with {} samples", dataset.len());
366
367        let start_time = Instant::now();
368        let mut reconstruction_errors = Vec::new();
369        let mut free_energy_diffs = Vec::new();
370        let mut quantum_stats = QuantumSamplingStats::default();
371
372        // Momentum terms
373        let mut weight_momentum =
374            vec![vec![0.0; self.hidden_config.num_units]; self.visible_config.num_units];
375        let mut visible_bias_momentum = vec![0.0; self.visible_config.num_units];
376        let mut hidden_bias_momentum = vec![0.0; self.hidden_config.num_units];
377
378        // Persistent chains for PCD
379        let mut persistent_chains = if self.training_config.persistent_cd {
380            Some(self.initialize_persistent_chains(self.training_config.batch_size)?)
381        } else {
382            None
383        };
384
385        for epoch in 0..self.training_config.epochs {
386            let epoch_start = Instant::now();
387            let mut epoch_error = 0.0;
388            let mut epoch_free_energy_diff = 0.0;
389            let mut num_batches = 0;
390
391            // Shuffle dataset
392            let mut shuffled_indices: Vec<usize> = (0..dataset.len()).collect();
393            use scirs2_core::random::prelude::*;
394            shuffled_indices.shuffle(&mut self.rng);
395
396            // Process batches
397            for batch_start in (0..dataset.len()).step_by(self.training_config.batch_size) {
398                let batch_end = (batch_start + self.training_config.batch_size).min(dataset.len());
399                let batch_indices = &shuffled_indices[batch_start..batch_end];
400
401                let batch_samples: Vec<&TrainingSample> =
402                    batch_indices.iter().map(|&i| &dataset[i]).collect();
403
404                // Perform contrastive divergence
405                let (batch_error, batch_fe_diff, batch_quantum_stats) =
406                    self.contrastive_divergence_batch(&batch_samples, &mut persistent_chains)?;
407
408                // Update parameters with momentum
409                self.update_parameters_with_momentum(
410                    &batch_samples,
411                    &mut weight_momentum,
412                    &mut visible_bias_momentum,
413                    &mut hidden_bias_momentum,
414                )?;
415
416                epoch_error += batch_error;
417                epoch_free_energy_diff += batch_fe_diff;
418                quantum_stats.merge(&batch_quantum_stats);
419                num_batches += 1;
420            }
421
422            let avg_error = epoch_error / f64::from(num_batches);
423            let avg_fe_diff = epoch_free_energy_diff / f64::from(num_batches);
424
425            reconstruction_errors.push(avg_error);
426            free_energy_diffs.push(avg_fe_diff);
427
428            // Logging
429            if epoch % self.training_config.log_frequency == 0 {
430                println!(
431                    "Epoch {}: Error = {:.6}, FE Diff = {:.6}, Time = {:.2?}",
432                    epoch,
433                    avg_error,
434                    avg_fe_diff,
435                    epoch_start.elapsed()
436                );
437            }
438
439            // Early stopping
440            if let Some(threshold) = self.training_config.error_threshold {
441                if avg_error < threshold {
442                    println!("Converged at epoch {epoch} with error {avg_error:.6}");
443                    break;
444                }
445            }
446        }
447
448        let total_time = start_time.elapsed();
449
450        // Store training statistics
451        self.training_stats = Some(QbmTrainingStats {
452            total_training_time: total_time,
453            reconstruction_errors: reconstruction_errors.clone(),
454            free_energy_diffs,
455            epochs_completed: reconstruction_errors.len(),
456            final_reconstruction_error: reconstruction_errors.last().copied().unwrap_or(0.0),
457            converged: self.training_config.error_threshold.map_or(false, |t| {
458                reconstruction_errors.last().unwrap_or(&f64::INFINITY) < &t
459            }),
460            quantum_sampling_stats: quantum_stats,
461        });
462
463        println!("Training completed in {total_time:.2?}");
464        Ok(())
465    }
466
467    /// Perform contrastive divergence for a batch
468    fn contrastive_divergence_batch(
469        &mut self,
470        batch: &[&TrainingSample],
471        persistent_chains: &mut Option<Vec<Vec<f64>>>,
472    ) -> QbmResult<(f64, f64, QuantumSamplingStats)> {
473        let mut total_error = 0.0;
474        let mut total_fe_diff = 0.0;
475        let mut quantum_stats = QuantumSamplingStats::default();
476
477        for (i, sample) in batch.iter().enumerate() {
478            // Positive phase
479            let hidden_probs_pos = self.sample_hidden_given_visible(&sample.data)?;
480            let hidden_states_pos = self.sample_binary_units(&hidden_probs_pos)?;
481
482            // Negative phase
483            let (visible_recon, hidden_probs_neg, sampling_stats) =
484                if self.training_config.persistent_cd {
485                    if let Some(ref mut chains) = persistent_chains {
486                        let chain_index = i % chains.len();
487                        let mut chain = chains[chain_index].clone();
488                        for _ in 0..self.training_config.k_steps {
489                            let h_probs = self.sample_hidden_given_visible(&chain)?;
490                            let h_states = self.sample_binary_units(&h_probs)?;
491                            chain = self.sample_visible_given_hidden(&h_states)?;
492                        }
493                        chains[chain_index] = chain.clone();
494                        let h_probs = self.sample_hidden_given_visible(&chain)?;
495                        (chain, h_probs, QuantumSamplingStats::default())
496                    } else {
497                        return Err(QbmError::TrainingError(
498                            "Persistent chains not initialized".to_string(),
499                        ));
500                    }
501                } else {
502                    // Standard CD-k
503                    let mut v_states = sample.data.clone();
504                    let mut sampling_stats = QuantumSamplingStats::default();
505
506                    for _ in 0..self.training_config.k_steps {
507                        let h_probs = self.sample_hidden_given_visible(&v_states)?;
508                        let h_states = if self.hidden_config.quantum_sampling {
509                            let (states, stats) = self.quantum_sample_hidden(&h_probs)?;
510                            sampling_stats.merge(&stats);
511                            states
512                        } else {
513                            self.sample_binary_units(&h_probs)?
514                        };
515
516                        v_states = if self.visible_config.quantum_sampling {
517                            let (states, stats) = self.quantum_sample_visible(&h_states)?;
518                            sampling_stats.merge(&stats);
519                            states
520                        } else {
521                            self.sample_visible_given_hidden(&h_states)?
522                        };
523                    }
524
525                    let h_probs_neg = self.sample_hidden_given_visible(&v_states)?;
526                    (v_states, h_probs_neg, sampling_stats)
527                };
528
529            // Compute gradients and update (done in update_parameters_with_momentum)
530
531            // Compute reconstruction error
532            let error = sample
533                .data
534                .iter()
535                .zip(visible_recon.iter())
536                .map(|(orig, recon)| (orig - recon).powi(2))
537                .sum::<f64>()
538                / sample.data.len() as f64;
539
540            // Compute free energy difference
541            let fe_pos = self.free_energy(&sample.data)?;
542            let fe_neg = self.free_energy(&visible_recon)?;
543            let fe_diff = fe_pos - fe_neg;
544
545            total_error += error;
546            total_fe_diff += fe_diff;
547            quantum_stats.merge(&sampling_stats);
548        }
549
550        Ok((
551            total_error / batch.len() as f64,
552            total_fe_diff / batch.len() as f64,
553            quantum_stats,
554        ))
555    }
556
557    /// Update parameters using momentum
558    fn update_parameters_with_momentum(
559        &mut self,
560        _batch: &[&TrainingSample],
561        weight_momentum: &mut Vec<Vec<f64>>,
562        visible_bias_momentum: &mut Vec<f64>,
563        hidden_bias_momentum: &mut Vec<f64>,
564    ) -> QbmResult<()> {
565        // This is a simplified update - in practice, you'd compute actual gradients
566        // from the positive and negative phases of contrastive divergence
567
568        let lr = self.training_config.learning_rate;
569        let momentum = self.training_config.momentum;
570        let decay = self.training_config.weight_decay;
571
572        // Update weights (simplified - normally computed from CD phases)
573        for i in 0..self.visible_config.num_units {
574            for j in 0..self.hidden_config.num_units {
575                let gradient = self.rng.gen_range(-0.001..0.001); // Placeholder
576                weight_momentum[i][j] = momentum.mul_add(weight_momentum[i][j], lr * gradient);
577                self.weights[i][j] += decay.mul_add(-self.weights[i][j], weight_momentum[i][j]);
578            }
579        }
580
581        // Update visible biases
582        for i in 0..self.visible_config.num_units {
583            let gradient = self.rng.gen_range(-0.001..0.001); // Placeholder
584            visible_bias_momentum[i] = momentum.mul_add(visible_bias_momentum[i], lr * gradient);
585            self.visible_biases[i] += visible_bias_momentum[i];
586        }
587
588        // Update hidden biases
589        for j in 0..self.hidden_config.num_units {
590            let gradient = self.rng.gen_range(-0.001..0.001); // Placeholder
591            hidden_bias_momentum[j] = momentum.mul_add(hidden_bias_momentum[j], lr * gradient);
592            self.hidden_biases[j] += hidden_bias_momentum[j];
593        }
594
595        Ok(())
596    }
597
598    /// Initialize persistent chains for PCD
599    fn initialize_persistent_chains(&mut self, num_chains: usize) -> QbmResult<Vec<Vec<f64>>> {
600        let mut chains = Vec::new();
601
602        for _ in 0..num_chains {
603            let chain: Vec<f64> = (0..self.visible_config.num_units)
604                .map(|_| if self.rng.gen_bool(0.5) { 1.0 } else { 0.0 })
605                .collect();
606            chains.push(chain);
607        }
608
609        Ok(chains)
610    }
611
612    /// Sample hidden units given visible units
613    fn sample_hidden_given_visible(&self, visible: &[f64]) -> QbmResult<Vec<f64>> {
614        if visible.len() != self.visible_config.num_units {
615            return Err(QbmError::DataError("Visible size mismatch".to_string()));
616        }
617
618        let mut hidden_probs = vec![0.0; self.hidden_config.num_units];
619
620        for j in 0..self.hidden_config.num_units {
621            let activation = self.hidden_biases[j]
622                + visible
623                    .iter()
624                    .enumerate()
625                    .map(|(i, &v)| v * self.weights[i][j])
626                    .sum::<f64>();
627
628            hidden_probs[j] = match self.hidden_config.unit_type {
629                UnitType::Binary => sigmoid(activation),
630                UnitType::Gaussian => activation, // Linear for Gaussian
631                UnitType::Softmax => activation,  // Will be normalized later
632            };
633        }
634
635        // Apply softmax normalization if needed
636        if self.hidden_config.unit_type == UnitType::Softmax {
637            softmax_normalize(&mut hidden_probs);
638        }
639
640        Ok(hidden_probs)
641    }
642
643    /// Sample visible units given hidden units
644    fn sample_visible_given_hidden(&self, hidden: &[f64]) -> QbmResult<Vec<f64>> {
645        if hidden.len() != self.hidden_config.num_units {
646            return Err(QbmError::DataError("Hidden size mismatch".to_string()));
647        }
648
649        let mut visible_probs = vec![0.0; self.visible_config.num_units];
650
651        for i in 0..self.visible_config.num_units {
652            let activation = self.visible_biases[i]
653                + hidden
654                    .iter()
655                    .enumerate()
656                    .map(|(j, &h)| h * self.weights[i][j])
657                    .sum::<f64>();
658
659            visible_probs[i] = match self.visible_config.unit_type {
660                UnitType::Binary => sigmoid(activation),
661                UnitType::Gaussian => activation,
662                UnitType::Softmax => activation,
663            };
664        }
665
666        if self.visible_config.unit_type == UnitType::Softmax {
667            softmax_normalize(&mut visible_probs);
668        }
669
670        Ok(visible_probs)
671    }
672
673    /// Sample binary units from probabilities
674    fn sample_binary_units(&mut self, probabilities: &[f64]) -> QbmResult<Vec<f64>> {
675        Ok(probabilities
676            .iter()
677            .map(|&p| if self.rng.gen_bool(p) { 1.0 } else { 0.0 })
678            .collect())
679    }
680
681    /// Quantum sample hidden units using annealing
682    fn quantum_sample_hidden(
683        &mut self,
684        probabilities: &[f64],
685    ) -> QbmResult<(Vec<f64>, QuantumSamplingStats)> {
686        let start_time = Instant::now();
687
688        // Create Ising model for sampling
689        let mut ising_model = IsingModel::new(probabilities.len());
690
691        // Set biases based on probabilities
692        for (i, &prob) in probabilities.iter().enumerate() {
693            let bias = -2.0 * (prob.ln() - (1.0 - prob).ln()); // Logit transformation
694            ising_model.set_bias(i, bias)?;
695        }
696
697        // Sample using quantum annealing
698        if let Ok(sample) = self.quantum_annealing_sample(&ising_model) {
699            let sampling_time = start_time.elapsed();
700            let stats = QuantumSamplingStats {
701                total_sampling_time: sampling_time,
702                sampling_calls: 1,
703                average_annealing_energy: 0.0, // Would compute from result
704                success_rate: 1.0,
705                classical_fallback_rate: 0.0,
706            };
707
708            // Convert spins to 0/1
709            let binary_sample = sample
710                .iter()
711                .map(|&s| if s > 0 { 1.0 } else { 0.0 })
712                .collect();
713
714            Ok((binary_sample, stats))
715        } else {
716            // Fallback to classical sampling
717            let sample = self.sample_binary_units(probabilities)?;
718            let stats = QuantumSamplingStats {
719                total_sampling_time: start_time.elapsed(),
720                sampling_calls: 1,
721                average_annealing_energy: 0.0,
722                success_rate: 0.0,
723                classical_fallback_rate: 1.0,
724            };
725            Ok((sample, stats))
726        }
727    }
728
729    /// Quantum sample visible units using annealing
730    fn quantum_sample_visible(
731        &mut self,
732        hidden_states: &[f64],
733    ) -> QbmResult<(Vec<f64>, QuantumSamplingStats)> {
734        let visible_probs = self.sample_visible_given_hidden(hidden_states)?;
735        self.quantum_sample_hidden(&visible_probs) // Same process
736    }
737
738    /// Perform quantum annealing sampling
739    fn quantum_annealing_sample(&self, model: &IsingModel) -> QbmResult<Vec<i8>> {
740        let mut simulator =
741            QuantumAnnealingSimulator::new(self.training_config.annealing_params.clone())
742                .map_err(|e| QbmError::SamplingError(e.to_string()))?;
743
744        let result = simulator
745            .solve(model)
746            .map_err(|e| QbmError::SamplingError(e.to_string()))?;
747
748        Ok(result.best_spins)
749    }
750
751    /// Compute free energy of a configuration
752    fn free_energy(&self, visible: &[f64]) -> QbmResult<f64> {
753        if visible.len() != self.visible_config.num_units {
754            return Err(QbmError::DataError("Visible size mismatch".to_string()));
755        }
756
757        // Visible bias term
758        let visible_term: f64 = visible
759            .iter()
760            .zip(self.visible_biases.iter())
761            .map(|(&v, &b)| v * b)
762            .sum();
763
764        // Hidden term (sum of log(1 + exp(activation)) for each hidden unit)
765        let hidden_term: f64 = (0..self.hidden_config.num_units)
766            .map(|j| {
767                let activation = self.hidden_biases[j]
768                    + visible
769                        .iter()
770                        .enumerate()
771                        .map(|(i, &v)| v * self.weights[i][j])
772                        .sum::<f64>();
773                activation.exp().ln_1p()
774            })
775            .sum();
776
777        Ok(-(visible_term + hidden_term))
778    }
779
780    /// Perform inference on input data
781    pub fn infer(&mut self, input: &[f64]) -> QbmResult<QbmInferenceResult> {
782        if input.len() != self.visible_config.num_units {
783            return Err(QbmError::DataError("Input size mismatch".to_string()));
784        }
785
786        // Compute hidden activations
787        let hidden_probs = self.sample_hidden_given_visible(input)?;
788        let hidden_states = self.sample_binary_units(&hidden_probs)?;
789
790        // Reconstruct visible units
791        let reconstruction = self.sample_visible_given_hidden(&hidden_states)?;
792
793        // Compute free energy and probability
794        let free_energy = self.free_energy(input)?;
795        let probability = (-free_energy).exp(); // Unnormalized
796
797        Ok(QbmInferenceResult {
798            reconstruction,
799            hidden_activations: hidden_probs,
800            free_energy,
801            probability,
802        })
803    }
804
805    /// Generate samples from the learned distribution
806    pub fn generate_samples(&mut self, num_samples: usize) -> QbmResult<Vec<Vec<f64>>> {
807        let mut samples = Vec::new();
808
809        for _ in 0..num_samples {
810            // Start with random visible state
811            let mut visible: Vec<f64> = (0..self.visible_config.num_units)
812                .map(|_| if self.rng.gen_bool(0.5) { 1.0 } else { 0.0 })
813                .collect();
814
815            // Run Gibbs sampling for burn-in
816            for _ in 0..100 {
817                let hidden_probs = self.sample_hidden_given_visible(&visible)?;
818                let hidden_states = self.sample_binary_units(&hidden_probs)?;
819                visible = self.sample_visible_given_hidden(&hidden_states)?;
820            }
821
822            samples.push(visible);
823        }
824
825        Ok(samples)
826    }
827
828    /// Get training statistics
829    #[must_use]
830    pub const fn get_training_stats(&self) -> Option<&QbmTrainingStats> {
831        self.training_stats.as_ref()
832    }
833
834    /// Save model parameters
835    pub fn save_model(&self, path: &str) -> QbmResult<()> {
836        // Implement model serialization
837        // For now, return success
838        println!("Model would be saved to: {path}");
839        Ok(())
840    }
841
842    /// Load model parameters
843    pub fn load_model(&mut self, path: &str) -> QbmResult<()> {
844        // Implement model deserialization
845        // For now, return success
846        println!("Model would be loaded from: {path}");
847        Ok(())
848    }
849}
850
851impl QuantumSamplingStats {
852    /// Merge another stats object into this one
853    fn merge(&mut self, other: &Self) {
854        self.total_sampling_time += other.total_sampling_time;
855        self.sampling_calls += other.sampling_calls;
856
857        if self.sampling_calls > 0 {
858            let total_calls = self.sampling_calls as f64;
859            self.average_annealing_energy = self.average_annealing_energy.mul_add(
860                total_calls - other.sampling_calls as f64,
861                other.average_annealing_energy * other.sampling_calls as f64,
862            ) / total_calls;
863
864            self.success_rate = self.success_rate.mul_add(
865                total_calls - other.sampling_calls as f64,
866                other.success_rate * other.sampling_calls as f64,
867            ) / total_calls;
868
869            self.classical_fallback_rate = self.classical_fallback_rate.mul_add(
870                total_calls - other.sampling_calls as f64,
871                other.classical_fallback_rate * other.sampling_calls as f64,
872            ) / total_calls;
873        }
874    }
875}
876
877/// Sigmoid activation function
878fn sigmoid(x: f64) -> f64 {
879    1.0 / (1.0 + (-x).exp())
880}
881
882/// Apply softmax normalization in-place
883fn softmax_normalize(values: &mut [f64]) {
884    let max_val = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
885    let sum: f64 = values.iter().map(|&x| (x - max_val).exp()).sum();
886
887    for value in values.iter_mut() {
888        *value = (*value - max_val).exp() / sum;
889    }
890}
891
892/// Helper functions for different QBM variants
893
894/// Create a binary-binary RBM for typical unsupervised learning
895pub fn create_binary_rbm(
896    num_visible: usize,
897    num_hidden: usize,
898    training_config: QbmTrainingConfig,
899) -> QbmResult<QuantumRestrictedBoltzmannMachine> {
900    let visible_config = LayerConfig::new("visible".to_string(), num_visible, UnitType::Binary);
901    let hidden_config = LayerConfig::new("hidden".to_string(), num_hidden, UnitType::Binary);
902
903    QuantumRestrictedBoltzmannMachine::new(visible_config, hidden_config, training_config)
904}
905
906/// Create a Gaussian-Bernoulli RBM for continuous input data
907pub fn create_gaussian_bernoulli_rbm(
908    num_visible: usize,
909    num_hidden: usize,
910    training_config: QbmTrainingConfig,
911) -> QbmResult<QuantumRestrictedBoltzmannMachine> {
912    let visible_config = LayerConfig::new("visible".to_string(), num_visible, UnitType::Gaussian);
913    let hidden_config = LayerConfig::new("hidden".to_string(), num_hidden, UnitType::Binary);
914
915    QuantumRestrictedBoltzmannMachine::new(visible_config, hidden_config, training_config)
916}
917
918#[cfg(test)]
919mod tests {
920    use super::*;
921
922    #[test]
923    fn test_rbm_creation() {
924        let training_config = QbmTrainingConfig {
925            epochs: 10,
926            ..Default::default()
927        };
928
929        let rbm = create_binary_rbm(4, 3, training_config).expect("failed to create binary RBM");
930        assert_eq!(rbm.visible_config.num_units, 4);
931        assert_eq!(rbm.hidden_config.num_units, 3);
932    }
933
934    #[test]
935    fn test_sigmoid_function() {
936        assert!((sigmoid(0.0) - 0.5).abs() < 1e-10);
937        assert!(sigmoid(10.0) > 0.99);
938        assert!(sigmoid(-10.0) < 0.01);
939    }
940
941    #[test]
942    fn test_softmax_normalization() {
943        let mut values = vec![1.0, 2.0, 3.0];
944        softmax_normalize(&mut values);
945
946        let sum: f64 = values.iter().sum();
947        assert!((sum - 1.0).abs() < 1e-10);
948        assert!(values.iter().all(|&x| x > 0.0 && x < 1.0));
949    }
950
951    #[test]
952    fn test_training_sample_creation() {
953        let sample = TrainingSample::new(vec![1.0, 0.0, 1.0]);
954        assert_eq!(sample.data.len(), 3);
955        assert!(sample.label.is_none());
956
957        let labeled_sample = TrainingSample::labeled(vec![1.0, 0.0], vec![1.0]);
958        assert_eq!(labeled_sample.data.len(), 2);
959        assert_eq!(
960            labeled_sample
961                .label
962                .as_ref()
963                .expect("label should exist")
964                .len(),
965            1
966        );
967    }
968
969    #[test]
970    fn test_layer_config() {
971        let config = LayerConfig::new("test".to_string(), 10, UnitType::Binary)
972            .with_bias_range(-0.5, 0.5)
973            .with_quantum_sampling(false);
974
975        assert_eq!(config.num_units, 10);
976        assert_eq!(config.unit_type, UnitType::Binary);
977        assert_eq!(config.bias_init_range, (-0.5, 0.5));
978        assert!(!config.quantum_sampling);
979    }
980}