use std::collections::{HashMap, HashSet, VecDeque};
use scirs2_core::random::{Rng, RngExt, SeedableRng, StdRng};
use crate::diffusion::models::AdjList;
use crate::error::{GraphError, Result};
pub type RRSet = Vec<usize>;
#[derive(Debug, Clone)]
pub struct RISConfig {
pub num_rr_sets: usize,
pub seed: Option<u64>,
}
impl Default for RISConfig {
fn default() -> Self {
RISConfig {
num_rr_sets: 10_000,
seed: None,
}
}
}
#[derive(Debug, Clone)]
pub struct ImmConfig {
pub k: usize,
pub epsilon: f64,
pub delta: f64,
pub seed: Option<u64>,
}
impl Default for ImmConfig {
fn default() -> Self {
ImmConfig {
k: 5,
epsilon: 0.1,
delta: 0.01,
seed: None,
}
}
}
#[derive(Debug, Clone)]
pub struct ImmResult {
pub seeds: Vec<usize>,
pub estimated_spread: f64,
pub num_rr_sets: usize,
}
fn reverse_adj(adjacency: &AdjList) -> AdjList {
let mut rev: AdjList = HashMap::new();
for (&src, nbrs) in adjacency {
for &(tgt, p) in nbrs {
rev.entry(tgt).or_default().push((src, p));
}
}
rev
}
fn generate_one_rr_set(rev_adj: &AdjList, num_nodes: usize, rng: &mut impl Rng) -> RRSet {
let root: usize = rng.random_range(0..num_nodes);
let mut rr_set: HashSet<usize> = HashSet::new();
let mut queue: VecDeque<usize> = VecDeque::new();
rr_set.insert(root);
queue.push_back(root);
while let Some(node) = queue.pop_front() {
if let Some(in_nbrs) = rev_adj.get(&node) {
for &(src, prob) in in_nbrs {
if !rr_set.contains(&src) && rng.random::<f64>() < prob {
rr_set.insert(src);
queue.push_back(src);
}
}
}
}
rr_set.into_iter().collect()
}
fn greedy_max_coverage(rr_sets: &[RRSet], num_nodes: usize, k: usize) -> (Vec<usize>, usize) {
let r = rr_sets.len();
let mut node_to_rr: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
for (i, rr) in rr_sets.iter().enumerate() {
for &node in rr {
if node < num_nodes {
node_to_rr[node].push(i);
}
}
}
let mut covered: Vec<bool> = vec![false; r];
let mut seeds: Vec<usize> = Vec::with_capacity(k);
let mut coverage: Vec<usize> = node_to_rr.iter().map(|v| v.len()).collect();
for _ in 0..k {
let best = (0..num_nodes).max_by_key(|&n| coverage[n]).unwrap_or(0);
seeds.push(best);
for &rr_idx in &node_to_rr[best] {
if !covered[rr_idx] {
covered[rr_idx] = true;
for &other in &rr_sets[rr_idx] {
if other < num_nodes && coverage[other] > 0 {
coverage[other] -= 1;
}
}
}
}
coverage[best] = 0;
}
let num_covered = covered.iter().filter(|&&c| c).count();
(seeds, num_covered)
}
pub fn generate_rr_sets(
adjacency: &AdjList,
num_nodes: usize,
config: &RISConfig,
) -> Result<Vec<RRSet>> {
if num_nodes == 0 {
return Err(GraphError::InvalidParameter {
param: "num_nodes".to_string(),
value: "0".to_string(),
expected: ">= 1".to_string(),
context: "generate_rr_sets".to_string(),
});
}
if config.num_rr_sets == 0 {
return Err(GraphError::InvalidParameter {
param: "num_rr_sets".to_string(),
value: "0".to_string(),
expected: ">= 1".to_string(),
context: "generate_rr_sets".to_string(),
});
}
let mut rng: StdRng = match config.seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_rng(&mut scirs2_core::random::rng()),
};
let rev = reverse_adj(adjacency);
let mut rr_sets: Vec<RRSet> = Vec::with_capacity(config.num_rr_sets);
for _ in 0..config.num_rr_sets {
rr_sets.push(generate_one_rr_set(&rev, num_nodes, &mut rng));
}
Ok(rr_sets)
}
pub fn ris_estimate(rr_sets: &[RRSet], seeds: &[usize], num_nodes: usize) -> Result<f64> {
if rr_sets.is_empty() {
return Err(GraphError::InvalidParameter {
param: "rr_sets".to_string(),
value: "empty".to_string(),
expected: "non-empty RR set collection".to_string(),
context: "ris_estimate".to_string(),
});
}
let seed_set: HashSet<usize> = seeds.iter().cloned().collect();
let num_covered = rr_sets
.iter()
.filter(|rr| rr.iter().any(|n| seed_set.contains(n)))
.count();
Ok(num_covered as f64 / rr_sets.len() as f64 * num_nodes as f64)
}
pub fn imm_algorithm(
adjacency: &AdjList,
num_nodes: usize,
config: &ImmConfig,
) -> Result<ImmResult> {
if num_nodes == 0 {
return Err(GraphError::InvalidParameter {
param: "num_nodes".to_string(),
value: "0".to_string(),
expected: ">= 1".to_string(),
context: "imm_algorithm".to_string(),
});
}
if config.k == 0 {
return Ok(ImmResult {
seeds: Vec::new(),
estimated_spread: 0.0,
num_rr_sets: 0,
});
}
if config.k > num_nodes {
return Err(GraphError::InvalidParameter {
param: "k".to_string(),
value: config.k.to_string(),
expected: format!("<= num_nodes={num_nodes}"),
context: "imm_algorithm".to_string(),
});
}
if !(0.0..1.0).contains(&config.epsilon) {
return Err(GraphError::InvalidParameter {
param: "epsilon".to_string(),
value: config.epsilon.to_string(),
expected: "(0, 1)".to_string(),
context: "imm_algorithm".to_string(),
});
}
if !(0.0..1.0).contains(&config.delta) {
return Err(GraphError::InvalidParameter {
param: "delta".to_string(),
value: config.delta.to_string(),
expected: "(0, 1)".to_string(),
context: "imm_algorithm".to_string(),
});
}
let n = num_nodes as f64;
let k = config.k;
let eps = config.epsilon;
let delta = config.delta;
let ell = (1.0_f64 / delta).ln();
let log_n = n.ln();
let log_cnk = if k >= 1 {
let k_f = k as f64;
k_f * (n / k_f).ln()
} else {
0.0_f64
};
let max_iters = (log_n / 2.0_f64.ln()).ceil() as usize + 1;
let mut rng: StdRng = match config.seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_rng(&mut scirs2_core::random::rng()),
};
let rev = reverse_adj(adjacency);
let mut rr_sets: Vec<RRSet> = Vec::new();
let lambda_prime =
(8.0 + 2.0 * eps) * n * (ell * log_n + log_cnk + (2.0_f64).ln()) / (eps * eps);
for i in 1..=max_iters {
let theta_i = (lambda_prime / (n / 2.0_f64.powi(i as i32 - 1))).ceil() as usize;
while rr_sets.len() < theta_i {
rr_sets.push(generate_one_rr_set(&rev, num_nodes, &mut rng));
}
let (_, num_covered) = greedy_max_coverage(&rr_sets, num_nodes, k);
let frac = num_covered as f64 / rr_sets.len() as f64;
let eps_star = compute_epsilon_prime(n, k, ell, frac * n, rr_sets.len());
if frac - eps_star >= (1.0 - 1.0 / std::f64::consts::E - eps) * frac {
break;
}
}
let total_rr = rr_sets.len();
let (seeds, num_covered) = greedy_max_coverage(&rr_sets, num_nodes, k);
let estimated_spread = num_covered as f64 / total_rr as f64 * n;
Ok(ImmResult {
seeds,
estimated_spread,
num_rr_sets: total_rr,
})
}
fn compute_epsilon_prime(n: f64, k: usize, ell: f64, spread: f64, num_rr: usize) -> f64 {
if spread < 1.0 || num_rr == 0 {
return 1.0;
}
let k_f = k as f64;
let log_term = ell + (6.0_f64).ln() + k_f * (n / k_f).ln();
let eps_sq = (2.0 * (1.0 + 0.1) * log_term * n) / (spread * num_rr as f64);
eps_sq.sqrt().min(1.0)
}
pub fn sandwich_approximation(
adjacency: &AdjList,
num_nodes: usize,
k: usize,
config: &RISConfig,
) -> Result<(Vec<usize>, f64, Vec<usize>, f64)> {
if num_nodes == 0 {
return Err(GraphError::InvalidParameter {
param: "num_nodes".to_string(),
value: "0".to_string(),
expected: ">= 1".to_string(),
context: "sandwich_approximation".to_string(),
});
}
if k == 0 {
return Ok((Vec::new(), 0.0, Vec::new(), 0.0));
}
if k > num_nodes {
return Err(GraphError::InvalidParameter {
param: "k".to_string(),
value: k.to_string(),
expected: format!("<= num_nodes={num_nodes}"),
context: "sandwich_approximation".to_string(),
});
}
let rr_sets = generate_rr_sets(adjacency, num_nodes, config)?;
let (lower_seeds, lower_covered) = greedy_max_coverage(&rr_sets, num_nodes, k);
let lower_spread = lower_covered as f64 / rr_sets.len() as f64 * num_nodes as f64;
let mut rng: StdRng = match config.seed {
Some(s) => StdRng::seed_from_u64(s.wrapping_add(0xDEAD_BEEF)),
None => StdRng::from_rng(&mut scirs2_core::random::rng()),
};
let rev = reverse_adj(adjacency);
let mut upper_rr = rr_sets.clone();
for _ in 0..config.num_rr_sets {
upper_rr.push(generate_one_rr_set(&rev, num_nodes, &mut rng));
}
let (upper_seeds, upper_covered) = greedy_max_coverage(&upper_rr, num_nodes, k);
let upper_spread = upper_covered as f64 / upper_rr.len() as f64 * num_nodes as f64;
Ok((lower_seeds, lower_spread, upper_seeds, upper_spread))
}
#[cfg(test)]
mod tests {
use super::*;
fn star_adj(n: usize, p: f64) -> AdjList {
let mut adj: AdjList = HashMap::new();
for i in 1..n {
adj.entry(0).or_default().push((i, p));
}
adj
}
fn complete_adj(n: usize, p: f64) -> AdjList {
let mut adj: AdjList = HashMap::new();
for i in 0..n {
for j in 0..n {
if i != j {
adj.entry(i).or_default().push((j, p));
}
}
}
adj
}
#[test]
fn test_generate_rr_sets_basic() {
let adj = star_adj(6, 1.0);
let config = RISConfig {
num_rr_sets: 50,
seed: Some(42),
};
let rr = generate_rr_sets(&adj, 6, &config).expect("rr sets");
assert_eq!(rr.len(), 50);
for r in &rr {
assert!(!r.is_empty());
}
}
#[test]
fn test_generate_rr_sets_invalid_params() {
let adj = star_adj(6, 1.0);
let err = generate_rr_sets(&adj, 0, &RISConfig::default());
assert!(err.is_err());
let config = RISConfig {
num_rr_sets: 0,
seed: None,
};
let err2 = generate_rr_sets(&adj, 6, &config);
assert!(err2.is_err());
}
#[test]
fn test_ris_estimate_star_hub() {
let adj = star_adj(6, 1.0);
let config = RISConfig {
num_rr_sets: 200,
seed: Some(123),
};
let rr = generate_rr_sets(&adj, 6, &config).expect("rr sets");
let spread = ris_estimate(&rr, &[0], 6).expect("estimate");
assert!(spread >= 4.0, "spread={spread}");
}
#[test]
fn test_ris_estimate_empty_seed() {
let adj = star_adj(6, 1.0);
let config = RISConfig {
num_rr_sets: 100,
seed: Some(0),
};
let rr = generate_rr_sets(&adj, 6, &config).expect("rr sets");
let spread = ris_estimate(&rr, &[], 6).expect("zero seed");
assert_eq!(spread, 0.0);
}
#[test]
fn test_ris_estimate_empty_rr_error() {
let err = ris_estimate(&[], &[0], 6);
assert!(err.is_err());
}
#[test]
fn test_imm_star_selects_hub() {
let adj = star_adj(8, 1.0);
let config = ImmConfig {
k: 1,
epsilon: 0.3,
delta: 0.1,
seed: Some(42),
};
let result = imm_algorithm(&adj, 8, &config).expect("imm");
assert_eq!(result.seeds.len(), 1);
assert_eq!(result.seeds[0], 0, "hub expected, got {:?}", result.seeds);
assert!(result.estimated_spread >= 1.0);
}
#[test]
fn test_imm_k0_returns_empty() {
let adj = star_adj(5, 1.0);
let config = ImmConfig {
k: 0,
..Default::default()
};
let result = imm_algorithm(&adj, 5, &config).expect("imm k=0");
assert!(result.seeds.is_empty());
assert_eq!(result.estimated_spread, 0.0);
}
#[test]
fn test_imm_invalid_params() {
let adj = star_adj(5, 1.0);
let err = imm_algorithm(
&adj,
5,
&ImmConfig {
k: 10,
..Default::default()
},
);
assert!(err.is_err());
let err2 = imm_algorithm(
&adj,
5,
&ImmConfig {
epsilon: 1.5,
..Default::default()
},
);
assert!(err2.is_err());
let err3 = imm_algorithm(&adj, 0, &ImmConfig::default());
assert!(err3.is_err());
}
#[test]
fn test_imm_complete_graph() {
let adj = complete_adj(5, 1.0);
let config = ImmConfig {
k: 1,
epsilon: 0.3,
delta: 0.1,
seed: Some(7),
};
let result = imm_algorithm(&adj, 5, &config).expect("imm complete");
assert_eq!(result.seeds.len(), 1);
assert!(result.estimated_spread >= 1.0);
}
#[test]
fn test_sandwich_approximation_basic() {
let adj = star_adj(6, 1.0);
let config = RISConfig {
num_rr_sets: 100,
seed: Some(99),
};
let (lower_seeds, lower_spread, upper_seeds, upper_spread) =
sandwich_approximation(&adj, 6, 1, &config).expect("sandwich");
assert_eq!(lower_seeds.len(), 1);
assert_eq!(upper_seeds.len(), 1);
assert!(lower_spread >= 0.0);
assert!(upper_spread >= 0.0);
let _ = (lower_spread, upper_spread); }
#[test]
fn test_sandwich_k0_returns_empty() {
let adj = star_adj(5, 1.0);
let (ls, lsp, us, usp) =
sandwich_approximation(&adj, 5, 0, &RISConfig::default()).expect("k=0");
assert!(ls.is_empty());
assert!(us.is_empty());
assert_eq!(lsp, 0.0);
assert_eq!(usp, 0.0);
}
#[test]
fn test_sandwich_invalid_params() {
let adj = star_adj(5, 1.0);
let err = sandwich_approximation(&adj, 5, 10, &RISConfig::default());
assert!(err.is_err());
let err2 = sandwich_approximation(&adj, 0, 1, &RISConfig::default());
assert!(err2.is_err());
}
}