hep_classification/
hep_classification.rs

1#![allow(clippy::pedantic, clippy::unnecessary_wraps)]
2use quantrs2_ml::hep::{CollisionEvent, HEPQuantumClassifier, ParticleFeatures, ParticleType};
3use quantrs2_ml::prelude::*;
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::random::prelude::*;
6use std::time::Instant;
7
8fn main() -> Result<()> {
9    println!("Quantum High-Energy Physics Classification Example");
10    println!("=================================================");
11
12    // Create a quantum classifier for high-energy physics data
13    let num_qubits = 8;
14    let feature_dim = 8;
15    let num_classes = 2;
16
17    println!("Creating HEP quantum classifier with {num_qubits} qubits...");
18    let mut classifier = HEPQuantumClassifier::new(
19        num_qubits,
20        feature_dim,
21        num_classes,
22        quantrs2_ml::hep::HEPEncodingMethod::HybridEncoding,
23        vec!["background".to_string(), "higgs".to_string()],
24    )?;
25
26    // Generate synthetic training data
27    println!("Generating synthetic training data...");
28    let (training_particles, training_labels) = generate_synthetic_data(500);
29
30    println!("Training quantum classifier...");
31    let start = Instant::now();
32    let metrics = classifier.train_on_particles(
33        &training_particles,
34        &training_labels,
35        20,   // epochs
36        0.05, // learning rate
37    )?;
38
39    println!("Training completed in {:.2?}", start.elapsed());
40    println!("Final loss: {:.4}", metrics.final_loss);
41
42    // Generate test data
43    println!("Generating test data...");
44    let (test_particles, test_labels) = generate_synthetic_data(100);
45
46    // Evaluate classifier
47    println!("Evaluating classifier...");
48    // Convert test data to ndarray format
49    let num_samples = test_particles.len();
50    let mut test_features = Array2::zeros((num_samples, classifier.feature_dimension));
51    let mut test_labels_array = Array1::zeros(num_samples);
52
53    for (i, particle) in test_particles.iter().enumerate() {
54        let features = classifier.extract_features(particle)?;
55        for j in 0..features.len() {
56            test_features[[i, j]] = features[j];
57        }
58        test_labels_array[i] = test_labels[i] as f64;
59    }
60
61    let evaluation = classifier.evaluate(&test_features, &test_labels_array)?;
62
63    println!("Evaluation results:");
64    println!("  Overall accuracy: {:.2}%", evaluation.accuracy * 100.0);
65
66    println!("Class accuracies:");
67    for (i, &acc) in evaluation.class_accuracies.iter().enumerate() {
68        println!("  {}: {:.2}%", evaluation.class_labels[i], acc * 100.0);
69    }
70
71    // Create a test collision event
72    println!("\nClassifying a test collision event...");
73    let event = create_test_collision_event();
74
75    // Run classification
76    let classifications = classifier.classify_event(&event)?;
77
78    println!("Event classification results:");
79    for (i, (class, confidence)) in classifications.iter().enumerate() {
80        println!("  Particle {i}: {class} (confidence: {confidence:.2})");
81    }
82
83    // Create a Higgs detector
84    println!("\nCreating Higgs detector...");
85    let higgs_detector = quantrs2_ml::hep::HiggsDetector::new(num_qubits)?;
86
87    // Detect Higgs
88    let higgs_detections = higgs_detector.detect_higgs(&event)?;
89
90    println!("Higgs detection results:");
91    let higgs_count = higgs_detections.iter().filter(|&&x| x).count();
92    println!("  Found {higgs_count} potential Higgs particles");
93
94    Ok(())
95}
96
97// Generate synthetic particle data for training/testing
98fn generate_synthetic_data(num_samples: usize) -> (Vec<ParticleFeatures>, Vec<usize>) {
99    let mut particles = Vec::with_capacity(num_samples);
100    let mut labels = Vec::with_capacity(num_samples);
101
102    let particle_types = [
103        ParticleType::Electron,
104        ParticleType::Muon,
105        ParticleType::Photon,
106        ParticleType::Quark, // Changed from Proton which doesn't exist
107        ParticleType::Higgs,
108    ];
109
110    for i in 0..num_samples {
111        let is_higgs = i % 5 == 0;
112        let particle_type = if is_higgs {
113            ParticleType::Higgs
114        } else {
115            particle_types[i % 4]
116        };
117
118        // Generate synthetic four-momentum
119        // Higgs particles have higher energy
120        let energy_base = if is_higgs { 125.0 } else { 50.0 };
121        let energy = thread_rng().gen::<f64>().mul_add(10.0, energy_base);
122        let px = (thread_rng().gen::<f64>() - 0.5) * 20.0;
123        let py = (thread_rng().gen::<f64>() - 0.5) * 20.0;
124        let pz = (thread_rng().gen::<f64>() - 0.5) * 50.0;
125
126        // Create additional features
127        let mut additional_features = Vec::with_capacity(3);
128        for _ in 0..3 {
129            additional_features.push(thread_rng().gen::<f64>());
130        }
131
132        // Create particle features
133        let particle = ParticleFeatures {
134            particle_type,
135            four_momentum: [energy, px, py, pz],
136            additional_features,
137        };
138
139        particles.push(particle);
140        labels.push(usize::from(is_higgs));
141    }
142
143    (particles, labels)
144}
145
146// Create a test collision event
147fn create_test_collision_event() -> CollisionEvent {
148    let mut particles = Vec::new();
149
150    // Add an electron
151    particles.push(ParticleFeatures {
152        particle_type: ParticleType::Electron,
153        four_momentum: [50.5, 10.2, -15.7, 45.9],
154        additional_features: vec![0.8, 0.2, 0.3],
155    });
156
157    // Add a positron
158    particles.push(ParticleFeatures {
159        particle_type: ParticleType::Electron, // Type is electron, but with opposite charge
160        four_momentum: [50.2, -9.7, 14.3, -44.1],
161        additional_features: vec![0.7, 0.3, 0.2],
162    });
163
164    // Add photons (potential Higgs decay products)
165    particles.push(ParticleFeatures {
166        particle_type: ParticleType::Photon,
167        four_momentum: [62.8, 25.4, 30.1, 41.2],
168        additional_features: vec![0.9, 0.1, 0.4],
169    });
170
171    particles.push(ParticleFeatures {
172        particle_type: ParticleType::Photon,
173        four_momentum: [63.2, -24.1, -29.5, -40.8],
174        additional_features: vec![0.9, 0.1, 0.5],
175    });
176
177    // Create global event features
178    let global_features = vec![230.0]; // Total energy
179
180    CollisionEvent {
181        particles,
182        global_features,
183        event_type: Some("potential_higgs".to_string()),
184    }
185}