arrow_graph/algorithms/
sampling.rs

1use arrow::record_batch::RecordBatch;
2use arrow::array::{StringArray, UInt32Array, Float64Array};
3use arrow::datatypes::{DataType, Field, Schema};
4use std::sync::Arc;
5use rand::{Rng, SeedableRng};
6use rand::seq::SliceRandom;
7use rand_pcg::Pcg64;
8use crate::algorithms::{GraphAlgorithm, AlgorithmParams};
9use crate::graph::ArrowGraph;
10use crate::error::{GraphError, Result};
11
12/// Random Walk implementation for graph sampling and ML feature generation
13pub struct RandomWalk;
14
15impl RandomWalk {
16    /// Perform random walks starting from specified nodes
17    fn compute_random_walks(
18        &self,
19        graph: &ArrowGraph,
20        start_nodes: &[String],
21        walk_length: usize,
22        num_walks: usize,
23        seed: Option<u64>,
24    ) -> Result<Vec<Vec<String>>> {
25        if walk_length == 0 {
26            return Err(GraphError::invalid_parameter(
27                "walk_length must be greater than 0"
28            ));
29        }
30
31        if num_walks == 0 {
32            return Err(GraphError::invalid_parameter(
33                "num_walks must be greater than 0"
34            ));
35        }
36
37        let mut rng = match seed {
38            Some(s) => Pcg64::seed_from_u64(s),
39            None => Pcg64::from_entropy(),
40        };
41
42        let mut all_walks = Vec::new();
43
44        for start_node in start_nodes {
45            if !graph.has_node(start_node) {
46                return Err(GraphError::node_not_found(start_node.clone()));
47            }
48
49            for _ in 0..num_walks {
50                let walk = self.single_random_walk(graph, start_node, walk_length, &mut rng)?;
51                all_walks.push(walk);
52            }
53        }
54
55        Ok(all_walks)
56    }
57
58    /// Perform a single random walk from a starting node
59    fn single_random_walk(
60        &self,
61        graph: &ArrowGraph,
62        start_node: &str,
63        walk_length: usize,
64        rng: &mut Pcg64,
65    ) -> Result<Vec<String>> {
66        let mut walk = Vec::with_capacity(walk_length);
67        let mut current_node = start_node.to_string();
68
69        walk.push(current_node.clone());
70
71        for _ in 1..walk_length {
72            let neighbors = match graph.neighbors(&current_node) {
73                Some(neighbors) => neighbors,
74                None => break, // Dead end - end walk early
75            };
76
77            if neighbors.is_empty() {
78                break; // No neighbors - end walk early
79            }
80
81            // Choose random neighbor
82            let next_node = neighbors.choose(rng).unwrap();
83            current_node = next_node.clone();
84            walk.push(current_node.clone());
85        }
86
87        Ok(walk)
88    }
89}
90
91impl GraphAlgorithm for RandomWalk {
92    fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
93        let walk_length: usize = params.get("walk_length").unwrap_or(10);
94        let num_walks: usize = params.get("num_walks").unwrap_or(10);
95        let seed: Option<u64> = params.get("seed");
96        
97        // Get start nodes - if not specified, use all nodes
98        let start_nodes: Vec<String> = if let Some(nodes) = params.get::<Vec<String>>("start_nodes") {
99            nodes
100        } else {
101            graph.node_ids().cloned().collect()
102        };
103
104        let walks = self.compute_random_walks(graph, &start_nodes, walk_length, num_walks, seed)?;
105
106        // Convert walks to Arrow format
107        let schema = Arc::new(Schema::new(vec![
108            Field::new("walk_id", DataType::UInt32, false),
109            Field::new("step", DataType::UInt32, false),
110            Field::new("node_id", DataType::Utf8, false),
111        ]));
112
113        let mut walk_ids = Vec::new();
114        let mut steps = Vec::new();
115        let mut node_ids = Vec::new();
116
117        for (walk_id, walk) in walks.iter().enumerate() {
118            for (step, node_id) in walk.iter().enumerate() {
119                walk_ids.push(walk_id as u32);
120                steps.push(step as u32);
121                node_ids.push(node_id.clone());
122            }
123        }
124
125        RecordBatch::try_new(
126            schema,
127            vec![
128                Arc::new(UInt32Array::from(walk_ids)),
129                Arc::new(UInt32Array::from(steps)),
130                Arc::new(StringArray::from(node_ids)),
131            ],
132        ).map_err(GraphError::from)
133    }
134
135    fn name(&self) -> &'static str {
136        "random_walk"
137    }
138
139    fn description(&self) -> &'static str {
140        "Generate random walks from specified nodes for graph sampling and ML feature generation"
141    }
142}
143
144/// Node2Vec-style biased random walks with return parameter p and in-out parameter q
145pub struct Node2VecWalk;
146
147impl Node2VecWalk {
148    /// Perform Node2Vec-style biased random walks
149    fn compute_node2vec_walks(
150        &self,
151        graph: &ArrowGraph,
152        start_nodes: &[String],
153        walk_length: usize,
154        num_walks: usize,
155        p: f64, // Return parameter (controls likelihood of returning to previous node)
156        q: f64, // In-out parameter (controls likelihood of exploring vs. staying local)
157        seed: Option<u64>,
158    ) -> Result<Vec<Vec<String>>> {
159        if walk_length < 2 {
160            return Err(GraphError::invalid_parameter(
161                "walk_length must be at least 2 for Node2Vec walks"
162            ));
163        }
164
165        if p <= 0.0 || q <= 0.0 {
166            return Err(GraphError::invalid_parameter(
167                "p and q parameters must be positive"
168            ));
169        }
170
171        let mut rng = match seed {
172            Some(s) => Pcg64::seed_from_u64(s),
173            None => Pcg64::from_entropy(),
174        };
175
176        let mut all_walks = Vec::new();
177
178        for start_node in start_nodes {
179            if !graph.has_node(start_node) {
180                return Err(GraphError::node_not_found(start_node.clone()));
181            }
182
183            for _ in 0..num_walks {
184                let walk = self.single_node2vec_walk(graph, start_node, walk_length, p, q, &mut rng)?;
185                all_walks.push(walk);
186            }
187        }
188
189        Ok(all_walks)
190    }
191
192    /// Perform a single Node2Vec-style biased random walk
193    fn single_node2vec_walk(
194        &self,
195        graph: &ArrowGraph,
196        start_node: &str,
197        walk_length: usize,
198        p: f64,
199        q: f64,
200        rng: &mut Pcg64,
201    ) -> Result<Vec<String>> {
202        let mut walk = Vec::with_capacity(walk_length);
203        walk.push(start_node.to_string());
204
205        // First step is uniform random
206        let first_neighbors = match graph.neighbors(start_node) {
207            Some(neighbors) if !neighbors.is_empty() => neighbors,
208            _ => return Ok(walk), // No neighbors - return single node walk
209        };
210
211        let second_node = first_neighbors.choose(rng).unwrap().clone();
212        walk.push(second_node.clone());
213
214        // Subsequent steps use biased probabilities
215        for _ in 2..walk_length {
216            let current_node = &walk[walk.len() - 1];
217            let previous_node = &walk[walk.len() - 2];
218
219            let neighbors = match graph.neighbors(current_node) {
220                Some(neighbors) if !neighbors.is_empty() => neighbors,
221                _ => break, // No neighbors - end walk
222            };
223
224            let next_node = self.choose_next_node_biased(
225                graph, previous_node, current_node, &neighbors, p, q, rng
226            )?;
227            walk.push(next_node);
228        }
229
230        Ok(walk)
231    }
232
233    /// Choose next node based on Node2Vec biased probabilities
234    fn choose_next_node_biased(
235        &self,
236        graph: &ArrowGraph,
237        previous_node: &str,
238        current_node: &str,
239        neighbors: &[String],
240        p: f64,
241        q: f64,
242        rng: &mut Pcg64,
243    ) -> Result<String> {
244        let mut probabilities = Vec::new();
245        let mut cumulative_prob = 0.0;
246
247        // Get neighbors of previous node for distance calculation
248        let previous_neighbors: std::collections::HashSet<_> = match graph.neighbors(previous_node) {
249            Some(neighbors) => neighbors.iter().collect(),
250            None => std::collections::HashSet::new(),
251        };
252
253        for neighbor in neighbors {
254            let weight = if neighbor == previous_node {
255                // Return to previous node - controlled by p
256                1.0 / p
257            } else if previous_neighbors.contains(neighbor) {
258                // Stay within local neighborhood - weight = 1
259                1.0
260            } else {
261                // Explore further - controlled by q
262                1.0 / q
263            };
264
265            // Apply edge weight if available
266            let edge_weight = graph.edge_weight(current_node, neighbor).unwrap_or(1.0);
267            let final_weight = weight * edge_weight;
268
269            cumulative_prob += final_weight;
270            probabilities.push((neighbor, cumulative_prob));
271        }
272
273        // Sample from the probability distribution
274        let random_val = rng.gen::<f64>() * cumulative_prob;
275        
276        for (neighbor, cum_prob) in probabilities {
277            if random_val <= cum_prob {
278                return Ok(neighbor.clone());
279            }
280        }
281
282        // Fallback - should not happen with proper probability distribution
283        Ok(neighbors[0].clone())
284    }
285}
286
287impl GraphAlgorithm for Node2VecWalk {
288    fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
289        let walk_length: usize = params.get("walk_length").unwrap_or(80);
290        let num_walks: usize = params.get("num_walks").unwrap_or(10);
291        let p: f64 = params.get("p").unwrap_or(1.0);
292        let q: f64 = params.get("q").unwrap_or(1.0);
293        let seed: Option<u64> = params.get("seed");
294        
295        // Get start nodes - if not specified, use all nodes
296        let start_nodes: Vec<String> = if let Some(nodes) = params.get::<Vec<String>>("start_nodes") {
297            nodes
298        } else {
299            graph.node_ids().cloned().collect()
300        };
301
302        let walks = self.compute_node2vec_walks(graph, &start_nodes, walk_length, num_walks, p, q, seed)?;
303
304        // Convert walks to Arrow format with additional Node2Vec metadata
305        let schema = Arc::new(Schema::new(vec![
306            Field::new("walk_id", DataType::UInt32, false),
307            Field::new("step", DataType::UInt32, false),
308            Field::new("node_id", DataType::Utf8, false),
309            Field::new("p_param", DataType::Float64, false),
310            Field::new("q_param", DataType::Float64, false),
311        ]));
312
313        let mut walk_ids = Vec::new();
314        let mut steps = Vec::new();
315        let mut node_ids = Vec::new();
316        let mut p_params = Vec::new();
317        let mut q_params = Vec::new();
318
319        for (walk_id, walk) in walks.iter().enumerate() {
320            for (step, node_id) in walk.iter().enumerate() {
321                walk_ids.push(walk_id as u32);
322                steps.push(step as u32);
323                node_ids.push(node_id.clone());
324                p_params.push(p);
325                q_params.push(q);
326            }
327        }
328
329        RecordBatch::try_new(
330            schema,
331            vec![
332                Arc::new(UInt32Array::from(walk_ids)),
333                Arc::new(UInt32Array::from(steps)),
334                Arc::new(StringArray::from(node_ids)),
335                Arc::new(Float64Array::from(p_params)),
336                Arc::new(Float64Array::from(q_params)),
337            ],
338        ).map_err(GraphError::from)
339    }
340
341    fn name(&self) -> &'static str {
342        "node2vec"
343    }
344
345    fn description(&self) -> &'static str {
346        "Generate Node2Vec-style biased random walks with return (p) and in-out (q) parameters"
347    }
348}
349
350/// Graph sampling using various strategies
351pub struct GraphSampling;
352
353impl GraphSampling {
354    /// Random node sampling
355    pub fn random_node_sampling(
356        &self,
357        graph: &ArrowGraph,
358        sample_size: usize,
359        seed: Option<u64>,
360    ) -> Result<Vec<String>> {
361        let all_nodes: Vec<String> = graph.node_ids().cloned().collect();
362        
363        if sample_size >= all_nodes.len() {
364            return Ok(all_nodes);
365        }
366
367        let mut rng = match seed {
368            Some(s) => Pcg64::seed_from_u64(s),
369            None => Pcg64::from_entropy(),
370        };
371
372        let sampled_nodes = all_nodes.choose_multiple(&mut rng, sample_size).cloned().collect();
373        Ok(sampled_nodes)
374    }
375
376    /// Random edge sampling
377    pub fn random_edge_sampling(
378        &self,
379        graph: &ArrowGraph,
380        sample_ratio: f64,
381        seed: Option<u64>,
382    ) -> Result<RecordBatch> {
383        if !(0.0..=1.0).contains(&sample_ratio) {
384            return Err(GraphError::invalid_parameter(
385                "sample_ratio must be between 0.0 and 1.0"
386            ));
387        }
388
389        let mut rng = match seed {
390            Some(s) => Pcg64::seed_from_u64(s),
391            None => Pcg64::from_entropy(),
392        };
393
394        let mut sampled_sources = Vec::new();
395        let mut sampled_targets = Vec::new();
396        let mut sampled_weights = Vec::new();
397
398        // Sample edges based on ratio
399        for node_id in graph.node_ids() {
400            if let Some(neighbors) = graph.neighbors(node_id) {
401                for neighbor in neighbors {
402                    if rng.gen::<f64>() < sample_ratio {
403                        sampled_sources.push(node_id.clone());
404                        sampled_targets.push(neighbor.clone());
405                        sampled_weights.push(graph.edge_weight(node_id, neighbor).unwrap_or(1.0));
406                    }
407                }
408            }
409        }
410
411        let schema = Arc::new(Schema::new(vec![
412            Field::new("source", DataType::Utf8, false),
413            Field::new("target", DataType::Utf8, false),
414            Field::new("weight", DataType::Float64, false),
415        ]));
416
417        RecordBatch::try_new(
418            schema,
419            vec![
420                Arc::new(StringArray::from(sampled_sources)),
421                Arc::new(StringArray::from(sampled_targets)),
422                Arc::new(Float64Array::from(sampled_weights)),
423            ],
424        ).map_err(GraphError::from)
425    }
426
427    /// Snowball sampling (BFS-based expansion)
428    pub fn snowball_sampling(
429        &self,
430        graph: &ArrowGraph,
431        seed_nodes: &[String],
432        k_hops: usize,
433        max_nodes: Option<usize>,
434    ) -> Result<Vec<String>> {
435        if k_hops == 0 {
436            return Ok(seed_nodes.to_vec());
437        }
438
439        let mut sampled_nodes: std::collections::HashSet<String> = std::collections::HashSet::new();
440        let mut current_frontier: std::collections::HashSet<String> = std::collections::HashSet::new();
441
442        // Initialize with seed nodes
443        for seed_node in seed_nodes {
444            if !graph.has_node(seed_node) {
445                return Err(GraphError::node_not_found(seed_node.clone()));
446            }
447            sampled_nodes.insert(seed_node.clone());
448            current_frontier.insert(seed_node.clone());
449        }
450
451        // Expand k hops
452        for _ in 0..k_hops {
453            let mut next_frontier = std::collections::HashSet::new();
454
455            for node in &current_frontier {
456                if let Some(neighbors) = graph.neighbors(node) {
457                    for neighbor in neighbors {
458                        if !sampled_nodes.contains(neighbor) {
459                            sampled_nodes.insert(neighbor.clone());
460                            next_frontier.insert(neighbor.clone());
461
462                            // Check if we've reached the maximum number of nodes
463                            if let Some(max) = max_nodes {
464                                if sampled_nodes.len() >= max {
465                                    return Ok(sampled_nodes.into_iter().collect());
466                                }
467                            }
468                        }
469                    }
470                }
471            }
472
473            current_frontier = next_frontier;
474            if current_frontier.is_empty() {
475                break; // No more nodes to expand
476            }
477        }
478
479        Ok(sampled_nodes.into_iter().collect())
480    }
481}
482
483impl GraphAlgorithm for GraphSampling {
484    fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
485        let sampling_method: String = params.get("method").unwrap_or("random_node".to_string());
486
487        match sampling_method.as_str() {
488            "random_node" => {
489                let sample_size: usize = params.get("sample_size").unwrap_or(graph.node_count() / 2);
490                let seed: Option<u64> = params.get("seed");
491                
492                let sampled_nodes = self.random_node_sampling(graph, sample_size, seed)?;
493
494                let schema = Arc::new(Schema::new(vec![
495                    Field::new("node_id", DataType::Utf8, false),
496                ]));
497
498                RecordBatch::try_new(
499                    schema,
500                    vec![Arc::new(StringArray::from(sampled_nodes))],
501                ).map_err(GraphError::from)
502            }
503            "random_edge" => {
504                let sample_ratio: f64 = params.get("sample_ratio").unwrap_or(0.5);
505                let seed: Option<u64> = params.get("seed");
506                
507                self.random_edge_sampling(graph, sample_ratio, seed)
508            }
509            "snowball" => {
510                let seed_nodes: Vec<String> = params.get("seed_nodes")
511                    .unwrap_or_else(|| vec![graph.node_ids().next().unwrap().clone()]);
512                let k_hops: usize = params.get("k_hops").unwrap_or(2);
513                let max_nodes: Option<usize> = params.get("max_nodes");
514                
515                let sampled_nodes = self.snowball_sampling(graph, &seed_nodes, k_hops, max_nodes)?;
516
517                let schema = Arc::new(Schema::new(vec![
518                    Field::new("node_id", DataType::Utf8, false),
519                ]));
520
521                RecordBatch::try_new(
522                    schema,
523                    vec![Arc::new(StringArray::from(sampled_nodes))],
524                ).map_err(GraphError::from)
525            }
526            _ => Err(GraphError::invalid_parameter(format!(
527                "Unknown sampling method: {}. Supported methods: random_node, random_edge, snowball",
528                sampling_method
529            )))
530        }
531    }
532
533    fn name(&self) -> &'static str {
534        "graph_sampling"
535    }
536
537    fn description(&self) -> &'static str {
538        "Perform various graph sampling strategies including random node/edge sampling and snowball sampling"
539    }
540}