quantrs2_tytan/analysis/
mod.rs1#[cfg(feature = "clustering")]
7use scirs2_core::ndarray::Array2;
8use std::collections::HashMap;
9use thiserror::Error;
10
11use crate::sampler::SampleResult;
12
13pub mod visualization;
15pub use visualization::*;
16
17pub mod graph;
19
20#[derive(Error, Debug)]
22pub enum AnalysisError {
23 #[error("Clustering error: {0}")]
25 ClusteringError(String),
26
27 #[error("Visualization error: {0}")]
29 VisualizationError(String),
30
31 #[error("Data processing error: {0}")]
33 DataProcessingError(String),
34}
35
36pub type AnalysisResult<T> = Result<T, AnalysisError>;
38
39#[cfg(feature = "clustering")]
41pub fn cluster_solutions(
42 results: &[SampleResult],
43 max_clusters: usize,
44) -> AnalysisResult<Vec<(Vec<usize>, f64)>> {
45 use crate::scirs_stub::scirs2_ml::KMeans;
46
47 if results.is_empty() {
48 return Err(AnalysisError::DataProcessingError(
49 "Empty results list".to_string(),
50 ));
51 }
52
53 let variable_names: Vec<String> = results[0].assignments.keys().cloned().collect();
55
56 let n_vars = variable_names.len();
58 let n_samples = results.len();
59
60 let mut data = Array2::<f64>::zeros((n_samples, n_vars));
61
62 for (i, result) in results.iter().enumerate() {
63 for (j, var_name) in variable_names.iter().enumerate() {
64 if let Some(&value) = result.assignments.get(var_name) {
65 data[[i, j]] = if value { 1.0 } else { 0.0 };
66 }
67 }
68 }
69
70 let actual_max_clusters = std::cmp::min(max_clusters, n_samples / 2);
72 let actual_max_clusters = std::cmp::max(actual_max_clusters, 2); let kmeans = KMeans::new(actual_max_clusters);
76 let labels = kmeans
77 .fit_predict(&data)
78 .map_err(|e| AnalysisError::ClusteringError(e.to_string()))?;
79
80 let mut clusters: HashMap<usize, Vec<usize>> = HashMap::new();
82 let mut cluster_energies: HashMap<usize, Vec<f64>> = HashMap::new();
83
84 for (i, &label) in labels.iter().enumerate() {
85 clusters.entry(label).or_default().push(i);
86 cluster_energies
87 .entry(label)
88 .or_default()
89 .push(results[i].energy);
90 }
91
92 let mut cluster_results = Vec::new();
94 for (label, indices) in clusters {
95 let avg_energy: f64 = cluster_energies[&label].iter().sum::<f64>() / indices.len() as f64;
96 cluster_results.push((indices, avg_energy));
97 }
98
99 cluster_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
101
102 Ok(cluster_results)
103}
104
105#[cfg(not(feature = "clustering"))]
107pub fn cluster_solutions(
108 results: &[SampleResult],
109 _max_clusters: usize,
110) -> AnalysisResult<Vec<(Vec<usize>, f64)>> {
111 if results.is_empty() {
113 return Err(AnalysisError::DataProcessingError(
114 "Empty results list".to_string(),
115 ));
116 }
117
118 let mut groups: HashMap<Vec<bool>, Vec<usize>> = HashMap::new();
120 let mut group_energies: HashMap<Vec<bool>, Vec<f64>> = HashMap::new();
121
122 let mut variable_names: Vec<String> = results[0].assignments.keys().cloned().collect();
124 variable_names.sort();
125
126 for (i, result) in results.iter().enumerate() {
127 let binary: Vec<bool> = variable_names
129 .iter()
130 .map(|name| *result.assignments.get(name).unwrap_or(&false))
131 .collect();
132
133 groups.entry(binary.clone()).or_default().push(i);
134 group_energies
135 .entry(binary)
136 .or_default()
137 .push(result.energy);
138 }
139
140 let mut group_results = Vec::new();
142 for (binary, indices) in groups {
143 let avg_energy: f64 = group_energies[&binary].iter().sum::<f64>() / indices.len() as f64;
144 group_results.push((indices, avg_energy));
145 }
146
147 group_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
149
150 Ok(group_results)
151}
152
153pub fn calculate_diversity(results: &[SampleResult]) -> AnalysisResult<HashMap<String, f64>> {
155 if results.is_empty() {
156 return Err(AnalysisError::DataProcessingError(
157 "Empty results list".to_string(),
158 ));
159 }
160
161 let variable_names: Vec<String> = results[0].assignments.keys().cloned().collect();
163
164 let n_vars = variable_names.len();
165
166 let mut metrics = HashMap::new();
168
169 let mut distances = Vec::new();
171
172 for i in 0..results.len() {
173 for j in (i + 1)..results.len() {
174 let mut distance = 0;
175
176 for var_name in &variable_names {
177 let val_i = results[i].assignments.get(var_name).unwrap_or(&false);
178 let val_j = results[j].assignments.get(var_name).unwrap_or(&false);
179
180 if val_i != val_j {
181 distance += 1;
182 }
183 }
184
185 distances.push(distance as f64 / n_vars as f64);
186 }
187 }
188
189 if !distances.is_empty() {
190 let avg_distance: f64 = distances.iter().sum::<f64>() / distances.len() as f64;
192
193 distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
195
196 let min_distance = distances.first().copied().unwrap_or(0.0);
197 let max_distance = distances.last().copied().unwrap_or(0.0);
198
199 let median_idx = distances.len() / 2;
200 let median_distance = if distances.len() % 2 == 0 {
201 f64::midpoint(distances[median_idx - 1], distances[median_idx])
202 } else {
203 distances[median_idx]
204 };
205
206 metrics.insert("avg_distance".to_string(), avg_distance);
207 metrics.insert("min_distance".to_string(), min_distance);
208 metrics.insert("max_distance".to_string(), max_distance);
209 metrics.insert("median_distance".to_string(), median_distance);
210 }
211
212 let energies: Vec<f64> = results.iter().map(|r| r.energy).collect();
214
215 if !energies.is_empty() {
216 let min_energy = *energies
217 .iter()
218 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
219 .unwrap_or(&0.0);
220 let max_energy = *energies
221 .iter()
222 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
223 .unwrap_or(&0.0);
224 let energy_range = max_energy - min_energy;
225
226 metrics.insert("energy_range".to_string(), energy_range);
227 metrics.insert("min_energy".to_string(), min_energy);
228 metrics.insert("max_energy".to_string(), max_energy);
229 }
230
231 for var_name in &variable_names {
233 let var_count = results
234 .iter()
235 .filter(|r| *r.assignments.get(var_name).unwrap_or(&false))
236 .count() as f64
237 / results.len() as f64;
238
239 metrics.insert(format!("var_bias_{var_name}"), var_count);
240 }
241
242 Ok(metrics)
243}
244
245#[cfg(feature = "plotters")]
247pub fn visualize_energy_distribution(
248 results: &[SampleResult],
249 file_path: &str,
250) -> AnalysisResult<()> {
251 use plotters::prelude::*;
252
253 if results.is_empty() {
254 return Err(AnalysisError::DataProcessingError(
255 "Empty results list".to_string(),
256 ));
257 }
258
259 let energies: Vec<f64> = results.iter().map(|r| r.energy).collect();
261
262 let root = BitMapBackend::new(file_path, (800, 600)).into_drawing_area();
264
265 root.fill(&WHITE)
266 .map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
267
268 let min_energy = *energies
269 .iter()
270 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
271 .ok_or_else(|| AnalysisError::DataProcessingError("No energies found".to_string()))?;
272 let max_energy = *energies
273 .iter()
274 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
275 .ok_or_else(|| AnalysisError::DataProcessingError("No energies found".to_string()))?;
276
277 let energy_range = max_energy - min_energy;
279 let padding = energy_range * 0.1;
280 let y_min = min_energy - padding;
281 let y_max = max_energy + padding;
282
283 let mut chart = ChartBuilder::on(&root)
284 .caption("Energy Distribution", ("sans-serif", 30))
285 .margin(10)
286 .x_label_area_size(40)
287 .y_label_area_size(60)
288 .build_cartesian_2d(0..results.len(), y_min..y_max)
289 .map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
290
291 chart
292 .configure_mesh()
293 .x_desc("Solution Index")
294 .y_desc("Energy")
295 .draw()
296 .map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
297
298 let mut sorted_energies = energies;
300 sorted_energies.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
301
302 chart
303 .draw_series(LineSeries::new(
304 sorted_energies.iter().enumerate().map(|(i, &e)| (i, e)),
305 &RED,
306 ))
307 .map_err(|e| AnalysisError::VisualizationError(e.to_string()))?
308 .label("Energy")
309 .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], RED));
310
311 chart
312 .configure_series_labels()
313 .background_style(WHITE.mix(0.8))
314 .border_style(BLACK)
315 .draw()
316 .map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
317
318 root.present()
319 .map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
320
321 Ok(())
322}
323
324#[cfg(not(feature = "plotters"))]
326pub fn visualize_energy_distribution(
327 _results: &[SampleResult],
328 _file_path: &str,
329) -> AnalysisResult<()> {
330 Err(AnalysisError::VisualizationError(
331 "Visualization requires the 'plotters' feature".to_string(),
332 ))
333}