quantum_boltzmann/
quantum_boltzmann.rs

1//! Quantum Boltzmann Machine Example
2//!
3//! This example demonstrates quantum Boltzmann machines for unsupervised learning,
4//! including RBMs and deep Boltzmann machines.
5
6use quantrs2_ml::prelude::*;
7use scirs2_core::ndarray::{s, Array1, Array2};
8use scirs2_core::random::prelude::*;
9
10fn main() -> Result<()> {
11    println!("=== Quantum Boltzmann Machine Demo ===\n");
12
13    // Step 1: Basic Boltzmann machine
14    println!("1. Basic Quantum Boltzmann Machine...");
15    basic_qbm_demo()?;
16
17    // Step 2: Restricted Boltzmann Machine
18    println!("\n2. Quantum Restricted Boltzmann Machine (RBM)...");
19    rbm_demo()?;
20
21    // Step 3: Deep Boltzmann Machine
22    println!("\n3. Deep Boltzmann Machine...");
23    deep_boltzmann_demo()?;
24
25    // Step 4: Energy landscape visualization
26    println!("\n4. Energy Landscape Analysis...");
27    energy_landscape_demo()?;
28
29    // Step 5: Pattern completion
30    println!("\n5. Pattern Completion Demo...");
31    pattern_completion_demo()?;
32
33    println!("\n=== Boltzmann Machine Demo Complete ===");
34
35    Ok(())
36}
37
38/// Basic Quantum Boltzmann Machine demonstration
39fn basic_qbm_demo() -> Result<()> {
40    // Create a small QBM
41    let mut qbm = QuantumBoltzmannMachine::new(
42        4,    // visible units
43        2,    // hidden units
44        1.0,  // temperature
45        0.01, // learning rate
46    )?;
47
48    println!("   Created QBM with 4 visible and 2 hidden units");
49
50    // Generate synthetic binary data
51    let data = generate_binary_patterns(100, 4);
52
53    // Train the QBM
54    println!("   Training on binary patterns...");
55    let losses = qbm.train(&data, 50, 10)?;
56
57    println!("   Training complete:");
58    println!("   - Initial loss: {:.4}", losses[0]);
59    println!("   - Final loss: {:.4}", losses.last().unwrap());
60
61    // Sample from trained model
62    let samples = qbm.sample(5)?;
63    println!("\n   Generated samples:");
64    for (i, sample) in samples.outer_iter().enumerate() {
65        print!("   Sample {}: [", i + 1);
66        for val in sample {
67            print!("{val:.0} ");
68        }
69        println!("]");
70    }
71
72    Ok(())
73}
74
75/// RBM demonstration with persistent contrastive divergence
76fn rbm_demo() -> Result<()> {
77    // Create RBM with annealing
78    let annealing = AnnealingSchedule::new(2.0, 0.5, 100);
79
80    let mut rbm = QuantumRBM::new(
81        6,    // visible units
82        3,    // hidden units
83        2.0,  // initial temperature
84        0.01, // learning rate
85    )?
86    .with_annealing(annealing);
87
88    println!("   Created Quantum RBM with annealing schedule");
89
90    // Generate correlated binary data
91    let data = generate_correlated_data(200, 6);
92
93    // Train with PCD
94    println!("   Training with Persistent Contrastive Divergence...");
95    let losses = rbm.train_pcd(
96        &data, 100, // epochs
97        20,  // batch size
98        50,  // persistent chains
99    )?;
100
101    // Analyze training
102    let improvement = (losses[0] - losses.last().unwrap()) / losses[0] * 100.0;
103    println!("   Training statistics:");
104    println!("   - Loss reduction: {improvement:.1}%");
105    println!("   - Final temperature: 0.5");
106
107    // Test reconstruction
108    let test_data = data.slice(s![0..5, ..]).to_owned();
109    let reconstructed = rbm.qbm().reconstruct(&test_data)?;
110
111    println!("\n   Reconstruction quality:");
112    for i in 0..3 {
113        print!("   Original:      [");
114        for val in test_data.row(i) {
115            print!("{val:.0} ");
116        }
117        print!("]  →  Reconstructed: [");
118        for val in reconstructed.row(i) {
119            print!("{val:.0} ");
120        }
121        println!("]");
122    }
123
124    Ok(())
125}
126
127/// Deep Boltzmann Machine demonstration
128fn deep_boltzmann_demo() -> Result<()> {
129    // Create a 3-layer DBM
130    let layer_sizes = vec![8, 4, 2];
131    let mut dbm = DeepBoltzmannMachine::new(
132        layer_sizes.clone(),
133        1.0,  // temperature
134        0.01, // learning rate
135    )?;
136
137    println!("   Created Deep Boltzmann Machine:");
138    println!("   - Architecture: {layer_sizes:?}");
139    println!("   - Total layers: {}", dbm.rbms().len());
140
141    // Generate hierarchical data
142    let data = generate_hierarchical_data(300, 8);
143
144    // Layer-wise pretraining
145    println!("\n   Performing layer-wise pretraining...");
146    dbm.pretrain(
147        &data, 50, // epochs per layer
148        30, // batch size
149    )?;
150
151    println!("\n   Pretraining complete!");
152    println!("   Each layer learned increasingly abstract features");
153
154    Ok(())
155}
156
157/// Energy landscape visualization
158fn energy_landscape_demo() -> Result<()> {
159    // Create small QBM for visualization
160    let qbm = QuantumBoltzmannMachine::new(
161        2,    // visible units (for 2D visualization)
162        1,    // hidden unit
163        0.5,  // temperature
164        0.01, // learning rate
165    )?;
166
167    println!("   Analyzing energy landscape of 2-unit system");
168
169    // Compute energy for all 4 possible states
170    let states = [
171        Array1::from_vec(vec![0.0, 0.0]),
172        Array1::from_vec(vec![0.0, 1.0]),
173        Array1::from_vec(vec![1.0, 0.0]),
174        Array1::from_vec(vec![1.0, 1.0]),
175    ];
176
177    println!("\n   State energies:");
178    for (i, state) in states.iter().enumerate() {
179        let energy = qbm.energy(state);
180        let prob = (-energy / qbm.temperature()).exp();
181        println!(
182            "   State [{:.0}, {:.0}]: E = {:.3}, P ∝ {:.3}",
183            state[0], state[1], energy, prob
184        );
185    }
186
187    // Show coupling matrix
188    println!("\n   Coupling matrix:");
189    for i in 0..3 {
190        print!("   [");
191        for j in 0..3 {
192            print!("{:6.3} ", qbm.couplings()[[i, j]]);
193        }
194        println!("]");
195    }
196
197    Ok(())
198}
199
200/// Pattern completion demonstration
201fn pattern_completion_demo() -> Result<()> {
202    // Create RBM
203    let mut rbm = QuantumRBM::new(
204        8,    // visible units
205        4,    // hidden units
206        1.0,  // temperature
207        0.02, // learning rate
208    )?;
209
210    // Train on specific patterns
211    let patterns = create_letter_patterns();
212    println!("   Training on letter-like patterns...");
213
214    rbm.train_pcd(&patterns, 100, 10, 20)?;
215
216    // Test pattern completion
217    println!("\n   Pattern completion test:");
218
219    // Create corrupted patterns
220    let mut corrupted = patterns.row(0).to_owned();
221    corrupted[3] = 1.0 - corrupted[3]; // Flip one bit
222    corrupted[5] = 1.0 - corrupted[5]; // Flip another
223
224    print!("   Corrupted:  [");
225    for val in &corrupted {
226        print!("{val:.0} ");
227    }
228    println!("]");
229
230    // Complete pattern
231    let completed = complete_pattern(&rbm, &corrupted)?;
232
233    print!("   Completed:  [");
234    for val in &completed {
235        print!("{val:.0} ");
236    }
237    println!("]");
238
239    print!("   Original:   [");
240    for val in patterns.row(0) {
241        print!("{val:.0} ");
242    }
243    println!("]");
244
245    let accuracy = patterns
246        .row(0)
247        .iter()
248        .zip(completed.iter())
249        .filter(|(&a, &b)| (a - b).abs() < 0.5)
250        .count() as f64
251        / 8.0;
252
253    println!("   Reconstruction accuracy: {:.1}%", accuracy * 100.0);
254
255    Ok(())
256}
257
258/// Generate binary patterns
259fn generate_binary_patterns(n_samples: usize, n_features: usize) -> Array2<f64> {
260    Array2::from_shape_fn((n_samples, n_features), |(_, _)| {
261        if thread_rng().gen::<f64>() > 0.5 {
262            1.0
263        } else {
264            0.0
265        }
266    })
267}
268
269/// Generate correlated binary data
270fn generate_correlated_data(n_samples: usize, n_features: usize) -> Array2<f64> {
271    let mut data = Array2::zeros((n_samples, n_features));
272
273    for i in 0..n_samples {
274        // Generate correlated features
275        let base = if thread_rng().gen::<f64>() > 0.5 {
276            1.0
277        } else {
278            0.0
279        };
280
281        for j in 0..n_features {
282            if j % 2 == 0 {
283                data[[i, j]] = base;
284            } else {
285                // Correlate with previous feature
286                data[[i, j]] = if thread_rng().gen::<f64>() > 0.2 {
287                    base
288                } else {
289                    1.0 - base
290                };
291            }
292        }
293    }
294
295    data
296}
297
298/// Generate hierarchical data
299fn generate_hierarchical_data(n_samples: usize, n_features: usize) -> Array2<f64> {
300    let mut data = Array2::zeros((n_samples, n_features));
301
302    for i in 0..n_samples {
303        // Choose high-level pattern
304        let pattern_type = i % 3;
305
306        match pattern_type {
307            0 => {
308                // Pattern A: alternating
309                for j in 0..n_features {
310                    data[[i, j]] = (j % 2) as f64;
311                }
312            }
313            1 => {
314                // Pattern B: blocks
315                for j in 0..n_features {
316                    data[[i, j]] = ((j / 2) % 2) as f64;
317                }
318            }
319            _ => {
320                // Pattern C: random with structure
321                let shift = (thread_rng().gen::<f64>() * 4.0) as usize;
322                for j in 0..n_features {
323                    data[[i, j]] = if (j + shift) % 3 == 0 { 1.0 } else { 0.0 };
324                }
325            }
326        }
327
328        // Add noise
329        for j in 0..n_features {
330            if thread_rng().gen::<f64>() < 0.1 {
331                data[[i, j]] = 1.0 - data[[i, j]];
332            }
333        }
334    }
335
336    data
337}
338
339/// Create letter-like patterns
340fn create_letter_patterns() -> Array2<f64> {
341    // Simple 8-bit patterns resembling letters
342    Array2::from_shape_vec(
343        (4, 8),
344        vec![
345            // Pattern 'L'
346            1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, // Pattern 'T'
347            1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, // Pattern 'I'
348            0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, // Pattern 'H'
349            1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0,
350        ],
351    )
352    .unwrap()
353}
354
355/// Complete a partial pattern
356fn complete_pattern(rbm: &QuantumRBM, partial: &Array1<f64>) -> Result<Array1<f64>> {
357    // Use Gibbs sampling to complete pattern
358    let mut current = partial.clone();
359
360    for _ in 0..10 {
361        let hidden = rbm.qbm().sample_hidden_given_visible(&current.view())?;
362        current = rbm.qbm().sample_visible_given_hidden(&hidden)?;
363    }
364
365    Ok(current)
366}