random_graphs/distributions/
binomial.rs1use petgraph::{Graph, Undirected};
2use rand::distributions::{Bernoulli, Distribution};
3use rand::Rng;
4use std::iter::FromIterator;
5use thiserror::Error;
6
7#[derive(Debug, Error, PartialEq)]
8pub enum BinomialGraphError {
9 #[error("invalid parameter `p` = {0}, should be 0 <= `p` <= 1")]
10 InvalidProbability(f64),
11}
12
13#[derive(Debug, Clone)]
14pub struct BinomialGraphDistribution {
15 nodes: usize,
16 p: f64,
17}
18
19impl BinomialGraphDistribution {
20 pub fn new(nodes: usize, p: f64) -> Result<Self, BinomialGraphError> {
37 if p < 0.0 || p > 1.0 {
39 return Err(BinomialGraphError::InvalidProbability(p));
40 }
41
42 Ok(Self { nodes, p })
43 }
44}
45
46impl Distribution<Graph<usize, (), Undirected>> for BinomialGraphDistribution {
47 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Graph<usize, (), Undirected> {
48 let mut graph = Graph::new_undirected();
50
51 let nodes = Vec::from_iter((0..self.nodes).map(|index| graph.add_node(index)));
52
53 let bernoulli = Bernoulli::new(self.p).unwrap();
55
56 for (index, start_node) in nodes.iter().enumerate() {
57 for end_node in nodes.iter().skip(index + 1) {
58 if bernoulli.sample(rng) {
59 graph.add_edge(start_node.clone(), end_node.clone(), ());
60 }
61 }
62 }
63
64 graph
65 }
66}
67
68#[cfg(test)]
69mod test {
70 use super::*;
71 use rand::thread_rng;
72
73 #[test]
74 fn test_invalid_p_causes_error() {
75 let distribution = BinomialGraphDistribution::new(4, -0.05);
77 assert_eq!(distribution.err(), Some(BinomialGraphError::InvalidProbability(-0.05)));
78
79 for acceptable_p in &[0.0, 0.05, 0.4, 0.77, 0.33, 0.999, 1.0] {
81 let distribution = BinomialGraphDistribution::new(4, *acceptable_p);
82 assert!(distribution.is_ok());
83 }
84
85 let distribution = BinomialGraphDistribution::new(4, 1.01);
87 assert_eq!(distribution.err(), Some(BinomialGraphError::InvalidProbability(1.01)));
88 }
89
90 #[test]
91 fn test_binomial_graph_distribution() {
92 let nodes = 9;
95 let p = 1.0 / 6.0;
96
97 let distribution = BinomialGraphDistribution::new(nodes, p).unwrap();
98 let mut rng = thread_rng();
99
100 let iteration_count = 10000;
101
102 let edge_count : usize = (0..iteration_count)
104 .map(|_| distribution.sample(&mut rng).edge_count())
105 .sum();
106
107 let average_number_of_edges = (edge_count as f64) / (iteration_count as f64);
108
109 let relative_tolerance = (average_number_of_edges - 6.0) / 6.0;
112 assert!(relative_tolerance < 0.01);
113 }
114}