1use ndarray::{s, Array1, Array2};
7use quantrs2_ml::prelude::*;
8
9fn main() -> Result<()> {
10 println!("=== Quantum Boltzmann Machine Demo ===\n");
11
12 println!("1. Basic Quantum Boltzmann Machine...");
14 basic_qbm_demo()?;
15
16 println!("\n2. Quantum Restricted Boltzmann Machine (RBM)...");
18 rbm_demo()?;
19
20 println!("\n3. Deep Boltzmann Machine...");
22 deep_boltzmann_demo()?;
23
24 println!("\n4. Energy Landscape Analysis...");
26 energy_landscape_demo()?;
27
28 println!("\n5. Pattern Completion Demo...");
30 pattern_completion_demo()?;
31
32 println!("\n=== Boltzmann Machine Demo Complete ===");
33
34 Ok(())
35}
36
37fn basic_qbm_demo() -> Result<()> {
39 let mut qbm = QuantumBoltzmannMachine::new(
41 4, 2, 1.0, 0.01, )?;
46
47 println!(" Created QBM with 4 visible and 2 hidden units");
48
49 let data = generate_binary_patterns(100, 4);
51
52 println!(" Training on binary patterns...");
54 let losses = qbm.train(&data, 50, 10)?;
55
56 println!(" Training complete:");
57 println!(" - Initial loss: {:.4}", losses[0]);
58 println!(" - Final loss: {:.4}", losses.last().unwrap());
59
60 let samples = qbm.sample(5)?;
62 println!("\n Generated samples:");
63 for (i, sample) in samples.outer_iter().enumerate() {
64 print!(" Sample {}: [", i + 1);
65 for val in sample.iter() {
66 print!("{:.0} ", val);
67 }
68 println!("]");
69 }
70
71 Ok(())
72}
73
74fn rbm_demo() -> Result<()> {
76 let annealing = AnnealingSchedule::new(2.0, 0.5, 100);
78
79 let mut rbm = QuantumRBM::new(
80 6, 3, 2.0, 0.01, )?
85 .with_annealing(annealing);
86
87 println!(" Created Quantum RBM with annealing schedule");
88
89 let data = generate_correlated_data(200, 6);
91
92 println!(" Training with Persistent Contrastive Divergence...");
94 let losses = rbm.train_pcd(
95 &data, 100, 20, 50, )?;
99
100 let improvement = (losses[0] - losses.last().unwrap()) / losses[0] * 100.0;
102 println!(" Training statistics:");
103 println!(" - Loss reduction: {:.1}%", improvement);
104 println!(" - Final temperature: 0.5");
105
106 let test_data = data.slice(s![0..5, ..]).to_owned();
108 let reconstructed = rbm.qbm().reconstruct(&test_data)?;
109
110 println!("\n Reconstruction quality:");
111 for i in 0..3 {
112 print!(" Original: [");
113 for val in test_data.row(i).iter() {
114 print!("{:.0} ", val);
115 }
116 print!("] → Reconstructed: [");
117 for val in reconstructed.row(i).iter() {
118 print!("{:.0} ", val);
119 }
120 println!("]");
121 }
122
123 Ok(())
124}
125
126fn deep_boltzmann_demo() -> Result<()> {
128 let layer_sizes = vec![8, 4, 2];
130 let mut dbm = DeepBoltzmannMachine::new(
131 layer_sizes.clone(),
132 1.0, 0.01, )?;
135
136 println!(" Created Deep Boltzmann Machine:");
137 println!(" - Architecture: {:?}", layer_sizes);
138 println!(" - Total layers: {}", dbm.rbms().len());
139
140 let data = generate_hierarchical_data(300, 8);
142
143 println!("\n Performing layer-wise pretraining...");
145 dbm.pretrain(
146 &data, 50, 30, )?;
149
150 println!("\n Pretraining complete!");
151 println!(" Each layer learned increasingly abstract features");
152
153 Ok(())
154}
155
156fn energy_landscape_demo() -> Result<()> {
158 let qbm = QuantumBoltzmannMachine::new(
160 2, 1, 0.5, 0.01, )?;
165
166 println!(" Analyzing energy landscape of 2-unit system");
167
168 let states = vec![
170 Array1::from_vec(vec![0.0, 0.0]),
171 Array1::from_vec(vec![0.0, 1.0]),
172 Array1::from_vec(vec![1.0, 0.0]),
173 Array1::from_vec(vec![1.0, 1.0]),
174 ];
175
176 println!("\n State energies:");
177 for (i, state) in states.iter().enumerate() {
178 let energy = qbm.energy(state);
179 let prob = (-energy / qbm.temperature()).exp();
180 println!(
181 " State [{:.0}, {:.0}]: E = {:.3}, P ∝ {:.3}",
182 state[0], state[1], energy, prob
183 );
184 }
185
186 println!("\n Coupling matrix:");
188 for i in 0..3 {
189 print!(" [");
190 for j in 0..3 {
191 print!("{:6.3} ", qbm.couplings()[[i, j]]);
192 }
193 println!("]");
194 }
195
196 Ok(())
197}
198
199fn pattern_completion_demo() -> Result<()> {
201 let mut rbm = QuantumRBM::new(
203 8, 4, 1.0, 0.02, )?;
208
209 let patterns = create_letter_patterns();
211 println!(" Training on letter-like patterns...");
212
213 rbm.train_pcd(&patterns, 100, 10, 20)?;
214
215 println!("\n Pattern completion test:");
217
218 let mut corrupted = patterns.row(0).to_owned();
220 corrupted[3] = 1.0 - corrupted[3]; corrupted[5] = 1.0 - corrupted[5]; print!(" Corrupted: [");
224 for val in corrupted.iter() {
225 print!("{:.0} ", val);
226 }
227 println!("]");
228
229 let completed = complete_pattern(&rbm, &corrupted)?;
231
232 print!(" Completed: [");
233 for val in completed.iter() {
234 print!("{:.0} ", val);
235 }
236 println!("]");
237
238 print!(" Original: [");
239 for val in patterns.row(0).iter() {
240 print!("{:.0} ", val);
241 }
242 println!("]");
243
244 let accuracy = patterns
245 .row(0)
246 .iter()
247 .zip(completed.iter())
248 .filter(|(&a, &b)| (a - b).abs() < 0.5)
249 .count() as f64
250 / 8.0;
251
252 println!(" Reconstruction accuracy: {:.1}%", accuracy * 100.0);
253
254 Ok(())
255}
256
257fn generate_binary_patterns(n_samples: usize, n_features: usize) -> Array2<f64> {
259 Array2::from_shape_fn((n_samples, n_features), |(_, _)| {
260 if rand::random::<f64>() > 0.5 {
261 1.0
262 } else {
263 0.0
264 }
265 })
266}
267
268fn generate_correlated_data(n_samples: usize, n_features: usize) -> Array2<f64> {
270 let mut data = Array2::zeros((n_samples, n_features));
271
272 for i in 0..n_samples {
273 let base = if rand::random::<f64>() > 0.5 {
275 1.0
276 } else {
277 0.0
278 };
279
280 for j in 0..n_features {
281 if j % 2 == 0 {
282 data[[i, j]] = base;
283 } else {
284 data[[i, j]] = if rand::random::<f64>() > 0.2 {
286 base
287 } else {
288 1.0 - base
289 };
290 }
291 }
292 }
293
294 data
295}
296
297fn generate_hierarchical_data(n_samples: usize, n_features: usize) -> Array2<f64> {
299 let mut data = Array2::zeros((n_samples, n_features));
300
301 for i in 0..n_samples {
302 let pattern_type = i % 3;
304
305 match pattern_type {
306 0 => {
307 for j in 0..n_features {
309 data[[i, j]] = (j % 2) as f64;
310 }
311 }
312 1 => {
313 for j in 0..n_features {
315 data[[i, j]] = ((j / 2) % 2) as f64;
316 }
317 }
318 _ => {
319 let shift = (rand::random::<f64>() * 4.0) as usize;
321 for j in 0..n_features {
322 data[[i, j]] = if (j + shift) % 3 == 0 { 1.0 } else { 0.0 };
323 }
324 }
325 }
326
327 for j in 0..n_features {
329 if rand::random::<f64>() < 0.1 {
330 data[[i, j]] = 1.0 - data[[i, j]];
331 }
332 }
333 }
334
335 data
336}
337
338fn create_letter_patterns() -> Array2<f64> {
340 Array2::from_shape_vec(
342 (4, 8),
343 vec![
344 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0,
349 ],
350 )
351 .unwrap()
352}
353
354fn complete_pattern(rbm: &QuantumRBM, partial: &Array1<f64>) -> Result<Array1<f64>> {
356 let mut current = partial.clone();
358
359 for _ in 0..10 {
360 let hidden = rbm.qbm().sample_hidden_given_visible(¤t.view())?;
361 current = rbm.qbm().sample_visible_given_hidden(&hidden)?;
362 }
363
364 Ok(current)
365}