1use std::collections::{HashMap, HashSet, VecDeque};
32
33use scirs2_core::random::{Rng, RngExt, SeedableRng, StdRng};
34
35use crate::diffusion::models::AdjList;
36use crate::error::{GraphError, Result};
37
38pub type RRSet = Vec<usize>;
47
48#[derive(Debug, Clone)]
50pub struct RISConfig {
51 pub num_rr_sets: usize,
57 pub seed: Option<u64>,
59}
60
61impl Default for RISConfig {
62 fn default() -> Self {
63 RISConfig {
64 num_rr_sets: 10_000,
65 seed: None,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct ImmConfig {
73 pub k: usize,
75 pub epsilon: f64,
78 pub delta: f64,
81 pub seed: Option<u64>,
83}
84
85impl Default for ImmConfig {
86 fn default() -> Self {
87 ImmConfig {
88 k: 5,
89 epsilon: 0.1,
90 delta: 0.01,
91 seed: None,
92 }
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct ImmResult {
99 pub seeds: Vec<usize>,
101 pub estimated_spread: f64,
103 pub num_rr_sets: usize,
105}
106
107fn reverse_adj(adjacency: &AdjList) -> AdjList {
113 let mut rev: AdjList = HashMap::new();
114 for (&src, nbrs) in adjacency {
115 for &(tgt, p) in nbrs {
116 rev.entry(tgt).or_default().push((src, p));
117 }
118 }
119 rev
120}
121
122fn generate_one_rr_set(rev_adj: &AdjList, num_nodes: usize, rng: &mut impl Rng) -> RRSet {
129 let root: usize = rng.random_range(0..num_nodes);
130 let mut rr_set: HashSet<usize> = HashSet::new();
131 let mut queue: VecDeque<usize> = VecDeque::new();
132
133 rr_set.insert(root);
134 queue.push_back(root);
135
136 while let Some(node) = queue.pop_front() {
137 if let Some(in_nbrs) = rev_adj.get(&node) {
139 for &(src, prob) in in_nbrs {
140 if !rr_set.contains(&src) && rng.random::<f64>() < prob {
142 rr_set.insert(src);
143 queue.push_back(src);
144 }
145 }
146 }
147 }
148
149 rr_set.into_iter().collect()
150}
151
152fn greedy_max_coverage(rr_sets: &[RRSet], num_nodes: usize, k: usize) -> (Vec<usize>, usize) {
159 let r = rr_sets.len();
160
161 let mut node_to_rr: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
163 for (i, rr) in rr_sets.iter().enumerate() {
164 for &node in rr {
165 if node < num_nodes {
166 node_to_rr[node].push(i);
167 }
168 }
169 }
170
171 let mut covered: Vec<bool> = vec![false; r];
172 let mut seeds: Vec<usize> = Vec::with_capacity(k);
173 let mut coverage: Vec<usize> = node_to_rr.iter().map(|v| v.len()).collect();
175
176 for _ in 0..k {
177 let best = (0..num_nodes).max_by_key(|&n| coverage[n]).unwrap_or(0);
179
180 seeds.push(best);
181
182 for &rr_idx in &node_to_rr[best] {
184 if !covered[rr_idx] {
185 covered[rr_idx] = true;
186 for &other in &rr_sets[rr_idx] {
188 if other < num_nodes && coverage[other] > 0 {
189 coverage[other] -= 1;
190 }
191 }
192 }
193 }
194
195 coverage[best] = 0;
197 }
198
199 let num_covered = covered.iter().filter(|&&c| c).count();
200 (seeds, num_covered)
201}
202
203pub fn generate_rr_sets(
220 adjacency: &AdjList,
221 num_nodes: usize,
222 config: &RISConfig,
223) -> Result<Vec<RRSet>> {
224 if num_nodes == 0 {
225 return Err(GraphError::InvalidParameter {
226 param: "num_nodes".to_string(),
227 value: "0".to_string(),
228 expected: ">= 1".to_string(),
229 context: "generate_rr_sets".to_string(),
230 });
231 }
232 if config.num_rr_sets == 0 {
233 return Err(GraphError::InvalidParameter {
234 param: "num_rr_sets".to_string(),
235 value: "0".to_string(),
236 expected: ">= 1".to_string(),
237 context: "generate_rr_sets".to_string(),
238 });
239 }
240
241 let mut rng: StdRng = match config.seed {
242 Some(s) => StdRng::seed_from_u64(s),
243 None => StdRng::from_rng(&mut scirs2_core::random::rng()),
244 };
245
246 let rev = reverse_adj(adjacency);
247 let mut rr_sets: Vec<RRSet> = Vec::with_capacity(config.num_rr_sets);
248
249 for _ in 0..config.num_rr_sets {
250 rr_sets.push(generate_one_rr_set(&rev, num_nodes, &mut rng));
251 }
252
253 Ok(rr_sets)
254}
255
256pub fn ris_estimate(rr_sets: &[RRSet], seeds: &[usize], num_nodes: usize) -> Result<f64> {
270 if rr_sets.is_empty() {
271 return Err(GraphError::InvalidParameter {
272 param: "rr_sets".to_string(),
273 value: "empty".to_string(),
274 expected: "non-empty RR set collection".to_string(),
275 context: "ris_estimate".to_string(),
276 });
277 }
278
279 let seed_set: HashSet<usize> = seeds.iter().cloned().collect();
280 let num_covered = rr_sets
281 .iter()
282 .filter(|rr| rr.iter().any(|n| seed_set.contains(n)))
283 .count();
284
285 Ok(num_covered as f64 / rr_sets.len() as f64 * num_nodes as f64)
286}
287
288pub fn imm_algorithm(
311 adjacency: &AdjList,
312 num_nodes: usize,
313 config: &ImmConfig,
314) -> Result<ImmResult> {
315 if num_nodes == 0 {
316 return Err(GraphError::InvalidParameter {
317 param: "num_nodes".to_string(),
318 value: "0".to_string(),
319 expected: ">= 1".to_string(),
320 context: "imm_algorithm".to_string(),
321 });
322 }
323 if config.k == 0 {
324 return Ok(ImmResult {
325 seeds: Vec::new(),
326 estimated_spread: 0.0,
327 num_rr_sets: 0,
328 });
329 }
330 if config.k > num_nodes {
331 return Err(GraphError::InvalidParameter {
332 param: "k".to_string(),
333 value: config.k.to_string(),
334 expected: format!("<= num_nodes={num_nodes}"),
335 context: "imm_algorithm".to_string(),
336 });
337 }
338 if !(0.0..1.0).contains(&config.epsilon) {
339 return Err(GraphError::InvalidParameter {
340 param: "epsilon".to_string(),
341 value: config.epsilon.to_string(),
342 expected: "(0, 1)".to_string(),
343 context: "imm_algorithm".to_string(),
344 });
345 }
346 if !(0.0..1.0).contains(&config.delta) {
347 return Err(GraphError::InvalidParameter {
348 param: "delta".to_string(),
349 value: config.delta.to_string(),
350 expected: "(0, 1)".to_string(),
351 context: "imm_algorithm".to_string(),
352 });
353 }
354
355 let n = num_nodes as f64;
356 let k = config.k;
357 let eps = config.epsilon;
358 let delta = config.delta;
359
360 let ell = (1.0_f64 / delta).ln();
362 let log_n = n.ln();
363
364 let log_cnk = if k >= 1 {
367 let k_f = k as f64;
368 k_f * (n / k_f).ln()
369 } else {
370 0.0_f64
371 };
372
373 let max_iters = (log_n / 2.0_f64.ln()).ceil() as usize + 1;
375
376 let mut rng: StdRng = match config.seed {
377 Some(s) => StdRng::seed_from_u64(s),
378 None => StdRng::from_rng(&mut scirs2_core::random::rng()),
379 };
380
381 let rev = reverse_adj(adjacency);
382 let mut rr_sets: Vec<RRSet> = Vec::new();
383
384 let lambda_prime =
387 (8.0 + 2.0 * eps) * n * (ell * log_n + log_cnk + (2.0_f64).ln()) / (eps * eps);
388
389 for i in 1..=max_iters {
390 let theta_i = (lambda_prime / (n / 2.0_f64.powi(i as i32 - 1))).ceil() as usize;
392
393 while rr_sets.len() < theta_i {
395 rr_sets.push(generate_one_rr_set(&rev, num_nodes, &mut rng));
396 }
397
398 let (_, num_covered) = greedy_max_coverage(&rr_sets, num_nodes, k);
400 let frac = num_covered as f64 / rr_sets.len() as f64;
401
402 let eps_star = compute_epsilon_prime(n, k, ell, frac * n, rr_sets.len());
406 if frac - eps_star >= (1.0 - 1.0 / std::f64::consts::E - eps) * frac {
407 break;
408 }
409 }
410
411 let total_rr = rr_sets.len();
413 let (seeds, num_covered) = greedy_max_coverage(&rr_sets, num_nodes, k);
414
415 let estimated_spread = num_covered as f64 / total_rr as f64 * n;
416
417 Ok(ImmResult {
418 seeds,
419 estimated_spread,
420 num_rr_sets: total_rr,
421 })
422}
423
424fn compute_epsilon_prime(n: f64, k: usize, ell: f64, spread: f64, num_rr: usize) -> f64 {
431 if spread < 1.0 || num_rr == 0 {
432 return 1.0;
433 }
434 let k_f = k as f64;
435 let log_term = ell + (6.0_f64).ln() + k_f * (n / k_f).ln();
437 let eps_sq = (2.0 * (1.0 + 0.1) * log_term * n) / (spread * num_rr as f64);
438 eps_sq.sqrt().min(1.0)
439}
440
441pub fn sandwich_approximation(
456 adjacency: &AdjList,
457 num_nodes: usize,
458 k: usize,
459 config: &RISConfig,
460) -> Result<(Vec<usize>, f64, Vec<usize>, f64)> {
461 if num_nodes == 0 {
462 return Err(GraphError::InvalidParameter {
463 param: "num_nodes".to_string(),
464 value: "0".to_string(),
465 expected: ">= 1".to_string(),
466 context: "sandwich_approximation".to_string(),
467 });
468 }
469 if k == 0 {
470 return Ok((Vec::new(), 0.0, Vec::new(), 0.0));
471 }
472 if k > num_nodes {
473 return Err(GraphError::InvalidParameter {
474 param: "k".to_string(),
475 value: k.to_string(),
476 expected: format!("<= num_nodes={num_nodes}"),
477 context: "sandwich_approximation".to_string(),
478 });
479 }
480
481 let rr_sets = generate_rr_sets(adjacency, num_nodes, config)?;
483
484 let (lower_seeds, lower_covered) = greedy_max_coverage(&rr_sets, num_nodes, k);
486 let lower_spread = lower_covered as f64 / rr_sets.len() as f64 * num_nodes as f64;
487
488 let mut rng: StdRng = match config.seed {
491 Some(s) => StdRng::seed_from_u64(s.wrapping_add(0xDEAD_BEEF)),
492 None => StdRng::from_rng(&mut scirs2_core::random::rng()),
493 };
494 let rev = reverse_adj(adjacency);
495 let mut upper_rr = rr_sets.clone();
496 for _ in 0..config.num_rr_sets {
497 upper_rr.push(generate_one_rr_set(&rev, num_nodes, &mut rng));
498 }
499
500 let (upper_seeds, upper_covered) = greedy_max_coverage(&upper_rr, num_nodes, k);
501 let upper_spread = upper_covered as f64 / upper_rr.len() as f64 * num_nodes as f64;
502
503 Ok((lower_seeds, lower_spread, upper_seeds, upper_spread))
504}
505
506#[cfg(test)]
511mod tests {
512 use super::*;
513
514 fn star_adj(n: usize, p: f64) -> AdjList {
516 let mut adj: AdjList = HashMap::new();
517 for i in 1..n {
518 adj.entry(0).or_default().push((i, p));
519 }
520 adj
521 }
522
523 fn complete_adj(n: usize, p: f64) -> AdjList {
525 let mut adj: AdjList = HashMap::new();
526 for i in 0..n {
527 for j in 0..n {
528 if i != j {
529 adj.entry(i).or_default().push((j, p));
530 }
531 }
532 }
533 adj
534 }
535
536 #[test]
537 fn test_generate_rr_sets_basic() {
538 let adj = star_adj(6, 1.0);
539 let config = RISConfig {
540 num_rr_sets: 50,
541 seed: Some(42),
542 };
543 let rr = generate_rr_sets(&adj, 6, &config).expect("rr sets");
544 assert_eq!(rr.len(), 50);
545 for r in &rr {
547 assert!(!r.is_empty());
548 }
549 }
550
551 #[test]
552 fn test_generate_rr_sets_invalid_params() {
553 let adj = star_adj(6, 1.0);
554 let err = generate_rr_sets(&adj, 0, &RISConfig::default());
556 assert!(err.is_err());
557 let config = RISConfig {
559 num_rr_sets: 0,
560 seed: None,
561 };
562 let err2 = generate_rr_sets(&adj, 6, &config);
563 assert!(err2.is_err());
564 }
565
566 #[test]
567 fn test_ris_estimate_star_hub() {
568 let adj = star_adj(6, 1.0);
570 let config = RISConfig {
571 num_rr_sets: 200,
572 seed: Some(123),
573 };
574 let rr = generate_rr_sets(&adj, 6, &config).expect("rr sets");
575 let spread = ris_estimate(&rr, &[0], 6).expect("estimate");
576 assert!(spread >= 4.0, "spread={spread}");
578 }
579
580 #[test]
581 fn test_ris_estimate_empty_seed() {
582 let adj = star_adj(6, 1.0);
583 let config = RISConfig {
584 num_rr_sets: 100,
585 seed: Some(0),
586 };
587 let rr = generate_rr_sets(&adj, 6, &config).expect("rr sets");
588 let spread = ris_estimate(&rr, &[], 6).expect("zero seed");
589 assert_eq!(spread, 0.0);
590 }
591
592 #[test]
593 fn test_ris_estimate_empty_rr_error() {
594 let err = ris_estimate(&[], &[0], 6);
595 assert!(err.is_err());
596 }
597
598 #[test]
599 fn test_imm_star_selects_hub() {
600 let adj = star_adj(8, 1.0);
601 let config = ImmConfig {
602 k: 1,
603 epsilon: 0.3,
604 delta: 0.1,
605 seed: Some(42),
606 };
607 let result = imm_algorithm(&adj, 8, &config).expect("imm");
608 assert_eq!(result.seeds.len(), 1);
609 assert_eq!(result.seeds[0], 0, "hub expected, got {:?}", result.seeds);
611 assert!(result.estimated_spread >= 1.0);
612 }
613
614 #[test]
615 fn test_imm_k0_returns_empty() {
616 let adj = star_adj(5, 1.0);
617 let config = ImmConfig {
618 k: 0,
619 ..Default::default()
620 };
621 let result = imm_algorithm(&adj, 5, &config).expect("imm k=0");
622 assert!(result.seeds.is_empty());
623 assert_eq!(result.estimated_spread, 0.0);
624 }
625
626 #[test]
627 fn test_imm_invalid_params() {
628 let adj = star_adj(5, 1.0);
629 let err = imm_algorithm(
631 &adj,
632 5,
633 &ImmConfig {
634 k: 10,
635 ..Default::default()
636 },
637 );
638 assert!(err.is_err());
639
640 let err2 = imm_algorithm(
642 &adj,
643 5,
644 &ImmConfig {
645 epsilon: 1.5,
646 ..Default::default()
647 },
648 );
649 assert!(err2.is_err());
650
651 let err3 = imm_algorithm(&adj, 0, &ImmConfig::default());
653 assert!(err3.is_err());
654 }
655
656 #[test]
657 fn test_imm_complete_graph() {
658 let adj = complete_adj(5, 1.0);
660 let config = ImmConfig {
661 k: 1,
662 epsilon: 0.3,
663 delta: 0.1,
664 seed: Some(7),
665 };
666 let result = imm_algorithm(&adj, 5, &config).expect("imm complete");
667 assert_eq!(result.seeds.len(), 1);
668 assert!(result.estimated_spread >= 1.0);
669 }
670
671 #[test]
672 fn test_sandwich_approximation_basic() {
673 let adj = star_adj(6, 1.0);
674 let config = RISConfig {
675 num_rr_sets: 100,
676 seed: Some(99),
677 };
678 let (lower_seeds, lower_spread, upper_seeds, upper_spread) =
679 sandwich_approximation(&adj, 6, 1, &config).expect("sandwich");
680 assert_eq!(lower_seeds.len(), 1);
681 assert_eq!(upper_seeds.len(), 1);
682 assert!(lower_spread >= 0.0);
684 assert!(upper_spread >= 0.0);
685 let _ = (lower_spread, upper_spread); }
689
690 #[test]
691 fn test_sandwich_k0_returns_empty() {
692 let adj = star_adj(5, 1.0);
693 let (ls, lsp, us, usp) =
694 sandwich_approximation(&adj, 5, 0, &RISConfig::default()).expect("k=0");
695 assert!(ls.is_empty());
696 assert!(us.is_empty());
697 assert_eq!(lsp, 0.0);
698 assert_eq!(usp, 0.0);
699 }
700
701 #[test]
702 fn test_sandwich_invalid_params() {
703 let adj = star_adj(5, 1.0);
704 let err = sandwich_approximation(&adj, 5, 10, &RISConfig::default());
705 assert!(err.is_err());
706 let err2 = sandwich_approximation(&adj, 0, 1, &RISConfig::default());
707 assert!(err2.is_err());
708 }
709}