use grafeo_common::types::{EdgeId, NodeId};
use grafeo_common::utils::hash::FxHashMap;
use smallvec::SmallVec;
#[derive(Debug, Clone)]
struct TrieNode {
children: FxHashMap<NodeId, TrieNode>,
edges: SmallVec<[EdgeId; 4]>,
}
impl TrieNode {
fn new() -> Self {
Self {
children: FxHashMap::default(),
edges: SmallVec::new(),
}
}
fn insert(&mut self, path: &[NodeId], edge_id: EdgeId) {
if path.is_empty() {
self.edges.push(edge_id);
return;
}
self.children
.entry(path[0])
.or_insert_with(TrieNode::new)
.insert(&path[1..], edge_id);
}
fn get_child(&self, key: NodeId) -> Option<&TrieNode> {
self.children.get(&key)
}
fn children_sorted(&self) -> Vec<NodeId> {
let mut keys: Vec<_> = self.children.keys().copied().collect();
keys.sort();
keys
}
}
pub struct TrieIndex {
root: TrieNode,
size: usize,
}
impl TrieIndex {
#[must_use]
pub fn new() -> Self {
Self {
root: TrieNode::new(),
size: 0,
}
}
pub fn insert(&mut self, path: &[NodeId], edge_id: EdgeId) {
self.root.insert(path, edge_id);
self.size += 1;
}
pub fn insert_edge(&mut self, src: NodeId, dst: NodeId, edge_id: EdgeId) {
self.insert(&[src, dst], edge_id);
}
pub fn len(&self) -> usize {
self.size
}
pub fn is_empty(&self) -> bool {
self.size == 0
}
#[allow(clippy::iter_not_returning_iterator)] pub fn iter(&self) -> TrieIterator<'_> {
TrieIterator::new(&self.root)
}
pub fn iter_at(&self, path: &[NodeId]) -> Option<TrieIterator<'_>> {
let mut node = &self.root;
for &key in path {
node = node.get_child(key)?;
}
Some(TrieIterator::new(node))
}
pub fn get(&self, path: &[NodeId]) -> Option<&[EdgeId]> {
let mut node = &self.root;
for &key in path {
node = node.get_child(key)?;
}
if node.edges.is_empty() {
None
} else {
Some(&node.edges)
}
}
}
impl Default for TrieIndex {
fn default() -> Self {
Self::new()
}
}
pub struct TrieIterator<'a> {
node: &'a TrieNode,
keys: Vec<NodeId>,
pos: usize,
}
impl<'a> TrieIterator<'a> {
fn new(node: &'a TrieNode) -> Self {
let keys = node.children_sorted();
Self { node, keys, pos: 0 }
}
pub fn key(&self) -> Option<NodeId> {
self.keys.get(self.pos).copied()
}
pub fn next(&mut self) -> bool {
if self.pos < self.keys.len() {
self.pos += 1;
self.pos < self.keys.len()
} else {
false
}
}
pub fn seek(&mut self, target: NodeId) -> bool {
match self.keys[self.pos..].binary_search(&target) {
Ok(offset) => {
self.pos += offset;
true
}
Err(offset) => {
self.pos += offset;
self.pos < self.keys.len()
}
}
}
pub fn open(&self) -> Option<TrieIterator<'a>> {
let key = self.key()?;
let child = self.node.get_child(key)?;
Some(TrieIterator::new(child))
}
pub fn is_valid(&self) -> bool {
self.pos < self.keys.len()
}
}
pub struct LeapfrogJoin<'a> {
iters: Vec<TrieIterator<'a>>,
current_key: Option<NodeId>,
}
impl<'a> LeapfrogJoin<'a> {
pub fn new(iters: Vec<TrieIterator<'a>>) -> Self {
let mut join = Self {
iters,
current_key: None,
};
join.init();
join
}
fn init(&mut self) {
if self.iters.is_empty() {
return;
}
self.iters.sort_by_key(|it| it.key());
self.search();
}
fn search(&mut self) {
if self.iters.is_empty() || !self.iters[0].is_valid() {
self.current_key = None;
return;
}
loop {
let max_key = self.iters.last().and_then(|it| it.key());
let min_key = self.iters.first().and_then(|it| it.key());
match (min_key, max_key) {
(Some(min), Some(max)) if min == max => {
self.current_key = Some(min);
return;
}
(Some(_), Some(max)) => {
if !self.iters[0].seek(max) {
self.current_key = None;
return;
}
self.iters.sort_by_key(|it| it.key());
}
_ => {
self.current_key = None;
return;
}
}
}
}
pub fn key(&self) -> Option<NodeId> {
self.current_key
}
pub fn next(&mut self) -> bool {
if self.current_key.is_none() || self.iters.is_empty() {
return false;
}
self.iters[0].next();
self.iters.sort_by_key(|it| it.key());
self.search();
self.current_key.is_some()
}
pub fn open(&self) -> Option<Vec<TrieIterator<'a>>> {
self.current_key?;
self.iters.iter().map(|it| it.open()).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trie_basic() {
let mut trie = TrieIndex::new();
trie.insert_edge(NodeId::new(1), NodeId::new(2), EdgeId::new(0));
trie.insert_edge(NodeId::new(1), NodeId::new(3), EdgeId::new(1));
trie.insert_edge(NodeId::new(2), NodeId::new(3), EdgeId::new(2));
assert_eq!(trie.len(), 3);
}
#[test]
fn test_trie_iterator() {
let mut trie = TrieIndex::new();
trie.insert_edge(NodeId::new(1), NodeId::new(10), EdgeId::new(0));
trie.insert_edge(NodeId::new(2), NodeId::new(20), EdgeId::new(1));
trie.insert_edge(NodeId::new(3), NodeId::new(30), EdgeId::new(2));
let mut iter = trie.iter();
assert_eq!(iter.key(), Some(NodeId::new(1)));
assert!(iter.next());
assert_eq!(iter.key(), Some(NodeId::new(2)));
assert!(iter.next());
assert_eq!(iter.key(), Some(NodeId::new(3)));
assert!(!iter.next());
}
#[test]
fn test_trie_seek() {
let mut trie = TrieIndex::new();
for i in [1, 3, 5, 7, 9] {
trie.insert_edge(NodeId::new(i), NodeId::new(100), EdgeId::new(i));
}
let mut iter = trie.iter();
assert!(iter.seek(NodeId::new(4)));
assert_eq!(iter.key(), Some(NodeId::new(5)));
assert!(iter.seek(NodeId::new(7)));
assert_eq!(iter.key(), Some(NodeId::new(7)));
assert!(!iter.seek(NodeId::new(10)));
}
#[test]
fn test_leapfrog_join() {
let mut trie1 = TrieIndex::new();
let mut trie2 = TrieIndex::new();
for &i in &[1, 2, 3, 5] {
trie1.insert_edge(NodeId::new(i), NodeId::new(100), EdgeId::new(i));
}
for &i in &[2, 3, 4, 5] {
trie2.insert_edge(NodeId::new(i), NodeId::new(100), EdgeId::new(i + 10));
}
let iters = vec![trie1.iter(), trie2.iter()];
let mut join = LeapfrogJoin::new(iters);
let mut results = Vec::new();
loop {
if let Some(key) = join.key() {
results.push(key);
if !join.next() {
break;
}
} else {
break;
}
}
assert_eq!(results.len(), 3);
assert!(results.contains(&NodeId::new(2)));
assert!(results.contains(&NodeId::new(3)));
assert!(results.contains(&NodeId::new(5)));
}
#[test]
fn test_trie_get_existing_path() {
let mut trie = TrieIndex::new();
trie.insert_edge(NodeId::new(1), NodeId::new(2), EdgeId::new(10));
trie.insert_edge(NodeId::new(1), NodeId::new(3), EdgeId::new(11));
let edges = trie.get(&[NodeId::new(1), NodeId::new(2)]);
assert!(edges.is_some());
assert_eq!(edges.unwrap(), &[EdgeId::new(10)]);
}
#[test]
fn test_trie_get_nonexistent_path() {
let mut trie = TrieIndex::new();
trie.insert_edge(NodeId::new(1), NodeId::new(2), EdgeId::new(0));
assert!(trie.get(&[NodeId::new(99)]).is_none());
assert!(trie.get(&[NodeId::new(1), NodeId::new(99)]).is_none());
}
#[test]
fn test_trie_get_empty_path() {
let trie = TrieIndex::new();
assert!(trie.get(&[]).is_none());
}
#[test]
fn test_trie_iter_at_existing() {
let mut trie = TrieIndex::new();
trie.insert_edge(NodeId::new(1), NodeId::new(2), EdgeId::new(0));
trie.insert_edge(NodeId::new(1), NodeId::new(3), EdgeId::new(1));
let iter = trie.iter_at(&[NodeId::new(1)]);
assert!(iter.is_some());
let iter = iter.unwrap();
assert_eq!(iter.key(), Some(NodeId::new(2)));
}
#[test]
fn test_trie_iter_at_nonexistent() {
let trie = TrieIndex::new();
assert!(trie.iter_at(&[NodeId::new(99)]).is_none());
}
#[test]
fn test_leapfrog_join_open() {
let mut trie1 = TrieIndex::new();
let mut trie2 = TrieIndex::new();
trie1.insert_edge(NodeId::new(1), NodeId::new(10), EdgeId::new(0));
trie1.insert_edge(NodeId::new(1), NodeId::new(20), EdgeId::new(1));
trie2.insert_edge(NodeId::new(1), NodeId::new(15), EdgeId::new(2));
trie2.insert_edge(NodeId::new(1), NodeId::new(20), EdgeId::new(3));
let iters = vec![trie1.iter(), trie2.iter()];
let join = LeapfrogJoin::new(iters);
assert_eq!(join.key(), Some(NodeId::new(1)));
let child_iters = join.open();
assert!(child_iters.is_some());
let child_iters = child_iters.unwrap();
assert_eq!(child_iters.len(), 2);
}
#[test]
fn test_leapfrog_join_empty_intersection() {
let mut trie1 = TrieIndex::new();
let mut trie2 = TrieIndex::new();
trie1.insert_edge(NodeId::new(1), NodeId::new(10), EdgeId::new(0));
trie2.insert_edge(NodeId::new(2), NodeId::new(20), EdgeId::new(1));
let iters = vec![trie1.iter(), trie2.iter()];
let join = LeapfrogJoin::new(iters);
assert!(join.key().is_none());
}
#[test]
fn test_trie_seek_backward_stays_forward() {
let mut trie = TrieIndex::new();
for i in [1, 3, 5, 7] {
trie.insert_edge(NodeId::new(i), NodeId::new(100), EdgeId::new(i));
}
let mut iter = trie.iter();
assert!(iter.seek(NodeId::new(5)));
assert_eq!(iter.key(), Some(NodeId::new(5)));
assert!(iter.seek(NodeId::new(5)));
assert_eq!(iter.key(), Some(NodeId::new(5)));
}
}