1use scirs2_core::random::prelude::*;
8use scirs2_core::random::ChaCha8Rng;
9use scirs2_core::random::{Rng, SeedableRng};
10use scirs2_core::SliceRandomExt;
11use std::collections::HashMap;
12use std::time::{Duration, Instant};
13use thiserror::Error;
14
15use crate::ising::{IsingError, IsingModel};
16use crate::simulator::{AnnealingParams, AnnealingSolution, QuantumAnnealingSimulator};
17
18#[derive(Error, Debug)]
20pub enum QbmError {
21 #[error("Ising error: {0}")]
23 IsingError(#[from] IsingError),
24
25 #[error("Invalid model: {0}")]
27 InvalidModel(String),
28
29 #[error("Training error: {0}")]
31 TrainingError(String),
32
33 #[error("Sampling error: {0}")]
35 SamplingError(String),
36
37 #[error("Data error: {0}")]
39 DataError(String),
40}
41
42pub type QbmResult<T> = Result<T, QbmError>;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum UnitType {
48 Binary,
50 Gaussian,
52 Softmax,
54}
55
56#[derive(Debug, Clone)]
58pub struct LayerConfig {
59 pub num_units: usize,
61
62 pub unit_type: UnitType,
64
65 pub name: String,
67
68 pub bias_init_range: (f64, f64),
70
71 pub quantum_sampling: bool,
73}
74
75impl LayerConfig {
76 #[must_use]
78 pub const fn new(name: String, num_units: usize, unit_type: UnitType) -> Self {
79 Self {
80 num_units,
81 unit_type,
82 name,
83 bias_init_range: (-0.1, 0.1),
84 quantum_sampling: true,
85 }
86 }
87
88 #[must_use]
90 pub const fn with_bias_range(mut self, min: f64, max: f64) -> Self {
91 self.bias_init_range = (min, max);
92 self
93 }
94
95 #[must_use]
97 pub const fn with_quantum_sampling(mut self, enabled: bool) -> Self {
98 self.quantum_sampling = enabled;
99 self
100 }
101}
102
103#[derive(Debug)]
105pub struct QuantumRestrictedBoltzmannMachine {
106 visible_config: LayerConfig,
108
109 hidden_config: LayerConfig,
111
112 visible_biases: Vec<f64>,
114
115 hidden_biases: Vec<f64>,
117
118 weights: Vec<Vec<f64>>,
120
121 training_config: QbmTrainingConfig,
123
124 rng: ChaCha8Rng,
126
127 training_stats: Option<QbmTrainingStats>,
129}
130
131#[derive(Debug, Clone)]
133pub struct QbmTrainingConfig {
134 pub learning_rate: f64,
136
137 pub epochs: usize,
139
140 pub batch_size: usize,
142
143 pub k_steps: usize,
145
146 pub persistent_cd: bool,
148
149 pub weight_decay: f64,
151
152 pub momentum: f64,
154
155 pub annealing_params: AnnealingParams,
157
158 pub seed: Option<u64>,
160
161 pub error_threshold: Option<f64>,
163
164 pub log_frequency: usize,
166}
167
168impl Default for QbmTrainingConfig {
169 fn default() -> Self {
170 Self {
171 learning_rate: 0.01,
172 epochs: 100,
173 batch_size: 32,
174 k_steps: 1,
175 persistent_cd: false,
176 weight_decay: 0.0001,
177 momentum: 0.5,
178 annealing_params: AnnealingParams::default(),
179 seed: None,
180 error_threshold: None,
181 log_frequency: 10,
182 }
183 }
184}
185
186#[derive(Debug, Clone)]
188pub struct QbmTrainingStats {
189 pub total_training_time: Duration,
191
192 pub reconstruction_errors: Vec<f64>,
194
195 pub free_energy_diffs: Vec<f64>,
197
198 pub epochs_completed: usize,
200
201 pub final_reconstruction_error: f64,
203
204 pub converged: bool,
206
207 pub quantum_sampling_stats: QuantumSamplingStats,
209}
210
211#[derive(Debug, Clone)]
213pub struct QuantumSamplingStats {
214 pub total_sampling_time: Duration,
216
217 pub sampling_calls: usize,
219
220 pub average_annealing_energy: f64,
222
223 pub success_rate: f64,
225
226 pub classical_fallback_rate: f64,
228}
229
230impl Default for QuantumSamplingStats {
231 fn default() -> Self {
232 Self {
233 total_sampling_time: Duration::from_secs(0),
234 sampling_calls: 0,
235 average_annealing_energy: 0.0,
236 success_rate: 1.0,
237 classical_fallback_rate: 0.0,
238 }
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct TrainingSample {
245 pub data: Vec<f64>,
247
248 pub label: Option<Vec<f64>>,
250}
251
252impl TrainingSample {
253 #[must_use]
255 pub const fn new(data: Vec<f64>) -> Self {
256 Self { data, label: None }
257 }
258
259 #[must_use]
261 pub const fn labeled(data: Vec<f64>, label: Vec<f64>) -> Self {
262 Self {
263 data,
264 label: Some(label),
265 }
266 }
267}
268
269#[derive(Debug, Clone)]
271pub struct QbmInferenceResult {
272 pub reconstruction: Vec<f64>,
274
275 pub hidden_activations: Vec<f64>,
277
278 pub free_energy: f64,
280
281 pub probability: f64,
283}
284
285impl QuantumRestrictedBoltzmannMachine {
286 pub fn new(
288 visible_config: LayerConfig,
289 hidden_config: LayerConfig,
290 training_config: QbmTrainingConfig,
291 ) -> QbmResult<Self> {
292 if visible_config.num_units == 0 || hidden_config.num_units == 0 {
293 return Err(QbmError::InvalidModel(
294 "Layer sizes must be > 0".to_string(),
295 ));
296 }
297
298 let rng = match training_config.seed {
299 Some(seed) => ChaCha8Rng::seed_from_u64(seed),
300 None => ChaCha8Rng::seed_from_u64(thread_rng().gen()),
301 };
302
303 let mut rbm = Self {
304 visible_config: visible_config.clone(),
305 hidden_config: hidden_config.clone(),
306 visible_biases: vec![0.0; visible_config.num_units],
307 hidden_biases: vec![0.0; hidden_config.num_units],
308 weights: vec![vec![0.0; hidden_config.num_units]; visible_config.num_units],
309 training_config,
310 rng,
311 training_stats: None,
312 };
313
314 rbm.initialize_parameters()?;
315 Ok(rbm)
316 }
317
318 fn initialize_parameters(&mut self) -> QbmResult<()> {
320 let (v_min, v_max) = self.visible_config.bias_init_range;
322 for bias in &mut self.visible_biases {
323 *bias = self.rng.gen_range(v_min..v_max);
324 }
325
326 let (h_min, h_max) = self.hidden_config.bias_init_range;
328 for bias in &mut self.hidden_biases {
329 *bias = self.rng.gen_range(h_min..h_max);
330 }
331
332 let fan_in = self.visible_config.num_units as f64;
334 let fan_out = self.hidden_config.num_units as f64;
335 let xavier_std = (2.0 / (fan_in + fan_out)).sqrt();
336
337 for i in 0..self.visible_config.num_units {
338 for j in 0..self.hidden_config.num_units {
339 self.weights[i][j] = self.rng.gen_range(-xavier_std..xavier_std);
340 }
341 }
342
343 Ok(())
344 }
345
346 pub fn train(&mut self, dataset: &[TrainingSample]) -> QbmResult<()> {
348 if dataset.is_empty() {
349 return Err(QbmError::DataError("Dataset is empty".to_string()));
350 }
351
352 let expected_size = self.visible_config.num_units;
354 for (i, sample) in dataset.iter().enumerate() {
355 if sample.data.len() != expected_size {
356 return Err(QbmError::DataError(format!(
357 "Sample {} has {} features, expected {}",
358 i,
359 sample.data.len(),
360 expected_size
361 )));
362 }
363 }
364
365 println!("Starting QBM training with {} samples", dataset.len());
366
367 let start_time = Instant::now();
368 let mut reconstruction_errors = Vec::new();
369 let mut free_energy_diffs = Vec::new();
370 let mut quantum_stats = QuantumSamplingStats::default();
371
372 let mut weight_momentum =
374 vec![vec![0.0; self.hidden_config.num_units]; self.visible_config.num_units];
375 let mut visible_bias_momentum = vec![0.0; self.visible_config.num_units];
376 let mut hidden_bias_momentum = vec![0.0; self.hidden_config.num_units];
377
378 let mut persistent_chains = if self.training_config.persistent_cd {
380 Some(self.initialize_persistent_chains(self.training_config.batch_size)?)
381 } else {
382 None
383 };
384
385 for epoch in 0..self.training_config.epochs {
386 let epoch_start = Instant::now();
387 let mut epoch_error = 0.0;
388 let mut epoch_free_energy_diff = 0.0;
389 let mut num_batches = 0;
390
391 let mut shuffled_indices: Vec<usize> = (0..dataset.len()).collect();
393 use scirs2_core::random::prelude::*;
394 shuffled_indices.shuffle(&mut self.rng);
395
396 for batch_start in (0..dataset.len()).step_by(self.training_config.batch_size) {
398 let batch_end = (batch_start + self.training_config.batch_size).min(dataset.len());
399 let batch_indices = &shuffled_indices[batch_start..batch_end];
400
401 let batch_samples: Vec<&TrainingSample> =
402 batch_indices.iter().map(|&i| &dataset[i]).collect();
403
404 let (batch_error, batch_fe_diff, batch_quantum_stats) =
406 self.contrastive_divergence_batch(&batch_samples, &mut persistent_chains)?;
407
408 self.update_parameters_with_momentum(
410 &batch_samples,
411 &mut weight_momentum,
412 &mut visible_bias_momentum,
413 &mut hidden_bias_momentum,
414 )?;
415
416 epoch_error += batch_error;
417 epoch_free_energy_diff += batch_fe_diff;
418 quantum_stats.merge(&batch_quantum_stats);
419 num_batches += 1;
420 }
421
422 let avg_error = epoch_error / f64::from(num_batches);
423 let avg_fe_diff = epoch_free_energy_diff / f64::from(num_batches);
424
425 reconstruction_errors.push(avg_error);
426 free_energy_diffs.push(avg_fe_diff);
427
428 if epoch % self.training_config.log_frequency == 0 {
430 println!(
431 "Epoch {}: Error = {:.6}, FE Diff = {:.6}, Time = {:.2?}",
432 epoch,
433 avg_error,
434 avg_fe_diff,
435 epoch_start.elapsed()
436 );
437 }
438
439 if let Some(threshold) = self.training_config.error_threshold {
441 if avg_error < threshold {
442 println!("Converged at epoch {epoch} with error {avg_error:.6}");
443 break;
444 }
445 }
446 }
447
448 let total_time = start_time.elapsed();
449
450 self.training_stats = Some(QbmTrainingStats {
452 total_training_time: total_time,
453 reconstruction_errors: reconstruction_errors.clone(),
454 free_energy_diffs,
455 epochs_completed: reconstruction_errors.len(),
456 final_reconstruction_error: reconstruction_errors.last().copied().unwrap_or(0.0),
457 converged: self.training_config.error_threshold.map_or(false, |t| {
458 reconstruction_errors.last().unwrap_or(&f64::INFINITY) < &t
459 }),
460 quantum_sampling_stats: quantum_stats,
461 });
462
463 println!("Training completed in {total_time:.2?}");
464 Ok(())
465 }
466
467 fn contrastive_divergence_batch(
469 &mut self,
470 batch: &[&TrainingSample],
471 persistent_chains: &mut Option<Vec<Vec<f64>>>,
472 ) -> QbmResult<(f64, f64, QuantumSamplingStats)> {
473 let mut total_error = 0.0;
474 let mut total_fe_diff = 0.0;
475 let mut quantum_stats = QuantumSamplingStats::default();
476
477 for (i, sample) in batch.iter().enumerate() {
478 let hidden_probs_pos = self.sample_hidden_given_visible(&sample.data)?;
480 let hidden_states_pos = self.sample_binary_units(&hidden_probs_pos)?;
481
482 let (visible_recon, hidden_probs_neg, sampling_stats) =
484 if self.training_config.persistent_cd {
485 if let Some(ref mut chains) = persistent_chains {
486 let chain_index = i % chains.len();
487 let mut chain = chains[chain_index].clone();
488 for _ in 0..self.training_config.k_steps {
489 let h_probs = self.sample_hidden_given_visible(&chain)?;
490 let h_states = self.sample_binary_units(&h_probs)?;
491 chain = self.sample_visible_given_hidden(&h_states)?;
492 }
493 chains[chain_index] = chain.clone();
494 let h_probs = self.sample_hidden_given_visible(&chain)?;
495 (chain, h_probs, QuantumSamplingStats::default())
496 } else {
497 return Err(QbmError::TrainingError(
498 "Persistent chains not initialized".to_string(),
499 ));
500 }
501 } else {
502 let mut v_states = sample.data.clone();
504 let mut sampling_stats = QuantumSamplingStats::default();
505
506 for _ in 0..self.training_config.k_steps {
507 let h_probs = self.sample_hidden_given_visible(&v_states)?;
508 let h_states = if self.hidden_config.quantum_sampling {
509 let (states, stats) = self.quantum_sample_hidden(&h_probs)?;
510 sampling_stats.merge(&stats);
511 states
512 } else {
513 self.sample_binary_units(&h_probs)?
514 };
515
516 v_states = if self.visible_config.quantum_sampling {
517 let (states, stats) = self.quantum_sample_visible(&h_states)?;
518 sampling_stats.merge(&stats);
519 states
520 } else {
521 self.sample_visible_given_hidden(&h_states)?
522 };
523 }
524
525 let h_probs_neg = self.sample_hidden_given_visible(&v_states)?;
526 (v_states, h_probs_neg, sampling_stats)
527 };
528
529 let error = sample
533 .data
534 .iter()
535 .zip(visible_recon.iter())
536 .map(|(orig, recon)| (orig - recon).powi(2))
537 .sum::<f64>()
538 / sample.data.len() as f64;
539
540 let fe_pos = self.free_energy(&sample.data)?;
542 let fe_neg = self.free_energy(&visible_recon)?;
543 let fe_diff = fe_pos - fe_neg;
544
545 total_error += error;
546 total_fe_diff += fe_diff;
547 quantum_stats.merge(&sampling_stats);
548 }
549
550 Ok((
551 total_error / batch.len() as f64,
552 total_fe_diff / batch.len() as f64,
553 quantum_stats,
554 ))
555 }
556
557 fn update_parameters_with_momentum(
559 &mut self,
560 _batch: &[&TrainingSample],
561 weight_momentum: &mut Vec<Vec<f64>>,
562 visible_bias_momentum: &mut Vec<f64>,
563 hidden_bias_momentum: &mut Vec<f64>,
564 ) -> QbmResult<()> {
565 let lr = self.training_config.learning_rate;
569 let momentum = self.training_config.momentum;
570 let decay = self.training_config.weight_decay;
571
572 for i in 0..self.visible_config.num_units {
574 for j in 0..self.hidden_config.num_units {
575 let gradient = self.rng.gen_range(-0.001..0.001); weight_momentum[i][j] = momentum.mul_add(weight_momentum[i][j], lr * gradient);
577 self.weights[i][j] += decay.mul_add(-self.weights[i][j], weight_momentum[i][j]);
578 }
579 }
580
581 for i in 0..self.visible_config.num_units {
583 let gradient = self.rng.gen_range(-0.001..0.001); visible_bias_momentum[i] = momentum.mul_add(visible_bias_momentum[i], lr * gradient);
585 self.visible_biases[i] += visible_bias_momentum[i];
586 }
587
588 for j in 0..self.hidden_config.num_units {
590 let gradient = self.rng.gen_range(-0.001..0.001); hidden_bias_momentum[j] = momentum.mul_add(hidden_bias_momentum[j], lr * gradient);
592 self.hidden_biases[j] += hidden_bias_momentum[j];
593 }
594
595 Ok(())
596 }
597
598 fn initialize_persistent_chains(&mut self, num_chains: usize) -> QbmResult<Vec<Vec<f64>>> {
600 let mut chains = Vec::new();
601
602 for _ in 0..num_chains {
603 let chain: Vec<f64> = (0..self.visible_config.num_units)
604 .map(|_| if self.rng.gen_bool(0.5) { 1.0 } else { 0.0 })
605 .collect();
606 chains.push(chain);
607 }
608
609 Ok(chains)
610 }
611
612 fn sample_hidden_given_visible(&self, visible: &[f64]) -> QbmResult<Vec<f64>> {
614 if visible.len() != self.visible_config.num_units {
615 return Err(QbmError::DataError("Visible size mismatch".to_string()));
616 }
617
618 let mut hidden_probs = vec![0.0; self.hidden_config.num_units];
619
620 for j in 0..self.hidden_config.num_units {
621 let activation = self.hidden_biases[j]
622 + visible
623 .iter()
624 .enumerate()
625 .map(|(i, &v)| v * self.weights[i][j])
626 .sum::<f64>();
627
628 hidden_probs[j] = match self.hidden_config.unit_type {
629 UnitType::Binary => sigmoid(activation),
630 UnitType::Gaussian => activation, UnitType::Softmax => activation, };
633 }
634
635 if self.hidden_config.unit_type == UnitType::Softmax {
637 softmax_normalize(&mut hidden_probs);
638 }
639
640 Ok(hidden_probs)
641 }
642
643 fn sample_visible_given_hidden(&self, hidden: &[f64]) -> QbmResult<Vec<f64>> {
645 if hidden.len() != self.hidden_config.num_units {
646 return Err(QbmError::DataError("Hidden size mismatch".to_string()));
647 }
648
649 let mut visible_probs = vec![0.0; self.visible_config.num_units];
650
651 for i in 0..self.visible_config.num_units {
652 let activation = self.visible_biases[i]
653 + hidden
654 .iter()
655 .enumerate()
656 .map(|(j, &h)| h * self.weights[i][j])
657 .sum::<f64>();
658
659 visible_probs[i] = match self.visible_config.unit_type {
660 UnitType::Binary => sigmoid(activation),
661 UnitType::Gaussian => activation,
662 UnitType::Softmax => activation,
663 };
664 }
665
666 if self.visible_config.unit_type == UnitType::Softmax {
667 softmax_normalize(&mut visible_probs);
668 }
669
670 Ok(visible_probs)
671 }
672
673 fn sample_binary_units(&mut self, probabilities: &[f64]) -> QbmResult<Vec<f64>> {
675 Ok(probabilities
676 .iter()
677 .map(|&p| if self.rng.gen_bool(p) { 1.0 } else { 0.0 })
678 .collect())
679 }
680
681 fn quantum_sample_hidden(
683 &mut self,
684 probabilities: &[f64],
685 ) -> QbmResult<(Vec<f64>, QuantumSamplingStats)> {
686 let start_time = Instant::now();
687
688 let mut ising_model = IsingModel::new(probabilities.len());
690
691 for (i, &prob) in probabilities.iter().enumerate() {
693 let bias = -2.0 * (prob.ln() - (1.0 - prob).ln()); ising_model.set_bias(i, bias)?;
695 }
696
697 if let Ok(sample) = self.quantum_annealing_sample(&ising_model) {
699 let sampling_time = start_time.elapsed();
700 let stats = QuantumSamplingStats {
701 total_sampling_time: sampling_time,
702 sampling_calls: 1,
703 average_annealing_energy: 0.0, success_rate: 1.0,
705 classical_fallback_rate: 0.0,
706 };
707
708 let binary_sample = sample
710 .iter()
711 .map(|&s| if s > 0 { 1.0 } else { 0.0 })
712 .collect();
713
714 Ok((binary_sample, stats))
715 } else {
716 let sample = self.sample_binary_units(probabilities)?;
718 let stats = QuantumSamplingStats {
719 total_sampling_time: start_time.elapsed(),
720 sampling_calls: 1,
721 average_annealing_energy: 0.0,
722 success_rate: 0.0,
723 classical_fallback_rate: 1.0,
724 };
725 Ok((sample, stats))
726 }
727 }
728
729 fn quantum_sample_visible(
731 &mut self,
732 hidden_states: &[f64],
733 ) -> QbmResult<(Vec<f64>, QuantumSamplingStats)> {
734 let visible_probs = self.sample_visible_given_hidden(hidden_states)?;
735 self.quantum_sample_hidden(&visible_probs) }
737
738 fn quantum_annealing_sample(&self, model: &IsingModel) -> QbmResult<Vec<i8>> {
740 let mut simulator =
741 QuantumAnnealingSimulator::new(self.training_config.annealing_params.clone())
742 .map_err(|e| QbmError::SamplingError(e.to_string()))?;
743
744 let result = simulator
745 .solve(model)
746 .map_err(|e| QbmError::SamplingError(e.to_string()))?;
747
748 Ok(result.best_spins)
749 }
750
751 fn free_energy(&self, visible: &[f64]) -> QbmResult<f64> {
753 if visible.len() != self.visible_config.num_units {
754 return Err(QbmError::DataError("Visible size mismatch".to_string()));
755 }
756
757 let visible_term: f64 = visible
759 .iter()
760 .zip(self.visible_biases.iter())
761 .map(|(&v, &b)| v * b)
762 .sum();
763
764 let hidden_term: f64 = (0..self.hidden_config.num_units)
766 .map(|j| {
767 let activation = self.hidden_biases[j]
768 + visible
769 .iter()
770 .enumerate()
771 .map(|(i, &v)| v * self.weights[i][j])
772 .sum::<f64>();
773 activation.exp().ln_1p()
774 })
775 .sum();
776
777 Ok(-(visible_term + hidden_term))
778 }
779
780 pub fn infer(&mut self, input: &[f64]) -> QbmResult<QbmInferenceResult> {
782 if input.len() != self.visible_config.num_units {
783 return Err(QbmError::DataError("Input size mismatch".to_string()));
784 }
785
786 let hidden_probs = self.sample_hidden_given_visible(input)?;
788 let hidden_states = self.sample_binary_units(&hidden_probs)?;
789
790 let reconstruction = self.sample_visible_given_hidden(&hidden_states)?;
792
793 let free_energy = self.free_energy(input)?;
795 let probability = (-free_energy).exp(); Ok(QbmInferenceResult {
798 reconstruction,
799 hidden_activations: hidden_probs,
800 free_energy,
801 probability,
802 })
803 }
804
805 pub fn generate_samples(&mut self, num_samples: usize) -> QbmResult<Vec<Vec<f64>>> {
807 let mut samples = Vec::new();
808
809 for _ in 0..num_samples {
810 let mut visible: Vec<f64> = (0..self.visible_config.num_units)
812 .map(|_| if self.rng.gen_bool(0.5) { 1.0 } else { 0.0 })
813 .collect();
814
815 for _ in 0..100 {
817 let hidden_probs = self.sample_hidden_given_visible(&visible)?;
818 let hidden_states = self.sample_binary_units(&hidden_probs)?;
819 visible = self.sample_visible_given_hidden(&hidden_states)?;
820 }
821
822 samples.push(visible);
823 }
824
825 Ok(samples)
826 }
827
828 #[must_use]
830 pub const fn get_training_stats(&self) -> Option<&QbmTrainingStats> {
831 self.training_stats.as_ref()
832 }
833
834 pub fn save_model(&self, path: &str) -> QbmResult<()> {
836 println!("Model would be saved to: {path}");
839 Ok(())
840 }
841
842 pub fn load_model(&mut self, path: &str) -> QbmResult<()> {
844 println!("Model would be loaded from: {path}");
847 Ok(())
848 }
849}
850
851impl QuantumSamplingStats {
852 fn merge(&mut self, other: &Self) {
854 self.total_sampling_time += other.total_sampling_time;
855 self.sampling_calls += other.sampling_calls;
856
857 if self.sampling_calls > 0 {
858 let total_calls = self.sampling_calls as f64;
859 self.average_annealing_energy = self.average_annealing_energy.mul_add(
860 total_calls - other.sampling_calls as f64,
861 other.average_annealing_energy * other.sampling_calls as f64,
862 ) / total_calls;
863
864 self.success_rate = self.success_rate.mul_add(
865 total_calls - other.sampling_calls as f64,
866 other.success_rate * other.sampling_calls as f64,
867 ) / total_calls;
868
869 self.classical_fallback_rate = self.classical_fallback_rate.mul_add(
870 total_calls - other.sampling_calls as f64,
871 other.classical_fallback_rate * other.sampling_calls as f64,
872 ) / total_calls;
873 }
874 }
875}
876
877fn sigmoid(x: f64) -> f64 {
879 1.0 / (1.0 + (-x).exp())
880}
881
882fn softmax_normalize(values: &mut [f64]) {
884 let max_val = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
885 let sum: f64 = values.iter().map(|&x| (x - max_val).exp()).sum();
886
887 for value in values.iter_mut() {
888 *value = (*value - max_val).exp() / sum;
889 }
890}
891
892pub fn create_binary_rbm(
896 num_visible: usize,
897 num_hidden: usize,
898 training_config: QbmTrainingConfig,
899) -> QbmResult<QuantumRestrictedBoltzmannMachine> {
900 let visible_config = LayerConfig::new("visible".to_string(), num_visible, UnitType::Binary);
901 let hidden_config = LayerConfig::new("hidden".to_string(), num_hidden, UnitType::Binary);
902
903 QuantumRestrictedBoltzmannMachine::new(visible_config, hidden_config, training_config)
904}
905
906pub fn create_gaussian_bernoulli_rbm(
908 num_visible: usize,
909 num_hidden: usize,
910 training_config: QbmTrainingConfig,
911) -> QbmResult<QuantumRestrictedBoltzmannMachine> {
912 let visible_config = LayerConfig::new("visible".to_string(), num_visible, UnitType::Gaussian);
913 let hidden_config = LayerConfig::new("hidden".to_string(), num_hidden, UnitType::Binary);
914
915 QuantumRestrictedBoltzmannMachine::new(visible_config, hidden_config, training_config)
916}
917
918#[cfg(test)]
919mod tests {
920 use super::*;
921
922 #[test]
923 fn test_rbm_creation() {
924 let training_config = QbmTrainingConfig {
925 epochs: 10,
926 ..Default::default()
927 };
928
929 let rbm = create_binary_rbm(4, 3, training_config).expect("failed to create binary RBM");
930 assert_eq!(rbm.visible_config.num_units, 4);
931 assert_eq!(rbm.hidden_config.num_units, 3);
932 }
933
934 #[test]
935 fn test_sigmoid_function() {
936 assert!((sigmoid(0.0) - 0.5).abs() < 1e-10);
937 assert!(sigmoid(10.0) > 0.99);
938 assert!(sigmoid(-10.0) < 0.01);
939 }
940
941 #[test]
942 fn test_softmax_normalization() {
943 let mut values = vec![1.0, 2.0, 3.0];
944 softmax_normalize(&mut values);
945
946 let sum: f64 = values.iter().sum();
947 assert!((sum - 1.0).abs() < 1e-10);
948 assert!(values.iter().all(|&x| x > 0.0 && x < 1.0));
949 }
950
951 #[test]
952 fn test_training_sample_creation() {
953 let sample = TrainingSample::new(vec![1.0, 0.0, 1.0]);
954 assert_eq!(sample.data.len(), 3);
955 assert!(sample.label.is_none());
956
957 let labeled_sample = TrainingSample::labeled(vec![1.0, 0.0], vec![1.0]);
958 assert_eq!(labeled_sample.data.len(), 2);
959 assert_eq!(
960 labeled_sample
961 .label
962 .as_ref()
963 .expect("label should exist")
964 .len(),
965 1
966 );
967 }
968
969 #[test]
970 fn test_layer_config() {
971 let config = LayerConfig::new("test".to_string(), 10, UnitType::Binary)
972 .with_bias_range(-0.5, 0.5)
973 .with_quantum_sampling(false);
974
975 assert_eq!(config.num_units, 10);
976 assert_eq!(config.unit_type, UnitType::Binary);
977 assert_eq!(config.bias_init_range, (-0.5, 0.5));
978 assert!(!config.quantum_sampling);
979 }
980}