1use super::index::CsrIndex;
10use crate::csr::LocalNodeId;
11
12impl CsrIndex {
13 pub(crate) fn enable_weights(&mut self) {
15 self.has_weights = true;
16
17 if !self.out_targets.is_empty() {
19 self.out_weights = Some(vec![1.0; self.out_targets.len()].into());
20 }
21 if !self.in_targets.is_empty() {
22 self.in_weights = Some(vec![1.0; self.in_targets.len()].into());
23 }
24
25 for (buf, wbuf) in self
27 .buffer_out
28 .iter()
29 .zip(self.buffer_out_weights.iter_mut())
30 {
31 if wbuf.len() < buf.len() {
32 wbuf.resize(buf.len(), 1.0);
33 }
34 }
35 for (buf, wbuf) in self.buffer_in.iter().zip(self.buffer_in_weights.iter_mut()) {
36 if wbuf.len() < buf.len() {
37 wbuf.resize(buf.len(), 1.0);
38 }
39 }
40 }
41
42 pub fn has_weights(&self) -> bool {
44 self.has_weights
45 }
46
47 pub fn out_edge_weight(&self, edge_idx: usize) -> f64 {
52 self.out_weights
53 .as_ref()
54 .and_then(|ws| ws.get(edge_idx).copied())
55 .unwrap_or(1.0)
56 }
57
58 pub fn in_edge_weight(&self, edge_idx: usize) -> f64 {
60 self.in_weights
61 .as_ref()
62 .and_then(|ws| ws.get(edge_idx).copied())
63 .unwrap_or(1.0)
64 }
65
66 pub fn edge_weight(&self, src: &str, label: &str, dst: &str) -> Option<f64> {
71 let src_id = *self.node_to_id.get(src)?;
72 let dst_id = *self.node_to_id.get(dst)?;
73 let label_id = *self.label_to_id.get(label)?;
74
75 let idx = src_id as usize;
77 if idx + 1 < self.out_offsets.len() {
78 let start = self.out_offsets[idx] as usize;
79 let end = self.out_offsets[idx + 1] as usize;
80 for i in start..end {
81 if self.out_labels[i] == label_id
82 && self.out_targets[i] == dst_id
83 && !self.deleted_edges.contains(&(src_id, label_id, dst_id))
84 {
85 return Some(self.out_edge_weight(i));
86 }
87 }
88 }
89
90 if idx < self.buffer_out.len() {
92 for (buf_idx, &(l, d)) in self.buffer_out[idx].iter().enumerate() {
93 if l == label_id && d == dst_id {
94 if self.has_weights {
95 return Some(
96 self.buffer_out_weights[idx]
97 .get(buf_idx)
98 .copied()
99 .unwrap_or(1.0),
100 );
101 }
102 return Some(1.0);
103 }
104 }
105 }
106
107 None
108 }
109
110 pub fn iter_out_edges_weighted(
115 &self,
116 node: LocalNodeId,
117 ) -> impl Iterator<Item = (u32, LocalNodeId, f64)> + '_ {
118 let node = node.raw(self.partition_tag);
119 let tag = self.partition_tag;
120 let idx = node as usize;
121
122 let dense_start = if idx + 1 < self.out_offsets.len() {
124 self.out_offsets[idx] as usize
125 } else {
126 0
127 };
128 let dense_end = if idx + 1 < self.out_offsets.len() {
129 self.out_offsets[idx + 1] as usize
130 } else {
131 0
132 };
133
134 let dense = (dense_start..dense_end)
135 .map(move |i| {
136 let w = self.out_edge_weight(i);
137 (self.out_labels[i], self.out_targets[i], w)
138 })
139 .filter(move |&(lid, dst, _)| !self.deleted_edges.contains(&(node, lid, dst)))
140 .collect::<Vec<_>>();
141
142 let buffer = if idx < self.buffer_out.len() {
144 self.buffer_out[idx]
145 .iter()
146 .enumerate()
147 .map(|(buf_idx, &(lid, dst))| {
148 let w = if self.has_weights {
149 self.buffer_out_weights[idx]
150 .get(buf_idx)
151 .copied()
152 .unwrap_or(1.0)
153 } else {
154 1.0
155 };
156 (lid, dst, w)
157 })
158 .collect::<Vec<_>>()
159 } else {
160 Vec::new()
161 };
162
163 dense
164 .into_iter()
165 .chain(buffer)
166 .map(move |(lid, dst, w)| (lid, LocalNodeId::new(dst, tag), w))
167 }
168
169 pub fn iter_out_edges_weighted_raw(
173 &self,
174 node: u32,
175 ) -> impl Iterator<Item = (u32, u32, f64)> + '_ {
176 self.iter_out_edges_weighted(self.local(node))
177 .map(move |(lid, dst, w)| (lid, dst.raw(self.partition_tag), w))
178 }
179}
180
181pub fn extract_weight_from_properties(properties: &[u8]) -> f64 {
187 if properties.is_empty() {
188 return 1.0;
189 }
190 let Ok(val) = rmpv::decode::read_value(&mut &properties[..]) else {
191 return 1.0;
192 };
193 match val {
194 rmpv::Value::Map(entries) => {
195 for (k, v) in entries {
196 if let rmpv::Value::String(ref s) = k
197 && s.as_str() == Some("weight")
198 {
199 return match v {
200 rmpv::Value::F64(f) => f,
201 rmpv::Value::F32(f) => f as f64,
202 rmpv::Value::Integer(i) => i.as_f64().unwrap_or(1.0),
203 _ => 1.0,
204 };
205 }
206 }
207 1.0
208 }
209 _ => 1.0,
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn unweighted_graph_has_no_weight_arrays() {
219 let mut csr = CsrIndex::new();
220 csr.add_edge("a", "L", "b").unwrap();
221 assert!(!csr.has_weights());
222 assert!(csr.out_weights.is_none());
223 assert!(csr.in_weights.is_none());
224 }
225
226 #[test]
227 fn weighted_edge_basic() {
228 let mut csr = CsrIndex::new();
229 csr.add_edge_weighted("a", "ROAD", "b", 5.0).unwrap();
230 csr.add_edge_weighted("b", "ROAD", "c", 3.0).unwrap();
231 csr.add_edge("c", "ROAD", "d").unwrap(); assert!(csr.has_weights());
234 assert_eq!(csr.edge_weight("a", "ROAD", "b"), Some(5.0));
235 assert_eq!(csr.edge_weight("b", "ROAD", "c"), Some(3.0));
236 assert_eq!(csr.edge_weight("c", "ROAD", "d"), Some(1.0));
237 assert_eq!(csr.edge_weight("a", "ROAD", "c"), None);
238 }
239
240 #[test]
241 fn weighted_edges_survive_compaction() {
242 let mut csr = CsrIndex::new();
243 csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
244 csr.add_edge_weighted("b", "R", "c", 7.0).unwrap();
245 csr.add_edge("c", "R", "d").unwrap();
246
247 csr.compact().expect("no governor, cannot fail");
248
249 assert!(csr.has_weights());
250 assert_eq!(csr.edge_weight("a", "R", "b"), Some(2.5));
251 assert_eq!(csr.edge_weight("b", "R", "c"), Some(7.0));
252 assert_eq!(csr.edge_weight("c", "R", "d"), Some(1.0));
253 }
254
255 #[test]
256 fn weighted_edge_remove_keeps_weights_consistent() {
257 let mut csr = CsrIndex::new();
258 csr.add_edge_weighted("a", "R", "b", 2.0).unwrap();
259 csr.add_edge_weighted("a", "R", "c", 3.0).unwrap();
260 csr.add_edge_weighted("a", "R", "d", 4.0).unwrap();
261
262 csr.remove_edge("a", "R", "c");
263
264 assert_eq!(csr.edge_weight("a", "R", "b"), Some(2.0));
265 assert_eq!(csr.edge_weight("a", "R", "c"), None);
266 assert_eq!(csr.edge_weight("a", "R", "d"), Some(4.0));
267 }
268
269 #[test]
270 fn iter_out_edges_weighted_returns_weights() {
271 let mut csr = CsrIndex::new();
272 csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
273 csr.add_edge_weighted("a", "R", "c", 7.0).unwrap();
274 csr.compact().expect("no governor, cannot fail");
275
276 let edges: Vec<(u32, LocalNodeId, f64)> =
277 csr.iter_out_edges_weighted(csr.local(0)).collect();
278 assert_eq!(edges.len(), 2);
279
280 let weights: Vec<f64> = edges.iter().map(|e| e.2).collect();
281 assert!(weights.contains(&2.5));
282 assert!(weights.contains(&7.0));
283 }
284
285 #[test]
286 fn mixed_weighted_unweighted_backfill() {
287 let mut csr = CsrIndex::new();
288 csr.add_edge("a", "L", "b").unwrap();
289 csr.add_edge("b", "L", "c").unwrap();
290 assert!(!csr.has_weights());
291
292 csr.add_edge_weighted("c", "L", "d", 5.0).unwrap();
293 assert!(csr.has_weights());
294 assert_eq!(csr.edge_weight("a", "L", "b"), Some(1.0));
295 assert_eq!(csr.edge_weight("c", "L", "d"), Some(5.0));
296 }
297
298 #[test]
299 fn extract_weight_from_empty_properties() {
300 assert_eq!(extract_weight_from_properties(b""), 1.0);
301 }
302
303 #[test]
304 fn extract_weight_f64() {
305 let props = rmpv::Value::Map(vec![(
306 rmpv::Value::String("weight".into()),
307 rmpv::Value::F64(0.75),
308 )]);
309 let mut buf = Vec::new();
310 rmpv::encode::write_value(&mut buf, &props).unwrap();
311 assert_eq!(extract_weight_from_properties(&buf), 0.75);
312 }
313
314 #[test]
315 fn extract_weight_integer() {
316 let props = rmpv::Value::Map(vec![(
317 rmpv::Value::String("weight".into()),
318 rmpv::Value::Integer(rmpv::Integer::from(42)),
319 )]);
320 let mut buf = Vec::new();
321 rmpv::encode::write_value(&mut buf, &props).unwrap();
322 assert_eq!(extract_weight_from_properties(&buf), 42.0);
323 }
324
325 #[test]
326 fn extract_weight_missing_key() {
327 let props = rmpv::Value::Map(vec![(
328 rmpv::Value::String("color".into()),
329 rmpv::Value::String("red".into()),
330 )]);
331 let mut buf = Vec::new();
332 rmpv::encode::write_value(&mut buf, &props).unwrap();
333 assert_eq!(extract_weight_from_properties(&buf), 1.0);
334 }
335}