Skip to main content

nodedb_graph/csr/
weights.rs

1//! Edge weight management for the CSR index.
2//!
3//! Optional `f64` weight per edge stored in parallel arrays. `None` when the
4//! graph is entirely unweighted (zero memory overhead). Populated from the
5//! `"weight"` edge property at insertion time. Unweighted edges default to 1.0.
6
7use super::index::CsrIndex;
8use crate::csr::LocalNodeId;
9
10impl CsrIndex {
11    /// Enable weight tracking. Backfills existing buffer entries with 1.0.
12    pub(crate) fn enable_weights(&mut self) {
13        self.has_weights = true;
14
15        // Backfill existing dense arrays with 1.0.
16        if !self.out_targets.is_empty() {
17            self.out_weights = Some(vec![1.0; self.out_targets.len()].into());
18        }
19        if !self.in_targets.is_empty() {
20            self.in_weights = Some(vec![1.0; self.in_targets.len()].into());
21        }
22
23        // Backfill existing buffer entries with 1.0.
24        for (buf, wbuf) in self
25            .buffer_out
26            .iter()
27            .zip(self.buffer_out_weights.iter_mut())
28        {
29            if wbuf.len() < buf.len() {
30                wbuf.resize(buf.len(), 1.0);
31            }
32        }
33        for (buf, wbuf) in self.buffer_in.iter().zip(self.buffer_in_weights.iter_mut()) {
34            if wbuf.len() < buf.len() {
35                wbuf.resize(buf.len(), 1.0);
36            }
37        }
38    }
39
40    /// Whether this CSR has any weighted edges.
41    pub fn has_weights(&self) -> bool {
42        self.has_weights
43    }
44
45    /// Get the weight of the i-th outbound edge from a node (dense CSR only).
46    ///
47    /// `edge_idx` is the absolute index into `out_targets`/`out_weights`.
48    /// Returns 1.0 for unweighted graphs.
49    pub fn out_edge_weight(&self, edge_idx: usize) -> f64 {
50        self.out_weights
51            .as_ref()
52            .and_then(|ws| ws.get(edge_idx).copied())
53            .unwrap_or(1.0)
54    }
55
56    /// Get the weight of the i-th inbound edge to a node (dense CSR only).
57    pub fn in_edge_weight(&self, edge_idx: usize) -> f64 {
58        self.in_weights
59            .as_ref()
60            .and_then(|ws| ws.get(edge_idx).copied())
61            .unwrap_or(1.0)
62    }
63
64    /// Get the weight of a specific outbound edge from `src` to `dst` via `label`.
65    ///
66    /// Checks both dense and buffer. Returns 1.0 if the edge exists but has
67    /// no weight, or `None` if the edge doesn't exist.
68    pub fn edge_weight(&self, src: &str, label: &str, dst: &str) -> Option<f64> {
69        let src_id = *self.node_to_id.get(src)?;
70        let dst_id = *self.node_to_id.get(dst)?;
71        let label_id = *self.label_to_id.get(label)?;
72
73        // Check dense CSR.
74        let idx = src_id as usize;
75        if idx + 1 < self.out_offsets.len() {
76            let start = self.out_offsets[idx] as usize;
77            let end = self.out_offsets[idx + 1] as usize;
78            for i in start..end {
79                if self.out_labels[i] == label_id
80                    && self.out_targets[i] == dst_id
81                    && !self.deleted_edges.contains(&(src_id, label_id, dst_id))
82                {
83                    return Some(self.out_edge_weight(i));
84                }
85            }
86        }
87
88        // Check buffer.
89        if idx < self.buffer_out.len() {
90            for (buf_idx, &(l, d)) in self.buffer_out[idx].iter().enumerate() {
91                if l == label_id && d == dst_id {
92                    if self.has_weights {
93                        return Some(
94                            self.buffer_out_weights[idx]
95                                .get(buf_idx)
96                                .copied()
97                                .unwrap_or(1.0),
98                        );
99                    }
100                    return Some(1.0);
101                }
102            }
103        }
104
105        None
106    }
107
108    /// Iterate outbound edges of a node with weights: `(label_id, dst_id, weight)`.
109    ///
110    /// Yields from both dense CSR and buffer, excluding deleted edges.
111    /// Weights are 1.0 for unweighted graphs.
112    pub fn iter_out_edges_weighted(
113        &self,
114        node: LocalNodeId,
115    ) -> impl Iterator<Item = (u32, LocalNodeId, f64)> + '_ {
116        let node = node.raw(self.partition_tag);
117        let tag = self.partition_tag;
118        let idx = node as usize;
119
120        // Dense edges with weights.
121        let dense_start = if idx + 1 < self.out_offsets.len() {
122            self.out_offsets[idx] as usize
123        } else {
124            0
125        };
126        let dense_end = if idx + 1 < self.out_offsets.len() {
127            self.out_offsets[idx + 1] as usize
128        } else {
129            0
130        };
131
132        let dense = (dense_start..dense_end)
133            .map(move |i| {
134                let w = self.out_edge_weight(i);
135                (self.out_labels[i], self.out_targets[i], w)
136            })
137            .filter(move |&(lid, dst, _)| !self.deleted_edges.contains(&(node, lid, dst)))
138            .collect::<Vec<_>>();
139
140        // Buffer edges with weights.
141        let buffer = if idx < self.buffer_out.len() {
142            self.buffer_out[idx]
143                .iter()
144                .enumerate()
145                .map(|(buf_idx, &(lid, dst))| {
146                    let w = if self.has_weights {
147                        self.buffer_out_weights[idx]
148                            .get(buf_idx)
149                            .copied()
150                            .unwrap_or(1.0)
151                    } else {
152                        1.0
153                    };
154                    (lid, dst, w)
155                })
156                .collect::<Vec<_>>()
157        } else {
158            Vec::new()
159        };
160
161        dense
162            .into_iter()
163            .chain(buffer)
164            .map(move |(lid, dst, w)| (lid, LocalNodeId::new(dst, tag), w))
165    }
166
167    /// Raw u32 variant of `iter_out_edges_weighted`. In-partition
168    /// algorithm use only — see [`Self::iter_out_edges_raw`] for the
169    /// safety rationale.
170    pub fn iter_out_edges_weighted_raw(
171        &self,
172        node: u32,
173    ) -> impl Iterator<Item = (u32, u32, f64)> + '_ {
174        self.iter_out_edges_weighted(self.local(node))
175            .map(move |(lid, dst, w)| (lid, dst.raw(self.partition_tag), w))
176    }
177}
178
179/// Extract the `"weight"` property from MessagePack-encoded edge properties.
180///
181/// Returns 1.0 if properties are empty, malformed, or don't contain a
182/// `"weight"` key. Handles F64, F32, and integer weight values; other
183/// numeric types default to 1.0.
184pub fn extract_weight_from_properties(properties: &[u8]) -> f64 {
185    if properties.is_empty() {
186        return 1.0;
187    }
188    let Ok(val) = rmpv::decode::read_value(&mut &properties[..]) else {
189        return 1.0;
190    };
191    match val {
192        rmpv::Value::Map(entries) => {
193            for (k, v) in entries {
194                if let rmpv::Value::String(ref s) = k
195                    && s.as_str() == Some("weight")
196                {
197                    return match v {
198                        rmpv::Value::F64(f) => f,
199                        rmpv::Value::F32(f) => f as f64,
200                        rmpv::Value::Integer(i) => i.as_f64().unwrap_or(1.0),
201                        _ => 1.0,
202                    };
203                }
204            }
205            1.0
206        }
207        _ => 1.0,
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn unweighted_graph_has_no_weight_arrays() {
217        let mut csr = CsrIndex::new();
218        csr.add_edge("a", "L", "b").unwrap();
219        assert!(!csr.has_weights());
220        assert!(csr.out_weights.is_none());
221        assert!(csr.in_weights.is_none());
222    }
223
224    #[test]
225    fn weighted_edge_basic() {
226        let mut csr = CsrIndex::new();
227        csr.add_edge_weighted("a", "ROAD", "b", 5.0).unwrap();
228        csr.add_edge_weighted("b", "ROAD", "c", 3.0).unwrap();
229        csr.add_edge("c", "ROAD", "d").unwrap(); // unweighted → 1.0
230
231        assert!(csr.has_weights());
232        assert_eq!(csr.edge_weight("a", "ROAD", "b"), Some(5.0));
233        assert_eq!(csr.edge_weight("b", "ROAD", "c"), Some(3.0));
234        assert_eq!(csr.edge_weight("c", "ROAD", "d"), Some(1.0));
235        assert_eq!(csr.edge_weight("a", "ROAD", "c"), None);
236    }
237
238    #[test]
239    fn weighted_edges_survive_compaction() {
240        let mut csr = CsrIndex::new();
241        csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
242        csr.add_edge_weighted("b", "R", "c", 7.0).unwrap();
243        csr.add_edge("c", "R", "d").unwrap();
244
245        csr.compact();
246
247        assert!(csr.has_weights());
248        assert_eq!(csr.edge_weight("a", "R", "b"), Some(2.5));
249        assert_eq!(csr.edge_weight("b", "R", "c"), Some(7.0));
250        assert_eq!(csr.edge_weight("c", "R", "d"), Some(1.0));
251    }
252
253    #[test]
254    fn weighted_edge_remove_keeps_weights_consistent() {
255        let mut csr = CsrIndex::new();
256        csr.add_edge_weighted("a", "R", "b", 2.0).unwrap();
257        csr.add_edge_weighted("a", "R", "c", 3.0).unwrap();
258        csr.add_edge_weighted("a", "R", "d", 4.0).unwrap();
259
260        csr.remove_edge("a", "R", "c");
261
262        assert_eq!(csr.edge_weight("a", "R", "b"), Some(2.0));
263        assert_eq!(csr.edge_weight("a", "R", "c"), None);
264        assert_eq!(csr.edge_weight("a", "R", "d"), Some(4.0));
265    }
266
267    #[test]
268    fn iter_out_edges_weighted_returns_weights() {
269        let mut csr = CsrIndex::new();
270        csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
271        csr.add_edge_weighted("a", "R", "c", 7.0).unwrap();
272        csr.compact();
273
274        let edges: Vec<(u32, LocalNodeId, f64)> =
275            csr.iter_out_edges_weighted(csr.local(0)).collect();
276        assert_eq!(edges.len(), 2);
277
278        let weights: Vec<f64> = edges.iter().map(|e| e.2).collect();
279        assert!(weights.contains(&2.5));
280        assert!(weights.contains(&7.0));
281    }
282
283    #[test]
284    fn mixed_weighted_unweighted_backfill() {
285        let mut csr = CsrIndex::new();
286        csr.add_edge("a", "L", "b").unwrap();
287        csr.add_edge("b", "L", "c").unwrap();
288        assert!(!csr.has_weights());
289
290        csr.add_edge_weighted("c", "L", "d", 5.0).unwrap();
291        assert!(csr.has_weights());
292        assert_eq!(csr.edge_weight("a", "L", "b"), Some(1.0));
293        assert_eq!(csr.edge_weight("c", "L", "d"), Some(5.0));
294    }
295
296    #[test]
297    fn extract_weight_from_empty_properties() {
298        assert_eq!(extract_weight_from_properties(b""), 1.0);
299    }
300
301    #[test]
302    fn extract_weight_f64() {
303        let props = rmpv::Value::Map(vec![(
304            rmpv::Value::String("weight".into()),
305            rmpv::Value::F64(0.75),
306        )]);
307        let mut buf = Vec::new();
308        rmpv::encode::write_value(&mut buf, &props).unwrap();
309        assert_eq!(extract_weight_from_properties(&buf), 0.75);
310    }
311
312    #[test]
313    fn extract_weight_integer() {
314        let props = rmpv::Value::Map(vec![(
315            rmpv::Value::String("weight".into()),
316            rmpv::Value::Integer(rmpv::Integer::from(42)),
317        )]);
318        let mut buf = Vec::new();
319        rmpv::encode::write_value(&mut buf, &props).unwrap();
320        assert_eq!(extract_weight_from_properties(&buf), 42.0);
321    }
322
323    #[test]
324    fn extract_weight_missing_key() {
325        let props = rmpv::Value::Map(vec![(
326            rmpv::Value::String("color".into()),
327            rmpv::Value::String("red".into()),
328        )]);
329        let mut buf = Vec::new();
330        rmpv::encode::write_value(&mut buf, &props).unwrap();
331        assert_eq!(extract_weight_from_properties(&buf), 1.0);
332    }
333}