Skip to main content

omnigraph/graph_index/
mod.rs

1use std::collections::HashMap;
2
3use arrow_array::StringArray;
4use futures::TryStreamExt;
5
6use crate::db::Snapshot;
7use crate::error::{OmniError, Result};
8
9/// Dense u32 mapping for a single node type: String ID ↔ dense index.
10#[derive(Debug, Clone)]
11pub struct TypeIndex {
12    id_to_dense: HashMap<String, u32>,
13    dense_to_id: Vec<String>,
14}
15
16impl TypeIndex {
17    pub(crate) fn new() -> Self {
18        Self {
19            id_to_dense: HashMap::new(),
20            dense_to_id: Vec::new(),
21        }
22    }
23
24    /// Get or insert a string ID, returning its dense index.
25    pub(crate) fn get_or_insert(&mut self, id: &str) -> u32 {
26        if let Some(&idx) = self.id_to_dense.get(id) {
27            return idx;
28        }
29        let idx = self.dense_to_id.len() as u32;
30        self.dense_to_id.push(id.to_string());
31        self.id_to_dense.insert(id.to_string(), idx);
32        idx
33    }
34
35    pub fn to_dense(&self, id: &str) -> Option<u32> {
36        self.id_to_dense.get(id).copied()
37    }
38
39    pub fn to_id(&self, dense: u32) -> Option<&str> {
40        self.dense_to_id.get(dense as usize).map(|s| s.as_str())
41    }
42
43    pub fn len(&self) -> usize {
44        self.dense_to_id.len()
45    }
46}
47
48/// CSR (Compressed Sparse Row) adjacency index.
49#[derive(Debug, Clone)]
50pub struct CsrIndex {
51    /// offsets[i] .. offsets[i+1] gives the neighbor range for node i.
52    offsets: Vec<u32>,
53    /// Dense indices of destination nodes.
54    targets: Vec<u32>,
55}
56
57impl CsrIndex {
58    pub(crate) fn build(num_nodes: usize, edges: &[(u32, u32)]) -> Self {
59        // Count outgoing edges per source
60        let mut counts = vec![0u32; num_nodes];
61        for &(src, _) in edges {
62            counts[src as usize] += 1;
63        }
64
65        // Build offset array (prefix sum)
66        let mut offsets = Vec::with_capacity(num_nodes + 1);
67        offsets.push(0);
68        for &c in &counts {
69            offsets.push(offsets.last().unwrap() + c);
70        }
71
72        // Fill targets
73        let mut targets = vec![0u32; edges.len()];
74        let mut cursors = vec![0u32; num_nodes];
75        for &(src, dst) in edges {
76            let s = src as usize;
77            let pos = offsets[s] + cursors[s];
78            targets[pos as usize] = dst;
79            cursors[s] += 1;
80        }
81
82        Self { offsets, targets }
83    }
84
85    /// Return the dense indices of neighbors for a given dense node index.
86    pub fn neighbors(&self, node: u32) -> &[u32] {
87        let start = self.offsets[node as usize] as usize;
88        let end = self.offsets[node as usize + 1] as usize;
89        &self.targets[start..end]
90    }
91
92    /// Check if a node has any outgoing edges. O(1), no allocation.
93    pub fn has_neighbors(&self, node: u32) -> bool {
94        let n = node as usize;
95        self.offsets[n + 1] > self.offsets[n]
96    }
97}
98
99/// Topology-only graph index. No node data cached — just adjacency.
100#[derive(Debug, Clone)]
101pub struct GraphIndex {
102    /// Dense index per node type (built from edge src/dst columns).
103    type_indices: HashMap<String, TypeIndex>,
104    /// Outgoing adjacency per edge type.
105    csr: HashMap<String, CsrIndex>,
106    /// Incoming adjacency per edge type.
107    csc: HashMap<String, CsrIndex>,
108}
109
110impl GraphIndex {
111    /// Build a graph index by scanning edge sub-tables from a snapshot.
112    pub async fn build(
113        snapshot: &Snapshot,
114        edge_types: &HashMap<String, (String, String)>, // edge_name → (from_type, to_type)
115    ) -> Result<Self> {
116        let mut type_indices: HashMap<String, TypeIndex> = HashMap::new();
117        let mut csr = HashMap::new();
118        let mut csc = HashMap::new();
119
120        // Phase 1: Scan all edges, build TypeIndices and collect edge pairs
121        let mut edge_pairs: HashMap<String, Vec<(u32, u32)>> = HashMap::new();
122
123        for (edge_name, (from_type, to_type)) in edge_types {
124            let table_key = format!("edge:{}", edge_name);
125            if snapshot.entry(&table_key).is_none() {
126                continue;
127            }
128
129            let ds = snapshot.open(&table_key).await?;
130
131            let batches: Vec<arrow_array::RecordBatch> = ds
132                .scan()
133                .project(&["src", "dst"])
134                .map_err(|e| OmniError::Lance(e.to_string()))?
135                .try_into_stream()
136                .await
137                .map_err(|e| OmniError::Lance(e.to_string()))?
138                .try_collect()
139                .await
140                .map_err(|e| OmniError::Lance(e.to_string()))?;
141
142            type_indices
143                .entry(from_type.clone())
144                .or_insert_with(TypeIndex::new);
145            type_indices
146                .entry(to_type.clone())
147                .or_insert_with(TypeIndex::new);
148
149            let mut edges: Vec<(u32, u32)> = Vec::new();
150            for batch in &batches {
151                let srcs = string_column(batch, "src")?;
152                let dsts = string_column(batch, "dst")?;
153
154                for i in 0..batch.num_rows() {
155                    let src_dense = type_indices
156                        .get_mut(from_type)
157                        .unwrap()
158                        .get_or_insert(srcs.value(i));
159                    let dst_dense = type_indices
160                        .get_mut(to_type)
161                        .unwrap()
162                        .get_or_insert(dsts.value(i));
163                    edges.push((src_dense, dst_dense));
164                }
165            }
166            edge_pairs.insert(edge_name.clone(), edges);
167        }
168
169        // Phase 2: Build CSR/CSC using final TypeIndex sizes
170        for (edge_name, (from_type, to_type)) in edge_types {
171            let Some(edges) = edge_pairs.get(edge_name) else {
172                continue;
173            };
174
175            let src_count = type_indices[from_type].len();
176            let dst_count = type_indices[to_type].len();
177
178            csr.insert(edge_name.clone(), CsrIndex::build(src_count, edges));
179
180            let reversed: Vec<(u32, u32)> = edges.iter().map(|&(s, d)| (d, s)).collect();
181            csc.insert(edge_name.clone(), CsrIndex::build(dst_count, &reversed));
182        }
183
184        Ok(Self {
185            type_indices,
186            csr,
187            csc,
188        })
189    }
190
191    pub fn type_index(&self, type_name: &str) -> Option<&TypeIndex> {
192        self.type_indices.get(type_name)
193    }
194
195    pub fn csr(&self, edge_type: &str) -> Option<&CsrIndex> {
196        self.csr.get(edge_type)
197    }
198
199    pub fn csc(&self, edge_type: &str) -> Option<&CsrIndex> {
200        self.csc.get(edge_type)
201    }
202
203    #[cfg(test)]
204    pub(crate) fn empty_for_test() -> Self {
205        Self {
206            type_indices: HashMap::new(),
207            csr: HashMap::new(),
208            csc: HashMap::new(),
209        }
210    }
211}
212
213fn string_column<'a>(batch: &'a arrow_array::RecordBatch, name: &str) -> Result<&'a StringArray> {
214    batch
215        .column_by_name(name)
216        .ok_or_else(|| {
217            OmniError::manifest_internal(format!("graph index batch missing '{name}' column"))
218        })?
219        .as_any()
220        .downcast_ref::<StringArray>()
221        .ok_or_else(|| {
222            OmniError::manifest_internal(format!("graph index column '{name}' is not Utf8"))
223        })
224}
225
226#[cfg(test)]
227mod tests {
228    use std::sync::Arc;
229
230    use arrow_array::UInt64Array;
231    use arrow_schema::{DataType, Field, Schema};
232
233    use super::*;
234
235    #[test]
236    fn type_index_round_trip() {
237        let mut idx = TypeIndex::new();
238        let a = idx.get_or_insert("Alice");
239        let b = idx.get_or_insert("Bob");
240        let c = idx.get_or_insert("Charlie");
241
242        assert_eq!(idx.to_dense("Alice"), Some(a));
243        assert_eq!(idx.to_dense("Bob"), Some(b));
244        assert_eq!(idx.to_dense("Charlie"), Some(c));
245
246        assert_eq!(idx.to_id(a), Some("Alice"));
247        assert_eq!(idx.to_id(b), Some("Bob"));
248        assert_eq!(idx.to_id(c), Some("Charlie"));
249        assert_eq!(idx.len(), 3);
250    }
251
252    #[test]
253    fn type_index_idempotent_insert() {
254        let mut idx = TypeIndex::new();
255        let a1 = idx.get_or_insert("Alice");
256        let a2 = idx.get_or_insert("Alice");
257        assert_eq!(a1, a2);
258        assert_eq!(idx.len(), 1);
259    }
260
261    #[test]
262    fn type_index_unknown_returns_none() {
263        let idx = TypeIndex::new();
264        assert_eq!(idx.to_dense("unknown"), None);
265        assert_eq!(idx.to_id(999), None);
266    }
267
268    #[test]
269    fn csr_neighbors_correct() {
270        // Graph: 0→1, 0→2, 1→2
271        let edges = vec![(0, 1), (0, 2), (1, 2)];
272        let csr = CsrIndex::build(3, &edges);
273
274        let mut n0: Vec<u32> = csr.neighbors(0).to_vec();
275        n0.sort();
276        assert_eq!(n0, vec![1, 2]);
277
278        assert_eq!(csr.neighbors(1), &[2]);
279        assert_eq!(csr.neighbors(2), &[] as &[u32]);
280    }
281
282    #[test]
283    fn csr_empty_graph() {
284        let csr = CsrIndex::build(3, &[]);
285        assert_eq!(csr.neighbors(0), &[] as &[u32]);
286        assert_eq!(csr.neighbors(1), &[] as &[u32]);
287        assert_eq!(csr.neighbors(2), &[] as &[u32]);
288        assert!(!csr.has_neighbors(0));
289    }
290
291    #[test]
292    fn csr_has_neighbors() {
293        // 0→1, 1→2
294        let csr = CsrIndex::build(3, &[(0, 1), (1, 2)]);
295        assert!(csr.has_neighbors(0));
296        assert!(csr.has_neighbors(1));
297        assert!(!csr.has_neighbors(2));
298    }
299
300    #[test]
301    fn string_column_returns_error_for_bad_schema() {
302        let batch = arrow_array::RecordBatch::try_new(
303            Arc::new(Schema::new(vec![Field::new(
304                "src",
305                DataType::UInt64,
306                false,
307            )])),
308            vec![Arc::new(UInt64Array::from(vec![1_u64]))],
309        )
310        .unwrap();
311
312        let err = string_column(&batch, "src").unwrap_err();
313        assert!(err.to_string().contains("src"));
314    }
315}