1use std::collections::HashMap;
2
3use arrow_array::StringArray;
4use futures::TryStreamExt;
5
6use crate::db::Snapshot;
7use crate::error::{OmniError, Result};
8
9#[derive(Debug, Clone)]
11pub struct TypeIndex {
12 id_to_dense: HashMap<String, u32>,
13 dense_to_id: Vec<String>,
14}
15
16impl TypeIndex {
17 pub(crate) fn new() -> Self {
18 Self {
19 id_to_dense: HashMap::new(),
20 dense_to_id: Vec::new(),
21 }
22 }
23
24 pub(crate) fn get_or_insert(&mut self, id: &str) -> u32 {
26 if let Some(&idx) = self.id_to_dense.get(id) {
27 return idx;
28 }
29 let idx = self.dense_to_id.len() as u32;
30 self.dense_to_id.push(id.to_string());
31 self.id_to_dense.insert(id.to_string(), idx);
32 idx
33 }
34
35 pub fn to_dense(&self, id: &str) -> Option<u32> {
36 self.id_to_dense.get(id).copied()
37 }
38
39 pub fn to_id(&self, dense: u32) -> Option<&str> {
40 self.dense_to_id.get(dense as usize).map(|s| s.as_str())
41 }
42
43 pub fn len(&self) -> usize {
44 self.dense_to_id.len()
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct CsrIndex {
51 offsets: Vec<u32>,
53 targets: Vec<u32>,
55}
56
57impl CsrIndex {
58 pub(crate) fn build(num_nodes: usize, edges: &[(u32, u32)]) -> Self {
59 let mut counts = vec![0u32; num_nodes];
61 for &(src, _) in edges {
62 counts[src as usize] += 1;
63 }
64
65 let mut offsets = Vec::with_capacity(num_nodes + 1);
67 offsets.push(0);
68 for &c in &counts {
69 offsets.push(offsets.last().unwrap() + c);
70 }
71
72 let mut targets = vec![0u32; edges.len()];
74 let mut cursors = vec![0u32; num_nodes];
75 for &(src, dst) in edges {
76 let s = src as usize;
77 let pos = offsets[s] + cursors[s];
78 targets[pos as usize] = dst;
79 cursors[s] += 1;
80 }
81
82 Self { offsets, targets }
83 }
84
85 pub fn neighbors(&self, node: u32) -> &[u32] {
87 let start = self.offsets[node as usize] as usize;
88 let end = self.offsets[node as usize + 1] as usize;
89 &self.targets[start..end]
90 }
91
92 pub fn has_neighbors(&self, node: u32) -> bool {
94 let n = node as usize;
95 self.offsets[n + 1] > self.offsets[n]
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct GraphIndex {
102 type_indices: HashMap<String, TypeIndex>,
104 csr: HashMap<String, CsrIndex>,
106 csc: HashMap<String, CsrIndex>,
108}
109
110impl GraphIndex {
111 pub async fn build(
113 snapshot: &Snapshot,
114 edge_types: &HashMap<String, (String, String)>, ) -> Result<Self> {
116 let mut type_indices: HashMap<String, TypeIndex> = HashMap::new();
117 let mut csr = HashMap::new();
118 let mut csc = HashMap::new();
119
120 let mut edge_pairs: HashMap<String, Vec<(u32, u32)>> = HashMap::new();
122
123 for (edge_name, (from_type, to_type)) in edge_types {
124 let table_key = format!("edge:{}", edge_name);
125 if snapshot.entry(&table_key).is_none() {
126 continue;
127 }
128
129 let ds = snapshot.open(&table_key).await?;
130
131 let batches: Vec<arrow_array::RecordBatch> = ds
132 .scan()
133 .project(&["src", "dst"])
134 .map_err(|e| OmniError::Lance(e.to_string()))?
135 .try_into_stream()
136 .await
137 .map_err(|e| OmniError::Lance(e.to_string()))?
138 .try_collect()
139 .await
140 .map_err(|e| OmniError::Lance(e.to_string()))?;
141
142 type_indices
143 .entry(from_type.clone())
144 .or_insert_with(TypeIndex::new);
145 type_indices
146 .entry(to_type.clone())
147 .or_insert_with(TypeIndex::new);
148
149 let mut edges: Vec<(u32, u32)> = Vec::new();
150 for batch in &batches {
151 let srcs = string_column(batch, "src")?;
152 let dsts = string_column(batch, "dst")?;
153
154 for i in 0..batch.num_rows() {
155 let src_dense = type_indices
156 .get_mut(from_type)
157 .unwrap()
158 .get_or_insert(srcs.value(i));
159 let dst_dense = type_indices
160 .get_mut(to_type)
161 .unwrap()
162 .get_or_insert(dsts.value(i));
163 edges.push((src_dense, dst_dense));
164 }
165 }
166 edge_pairs.insert(edge_name.clone(), edges);
167 }
168
169 for (edge_name, (from_type, to_type)) in edge_types {
171 let Some(edges) = edge_pairs.get(edge_name) else {
172 continue;
173 };
174
175 let src_count = type_indices[from_type].len();
176 let dst_count = type_indices[to_type].len();
177
178 csr.insert(edge_name.clone(), CsrIndex::build(src_count, edges));
179
180 let reversed: Vec<(u32, u32)> = edges.iter().map(|&(s, d)| (d, s)).collect();
181 csc.insert(edge_name.clone(), CsrIndex::build(dst_count, &reversed));
182 }
183
184 Ok(Self {
185 type_indices,
186 csr,
187 csc,
188 })
189 }
190
191 pub fn type_index(&self, type_name: &str) -> Option<&TypeIndex> {
192 self.type_indices.get(type_name)
193 }
194
195 pub fn csr(&self, edge_type: &str) -> Option<&CsrIndex> {
196 self.csr.get(edge_type)
197 }
198
199 pub fn csc(&self, edge_type: &str) -> Option<&CsrIndex> {
200 self.csc.get(edge_type)
201 }
202
203 #[cfg(test)]
204 pub(crate) fn empty_for_test() -> Self {
205 Self {
206 type_indices: HashMap::new(),
207 csr: HashMap::new(),
208 csc: HashMap::new(),
209 }
210 }
211}
212
213fn string_column<'a>(batch: &'a arrow_array::RecordBatch, name: &str) -> Result<&'a StringArray> {
214 batch
215 .column_by_name(name)
216 .ok_or_else(|| {
217 OmniError::manifest_internal(format!("graph index batch missing '{name}' column"))
218 })?
219 .as_any()
220 .downcast_ref::<StringArray>()
221 .ok_or_else(|| {
222 OmniError::manifest_internal(format!("graph index column '{name}' is not Utf8"))
223 })
224}
225
226#[cfg(test)]
227mod tests {
228 use std::sync::Arc;
229
230 use arrow_array::UInt64Array;
231 use arrow_schema::{DataType, Field, Schema};
232
233 use super::*;
234
235 #[test]
236 fn type_index_round_trip() {
237 let mut idx = TypeIndex::new();
238 let a = idx.get_or_insert("Alice");
239 let b = idx.get_or_insert("Bob");
240 let c = idx.get_or_insert("Charlie");
241
242 assert_eq!(idx.to_dense("Alice"), Some(a));
243 assert_eq!(idx.to_dense("Bob"), Some(b));
244 assert_eq!(idx.to_dense("Charlie"), Some(c));
245
246 assert_eq!(idx.to_id(a), Some("Alice"));
247 assert_eq!(idx.to_id(b), Some("Bob"));
248 assert_eq!(idx.to_id(c), Some("Charlie"));
249 assert_eq!(idx.len(), 3);
250 }
251
252 #[test]
253 fn type_index_idempotent_insert() {
254 let mut idx = TypeIndex::new();
255 let a1 = idx.get_or_insert("Alice");
256 let a2 = idx.get_or_insert("Alice");
257 assert_eq!(a1, a2);
258 assert_eq!(idx.len(), 1);
259 }
260
261 #[test]
262 fn type_index_unknown_returns_none() {
263 let idx = TypeIndex::new();
264 assert_eq!(idx.to_dense("unknown"), None);
265 assert_eq!(idx.to_id(999), None);
266 }
267
268 #[test]
269 fn csr_neighbors_correct() {
270 let edges = vec![(0, 1), (0, 2), (1, 2)];
272 let csr = CsrIndex::build(3, &edges);
273
274 let mut n0: Vec<u32> = csr.neighbors(0).to_vec();
275 n0.sort();
276 assert_eq!(n0, vec![1, 2]);
277
278 assert_eq!(csr.neighbors(1), &[2]);
279 assert_eq!(csr.neighbors(2), &[] as &[u32]);
280 }
281
282 #[test]
283 fn csr_empty_graph() {
284 let csr = CsrIndex::build(3, &[]);
285 assert_eq!(csr.neighbors(0), &[] as &[u32]);
286 assert_eq!(csr.neighbors(1), &[] as &[u32]);
287 assert_eq!(csr.neighbors(2), &[] as &[u32]);
288 assert!(!csr.has_neighbors(0));
289 }
290
291 #[test]
292 fn csr_has_neighbors() {
293 let csr = CsrIndex::build(3, &[(0, 1), (1, 2)]);
295 assert!(csr.has_neighbors(0));
296 assert!(csr.has_neighbors(1));
297 assert!(!csr.has_neighbors(2));
298 }
299
300 #[test]
301 fn string_column_returns_error_for_bad_schema() {
302 let batch = arrow_array::RecordBatch::try_new(
303 Arc::new(Schema::new(vec![Field::new(
304 "src",
305 DataType::UInt64,
306 false,
307 )])),
308 vec![Arc::new(UInt64Array::from(vec![1_u64]))],
309 )
310 .unwrap();
311
312 let err = string_column(&batch, "src").unwrap_err();
313 assert!(err.to_string().contains("src"));
314 }
315}