digraph_rs/
generator.rs

1use std::{collections::HashMap, vec};
2
3use graphviz_rust::attributes::len;
4use rand::{rngs::ThreadRng, Rng};
5
6use super::{DiGraph, EmptyPayload};
7use crate::digraph;
8use std::hash::Hash;
9
10/// Erdős-Rényi model
11#[derive(Clone, Copy)]
12pub struct ERCfg {
13    pub node_len: usize,
14    pub edge_prob: f64,
15    pub self_conn: bool,
16    pub back_strict: bool,
17    pub max_from: usize,
18    pub max_to: usize,
19}
20
21/// Watts Strogatz model
22#[derive(Clone, Copy)]
23pub struct WSCfg {
24    pub node_len: usize,
25    pub nearest_k: usize,
26    pub rewire_prob: f64,
27}
28
29#[derive(Clone, Copy)]
30pub enum RGGenCfg {
31    ER(ERCfg),
32    WS(WSCfg),
33}
34
35impl Default for RGGenCfg {
36    fn default() -> Self {
37        RGGenCfg::ER(ERCfg {
38            node_len: 30,
39            edge_prob: 0.1,
40            self_conn: false,
41            back_strict: true,
42            max_from: 0,
43            max_to: 0,
44        })
45    }
46}
47fn has_back_link<NId, NL, EL>(g: &DiGraph<NId, NL, EL>, from: &NId, to: &NId) -> bool
48where
49    NId: Clone + Eq + Hash,
50{
51    g.successors(to.clone())
52        .map(|ss| ss.contains_key(from))
53        .unwrap_or(false)
54}
55
56fn ws_generate<NId, NL, EL, FNId, FNL, FEL>(
57    cfg: WSCfg,
58    mut f_id: FNId,
59    f_nl: FNL,
60    f_el: FEL,
61) -> DiGraph<NId, NL, EL>
62where
63    NId: Clone + Eq + Hash,
64    EL: Clone,
65    FNId: FnMut() -> NId,
66    FNL: Fn(&NId) -> NL,
67    FEL: Fn(&NId, &NId) -> EL,
68{
69    let mut g = digraph!(NId, NL, EL);
70    let WSCfg {
71        node_len,
72        nearest_k,
73        rewire_prob,
74    } = cfg;
75    let mut rand = rand::thread_rng();
76    let nsize = nearest_k / 2;
77    assert!(
78        node_len > nsize,
79        "the node len {} should be greater then nearest_k / 2: {}",
80        node_len,
81        nsize
82    );
83
84    let mut ids = vec![];
85    let mut ring: HashMap<NId, Vec<(NId, EL)>> = HashMap::new();
86    for _ in 0..node_len {
87        let id = f_id();
88        let nl = f_nl(&id);
89        g.add_node(id.clone(), nl);
90        ids.push(id.clone());
91    }
92
93    let l = ids.len();
94    for (idx, from) in ids.iter().enumerate() {
95        let mut ring_edges = vec![];
96        for r in 1..=nsize {
97            let lhs_idx = if idx < r { l - r } else { idx - r };
98            let rhs_idx = if idx + r > l { r } else { idx + r };
99
100            if let Some(to) = ids.get(lhs_idx) {
101                if !has_back_link(&g, from, to) {
102                    let payload = f_el(from, to);
103                    ring_edges.push((to.clone(), payload));
104                }
105            }
106
107            if let Some(to) = ids.get(rhs_idx) {
108                if !has_back_link(&g, from, to) {
109                    let payload = f_el(from, to);
110                    ring_edges.push((to.clone(), payload));
111                }
112            }
113        }
114        ring.insert(from.clone(), ring_edges);
115    }
116
117    for from in ids.iter() {
118        if let Some(edges) = ring.remove(from) {
119            let edges_nodes: Vec<NId> = edges.iter().map(|(id, _)| id.clone()).collect();
120            for (to, pl) in edges.into_iter() {
121                let should_replace = rand.gen_bool(rewire_prob);
122                if !should_replace {
123                    g.add_edge(from.clone(), to, pl);
124                } else {
125                    let mut rand_id = rand.gen_range(0..l);
126                    let mut rand_node = ids.get(rand_id).unwrap();
127                    while rand_node == from || edges_nodes.contains(rand_node) {
128                        rand_id = rand.gen_range(0..l);
129                        rand_node = ids.get(rand_id).unwrap();
130                    }
131                    g.add_edge(from.clone(), rand_node.clone(), pl);
132                }
133            }
134        }
135    }
136    g
137}
138
139fn er_generate<NId, NL, EL, FNId, FNL, FEL>(
140    cfg: ERCfg,
141    mut f_id: FNId,
142    f_nl: FNL,
143    f_el: FEL,
144) -> DiGraph<NId, NL, EL>
145where
146    NId: Clone + Eq + Hash,
147    EL: Clone,
148    FNId: FnMut() -> NId,
149    FNL: Fn(&NId) -> NL,
150    FEL: Fn(&NId, &NId) -> EL,
151{
152    let mut g = digraph!(NId, NL, EL);
153    let mut rand = rand::thread_rng();
154    let ERCfg {
155        node_len,
156        edge_prob,
157        self_conn,
158        back_strict,
159        max_from,
160        max_to,
161    } = cfg;
162
163    let mut ids_counters = HashMap::new();
164    let mut ids = vec![];
165    for _ in 0..node_len {
166        let id = f_id();
167        let nl = f_nl(&id);
168        g.add_node(id.clone(), nl);
169        ids.push(id.clone());
170        ids_counters.insert(id.clone(), (0usize, 0usize));
171    }
172    for from in ids.iter() {
173        for to in ids.iter() {
174            let max_bounds = max_from != 0
175                && ids_counters
176                    .get(from)
177                    .map(|(v, _)| v >= &max_from)
178                    .unwrap_or(false)
179                || max_to != 0
180                    && ids_counters
181                        .get(to)
182                        .map(|(_, v)| v >= &max_to)
183                        .unwrap_or(false);
184
185            if !max_bounds {
186                let should_gen = if !self_conn && from == to {
187                    false
188                } else {
189                    rand.gen_bool(edge_prob)
190                };
191                if should_gen {
192                    if !back_strict || !has_back_link(&g, from, to) {
193                        ids_counters.entry(from.clone()).and_modify(|v| {
194                            *v = (v.0 + 1, v.1);
195                        });
196                        ids_counters.entry(to.clone()).and_modify(|v| {
197                            *v = (v.0, v.1 + 1);
198                        });
199                        let el = f_el(from, to);
200                        g.add_edge(from.clone(), to.clone(), el);
201                    }
202                }
203            }
204        }
205    }
206    g
207}
208
209pub struct RandomGraphGenerator {
210    cfg: RGGenCfg,
211}
212
213impl RandomGraphGenerator {
214    pub fn generate_empty(&mut self) -> DiGraph<usize, EmptyPayload, EmptyPayload> {
215        self.generate_usize(|_| EmptyPayload {}, |_, _| EmptyPayload {})
216    }
217
218    pub fn generate_usize<NL, EL, FNL, FEL>(
219        &mut self,
220        f_nl: FNL,
221        f_el: FEL,
222    ) -> DiGraph<usize, NL, EL>
223    where
224        FNL: Fn(&usize) -> NL,
225        EL: Clone,
226        FEL: Fn(&usize, &usize) -> EL,
227    {
228        let len = match self.cfg {
229            RGGenCfg::ER(ERCfg { node_len, .. }) | RGGenCfg::WS(WSCfg { node_len, .. }) => node_len,
230        };
231        let mut r = 0..len;
232        self.generate(move || r.next().unwrap(), f_nl, f_el)
233    }
234}
235
236impl Default for RandomGraphGenerator {
237    fn default() -> Self {
238        Self {
239            cfg: Default::default(),
240        }
241    }
242}
243
244impl RandomGraphGenerator {
245    pub fn new(cfg: RGGenCfg) -> Self {
246        Self { cfg }
247    }
248
249    pub fn generate<NId, NL, EL, FNId, FNL, FEL>(
250        &mut self,
251        mut f_id: FNId,
252        f_nl: FNL,
253        f_el: FEL,
254    ) -> DiGraph<NId, NL, EL>
255    where
256        NId: Clone + Eq + Hash,
257        EL: Clone,
258        FNId: FnMut() -> NId,
259        FNL: Fn(&NId) -> NL,
260        FEL: Fn(&NId, &NId) -> EL,
261    {
262        match self.cfg {
263            RGGenCfg::WS(cfg) => ws_generate(cfg, f_id, f_nl, f_el),
264            RGGenCfg::ER(cfg) => er_generate(cfg, f_id, f_nl, f_el),
265        }
266    }
267}
268
269#[cfg(test)]
270pub mod tests {
271    use crate::generator::{ERCfg, RGGenCfg, WSCfg};
272
273    use super::RandomGraphGenerator;
274
275    #[test]
276    fn simple_gen_test() {
277        let mut g = RandomGraphGenerator::default();
278        let di = g.generate_empty();
279
280        let r = di.visualize().str_to_dot_file("dots/gen.svg");
281        assert!(r.is_ok());
282    }
283    #[test]
284    fn simple_gen_load_test() {
285        let mut g = RandomGraphGenerator::new(RGGenCfg::ER(ERCfg {
286            node_len: 30,
287            edge_prob: 0.1,
288            self_conn: false,
289            back_strict: true,
290            max_from: 0,
291            max_to: 0,
292        }));
293        let di = g.generate_usize(|_| 0, |lhs, rhs| lhs + rhs);
294
295        let r = di.visualize().str_to_dot_file("dots/gen_load.svg");
296        assert!(r.is_ok());
297    }
298    #[test]
299    fn simple_gen_sw_test() {
300        let mut g = RandomGraphGenerator::new(RGGenCfg::WS(WSCfg {
301            node_len: 20,
302            nearest_k: 4,
303            rewire_prob: 0.5,
304        }));
305        let di = g.generate_usize(|_| 0, |lhs, rhs| lhs + rhs);
306
307        let r = di.visualize().str_to_dot_file("dots/gen_load.svg");
308        assert!(r.is_ok());
309    }
310
311    #[test]
312    fn simple_gen_both_test() {
313        let mut ws_gen = RandomGraphGenerator::new(RGGenCfg::WS(WSCfg {
314            node_len: 20,
315            nearest_k: 4,
316            rewire_prob: 0.5,
317        }));
318        let di = ws_gen.generate_usize(|_| 0, |lhs, rhs| lhs + rhs);
319        let r = di.visualize().str_to_dot_file("dots/gen_ws.svg");
320        assert!(r.is_ok());
321
322        let mut er_gen = RandomGraphGenerator::new(RGGenCfg::ER(ERCfg {
323            node_len: 20,
324            edge_prob: 0.1,
325            self_conn: false,
326            back_strict: true,
327            max_from: 0,
328            max_to: 0,
329        }));
330        let di = er_gen.generate_usize(|_| 0, |lhs, rhs| lhs + rhs);
331        let r = di.visualize().str_to_dot_file("dots/gen_er.svg");
332        assert!(r.is_ok());
333    }
334}