Skip to main content

random_graphs/distributions/
binomial.rs

1use petgraph::{Graph, Undirected};
2use rand::distributions::{Bernoulli, Distribution};
3use rand::Rng;
4use std::iter::FromIterator;
5use thiserror::Error;
6
7#[derive(Debug, Error, PartialEq)]
8pub enum BinomialGraphError {
9    #[error("invalid parameter `p` = {0}, should be 0 <= `p` <= 1")]
10    InvalidProbability(f64),
11}
12
13#[derive(Debug, Clone)]
14pub struct BinomialGraphDistribution {
15    nodes: usize,
16    p: f64,
17}
18
19impl BinomialGraphDistribution {
20    /// Creates a new `BinomialGraphDistribution` with `nodes` nodes, and where up to
21    /// `binomial(nodes, 2)` edges are inserted independently with probability `p`.
22    ///
23    /// Will return an error if `p < 0` or `p > 1`.
24    ///
25    /// # Example
26    /// ```rust
27    /// use random_graphs::prelude::*;
28    /// use rand::prelude::*;
29    ///
30    /// let distribution = BinomialGraphDistribution::new(4, 0.25).unwrap();
31    ///
32    /// // Generate a random graph
33    /// let graph = distribution.sample(&mut thread_rng());
34    /// assert_eq!(graph.node_count(), 4);
35    /// ```
36    pub fn new(nodes: usize, p: f64) -> Result<Self, BinomialGraphError> {
37        // Probability must be between 0 and 1.
38        if p < 0.0 || p > 1.0 {
39            return Err(BinomialGraphError::InvalidProbability(p));
40        }
41
42        Ok(Self { nodes, p })
43    }
44}
45
46impl Distribution<Graph<usize, (), Undirected>> for BinomialGraphDistribution {
47    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Graph<usize, (), Undirected> {
48        // Expected number of edges is binomial(n, 2) * p
49        let mut graph = Graph::new_undirected();
50
51        let nodes = Vec::from_iter((0..self.nodes).map(|index| graph.add_node(index)));
52
53        // Unwrap is fine here because we've already verified that 0 <= self.p <= 1.
54        let bernoulli = Bernoulli::new(self.p).unwrap();
55
56        for (index, start_node) in nodes.iter().enumerate() {
57            for end_node in nodes.iter().skip(index + 1) {
58                if bernoulli.sample(rng) {
59                    graph.add_edge(start_node.clone(), end_node.clone(), ());
60                }
61            }
62        }
63
64        graph
65    }
66}
67
68#[cfg(test)]
69mod test {
70    use super::*;
71    use rand::thread_rng;
72
73    #[test]
74    fn test_invalid_p_causes_error() {
75        // Negative value should cause an error
76        let distribution = BinomialGraphDistribution::new(4, -0.05);
77        assert_eq!(distribution.err(), Some(BinomialGraphError::InvalidProbability(-0.05)));
78
79        // A couple of p-values that should be fine
80        for acceptable_p in &[0.0, 0.05, 0.4, 0.77, 0.33, 0.999, 1.0] {
81            let distribution = BinomialGraphDistribution::new(4, *acceptable_p);
82            assert!(distribution.is_ok());
83        }
84
85        // A value greater than 1 should cause an error
86        let distribution = BinomialGraphDistribution::new(4, 1.01);
87        assert_eq!(distribution.err(), Some(BinomialGraphError::InvalidProbability(1.01)));
88    }
89
90    #[test]
91    fn test_binomial_graph_distribution() {
92        // Given 9 nodes, there are 36 possible edges.  Using `p = 1/6` the expected
93        // number of edges in our graph is 6.
94        let nodes = 9;
95        let p = 1.0 / 6.0;
96
97        let distribution = BinomialGraphDistribution::new(nodes, p).unwrap();
98        let mut rng = thread_rng();
99
100        let iteration_count = 10000;
101
102        // Count the number of edges across 10,000 generations
103        let edge_count : usize = (0..iteration_count)
104            .map(|_| distribution.sample(&mut rng).edge_count())
105            .sum();
106
107        let average_number_of_edges = (edge_count as f64) / (iteration_count as f64);
108
109        // TODO: do some mathematics here to figure out what a reasonable relative tolerance
110        //  is here for 10,000 samples (CLT, LOLN).
111        let relative_tolerance = (average_number_of_edges - 6.0) / 6.0;
112        assert!(relative_tolerance < 0.01);
113    }
114}