quantrs2_tytan/analysis/
mod.rs

1//! Analysis utilities for quantum annealing results.
2//!
3//! This module provides tools for analyzing and interpreting
4//! results from quantum annealing, including clustering and visualization.
5
6#[cfg(feature = "clustering")]
7use scirs2_core::ndarray::Array2;
8use std::collections::HashMap;
9use thiserror::Error;
10
11use crate::sampler::SampleResult;
12
13// Re-export visualization module
14pub mod visualization;
15pub use visualization::*;
16
17// Graph utilities
18pub mod graph;
19
20/// Errors that can occur during analysis
21#[derive(Error, Debug)]
22pub enum AnalysisError {
23    /// Error in clustering algorithm
24    #[error("Clustering error: {0}")]
25    ClusteringError(String),
26
27    /// Error in visualization
28    #[error("Visualization error: {0}")]
29    VisualizationError(String),
30
31    /// Error in data processing
32    #[error("Data processing error: {0}")]
33    DataProcessingError(String),
34}
35
36/// Result type for analysis operations
37pub type AnalysisResult<T> = Result<T, AnalysisError>;
38
39/// Cluster similar solutions to identify patterns
40#[cfg(feature = "clustering")]
41pub fn cluster_solutions(
42    results: &[SampleResult],
43    max_clusters: usize,
44) -> AnalysisResult<Vec<(Vec<usize>, f64)>> {
45    use crate::scirs_stub::scirs2_ml::KMeans;
46
47    if results.is_empty() {
48        return Err(AnalysisError::DataProcessingError(
49            "Empty results list".to_string(),
50        ));
51    }
52
53    // Extract all variable names
54    let variable_names: Vec<String> = results[0].assignments.keys().cloned().collect();
55
56    // Convert solutions to binary vectors
57    let n_vars = variable_names.len();
58    let n_samples = results.len();
59
60    let mut data = Array2::<f64>::zeros((n_samples, n_vars));
61
62    for (i, result) in results.iter().enumerate() {
63        for (j, var_name) in variable_names.iter().enumerate() {
64            if let Some(&value) = result.assignments.get(var_name) {
65                data[[i, j]] = if value { 1.0 } else { 0.0 };
66            }
67        }
68    }
69
70    // Determine optimal number of clusters
71    let actual_max_clusters = std::cmp::min(max_clusters, n_samples / 2);
72    let actual_max_clusters = std::cmp::max(actual_max_clusters, 2); // At least 2 clusters
73
74    // Run K-means clustering
75    let kmeans = KMeans::new(actual_max_clusters);
76    let labels = kmeans
77        .fit_predict(&data)
78        .map_err(|e| AnalysisError::ClusteringError(e.to_string()))?;
79
80    // Group results by cluster
81    let mut clusters: HashMap<usize, Vec<usize>> = HashMap::new();
82    let mut cluster_energies: HashMap<usize, Vec<f64>> = HashMap::new();
83
84    for (i, &label) in labels.iter().enumerate() {
85        clusters.entry(label).or_default().push(i);
86        cluster_energies
87            .entry(label)
88            .or_default()
89            .push(results[i].energy);
90    }
91
92    // Calculate average energy for each cluster
93    let mut cluster_results = Vec::new();
94    for (label, indices) in clusters {
95        let avg_energy: f64 = cluster_energies[&label].iter().sum::<f64>() / indices.len() as f64;
96        cluster_results.push((indices, avg_energy));
97    }
98
99    // Sort clusters by average energy
100    cluster_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
101
102    Ok(cluster_results)
103}
104
105/// Fallback clustering implementation
106#[cfg(not(feature = "clustering"))]
107pub fn cluster_solutions(
108    results: &[SampleResult],
109    _max_clusters: usize,
110) -> AnalysisResult<Vec<(Vec<usize>, f64)>> {
111    // Simple implementation: just group identical solutions
112    if results.is_empty() {
113        return Err(AnalysisError::DataProcessingError(
114            "Empty results list".to_string(),
115        ));
116    }
117
118    // Group solutions by their binary representation
119    let mut groups: HashMap<Vec<bool>, Vec<usize>> = HashMap::new();
120    let mut group_energies: HashMap<Vec<bool>, Vec<f64>> = HashMap::new();
121
122    // Extract all variable names in sorted order
123    let mut variable_names: Vec<String> = results[0].assignments.keys().cloned().collect();
124    variable_names.sort();
125
126    for (i, result) in results.iter().enumerate() {
127        // Convert to sorted binary vector for consistent comparison
128        let binary: Vec<bool> = variable_names
129            .iter()
130            .map(|name| *result.assignments.get(name).unwrap_or(&false))
131            .collect();
132
133        groups.entry(binary.clone()).or_default().push(i);
134        group_energies
135            .entry(binary)
136            .or_default()
137            .push(result.energy);
138    }
139
140    // Calculate average energy for each group
141    let mut group_results = Vec::new();
142    for (binary, indices) in groups {
143        let avg_energy: f64 = group_energies[&binary].iter().sum::<f64>() / indices.len() as f64;
144        group_results.push((indices, avg_energy));
145    }
146
147    // Sort groups by average energy
148    group_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
149
150    Ok(group_results)
151}
152
153/// Calculate diversity metrics for a set of solutions
154pub fn calculate_diversity(results: &[SampleResult]) -> AnalysisResult<HashMap<String, f64>> {
155    if results.is_empty() {
156        return Err(AnalysisError::DataProcessingError(
157            "Empty results list".to_string(),
158        ));
159    }
160
161    // Extract all variable names
162    let variable_names: Vec<String> = results[0].assignments.keys().cloned().collect();
163
164    let n_vars = variable_names.len();
165
166    // Calculate diversity metrics
167    let mut metrics = HashMap::new();
168
169    // 1. Hamming distance statistics
170    let mut distances = Vec::new();
171
172    for i in 0..results.len() {
173        for j in (i + 1)..results.len() {
174            let mut distance = 0;
175
176            for var_name in &variable_names {
177                let val_i = results[i].assignments.get(var_name).unwrap_or(&false);
178                let val_j = results[j].assignments.get(var_name).unwrap_or(&false);
179
180                if val_i != val_j {
181                    distance += 1;
182                }
183            }
184
185            distances.push(distance as f64 / n_vars as f64);
186        }
187    }
188
189    if !distances.is_empty() {
190        // Calculate statistics
191        let avg_distance: f64 = distances.iter().sum::<f64>() / distances.len() as f64;
192
193        // Sort for percentiles
194        distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
195
196        let min_distance = distances.first().copied().unwrap_or(0.0);
197        let max_distance = distances.last().copied().unwrap_or(0.0);
198
199        let median_idx = distances.len() / 2;
200        let median_distance = if distances.len() % 2 == 0 {
201            f64::midpoint(distances[median_idx - 1], distances[median_idx])
202        } else {
203            distances[median_idx]
204        };
205
206        metrics.insert("avg_distance".to_string(), avg_distance);
207        metrics.insert("min_distance".to_string(), min_distance);
208        metrics.insert("max_distance".to_string(), max_distance);
209        metrics.insert("median_distance".to_string(), median_distance);
210    }
211
212    // 2. Energy spread
213    let energies: Vec<f64> = results.iter().map(|r| r.energy).collect();
214
215    if !energies.is_empty() {
216        let min_energy = *energies
217            .iter()
218            .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
219            .unwrap_or(&0.0);
220        let max_energy = *energies
221            .iter()
222            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
223            .unwrap_or(&0.0);
224        let energy_range = max_energy - min_energy;
225
226        metrics.insert("energy_range".to_string(), energy_range);
227        metrics.insert("min_energy".to_string(), min_energy);
228        metrics.insert("max_energy".to_string(), max_energy);
229    }
230
231    // 3. Variable bias - how often each variable is 1
232    for var_name in &variable_names {
233        let var_count = results
234            .iter()
235            .filter(|r| *r.assignments.get(var_name).unwrap_or(&false))
236            .count() as f64
237            / results.len() as f64;
238
239        metrics.insert(format!("var_bias_{var_name}"), var_count);
240    }
241
242    Ok(metrics)
243}
244
245/// Generate visualizations for solution distributions
246#[cfg(feature = "plotters")]
247pub fn visualize_energy_distribution(
248    results: &[SampleResult],
249    file_path: &str,
250) -> AnalysisResult<()> {
251    use plotters::prelude::*;
252
253    if results.is_empty() {
254        return Err(AnalysisError::DataProcessingError(
255            "Empty results list".to_string(),
256        ));
257    }
258
259    // Extract energies
260    let energies: Vec<f64> = results.iter().map(|r| r.energy).collect();
261
262    // Create energy histogram
263    let root = BitMapBackend::new(file_path, (800, 600)).into_drawing_area();
264
265    root.fill(&WHITE)
266        .map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
267
268    let min_energy = *energies
269        .iter()
270        .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
271        .ok_or_else(|| AnalysisError::DataProcessingError("No energies found".to_string()))?;
272    let max_energy = *energies
273        .iter()
274        .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
275        .ok_or_else(|| AnalysisError::DataProcessingError("No energies found".to_string()))?;
276
277    // Add some padding
278    let energy_range = max_energy - min_energy;
279    let padding = energy_range * 0.1;
280    let y_min = min_energy - padding;
281    let y_max = max_energy + padding;
282
283    let mut chart = ChartBuilder::on(&root)
284        .caption("Energy Distribution", ("sans-serif", 30))
285        .margin(10)
286        .x_label_area_size(40)
287        .y_label_area_size(60)
288        .build_cartesian_2d(0..results.len(), y_min..y_max)
289        .map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
290
291    chart
292        .configure_mesh()
293        .x_desc("Solution Index")
294        .y_desc("Energy")
295        .draw()
296        .map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
297
298    // Sort energies for this plot
299    let mut sorted_energies = energies;
300    sorted_energies.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
301
302    chart
303        .draw_series(LineSeries::new(
304            sorted_energies.iter().enumerate().map(|(i, &e)| (i, e)),
305            &RED,
306        ))
307        .map_err(|e| AnalysisError::VisualizationError(e.to_string()))?
308        .label("Energy")
309        .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], RED));
310
311    chart
312        .configure_series_labels()
313        .background_style(WHITE.mix(0.8))
314        .border_style(BLACK)
315        .draw()
316        .map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
317
318    root.present()
319        .map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
320
321    Ok(())
322}
323
324/// Fallback visualization (empty implementation)
325#[cfg(not(feature = "plotters"))]
326pub fn visualize_energy_distribution(
327    _results: &[SampleResult],
328    _file_path: &str,
329) -> AnalysisResult<()> {
330    Err(AnalysisError::VisualizationError(
331        "Visualization requires the 'plotters' feature".to_string(),
332    ))
333}