quantum_boltzmann/
quantum_boltzmann.rs

1#![allow(
2    clippy::pedantic,
3    clippy::unnecessary_wraps,
4    clippy::needless_range_loop,
5    clippy::useless_vec,
6    clippy::needless_collect,
7    clippy::too_many_arguments
8)]
9//! Quantum Boltzmann Machine Example
10//!
11//! This example demonstrates quantum Boltzmann machines for unsupervised learning,
12//! including RBMs and deep Boltzmann machines.
13
14use quantrs2_ml::prelude::*;
15use scirs2_core::ndarray::{s, Array1, Array2};
16use scirs2_core::random::prelude::*;
17
18fn main() -> Result<()> {
19    println!("=== Quantum Boltzmann Machine Demo ===\n");
20
21    // Step 1: Basic Boltzmann machine
22    println!("1. Basic Quantum Boltzmann Machine...");
23    basic_qbm_demo()?;
24
25    // Step 2: Restricted Boltzmann Machine
26    println!("\n2. Quantum Restricted Boltzmann Machine (RBM)...");
27    rbm_demo()?;
28
29    // Step 3: Deep Boltzmann Machine
30    println!("\n3. Deep Boltzmann Machine...");
31    deep_boltzmann_demo()?;
32
33    // Step 4: Energy landscape visualization
34    println!("\n4. Energy Landscape Analysis...");
35    energy_landscape_demo()?;
36
37    // Step 5: Pattern completion
38    println!("\n5. Pattern Completion Demo...");
39    pattern_completion_demo()?;
40
41    println!("\n=== Boltzmann Machine Demo Complete ===");
42
43    Ok(())
44}
45
46/// Basic Quantum Boltzmann Machine demonstration
47fn basic_qbm_demo() -> Result<()> {
48    // Create a small QBM
49    let mut qbm = QuantumBoltzmannMachine::new(
50        4,    // visible units
51        2,    // hidden units
52        1.0,  // temperature
53        0.01, // learning rate
54    )?;
55
56    println!("   Created QBM with 4 visible and 2 hidden units");
57
58    // Generate synthetic binary data
59    let data = generate_binary_patterns(100, 4);
60
61    // Train the QBM
62    println!("   Training on binary patterns...");
63    let losses = qbm.train(&data, 50, 10)?;
64
65    println!("   Training complete:");
66    println!("   - Initial loss: {:.4}", losses[0]);
67    println!("   - Final loss: {:.4}", losses.last().unwrap());
68
69    // Sample from trained model
70    let samples = qbm.sample(5)?;
71    println!("\n   Generated samples:");
72    for (i, sample) in samples.outer_iter().enumerate() {
73        print!("   Sample {}: [", i + 1);
74        for val in sample {
75            print!("{val:.0} ");
76        }
77        println!("]");
78    }
79
80    Ok(())
81}
82
83/// RBM demonstration with persistent contrastive divergence
84fn rbm_demo() -> Result<()> {
85    // Create RBM with annealing
86    let annealing = AnnealingSchedule::new(2.0, 0.5, 100);
87
88    let mut rbm = QuantumRBM::new(
89        6,    // visible units
90        3,    // hidden units
91        2.0,  // initial temperature
92        0.01, // learning rate
93    )?
94    .with_annealing(annealing);
95
96    println!("   Created Quantum RBM with annealing schedule");
97
98    // Generate correlated binary data
99    let data = generate_correlated_data(200, 6);
100
101    // Train with PCD
102    println!("   Training with Persistent Contrastive Divergence...");
103    let losses = rbm.train_pcd(
104        &data, 100, // epochs
105        20,  // batch size
106        50,  // persistent chains
107    )?;
108
109    // Analyze training
110    let improvement = (losses[0] - losses.last().unwrap()) / losses[0] * 100.0;
111    println!("   Training statistics:");
112    println!("   - Loss reduction: {improvement:.1}%");
113    println!("   - Final temperature: 0.5");
114
115    // Test reconstruction
116    let test_data = data.slice(s![0..5, ..]).to_owned();
117    let reconstructed = rbm.qbm().reconstruct(&test_data)?;
118
119    println!("\n   Reconstruction quality:");
120    for i in 0..3 {
121        print!("   Original:      [");
122        for val in test_data.row(i) {
123            print!("{val:.0} ");
124        }
125        print!("]  →  Reconstructed: [");
126        for val in reconstructed.row(i) {
127            print!("{val:.0} ");
128        }
129        println!("]");
130    }
131
132    Ok(())
133}
134
135/// Deep Boltzmann Machine demonstration
136fn deep_boltzmann_demo() -> Result<()> {
137    // Create a 3-layer DBM
138    let layer_sizes = vec![8, 4, 2];
139    let mut dbm = DeepBoltzmannMachine::new(
140        layer_sizes.clone(),
141        1.0,  // temperature
142        0.01, // learning rate
143    )?;
144
145    println!("   Created Deep Boltzmann Machine:");
146    println!("   - Architecture: {layer_sizes:?}");
147    println!("   - Total layers: {}", dbm.rbms().len());
148
149    // Generate hierarchical data
150    let data = generate_hierarchical_data(300, 8);
151
152    // Layer-wise pretraining
153    println!("\n   Performing layer-wise pretraining...");
154    dbm.pretrain(
155        &data, 50, // epochs per layer
156        30, // batch size
157    )?;
158
159    println!("\n   Pretraining complete!");
160    println!("   Each layer learned increasingly abstract features");
161
162    Ok(())
163}
164
165/// Energy landscape visualization
166fn energy_landscape_demo() -> Result<()> {
167    // Create small QBM for visualization
168    let qbm = QuantumBoltzmannMachine::new(
169        2,    // visible units (for 2D visualization)
170        1,    // hidden unit
171        0.5,  // temperature
172        0.01, // learning rate
173    )?;
174
175    println!("   Analyzing energy landscape of 2-unit system");
176
177    // Compute energy for all 4 possible states
178    let states = [
179        Array1::from_vec(vec![0.0, 0.0]),
180        Array1::from_vec(vec![0.0, 1.0]),
181        Array1::from_vec(vec![1.0, 0.0]),
182        Array1::from_vec(vec![1.0, 1.0]),
183    ];
184
185    println!("\n   State energies:");
186    for (i, state) in states.iter().enumerate() {
187        let energy = qbm.energy(state);
188        let prob = (-energy / qbm.temperature()).exp();
189        println!(
190            "   State [{:.0}, {:.0}]: E = {:.3}, P ∝ {:.3}",
191            state[0], state[1], energy, prob
192        );
193    }
194
195    // Show coupling matrix
196    println!("\n   Coupling matrix:");
197    for i in 0..3 {
198        print!("   [");
199        for j in 0..3 {
200            print!("{:6.3} ", qbm.couplings()[[i, j]]);
201        }
202        println!("]");
203    }
204
205    Ok(())
206}
207
208/// Pattern completion demonstration
209fn pattern_completion_demo() -> Result<()> {
210    // Create RBM
211    let mut rbm = QuantumRBM::new(
212        8,    // visible units
213        4,    // hidden units
214        1.0,  // temperature
215        0.02, // learning rate
216    )?;
217
218    // Train on specific patterns
219    let patterns = create_letter_patterns();
220    println!("   Training on letter-like patterns...");
221
222    rbm.train_pcd(&patterns, 100, 10, 20)?;
223
224    // Test pattern completion
225    println!("\n   Pattern completion test:");
226
227    // Create corrupted patterns
228    let mut corrupted = patterns.row(0).to_owned();
229    corrupted[3] = 1.0 - corrupted[3]; // Flip one bit
230    corrupted[5] = 1.0 - corrupted[5]; // Flip another
231
232    print!("   Corrupted:  [");
233    for val in &corrupted {
234        print!("{val:.0} ");
235    }
236    println!("]");
237
238    // Complete pattern
239    let completed = complete_pattern(&rbm, &corrupted)?;
240
241    print!("   Completed:  [");
242    for val in &completed {
243        print!("{val:.0} ");
244    }
245    println!("]");
246
247    print!("   Original:   [");
248    for val in patterns.row(0) {
249        print!("{val:.0} ");
250    }
251    println!("]");
252
253    let accuracy = patterns
254        .row(0)
255        .iter()
256        .zip(completed.iter())
257        .filter(|(&a, &b)| (a - b).abs() < 0.5)
258        .count() as f64
259        / 8.0;
260
261    println!("   Reconstruction accuracy: {:.1}%", accuracy * 100.0);
262
263    Ok(())
264}
265
266/// Generate binary patterns
267fn generate_binary_patterns(n_samples: usize, n_features: usize) -> Array2<f64> {
268    Array2::from_shape_fn((n_samples, n_features), |(_, _)| {
269        if thread_rng().gen::<f64>() > 0.5 {
270            1.0
271        } else {
272            0.0
273        }
274    })
275}
276
277/// Generate correlated binary data
278fn generate_correlated_data(n_samples: usize, n_features: usize) -> Array2<f64> {
279    let mut data = Array2::zeros((n_samples, n_features));
280
281    for i in 0..n_samples {
282        // Generate correlated features
283        let base = if thread_rng().gen::<f64>() > 0.5 {
284            1.0
285        } else {
286            0.0
287        };
288
289        for j in 0..n_features {
290            if j % 2 == 0 {
291                data[[i, j]] = base;
292            } else {
293                // Correlate with previous feature
294                data[[i, j]] = if thread_rng().gen::<f64>() > 0.2 {
295                    base
296                } else {
297                    1.0 - base
298                };
299            }
300        }
301    }
302
303    data
304}
305
306/// Generate hierarchical data
307fn generate_hierarchical_data(n_samples: usize, n_features: usize) -> Array2<f64> {
308    let mut data = Array2::zeros((n_samples, n_features));
309
310    for i in 0..n_samples {
311        // Choose high-level pattern
312        let pattern_type = i % 3;
313
314        match pattern_type {
315            0 => {
316                // Pattern A: alternating
317                for j in 0..n_features {
318                    data[[i, j]] = (j % 2) as f64;
319                }
320            }
321            1 => {
322                // Pattern B: blocks
323                for j in 0..n_features {
324                    data[[i, j]] = ((j / 2) % 2) as f64;
325                }
326            }
327            _ => {
328                // Pattern C: random with structure
329                let shift = (thread_rng().gen::<f64>() * 4.0) as usize;
330                for j in 0..n_features {
331                    data[[i, j]] = if (j + shift) % 3 == 0 { 1.0 } else { 0.0 };
332                }
333            }
334        }
335
336        // Add noise
337        for j in 0..n_features {
338            if thread_rng().gen::<f64>() < 0.1 {
339                data[[i, j]] = 1.0 - data[[i, j]];
340            }
341        }
342    }
343
344    data
345}
346
347/// Create letter-like patterns
348fn create_letter_patterns() -> Array2<f64> {
349    // Simple 8-bit patterns resembling letters
350    Array2::from_shape_vec(
351        (4, 8),
352        vec![
353            // Pattern 'L'
354            1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, // Pattern 'T'
355            1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, // Pattern 'I'
356            0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, // Pattern 'H'
357            1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0,
358        ],
359    )
360    .unwrap()
361}
362
363/// Complete a partial pattern
364fn complete_pattern(rbm: &QuantumRBM, partial: &Array1<f64>) -> Result<Array1<f64>> {
365    // Use Gibbs sampling to complete pattern
366    let mut current = partial.clone();
367
368    for _ in 0..10 {
369        let hidden = rbm.qbm().sample_hidden_given_visible(&current.view())?;
370        current = rbm.qbm().sample_visible_given_hidden(&hidden)?;
371    }
372
373    Ok(current)
374}