use rand::prelude::*;
use serde::{Deserialize, Serialize};
use crate::graph::SparseGraph;
use crate::types::EdgeImportance;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpectralSampler {
pub epsilon: f64,
}
impl SpectralSampler {
pub fn new(epsilon: f64) -> Self {
Self { epsilon }
}
pub fn sample_edges(
&self,
scores: &[EdgeImportance],
budget: usize,
backbone_edges: &std::collections::HashSet<(usize, usize)>,
) -> SparseGraph {
if scores.is_empty() {
return SparseGraph::new();
}
let mut rng = rand::thread_rng();
let n_vertices = scores
.iter()
.map(|s| s.u.max(s.v) + 1)
.max()
.unwrap_or(0);
let log_n = (n_vertices as f64).ln().max(1.0);
let total_importance: f64 = scores.iter().map(|s| s.score).sum();
if total_importance <= 0.0 {
return self.backbone_only_graph(scores, backbone_edges);
}
let backbone_count = scores
.iter()
.filter(|s| {
let key = Self::edge_key(s.u, s.v);
backbone_edges.contains(&key)
})
.count();
let sample_budget = budget.saturating_sub(backbone_count);
let c = if total_importance > 0.0 {
sample_budget as f64 / (total_importance * log_n / (self.epsilon * self.epsilon))
} else {
1.0
};
let mut g = SparseGraph::with_capacity(n_vertices);
for s in scores {
let key = Self::edge_key(s.u, s.v);
let is_backbone = backbone_edges.contains(&key);
if is_backbone {
let _ = g.insert_or_update_edge(s.u, s.v, s.weight);
continue;
}
let p = (c * s.score * log_n / (self.epsilon * self.epsilon)).min(1.0);
if p >= 1.0 || rng.gen::<f64>() < p {
let reweighted = if p > 0.0 { s.weight / p } else { s.weight };
let _ = g.insert_or_update_edge(s.u, s.v, reweighted);
}
}
g
}
pub fn sample_single_edge(
&self,
importance: &EdgeImportance,
n_vertices: usize,
total_importance: f64,
budget: usize,
) -> Option<(usize, usize, f64)> {
let log_n = (n_vertices as f64).ln().max(1.0);
let c = if total_importance > 0.0 {
budget as f64 / (total_importance * log_n / (self.epsilon * self.epsilon))
} else {
1.0
};
let p = (c * importance.score * log_n / (self.epsilon * self.epsilon)).min(1.0);
let mut rng = rand::thread_rng();
if p >= 1.0 || rng.gen::<f64>() < p {
let reweighted = if p > 0.0 {
importance.weight / p
} else {
importance.weight
};
Some((importance.u, importance.v, reweighted))
} else {
None
}
}
fn edge_key(u: usize, v: usize) -> (usize, usize) {
if u <= v { (u, v) } else { (v, u) }
}
fn backbone_only_graph(
&self,
scores: &[EdgeImportance],
backbone_edges: &std::collections::HashSet<(usize, usize)>,
) -> SparseGraph {
let n = scores
.iter()
.map(|s| s.u.max(s.v) + 1)
.max()
.unwrap_or(0);
let mut g = SparseGraph::with_capacity(n);
for s in scores {
let key = Self::edge_key(s.u, s.v);
if backbone_edges.contains(&key) {
let _ = g.insert_or_update_edge(s.u, s.v, s.weight);
}
}
g
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sample_with_backbone() {
let scores = vec![
EdgeImportance::new(0, 1, 1.0, 1.0),
EdgeImportance::new(1, 2, 1.0, 1.0),
EdgeImportance::new(0, 2, 1.0, 1.0),
];
let mut backbone = std::collections::HashSet::new();
backbone.insert((0, 1));
let sampler = SpectralSampler::new(0.2);
let g = sampler.sample_edges(&scores, 10, &backbone);
assert!(g.has_edge(0, 1));
}
#[test]
fn test_sample_empty() {
let sampler = SpectralSampler::new(0.2);
let g = sampler.sample_edges(&[], 10, &Default::default());
assert_eq!(g.num_edges(), 0);
}
#[test]
fn test_high_budget_keeps_all() {
let scores = vec![
EdgeImportance::new(0, 1, 1.0, 10.0),
EdgeImportance::new(1, 2, 1.0, 10.0),
];
let sampler = SpectralSampler::new(0.01); let g = sampler.sample_edges(&scores, 1000, &Default::default());
assert!(g.num_edges() >= 1); }
}