#[cfg(feature = "clustering")]
use scirs2_core::ndarray::Array2;
use std::collections::HashMap;
use thiserror::Error;
use crate::sampler::SampleResult;
pub mod visualization;
pub use visualization::*;
pub mod graph;
#[derive(Error, Debug)]
pub enum AnalysisError {
#[error("Clustering error: {0}")]
ClusteringError(String),
#[error("Visualization error: {0}")]
VisualizationError(String),
#[error("Data processing error: {0}")]
DataProcessingError(String),
}
pub type AnalysisResult<T> = Result<T, AnalysisError>;
#[cfg(feature = "clustering")]
pub fn cluster_solutions(
results: &[SampleResult],
max_clusters: usize,
) -> AnalysisResult<Vec<(Vec<usize>, f64)>> {
use crate::scirs_stub::scirs2_ml::KMeans;
if results.is_empty() {
return Err(AnalysisError::DataProcessingError(
"Empty results list".to_string(),
));
}
let variable_names: Vec<String> = results[0].assignments.keys().cloned().collect();
let n_vars = variable_names.len();
let n_samples = results.len();
let mut data = Array2::<f64>::zeros((n_samples, n_vars));
for (i, result) in results.iter().enumerate() {
for (j, var_name) in variable_names.iter().enumerate() {
if let Some(&value) = result.assignments.get(var_name) {
data[[i, j]] = if value { 1.0 } else { 0.0 };
}
}
}
let actual_max_clusters = std::cmp::min(max_clusters, n_samples / 2);
let actual_max_clusters = std::cmp::max(actual_max_clusters, 2);
let kmeans = KMeans::new(actual_max_clusters);
let labels = kmeans
.fit_predict(&data)
.map_err(|e| AnalysisError::ClusteringError(e.to_string()))?;
let mut clusters: HashMap<usize, Vec<usize>> = HashMap::new();
let mut cluster_energies: HashMap<usize, Vec<f64>> = HashMap::new();
for (i, &label) in labels.iter().enumerate() {
clusters.entry(label).or_default().push(i);
cluster_energies
.entry(label)
.or_default()
.push(results[i].energy);
}
let mut cluster_results = Vec::new();
for (label, indices) in clusters {
let avg_energy: f64 = cluster_energies[&label].iter().sum::<f64>() / indices.len() as f64;
cluster_results.push((indices, avg_energy));
}
cluster_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(cluster_results)
}
#[cfg(not(feature = "clustering"))]
pub fn cluster_solutions(
results: &[SampleResult],
_max_clusters: usize,
) -> AnalysisResult<Vec<(Vec<usize>, f64)>> {
if results.is_empty() {
return Err(AnalysisError::DataProcessingError(
"Empty results list".to_string(),
));
}
let mut groups: HashMap<Vec<bool>, Vec<usize>> = HashMap::new();
let mut group_energies: HashMap<Vec<bool>, Vec<f64>> = HashMap::new();
let mut variable_names: Vec<String> = results[0].assignments.keys().cloned().collect();
variable_names.sort();
for (i, result) in results.iter().enumerate() {
let binary: Vec<bool> = variable_names
.iter()
.map(|name| *result.assignments.get(name).unwrap_or(&false))
.collect();
groups.entry(binary.clone()).or_default().push(i);
group_energies
.entry(binary)
.or_default()
.push(result.energy);
}
let mut group_results = Vec::new();
for (binary, indices) in groups {
let avg_energy: f64 = group_energies[&binary].iter().sum::<f64>() / indices.len() as f64;
group_results.push((indices, avg_energy));
}
group_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(group_results)
}
pub fn calculate_diversity(results: &[SampleResult]) -> AnalysisResult<HashMap<String, f64>> {
if results.is_empty() {
return Err(AnalysisError::DataProcessingError(
"Empty results list".to_string(),
));
}
let variable_names: Vec<String> = results[0].assignments.keys().cloned().collect();
let n_vars = variable_names.len();
let mut metrics = HashMap::new();
let mut distances = Vec::new();
for i in 0..results.len() {
for j in (i + 1)..results.len() {
let mut distance = 0;
for var_name in &variable_names {
let val_i = results[i].assignments.get(var_name).unwrap_or(&false);
let val_j = results[j].assignments.get(var_name).unwrap_or(&false);
if val_i != val_j {
distance += 1;
}
}
distances.push(distance as f64 / n_vars as f64);
}
}
if !distances.is_empty() {
let avg_distance: f64 = distances.iter().sum::<f64>() / distances.len() as f64;
distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let min_distance = distances.first().copied().unwrap_or(0.0);
let max_distance = distances.last().copied().unwrap_or(0.0);
let median_idx = distances.len() / 2;
let median_distance = if distances.len() % 2 == 0 {
f64::midpoint(distances[median_idx - 1], distances[median_idx])
} else {
distances[median_idx]
};
metrics.insert("avg_distance".to_string(), avg_distance);
metrics.insert("min_distance".to_string(), min_distance);
metrics.insert("max_distance".to_string(), max_distance);
metrics.insert("median_distance".to_string(), median_distance);
}
let energies: Vec<f64> = results.iter().map(|r| r.energy).collect();
if !energies.is_empty() {
let min_energy = *energies
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(&0.0);
let max_energy = *energies
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(&0.0);
let energy_range = max_energy - min_energy;
metrics.insert("energy_range".to_string(), energy_range);
metrics.insert("min_energy".to_string(), min_energy);
metrics.insert("max_energy".to_string(), max_energy);
}
for var_name in &variable_names {
let var_count = results
.iter()
.filter(|r| *r.assignments.get(var_name).unwrap_or(&false))
.count() as f64
/ results.len() as f64;
metrics.insert(format!("var_bias_{var_name}"), var_count);
}
Ok(metrics)
}
#[cfg(feature = "plotters")]
pub fn visualize_energy_distribution(
results: &[SampleResult],
file_path: &str,
) -> AnalysisResult<()> {
use plotters::prelude::*;
if results.is_empty() {
return Err(AnalysisError::DataProcessingError(
"Empty results list".to_string(),
));
}
let energies: Vec<f64> = results.iter().map(|r| r.energy).collect();
let root = BitMapBackend::new(file_path, (800, 600)).into_drawing_area();
root.fill(&WHITE)
.map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
let min_energy = *energies
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.ok_or_else(|| AnalysisError::DataProcessingError("No energies found".to_string()))?;
let max_energy = *energies
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.ok_or_else(|| AnalysisError::DataProcessingError("No energies found".to_string()))?;
let energy_range = max_energy - min_energy;
let padding = energy_range * 0.1;
let y_min = min_energy - padding;
let y_max = max_energy + padding;
let mut chart = ChartBuilder::on(&root)
.caption("Energy Distribution", ("sans-serif", 30))
.margin(10)
.x_label_area_size(40)
.y_label_area_size(60)
.build_cartesian_2d(0..results.len(), y_min..y_max)
.map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
chart
.configure_mesh()
.x_desc("Solution Index")
.y_desc("Energy")
.draw()
.map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
let mut sorted_energies = energies;
sorted_energies.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
chart
.draw_series(LineSeries::new(
sorted_energies.iter().enumerate().map(|(i, &e)| (i, e)),
&RED,
))
.map_err(|e| AnalysisError::VisualizationError(e.to_string()))?
.label("Energy")
.legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], RED));
chart
.configure_series_labels()
.background_style(WHITE.mix(0.8))
.border_style(BLACK)
.draw()
.map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
root.present()
.map_err(|e| AnalysisError::VisualizationError(e.to_string()))?;
Ok(())
}
#[cfg(not(feature = "plotters"))]
pub fn visualize_energy_distribution(
_results: &[SampleResult],
_file_path: &str,
) -> AnalysisResult<()> {
Err(AnalysisError::VisualizationError(
"Visualization requires the 'plotters' feature".to_string(),
))
}