use std::collections::HashMap;
use arrow_array::StringArray;
use futures::TryStreamExt;
use crate::db::Snapshot;
use crate::error::{OmniError, Result};
#[derive(Debug, Clone)]
pub struct TypeIndex {
id_to_dense: HashMap<String, u32>,
dense_to_id: Vec<String>,
}
impl TypeIndex {
pub(crate) fn new() -> Self {
Self {
id_to_dense: HashMap::new(),
dense_to_id: Vec::new(),
}
}
pub(crate) fn get_or_insert(&mut self, id: &str) -> u32 {
if let Some(&idx) = self.id_to_dense.get(id) {
return idx;
}
let idx = self.dense_to_id.len() as u32;
self.dense_to_id.push(id.to_string());
self.id_to_dense.insert(id.to_string(), idx);
idx
}
pub fn to_dense(&self, id: &str) -> Option<u32> {
self.id_to_dense.get(id).copied()
}
pub fn to_id(&self, dense: u32) -> Option<&str> {
self.dense_to_id.get(dense as usize).map(|s| s.as_str())
}
pub fn len(&self) -> usize {
self.dense_to_id.len()
}
}
#[derive(Debug, Clone)]
pub struct CsrIndex {
offsets: Vec<u32>,
targets: Vec<u32>,
}
impl CsrIndex {
pub(crate) fn build(num_nodes: usize, edges: &[(u32, u32)]) -> Self {
let mut counts = vec![0u32; num_nodes];
for &(src, _) in edges {
counts[src as usize] += 1;
}
let mut offsets = Vec::with_capacity(num_nodes + 1);
offsets.push(0);
for &c in &counts {
offsets.push(offsets.last().unwrap() + c);
}
let mut targets = vec![0u32; edges.len()];
let mut cursors = vec![0u32; num_nodes];
for &(src, dst) in edges {
let s = src as usize;
let pos = offsets[s] + cursors[s];
targets[pos as usize] = dst;
cursors[s] += 1;
}
Self { offsets, targets }
}
pub fn neighbors(&self, node: u32) -> &[u32] {
let start = self.offsets[node as usize] as usize;
let end = self.offsets[node as usize + 1] as usize;
&self.targets[start..end]
}
pub fn has_neighbors(&self, node: u32) -> bool {
let n = node as usize;
self.offsets[n + 1] > self.offsets[n]
}
}
#[derive(Debug, Clone)]
pub struct GraphIndex {
type_indices: HashMap<String, TypeIndex>,
csr: HashMap<String, CsrIndex>,
csc: HashMap<String, CsrIndex>,
}
impl GraphIndex {
pub async fn build(
snapshot: &Snapshot,
edge_types: &HashMap<String, (String, String)>, ) -> Result<Self> {
let mut type_indices: HashMap<String, TypeIndex> = HashMap::new();
let mut csr = HashMap::new();
let mut csc = HashMap::new();
let mut edge_pairs: HashMap<String, Vec<(u32, u32)>> = HashMap::new();
for (edge_name, (from_type, to_type)) in edge_types {
let table_key = format!("edge:{}", edge_name);
if snapshot.entry(&table_key).is_none() {
continue;
}
let ds = snapshot.open(&table_key).await?;
let batches: Vec<arrow_array::RecordBatch> = ds
.scan()
.project(&["src", "dst"])
.map_err(|e| OmniError::Lance(e.to_string()))?
.try_into_stream()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?
.try_collect()
.await
.map_err(|e| OmniError::Lance(e.to_string()))?;
type_indices
.entry(from_type.clone())
.or_insert_with(TypeIndex::new);
type_indices
.entry(to_type.clone())
.or_insert_with(TypeIndex::new);
let mut edges: Vec<(u32, u32)> = Vec::new();
for batch in &batches {
let srcs = string_column(batch, "src")?;
let dsts = string_column(batch, "dst")?;
for i in 0..batch.num_rows() {
let src_dense = type_indices
.get_mut(from_type)
.unwrap()
.get_or_insert(srcs.value(i));
let dst_dense = type_indices
.get_mut(to_type)
.unwrap()
.get_or_insert(dsts.value(i));
edges.push((src_dense, dst_dense));
}
}
edge_pairs.insert(edge_name.clone(), edges);
}
for (edge_name, (from_type, to_type)) in edge_types {
let Some(edges) = edge_pairs.get(edge_name) else {
continue;
};
let src_count = type_indices[from_type].len();
let dst_count = type_indices[to_type].len();
csr.insert(edge_name.clone(), CsrIndex::build(src_count, edges));
let reversed: Vec<(u32, u32)> = edges.iter().map(|&(s, d)| (d, s)).collect();
csc.insert(edge_name.clone(), CsrIndex::build(dst_count, &reversed));
}
Ok(Self {
type_indices,
csr,
csc,
})
}
pub fn type_index(&self, type_name: &str) -> Option<&TypeIndex> {
self.type_indices.get(type_name)
}
pub fn csr(&self, edge_type: &str) -> Option<&CsrIndex> {
self.csr.get(edge_type)
}
pub fn csc(&self, edge_type: &str) -> Option<&CsrIndex> {
self.csc.get(edge_type)
}
#[cfg(test)]
pub(crate) fn empty_for_test() -> Self {
Self {
type_indices: HashMap::new(),
csr: HashMap::new(),
csc: HashMap::new(),
}
}
}
fn string_column<'a>(batch: &'a arrow_array::RecordBatch, name: &str) -> Result<&'a StringArray> {
batch
.column_by_name(name)
.ok_or_else(|| {
OmniError::manifest_internal(format!("graph index batch missing '{name}' column"))
})?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
OmniError::manifest_internal(format!("graph index column '{name}' is not Utf8"))
})
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_array::UInt64Array;
use arrow_schema::{DataType, Field, Schema};
use super::*;
#[test]
fn type_index_round_trip() {
let mut idx = TypeIndex::new();
let a = idx.get_or_insert("Alice");
let b = idx.get_or_insert("Bob");
let c = idx.get_or_insert("Charlie");
assert_eq!(idx.to_dense("Alice"), Some(a));
assert_eq!(idx.to_dense("Bob"), Some(b));
assert_eq!(idx.to_dense("Charlie"), Some(c));
assert_eq!(idx.to_id(a), Some("Alice"));
assert_eq!(idx.to_id(b), Some("Bob"));
assert_eq!(idx.to_id(c), Some("Charlie"));
assert_eq!(idx.len(), 3);
}
#[test]
fn type_index_idempotent_insert() {
let mut idx = TypeIndex::new();
let a1 = idx.get_or_insert("Alice");
let a2 = idx.get_or_insert("Alice");
assert_eq!(a1, a2);
assert_eq!(idx.len(), 1);
}
#[test]
fn type_index_unknown_returns_none() {
let idx = TypeIndex::new();
assert_eq!(idx.to_dense("unknown"), None);
assert_eq!(idx.to_id(999), None);
}
#[test]
fn csr_neighbors_correct() {
let edges = vec![(0, 1), (0, 2), (1, 2)];
let csr = CsrIndex::build(3, &edges);
let mut n0: Vec<u32> = csr.neighbors(0).to_vec();
n0.sort();
assert_eq!(n0, vec![1, 2]);
assert_eq!(csr.neighbors(1), &[2]);
assert_eq!(csr.neighbors(2), &[] as &[u32]);
}
#[test]
fn csr_empty_graph() {
let csr = CsrIndex::build(3, &[]);
assert_eq!(csr.neighbors(0), &[] as &[u32]);
assert_eq!(csr.neighbors(1), &[] as &[u32]);
assert_eq!(csr.neighbors(2), &[] as &[u32]);
assert!(!csr.has_neighbors(0));
}
#[test]
fn csr_has_neighbors() {
let csr = CsrIndex::build(3, &[(0, 1), (1, 2)]);
assert!(csr.has_neighbors(0));
assert!(csr.has_neighbors(1));
assert!(!csr.has_neighbors(2));
}
#[test]
fn string_column_returns_error_for_bad_schema() {
let batch = arrow_array::RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new(
"src",
DataType::UInt64,
false,
)])),
vec![Arc::new(UInt64Array::from(vec![1_u64]))],
)
.unwrap();
let err = string_column(&batch, "src").unwrap_err();
assert!(err.to_string().contains("src"));
}
}