use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use zerompk::{FromMessagePack, ToMessagePack};
use crate::error::ColumnarError;
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToMessagePack, FromMessagePack,
)]
pub struct RowLocation {
pub segment_id: u32,
pub row_index: u32,
}
#[derive(Debug, Clone)]
pub struct PkIndex {
inner: BTreeMap<Vec<u8>, RowLocation>,
}
impl PkIndex {
pub fn new() -> Self {
Self {
inner: BTreeMap::new(),
}
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn insert(
&mut self,
pk_bytes: Vec<u8>,
location: RowLocation,
) -> Result<(), ColumnarError> {
if self.inner.contains_key(&pk_bytes) {
return Err(ColumnarError::DuplicatePrimaryKey);
}
self.inner.insert(pk_bytes, location);
Ok(())
}
pub fn upsert(&mut self, pk_bytes: Vec<u8>, location: RowLocation) {
self.inner.insert(pk_bytes, location);
}
pub fn get(&self, pk_bytes: &[u8]) -> Option<&RowLocation> {
self.inner.get(pk_bytes)
}
pub fn remove(&mut self, pk_bytes: &[u8]) -> Option<RowLocation> {
self.inner.remove(pk_bytes)
}
pub fn contains(&self, pk_bytes: &[u8]) -> bool {
self.inner.contains_key(pk_bytes)
}
pub fn remap_segment(
&mut self,
old_segment_id: u32,
remap_fn: impl Fn(u32) -> Option<RowLocation>,
) {
let keys_to_remap: Vec<Vec<u8>> = self
.inner
.iter()
.filter(|(_, loc)| loc.segment_id == old_segment_id)
.map(|(k, _)| k.clone())
.collect();
for key in keys_to_remap {
let old_loc = self.inner.remove(&key).expect("key exists from filter");
if let Some(new_loc) = remap_fn(old_loc.row_index) {
self.inner.insert(key, new_loc);
}
}
}
pub fn bulk_insert(
&mut self,
segment_id: u32,
pk_bytes_list: &[Vec<u8>],
) -> Result<(), ColumnarError> {
for (row_index, pk_bytes) in pk_bytes_list.iter().enumerate() {
let location = RowLocation {
segment_id,
row_index: row_index as u32,
};
self.insert(pk_bytes.clone(), location)?;
}
Ok(())
}
pub fn remove_segment(&mut self, segment_id: u32) {
self.inner.retain(|_, loc| loc.segment_id != segment_id);
}
pub fn to_bytes(&self) -> Result<Vec<u8>, ColumnarError> {
let entries: Vec<(&Vec<u8>, &RowLocation)> = self.inner.iter().collect();
zerompk::to_msgpack_vec(&entries).map_err(|e| ColumnarError::Serialization(e.to_string()))
}
pub fn from_bytes(data: &[u8]) -> Result<Self, ColumnarError> {
let entries: Vec<(Vec<u8>, RowLocation)> =
zerompk::from_msgpack(data).map_err(|e| ColumnarError::Serialization(e.to_string()))?;
let mut inner = BTreeMap::new();
for (key, loc) in entries {
inner.insert(key, loc);
}
Ok(Self { inner })
}
}
impl Default for PkIndex {
fn default() -> Self {
Self::new()
}
}
pub fn encode_pk(value: &nodedb_types::value::Value) -> Vec<u8> {
use nodedb_types::value::Value;
match value {
Value::Integer(v) => {
let sortable = (*v as u64) ^ (1u64 << 63);
sortable.to_be_bytes().to_vec()
}
Value::String(s) => s.as_bytes().to_vec(),
Value::Uuid(s) => s.as_bytes().to_vec(),
Value::Decimal(d) => d.serialize().to_vec(),
Value::DateTime(dt) => {
let sortable = (dt.micros as u64) ^ (1u64 << 63);
sortable.to_be_bytes().to_vec()
}
_ => format!("{value:?}").into_bytes(),
}
}
pub fn encode_composite_pk(values: &[&nodedb_types::value::Value]) -> Vec<u8> {
let mut key = Vec::new();
for (i, val) in values.iter().enumerate() {
if i > 0 {
key.push(0xFF); }
key.extend_from_slice(&encode_pk(val));
}
key
}
#[cfg(test)]
mod tests {
use nodedb_types::value::Value;
use super::*;
#[test]
fn insert_and_lookup() {
let mut idx = PkIndex::new();
let pk = encode_pk(&Value::Integer(42));
let loc = RowLocation {
segment_id: 0,
row_index: 5,
};
idx.insert(pk.clone(), loc).expect("insert");
assert_eq!(idx.get(&pk), Some(&loc));
assert_eq!(idx.len(), 1);
}
#[test]
fn duplicate_pk_rejected() {
let mut idx = PkIndex::new();
let pk = encode_pk(&Value::Integer(1));
let loc = RowLocation {
segment_id: 0,
row_index: 0,
};
idx.insert(pk.clone(), loc).expect("first insert");
assert!(matches!(
idx.insert(pk, loc),
Err(ColumnarError::DuplicatePrimaryKey)
));
}
#[test]
fn remove_entry() {
let mut idx = PkIndex::new();
let pk = encode_pk(&Value::Integer(1));
let loc = RowLocation {
segment_id: 0,
row_index: 0,
};
idx.insert(pk.clone(), loc).expect("insert");
let removed = idx.remove(&pk);
assert_eq!(removed, Some(loc));
assert!(idx.is_empty());
}
#[test]
fn bulk_insert() {
let mut idx = PkIndex::new();
let pks: Vec<Vec<u8>> = (0..10).map(|i| encode_pk(&Value::Integer(i))).collect();
idx.bulk_insert(0, &pks).expect("bulk insert");
assert_eq!(idx.len(), 10);
let loc = idx.get(&pks[5]).expect("lookup");
assert_eq!(loc.segment_id, 0);
assert_eq!(loc.row_index, 5);
}
#[test]
fn remap_segment() {
let mut idx = PkIndex::new();
let pks: Vec<Vec<u8>> = (0..5).map(|i| encode_pk(&Value::Integer(i))).collect();
idx.bulk_insert(0, &pks).expect("bulk insert");
idx.remap_segment(0, |old_row| {
if old_row == 2 {
None } else {
Some(RowLocation {
segment_id: 1,
row_index: old_row + 10,
})
}
});
assert_eq!(idx.len(), 4); let loc = idx.get(&pks[0]).expect("row 0");
assert_eq!(loc.segment_id, 1);
assert_eq!(loc.row_index, 10);
assert!(idx.get(&pks[2]).is_none()); }
#[test]
fn remove_segment() {
let mut idx = PkIndex::new();
let pks: Vec<Vec<u8>> = (0..5).map(|i| encode_pk(&Value::Integer(i))).collect();
idx.bulk_insert(0, &pks).expect("seg 0");
let pks2: Vec<Vec<u8>> = (10..15).map(|i| encode_pk(&Value::Integer(i))).collect();
idx.bulk_insert(1, &pks2).expect("seg 1");
assert_eq!(idx.len(), 10);
idx.remove_segment(0);
assert_eq!(idx.len(), 5); }
#[test]
fn serialization_roundtrip() {
let mut idx = PkIndex::new();
let pks: Vec<Vec<u8>> = (0..100).map(|i| encode_pk(&Value::Integer(i))).collect();
idx.bulk_insert(0, &pks).expect("bulk insert");
let bytes = idx.to_bytes().expect("serialize");
let restored = PkIndex::from_bytes(&bytes).expect("deserialize");
assert_eq!(restored.len(), 100);
let loc = restored.get(&pks[50]).expect("lookup");
assert_eq!(loc.segment_id, 0);
assert_eq!(loc.row_index, 50);
}
#[test]
fn int_sort_order() {
let values = [-100i64, -1, 0, 1, 100];
let encoded: Vec<Vec<u8>> = values
.iter()
.map(|v| encode_pk(&Value::Integer(*v)))
.collect();
for i in 0..encoded.len() - 1 {
assert!(
encoded[i] < encoded[i + 1],
"sort order broken: {} < {}",
values[i],
values[i + 1]
);
}
}
#[test]
fn composite_pk() {
let pk1 = encode_composite_pk(&[&Value::Integer(1), &Value::String("a".into())]);
let pk2 = encode_composite_pk(&[&Value::Integer(1), &Value::String("b".into())]);
let pk3 = encode_composite_pk(&[&Value::Integer(2), &Value::String("a".into())]);
assert!(pk1 < pk2); assert!(pk2 < pk3); }
}