1use super::index::CsrIndex;
8
9impl CsrIndex {
10 pub(crate) fn enable_weights(&mut self) {
12 self.has_weights = true;
13
14 if !self.out_targets.is_empty() {
16 self.out_weights = Some(vec![1.0; self.out_targets.len()].into());
17 }
18 if !self.in_targets.is_empty() {
19 self.in_weights = Some(vec![1.0; self.in_targets.len()].into());
20 }
21
22 for (buf, wbuf) in self
24 .buffer_out
25 .iter()
26 .zip(self.buffer_out_weights.iter_mut())
27 {
28 if wbuf.len() < buf.len() {
29 wbuf.resize(buf.len(), 1.0);
30 }
31 }
32 for (buf, wbuf) in self.buffer_in.iter().zip(self.buffer_in_weights.iter_mut()) {
33 if wbuf.len() < buf.len() {
34 wbuf.resize(buf.len(), 1.0);
35 }
36 }
37 }
38
39 pub fn has_weights(&self) -> bool {
41 self.has_weights
42 }
43
44 pub fn out_edge_weight(&self, edge_idx: usize) -> f64 {
49 self.out_weights
50 .as_ref()
51 .and_then(|ws| ws.get(edge_idx).copied())
52 .unwrap_or(1.0)
53 }
54
55 pub fn in_edge_weight(&self, edge_idx: usize) -> f64 {
57 self.in_weights
58 .as_ref()
59 .and_then(|ws| ws.get(edge_idx).copied())
60 .unwrap_or(1.0)
61 }
62
63 pub fn edge_weight(&self, src: &str, label: &str, dst: &str) -> Option<f64> {
68 let src_id = *self.node_to_id.get(src)?;
69 let dst_id = *self.node_to_id.get(dst)?;
70 let label_id = *self.label_to_id.get(label)?;
71
72 let idx = src_id as usize;
74 if idx + 1 < self.out_offsets.len() {
75 let start = self.out_offsets[idx] as usize;
76 let end = self.out_offsets[idx + 1] as usize;
77 for i in start..end {
78 if self.out_labels[i] == label_id
79 && self.out_targets[i] == dst_id
80 && !self.deleted_edges.contains(&(src_id, label_id, dst_id))
81 {
82 return Some(self.out_edge_weight(i));
83 }
84 }
85 }
86
87 if idx < self.buffer_out.len() {
89 for (buf_idx, &(l, d)) in self.buffer_out[idx].iter().enumerate() {
90 if l == label_id && d == dst_id {
91 if self.has_weights {
92 return Some(
93 self.buffer_out_weights[idx]
94 .get(buf_idx)
95 .copied()
96 .unwrap_or(1.0),
97 );
98 }
99 return Some(1.0);
100 }
101 }
102 }
103
104 None
105 }
106
107 pub fn iter_out_edges_weighted(&self, node: u32) -> impl Iterator<Item = (u32, u32, f64)> + '_ {
112 let idx = node as usize;
113
114 let dense_start = if idx + 1 < self.out_offsets.len() {
116 self.out_offsets[idx] as usize
117 } else {
118 0
119 };
120 let dense_end = if idx + 1 < self.out_offsets.len() {
121 self.out_offsets[idx + 1] as usize
122 } else {
123 0
124 };
125
126 let dense = (dense_start..dense_end)
127 .map(move |i| {
128 let w = self.out_edge_weight(i);
129 (self.out_labels[i], self.out_targets[i], w)
130 })
131 .filter(move |&(lid, dst, _)| !self.deleted_edges.contains(&(node, lid, dst)))
132 .collect::<Vec<_>>();
133
134 let buffer = if idx < self.buffer_out.len() {
136 self.buffer_out[idx]
137 .iter()
138 .enumerate()
139 .map(|(buf_idx, &(lid, dst))| {
140 let w = if self.has_weights {
141 self.buffer_out_weights[idx]
142 .get(buf_idx)
143 .copied()
144 .unwrap_or(1.0)
145 } else {
146 1.0
147 };
148 (lid, dst, w)
149 })
150 .collect::<Vec<_>>()
151 } else {
152 Vec::new()
153 };
154
155 dense.into_iter().chain(buffer)
156 }
157}
158
159pub fn extract_weight_from_properties(properties: &[u8]) -> f64 {
165 if properties.is_empty() {
166 return 1.0;
167 }
168 let Ok(val) = rmpv::decode::read_value(&mut &properties[..]) else {
169 return 1.0;
170 };
171 match val {
172 rmpv::Value::Map(entries) => {
173 for (k, v) in entries {
174 if let rmpv::Value::String(ref s) = k
175 && s.as_str() == Some("weight")
176 {
177 return match v {
178 rmpv::Value::F64(f) => f,
179 rmpv::Value::F32(f) => f as f64,
180 rmpv::Value::Integer(i) => i.as_f64().unwrap_or(1.0),
181 _ => 1.0,
182 };
183 }
184 }
185 1.0
186 }
187 _ => 1.0,
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn unweighted_graph_has_no_weight_arrays() {
197 let mut csr = CsrIndex::new();
198 csr.add_edge("a", "L", "b").unwrap();
199 assert!(!csr.has_weights());
200 assert!(csr.out_weights.is_none());
201 assert!(csr.in_weights.is_none());
202 }
203
204 #[test]
205 fn weighted_edge_basic() {
206 let mut csr = CsrIndex::new();
207 csr.add_edge_weighted("a", "ROAD", "b", 5.0).unwrap();
208 csr.add_edge_weighted("b", "ROAD", "c", 3.0).unwrap();
209 csr.add_edge("c", "ROAD", "d").unwrap(); assert!(csr.has_weights());
212 assert_eq!(csr.edge_weight("a", "ROAD", "b"), Some(5.0));
213 assert_eq!(csr.edge_weight("b", "ROAD", "c"), Some(3.0));
214 assert_eq!(csr.edge_weight("c", "ROAD", "d"), Some(1.0));
215 assert_eq!(csr.edge_weight("a", "ROAD", "c"), None);
216 }
217
218 #[test]
219 fn weighted_edges_survive_compaction() {
220 let mut csr = CsrIndex::new();
221 csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
222 csr.add_edge_weighted("b", "R", "c", 7.0).unwrap();
223 csr.add_edge("c", "R", "d").unwrap();
224
225 csr.compact();
226
227 assert!(csr.has_weights());
228 assert_eq!(csr.edge_weight("a", "R", "b"), Some(2.5));
229 assert_eq!(csr.edge_weight("b", "R", "c"), Some(7.0));
230 assert_eq!(csr.edge_weight("c", "R", "d"), Some(1.0));
231 }
232
233 #[test]
234 fn weighted_edge_remove_keeps_weights_consistent() {
235 let mut csr = CsrIndex::new();
236 csr.add_edge_weighted("a", "R", "b", 2.0).unwrap();
237 csr.add_edge_weighted("a", "R", "c", 3.0).unwrap();
238 csr.add_edge_weighted("a", "R", "d", 4.0).unwrap();
239
240 csr.remove_edge("a", "R", "c");
241
242 assert_eq!(csr.edge_weight("a", "R", "b"), Some(2.0));
243 assert_eq!(csr.edge_weight("a", "R", "c"), None);
244 assert_eq!(csr.edge_weight("a", "R", "d"), Some(4.0));
245 }
246
247 #[test]
248 fn iter_out_edges_weighted_returns_weights() {
249 let mut csr = CsrIndex::new();
250 csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
251 csr.add_edge_weighted("a", "R", "c", 7.0).unwrap();
252 csr.compact();
253
254 let edges: Vec<(u32, u32, f64)> = csr.iter_out_edges_weighted(0).collect();
255 assert_eq!(edges.len(), 2);
256
257 let weights: Vec<f64> = edges.iter().map(|e| e.2).collect();
258 assert!(weights.contains(&2.5));
259 assert!(weights.contains(&7.0));
260 }
261
262 #[test]
263 fn mixed_weighted_unweighted_backfill() {
264 let mut csr = CsrIndex::new();
265 csr.add_edge("a", "L", "b").unwrap();
266 csr.add_edge("b", "L", "c").unwrap();
267 assert!(!csr.has_weights());
268
269 csr.add_edge_weighted("c", "L", "d", 5.0).unwrap();
270 assert!(csr.has_weights());
271 assert_eq!(csr.edge_weight("a", "L", "b"), Some(1.0));
272 assert_eq!(csr.edge_weight("c", "L", "d"), Some(5.0));
273 }
274
275 #[test]
276 fn extract_weight_from_empty_properties() {
277 assert_eq!(extract_weight_from_properties(b""), 1.0);
278 }
279
280 #[test]
281 fn extract_weight_f64() {
282 let props = rmpv::Value::Map(vec![(
283 rmpv::Value::String("weight".into()),
284 rmpv::Value::F64(0.75),
285 )]);
286 let mut buf = Vec::new();
287 rmpv::encode::write_value(&mut buf, &props).unwrap();
288 assert_eq!(extract_weight_from_properties(&buf), 0.75);
289 }
290
291 #[test]
292 fn extract_weight_integer() {
293 let props = rmpv::Value::Map(vec![(
294 rmpv::Value::String("weight".into()),
295 rmpv::Value::Integer(rmpv::Integer::from(42)),
296 )]);
297 let mut buf = Vec::new();
298 rmpv::encode::write_value(&mut buf, &props).unwrap();
299 assert_eq!(extract_weight_from_properties(&buf), 42.0);
300 }
301
302 #[test]
303 fn extract_weight_missing_key() {
304 let props = rmpv::Value::Map(vec![(
305 rmpv::Value::String("color".into()),
306 rmpv::Value::String("red".into()),
307 )]);
308 let mut buf = Vec::new();
309 rmpv::encode::write_value(&mut buf, &props).unwrap();
310 assert_eq!(extract_weight_from_properties(&buf), 1.0);
311 }
312}