1#![allow(
2 clippy::pedantic,
3 clippy::unnecessary_wraps,
4 clippy::needless_range_loop,
5 clippy::useless_vec,
6 clippy::needless_collect,
7 clippy::too_many_arguments
8)]
9use quantrs2_ml::prelude::*;
15use scirs2_core::ndarray::{s, Array1, Array2};
16use scirs2_core::random::prelude::*;
17
18fn main() -> Result<()> {
19 println!("=== Quantum Boltzmann Machine Demo ===\n");
20
21 println!("1. Basic Quantum Boltzmann Machine...");
23 basic_qbm_demo()?;
24
25 println!("\n2. Quantum Restricted Boltzmann Machine (RBM)...");
27 rbm_demo()?;
28
29 println!("\n3. Deep Boltzmann Machine...");
31 deep_boltzmann_demo()?;
32
33 println!("\n4. Energy Landscape Analysis...");
35 energy_landscape_demo()?;
36
37 println!("\n5. Pattern Completion Demo...");
39 pattern_completion_demo()?;
40
41 println!("\n=== Boltzmann Machine Demo Complete ===");
42
43 Ok(())
44}
45
46fn basic_qbm_demo() -> Result<()> {
48 let mut qbm = QuantumBoltzmannMachine::new(
50 4, 2, 1.0, 0.01, )?;
55
56 println!(" Created QBM with 4 visible and 2 hidden units");
57
58 let data = generate_binary_patterns(100, 4);
60
61 println!(" Training on binary patterns...");
63 let losses = qbm.train(&data, 50, 10)?;
64
65 println!(" Training complete:");
66 println!(" - Initial loss: {:.4}", losses[0]);
67 println!(" - Final loss: {:.4}", losses.last().unwrap());
68
69 let samples = qbm.sample(5)?;
71 println!("\n Generated samples:");
72 for (i, sample) in samples.outer_iter().enumerate() {
73 print!(" Sample {}: [", i + 1);
74 for val in sample {
75 print!("{val:.0} ");
76 }
77 println!("]");
78 }
79
80 Ok(())
81}
82
83fn rbm_demo() -> Result<()> {
85 let annealing = AnnealingSchedule::new(2.0, 0.5, 100);
87
88 let mut rbm = QuantumRBM::new(
89 6, 3, 2.0, 0.01, )?
94 .with_annealing(annealing);
95
96 println!(" Created Quantum RBM with annealing schedule");
97
98 let data = generate_correlated_data(200, 6);
100
101 println!(" Training with Persistent Contrastive Divergence...");
103 let losses = rbm.train_pcd(
104 &data, 100, 20, 50, )?;
108
109 let improvement = (losses[0] - losses.last().unwrap()) / losses[0] * 100.0;
111 println!(" Training statistics:");
112 println!(" - Loss reduction: {improvement:.1}%");
113 println!(" - Final temperature: 0.5");
114
115 let test_data = data.slice(s![0..5, ..]).to_owned();
117 let reconstructed = rbm.qbm().reconstruct(&test_data)?;
118
119 println!("\n Reconstruction quality:");
120 for i in 0..3 {
121 print!(" Original: [");
122 for val in test_data.row(i) {
123 print!("{val:.0} ");
124 }
125 print!("] → Reconstructed: [");
126 for val in reconstructed.row(i) {
127 print!("{val:.0} ");
128 }
129 println!("]");
130 }
131
132 Ok(())
133}
134
135fn deep_boltzmann_demo() -> Result<()> {
137 let layer_sizes = vec![8, 4, 2];
139 let mut dbm = DeepBoltzmannMachine::new(
140 layer_sizes.clone(),
141 1.0, 0.01, )?;
144
145 println!(" Created Deep Boltzmann Machine:");
146 println!(" - Architecture: {layer_sizes:?}");
147 println!(" - Total layers: {}", dbm.rbms().len());
148
149 let data = generate_hierarchical_data(300, 8);
151
152 println!("\n Performing layer-wise pretraining...");
154 dbm.pretrain(
155 &data, 50, 30, )?;
158
159 println!("\n Pretraining complete!");
160 println!(" Each layer learned increasingly abstract features");
161
162 Ok(())
163}
164
165fn energy_landscape_demo() -> Result<()> {
167 let qbm = QuantumBoltzmannMachine::new(
169 2, 1, 0.5, 0.01, )?;
174
175 println!(" Analyzing energy landscape of 2-unit system");
176
177 let states = [
179 Array1::from_vec(vec![0.0, 0.0]),
180 Array1::from_vec(vec![0.0, 1.0]),
181 Array1::from_vec(vec![1.0, 0.0]),
182 Array1::from_vec(vec![1.0, 1.0]),
183 ];
184
185 println!("\n State energies:");
186 for (i, state) in states.iter().enumerate() {
187 let energy = qbm.energy(state);
188 let prob = (-energy / qbm.temperature()).exp();
189 println!(
190 " State [{:.0}, {:.0}]: E = {:.3}, P ∝ {:.3}",
191 state[0], state[1], energy, prob
192 );
193 }
194
195 println!("\n Coupling matrix:");
197 for i in 0..3 {
198 print!(" [");
199 for j in 0..3 {
200 print!("{:6.3} ", qbm.couplings()[[i, j]]);
201 }
202 println!("]");
203 }
204
205 Ok(())
206}
207
208fn pattern_completion_demo() -> Result<()> {
210 let mut rbm = QuantumRBM::new(
212 8, 4, 1.0, 0.02, )?;
217
218 let patterns = create_letter_patterns();
220 println!(" Training on letter-like patterns...");
221
222 rbm.train_pcd(&patterns, 100, 10, 20)?;
223
224 println!("\n Pattern completion test:");
226
227 let mut corrupted = patterns.row(0).to_owned();
229 corrupted[3] = 1.0 - corrupted[3]; corrupted[5] = 1.0 - corrupted[5]; print!(" Corrupted: [");
233 for val in &corrupted {
234 print!("{val:.0} ");
235 }
236 println!("]");
237
238 let completed = complete_pattern(&rbm, &corrupted)?;
240
241 print!(" Completed: [");
242 for val in &completed {
243 print!("{val:.0} ");
244 }
245 println!("]");
246
247 print!(" Original: [");
248 for val in patterns.row(0) {
249 print!("{val:.0} ");
250 }
251 println!("]");
252
253 let accuracy = patterns
254 .row(0)
255 .iter()
256 .zip(completed.iter())
257 .filter(|(&a, &b)| (a - b).abs() < 0.5)
258 .count() as f64
259 / 8.0;
260
261 println!(" Reconstruction accuracy: {:.1}%", accuracy * 100.0);
262
263 Ok(())
264}
265
266fn generate_binary_patterns(n_samples: usize, n_features: usize) -> Array2<f64> {
268 Array2::from_shape_fn((n_samples, n_features), |(_, _)| {
269 if thread_rng().gen::<f64>() > 0.5 {
270 1.0
271 } else {
272 0.0
273 }
274 })
275}
276
277fn generate_correlated_data(n_samples: usize, n_features: usize) -> Array2<f64> {
279 let mut data = Array2::zeros((n_samples, n_features));
280
281 for i in 0..n_samples {
282 let base = if thread_rng().gen::<f64>() > 0.5 {
284 1.0
285 } else {
286 0.0
287 };
288
289 for j in 0..n_features {
290 if j % 2 == 0 {
291 data[[i, j]] = base;
292 } else {
293 data[[i, j]] = if thread_rng().gen::<f64>() > 0.2 {
295 base
296 } else {
297 1.0 - base
298 };
299 }
300 }
301 }
302
303 data
304}
305
306fn generate_hierarchical_data(n_samples: usize, n_features: usize) -> Array2<f64> {
308 let mut data = Array2::zeros((n_samples, n_features));
309
310 for i in 0..n_samples {
311 let pattern_type = i % 3;
313
314 match pattern_type {
315 0 => {
316 for j in 0..n_features {
318 data[[i, j]] = (j % 2) as f64;
319 }
320 }
321 1 => {
322 for j in 0..n_features {
324 data[[i, j]] = ((j / 2) % 2) as f64;
325 }
326 }
327 _ => {
328 let shift = (thread_rng().gen::<f64>() * 4.0) as usize;
330 for j in 0..n_features {
331 data[[i, j]] = if (j + shift) % 3 == 0 { 1.0 } else { 0.0 };
332 }
333 }
334 }
335
336 for j in 0..n_features {
338 if thread_rng().gen::<f64>() < 0.1 {
339 data[[i, j]] = 1.0 - data[[i, j]];
340 }
341 }
342 }
343
344 data
345}
346
347fn create_letter_patterns() -> Array2<f64> {
349 Array2::from_shape_vec(
351 (4, 8),
352 vec![
353 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0,
358 ],
359 )
360 .unwrap()
361}
362
363fn complete_pattern(rbm: &QuantumRBM, partial: &Array1<f64>) -> Result<Array1<f64>> {
365 let mut current = partial.clone();
367
368 for _ in 0..10 {
369 let hidden = rbm.qbm().sample_hidden_given_visible(¤t.view())?;
370 current = rbm.qbm().sample_visible_given_hidden(&hidden)?;
371 }
372
373 Ok(current)
374}