Skip to main content

node2vec_rs/
reader.rs

1use csv::Reader;
2use rustc_hash::FxHashMap;
3use std::error::Error;
4use std::fs::File;
5
6use crate::prelude::*;
7
8/// Helper function to read in a graph from CSV
9///
10/// ### Params
11///
12/// * `path` - Path to the CSV with a `"from"`, `"to"` and `"weight"` column.
13/// * `directed` - Boolean. Shall the graph be treated as a directed or
14///   undirected graph.
15/// * `p` - p parameter in node2vec that controls probability to return
16/// * `q` - q parameter in node2vec that controls probability to reach out
17///   futher in the graph.
18///
19/// ### Returns
20///
21/// The `Node2VecGraph` with adjacency stored in their and transition
22/// probabilities.
23pub fn read_graph(
24    path: &str,
25    directed: bool,
26    p: f32,
27    q: f32,
28) -> Result<Node2VecGraph, Box<dyn Error>> {
29    let mut adjacency = FxHashMap::default();
30    let file = File::open(path)?;
31    let mut rdr = Reader::from_reader(file);
32
33    for result in rdr.records() {
34        let record = result?;
35        let from: u32 = record[0]
36            .parse()
37            .map_err(|_| format!("Cannot cast 'from' to u32: {}", &record[0]))?;
38        let to: u32 = record[1]
39            .parse()
40            .map_err(|_| format!("Cannot cast 'to' to u32: {}", &record[1]))?;
41        let weight: f32 = record.get(2).and_then(|s| s.parse().ok()).unwrap_or(1.0);
42
43        adjacency
44            .entry(from)
45            .or_insert_with(Vec::new)
46            .push((to, weight));
47
48        if !directed {
49            adjacency
50                .entry(to)
51                .or_insert_with(Vec::new)
52                .push((from, weight));
53        }
54    }
55
56    let transition_probs = compute_transition_prob(&adjacency, p, q);
57
58    Ok(Node2VecGraph {
59        adjacency,
60        transition_probs,
61    })
62}
63
64#[cfg(test)]
65mod reader_tests {
66    use rustc_hash::FxHashMap;
67    use std::fs::File;
68    use std::io::Write;
69    use tempfile::tempdir;
70
71    type SimpleGraph = Result<FxHashMap<u32, Vec<(u32, f32)>>, Box<dyn std::error::Error>>;
72
73    // Simplified version of read_graph for testing
74    fn read_graph_test(path: &str, directed: bool) -> SimpleGraph {
75        use csv::Reader;
76        let mut adjacency = FxHashMap::default();
77        let file = File::open(path)?;
78        let mut rdr = Reader::from_reader(file);
79
80        for result in rdr.records() {
81            let record = result?;
82            let from: u32 = record[0].parse()?;
83            let to: u32 = record[1].parse()?;
84            let weight: f32 = record.get(2).and_then(|s| s.parse().ok()).unwrap_or(1.0);
85
86            adjacency
87                .entry(from)
88                .or_insert_with(Vec::new)
89                .push((to, weight));
90
91            if !directed {
92                adjacency
93                    .entry(to)
94                    .or_insert_with(Vec::new)
95                    .push((from, weight));
96            }
97        }
98
99        Ok(adjacency)
100    }
101
102    #[test]
103    fn test_undirected_graph_symmetry() {
104        let dir = tempdir().unwrap();
105        let file_path = dir.path().join("test.csv");
106        let mut file = File::create(&file_path).unwrap();
107        writeln!(file, "from,to").unwrap();
108        writeln!(file, "1,2").unwrap();
109        writeln!(file, "2,3").unwrap();
110
111        let graph = read_graph_test(file_path.to_str().unwrap(), false).unwrap();
112
113        // Check bidirectional edges exist
114        assert!(graph.get(&1).unwrap().iter().any(|(n, _)| *n == 2));
115        assert!(graph.get(&2).unwrap().iter().any(|(n, _)| *n == 1));
116        assert!(graph.get(&2).unwrap().iter().any(|(n, _)| *n == 3));
117        assert!(graph.get(&3).unwrap().iter().any(|(n, _)| *n == 2));
118    }
119
120    #[test]
121    fn test_directed_graph_no_symmetry() {
122        let dir = tempdir().unwrap();
123        let file_path = dir.path().join("test.csv");
124        let mut file = File::create(&file_path).unwrap();
125        writeln!(file, "from,to").unwrap();
126        writeln!(file, "1,2").unwrap();
127        writeln!(file, "2,3").unwrap();
128
129        let graph = read_graph_test(file_path.to_str().unwrap(), true).unwrap();
130
131        // Check only forward edges exist
132        assert!(graph.get(&1).unwrap().iter().any(|(n, _)| *n == 2));
133        assert!(!graph.contains_key(&2) || !graph.get(&2).unwrap().iter().any(|(n, _)| *n == 1));
134    }
135
136    #[test]
137    fn test_default_weights() {
138        let dir = tempdir().unwrap();
139        let file_path = dir.path().join("test.csv");
140        let mut file = File::create(&file_path).unwrap();
141        writeln!(file, "from,to").unwrap();
142        writeln!(file, "1,2").unwrap();
143
144        let graph = read_graph_test(file_path.to_str().unwrap(), false).unwrap();
145
146        // All weights should be 1.0
147        for edges in graph.values() {
148            for (_, weight) in edges {
149                assert_eq!(*weight, 1.0);
150            }
151        }
152    }
153
154    #[test]
155    fn test_explicit_weights() {
156        let dir = tempdir().unwrap();
157        let file_path = dir.path().join("test.csv");
158        let mut file = File::create(&file_path).unwrap();
159        writeln!(file, "from,to,weight").unwrap();
160        writeln!(file, "1,2,0.5").unwrap();
161        writeln!(file, "2,3,2.0").unwrap();
162
163        let graph = read_graph_test(file_path.to_str().unwrap(), false).unwrap();
164
165        // Check explicit weights
166        let edge_1_2 = graph
167            .get(&1)
168            .unwrap()
169            .iter()
170            .find(|(n, _)| *n == 2)
171            .unwrap();
172        assert_eq!(edge_1_2.1, 0.5);
173
174        let edge_2_3 = graph
175            .get(&2)
176            .unwrap()
177            .iter()
178            .find(|(n, _)| *n == 3)
179            .unwrap();
180        assert_eq!(edge_2_3.1, 2.0);
181    }
182
183    #[test]
184    fn test_self_loops() {
185        let dir = tempdir().unwrap();
186        let file_path = dir.path().join("test.csv");
187        let mut file = File::create(&file_path).unwrap();
188        writeln!(file, "from,to").unwrap();
189        writeln!(file, "1,1").unwrap();
190
191        let graph = read_graph_test(file_path.to_str().unwrap(), false).unwrap();
192
193        // Self-loop should exist
194        assert!(graph.get(&1).unwrap().iter().any(|(n, _)| *n == 1));
195    }
196
197    #[test]
198    fn test_invalid_node_id() {
199        let dir = tempdir().unwrap();
200        let file_path = dir.path().join("test.csv");
201        let mut file = File::create(&file_path).unwrap();
202        writeln!(file, "from,to").unwrap();
203        writeln!(file, "invalid,2").unwrap();
204
205        let result = read_graph_test(file_path.to_str().unwrap(), false);
206        assert!(result.is_err());
207    }
208}