use crate::indexes::{PageIndex, PageOffset, RowPointer, SquashedOffset};
use core::mem;
use core::ops::{Bound, RangeBounds};
use core::option::IntoIter;
use spacetimedb_sats::memory_usage::MemoryUsage;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct UniqueDirectIndex {
outer: Vec<Option<InnerIndex>>,
len: usize,
}
impl MemoryUsage for UniqueDirectIndex {
fn heap_usage(&self) -> usize {
let Self { outer, len } = self;
outer.heap_usage() + len.heap_usage()
}
}
const PAGE_SIZE: usize = 4_096;
const KEYS_PER_INNER: usize = PAGE_SIZE / size_of::<RowPointer>();
type InnerIndexArray = [RowPointer; KEYS_PER_INNER];
#[derive(Debug, Clone, PartialEq, Eq)]
struct InnerIndex {
inner: Box<InnerIndexArray>,
}
impl MemoryUsage for InnerIndex {
fn heap_usage(&self) -> usize {
self.inner.heap_usage()
}
}
pub(super) const NONE_PTR: RowPointer = RowPointer::new(false, PageIndex(0), PageOffset(0), SquashedOffset::TX_STATE);
struct InnerIndexKey(usize);
fn split_key(key: usize) -> (usize, InnerIndexKey) {
(key / KEYS_PER_INNER, InnerIndexKey(key % KEYS_PER_INNER))
}
impl InnerIndex {
fn new() -> Self {
use std::alloc::{alloc_zeroed, handle_alloc_error, Layout};
let layout = Layout::new::<InnerIndexArray>();
let raw: *mut InnerIndexArray = unsafe { alloc_zeroed(layout) }.cast();
if raw.is_null() {
handle_alloc_error(layout);
}
let inner = unsafe { Box::from_raw(raw) };
Self { inner }
}
fn get(&self, key: InnerIndexKey) -> RowPointer {
*unsafe { self.inner.get_unchecked(key.0) }
}
fn get_mut(&mut self, key: InnerIndexKey) -> &mut RowPointer {
unsafe { self.inner.get_unchecked_mut(key.0) }
}
}
impl UniqueDirectIndex {
pub fn insert(&mut self, key: usize, val: RowPointer) -> Result<(), RowPointer> {
let (key_outer, key_inner) = split_key(key);
let outer = &mut self.outer;
outer.resize(outer.len().max(key_outer + 1), None);
let inner = unsafe { outer.get_unchecked_mut(key_outer) };
let inner = inner.get_or_insert_with(InnerIndex::new);
let slot = inner.get_mut(key_inner);
let in_slot = *slot;
if in_slot == NONE_PTR {
*slot = val.with_reserved_bit(true);
self.len += 1;
Ok(())
} else {
Err(in_slot.with_reserved_bit(false))
}
}
pub fn delete(&mut self, key: usize) -> bool {
let (key_outer, key_inner) = split_key(key);
let outer = &mut self.outer;
if let Some(Some(inner)) = outer.get_mut(key_outer) {
let slot = inner.get_mut(key_inner);
let old_val = mem::replace(slot, NONE_PTR);
let deleted = old_val != NONE_PTR;
self.len -= deleted as usize;
return deleted;
}
false
}
pub fn seek_point(&self, key: usize) -> UniqueDirectIndexPointIter {
let (outer_key, inner_key) = split_key(key);
let point = self
.outer
.get(outer_key)
.and_then(|x| x.as_ref())
.map(|inner| inner.get(inner_key))
.filter(|slot| *slot != NONE_PTR);
UniqueDirectIndexPointIter::new(point)
}
pub fn seek_range(&self, range: &impl RangeBounds<usize>) -> UniqueDirectIndexRangeIter {
let max_key = self.outer.len() * KEYS_PER_INNER;
let start = match range.start_bound() {
Bound::Included(&s) => s,
Bound::Excluded(&s) => s + 1, Bound::Unbounded => 0,
};
let end = match range.end_bound() {
Bound::Included(&e) => e + 1, Bound::Excluded(&e) => e,
Bound::Unbounded => max_key,
};
let end = end.min(max_key);
let start = start.min(end);
UniqueDirectIndexRangeIter {
outer: &self.outer,
start,
end,
}
}
pub fn num_keys(&self) -> usize {
self.len
}
pub fn len(&self) -> usize {
self.len
}
#[allow(unused)] pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&mut self) {
self.outer.clear();
self.len = 0;
}
pub(crate) fn can_merge(&self, other: &Self, ignore: impl Fn(&RowPointer) -> bool) -> Result<(), RowPointer> {
for (inner_s, inner_o) in self.outer.iter().zip(&other.outer) {
let (Some(inner_s), Some(inner_o)) = (inner_s, inner_o) else {
continue;
};
for (slot_s, slot_o) in inner_s.inner.iter().zip(inner_o.inner.iter()) {
let ptr_s = slot_s.with_reserved_bit(false);
if *slot_s != NONE_PTR && *slot_o != NONE_PTR && !ignore(&ptr_s) {
return Err(ptr_s);
}
}
}
Ok(())
}
}
pub struct UniqueDirectIndexPointIter {
iter: IntoIter<RowPointer>,
}
impl UniqueDirectIndexPointIter {
pub(super) fn new(point: Option<RowPointer>) -> Self {
let iter = point.map(|ptr| ptr.with_reserved_bit(false)).into_iter();
Self { iter }
}
}
impl Iterator for UniqueDirectIndexPointIter {
type Item = RowPointer;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
#[derive(Debug)]
pub struct UniqueDirectIndexRangeIter<'a> {
outer: &'a [Option<InnerIndex>],
start: usize,
end: usize,
}
impl Iterator for UniqueDirectIndexRangeIter<'_> {
type Item = RowPointer;
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.start >= self.end {
return None;
}
let (outer_key, inner_key) = split_key(self.start);
let inner = unsafe { self.outer.get_unchecked(outer_key) };
let Some(inner) = inner else {
self.start += KEYS_PER_INNER;
continue;
};
let ptr = inner.get(inner_key);
self.start += 1;
if ptr != NONE_PTR {
return Some(ptr.with_reserved_bit(false));
}
}
}
}
#[cfg(test)]
pub(super) mod test {
use super::*;
use core::iter::repeat_with;
use spacetimedb_sats::layout::Size;
const FIXED_ROW_SIZE: Size = Size(4 * 4);
pub(crate) fn gen_row_pointers() -> impl Iterator<Item = RowPointer> {
let mut page_index = PageIndex(0);
let mut page_offset = PageOffset(0);
repeat_with(move || {
if page_offset.0 as usize + FIXED_ROW_SIZE.0 as usize >= PageOffset::PAGE_END.0 as usize {
page_index.0 += 1;
page_offset = PageOffset(0);
} else {
page_offset += FIXED_ROW_SIZE;
}
RowPointer::new(false, page_index, page_offset, SquashedOffset::COMMITTED_STATE)
})
}
#[test]
fn seek_range_gives_back_inserted() {
let range = (KEYS_PER_INNER - 2)..(KEYS_PER_INNER + 2);
let (keys, ptrs): (Vec<_>, Vec<_>) = range.clone().zip(gen_row_pointers()).unzip();
let mut index = UniqueDirectIndex::default();
for (key, ptr) in keys.iter().zip(&ptrs) {
index.insert(*key, *ptr).unwrap();
}
assert_eq!(index.len(), 4);
let ptrs_found = index.seek_range(&range).collect::<Vec<_>>();
assert_eq!(ptrs, ptrs_found);
}
#[test]
fn inserting_again_errors() {
let range = (KEYS_PER_INNER - 2)..(KEYS_PER_INNER + 2);
let (keys, ptrs): (Vec<_>, Vec<_>) = range.zip(gen_row_pointers()).unzip();
let mut index = UniqueDirectIndex::default();
for (key, ptr) in keys.iter().zip(&ptrs) {
index.insert(*key, *ptr).unwrap();
}
for (key, ptr) in keys.iter().zip(&ptrs) {
assert_eq!(index.insert(*key, *ptr).unwrap_err(), *ptr)
}
}
#[test]
fn deleting_allows_reinsertion() {
let range = (KEYS_PER_INNER - 2)..(KEYS_PER_INNER + 2);
let (keys, ptrs): (Vec<_>, Vec<_>) = range.zip(gen_row_pointers()).unzip();
let mut index = UniqueDirectIndex::default();
for (key, ptr) in keys.iter().zip(&ptrs) {
index.insert(*key, *ptr).unwrap();
}
assert_eq!(index.len(), 4);
let key = KEYS_PER_INNER + 1;
let ptr = index.seek_point(key).next().unwrap();
assert!(index.delete(key));
assert!(!index.delete(key));
assert_eq!(index.len(), 3);
index.insert(key, ptr).unwrap();
assert_eq!(index.len(), 4);
}
}