quantum_boltzmann/
quantum_boltzmann.rs

1//! Quantum Boltzmann Machine Example
2//!
3//! This example demonstrates quantum Boltzmann machines for unsupervised learning,
4//! including RBMs and deep Boltzmann machines.
5
6use ndarray::{s, Array1, Array2};
7use quantrs2_ml::prelude::*;
8
9fn main() -> Result<()> {
10    println!("=== Quantum Boltzmann Machine Demo ===\n");
11
12    // Step 1: Basic Boltzmann machine
13    println!("1. Basic Quantum Boltzmann Machine...");
14    basic_qbm_demo()?;
15
16    // Step 2: Restricted Boltzmann Machine
17    println!("\n2. Quantum Restricted Boltzmann Machine (RBM)...");
18    rbm_demo()?;
19
20    // Step 3: Deep Boltzmann Machine
21    println!("\n3. Deep Boltzmann Machine...");
22    deep_boltzmann_demo()?;
23
24    // Step 4: Energy landscape visualization
25    println!("\n4. Energy Landscape Analysis...");
26    energy_landscape_demo()?;
27
28    // Step 5: Pattern completion
29    println!("\n5. Pattern Completion Demo...");
30    pattern_completion_demo()?;
31
32    println!("\n=== Boltzmann Machine Demo Complete ===");
33
34    Ok(())
35}
36
37/// Basic Quantum Boltzmann Machine demonstration
38fn basic_qbm_demo() -> Result<()> {
39    // Create a small QBM
40    let mut qbm = QuantumBoltzmannMachine::new(
41        4,    // visible units
42        2,    // hidden units
43        1.0,  // temperature
44        0.01, // learning rate
45    )?;
46
47    println!("   Created QBM with 4 visible and 2 hidden units");
48
49    // Generate synthetic binary data
50    let data = generate_binary_patterns(100, 4);
51
52    // Train the QBM
53    println!("   Training on binary patterns...");
54    let losses = qbm.train(&data, 50, 10)?;
55
56    println!("   Training complete:");
57    println!("   - Initial loss: {:.4}", losses[0]);
58    println!("   - Final loss: {:.4}", losses.last().unwrap());
59
60    // Sample from trained model
61    let samples = qbm.sample(5)?;
62    println!("\n   Generated samples:");
63    for (i, sample) in samples.outer_iter().enumerate() {
64        print!("   Sample {}: [", i + 1);
65        for val in sample.iter() {
66            print!("{:.0} ", val);
67        }
68        println!("]");
69    }
70
71    Ok(())
72}
73
74/// RBM demonstration with persistent contrastive divergence
75fn rbm_demo() -> Result<()> {
76    // Create RBM with annealing
77    let annealing = AnnealingSchedule::new(2.0, 0.5, 100);
78
79    let mut rbm = QuantumRBM::new(
80        6,    // visible units
81        3,    // hidden units
82        2.0,  // initial temperature
83        0.01, // learning rate
84    )?
85    .with_annealing(annealing);
86
87    println!("   Created Quantum RBM with annealing schedule");
88
89    // Generate correlated binary data
90    let data = generate_correlated_data(200, 6);
91
92    // Train with PCD
93    println!("   Training with Persistent Contrastive Divergence...");
94    let losses = rbm.train_pcd(
95        &data, 100, // epochs
96        20,  // batch size
97        50,  // persistent chains
98    )?;
99
100    // Analyze training
101    let improvement = (losses[0] - losses.last().unwrap()) / losses[0] * 100.0;
102    println!("   Training statistics:");
103    println!("   - Loss reduction: {:.1}%", improvement);
104    println!("   - Final temperature: 0.5");
105
106    // Test reconstruction
107    let test_data = data.slice(s![0..5, ..]).to_owned();
108    let reconstructed = rbm.qbm().reconstruct(&test_data)?;
109
110    println!("\n   Reconstruction quality:");
111    for i in 0..3 {
112        print!("   Original:      [");
113        for val in test_data.row(i).iter() {
114            print!("{:.0} ", val);
115        }
116        print!("]  →  Reconstructed: [");
117        for val in reconstructed.row(i).iter() {
118            print!("{:.0} ", val);
119        }
120        println!("]");
121    }
122
123    Ok(())
124}
125
126/// Deep Boltzmann Machine demonstration
127fn deep_boltzmann_demo() -> Result<()> {
128    // Create a 3-layer DBM
129    let layer_sizes = vec![8, 4, 2];
130    let mut dbm = DeepBoltzmannMachine::new(
131        layer_sizes.clone(),
132        1.0,  // temperature
133        0.01, // learning rate
134    )?;
135
136    println!("   Created Deep Boltzmann Machine:");
137    println!("   - Architecture: {:?}", layer_sizes);
138    println!("   - Total layers: {}", dbm.rbms().len());
139
140    // Generate hierarchical data
141    let data = generate_hierarchical_data(300, 8);
142
143    // Layer-wise pretraining
144    println!("\n   Performing layer-wise pretraining...");
145    dbm.pretrain(
146        &data, 50, // epochs per layer
147        30, // batch size
148    )?;
149
150    println!("\n   Pretraining complete!");
151    println!("   Each layer learned increasingly abstract features");
152
153    Ok(())
154}
155
156/// Energy landscape visualization
157fn energy_landscape_demo() -> Result<()> {
158    // Create small QBM for visualization
159    let qbm = QuantumBoltzmannMachine::new(
160        2,    // visible units (for 2D visualization)
161        1,    // hidden unit
162        0.5,  // temperature
163        0.01, // learning rate
164    )?;
165
166    println!("   Analyzing energy landscape of 2-unit system");
167
168    // Compute energy for all 4 possible states
169    let states = vec![
170        Array1::from_vec(vec![0.0, 0.0]),
171        Array1::from_vec(vec![0.0, 1.0]),
172        Array1::from_vec(vec![1.0, 0.0]),
173        Array1::from_vec(vec![1.0, 1.0]),
174    ];
175
176    println!("\n   State energies:");
177    for (i, state) in states.iter().enumerate() {
178        let energy = qbm.energy(state);
179        let prob = (-energy / qbm.temperature()).exp();
180        println!(
181            "   State [{:.0}, {:.0}]: E = {:.3}, P ∝ {:.3}",
182            state[0], state[1], energy, prob
183        );
184    }
185
186    // Show coupling matrix
187    println!("\n   Coupling matrix:");
188    for i in 0..3 {
189        print!("   [");
190        for j in 0..3 {
191            print!("{:6.3} ", qbm.couplings()[[i, j]]);
192        }
193        println!("]");
194    }
195
196    Ok(())
197}
198
199/// Pattern completion demonstration
200fn pattern_completion_demo() -> Result<()> {
201    // Create RBM
202    let mut rbm = QuantumRBM::new(
203        8,    // visible units
204        4,    // hidden units
205        1.0,  // temperature
206        0.02, // learning rate
207    )?;
208
209    // Train on specific patterns
210    let patterns = create_letter_patterns();
211    println!("   Training on letter-like patterns...");
212
213    rbm.train_pcd(&patterns, 100, 10, 20)?;
214
215    // Test pattern completion
216    println!("\n   Pattern completion test:");
217
218    // Create corrupted patterns
219    let mut corrupted = patterns.row(0).to_owned();
220    corrupted[3] = 1.0 - corrupted[3]; // Flip one bit
221    corrupted[5] = 1.0 - corrupted[5]; // Flip another
222
223    print!("   Corrupted:  [");
224    for val in corrupted.iter() {
225        print!("{:.0} ", val);
226    }
227    println!("]");
228
229    // Complete pattern
230    let completed = complete_pattern(&rbm, &corrupted)?;
231
232    print!("   Completed:  [");
233    for val in completed.iter() {
234        print!("{:.0} ", val);
235    }
236    println!("]");
237
238    print!("   Original:   [");
239    for val in patterns.row(0).iter() {
240        print!("{:.0} ", val);
241    }
242    println!("]");
243
244    let accuracy = patterns
245        .row(0)
246        .iter()
247        .zip(completed.iter())
248        .filter(|(&a, &b)| (a - b).abs() < 0.5)
249        .count() as f64
250        / 8.0;
251
252    println!("   Reconstruction accuracy: {:.1}%", accuracy * 100.0);
253
254    Ok(())
255}
256
257/// Generate binary patterns
258fn generate_binary_patterns(n_samples: usize, n_features: usize) -> Array2<f64> {
259    Array2::from_shape_fn((n_samples, n_features), |(_, _)| {
260        if rand::random::<f64>() > 0.5 {
261            1.0
262        } else {
263            0.0
264        }
265    })
266}
267
268/// Generate correlated binary data
269fn generate_correlated_data(n_samples: usize, n_features: usize) -> Array2<f64> {
270    let mut data = Array2::zeros((n_samples, n_features));
271
272    for i in 0..n_samples {
273        // Generate correlated features
274        let base = if rand::random::<f64>() > 0.5 {
275            1.0
276        } else {
277            0.0
278        };
279
280        for j in 0..n_features {
281            if j % 2 == 0 {
282                data[[i, j]] = base;
283            } else {
284                // Correlate with previous feature
285                data[[i, j]] = if rand::random::<f64>() > 0.2 {
286                    base
287                } else {
288                    1.0 - base
289                };
290            }
291        }
292    }
293
294    data
295}
296
297/// Generate hierarchical data
298fn generate_hierarchical_data(n_samples: usize, n_features: usize) -> Array2<f64> {
299    let mut data = Array2::zeros((n_samples, n_features));
300
301    for i in 0..n_samples {
302        // Choose high-level pattern
303        let pattern_type = i % 3;
304
305        match pattern_type {
306            0 => {
307                // Pattern A: alternating
308                for j in 0..n_features {
309                    data[[i, j]] = (j % 2) as f64;
310                }
311            }
312            1 => {
313                // Pattern B: blocks
314                for j in 0..n_features {
315                    data[[i, j]] = ((j / 2) % 2) as f64;
316                }
317            }
318            _ => {
319                // Pattern C: random with structure
320                let shift = (rand::random::<f64>() * 4.0) as usize;
321                for j in 0..n_features {
322                    data[[i, j]] = if (j + shift) % 3 == 0 { 1.0 } else { 0.0 };
323                }
324            }
325        }
326
327        // Add noise
328        for j in 0..n_features {
329            if rand::random::<f64>() < 0.1 {
330                data[[i, j]] = 1.0 - data[[i, j]];
331            }
332        }
333    }
334
335    data
336}
337
338/// Create letter-like patterns
339fn create_letter_patterns() -> Array2<f64> {
340    // Simple 8-bit patterns resembling letters
341    Array2::from_shape_vec(
342        (4, 8),
343        vec![
344            // Pattern 'L'
345            1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, // Pattern 'T'
346            1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, // Pattern 'I'
347            0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, // Pattern 'H'
348            1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0,
349        ],
350    )
351    .unwrap()
352}
353
354/// Complete a partial pattern
355fn complete_pattern(rbm: &QuantumRBM, partial: &Array1<f64>) -> Result<Array1<f64>> {
356    // Use Gibbs sampling to complete pattern
357    let mut current = partial.clone();
358
359    for _ in 0..10 {
360        let hidden = rbm.qbm().sample_hidden_given_visible(&current.view())?;
361        current = rbm.qbm().sample_visible_given_hidden(&hidden)?;
362    }
363
364    Ok(current)
365}