Skip to main content

scirs2_graph/
social.rs

1//! Social network analysis algorithms
2//!
3//! This module provides specialized algorithms for analyzing social graphs:
4//!
5//! - **Influence Maximization**: Find the top-k most influential seed nodes
6//!   under Independent Cascade or Linear Threshold diffusion models
7//! - **Role Detection**: Identify structurally equivalent node roles
8//! - **Echo Chamber Detection**: Partition users into ideologically isolated groups
9//! - **Polarization Index**: Quantify the degree of network polarization
10//!
11//! # References
12//! - Kempe, Kleinberg & Tardos (2003) — influence maximization / IC / LT
13//! - Lorrain & White (1971) — structural equivalence
14//! - Del Vicario et al. (2016) — echo chamber detection
15
16use std::cmp::Reverse;
17use std::collections::{HashMap, HashSet, VecDeque};
18
19use scirs2_core::random::{Rng, RngExt};
20
21use crate::base::{EdgeWeight, Graph, Node};
22use crate::error::{GraphError, Result};
23
24/// Type alias for numeric node id used in social network operations
25pub type NodeId = usize;
26
27/// Spread estimation function type
28type SpreadFn = Box<dyn Fn(&[NodeId]) -> f64>;
29
30// ============================================================================
31// Diffusion / cascade models
32// ============================================================================
33
34/// Diffusion model for influence propagation
35#[derive(Debug, Clone, PartialEq, Default)]
36pub enum CascadeModel {
37    /// Independent Cascade (IC): each active node tries to activate each
38    /// inactive neighbor independently with probability `edge_weight`
39    #[default]
40    IndependentCascade,
41    /// Linear Threshold (LT): a node activates when the total weight of
42    /// incoming active neighbors exceeds a per-node threshold drawn from
43    /// Uniform(0, 1)
44    LinearThreshold,
45}
46
47// ============================================================================
48// Influence maximization
49// ============================================================================
50
51/// Estimate the expected spread of a seed set under the IC model
52/// using Monte-Carlo simulation with `num_simulations` runs.
53///
54/// Returns the average number of activated nodes (including seeds).
55fn estimate_spread_ic(
56    adj: &HashMap<NodeId, Vec<(NodeId, f64)>>,
57    seeds: &[NodeId],
58    num_simulations: usize,
59) -> f64 {
60    let mut rng = scirs2_core::random::rng();
61    let mut total = 0.0f64;
62
63    for _ in 0..num_simulations {
64        let mut active: HashSet<NodeId> = seeds.iter().cloned().collect();
65        let mut queue: VecDeque<NodeId> = seeds.iter().cloned().collect();
66
67        while let Some(node) = queue.pop_front() {
68            if let Some(neighbors) = adj.get(&node) {
69                for &(nbr, prob) in neighbors {
70                    if !active.contains(&nbr) && rng.random::<f64>() < prob {
71                        active.insert(nbr);
72                        queue.push_back(nbr);
73                    }
74                }
75            }
76        }
77        total += active.len() as f64;
78    }
79
80    total / num_simulations as f64
81}
82
83/// Estimate expected spread under the Linear Threshold model
84fn estimate_spread_lt(
85    adj: &HashMap<NodeId, Vec<(NodeId, f64)>>,
86    n_nodes: usize,
87    seeds: &[NodeId],
88    num_simulations: usize,
89) -> f64 {
90    let mut rng = scirs2_core::random::rng();
91    let mut total = 0.0f64;
92
93    for _ in 0..num_simulations {
94        // Draw random thresholds
95        let thresholds: Vec<f64> = (0..n_nodes).map(|_| rng.random::<f64>()).collect();
96        let mut active: HashSet<NodeId> = seeds.iter().cloned().collect();
97        let mut changed = true;
98
99        while changed {
100            changed = false;
101            for node in 0..n_nodes {
102                if active.contains(&node) {
103                    continue;
104                }
105                // Sum of weights from active in-neighbors
106                let influence: f64 = adj
107                    .get(&node)
108                    .map(|nbrs| {
109                        nbrs.iter()
110                            .filter(|&&(nbr, _)| active.contains(&nbr))
111                            .map(|&(_, w)| w)
112                            .sum::<f64>()
113                    })
114                    .unwrap_or(0.0);
115
116                if influence >= thresholds[node] {
117                    active.insert(node);
118                    changed = true;
119                }
120            }
121        }
122        total += active.len() as f64;
123    }
124
125    total / num_simulations as f64
126}
127
128/// Configuration for influence maximization
129#[derive(Debug, Clone)]
130pub struct InfluenceConfig {
131    /// Diffusion model to use
132    pub model: CascadeModel,
133    /// Number of Monte-Carlo simulations per candidate
134    pub num_simulations: usize,
135    /// Default edge activation probability (used when weight ∉ (0,1))
136    pub default_prob: f64,
137}
138
139impl Default for InfluenceConfig {
140    fn default() -> Self {
141        InfluenceConfig {
142            model: CascadeModel::IndependentCascade,
143            num_simulations: 100,
144            default_prob: 0.1,
145        }
146    }
147}
148
149/// Select the top-k most influential seed nodes using a greedy hill-climbing
150/// algorithm (Kempe, Kleinberg & Tardos 2003).
151///
152/// At each step, the node that maximises the marginal gain in expected spread
153/// is added to the seed set.
154///
155/// # Arguments
156/// * `graph` - Undirected or directed graph with edge weights as probabilities
157/// * `k` - Number of seed nodes to select
158/// * `config` - Model and simulation settings
159///
160/// # Returns
161/// A vector of `k` node ids (0-indexed) in order of selection
162pub fn influence_maximization<N, E, Ix>(
163    graph: &Graph<N, E, Ix>,
164    k: usize,
165    config: &InfluenceConfig,
166) -> Result<Vec<NodeId>>
167where
168    N: Node + Clone + std::fmt::Debug,
169    E: EdgeWeight + Clone + Into<f64>,
170    Ix: petgraph::graph::IndexType,
171{
172    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
173    let n = nodes.len();
174
175    if k == 0 {
176        return Ok(Vec::new());
177    }
178    if k > n {
179        return Err(GraphError::InvalidParameter {
180            param: "k".to_string(),
181            value: k.to_string(),
182            expected: format!("<= n_nodes ({})", n),
183            context: "influence_maximization".to_string(),
184        });
185    }
186
187    // Build numeric adjacency list
188    let node_to_idx: HashMap<N, NodeId> = nodes
189        .iter()
190        .enumerate()
191        .map(|(i, nd)| (nd.clone(), i))
192        .collect();
193
194    let mut adj: HashMap<NodeId, Vec<(NodeId, f64)>> = HashMap::new();
195    for edge in graph.edges() {
196        let si = *node_to_idx
197            .get(&edge.source)
198            .ok_or_else(|| GraphError::node_not_found("source node"))?;
199        let ti = *node_to_idx
200            .get(&edge.target)
201            .ok_or_else(|| GraphError::node_not_found("target node"))?;
202        let w: f64 = edge.weight.clone().into();
203        let prob = if w > 0.0 && w <= 1.0 {
204            w
205        } else {
206            config.default_prob
207        };
208        adj.entry(si).or_default().push((ti, prob));
209        adj.entry(ti).or_default().push((si, prob)); // undirected
210    }
211
212    let spread_fn: SpreadFn = match &config.model {
213        CascadeModel::IndependentCascade => {
214            let adj_ref = adj.clone();
215            let sims = config.num_simulations;
216            Box::new(move |seeds| estimate_spread_ic(&adj_ref, seeds, sims))
217        }
218        CascadeModel::LinearThreshold => {
219            let adj_ref = adj.clone();
220            let sims = config.num_simulations;
221            Box::new(move |seeds| estimate_spread_lt(&adj_ref, n, seeds, sims))
222        }
223    };
224
225    let mut seeds: Vec<NodeId> = Vec::with_capacity(k);
226    let mut current_spread = 0.0f64;
227
228    for _ in 0..k {
229        let mut best_node = None;
230        let mut best_gain = f64::NEG_INFINITY;
231
232        for candidate in 0..n {
233            if seeds.contains(&candidate) {
234                continue;
235            }
236            let mut trial_seeds = seeds.clone();
237            trial_seeds.push(candidate);
238            let spread = spread_fn(&trial_seeds);
239            let gain = spread - current_spread;
240
241            if gain > best_gain {
242                best_gain = gain;
243                best_node = Some(candidate);
244            }
245        }
246
247        if let Some(node) = best_node {
248            seeds.push(node);
249            current_spread += best_gain;
250        }
251    }
252
253    Ok(seeds)
254}
255
256// ============================================================================
257// Role detection
258// ============================================================================
259
260/// Structural role of a node in the network
261#[derive(Debug, Clone, PartialEq, Eq, Hash)]
262pub enum RoleType {
263    /// High-degree central hub connecting many nodes
264    Hub,
265    /// Low-degree peripheral node connected mainly to hubs
266    Peripheral,
267    /// Node bridging between communities (high betweenness, moderate degree)
268    Bridge,
269    /// Ordinary member of a community
270    Member,
271    /// Isolated node with no connections
272    Isolated,
273}
274
275impl std::fmt::Display for RoleType {
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        match self {
278            RoleType::Hub => write!(f, "Hub"),
279            RoleType::Peripheral => write!(f, "Peripheral"),
280            RoleType::Bridge => write!(f, "Bridge"),
281            RoleType::Member => write!(f, "Member"),
282            RoleType::Isolated => write!(f, "Isolated"),
283        }
284    }
285}
286
287/// Detect structural roles for all nodes based on degree and local clustering.
288///
289/// The assignment uses thresholds derived from the graph's degree statistics:
290/// - **Isolated**: degree 0
291/// - **Hub**: degree > mean + std
292/// - **Peripheral**: degree < mean - 0.5*std AND low local clustering
293/// - **Bridge**: clustering coefficient much lower than graph average
294/// - **Member**: otherwise
295///
296/// # Arguments
297/// * `graph` - The graph to analyze
298///
299/// # Returns
300/// Map from node index to its detected role
301pub fn role_detection<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Result<HashMap<NodeId, RoleType>>
302where
303    N: Node + Clone + std::fmt::Debug,
304    E: EdgeWeight + Clone + Into<f64>,
305    Ix: petgraph::graph::IndexType,
306{
307    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
308    let n = nodes.len();
309
310    if n == 0 {
311        return Ok(HashMap::new());
312    }
313
314    // Compute degrees
315    let degrees: Vec<f64> = nodes.iter().map(|nd| graph.degree(nd) as f64).collect();
316    let mean_deg = degrees.iter().sum::<f64>() / n as f64;
317    let var_deg = degrees.iter().map(|d| (d - mean_deg).powi(2)).sum::<f64>() / n as f64;
318    let std_deg = var_deg.sqrt();
319
320    // Compute local clustering coefficient per node
321    let clustering: Vec<f64> = nodes
322        .iter()
323        .map(|nd| local_clustering_coefficient(graph, nd))
324        .collect();
325
326    let mean_clustering = if n > 0 {
327        clustering.iter().sum::<f64>() / n as f64
328    } else {
329        0.0
330    };
331
332    let mut roles = HashMap::with_capacity(n);
333
334    for (i, _node) in nodes.iter().enumerate() {
335        let deg = degrees[i];
336        let clust = clustering[i];
337
338        let role = if deg == 0.0 {
339            RoleType::Isolated
340        } else if deg > mean_deg + std_deg {
341            RoleType::Hub
342        } else if deg < (mean_deg - 0.5 * std_deg).max(1.0) && clust < mean_clustering * 0.5 {
343            RoleType::Peripheral
344        } else if clust < mean_clustering * 0.4 && deg >= 2.0 {
345            RoleType::Bridge
346        } else {
347            RoleType::Member
348        };
349
350        roles.insert(i, role);
351    }
352
353    Ok(roles)
354}
355
356/// Compute local clustering coefficient for a node
357fn local_clustering_coefficient<N, E, Ix>(graph: &Graph<N, E, Ix>, node: &N) -> f64
358where
359    N: Node + Clone + std::fmt::Debug,
360    E: EdgeWeight + Clone + Into<f64>,
361    Ix: petgraph::graph::IndexType,
362{
363    let neighbors: Vec<N> = match graph.neighbors(node) {
364        Ok(nbrs) => nbrs,
365        Err(_) => return 0.0,
366    };
367    let k = neighbors.len();
368    if k < 2 {
369        return 0.0;
370    }
371
372    let mut triangles = 0usize;
373    for i in 0..k {
374        for j in i + 1..k {
375            if graph.has_edge(&neighbors[i], &neighbors[j]) {
376                triangles += 1;
377            }
378        }
379    }
380
381    let max_possible = k * (k - 1) / 2;
382    if max_possible == 0 {
383        0.0
384    } else {
385        triangles as f64 / max_possible as f64
386    }
387}
388
389// ============================================================================
390// Echo chamber detection
391// ============================================================================
392
393/// Detect echo chambers using a label-propagation-inspired algorithm
394/// that respects node opinion features.
395///
396/// Nodes are partitioned into communities where internal edge density is
397/// high and cross-community connections are few. The `features` argument
398/// provides a numeric opinion/attribute vector per node (indexed 0..n).
399///
400/// # Arguments
401/// * `graph` - The social graph
402/// * `features` - Per-node feature vectors; must have `graph.node_count()` entries
403///
404/// # Returns
405/// A list of echo chambers, each being a list of node indices
406pub fn echo_chamber_detection<N, E, Ix>(
407    graph: &Graph<N, E, Ix>,
408    features: &[Vec<f64>],
409) -> Result<Vec<Vec<NodeId>>>
410where
411    N: Node + Clone + std::fmt::Debug,
412    E: EdgeWeight + Clone + Into<f64>,
413    Ix: petgraph::graph::IndexType,
414{
415    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
416    let n = nodes.len();
417
418    if n == 0 {
419        return Ok(Vec::new());
420    }
421    if features.len() != n {
422        return Err(GraphError::InvalidParameter {
423            param: "features".to_string(),
424            value: format!("{} rows", features.len()),
425            expected: format!("{} rows (one per node)", n),
426            context: "echo_chamber_detection".to_string(),
427        });
428    }
429
430    // Build numeric adjacency
431    let node_to_idx: HashMap<N, NodeId> = nodes
432        .iter()
433        .enumerate()
434        .map(|(i, nd)| (nd.clone(), i))
435        .collect();
436
437    let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
438    for edge in graph.edges() {
439        if let (Some(&si), Some(&ti)) =
440            (node_to_idx.get(&edge.source), node_to_idx.get(&edge.target))
441        {
442            adj[si].push(ti);
443            adj[ti].push(si);
444        }
445    }
446
447    // Feature-aware label propagation
448    // Initialize: each node has its own label
449    let mut labels: Vec<NodeId> = (0..n).collect();
450
451    // Run multiple rounds of propagation
452    for _round in 0..20 {
453        let mut changed = false;
454
455        // Randomized order via deterministic pseudo-random permutation
456        let mut order: Vec<usize> = (0..n).collect();
457        // Simple Fisher-Yates with deterministic seed
458        for i in (1..n).rev() {
459            let j = i
460                .wrapping_mul(6364136223846793005)
461                .wrapping_add(1442695040888963407)
462                % (i + 1);
463            order.swap(i, j);
464        }
465
466        for &node in &order {
467            let nbrs = &adj[node];
468            if nbrs.is_empty() {
469                continue;
470            }
471
472            // Score each candidate label by: frequency + feature similarity
473            let mut label_scores: HashMap<NodeId, f64> = HashMap::new();
474            for &nbr in nbrs {
475                let lbl = labels[nbr];
476                let sim = feature_similarity(&features[node], &features[nbr]);
477                *label_scores.entry(lbl).or_default() += 1.0 + sim;
478            }
479
480            if let Some((&best_label, _)) = label_scores
481                .iter()
482                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
483            {
484                if best_label != labels[node] {
485                    labels[node] = best_label;
486                    changed = true;
487                }
488            }
489        }
490
491        if !changed {
492            break;
493        }
494    }
495
496    // Group nodes by final label
497    let mut chambers: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
498    for (node, &lbl) in labels.iter().enumerate() {
499        chambers.entry(lbl).or_default().push(node);
500    }
501
502    let mut result: Vec<Vec<NodeId>> = chambers.into_values().collect();
503    result.sort_by_key(|b| Reverse(b.len())); // Largest first
504    Ok(result)
505}
506
507/// Cosine-like feature similarity in [−1, 1]
508fn feature_similarity(a: &[f64], b: &[f64]) -> f64 {
509    if a.is_empty() || b.is_empty() || a.len() != b.len() {
510        return 0.0;
511    }
512    let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
513    let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-10);
514    let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-10);
515    dot / (norm_a * norm_b)
516}
517
518// ============================================================================
519// Polarization index
520// ============================================================================
521
522/// Compute the polarization index of a social network.
523///
524/// The index combines three signals:
525/// 1. **Modularity**: ratio of intra-community vs cross-community edges
526///    (computed via echo chamber partition)
527/// 2. **Homophily**: average feature similarity within vs across detected chambers
528/// 3. **Fragmentation**: fraction of cut edges between chambers
529///
530/// The returned value is in [0, 1]: 0 = fully integrated, 1 = fully polarized.
531///
532/// # Arguments
533/// * `graph` - The social graph
534/// * `features` - Per-node opinion features (optional; pass empty vecs if unavailable)
535///
536/// # Returns
537/// Polarization index ∈ [0, 1]
538pub fn polarization_index<N, E, Ix>(graph: &Graph<N, E, Ix>, features: &[Vec<f64>]) -> Result<f64>
539where
540    N: Node + Clone + std::fmt::Debug,
541    E: EdgeWeight + Clone + Into<f64>,
542    Ix: petgraph::graph::IndexType,
543{
544    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
545    let n = nodes.len();
546
547    if n < 2 {
548        return Ok(0.0);
549    }
550
551    let feat_len = features.first().map(|f| f.len()).unwrap_or(0);
552    let feature_pad: Vec<Vec<f64>>;
553    let features_ref: &[Vec<f64>] = if features.len() == n {
554        features
555    } else {
556        feature_pad = vec![vec![0.0; feat_len.max(1)]; n];
557        &feature_pad
558    };
559
560    // Get chamber assignments
561    let chambers = echo_chamber_detection(graph, features_ref)?;
562    let num_chambers = chambers.len();
563
564    if num_chambers <= 1 {
565        return Ok(0.0);
566    }
567
568    // Map node → chamber id
569    let mut node_chamber: Vec<usize> = vec![0; n];
570    for (cid, chamber) in chambers.iter().enumerate() {
571        for &node in chamber {
572            if node < n {
573                node_chamber[node] = cid;
574            }
575        }
576    }
577
578    let node_to_idx: HashMap<N, NodeId> = nodes
579        .iter()
580        .enumerate()
581        .map(|(i, nd)| (nd.clone(), i))
582        .collect();
583
584    let edges = graph.edges();
585    let total_edges = edges.len() as f64;
586
587    if total_edges == 0.0 {
588        return Ok(0.0);
589    }
590
591    // Count intra- and cross-chamber edges
592    let mut intra = 0.0f64;
593    let mut cross = 0.0f64;
594    let mut intra_sim = 0.0f64;
595    let mut cross_sim = 0.0f64;
596
597    for edge in &edges {
598        if let (Some(&si), Some(&ti)) =
599            (node_to_idx.get(&edge.source), node_to_idx.get(&edge.target))
600        {
601            let sim = feature_similarity(
602                features_ref.get(si).map(|v| v.as_slice()).unwrap_or(&[]),
603                features_ref.get(ti).map(|v| v.as_slice()).unwrap_or(&[]),
604            );
605            if node_chamber[si] == node_chamber[ti] {
606                intra += 1.0;
607                intra_sim += sim;
608            } else {
609                cross += 1.0;
610                cross_sim += sim;
611            }
612        }
613    }
614
615    // Modularity component: high intra / total = high polarization
616    let modularity_component = intra / total_edges;
617
618    // Homophily component: if features available, compare similarities
619    let homophily_component = if feat_len > 0 && (intra + cross) > 0.0 {
620        let avg_intra_sim = if intra > 0.0 { intra_sim / intra } else { 0.0 };
621        let avg_cross_sim = if cross > 0.0 { cross_sim / cross } else { 0.0 };
622        // Normalize difference to [0, 1]
623        ((avg_intra_sim - avg_cross_sim + 2.0) / 4.0).clamp(0.0, 1.0)
624    } else {
625        0.5 // neutral if no features
626    };
627
628    // Combine: weighted average
629    let polarization = 0.6 * modularity_component + 0.4 * homophily_component;
630    Ok(polarization.clamp(0.0, 1.0))
631}
632
633// ============================================================================
634// Additional utility: spread simulation
635// ============================================================================
636
637/// Simulate information spread from a set of seed nodes and return all
638/// activated nodes under the chosen diffusion model.
639///
640/// Useful for what-if analysis after running `influence_maximization`.
641pub fn simulate_spread<N, E, Ix>(
642    graph: &Graph<N, E, Ix>,
643    seeds: &[NodeId],
644    config: &InfluenceConfig,
645) -> Result<HashSet<NodeId>>
646where
647    N: Node + Clone + std::fmt::Debug,
648    E: EdgeWeight + Clone + Into<f64>,
649    Ix: petgraph::graph::IndexType,
650{
651    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
652    let n = nodes.len();
653    let node_to_idx: HashMap<N, NodeId> = nodes
654        .iter()
655        .enumerate()
656        .map(|(i, nd)| (nd.clone(), i))
657        .collect();
658
659    let mut adj: HashMap<NodeId, Vec<(NodeId, f64)>> = HashMap::new();
660    for edge in graph.edges() {
661        let si = *node_to_idx
662            .get(&edge.source)
663            .ok_or_else(|| GraphError::node_not_found("source"))?;
664        let ti = *node_to_idx
665            .get(&edge.target)
666            .ok_or_else(|| GraphError::node_not_found("target"))?;
667        let w: f64 = edge.weight.clone().into();
668        let prob = if w > 0.0 && w <= 1.0 {
669            w
670        } else {
671            config.default_prob
672        };
673        adj.entry(si).or_default().push((ti, prob));
674        adj.entry(ti).or_default().push((si, prob));
675    }
676
677    let active_count = match &config.model {
678        CascadeModel::IndependentCascade => {
679            // Run a single full simulation
680            let mut rng = scirs2_core::random::rng();
681            let mut active: HashSet<NodeId> = seeds.iter().cloned().collect();
682            let mut queue: VecDeque<NodeId> = seeds.iter().cloned().collect();
683            while let Some(node) = queue.pop_front() {
684                if let Some(neighbors) = adj.get(&node) {
685                    for &(nbr, prob) in neighbors {
686                        if !active.contains(&nbr) && rng.random::<f64>() < prob {
687                            active.insert(nbr);
688                            queue.push_back(nbr);
689                        }
690                    }
691                }
692            }
693            active
694        }
695        CascadeModel::LinearThreshold => {
696            let mut rng = scirs2_core::random::rng();
697            let thresholds: Vec<f64> = (0..n).map(|_| rng.random::<f64>()).collect();
698            let mut active: HashSet<NodeId> = seeds.iter().cloned().collect();
699            let mut changed = true;
700            while changed {
701                changed = false;
702                for node in 0..n {
703                    if active.contains(&node) {
704                        continue;
705                    }
706                    let influence: f64 = adj
707                        .get(&node)
708                        .map(|nbrs| {
709                            nbrs.iter()
710                                .filter(|&&(nbr, _)| active.contains(&nbr))
711                                .map(|&(_, w)| w)
712                                .sum::<f64>()
713                        })
714                        .unwrap_or(0.0);
715                    if influence >= thresholds[node] {
716                        active.insert(node);
717                        changed = true;
718                    }
719                }
720            }
721            active
722        }
723    };
724
725    Ok(active_count)
726}
727
728// ============================================================================
729// Tests
730// ============================================================================
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735    use crate::base::Graph;
736
737    fn make_social_graph() -> Graph<usize, f64> {
738        let mut g: Graph<usize, f64> = Graph::new();
739        // Two cliques connected by a bridge
740        for i in 0..4 {
741            for j in i + 1..4 {
742                let _ = g.add_edge(i, j, 0.3);
743            }
744        }
745        for i in 5..9 {
746            for j in i + 1..9 {
747                let _ = g.add_edge(i, j, 0.3);
748            }
749        }
750        // Bridge node 4 connecting the two cliques
751        let _ = g.add_edge(3, 4, 0.1);
752        let _ = g.add_edge(4, 5, 0.1);
753        g
754    }
755
756    #[test]
757    fn test_influence_maximization_returns_k_seeds() {
758        let g = make_social_graph();
759        let config = InfluenceConfig {
760            model: CascadeModel::IndependentCascade,
761            num_simulations: 20,
762            default_prob: 0.3,
763        };
764        let seeds = influence_maximization(&g, 3, &config).expect("IM failed");
765        assert_eq!(seeds.len(), 3, "Should return exactly k seeds");
766        // No duplicates
767        let unique: HashSet<_> = seeds.iter().cloned().collect();
768        assert_eq!(unique.len(), 3, "Seeds should be unique");
769    }
770
771    #[test]
772    fn test_influence_maximization_linear_threshold() {
773        let g = make_social_graph();
774        let config = InfluenceConfig {
775            model: CascadeModel::LinearThreshold,
776            num_simulations: 20,
777            default_prob: 0.3,
778        };
779        let seeds = influence_maximization(&g, 2, &config).expect("IM LT failed");
780        assert_eq!(seeds.len(), 2);
781    }
782
783    #[test]
784    fn test_influence_maximization_k_zero() {
785        let g = make_social_graph();
786        let config = InfluenceConfig::default();
787        let seeds = influence_maximization(&g, 0, &config).expect("IM k=0");
788        assert!(seeds.is_empty());
789    }
790
791    #[test]
792    fn test_role_detection_identifies_hub() {
793        let g = make_social_graph();
794        let roles = role_detection(&g).expect("Role detection failed");
795        // Node 4 is the bridge
796        assert!(roles.contains_key(&4), "Node 4 should have a role");
797        // At least one hub in each clique (high degree nodes)
798        let hubs: Vec<_> = roles.values().filter(|r| **r == RoleType::Hub).collect();
799        assert!(!hubs.is_empty(), "Should detect at least one hub");
800    }
801
802    #[test]
803    fn test_role_detection_isolated() {
804        let mut g: Graph<usize, f64> = Graph::new();
805        g.add_node(0);
806        g.add_node(1);
807        let _ = g.add_edge(0, 1, 1.0);
808        g.add_node(2); // isolated
809        let roles = role_detection(&g).expect("Roles failed");
810        assert_eq!(roles.get(&2), Some(&RoleType::Isolated));
811    }
812
813    #[test]
814    fn test_echo_chamber_detection_two_groups() {
815        let g = make_social_graph();
816        // Feature: group A has opinion ~0, group B has opinion ~1
817        let features: Vec<Vec<f64>> = (0..9)
818            .map(|i| vec![if i < 4 { 0.1 } else { 0.9 }])
819            .collect();
820        let chambers = echo_chamber_detection(&g, &features).expect("Echo chamber failed");
821        assert!(!chambers.is_empty(), "Should detect at least one chamber");
822        // Total nodes across all chambers should equal graph node count
823        let total: usize = chambers.iter().map(|c| c.len()).sum();
824        assert_eq!(total, 9, "All nodes must be assigned to a chamber");
825    }
826
827    #[test]
828    fn test_echo_chamber_feature_size_mismatch() {
829        let g = make_social_graph();
830        let features: Vec<Vec<f64>> = vec![vec![0.5]; 3]; // wrong size
831        let result = echo_chamber_detection(&g, &features);
832        assert!(
833            result.is_err(),
834            "Should return error for mismatched features"
835        );
836    }
837
838    #[test]
839    fn test_polarization_index_range() {
840        let g = make_social_graph();
841        let features: Vec<Vec<f64>> = (0..9)
842            .map(|i| vec![if i < 4 { 0.0 } else { 1.0 }])
843            .collect();
844        let pi = polarization_index(&g, &features).expect("Polarization failed");
845        assert!(
846            (0.0..=1.0).contains(&pi),
847            "Polarization index must be in [0,1], got {}",
848            pi
849        );
850    }
851
852    #[test]
853    fn test_polarization_index_no_features() {
854        let g = make_social_graph();
855        let features: Vec<Vec<f64>> = vec![vec![0.0; 0]; 9];
856        let pi = polarization_index(&g, &features).expect("Polarization (no feat)");
857        assert!((0.0..=1.0).contains(&pi));
858    }
859
860    #[test]
861    fn test_simulate_spread_ic() {
862        let g = make_social_graph();
863        let config = InfluenceConfig {
864            model: CascadeModel::IndependentCascade,
865            num_simulations: 10,
866            default_prob: 0.3,
867        };
868        let activated = simulate_spread(&g, &[0], &config).expect("Spread failed");
869        // At minimum, the seed itself is activated
870        assert!(activated.contains(&0), "Seed must be in activated set");
871    }
872}