Skip to main content

nodedb_graph/csr/
weights.rs

1//! Edge weight management for the CSR index.
2//!
3//! Optional `f64` weight per edge stored in parallel arrays. `None` when the
4//! graph is entirely unweighted (zero memory overhead). Populated from the
5//! `"weight"` edge property at insertion time. Unweighted edges default to 1.0.
6
7use super::index::CsrIndex;
8
9impl CsrIndex {
10    /// Enable weight tracking. Backfills existing buffer entries with 1.0.
11    pub(crate) fn enable_weights(&mut self) {
12        self.has_weights = true;
13
14        // Backfill existing dense arrays with 1.0.
15        if !self.out_targets.is_empty() {
16            self.out_weights = Some(vec![1.0; self.out_targets.len()].into());
17        }
18        if !self.in_targets.is_empty() {
19            self.in_weights = Some(vec![1.0; self.in_targets.len()].into());
20        }
21
22        // Backfill existing buffer entries with 1.0.
23        for (buf, wbuf) in self
24            .buffer_out
25            .iter()
26            .zip(self.buffer_out_weights.iter_mut())
27        {
28            if wbuf.len() < buf.len() {
29                wbuf.resize(buf.len(), 1.0);
30            }
31        }
32        for (buf, wbuf) in self.buffer_in.iter().zip(self.buffer_in_weights.iter_mut()) {
33            if wbuf.len() < buf.len() {
34                wbuf.resize(buf.len(), 1.0);
35            }
36        }
37    }
38
39    /// Whether this CSR has any weighted edges.
40    pub fn has_weights(&self) -> bool {
41        self.has_weights
42    }
43
44    /// Get the weight of the i-th outbound edge from a node (dense CSR only).
45    ///
46    /// `edge_idx` is the absolute index into `out_targets`/`out_weights`.
47    /// Returns 1.0 for unweighted graphs.
48    pub fn out_edge_weight(&self, edge_idx: usize) -> f64 {
49        self.out_weights
50            .as_ref()
51            .and_then(|ws| ws.get(edge_idx).copied())
52            .unwrap_or(1.0)
53    }
54
55    /// Get the weight of the i-th inbound edge to a node (dense CSR only).
56    pub fn in_edge_weight(&self, edge_idx: usize) -> f64 {
57        self.in_weights
58            .as_ref()
59            .and_then(|ws| ws.get(edge_idx).copied())
60            .unwrap_or(1.0)
61    }
62
63    /// Get the weight of a specific outbound edge from `src` to `dst` via `label`.
64    ///
65    /// Checks both dense and buffer. Returns 1.0 if the edge exists but has
66    /// no weight, or `None` if the edge doesn't exist.
67    pub fn edge_weight(&self, src: &str, label: &str, dst: &str) -> Option<f64> {
68        let src_id = *self.node_to_id.get(src)?;
69        let dst_id = *self.node_to_id.get(dst)?;
70        let label_id = *self.label_to_id.get(label)?;
71
72        // Check dense CSR.
73        let idx = src_id as usize;
74        if idx + 1 < self.out_offsets.len() {
75            let start = self.out_offsets[idx] as usize;
76            let end = self.out_offsets[idx + 1] as usize;
77            for i in start..end {
78                if self.out_labels[i] == label_id
79                    && self.out_targets[i] == dst_id
80                    && !self.deleted_edges.contains(&(src_id, label_id, dst_id))
81                {
82                    return Some(self.out_edge_weight(i));
83                }
84            }
85        }
86
87        // Check buffer.
88        if idx < self.buffer_out.len() {
89            for (buf_idx, &(l, d)) in self.buffer_out[idx].iter().enumerate() {
90                if l == label_id && d == dst_id {
91                    if self.has_weights {
92                        return Some(
93                            self.buffer_out_weights[idx]
94                                .get(buf_idx)
95                                .copied()
96                                .unwrap_or(1.0),
97                        );
98                    }
99                    return Some(1.0);
100                }
101            }
102        }
103
104        None
105    }
106
107    /// Iterate outbound edges of a node with weights: `(label_id, dst_id, weight)`.
108    ///
109    /// Yields from both dense CSR and buffer, excluding deleted edges.
110    /// Weights are 1.0 for unweighted graphs.
111    pub fn iter_out_edges_weighted(&self, node: u32) -> impl Iterator<Item = (u32, u32, f64)> + '_ {
112        let idx = node as usize;
113
114        // Dense edges with weights.
115        let dense_start = if idx + 1 < self.out_offsets.len() {
116            self.out_offsets[idx] as usize
117        } else {
118            0
119        };
120        let dense_end = if idx + 1 < self.out_offsets.len() {
121            self.out_offsets[idx + 1] as usize
122        } else {
123            0
124        };
125
126        let dense = (dense_start..dense_end)
127            .map(move |i| {
128                let w = self.out_edge_weight(i);
129                (self.out_labels[i], self.out_targets[i], w)
130            })
131            .filter(move |&(lid, dst, _)| !self.deleted_edges.contains(&(node, lid, dst)))
132            .collect::<Vec<_>>();
133
134        // Buffer edges with weights.
135        let buffer = if idx < self.buffer_out.len() {
136            self.buffer_out[idx]
137                .iter()
138                .enumerate()
139                .map(|(buf_idx, &(lid, dst))| {
140                    let w = if self.has_weights {
141                        self.buffer_out_weights[idx]
142                            .get(buf_idx)
143                            .copied()
144                            .unwrap_or(1.0)
145                    } else {
146                        1.0
147                    };
148                    (lid, dst, w)
149                })
150                .collect::<Vec<_>>()
151        } else {
152            Vec::new()
153        };
154
155        dense.into_iter().chain(buffer)
156    }
157}
158
159/// Extract the `"weight"` property from MessagePack-encoded edge properties.
160///
161/// Returns 1.0 if properties are empty, malformed, or don't contain a
162/// `"weight"` key. Handles F64, F32, and integer weight values; other
163/// numeric types default to 1.0.
164pub fn extract_weight_from_properties(properties: &[u8]) -> f64 {
165    if properties.is_empty() {
166        return 1.0;
167    }
168    let Ok(val) = rmpv::decode::read_value(&mut &properties[..]) else {
169        return 1.0;
170    };
171    match val {
172        rmpv::Value::Map(entries) => {
173            for (k, v) in entries {
174                if let rmpv::Value::String(ref s) = k
175                    && s.as_str() == Some("weight")
176                {
177                    return match v {
178                        rmpv::Value::F64(f) => f,
179                        rmpv::Value::F32(f) => f as f64,
180                        rmpv::Value::Integer(i) => i.as_f64().unwrap_or(1.0),
181                        _ => 1.0,
182                    };
183                }
184            }
185            1.0
186        }
187        _ => 1.0,
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn unweighted_graph_has_no_weight_arrays() {
197        let mut csr = CsrIndex::new();
198        csr.add_edge("a", "L", "b").unwrap();
199        assert!(!csr.has_weights());
200        assert!(csr.out_weights.is_none());
201        assert!(csr.in_weights.is_none());
202    }
203
204    #[test]
205    fn weighted_edge_basic() {
206        let mut csr = CsrIndex::new();
207        csr.add_edge_weighted("a", "ROAD", "b", 5.0).unwrap();
208        csr.add_edge_weighted("b", "ROAD", "c", 3.0).unwrap();
209        csr.add_edge("c", "ROAD", "d").unwrap(); // unweighted → 1.0
210
211        assert!(csr.has_weights());
212        assert_eq!(csr.edge_weight("a", "ROAD", "b"), Some(5.0));
213        assert_eq!(csr.edge_weight("b", "ROAD", "c"), Some(3.0));
214        assert_eq!(csr.edge_weight("c", "ROAD", "d"), Some(1.0));
215        assert_eq!(csr.edge_weight("a", "ROAD", "c"), None);
216    }
217
218    #[test]
219    fn weighted_edges_survive_compaction() {
220        let mut csr = CsrIndex::new();
221        csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
222        csr.add_edge_weighted("b", "R", "c", 7.0).unwrap();
223        csr.add_edge("c", "R", "d").unwrap();
224
225        csr.compact();
226
227        assert!(csr.has_weights());
228        assert_eq!(csr.edge_weight("a", "R", "b"), Some(2.5));
229        assert_eq!(csr.edge_weight("b", "R", "c"), Some(7.0));
230        assert_eq!(csr.edge_weight("c", "R", "d"), Some(1.0));
231    }
232
233    #[test]
234    fn weighted_edge_remove_keeps_weights_consistent() {
235        let mut csr = CsrIndex::new();
236        csr.add_edge_weighted("a", "R", "b", 2.0).unwrap();
237        csr.add_edge_weighted("a", "R", "c", 3.0).unwrap();
238        csr.add_edge_weighted("a", "R", "d", 4.0).unwrap();
239
240        csr.remove_edge("a", "R", "c");
241
242        assert_eq!(csr.edge_weight("a", "R", "b"), Some(2.0));
243        assert_eq!(csr.edge_weight("a", "R", "c"), None);
244        assert_eq!(csr.edge_weight("a", "R", "d"), Some(4.0));
245    }
246
247    #[test]
248    fn iter_out_edges_weighted_returns_weights() {
249        let mut csr = CsrIndex::new();
250        csr.add_edge_weighted("a", "R", "b", 2.5).unwrap();
251        csr.add_edge_weighted("a", "R", "c", 7.0).unwrap();
252        csr.compact();
253
254        let edges: Vec<(u32, u32, f64)> = csr.iter_out_edges_weighted(0).collect();
255        assert_eq!(edges.len(), 2);
256
257        let weights: Vec<f64> = edges.iter().map(|e| e.2).collect();
258        assert!(weights.contains(&2.5));
259        assert!(weights.contains(&7.0));
260    }
261
262    #[test]
263    fn mixed_weighted_unweighted_backfill() {
264        let mut csr = CsrIndex::new();
265        csr.add_edge("a", "L", "b").unwrap();
266        csr.add_edge("b", "L", "c").unwrap();
267        assert!(!csr.has_weights());
268
269        csr.add_edge_weighted("c", "L", "d", 5.0).unwrap();
270        assert!(csr.has_weights());
271        assert_eq!(csr.edge_weight("a", "L", "b"), Some(1.0));
272        assert_eq!(csr.edge_weight("c", "L", "d"), Some(5.0));
273    }
274
275    #[test]
276    fn extract_weight_from_empty_properties() {
277        assert_eq!(extract_weight_from_properties(b""), 1.0);
278    }
279
280    #[test]
281    fn extract_weight_f64() {
282        let props = rmpv::Value::Map(vec![(
283            rmpv::Value::String("weight".into()),
284            rmpv::Value::F64(0.75),
285        )]);
286        let mut buf = Vec::new();
287        rmpv::encode::write_value(&mut buf, &props).unwrap();
288        assert_eq!(extract_weight_from_properties(&buf), 0.75);
289    }
290
291    #[test]
292    fn extract_weight_integer() {
293        let props = rmpv::Value::Map(vec![(
294            rmpv::Value::String("weight".into()),
295            rmpv::Value::Integer(rmpv::Integer::from(42)),
296        )]);
297        let mut buf = Vec::new();
298        rmpv::encode::write_value(&mut buf, &props).unwrap();
299        assert_eq!(extract_weight_from_properties(&buf), 42.0);
300    }
301
302    #[test]
303    fn extract_weight_missing_key() {
304        let props = rmpv::Value::Map(vec![(
305            rmpv::Value::String("color".into()),
306            rmpv::Value::String("red".into()),
307        )]);
308        let mut buf = Vec::new();
309        rmpv::encode::write_value(&mut buf, &props).unwrap();
310        assert_eq!(extract_weight_from_properties(&buf), 1.0);
311    }
312}