1use super::index::CsrIndex;
8use crate::csr::LocalNodeId;
9
10impl CsrIndex {
11 pub(crate) fn enable_weights(&mut self) {
13 self.has_weights = true;
14
15 if !self.out_targets.is_empty() {
17 self.out_weights = Some(vec![1.0; self.out_targets.len()].into());
18 }
19 if !self.in_targets.is_empty() {
20 self.in_weights = Some(vec![1.0; self.in_targets.len()].into());
21 }
22
23 for (buf, wbuf) in self
25 .buffer_out
26 .iter()
27 .zip(self.buffer_out_weights.iter_mut())
28 {
29 if wbuf.len() < buf.len() {
30 wbuf.resize(buf.len(), 1.0);
31 }
32 }
33 for (buf, wbuf) in self.buffer_in.iter().zip(self.buffer_in_weights.iter_mut()) {
34 if wbuf.len() < buf.len() {
35 wbuf.resize(buf.len(), 1.0);
36 }
37 }
38 }
39
40 pub fn has_weights(&self) -> bool {
42 self.has_weights
43 }
44
45 pub fn out_edge_weight(&self, edge_idx: usize) -> f64 {
50 self.out_weights
51 .as_ref()
52 .and_then(|ws| ws.get(edge_idx).copied())
53 .unwrap_or(1.0)
54 }
55
56 pub fn in_edge_weight(&self, edge_idx: usize) -> f64 {
58 self.in_weights
59 .as_ref()
60 .and_then(|ws| ws.get(edge_idx).copied())
61 .unwrap_or(1.0)
62 }
63
64 pub fn edge_weight(&self, src: &str, label: &str, dst: &str) -> Option<f64> {
69 let src_id = *self.node_to_id.get(src)?;
70 let dst_id = *self.node_to_id.get(dst)?;
71 let label_id = *self.label_to_id.get(label)?;
72
73 let idx = src_id as usize;
75 if idx + 1 < self.out_offsets.len() {
76 let start = self.out_offsets[idx] as usize;
77 let end = self.out_offsets[idx + 1] as usize;
78 for i in start..end {
79 if self.out_labels[i] == label_id
80 && self.out_targets[i] == dst_id
81 && !self.deleted_edges.contains(&(src_id, label_id, dst_id))
82 {
83 return Some(self.out_edge_weight(i));
84 }
85 }
86 }
87
88 if idx < self.buffer_out.len() {
90 for (buf_idx, &(l, d)) in self.buffer_out[idx].iter().enumerate() {
91 if l == label_id && d == dst_id {
92 if self.has_weights {
93 return Some(
94 self.buffer_out_weights[idx]
95 .get(buf_idx)
96 .copied()
97 .unwrap_or(1.0),
98 );
99 }
100 return Some(1.0);
101 }
102 }
103 }
104
105 None
106 }
107
108 pub fn iter_out_edges_weighted(
113 &self,
114 node: LocalNodeId,
115 ) -> impl Iterator<Item = (u32, LocalNodeId, f64)> + '_ {
116 let node = node.raw(self.partition_tag);
117 let tag = self.partition_tag;
118 let idx = node as usize;
119
120 let dense_start = if idx + 1 < self.out_offsets.len() {
122 self.out_offsets[idx] as usize
123 } else {
124 0
125 };
126 let dense_end = if idx + 1 < self.out_offsets.len() {
127 self.out_offsets[idx + 1] as usize
128 } else {
129 0
130 };
131
132 let dense = (dense_start..dense_end)
133 .map(move |i| {
134 let w = self.out_edge_weight(i);
135 (self.out_labels[i], self.out_targets[i], w)
136 })
137 .filter(move |&(lid, dst, _)| !self.deleted_edges.contains(&(node, lid, dst)))
138 .collect::<Vec<_>>();
139
140 let buffer = if idx < self.buffer_out.len() {
142 self.buffer_out[idx]
143 .iter()
144 .enumerate()
145 .map(|(buf_idx, &(lid, dst))| {
146 let w = if self.has_weights {
147 self.buffer_out_weights[idx]
148 .get(buf_idx)
149 .copied()
150 .unwrap_or(1.0)
151 } else {
152 1.0
153 };
154 (lid, dst, w)
155 })
156 .collect::<Vec<_>>()
157 } else {
158 Vec::new()
159 };
160
161 dense
162 .into_iter()
163 .chain(buffer)
164 .map(move |(lid, dst, w)| (lid, LocalNodeId::new(dst, tag), w))
165 }
166
167 pub fn iter_out_edges_weighted_raw(
171 &self,
172 node: u32,
173 ) -> impl Iterator<Item = (u32, u32, f64)> + '_ {
174 self.iter_out_edges_weighted(self.local(node))
175 .map(move |(lid, dst, w)| (lid, dst.raw(self.partition_tag), w))
176 }
177}
178
179pub fn extract_weight_from_properties(properties: &[u8]) -> f64 {
185 if properties.is_empty() {
186 return 1.0;
187 }
188 let Ok(val) = rmpv::decode::read_value(&mut &properties[..]) else {
189 return 1.0;
190 };
191 match val {
192 rmpv::Value::Map(entries) => {
193 for (k, v) in entries {
194 if let rmpv::Value::String(ref s) = k
195 && s.as_str() == Some("weight")
196 {
197 return match v {
198 rmpv::Value::F64(f) => f,
199 rmpv::Value::F32(f) => f as f64,
200 rmpv::Value::Integer(i) => i.as_f64().unwrap_or(1.0),
201 _ => 1.0,
202 };
203 }
204 }
205 1.0
206 }
207 _ => 1.0,
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn unweighted_graph_has_no_weight_arrays() {
217 let mut csr = CsrIndex::new();
218 csr.add_edge("a", "L", "b").unwrap();
219 assert!(!csr.has_weights());
220 assert!(csr.out_weights.is_none());
221 assert!(csr.in_weights.is_none());
222 }
223
224 #[test]
225 fn weighted_edge_basic() {
226 let mut csr = CsrIndex::new();
227 csr.add_edge_weighted("a", "ROAD", "b", 5.0).unwrap();
228 csr.add_edge_weighted("b", "ROAD", "c", 3.0).unwrap();
229 csr.add_edge("c", "ROAD", "d").unwrap(); assert!(csr.has_weights());
232 assert_eq!(csr.edge_weight("a", "ROAD", "b"), Some(5.0));
233 assert_eq!(csr.edge_weight("b", "ROAD", "c"), Some(3.0));
234 assert_eq!(csr.edge_weight("c", "ROAD", "d"), Some(1.0));
235 assert_eq!(csr.edge_weight("a", "ROAD", "c"), None);
236 }
237
238 #[test]
239 fn weighted_edges_survive_compaction() {
240 let mut csr = CsrIndex::new();
241 csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
242 csr.add_edge_weighted("b", "R", "c", 7.0).unwrap();
243 csr.add_edge("c", "R", "d").unwrap();
244
245 csr.compact();
246
247 assert!(csr.has_weights());
248 assert_eq!(csr.edge_weight("a", "R", "b"), Some(2.5));
249 assert_eq!(csr.edge_weight("b", "R", "c"), Some(7.0));
250 assert_eq!(csr.edge_weight("c", "R", "d"), Some(1.0));
251 }
252
253 #[test]
254 fn weighted_edge_remove_keeps_weights_consistent() {
255 let mut csr = CsrIndex::new();
256 csr.add_edge_weighted("a", "R", "b", 2.0).unwrap();
257 csr.add_edge_weighted("a", "R", "c", 3.0).unwrap();
258 csr.add_edge_weighted("a", "R", "d", 4.0).unwrap();
259
260 csr.remove_edge("a", "R", "c");
261
262 assert_eq!(csr.edge_weight("a", "R", "b"), Some(2.0));
263 assert_eq!(csr.edge_weight("a", "R", "c"), None);
264 assert_eq!(csr.edge_weight("a", "R", "d"), Some(4.0));
265 }
266
267 #[test]
268 fn iter_out_edges_weighted_returns_weights() {
269 let mut csr = CsrIndex::new();
270 csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
271 csr.add_edge_weighted("a", "R", "c", 7.0).unwrap();
272 csr.compact();
273
274 let edges: Vec<(u32, LocalNodeId, f64)> =
275 csr.iter_out_edges_weighted(csr.local(0)).collect();
276 assert_eq!(edges.len(), 2);
277
278 let weights: Vec<f64> = edges.iter().map(|e| e.2).collect();
279 assert!(weights.contains(&2.5));
280 assert!(weights.contains(&7.0));
281 }
282
283 #[test]
284 fn mixed_weighted_unweighted_backfill() {
285 let mut csr = CsrIndex::new();
286 csr.add_edge("a", "L", "b").unwrap();
287 csr.add_edge("b", "L", "c").unwrap();
288 assert!(!csr.has_weights());
289
290 csr.add_edge_weighted("c", "L", "d", 5.0).unwrap();
291 assert!(csr.has_weights());
292 assert_eq!(csr.edge_weight("a", "L", "b"), Some(1.0));
293 assert_eq!(csr.edge_weight("c", "L", "d"), Some(5.0));
294 }
295
296 #[test]
297 fn extract_weight_from_empty_properties() {
298 assert_eq!(extract_weight_from_properties(b""), 1.0);
299 }
300
301 #[test]
302 fn extract_weight_f64() {
303 let props = rmpv::Value::Map(vec![(
304 rmpv::Value::String("weight".into()),
305 rmpv::Value::F64(0.75),
306 )]);
307 let mut buf = Vec::new();
308 rmpv::encode::write_value(&mut buf, &props).unwrap();
309 assert_eq!(extract_weight_from_properties(&buf), 0.75);
310 }
311
312 #[test]
313 fn extract_weight_integer() {
314 let props = rmpv::Value::Map(vec![(
315 rmpv::Value::String("weight".into()),
316 rmpv::Value::Integer(rmpv::Integer::from(42)),
317 )]);
318 let mut buf = Vec::new();
319 rmpv::encode::write_value(&mut buf, &props).unwrap();
320 assert_eq!(extract_weight_from_properties(&buf), 42.0);
321 }
322
323 #[test]
324 fn extract_weight_missing_key() {
325 let props = rmpv::Value::Map(vec![(
326 rmpv::Value::String("color".into()),
327 rmpv::Value::String("red".into()),
328 )]);
329 let mut buf = Vec::new();
330 rmpv::encode::write_value(&mut buf, &props).unwrap();
331 assert_eq!(extract_weight_from_properties(&buf), 1.0);
332 }
333}