1use arrow::record_batch::RecordBatch;
2use arrow::array::{StringArray, UInt32Array, Float64Array};
3use arrow::datatypes::{DataType, Field, Schema};
4use std::sync::Arc;
5use rand::{Rng, SeedableRng};
6use rand::seq::SliceRandom;
7use rand_pcg::Pcg64;
8use crate::algorithms::{GraphAlgorithm, AlgorithmParams};
9use crate::graph::ArrowGraph;
10use crate::error::{GraphError, Result};
11
12pub struct RandomWalk;
14
15impl RandomWalk {
16 fn compute_random_walks(
18 &self,
19 graph: &ArrowGraph,
20 start_nodes: &[String],
21 walk_length: usize,
22 num_walks: usize,
23 seed: Option<u64>,
24 ) -> Result<Vec<Vec<String>>> {
25 if walk_length == 0 {
26 return Err(GraphError::invalid_parameter(
27 "walk_length must be greater than 0"
28 ));
29 }
30
31 if num_walks == 0 {
32 return Err(GraphError::invalid_parameter(
33 "num_walks must be greater than 0"
34 ));
35 }
36
37 let mut rng = match seed {
38 Some(s) => Pcg64::seed_from_u64(s),
39 None => Pcg64::from_entropy(),
40 };
41
42 let mut all_walks = Vec::new();
43
44 for start_node in start_nodes {
45 if !graph.has_node(start_node) {
46 return Err(GraphError::node_not_found(start_node.clone()));
47 }
48
49 for _ in 0..num_walks {
50 let walk = self.single_random_walk(graph, start_node, walk_length, &mut rng)?;
51 all_walks.push(walk);
52 }
53 }
54
55 Ok(all_walks)
56 }
57
58 fn single_random_walk(
60 &self,
61 graph: &ArrowGraph,
62 start_node: &str,
63 walk_length: usize,
64 rng: &mut Pcg64,
65 ) -> Result<Vec<String>> {
66 let mut walk = Vec::with_capacity(walk_length);
67 let mut current_node = start_node.to_string();
68
69 walk.push(current_node.clone());
70
71 for _ in 1..walk_length {
72 let neighbors = match graph.neighbors(¤t_node) {
73 Some(neighbors) => neighbors,
74 None => break, };
76
77 if neighbors.is_empty() {
78 break; }
80
81 let next_node = neighbors.choose(rng).unwrap();
83 current_node = next_node.clone();
84 walk.push(current_node.clone());
85 }
86
87 Ok(walk)
88 }
89}
90
91impl GraphAlgorithm for RandomWalk {
92 fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
93 let walk_length: usize = params.get("walk_length").unwrap_or(10);
94 let num_walks: usize = params.get("num_walks").unwrap_or(10);
95 let seed: Option<u64> = params.get("seed");
96
97 let start_nodes: Vec<String> = if let Some(nodes) = params.get::<Vec<String>>("start_nodes") {
99 nodes
100 } else {
101 graph.node_ids().cloned().collect()
102 };
103
104 let walks = self.compute_random_walks(graph, &start_nodes, walk_length, num_walks, seed)?;
105
106 let schema = Arc::new(Schema::new(vec![
108 Field::new("walk_id", DataType::UInt32, false),
109 Field::new("step", DataType::UInt32, false),
110 Field::new("node_id", DataType::Utf8, false),
111 ]));
112
113 let mut walk_ids = Vec::new();
114 let mut steps = Vec::new();
115 let mut node_ids = Vec::new();
116
117 for (walk_id, walk) in walks.iter().enumerate() {
118 for (step, node_id) in walk.iter().enumerate() {
119 walk_ids.push(walk_id as u32);
120 steps.push(step as u32);
121 node_ids.push(node_id.clone());
122 }
123 }
124
125 RecordBatch::try_new(
126 schema,
127 vec![
128 Arc::new(UInt32Array::from(walk_ids)),
129 Arc::new(UInt32Array::from(steps)),
130 Arc::new(StringArray::from(node_ids)),
131 ],
132 ).map_err(GraphError::from)
133 }
134
135 fn name(&self) -> &'static str {
136 "random_walk"
137 }
138
139 fn description(&self) -> &'static str {
140 "Generate random walks from specified nodes for graph sampling and ML feature generation"
141 }
142}
143
144pub struct Node2VecWalk;
146
147impl Node2VecWalk {
148 fn compute_node2vec_walks(
150 &self,
151 graph: &ArrowGraph,
152 start_nodes: &[String],
153 walk_length: usize,
154 num_walks: usize,
155 p: f64, q: f64, seed: Option<u64>,
158 ) -> Result<Vec<Vec<String>>> {
159 if walk_length < 2 {
160 return Err(GraphError::invalid_parameter(
161 "walk_length must be at least 2 for Node2Vec walks"
162 ));
163 }
164
165 if p <= 0.0 || q <= 0.0 {
166 return Err(GraphError::invalid_parameter(
167 "p and q parameters must be positive"
168 ));
169 }
170
171 let mut rng = match seed {
172 Some(s) => Pcg64::seed_from_u64(s),
173 None => Pcg64::from_entropy(),
174 };
175
176 let mut all_walks = Vec::new();
177
178 for start_node in start_nodes {
179 if !graph.has_node(start_node) {
180 return Err(GraphError::node_not_found(start_node.clone()));
181 }
182
183 for _ in 0..num_walks {
184 let walk = self.single_node2vec_walk(graph, start_node, walk_length, p, q, &mut rng)?;
185 all_walks.push(walk);
186 }
187 }
188
189 Ok(all_walks)
190 }
191
192 fn single_node2vec_walk(
194 &self,
195 graph: &ArrowGraph,
196 start_node: &str,
197 walk_length: usize,
198 p: f64,
199 q: f64,
200 rng: &mut Pcg64,
201 ) -> Result<Vec<String>> {
202 let mut walk = Vec::with_capacity(walk_length);
203 walk.push(start_node.to_string());
204
205 let first_neighbors = match graph.neighbors(start_node) {
207 Some(neighbors) if !neighbors.is_empty() => neighbors,
208 _ => return Ok(walk), };
210
211 let second_node = first_neighbors.choose(rng).unwrap().clone();
212 walk.push(second_node.clone());
213
214 for _ in 2..walk_length {
216 let current_node = &walk[walk.len() - 1];
217 let previous_node = &walk[walk.len() - 2];
218
219 let neighbors = match graph.neighbors(current_node) {
220 Some(neighbors) if !neighbors.is_empty() => neighbors,
221 _ => break, };
223
224 let next_node = self.choose_next_node_biased(
225 graph, previous_node, current_node, &neighbors, p, q, rng
226 )?;
227 walk.push(next_node);
228 }
229
230 Ok(walk)
231 }
232
233 fn choose_next_node_biased(
235 &self,
236 graph: &ArrowGraph,
237 previous_node: &str,
238 current_node: &str,
239 neighbors: &[String],
240 p: f64,
241 q: f64,
242 rng: &mut Pcg64,
243 ) -> Result<String> {
244 let mut probabilities = Vec::new();
245 let mut cumulative_prob = 0.0;
246
247 let previous_neighbors: std::collections::HashSet<_> = match graph.neighbors(previous_node) {
249 Some(neighbors) => neighbors.iter().collect(),
250 None => std::collections::HashSet::new(),
251 };
252
253 for neighbor in neighbors {
254 let weight = if neighbor == previous_node {
255 1.0 / p
257 } else if previous_neighbors.contains(neighbor) {
258 1.0
260 } else {
261 1.0 / q
263 };
264
265 let edge_weight = graph.edge_weight(current_node, neighbor).unwrap_or(1.0);
267 let final_weight = weight * edge_weight;
268
269 cumulative_prob += final_weight;
270 probabilities.push((neighbor, cumulative_prob));
271 }
272
273 let random_val = rng.gen::<f64>() * cumulative_prob;
275
276 for (neighbor, cum_prob) in probabilities {
277 if random_val <= cum_prob {
278 return Ok(neighbor.clone());
279 }
280 }
281
282 Ok(neighbors[0].clone())
284 }
285}
286
287impl GraphAlgorithm for Node2VecWalk {
288 fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
289 let walk_length: usize = params.get("walk_length").unwrap_or(80);
290 let num_walks: usize = params.get("num_walks").unwrap_or(10);
291 let p: f64 = params.get("p").unwrap_or(1.0);
292 let q: f64 = params.get("q").unwrap_or(1.0);
293 let seed: Option<u64> = params.get("seed");
294
295 let start_nodes: Vec<String> = if let Some(nodes) = params.get::<Vec<String>>("start_nodes") {
297 nodes
298 } else {
299 graph.node_ids().cloned().collect()
300 };
301
302 let walks = self.compute_node2vec_walks(graph, &start_nodes, walk_length, num_walks, p, q, seed)?;
303
304 let schema = Arc::new(Schema::new(vec![
306 Field::new("walk_id", DataType::UInt32, false),
307 Field::new("step", DataType::UInt32, false),
308 Field::new("node_id", DataType::Utf8, false),
309 Field::new("p_param", DataType::Float64, false),
310 Field::new("q_param", DataType::Float64, false),
311 ]));
312
313 let mut walk_ids = Vec::new();
314 let mut steps = Vec::new();
315 let mut node_ids = Vec::new();
316 let mut p_params = Vec::new();
317 let mut q_params = Vec::new();
318
319 for (walk_id, walk) in walks.iter().enumerate() {
320 for (step, node_id) in walk.iter().enumerate() {
321 walk_ids.push(walk_id as u32);
322 steps.push(step as u32);
323 node_ids.push(node_id.clone());
324 p_params.push(p);
325 q_params.push(q);
326 }
327 }
328
329 RecordBatch::try_new(
330 schema,
331 vec![
332 Arc::new(UInt32Array::from(walk_ids)),
333 Arc::new(UInt32Array::from(steps)),
334 Arc::new(StringArray::from(node_ids)),
335 Arc::new(Float64Array::from(p_params)),
336 Arc::new(Float64Array::from(q_params)),
337 ],
338 ).map_err(GraphError::from)
339 }
340
341 fn name(&self) -> &'static str {
342 "node2vec"
343 }
344
345 fn description(&self) -> &'static str {
346 "Generate Node2Vec-style biased random walks with return (p) and in-out (q) parameters"
347 }
348}
349
350pub struct GraphSampling;
352
353impl GraphSampling {
354 pub fn random_node_sampling(
356 &self,
357 graph: &ArrowGraph,
358 sample_size: usize,
359 seed: Option<u64>,
360 ) -> Result<Vec<String>> {
361 let all_nodes: Vec<String> = graph.node_ids().cloned().collect();
362
363 if sample_size >= all_nodes.len() {
364 return Ok(all_nodes);
365 }
366
367 let mut rng = match seed {
368 Some(s) => Pcg64::seed_from_u64(s),
369 None => Pcg64::from_entropy(),
370 };
371
372 let sampled_nodes = all_nodes.choose_multiple(&mut rng, sample_size).cloned().collect();
373 Ok(sampled_nodes)
374 }
375
376 pub fn random_edge_sampling(
378 &self,
379 graph: &ArrowGraph,
380 sample_ratio: f64,
381 seed: Option<u64>,
382 ) -> Result<RecordBatch> {
383 if !(0.0..=1.0).contains(&sample_ratio) {
384 return Err(GraphError::invalid_parameter(
385 "sample_ratio must be between 0.0 and 1.0"
386 ));
387 }
388
389 let mut rng = match seed {
390 Some(s) => Pcg64::seed_from_u64(s),
391 None => Pcg64::from_entropy(),
392 };
393
394 let mut sampled_sources = Vec::new();
395 let mut sampled_targets = Vec::new();
396 let mut sampled_weights = Vec::new();
397
398 for node_id in graph.node_ids() {
400 if let Some(neighbors) = graph.neighbors(node_id) {
401 for neighbor in neighbors {
402 if rng.gen::<f64>() < sample_ratio {
403 sampled_sources.push(node_id.clone());
404 sampled_targets.push(neighbor.clone());
405 sampled_weights.push(graph.edge_weight(node_id, neighbor).unwrap_or(1.0));
406 }
407 }
408 }
409 }
410
411 let schema = Arc::new(Schema::new(vec![
412 Field::new("source", DataType::Utf8, false),
413 Field::new("target", DataType::Utf8, false),
414 Field::new("weight", DataType::Float64, false),
415 ]));
416
417 RecordBatch::try_new(
418 schema,
419 vec![
420 Arc::new(StringArray::from(sampled_sources)),
421 Arc::new(StringArray::from(sampled_targets)),
422 Arc::new(Float64Array::from(sampled_weights)),
423 ],
424 ).map_err(GraphError::from)
425 }
426
427 pub fn snowball_sampling(
429 &self,
430 graph: &ArrowGraph,
431 seed_nodes: &[String],
432 k_hops: usize,
433 max_nodes: Option<usize>,
434 ) -> Result<Vec<String>> {
435 if k_hops == 0 {
436 return Ok(seed_nodes.to_vec());
437 }
438
439 let mut sampled_nodes: std::collections::HashSet<String> = std::collections::HashSet::new();
440 let mut current_frontier: std::collections::HashSet<String> = std::collections::HashSet::new();
441
442 for seed_node in seed_nodes {
444 if !graph.has_node(seed_node) {
445 return Err(GraphError::node_not_found(seed_node.clone()));
446 }
447 sampled_nodes.insert(seed_node.clone());
448 current_frontier.insert(seed_node.clone());
449 }
450
451 for _ in 0..k_hops {
453 let mut next_frontier = std::collections::HashSet::new();
454
455 for node in ¤t_frontier {
456 if let Some(neighbors) = graph.neighbors(node) {
457 for neighbor in neighbors {
458 if !sampled_nodes.contains(neighbor) {
459 sampled_nodes.insert(neighbor.clone());
460 next_frontier.insert(neighbor.clone());
461
462 if let Some(max) = max_nodes {
464 if sampled_nodes.len() >= max {
465 return Ok(sampled_nodes.into_iter().collect());
466 }
467 }
468 }
469 }
470 }
471 }
472
473 current_frontier = next_frontier;
474 if current_frontier.is_empty() {
475 break; }
477 }
478
479 Ok(sampled_nodes.into_iter().collect())
480 }
481}
482
483impl GraphAlgorithm for GraphSampling {
484 fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
485 let sampling_method: String = params.get("method").unwrap_or("random_node".to_string());
486
487 match sampling_method.as_str() {
488 "random_node" => {
489 let sample_size: usize = params.get("sample_size").unwrap_or(graph.node_count() / 2);
490 let seed: Option<u64> = params.get("seed");
491
492 let sampled_nodes = self.random_node_sampling(graph, sample_size, seed)?;
493
494 let schema = Arc::new(Schema::new(vec![
495 Field::new("node_id", DataType::Utf8, false),
496 ]));
497
498 RecordBatch::try_new(
499 schema,
500 vec![Arc::new(StringArray::from(sampled_nodes))],
501 ).map_err(GraphError::from)
502 }
503 "random_edge" => {
504 let sample_ratio: f64 = params.get("sample_ratio").unwrap_or(0.5);
505 let seed: Option<u64> = params.get("seed");
506
507 self.random_edge_sampling(graph, sample_ratio, seed)
508 }
509 "snowball" => {
510 let seed_nodes: Vec<String> = params.get("seed_nodes")
511 .unwrap_or_else(|| vec![graph.node_ids().next().unwrap().clone()]);
512 let k_hops: usize = params.get("k_hops").unwrap_or(2);
513 let max_nodes: Option<usize> = params.get("max_nodes");
514
515 let sampled_nodes = self.snowball_sampling(graph, &seed_nodes, k_hops, max_nodes)?;
516
517 let schema = Arc::new(Schema::new(vec![
518 Field::new("node_id", DataType::Utf8, false),
519 ]));
520
521 RecordBatch::try_new(
522 schema,
523 vec![Arc::new(StringArray::from(sampled_nodes))],
524 ).map_err(GraphError::from)
525 }
526 _ => Err(GraphError::invalid_parameter(format!(
527 "Unknown sampling method: {}. Supported methods: random_node, random_edge, snowball",
528 sampling_method
529 )))
530 }
531 }
532
533 fn name(&self) -> &'static str {
534 "graph_sampling"
535 }
536
537 fn description(&self) -> &'static str {
538 "Perform various graph sampling strategies including random node/edge sampling and snowball sampling"
539 }
540}