1use quantrs2_ml::prelude::*;
7use scirs2_core::ndarray::{s, Array1, Array2};
8use scirs2_core::random::prelude::*;
9
10fn main() -> Result<()> {
11 println!("=== Quantum Boltzmann Machine Demo ===\n");
12
13 println!("1. Basic Quantum Boltzmann Machine...");
15 basic_qbm_demo()?;
16
17 println!("\n2. Quantum Restricted Boltzmann Machine (RBM)...");
19 rbm_demo()?;
20
21 println!("\n3. Deep Boltzmann Machine...");
23 deep_boltzmann_demo()?;
24
25 println!("\n4. Energy Landscape Analysis...");
27 energy_landscape_demo()?;
28
29 println!("\n5. Pattern Completion Demo...");
31 pattern_completion_demo()?;
32
33 println!("\n=== Boltzmann Machine Demo Complete ===");
34
35 Ok(())
36}
37
38fn basic_qbm_demo() -> Result<()> {
40 let mut qbm = QuantumBoltzmannMachine::new(
42 4, 2, 1.0, 0.01, )?;
47
48 println!(" Created QBM with 4 visible and 2 hidden units");
49
50 let data = generate_binary_patterns(100, 4);
52
53 println!(" Training on binary patterns...");
55 let losses = qbm.train(&data, 50, 10)?;
56
57 println!(" Training complete:");
58 println!(" - Initial loss: {:.4}", losses[0]);
59 println!(" - Final loss: {:.4}", losses.last().unwrap());
60
61 let samples = qbm.sample(5)?;
63 println!("\n Generated samples:");
64 for (i, sample) in samples.outer_iter().enumerate() {
65 print!(" Sample {}: [", i + 1);
66 for val in sample {
67 print!("{val:.0} ");
68 }
69 println!("]");
70 }
71
72 Ok(())
73}
74
75fn rbm_demo() -> Result<()> {
77 let annealing = AnnealingSchedule::new(2.0, 0.5, 100);
79
80 let mut rbm = QuantumRBM::new(
81 6, 3, 2.0, 0.01, )?
86 .with_annealing(annealing);
87
88 println!(" Created Quantum RBM with annealing schedule");
89
90 let data = generate_correlated_data(200, 6);
92
93 println!(" Training with Persistent Contrastive Divergence...");
95 let losses = rbm.train_pcd(
96 &data, 100, 20, 50, )?;
100
101 let improvement = (losses[0] - losses.last().unwrap()) / losses[0] * 100.0;
103 println!(" Training statistics:");
104 println!(" - Loss reduction: {improvement:.1}%");
105 println!(" - Final temperature: 0.5");
106
107 let test_data = data.slice(s![0..5, ..]).to_owned();
109 let reconstructed = rbm.qbm().reconstruct(&test_data)?;
110
111 println!("\n Reconstruction quality:");
112 for i in 0..3 {
113 print!(" Original: [");
114 for val in test_data.row(i) {
115 print!("{val:.0} ");
116 }
117 print!("] → Reconstructed: [");
118 for val in reconstructed.row(i) {
119 print!("{val:.0} ");
120 }
121 println!("]");
122 }
123
124 Ok(())
125}
126
127fn deep_boltzmann_demo() -> Result<()> {
129 let layer_sizes = vec![8, 4, 2];
131 let mut dbm = DeepBoltzmannMachine::new(
132 layer_sizes.clone(),
133 1.0, 0.01, )?;
136
137 println!(" Created Deep Boltzmann Machine:");
138 println!(" - Architecture: {layer_sizes:?}");
139 println!(" - Total layers: {}", dbm.rbms().len());
140
141 let data = generate_hierarchical_data(300, 8);
143
144 println!("\n Performing layer-wise pretraining...");
146 dbm.pretrain(
147 &data, 50, 30, )?;
150
151 println!("\n Pretraining complete!");
152 println!(" Each layer learned increasingly abstract features");
153
154 Ok(())
155}
156
157fn energy_landscape_demo() -> Result<()> {
159 let qbm = QuantumBoltzmannMachine::new(
161 2, 1, 0.5, 0.01, )?;
166
167 println!(" Analyzing energy landscape of 2-unit system");
168
169 let states = [
171 Array1::from_vec(vec![0.0, 0.0]),
172 Array1::from_vec(vec![0.0, 1.0]),
173 Array1::from_vec(vec![1.0, 0.0]),
174 Array1::from_vec(vec![1.0, 1.0]),
175 ];
176
177 println!("\n State energies:");
178 for (i, state) in states.iter().enumerate() {
179 let energy = qbm.energy(state);
180 let prob = (-energy / qbm.temperature()).exp();
181 println!(
182 " State [{:.0}, {:.0}]: E = {:.3}, P ∝ {:.3}",
183 state[0], state[1], energy, prob
184 );
185 }
186
187 println!("\n Coupling matrix:");
189 for i in 0..3 {
190 print!(" [");
191 for j in 0..3 {
192 print!("{:6.3} ", qbm.couplings()[[i, j]]);
193 }
194 println!("]");
195 }
196
197 Ok(())
198}
199
200fn pattern_completion_demo() -> Result<()> {
202 let mut rbm = QuantumRBM::new(
204 8, 4, 1.0, 0.02, )?;
209
210 let patterns = create_letter_patterns();
212 println!(" Training on letter-like patterns...");
213
214 rbm.train_pcd(&patterns, 100, 10, 20)?;
215
216 println!("\n Pattern completion test:");
218
219 let mut corrupted = patterns.row(0).to_owned();
221 corrupted[3] = 1.0 - corrupted[3]; corrupted[5] = 1.0 - corrupted[5]; print!(" Corrupted: [");
225 for val in &corrupted {
226 print!("{val:.0} ");
227 }
228 println!("]");
229
230 let completed = complete_pattern(&rbm, &corrupted)?;
232
233 print!(" Completed: [");
234 for val in &completed {
235 print!("{val:.0} ");
236 }
237 println!("]");
238
239 print!(" Original: [");
240 for val in patterns.row(0) {
241 print!("{val:.0} ");
242 }
243 println!("]");
244
245 let accuracy = patterns
246 .row(0)
247 .iter()
248 .zip(completed.iter())
249 .filter(|(&a, &b)| (a - b).abs() < 0.5)
250 .count() as f64
251 / 8.0;
252
253 println!(" Reconstruction accuracy: {:.1}%", accuracy * 100.0);
254
255 Ok(())
256}
257
258fn generate_binary_patterns(n_samples: usize, n_features: usize) -> Array2<f64> {
260 Array2::from_shape_fn((n_samples, n_features), |(_, _)| {
261 if thread_rng().gen::<f64>() > 0.5 {
262 1.0
263 } else {
264 0.0
265 }
266 })
267}
268
269fn generate_correlated_data(n_samples: usize, n_features: usize) -> Array2<f64> {
271 let mut data = Array2::zeros((n_samples, n_features));
272
273 for i in 0..n_samples {
274 let base = if thread_rng().gen::<f64>() > 0.5 {
276 1.0
277 } else {
278 0.0
279 };
280
281 for j in 0..n_features {
282 if j % 2 == 0 {
283 data[[i, j]] = base;
284 } else {
285 data[[i, j]] = if thread_rng().gen::<f64>() > 0.2 {
287 base
288 } else {
289 1.0 - base
290 };
291 }
292 }
293 }
294
295 data
296}
297
298fn generate_hierarchical_data(n_samples: usize, n_features: usize) -> Array2<f64> {
300 let mut data = Array2::zeros((n_samples, n_features));
301
302 for i in 0..n_samples {
303 let pattern_type = i % 3;
305
306 match pattern_type {
307 0 => {
308 for j in 0..n_features {
310 data[[i, j]] = (j % 2) as f64;
311 }
312 }
313 1 => {
314 for j in 0..n_features {
316 data[[i, j]] = ((j / 2) % 2) as f64;
317 }
318 }
319 _ => {
320 let shift = (thread_rng().gen::<f64>() * 4.0) as usize;
322 for j in 0..n_features {
323 data[[i, j]] = if (j + shift) % 3 == 0 { 1.0 } else { 0.0 };
324 }
325 }
326 }
327
328 for j in 0..n_features {
330 if thread_rng().gen::<f64>() < 0.1 {
331 data[[i, j]] = 1.0 - data[[i, j]];
332 }
333 }
334 }
335
336 data
337}
338
339fn create_letter_patterns() -> Array2<f64> {
341 Array2::from_shape_vec(
343 (4, 8),
344 vec![
345 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0,
350 ],
351 )
352 .unwrap()
353}
354
355fn complete_pattern(rbm: &QuantumRBM, partial: &Array1<f64>) -> Result<Array1<f64>> {
357 let mut current = partial.clone();
359
360 for _ in 0..10 {
361 let hidden = rbm.qbm().sample_hidden_given_visible(¤t.view())?;
362 current = rbm.qbm().sample_visible_given_hidden(&hidden)?;
363 }
364
365 Ok(current)
366}