1use crate::graph::{Graph, GraphRef};
4use rand::prelude::*;
5use rand_chacha::ChaCha8Rng;
6
7#[derive(Debug, Clone, Copy)]
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9pub struct WalkConfig {
10 pub length: usize,
11 pub walks_per_node: usize,
12 pub p: f32,
13 pub q: f32,
14 pub seed: u64,
15}
16
17impl Default for WalkConfig {
18 fn default() -> Self {
19 Self {
20 length: 80,
21 walks_per_node: 10,
22 p: 1.0,
23 q: 1.0,
24 seed: 42,
25 }
26 }
27}
28
29pub fn sample_start_nodes_reservoir(node_count: usize, k: usize, seed: u64) -> Vec<usize> {
50 let mut rng = ChaCha8Rng::seed_from_u64(seed);
51 if k == 0 || node_count == 0 {
52 return Vec::new();
53 }
54 if k >= node_count {
55 return (0..node_count).collect();
56 }
57
58 let mut reservoir: Vec<usize> = (0..k).collect();
60 for i in k..node_count {
61 let j = rng.random_range(0..=i);
63 if j < k {
64 reservoir[j] = i;
65 }
66 }
67 reservoir
68}
69
70pub fn generate_walks<G: Graph>(graph: &G, config: WalkConfig) -> Vec<Vec<usize>> {
71 let start_nodes: Vec<usize> = (0..graph.node_count()).collect();
72 generate_walks_from_nodes(graph, &start_nodes, config)
73}
74
75pub fn generate_walks_from_nodes<G: Graph>(
81 graph: &G,
82 start_nodes: &[usize],
83 config: WalkConfig,
84) -> Vec<Vec<usize>> {
85 let mut walks = Vec::with_capacity(start_nodes.len() * config.walks_per_node);
86 let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
87 let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
88
89 for _ in 0..config.walks_per_node {
90 epoch_nodes.shuffle(&mut rng);
92 for &node in &epoch_nodes {
93 walks.push(unbiased_walk(graph, node, config.length, &mut rng));
94 }
95 }
96 walks
97}
98
99fn unbiased_walk<G: Graph, R: Rng>(
100 graph: &G,
101 start: usize,
102 length: usize,
103 rng: &mut R,
104) -> Vec<usize> {
105 let mut walk = Vec::with_capacity(length);
106 walk.push(start);
107 let mut curr = start;
108 for _ in 1..length {
109 let neighbors = graph.neighbors(curr);
110 if neighbors.is_empty() {
111 break;
112 }
113 curr = *neighbors.choose(rng).unwrap();
114 walk.push(curr);
115 }
116 walk
117}
118
119pub fn generate_walks_ref<G: GraphRef>(graph: &G, config: WalkConfig) -> Vec<Vec<usize>> {
123 let start_nodes: Vec<usize> = (0..graph.node_count()).collect();
124 generate_walks_ref_from_nodes(graph, &start_nodes, config)
125}
126
127pub fn generate_walks_ref_from_nodes<G: GraphRef>(
136 graph: &G,
137 start_nodes: &[usize],
138 config: WalkConfig,
139) -> Vec<Vec<usize>> {
140 let mut walks = Vec::with_capacity(start_nodes.len() * config.walks_per_node);
141 let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
142 let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
143
144 for _ in 0..config.walks_per_node {
145 epoch_nodes.shuffle(&mut rng);
146 for &node in &epoch_nodes {
147 walks.push(unbiased_walk_ref(graph, node, config.length, &mut rng));
148 }
149 }
150 walks
151}
152
153fn unbiased_walk_ref<G: GraphRef, R: Rng>(
154 graph: &G,
155 start: usize,
156 length: usize,
157 rng: &mut R,
158) -> Vec<usize> {
159 let mut walk = Vec::with_capacity(length);
160 walk.push(start);
161 let mut curr = start;
162 for _ in 1..length {
163 let neighbors = graph.neighbors_ref(curr);
164 if neighbors.is_empty() {
165 break;
166 }
167 curr = *neighbors.choose(rng).unwrap();
168 walk.push(curr);
169 }
170 walk
171}
172
173pub fn generate_walks_ref_streaming_from_nodes<G, F>(
181 graph: &G,
182 start_nodes: &[usize],
183 config: WalkConfig,
184 mut on_walk: F,
185) where
186 G: GraphRef,
187 F: FnMut(&[usize]),
188{
189 let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
190 let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
191 let mut buf: Vec<usize> = Vec::with_capacity(config.length);
192
193 for _ in 0..config.walks_per_node {
194 epoch_nodes.shuffle(&mut rng);
195 for &node in &epoch_nodes {
196 buf.clear();
197 unbiased_walk_ref_into(graph, node, config.length, &mut rng, &mut buf);
198 on_walk(&buf);
199 }
200 }
201}
202
203fn unbiased_walk_ref_into<G: GraphRef, R: Rng>(
204 graph: &G,
205 start: usize,
206 length: usize,
207 rng: &mut R,
208 out: &mut Vec<usize>,
209) {
210 out.reserve(length.saturating_sub(out.capacity()));
211 out.push(start);
212 let mut curr = start;
213 for _ in 1..length {
214 let neighbors = graph.neighbors_ref(curr);
215 if neighbors.is_empty() {
216 break;
217 }
218 curr = *neighbors.choose(rng).unwrap();
219 out.push(curr);
220 }
221}
222
223pub fn generate_biased_walks<G: Graph>(graph: &G, config: WalkConfig) -> Vec<Vec<usize>> {
224 let start_nodes: Vec<usize> = (0..graph.node_count()).collect();
225 generate_biased_walks_from_nodes(graph, &start_nodes, config)
226}
227
228pub fn generate_biased_walks_from_nodes<G: Graph>(
230 graph: &G,
231 start_nodes: &[usize],
232 config: WalkConfig,
233) -> Vec<Vec<usize>> {
234 let mut walks = Vec::with_capacity(start_nodes.len() * config.walks_per_node);
235 let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
236 let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
237
238 for _ in 0..config.walks_per_node {
239 epoch_nodes.shuffle(&mut rng);
240 for &node in &epoch_nodes {
241 walks.push(biased_walk(graph, node, config, &mut rng));
242 }
243 }
244 walks
245}
246
247fn biased_walk<G: Graph, R: Rng>(
248 graph: &G,
249 start: usize,
250 config: WalkConfig,
251 rng: &mut R,
252) -> Vec<usize> {
253 let mut walk = Vec::with_capacity(config.length);
254 walk.push(start);
255 let mut curr = start;
256 let mut prev: Option<usize> = None;
257 let mut prev_neighbors: Vec<usize> = Vec::new();
258
259 for _ in 1..config.length {
260 let neighbors = graph.neighbors(curr);
261 if neighbors.is_empty() {
262 break;
263 }
264 let next = if let Some(p_node) = prev {
265 sample_biased_rejection(rng, p_node, &prev_neighbors, &neighbors, config.p, config.q)
266 } else {
267 *neighbors.choose(rng).unwrap()
268 };
269 walk.push(next);
270 prev = Some(curr);
271 prev_neighbors = neighbors;
272 curr = next;
273 }
274 walk
275}
276
277pub fn generate_biased_walks_ref<G: GraphRef>(graph: &G, config: WalkConfig) -> Vec<Vec<usize>> {
279 let start_nodes: Vec<usize> = (0..graph.node_count()).collect();
280 generate_biased_walks_ref_from_nodes(graph, &start_nodes, config)
281}
282
283pub fn generate_biased_walks_ref_from_nodes<G: GraphRef>(
285 graph: &G,
286 start_nodes: &[usize],
287 config: WalkConfig,
288) -> Vec<Vec<usize>> {
289 let mut walks = Vec::with_capacity(start_nodes.len() * config.walks_per_node);
290 let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
291 let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
292
293 for _ in 0..config.walks_per_node {
294 epoch_nodes.shuffle(&mut rng);
295 for &node in &epoch_nodes {
296 walks.push(biased_walk_ref(graph, node, config, &mut rng));
297 }
298 }
299 walks
300}
301
302fn biased_walk_ref<G: GraphRef, R: Rng>(
303 graph: &G,
304 start: usize,
305 config: WalkConfig,
306 rng: &mut R,
307) -> Vec<usize> {
308 let mut walk = Vec::with_capacity(config.length);
309 walk.push(start);
310
311 let mut curr = start;
312 let mut prev: Option<usize> = None;
313 let mut prev_neighbors: &[usize] = &[];
314
315 for _ in 1..config.length {
316 let neighbors = graph.neighbors_ref(curr);
317 if neighbors.is_empty() {
318 break;
319 }
320
321 let next = if let Some(p_node) = prev {
322 sample_biased_rejection(rng, p_node, prev_neighbors, neighbors, config.p, config.q)
323 } else {
324 *neighbors.choose(rng).unwrap()
325 };
326
327 walk.push(next);
328
329 prev = Some(curr);
331 prev_neighbors = neighbors;
332
333 curr = next;
334 }
335 walk
336}
337
338pub fn generate_biased_walks_ref_streaming_from_nodes<G, F>(
342 graph: &G,
343 start_nodes: &[usize],
344 config: WalkConfig,
345 mut on_walk: F,
346) where
347 G: GraphRef,
348 F: FnMut(&[usize]),
349{
350 let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
351 let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
352 let mut buf: Vec<usize> = Vec::with_capacity(config.length);
353
354 for _ in 0..config.walks_per_node {
355 epoch_nodes.shuffle(&mut rng);
356 for &node in &epoch_nodes {
357 buf.clear();
358 biased_walk_ref_into(graph, node, config, &mut rng, &mut buf);
359 on_walk(&buf);
360 }
361 }
362}
363
364fn biased_walk_ref_into<G: GraphRef, R: Rng>(
365 graph: &G,
366 start: usize,
367 config: WalkConfig,
368 rng: &mut R,
369 out: &mut Vec<usize>,
370) {
371 out.reserve(config.length.saturating_sub(out.capacity()));
372 out.push(start);
373
374 let mut curr = start;
375 let mut prev: Option<usize> = None;
376 let mut prev_neighbors: &[usize] = &[];
377
378 for _ in 1..config.length {
379 let neighbors = graph.neighbors_ref(curr);
380 if neighbors.is_empty() {
381 break;
382 }
383
384 let next = if let Some(p_node) = prev {
385 sample_biased_rejection(rng, p_node, prev_neighbors, neighbors, config.p, config.q)
386 } else {
387 *neighbors.choose(rng).unwrap()
388 };
389
390 out.push(next);
391
392 prev = Some(curr);
393 prev_neighbors = neighbors;
394 curr = next;
395 }
396}
397
398fn sample_biased_rejection<R: Rng>(
399 rng: &mut R,
400 prev_node: usize,
401 prev_neighbors: &[usize],
402 neighbors: &[usize],
403 p: f32,
404 q: f32,
405) -> usize {
406 let max_prob = (1.0 / p).max(1.0).max(1.0 / q);
407 loop {
408 let candidate = *neighbors.choose(rng).unwrap();
409 let r: f32 = rng.random();
410 let is_in_edge = prev_neighbors.contains(&candidate);
411 let unnorm_prob = if candidate == prev_node {
412 1.0 / p
413 } else if is_in_edge {
414 1.0
415 } else {
416 1.0 / q
417 };
418 if r < unnorm_prob / max_prob {
419 return candidate;
420 }
421 }
422}
423
424#[cfg(feature = "parallel")]
425fn mix64(mut x: u64) -> u64 {
426 x ^= x >> 30;
428 x = x.wrapping_mul(0xbf58476d1ce4e5b9);
429 x ^= x >> 27;
430 x = x.wrapping_mul(0x94d049bb133111eb);
431 x ^= x >> 31;
432 x
433}
434
435#[cfg(feature = "parallel")]
439pub fn generate_walks_ref_parallel<G: GraphRef + Sync>(
440 graph: &G,
441 config: WalkConfig,
442) -> Vec<Vec<usize>> {
443 let n = graph.node_count();
444 let start_nodes: Vec<usize> = (0..n).collect();
445 generate_walks_ref_parallel_from_nodes(graph, &start_nodes, config)
446}
447
448#[cfg(feature = "parallel")]
452pub fn generate_walks_ref_parallel_from_nodes<G: GraphRef + Sync>(
453 graph: &G,
454 start_nodes: &[usize],
455 config: WalkConfig,
456) -> Vec<Vec<usize>> {
457 use rayon::prelude::*;
458
459 let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
460 let mut jobs: Vec<(u32, usize)> = Vec::with_capacity(start_nodes.len() * config.walks_per_node);
461
462 for epoch in 0..(config.walks_per_node as u32) {
463 let mut rng = ChaCha8Rng::seed_from_u64(mix64(config.seed ^ (epoch as u64)));
464 epoch_nodes.shuffle(&mut rng);
465 for &node in &epoch_nodes {
466 jobs.push((epoch, node));
467 }
468 }
469
470 jobs.par_iter()
471 .enumerate()
472 .map(|(i, (epoch, node))| {
473 let seed = mix64(config.seed ^ ((*epoch as u64) << 32) ^ (*node as u64) ^ (i as u64));
474 let mut rng = ChaCha8Rng::seed_from_u64(seed);
475 unbiased_walk_ref(graph, *node, config.length, &mut rng)
476 })
477 .collect()
478}
479
480#[cfg(feature = "parallel")]
484pub fn generate_biased_walks_ref_parallel<G: GraphRef + Sync>(
485 graph: &G,
486 config: WalkConfig,
487) -> Vec<Vec<usize>> {
488 let n = graph.node_count();
489 let start_nodes: Vec<usize> = (0..n).collect();
490 generate_biased_walks_ref_parallel_from_nodes(graph, &start_nodes, config)
491}
492
493#[cfg(feature = "parallel")]
497pub fn generate_biased_walks_ref_parallel_from_nodes<G: GraphRef + Sync>(
498 graph: &G,
499 start_nodes: &[usize],
500 config: WalkConfig,
501) -> Vec<Vec<usize>> {
502 use rayon::prelude::*;
503
504 let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
505 let mut jobs: Vec<(u32, usize)> = Vec::with_capacity(start_nodes.len() * config.walks_per_node);
506
507 for epoch in 0..(config.walks_per_node as u32) {
508 let mut rng = ChaCha8Rng::seed_from_u64(mix64(config.seed ^ (epoch as u64)));
509 epoch_nodes.shuffle(&mut rng);
510 for &node in &epoch_nodes {
511 jobs.push((epoch, node));
512 }
513 }
514
515 jobs.par_iter()
516 .enumerate()
517 .map(|(i, (epoch, node))| {
518 let seed = mix64(config.seed ^ ((*epoch as u64) << 32) ^ (*node as u64) ^ (i as u64));
519 let mut rng = ChaCha8Rng::seed_from_u64(seed);
520 biased_walk_ref(graph, *node, config, &mut rng)
521 })
522 .collect()
523}