1use csv::Reader;
2use rustc_hash::FxHashMap;
3use std::error::Error;
4use std::fs::File;
5
6use crate::prelude::*;
7
8pub fn read_graph(
24 path: &str,
25 directed: bool,
26 p: f32,
27 q: f32,
28) -> Result<Node2VecGraph, Box<dyn Error>> {
29 let mut adjacency = FxHashMap::default();
30 let file = File::open(path)?;
31 let mut rdr = Reader::from_reader(file);
32
33 for result in rdr.records() {
34 let record = result?;
35 let from: u32 = record[0]
36 .parse()
37 .map_err(|_| format!("Cannot cast 'from' to u32: {}", &record[0]))?;
38 let to: u32 = record[1]
39 .parse()
40 .map_err(|_| format!("Cannot cast 'to' to u32: {}", &record[1]))?;
41 let weight: f32 = record.get(2).and_then(|s| s.parse().ok()).unwrap_or(1.0);
42
43 adjacency
44 .entry(from)
45 .or_insert_with(Vec::new)
46 .push((to, weight));
47
48 if !directed {
49 adjacency
50 .entry(to)
51 .or_insert_with(Vec::new)
52 .push((from, weight));
53 }
54 }
55
56 let transition_probs = compute_transition_prob(&adjacency, p, q);
57
58 Ok(Node2VecGraph {
59 adjacency,
60 transition_probs,
61 })
62}
63
64#[cfg(test)]
65mod reader_tests {
66 use rustc_hash::FxHashMap;
67 use std::fs::File;
68 use std::io::Write;
69 use tempfile::tempdir;
70
71 type SimpleGraph = Result<FxHashMap<u32, Vec<(u32, f32)>>, Box<dyn std::error::Error>>;
72
73 fn read_graph_test(path: &str, directed: bool) -> SimpleGraph {
75 use csv::Reader;
76 let mut adjacency = FxHashMap::default();
77 let file = File::open(path)?;
78 let mut rdr = Reader::from_reader(file);
79
80 for result in rdr.records() {
81 let record = result?;
82 let from: u32 = record[0].parse()?;
83 let to: u32 = record[1].parse()?;
84 let weight: f32 = record.get(2).and_then(|s| s.parse().ok()).unwrap_or(1.0);
85
86 adjacency
87 .entry(from)
88 .or_insert_with(Vec::new)
89 .push((to, weight));
90
91 if !directed {
92 adjacency
93 .entry(to)
94 .or_insert_with(Vec::new)
95 .push((from, weight));
96 }
97 }
98
99 Ok(adjacency)
100 }
101
102 #[test]
103 fn test_undirected_graph_symmetry() {
104 let dir = tempdir().unwrap();
105 let file_path = dir.path().join("test.csv");
106 let mut file = File::create(&file_path).unwrap();
107 writeln!(file, "from,to").unwrap();
108 writeln!(file, "1,2").unwrap();
109 writeln!(file, "2,3").unwrap();
110
111 let graph = read_graph_test(file_path.to_str().unwrap(), false).unwrap();
112
113 assert!(graph.get(&1).unwrap().iter().any(|(n, _)| *n == 2));
115 assert!(graph.get(&2).unwrap().iter().any(|(n, _)| *n == 1));
116 assert!(graph.get(&2).unwrap().iter().any(|(n, _)| *n == 3));
117 assert!(graph.get(&3).unwrap().iter().any(|(n, _)| *n == 2));
118 }
119
120 #[test]
121 fn test_directed_graph_no_symmetry() {
122 let dir = tempdir().unwrap();
123 let file_path = dir.path().join("test.csv");
124 let mut file = File::create(&file_path).unwrap();
125 writeln!(file, "from,to").unwrap();
126 writeln!(file, "1,2").unwrap();
127 writeln!(file, "2,3").unwrap();
128
129 let graph = read_graph_test(file_path.to_str().unwrap(), true).unwrap();
130
131 assert!(graph.get(&1).unwrap().iter().any(|(n, _)| *n == 2));
133 assert!(!graph.contains_key(&2) || !graph.get(&2).unwrap().iter().any(|(n, _)| *n == 1));
134 }
135
136 #[test]
137 fn test_default_weights() {
138 let dir = tempdir().unwrap();
139 let file_path = dir.path().join("test.csv");
140 let mut file = File::create(&file_path).unwrap();
141 writeln!(file, "from,to").unwrap();
142 writeln!(file, "1,2").unwrap();
143
144 let graph = read_graph_test(file_path.to_str().unwrap(), false).unwrap();
145
146 for edges in graph.values() {
148 for (_, weight) in edges {
149 assert_eq!(*weight, 1.0);
150 }
151 }
152 }
153
154 #[test]
155 fn test_explicit_weights() {
156 let dir = tempdir().unwrap();
157 let file_path = dir.path().join("test.csv");
158 let mut file = File::create(&file_path).unwrap();
159 writeln!(file, "from,to,weight").unwrap();
160 writeln!(file, "1,2,0.5").unwrap();
161 writeln!(file, "2,3,2.0").unwrap();
162
163 let graph = read_graph_test(file_path.to_str().unwrap(), false).unwrap();
164
165 let edge_1_2 = graph
167 .get(&1)
168 .unwrap()
169 .iter()
170 .find(|(n, _)| *n == 2)
171 .unwrap();
172 assert_eq!(edge_1_2.1, 0.5);
173
174 let edge_2_3 = graph
175 .get(&2)
176 .unwrap()
177 .iter()
178 .find(|(n, _)| *n == 3)
179 .unwrap();
180 assert_eq!(edge_2_3.1, 2.0);
181 }
182
183 #[test]
184 fn test_self_loops() {
185 let dir = tempdir().unwrap();
186 let file_path = dir.path().join("test.csv");
187 let mut file = File::create(&file_path).unwrap();
188 writeln!(file, "from,to").unwrap();
189 writeln!(file, "1,1").unwrap();
190
191 let graph = read_graph_test(file_path.to_str().unwrap(), false).unwrap();
192
193 assert!(graph.get(&1).unwrap().iter().any(|(n, _)| *n == 1));
195 }
196
197 #[test]
198 fn test_invalid_node_id() {
199 let dir = tempdir().unwrap();
200 let file_path = dir.path().join("test.csv");
201 let mut file = File::create(&file_path).unwrap();
202 writeln!(file, "from,to").unwrap();
203 writeln!(file, "invalid,2").unwrap();
204
205 let result = read_graph_test(file_path.to_str().unwrap(), false);
206 assert!(result.is_err());
207 }
208}