use scirs2_core::ndarray::Array2;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum CondensationMethod {
KCenter,
ImportanceSampling,
GradientMatching,
Herding,
}
#[derive(Debug, Clone)]
pub struct CondensationConfig {
pub target_nodes: usize,
pub method: CondensationMethod,
pub max_iterations: usize,
pub learning_rate: f64,
}
impl Default for CondensationConfig {
fn default() -> Self {
Self {
target_nodes: 100,
method: CondensationMethod::KCenter,
max_iterations: 200,
learning_rate: 0.01,
}
}
}
#[derive(Debug, Clone)]
pub struct CondensedGraph {
pub adjacency: Array2<f64>,
pub features: Array2<f64>,
pub labels: Vec<usize>,
pub source_mapping: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct QualityMetrics {
pub degree_distribution_distance: f64,
pub spectral_distance: f64,
pub label_coverage: f64,
}
#[derive(Debug, Clone)]
pub struct CondensationResult {
pub condensed: CondensedGraph,
pub compression_ratio: f64,
pub quality_metrics: QualityMetrics,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = CondensationConfig::default();
assert_eq!(config.target_nodes, 100);
assert_eq!(config.method, CondensationMethod::KCenter);
assert_eq!(config.max_iterations, 200);
assert!((config.learning_rate - 0.01).abs() < 1e-12);
}
#[test]
fn test_condensation_method_variants() {
let methods = [
CondensationMethod::KCenter,
CondensationMethod::ImportanceSampling,
CondensationMethod::GradientMatching,
CondensationMethod::Herding,
];
for i in 0..methods.len() {
for j in (i + 1)..methods.len() {
assert_ne!(methods[i], methods[j]);
}
}
}
#[test]
fn test_condensation_method_clone_and_copy() {
let method = CondensationMethod::Herding;
let cloned = method;
let copied = method;
assert_eq!(method, cloned);
assert_eq!(method, copied);
}
#[test]
fn test_condensed_graph_creation() {
let adj = Array2::<f64>::zeros((3, 3));
let features = Array2::<f64>::ones((3, 2));
let labels = vec![0, 1, 0];
let source_mapping = vec![0, 5, 10];
let graph = CondensedGraph {
adjacency: adj.clone(),
features: features.clone(),
labels: labels.clone(),
source_mapping: source_mapping.clone(),
};
assert_eq!(graph.adjacency.nrows(), 3);
assert_eq!(graph.adjacency.ncols(), 3);
assert_eq!(graph.features.nrows(), 3);
assert_eq!(graph.features.ncols(), 2);
assert_eq!(graph.labels, vec![0, 1, 0]);
assert_eq!(graph.source_mapping, vec![0, 5, 10]);
}
#[test]
fn test_condensed_graph_clone() {
let graph = CondensedGraph {
adjacency: Array2::<f64>::eye(2),
features: Array2::<f64>::ones((2, 3)),
labels: vec![1, 2],
source_mapping: vec![0, 1],
};
let cloned = graph.clone();
assert_eq!(cloned.labels, graph.labels);
assert_eq!(cloned.source_mapping, graph.source_mapping);
assert_eq!(cloned.adjacency, graph.adjacency);
}
#[test]
fn test_quality_metrics_creation() {
let metrics = QualityMetrics {
degree_distribution_distance: 0.05,
spectral_distance: 0.1,
label_coverage: 0.95,
};
assert!((metrics.degree_distribution_distance - 0.05).abs() < 1e-12);
assert!((metrics.spectral_distance - 0.1).abs() < 1e-12);
assert!((metrics.label_coverage - 0.95).abs() < 1e-12);
}
#[test]
fn test_quality_metrics_perfect() {
let metrics = QualityMetrics {
degree_distribution_distance: 0.0,
spectral_distance: 0.0,
label_coverage: 1.0,
};
assert!((metrics.degree_distribution_distance).abs() < 1e-12);
assert!((metrics.spectral_distance).abs() < 1e-12);
assert!((metrics.label_coverage - 1.0).abs() < 1e-12);
}
#[test]
fn test_condensation_result_creation() {
let result = CondensationResult {
condensed: CondensedGraph {
adjacency: Array2::<f64>::zeros((2, 2)),
features: Array2::<f64>::ones((2, 4)),
labels: vec![0, 1],
source_mapping: vec![3, 7],
},
compression_ratio: 5.0,
quality_metrics: QualityMetrics {
degree_distribution_distance: 0.1,
spectral_distance: 0.2,
label_coverage: 1.0,
},
};
assert!((result.compression_ratio - 5.0).abs() < 1e-12);
assert_eq!(result.condensed.labels.len(), 2);
assert!((result.quality_metrics.label_coverage - 1.0).abs() < 1e-12);
}
#[test]
fn test_config_custom() {
let config = CondensationConfig {
target_nodes: 50,
method: CondensationMethod::GradientMatching,
max_iterations: 500,
learning_rate: 0.001,
};
assert_eq!(config.target_nodes, 50);
assert_eq!(config.method, CondensationMethod::GradientMatching);
assert_eq!(config.max_iterations, 500);
assert!((config.learning_rate - 0.001).abs() < 1e-12);
}
}