1#![allow(clippy::pedantic, clippy::unnecessary_wraps)]
2use quantrs2_ml::prelude::*;
8use scirs2_core::ndarray::{s, Array1, Array2};
9use scirs2_core::random::prelude::*;
10
11fn main() -> Result<()> {
12 println!("=== Quantum Boltzmann Machine Demo ===\n");
13
14 println!("1. Basic Quantum Boltzmann Machine...");
16 basic_qbm_demo()?;
17
18 println!("\n2. Quantum Restricted Boltzmann Machine (RBM)...");
20 rbm_demo()?;
21
22 println!("\n3. Deep Boltzmann Machine...");
24 deep_boltzmann_demo()?;
25
26 println!("\n4. Energy Landscape Analysis...");
28 energy_landscape_demo()?;
29
30 println!("\n5. Pattern Completion Demo...");
32 pattern_completion_demo()?;
33
34 println!("\n=== Boltzmann Machine Demo Complete ===");
35
36 Ok(())
37}
38
39fn basic_qbm_demo() -> Result<()> {
41 let mut qbm = QuantumBoltzmannMachine::new(
43 4, 2, 1.0, 0.01, )?;
48
49 println!(" Created QBM with 4 visible and 2 hidden units");
50
51 let data = generate_binary_patterns(100, 4);
53
54 println!(" Training on binary patterns...");
56 let losses = qbm.train(&data, 50, 10)?;
57
58 println!(" Training complete:");
59 println!(" - Initial loss: {:.4}", losses[0]);
60 println!(" - Final loss: {:.4}", losses.last().unwrap());
61
62 let samples = qbm.sample(5)?;
64 println!("\n Generated samples:");
65 for (i, sample) in samples.outer_iter().enumerate() {
66 print!(" Sample {}: [", i + 1);
67 for val in sample {
68 print!("{val:.0} ");
69 }
70 println!("]");
71 }
72
73 Ok(())
74}
75
76fn rbm_demo() -> Result<()> {
78 let annealing = AnnealingSchedule::new(2.0, 0.5, 100);
80
81 let mut rbm = QuantumRBM::new(
82 6, 3, 2.0, 0.01, )?
87 .with_annealing(annealing);
88
89 println!(" Created Quantum RBM with annealing schedule");
90
91 let data = generate_correlated_data(200, 6);
93
94 println!(" Training with Persistent Contrastive Divergence...");
96 let losses = rbm.train_pcd(
97 &data, 100, 20, 50, )?;
101
102 let improvement = (losses[0] - losses.last().unwrap()) / losses[0] * 100.0;
104 println!(" Training statistics:");
105 println!(" - Loss reduction: {improvement:.1}%");
106 println!(" - Final temperature: 0.5");
107
108 let test_data = data.slice(s![0..5, ..]).to_owned();
110 let reconstructed = rbm.qbm().reconstruct(&test_data)?;
111
112 println!("\n Reconstruction quality:");
113 for i in 0..3 {
114 print!(" Original: [");
115 for val in test_data.row(i) {
116 print!("{val:.0} ");
117 }
118 print!("] → Reconstructed: [");
119 for val in reconstructed.row(i) {
120 print!("{val:.0} ");
121 }
122 println!("]");
123 }
124
125 Ok(())
126}
127
128fn deep_boltzmann_demo() -> Result<()> {
130 let layer_sizes = vec![8, 4, 2];
132 let mut dbm = DeepBoltzmannMachine::new(
133 layer_sizes.clone(),
134 1.0, 0.01, )?;
137
138 println!(" Created Deep Boltzmann Machine:");
139 println!(" - Architecture: {layer_sizes:?}");
140 println!(" - Total layers: {}", dbm.rbms().len());
141
142 let data = generate_hierarchical_data(300, 8);
144
145 println!("\n Performing layer-wise pretraining...");
147 dbm.pretrain(
148 &data, 50, 30, )?;
151
152 println!("\n Pretraining complete!");
153 println!(" Each layer learned increasingly abstract features");
154
155 Ok(())
156}
157
158fn energy_landscape_demo() -> Result<()> {
160 let qbm = QuantumBoltzmannMachine::new(
162 2, 1, 0.5, 0.01, )?;
167
168 println!(" Analyzing energy landscape of 2-unit system");
169
170 let states = [
172 Array1::from_vec(vec![0.0, 0.0]),
173 Array1::from_vec(vec![0.0, 1.0]),
174 Array1::from_vec(vec![1.0, 0.0]),
175 Array1::from_vec(vec![1.0, 1.0]),
176 ];
177
178 println!("\n State energies:");
179 for (i, state) in states.iter().enumerate() {
180 let energy = qbm.energy(state);
181 let prob = (-energy / qbm.temperature()).exp();
182 println!(
183 " State [{:.0}, {:.0}]: E = {:.3}, P ∝ {:.3}",
184 state[0], state[1], energy, prob
185 );
186 }
187
188 println!("\n Coupling matrix:");
190 for i in 0..3 {
191 print!(" [");
192 for j in 0..3 {
193 print!("{:6.3} ", qbm.couplings()[[i, j]]);
194 }
195 println!("]");
196 }
197
198 Ok(())
199}
200
201fn pattern_completion_demo() -> Result<()> {
203 let mut rbm = QuantumRBM::new(
205 8, 4, 1.0, 0.02, )?;
210
211 let patterns = create_letter_patterns();
213 println!(" Training on letter-like patterns...");
214
215 rbm.train_pcd(&patterns, 100, 10, 20)?;
216
217 println!("\n Pattern completion test:");
219
220 let mut corrupted = patterns.row(0).to_owned();
222 corrupted[3] = 1.0 - corrupted[3]; corrupted[5] = 1.0 - corrupted[5]; print!(" Corrupted: [");
226 for val in &corrupted {
227 print!("{val:.0} ");
228 }
229 println!("]");
230
231 let completed = complete_pattern(&rbm, &corrupted)?;
233
234 print!(" Completed: [");
235 for val in &completed {
236 print!("{val:.0} ");
237 }
238 println!("]");
239
240 print!(" Original: [");
241 for val in patterns.row(0) {
242 print!("{val:.0} ");
243 }
244 println!("]");
245
246 let accuracy = patterns
247 .row(0)
248 .iter()
249 .zip(completed.iter())
250 .filter(|(&a, &b)| (a - b).abs() < 0.5)
251 .count() as f64
252 / 8.0;
253
254 println!(" Reconstruction accuracy: {:.1}%", accuracy * 100.0);
255
256 Ok(())
257}
258
259fn generate_binary_patterns(n_samples: usize, n_features: usize) -> Array2<f64> {
261 Array2::from_shape_fn((n_samples, n_features), |(_, _)| {
262 if thread_rng().gen::<f64>() > 0.5 {
263 1.0
264 } else {
265 0.0
266 }
267 })
268}
269
270fn generate_correlated_data(n_samples: usize, n_features: usize) -> Array2<f64> {
272 let mut data = Array2::zeros((n_samples, n_features));
273
274 for i in 0..n_samples {
275 let base = if thread_rng().gen::<f64>() > 0.5 {
277 1.0
278 } else {
279 0.0
280 };
281
282 for j in 0..n_features {
283 if j % 2 == 0 {
284 data[[i, j]] = base;
285 } else {
286 data[[i, j]] = if thread_rng().gen::<f64>() > 0.2 {
288 base
289 } else {
290 1.0 - base
291 };
292 }
293 }
294 }
295
296 data
297}
298
299fn generate_hierarchical_data(n_samples: usize, n_features: usize) -> Array2<f64> {
301 let mut data = Array2::zeros((n_samples, n_features));
302
303 for i in 0..n_samples {
304 let pattern_type = i % 3;
306
307 match pattern_type {
308 0 => {
309 for j in 0..n_features {
311 data[[i, j]] = (j % 2) as f64;
312 }
313 }
314 1 => {
315 for j in 0..n_features {
317 data[[i, j]] = ((j / 2) % 2) as f64;
318 }
319 }
320 _ => {
321 let shift = (thread_rng().gen::<f64>() * 4.0) as usize;
323 for j in 0..n_features {
324 data[[i, j]] = if (j + shift) % 3 == 0 { 1.0 } else { 0.0 };
325 }
326 }
327 }
328
329 for j in 0..n_features {
331 if thread_rng().gen::<f64>() < 0.1 {
332 data[[i, j]] = 1.0 - data[[i, j]];
333 }
334 }
335 }
336
337 data
338}
339
340fn create_letter_patterns() -> Array2<f64> {
342 Array2::from_shape_vec(
344 (4, 8),
345 vec![
346 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0,
351 ],
352 )
353 .unwrap()
354}
355
356fn complete_pattern(rbm: &QuantumRBM, partial: &Array1<f64>) -> Result<Array1<f64>> {
358 let mut current = partial.clone();
360
361 for _ in 0..10 {
362 let hidden = rbm.qbm().sample_hidden_given_visible(¤t.view())?;
363 current = rbm.qbm().sample_visible_given_hidden(&hidden)?;
364 }
365
366 Ok(current)
367}