Skip to main content

nodedb_graph/csr/
weights.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Edge weight management for the CSR index.
4//!
5//! Optional `f64` weight per edge stored in parallel arrays. `None` when the
6//! graph is entirely unweighted (zero memory overhead). Populated from the
7//! `"weight"` edge property at insertion time. Unweighted edges default to 1.0.
8
9use super::index::CsrIndex;
10use crate::csr::LocalNodeId;
11
12impl CsrIndex {
13    /// Enable weight tracking. Backfills existing buffer entries with 1.0.
14    pub(crate) fn enable_weights(&mut self) {
15        self.has_weights = true;
16
17        // Backfill existing dense arrays with 1.0.
18        if !self.out_targets.is_empty() {
19            self.out_weights = Some(vec![1.0; self.out_targets.len()].into());
20        }
21        if !self.in_targets.is_empty() {
22            self.in_weights = Some(vec![1.0; self.in_targets.len()].into());
23        }
24
25        // Backfill existing buffer entries with 1.0.
26        for (buf, wbuf) in self
27            .buffer_out
28            .iter()
29            .zip(self.buffer_out_weights.iter_mut())
30        {
31            if wbuf.len() < buf.len() {
32                wbuf.resize(buf.len(), 1.0);
33            }
34        }
35        for (buf, wbuf) in self.buffer_in.iter().zip(self.buffer_in_weights.iter_mut()) {
36            if wbuf.len() < buf.len() {
37                wbuf.resize(buf.len(), 1.0);
38            }
39        }
40    }
41
42    /// Whether this CSR has any weighted edges.
43    pub fn has_weights(&self) -> bool {
44        self.has_weights
45    }
46
47    /// Get the weight of the i-th outbound edge from a node (dense CSR only).
48    ///
49    /// `edge_idx` is the absolute index into `out_targets`/`out_weights`.
50    /// Returns 1.0 for unweighted graphs.
51    pub fn out_edge_weight(&self, edge_idx: usize) -> f64 {
52        self.out_weights
53            .as_ref()
54            .and_then(|ws| ws.get(edge_idx).copied())
55            .unwrap_or(1.0)
56    }
57
58    /// Get the weight of the i-th inbound edge to a node (dense CSR only).
59    pub fn in_edge_weight(&self, edge_idx: usize) -> f64 {
60        self.in_weights
61            .as_ref()
62            .and_then(|ws| ws.get(edge_idx).copied())
63            .unwrap_or(1.0)
64    }
65
66    /// Get the weight of a specific outbound edge from `src` to `dst` via `label`.
67    ///
68    /// Checks both dense and buffer. Returns 1.0 if the edge exists but has
69    /// no weight, or `None` if the edge doesn't exist.
70    pub fn edge_weight(&self, src: &str, label: &str, dst: &str) -> Option<f64> {
71        let src_id = *self.node_to_id.get(src)?;
72        let dst_id = *self.node_to_id.get(dst)?;
73        let label_id = *self.label_to_id.get(label)?;
74
75        // Check dense CSR.
76        let idx = src_id as usize;
77        if idx + 1 < self.out_offsets.len() {
78            let start = self.out_offsets[idx] as usize;
79            let end = self.out_offsets[idx + 1] as usize;
80            for i in start..end {
81                if self.out_labels[i] == label_id
82                    && self.out_targets[i] == dst_id
83                    && !self.deleted_edges.contains(&(src_id, label_id, dst_id))
84                {
85                    return Some(self.out_edge_weight(i));
86                }
87            }
88        }
89
90        // Check buffer.
91        if idx < self.buffer_out.len() {
92            for (buf_idx, &(l, d)) in self.buffer_out[idx].iter().enumerate() {
93                if l == label_id && d == dst_id {
94                    if self.has_weights {
95                        return Some(
96                            self.buffer_out_weights[idx]
97                                .get(buf_idx)
98                                .copied()
99                                .unwrap_or(1.0),
100                        );
101                    }
102                    return Some(1.0);
103                }
104            }
105        }
106
107        None
108    }
109
110    /// Iterate outbound edges of a node with weights: `(label_id, dst_id, weight)`.
111    ///
112    /// Yields from both dense CSR and buffer, excluding deleted edges.
113    /// Weights are 1.0 for unweighted graphs.
114    pub fn iter_out_edges_weighted(
115        &self,
116        node: LocalNodeId,
117    ) -> impl Iterator<Item = (u32, LocalNodeId, f64)> + '_ {
118        let node = node.raw(self.partition_tag);
119        let tag = self.partition_tag;
120        let idx = node as usize;
121
122        // Dense edges with weights.
123        let dense_start = if idx + 1 < self.out_offsets.len() {
124            self.out_offsets[idx] as usize
125        } else {
126            0
127        };
128        let dense_end = if idx + 1 < self.out_offsets.len() {
129            self.out_offsets[idx + 1] as usize
130        } else {
131            0
132        };
133
134        let dense = (dense_start..dense_end)
135            .map(move |i| {
136                let w = self.out_edge_weight(i);
137                (self.out_labels[i], self.out_targets[i], w)
138            })
139            .filter(move |&(lid, dst, _)| !self.deleted_edges.contains(&(node, lid, dst)))
140            .collect::<Vec<_>>();
141
142        // Buffer edges with weights.
143        let buffer = if idx < self.buffer_out.len() {
144            self.buffer_out[idx]
145                .iter()
146                .enumerate()
147                .map(|(buf_idx, &(lid, dst))| {
148                    let w = if self.has_weights {
149                        self.buffer_out_weights[idx]
150                            .get(buf_idx)
151                            .copied()
152                            .unwrap_or(1.0)
153                    } else {
154                        1.0
155                    };
156                    (lid, dst, w)
157                })
158                .collect::<Vec<_>>()
159        } else {
160            Vec::new()
161        };
162
163        dense
164            .into_iter()
165            .chain(buffer)
166            .map(move |(lid, dst, w)| (lid, LocalNodeId::new(dst, tag), w))
167    }
168
169    /// Raw u32 variant of `iter_out_edges_weighted`. In-partition
170    /// algorithm use only — see [`Self::iter_out_edges_raw`] for the
171    /// safety rationale.
172    pub fn iter_out_edges_weighted_raw(
173        &self,
174        node: u32,
175    ) -> impl Iterator<Item = (u32, u32, f64)> + '_ {
176        self.iter_out_edges_weighted(self.local(node))
177            .map(move |(lid, dst, w)| (lid, dst.raw(self.partition_tag), w))
178    }
179}
180
181/// Extract the `"weight"` property from MessagePack-encoded edge properties.
182///
183/// Returns 1.0 if properties are empty, malformed, or don't contain a
184/// `"weight"` key. Handles F64, F32, and integer weight values; other
185/// numeric types default to 1.0.
186pub fn extract_weight_from_properties(properties: &[u8]) -> f64 {
187    if properties.is_empty() {
188        return 1.0;
189    }
190    let Ok(val) = rmpv::decode::read_value(&mut &properties[..]) else {
191        return 1.0;
192    };
193    match val {
194        rmpv::Value::Map(entries) => {
195            for (k, v) in entries {
196                if let rmpv::Value::String(ref s) = k
197                    && s.as_str() == Some("weight")
198                {
199                    return match v {
200                        rmpv::Value::F64(f) => f,
201                        rmpv::Value::F32(f) => f as f64,
202                        rmpv::Value::Integer(i) => i.as_f64().unwrap_or(1.0),
203                        _ => 1.0,
204                    };
205                }
206            }
207            1.0
208        }
209        _ => 1.0,
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn unweighted_graph_has_no_weight_arrays() {
219        let mut csr = CsrIndex::new();
220        csr.add_edge("a", "L", "b").unwrap();
221        assert!(!csr.has_weights());
222        assert!(csr.out_weights.is_none());
223        assert!(csr.in_weights.is_none());
224    }
225
226    #[test]
227    fn weighted_edge_basic() {
228        let mut csr = CsrIndex::new();
229        csr.add_edge_weighted("a", "ROAD", "b", 5.0).unwrap();
230        csr.add_edge_weighted("b", "ROAD", "c", 3.0).unwrap();
231        csr.add_edge("c", "ROAD", "d").unwrap(); // unweighted → 1.0
232
233        assert!(csr.has_weights());
234        assert_eq!(csr.edge_weight("a", "ROAD", "b"), Some(5.0));
235        assert_eq!(csr.edge_weight("b", "ROAD", "c"), Some(3.0));
236        assert_eq!(csr.edge_weight("c", "ROAD", "d"), Some(1.0));
237        assert_eq!(csr.edge_weight("a", "ROAD", "c"), None);
238    }
239
240    #[test]
241    fn weighted_edges_survive_compaction() {
242        let mut csr = CsrIndex::new();
243        csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
244        csr.add_edge_weighted("b", "R", "c", 7.0).unwrap();
245        csr.add_edge("c", "R", "d").unwrap();
246
247        csr.compact().expect("no governor, cannot fail");
248
249        assert!(csr.has_weights());
250        assert_eq!(csr.edge_weight("a", "R", "b"), Some(2.5));
251        assert_eq!(csr.edge_weight("b", "R", "c"), Some(7.0));
252        assert_eq!(csr.edge_weight("c", "R", "d"), Some(1.0));
253    }
254
255    #[test]
256    fn weighted_edge_remove_keeps_weights_consistent() {
257        let mut csr = CsrIndex::new();
258        csr.add_edge_weighted("a", "R", "b", 2.0).unwrap();
259        csr.add_edge_weighted("a", "R", "c", 3.0).unwrap();
260        csr.add_edge_weighted("a", "R", "d", 4.0).unwrap();
261
262        csr.remove_edge("a", "R", "c");
263
264        assert_eq!(csr.edge_weight("a", "R", "b"), Some(2.0));
265        assert_eq!(csr.edge_weight("a", "R", "c"), None);
266        assert_eq!(csr.edge_weight("a", "R", "d"), Some(4.0));
267    }
268
269    #[test]
270    fn iter_out_edges_weighted_returns_weights() {
271        let mut csr = CsrIndex::new();
272        csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
273        csr.add_edge_weighted("a", "R", "c", 7.0).unwrap();
274        csr.compact().expect("no governor, cannot fail");
275
276        let edges: Vec<(u32, LocalNodeId, f64)> =
277            csr.iter_out_edges_weighted(csr.local(0)).collect();
278        assert_eq!(edges.len(), 2);
279
280        let weights: Vec<f64> = edges.iter().map(|e| e.2).collect();
281        assert!(weights.contains(&2.5));
282        assert!(weights.contains(&7.0));
283    }
284
285    #[test]
286    fn mixed_weighted_unweighted_backfill() {
287        let mut csr = CsrIndex::new();
288        csr.add_edge("a", "L", "b").unwrap();
289        csr.add_edge("b", "L", "c").unwrap();
290        assert!(!csr.has_weights());
291
292        csr.add_edge_weighted("c", "L", "d", 5.0).unwrap();
293        assert!(csr.has_weights());
294        assert_eq!(csr.edge_weight("a", "L", "b"), Some(1.0));
295        assert_eq!(csr.edge_weight("c", "L", "d"), Some(5.0));
296    }
297
298    #[test]
299    fn extract_weight_from_empty_properties() {
300        assert_eq!(extract_weight_from_properties(b""), 1.0);
301    }
302
303    #[test]
304    fn extract_weight_f64() {
305        let props = rmpv::Value::Map(vec![(
306            rmpv::Value::String("weight".into()),
307            rmpv::Value::F64(0.75),
308        )]);
309        let mut buf = Vec::new();
310        rmpv::encode::write_value(&mut buf, &props).unwrap();
311        assert_eq!(extract_weight_from_properties(&buf), 0.75);
312    }
313
314    #[test]
315    fn extract_weight_integer() {
316        let props = rmpv::Value::Map(vec![(
317            rmpv::Value::String("weight".into()),
318            rmpv::Value::Integer(rmpv::Integer::from(42)),
319        )]);
320        let mut buf = Vec::new();
321        rmpv::encode::write_value(&mut buf, &props).unwrap();
322        assert_eq!(extract_weight_from_properties(&buf), 42.0);
323    }
324
325    #[test]
326    fn extract_weight_missing_key() {
327        let props = rmpv::Value::Map(vec![(
328            rmpv::Value::String("color".into()),
329            rmpv::Value::String("red".into()),
330        )]);
331        let mut buf = Vec::new();
332        rmpv::encode::write_value(&mut buf, &props).unwrap();
333        assert_eq!(extract_weight_from_properties(&buf), 1.0);
334    }
335}