use scirs2_core::ndarray::{Array1, Array2, Axis};
use crate::error::{GraphError, Result};
use super::types::{CondensationConfig, CondensedGraph};
pub fn gradient_matching_condense(
adj: &Array2<f64>,
features: &Array2<f64>,
labels: &[usize],
config: &CondensationConfig,
) -> Result<CondensedGraph> {
let n = adj.nrows();
let d = features.ncols();
let k = config.target_nodes;
validate_distillation_inputs(adj, features, labels, k)?;
let num_classes = count_classes(labels);
let mut synth_features = Array2::<f64>::zeros((k, d));
let mut synth_labels = Vec::with_capacity(k);
let mut source_mapping = Vec::with_capacity(k);
let per_class = (k / num_classes.max(1)).max(1);
let mut class_counts = vec![0usize; num_classes];
let mut filled = 0;
for orig_idx in 0..n {
if filled >= k {
break;
}
let c = labels[orig_idx];
if c < num_classes && class_counts[c] < per_class {
for f in 0..d {
synth_features[[filled, f]] = features[[orig_idx, f]];
}
synth_labels.push(c);
source_mapping.push(orig_idx);
class_counts[c] += 1;
filled += 1;
}
}
if filled < k {
for orig_idx in 0..n {
if filled >= k {
break;
}
if !source_mapping.contains(&orig_idx) {
for f in 0..d {
synth_features[[filled, f]] = features[[orig_idx, f]];
}
synth_labels.push(labels[orig_idx]);
source_mapping.push(orig_idx);
filled += 1;
}
}
}
let mut synth_adj = build_initial_adjacency(&synth_features, k);
let norm_adj_orig = normalise_adjacency(adj, n);
let w = initialise_weight_matrix(d, num_classes);
let lr = config.learning_rate;
for _iter in 0..config.max_iterations {
let h_orig = gcn_forward(&norm_adj_orig, features, &w);
let node_grad_orig = compute_gradient(&h_orig, labels, num_classes);
let ax_orig = norm_adj_orig.dot(features);
let w_grad_orig = ax_orig.t().dot(&node_grad_orig);
let norm_adj_synth = normalise_adjacency(&synth_adj, k);
let h_synth = gcn_forward(&norm_adj_synth, &synth_features, &w);
let node_grad_synth = compute_gradient(&h_synth, &synth_labels, num_classes);
let ax_synth = norm_adj_synth.dot(&synth_features);
let w_grad_synth = ax_synth.t().dot(&node_grad_synth);
let w_grad_diff = &w_grad_orig - &w_grad_synth;
let feature_update = norm_adj_synth
.t()
.dot(&node_grad_synth)
.dot(&w_grad_diff.t());
for i in 0..k.min(synth_features.nrows()) {
for j in 0..d.min(synth_features.ncols()) {
if j < feature_update.ncols() {
let update = lr * feature_update[[i, j]];
let clamped = update.clamp(-1.0, 1.0);
synth_features[[i, j]] += clamped;
}
}
}
update_adjacency(&mut synth_adj, &synth_features, k, lr * 0.1);
}
Ok(CondensedGraph {
adjacency: synth_adj,
features: synth_features,
labels: synth_labels,
source_mapping,
})
}
pub fn feature_alignment_loss(orig_features: &Array2<f64>, synth_features: &Array2<f64>) -> f64 {
let mean_orig = orig_features.mean_axis(Axis(0));
let mean_synth = synth_features.mean_axis(Axis(0));
match (mean_orig, mean_synth) {
(Some(mo), Some(ms)) => {
let diff = &mo - &ms;
diff.dot(&diff)
}
_ => 0.0,
}
}
pub fn structure_matching_loss(orig_adj: &Array2<f64>, synth_adj: &Array2<f64>) -> f64 {
let mut degs_orig = degree_sequence(orig_adj);
let mut degs_synth = degree_sequence(synth_adj);
degs_orig.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
degs_synth.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let max_len = degs_orig.len().max(degs_synth.len());
degs_orig.resize(max_len, 0.0);
degs_synth.resize(max_len, 0.0);
let norm_orig = degs_orig.iter().sum::<f64>().max(1e-12);
let norm_synth = degs_synth.iter().sum::<f64>().max(1e-12);
let mut dist_sq = 0.0;
for i in 0..max_len {
let diff = degs_orig[i] / norm_orig - degs_synth[i] / norm_synth;
dist_sq += diff * diff;
}
dist_sq.sqrt()
}
fn normalise_adjacency(adj: &Array2<f64>, n: usize) -> Array2<f64> {
let mut a_hat = adj.clone();
for i in 0..n {
a_hat[[i, i]] += 1.0;
}
let mut d_inv_sqrt = Array1::<f64>::zeros(n);
for i in 0..n {
let deg: f64 = a_hat.row(i).sum();
if deg > 0.0 {
d_inv_sqrt[i] = 1.0 / deg.sqrt();
}
}
let mut normalised = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
normalised[[i, j]] = d_inv_sqrt[i] * a_hat[[i, j]] * d_inv_sqrt[j];
}
}
normalised
}
fn gcn_forward(norm_adj: &Array2<f64>, features: &Array2<f64>, w: &Array2<f64>) -> Array2<f64> {
let ax = norm_adj.dot(features);
let mut h = ax.dot(w);
h.mapv_inplace(|v| v.max(0.0));
h
}
fn compute_gradient(logits: &Array2<f64>, labels: &[usize], num_classes: usize) -> Array2<f64> {
let n = logits.nrows();
let c = logits.ncols().min(num_classes);
let mut probs = Array2::<f64>::zeros((n, c));
for i in 0..n {
let max_val = (0..c)
.map(|j| logits[[i, j]])
.fold(f64::NEG_INFINITY, f64::max);
let mut sum_exp = 0.0;
for j in 0..c {
let e = (logits[[i, j]] - max_val).exp();
probs[[i, j]] = e;
sum_exp += e;
}
if sum_exp > 0.0 {
for j in 0..c {
probs[[i, j]] /= sum_exp;
}
}
}
let mut grad = probs;
for i in 0..n {
let label = labels.get(i).copied().unwrap_or(0);
if label < c {
grad[[i, label]] -= 1.0;
}
}
let n_f64 = n as f64;
if n_f64 > 0.0 {
grad /= n_f64;
}
grad
}
fn initialise_weight_matrix(d: usize, num_classes: usize) -> Array2<f64> {
let scale = (2.0 / (d + num_classes) as f64).sqrt();
let mut w = Array2::<f64>::zeros((d, num_classes));
for i in 0..d {
for j in 0..num_classes {
let val = ((i * 7 + j * 13 + 3) as f64 % 17.0 - 8.0) / 17.0;
w[[i, j]] = val * scale;
}
}
w
}
fn build_initial_adjacency(features: &Array2<f64>, k: usize) -> Array2<f64> {
let mut adj = Array2::<f64>::zeros((k, k));
let knn = 3.min(k.saturating_sub(1));
for i in 0..k {
let mut dists: Vec<(usize, f64)> = (0..k)
.filter(|&j| j != i)
.map(|j| {
let diff = &features.row(i).to_owned() - &features.row(j).to_owned();
(j, diff.dot(&diff))
})
.collect();
dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
for &(j, _) in dists.iter().take(knn) {
adj[[i, j]] = 1.0;
adj[[j, i]] = 1.0; }
}
adj
}
fn update_adjacency(adj: &mut Array2<f64>, features: &Array2<f64>, k: usize, lr: f64) {
for i in 0..k {
for j in (i + 1)..k {
let diff = &features.row(i).to_owned() - &features.row(j).to_owned();
let sim = (-diff.dot(&diff)).exp();
let delta = lr * (sim - adj[[i, j]]);
adj[[i, j]] = (adj[[i, j]] + delta).clamp(0.0, 1.0);
adj[[j, i]] = adj[[i, j]];
}
}
}
fn count_classes(labels: &[usize]) -> usize {
if labels.is_empty() {
return 0;
}
let max_label = labels.iter().copied().max().unwrap_or(0);
max_label + 1
}
fn degree_sequence(adj: &Array2<f64>) -> Vec<f64> {
let n = adj.nrows();
(0..n).map(|i| adj.row(i).sum()).collect()
}
fn validate_distillation_inputs(
adj: &Array2<f64>,
features: &Array2<f64>,
labels: &[usize],
target_nodes: usize,
) -> Result<()> {
let n = adj.nrows();
if adj.nrows() != adj.ncols() {
return Err(GraphError::InvalidParameter {
param: "adj".to_string(),
value: format!("{}x{}", adj.nrows(), adj.ncols()),
expected: "square matrix".to_string(),
context: "gradient_matching_condense".to_string(),
});
}
if features.nrows() != n {
return Err(GraphError::InvalidParameter {
param: "features".to_string(),
value: format!("{} rows", features.nrows()),
expected: format!("{n} rows"),
context: "gradient_matching_condense: features must match adjacency".to_string(),
});
}
if labels.len() != n {
return Err(GraphError::InvalidParameter {
param: "labels".to_string(),
value: format!("length {}", labels.len()),
expected: format!("length {n}"),
context: "gradient_matching_condense: labels must match adjacency".to_string(),
});
}
if target_nodes == 0 {
return Err(GraphError::InvalidParameter {
param: "target_nodes".to_string(),
value: "0".to_string(),
expected: "target_nodes > 0".to_string(),
context: "gradient_matching_condense".to_string(),
});
}
if target_nodes > n {
return Err(GraphError::InvalidParameter {
param: "target_nodes".to_string(),
value: target_nodes.to_string(),
expected: format!("target_nodes <= {n}"),
context: "gradient_matching_condense: cannot condense to more nodes than original"
.to_string(),
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::condensation::types::{CondensationConfig, CondensationMethod};
fn simple_graph(n: usize, d: usize) -> (Array2<f64>, Array2<f64>, Vec<usize>) {
let mut adj = Array2::<f64>::zeros((n, n));
let mut features = Array2::<f64>::zeros((n, d));
let mut labels = vec![0usize; n];
for i in 0..(n - 1) {
adj[[i, i + 1]] = 1.0;
adj[[i + 1, i]] = 1.0;
}
let half = n / 2;
for i in 0..n {
if i < half {
features[[i, 0]] = i as f64 * 0.1;
if d > 1 {
features[[i, 1]] = 0.0;
}
labels[i] = 0;
} else {
features[[i, 0]] = 5.0 + (i - half) as f64 * 0.1;
if d > 1 {
features[[i, 1]] = 5.0;
}
labels[i] = 1;
}
}
(adj, features, labels)
}
#[test]
fn test_gradient_matching_produces_valid_output() {
let (adj, features, labels) = simple_graph(10, 3);
let config = CondensationConfig {
target_nodes: 4,
method: CondensationMethod::GradientMatching,
max_iterations: 10,
learning_rate: 0.01,
};
let result = gradient_matching_condense(&adj, &features, &labels, &config)
.expect("gradient_matching_condense should succeed");
assert_eq!(result.adjacency.nrows(), 4);
assert_eq!(result.adjacency.ncols(), 4);
assert_eq!(result.features.nrows(), 4);
assert_eq!(result.features.ncols(), 3);
assert_eq!(result.labels.len(), 4);
assert_eq!(result.source_mapping.len(), 4);
}
#[test]
fn test_gradient_matching_covers_classes() {
let (adj, features, labels) = simple_graph(10, 3);
let config = CondensationConfig {
target_nodes: 4,
method: CondensationMethod::GradientMatching,
max_iterations: 5,
learning_rate: 0.01,
};
let result = gradient_matching_condense(&adj, &features, &labels, &config)
.expect("gradient_matching_condense should succeed");
let has_class0 = result.labels.contains(&0);
let has_class1 = result.labels.contains(&1);
assert!(has_class0, "class 0 should be in condensed graph");
assert!(has_class1, "class 1 should be in condensed graph");
}
#[test]
fn test_gradient_matching_loss_decreases() {
let (adj, features, labels) = simple_graph(12, 4);
let config_few = CondensationConfig {
target_nodes: 4,
method: CondensationMethod::GradientMatching,
max_iterations: 2,
learning_rate: 0.01,
};
let result_few = gradient_matching_condense(&adj, &features, &labels, &config_few)
.expect("should succeed with few iterations");
let config_many = CondensationConfig {
target_nodes: 4,
method: CondensationMethod::GradientMatching,
max_iterations: 50,
learning_rate: 0.01,
};
let result_many = gradient_matching_condense(&adj, &features, &labels, &config_many)
.expect("should succeed with many iterations");
let loss_few = structure_matching_loss(&adj, &result_few.adjacency);
let loss_many = structure_matching_loss(&adj, &result_many.adjacency);
assert!(
loss_many < loss_few + 0.5,
"more iterations should not dramatically increase loss: few={loss_few}, many={loss_many}"
);
}
#[test]
fn test_gradient_matching_error_target_zero() {
let (adj, features, labels) = simple_graph(6, 2);
let config = CondensationConfig {
target_nodes: 0,
method: CondensationMethod::GradientMatching,
max_iterations: 5,
learning_rate: 0.01,
};
let result = gradient_matching_condense(&adj, &features, &labels, &config);
assert!(result.is_err());
}
#[test]
fn test_gradient_matching_error_target_too_large() {
let (adj, features, labels) = simple_graph(6, 2);
let config = CondensationConfig {
target_nodes: 100,
method: CondensationMethod::GradientMatching,
max_iterations: 5,
learning_rate: 0.01,
};
let result = gradient_matching_condense(&adj, &features, &labels, &config);
assert!(result.is_err());
}
#[test]
fn test_feature_alignment_loss_identical() {
let features = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("valid shape");
let loss = feature_alignment_loss(&features, &features);
assert!(
loss.abs() < 1e-12,
"feature alignment loss for identical features should be 0, got {loss}"
);
}
#[test]
fn test_feature_alignment_loss_different() {
let orig = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 0.0, 0.0]).expect("valid shape");
let synth =
Array2::from_shape_vec((2, 2), vec![10.0, 10.0, 10.0, 10.0]).expect("valid shape");
let loss = feature_alignment_loss(&orig, &synth);
assert!(
loss > 100.0,
"feature alignment loss for distant features should be large, got {loss}"
);
}
#[test]
fn test_structure_matching_loss_identical() {
let adj = Array2::from_shape_vec((3, 3), vec![0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0])
.expect("valid shape");
let loss = structure_matching_loss(&adj, &adj);
assert!(
loss.abs() < 1e-12,
"structure matching loss for identical adjacency should be 0, got {loss}"
);
}
#[test]
fn test_structure_matching_loss_different() {
let complete =
Array2::from_shape_vec((3, 3), vec![0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0])
.expect("valid shape");
let empty = Array2::<f64>::zeros((3, 3));
let loss = structure_matching_loss(&complete, &empty);
assert!(
loss > 0.0,
"structure matching loss for different graphs should be positive, got {loss}"
);
}
}