1use std::{collections::HashMap, vec};
2
3use graphviz_rust::attributes::len;
4use rand::{rngs::ThreadRng, Rng};
5
6use super::{DiGraph, EmptyPayload};
7use crate::digraph;
8use std::hash::Hash;
9
10#[derive(Clone, Copy)]
12pub struct ERCfg {
13 pub node_len: usize,
14 pub edge_prob: f64,
15 pub self_conn: bool,
16 pub back_strict: bool,
17 pub max_from: usize,
18 pub max_to: usize,
19}
20
21#[derive(Clone, Copy)]
23pub struct WSCfg {
24 pub node_len: usize,
25 pub nearest_k: usize,
26 pub rewire_prob: f64,
27}
28
29#[derive(Clone, Copy)]
30pub enum RGGenCfg {
31 ER(ERCfg),
32 WS(WSCfg),
33}
34
35impl Default for RGGenCfg {
36 fn default() -> Self {
37 RGGenCfg::ER(ERCfg {
38 node_len: 30,
39 edge_prob: 0.1,
40 self_conn: false,
41 back_strict: true,
42 max_from: 0,
43 max_to: 0,
44 })
45 }
46}
47fn has_back_link<NId, NL, EL>(g: &DiGraph<NId, NL, EL>, from: &NId, to: &NId) -> bool
48where
49 NId: Clone + Eq + Hash,
50{
51 g.successors(to.clone())
52 .map(|ss| ss.contains_key(from))
53 .unwrap_or(false)
54}
55
56fn ws_generate<NId, NL, EL, FNId, FNL, FEL>(
57 cfg: WSCfg,
58 mut f_id: FNId,
59 f_nl: FNL,
60 f_el: FEL,
61) -> DiGraph<NId, NL, EL>
62where
63 NId: Clone + Eq + Hash,
64 EL: Clone,
65 FNId: FnMut() -> NId,
66 FNL: Fn(&NId) -> NL,
67 FEL: Fn(&NId, &NId) -> EL,
68{
69 let mut g = digraph!(NId, NL, EL);
70 let WSCfg {
71 node_len,
72 nearest_k,
73 rewire_prob,
74 } = cfg;
75 let mut rand = rand::thread_rng();
76 let nsize = nearest_k / 2;
77 assert!(
78 node_len > nsize,
79 "the node len {} should be greater then nearest_k / 2: {}",
80 node_len,
81 nsize
82 );
83
84 let mut ids = vec![];
85 let mut ring: HashMap<NId, Vec<(NId, EL)>> = HashMap::new();
86 for _ in 0..node_len {
87 let id = f_id();
88 let nl = f_nl(&id);
89 g.add_node(id.clone(), nl);
90 ids.push(id.clone());
91 }
92
93 let l = ids.len();
94 for (idx, from) in ids.iter().enumerate() {
95 let mut ring_edges = vec![];
96 for r in 1..=nsize {
97 let lhs_idx = if idx < r { l - r } else { idx - r };
98 let rhs_idx = if idx + r > l { r } else { idx + r };
99
100 if let Some(to) = ids.get(lhs_idx) {
101 if !has_back_link(&g, from, to) {
102 let payload = f_el(from, to);
103 ring_edges.push((to.clone(), payload));
104 }
105 }
106
107 if let Some(to) = ids.get(rhs_idx) {
108 if !has_back_link(&g, from, to) {
109 let payload = f_el(from, to);
110 ring_edges.push((to.clone(), payload));
111 }
112 }
113 }
114 ring.insert(from.clone(), ring_edges);
115 }
116
117 for from in ids.iter() {
118 if let Some(edges) = ring.remove(from) {
119 let edges_nodes: Vec<NId> = edges.iter().map(|(id, _)| id.clone()).collect();
120 for (to, pl) in edges.into_iter() {
121 let should_replace = rand.gen_bool(rewire_prob);
122 if !should_replace {
123 g.add_edge(from.clone(), to, pl);
124 } else {
125 let mut rand_id = rand.gen_range(0..l);
126 let mut rand_node = ids.get(rand_id).unwrap();
127 while rand_node == from || edges_nodes.contains(rand_node) {
128 rand_id = rand.gen_range(0..l);
129 rand_node = ids.get(rand_id).unwrap();
130 }
131 g.add_edge(from.clone(), rand_node.clone(), pl);
132 }
133 }
134 }
135 }
136 g
137}
138
139fn er_generate<NId, NL, EL, FNId, FNL, FEL>(
140 cfg: ERCfg,
141 mut f_id: FNId,
142 f_nl: FNL,
143 f_el: FEL,
144) -> DiGraph<NId, NL, EL>
145where
146 NId: Clone + Eq + Hash,
147 EL: Clone,
148 FNId: FnMut() -> NId,
149 FNL: Fn(&NId) -> NL,
150 FEL: Fn(&NId, &NId) -> EL,
151{
152 let mut g = digraph!(NId, NL, EL);
153 let mut rand = rand::thread_rng();
154 let ERCfg {
155 node_len,
156 edge_prob,
157 self_conn,
158 back_strict,
159 max_from,
160 max_to,
161 } = cfg;
162
163 let mut ids_counters = HashMap::new();
164 let mut ids = vec![];
165 for _ in 0..node_len {
166 let id = f_id();
167 let nl = f_nl(&id);
168 g.add_node(id.clone(), nl);
169 ids.push(id.clone());
170 ids_counters.insert(id.clone(), (0usize, 0usize));
171 }
172 for from in ids.iter() {
173 for to in ids.iter() {
174 let max_bounds = max_from != 0
175 && ids_counters
176 .get(from)
177 .map(|(v, _)| v >= &max_from)
178 .unwrap_or(false)
179 || max_to != 0
180 && ids_counters
181 .get(to)
182 .map(|(_, v)| v >= &max_to)
183 .unwrap_or(false);
184
185 if !max_bounds {
186 let should_gen = if !self_conn && from == to {
187 false
188 } else {
189 rand.gen_bool(edge_prob)
190 };
191 if should_gen {
192 if !back_strict || !has_back_link(&g, from, to) {
193 ids_counters.entry(from.clone()).and_modify(|v| {
194 *v = (v.0 + 1, v.1);
195 });
196 ids_counters.entry(to.clone()).and_modify(|v| {
197 *v = (v.0, v.1 + 1);
198 });
199 let el = f_el(from, to);
200 g.add_edge(from.clone(), to.clone(), el);
201 }
202 }
203 }
204 }
205 }
206 g
207}
208
209pub struct RandomGraphGenerator {
210 cfg: RGGenCfg,
211}
212
213impl RandomGraphGenerator {
214 pub fn generate_empty(&mut self) -> DiGraph<usize, EmptyPayload, EmptyPayload> {
215 self.generate_usize(|_| EmptyPayload {}, |_, _| EmptyPayload {})
216 }
217
218 pub fn generate_usize<NL, EL, FNL, FEL>(
219 &mut self,
220 f_nl: FNL,
221 f_el: FEL,
222 ) -> DiGraph<usize, NL, EL>
223 where
224 FNL: Fn(&usize) -> NL,
225 EL: Clone,
226 FEL: Fn(&usize, &usize) -> EL,
227 {
228 let len = match self.cfg {
229 RGGenCfg::ER(ERCfg { node_len, .. }) | RGGenCfg::WS(WSCfg { node_len, .. }) => node_len,
230 };
231 let mut r = 0..len;
232 self.generate(move || r.next().unwrap(), f_nl, f_el)
233 }
234}
235
236impl Default for RandomGraphGenerator {
237 fn default() -> Self {
238 Self {
239 cfg: Default::default(),
240 }
241 }
242}
243
244impl RandomGraphGenerator {
245 pub fn new(cfg: RGGenCfg) -> Self {
246 Self { cfg }
247 }
248
249 pub fn generate<NId, NL, EL, FNId, FNL, FEL>(
250 &mut self,
251 mut f_id: FNId,
252 f_nl: FNL,
253 f_el: FEL,
254 ) -> DiGraph<NId, NL, EL>
255 where
256 NId: Clone + Eq + Hash,
257 EL: Clone,
258 FNId: FnMut() -> NId,
259 FNL: Fn(&NId) -> NL,
260 FEL: Fn(&NId, &NId) -> EL,
261 {
262 match self.cfg {
263 RGGenCfg::WS(cfg) => ws_generate(cfg, f_id, f_nl, f_el),
264 RGGenCfg::ER(cfg) => er_generate(cfg, f_id, f_nl, f_el),
265 }
266 }
267}
268
269#[cfg(test)]
270pub mod tests {
271 use crate::generator::{ERCfg, RGGenCfg, WSCfg};
272
273 use super::RandomGraphGenerator;
274
275 #[test]
276 fn simple_gen_test() {
277 let mut g = RandomGraphGenerator::default();
278 let di = g.generate_empty();
279
280 let r = di.visualize().str_to_dot_file("dots/gen.svg");
281 assert!(r.is_ok());
282 }
283 #[test]
284 fn simple_gen_load_test() {
285 let mut g = RandomGraphGenerator::new(RGGenCfg::ER(ERCfg {
286 node_len: 30,
287 edge_prob: 0.1,
288 self_conn: false,
289 back_strict: true,
290 max_from: 0,
291 max_to: 0,
292 }));
293 let di = g.generate_usize(|_| 0, |lhs, rhs| lhs + rhs);
294
295 let r = di.visualize().str_to_dot_file("dots/gen_load.svg");
296 assert!(r.is_ok());
297 }
298 #[test]
299 fn simple_gen_sw_test() {
300 let mut g = RandomGraphGenerator::new(RGGenCfg::WS(WSCfg {
301 node_len: 20,
302 nearest_k: 4,
303 rewire_prob: 0.5,
304 }));
305 let di = g.generate_usize(|_| 0, |lhs, rhs| lhs + rhs);
306
307 let r = di.visualize().str_to_dot_file("dots/gen_load.svg");
308 assert!(r.is_ok());
309 }
310
311 #[test]
312 fn simple_gen_both_test() {
313 let mut ws_gen = RandomGraphGenerator::new(RGGenCfg::WS(WSCfg {
314 node_len: 20,
315 nearest_k: 4,
316 rewire_prob: 0.5,
317 }));
318 let di = ws_gen.generate_usize(|_| 0, |lhs, rhs| lhs + rhs);
319 let r = di.visualize().str_to_dot_file("dots/gen_ws.svg");
320 assert!(r.is_ok());
321
322 let mut er_gen = RandomGraphGenerator::new(RGGenCfg::ER(ERCfg {
323 node_len: 20,
324 edge_prob: 0.1,
325 self_conn: false,
326 back_strict: true,
327 max_from: 0,
328 max_to: 0,
329 }));
330 let di = er_gen.generate_usize(|_| 0, |lhs, rhs| lhs + rhs);
331 let r = di.visualize().str_to_dot_file("dots/gen_er.svg");
332 assert!(r.is_ok());
333 }
334}