use std::collections::BTreeSet;
use std::fmt;
use std::os::raw::c_void;
use incrementalmerkletree::{Address, Level};
use shardtree::{
store::{Checkpoint, ShardStore},
LocatedPrunableTree, LocatedTree, PrunableTree, Tree,
};
use crate::hash::{MerkleHashVote, SHARD_HEIGHT};
use crate::serde::{read_checkpoint, read_shard_vote, write_checkpoint, write_shard_vote};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum KvError {
IoError,
Deserialization,
Serialization,
}
impl fmt::Display for KvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
KvError::IoError => write!(f, "KV callback returned an error"),
KvError::Deserialization => write!(f, "failed to deserialize KV data"),
KvError::Serialization => write!(f, "failed to serialize data for KV"),
}
}
}
impl std::error::Error for KvError {}
const SHARD_PREFIX: u8 = 0x0F;
const CAP_KEY: u8 = 0x10;
const CHECKPOINT_PREFIX: u8 = 0x11;
fn shard_key(index: u64) -> [u8; 9] {
let mut k = [0u8; 9];
k[0] = SHARD_PREFIX;
k[1..].copy_from_slice(&index.to_be_bytes());
k
}
fn cap_key() -> [u8; 1] {
[CAP_KEY]
}
fn checkpoint_key(id: u32) -> [u8; 5] {
let mut k = [0u8; 5];
k[0] = CHECKPOINT_PREFIX;
k[1..].copy_from_slice(&id.to_be_bytes());
k
}
pub type KvGetFn = unsafe extern "C" fn(
ctx: *mut c_void,
key: *const u8,
key_len: usize,
out_val: *mut *mut u8,
out_val_len: *mut usize,
) -> i32;
pub type KvSetFn = unsafe extern "C" fn(
ctx: *mut c_void,
key: *const u8,
key_len: usize,
val: *const u8,
val_len: usize,
) -> i32;
pub type KvDeleteFn = unsafe extern "C" fn(ctx: *mut c_void, key: *const u8, key_len: usize) -> i32;
pub type KvIterCreateFn = unsafe extern "C" fn(
ctx: *mut c_void,
prefix: *const u8,
prefix_len: usize,
reverse: u8,
) -> *mut c_void;
pub type KvIterNextFn = unsafe extern "C" fn(
iter: *mut c_void,
out_key: *mut *mut u8,
out_key_len: *mut usize,
out_val: *mut *mut u8,
out_val_len: *mut usize,
) -> i32;
pub type KvIterFreeFn = unsafe extern "C" fn(iter: *mut c_void);
pub type KvFreeBufFn = unsafe extern "C" fn(ptr: *mut u8, len: usize);
#[derive(Clone, Copy)]
pub struct KvCallbacks {
pub ctx: *mut c_void,
pub get: KvGetFn,
pub set: KvSetFn,
pub delete: KvDeleteFn,
pub iter_create: KvIterCreateFn,
pub iter_next: KvIterNextFn,
pub iter_free: KvIterFreeFn,
pub free_buf: KvFreeBufFn,
}
unsafe impl Send for KvCallbacks {}
unsafe impl Sync for KvCallbacks {}
impl KvCallbacks {
pub fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>, KvError> {
let mut out_ptr: *mut u8 = std::ptr::null_mut();
let mut out_len: usize = 0;
let rc = unsafe {
(self.get)(
self.ctx,
key.as_ptr(),
key.len(),
&mut out_ptr,
&mut out_len,
)
};
match rc {
0 => {
let val = unsafe { std::slice::from_raw_parts(out_ptr, out_len).to_vec() };
unsafe { (self.free_buf)(out_ptr, out_len) };
Ok(Some(val))
}
1 => Ok(None), _ => Err(KvError::IoError), }
}
pub fn set(&self, key: &[u8], val: &[u8]) -> Result<(), KvError> {
let rc = unsafe { (self.set)(self.ctx, key.as_ptr(), key.len(), val.as_ptr(), val.len()) };
if rc != 0 {
Err(KvError::IoError)
} else {
Ok(())
}
}
pub fn delete(&self, key: &[u8]) -> Result<(), KvError> {
let rc = unsafe { (self.delete)(self.ctx, key.as_ptr(), key.len()) };
if rc != 0 {
Err(KvError::IoError)
} else {
Ok(())
}
}
fn iter(&self, prefix: &[u8], reverse: bool) -> KvIter<'_> {
let handle =
unsafe { (self.iter_create)(self.ctx, prefix.as_ptr(), prefix.len(), reverse as u8) };
KvIter { handle, cb: self }
}
}
struct KvIter<'a> {
handle: *mut c_void,
cb: &'a KvCallbacks,
}
impl<'a> KvIter<'a> {
fn next(&mut self) -> Option<(Vec<u8>, Vec<u8>)> {
if self.handle.is_null() {
return None;
}
let mut key_ptr: *mut u8 = std::ptr::null_mut();
let mut key_len: usize = 0;
let mut val_ptr: *mut u8 = std::ptr::null_mut();
let mut val_len: usize = 0;
let rc = unsafe {
(self.cb.iter_next)(
self.handle,
&mut key_ptr,
&mut key_len,
&mut val_ptr,
&mut val_len,
)
};
if rc != 0 {
return None;
}
let key = unsafe { std::slice::from_raw_parts(key_ptr, key_len).to_vec() };
unsafe { (self.cb.free_buf)(key_ptr, key_len) };
let val = unsafe { std::slice::from_raw_parts(val_ptr, val_len).to_vec() };
unsafe { (self.cb.free_buf)(val_ptr, val_len) };
Some((key, val))
}
}
impl<'a> Drop for KvIter<'a> {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { (self.cb.iter_free)(self.handle) };
}
}
}
pub struct KvShardStore {
pub(crate) cb: KvCallbacks,
}
impl KvShardStore {
pub fn new(cb: KvCallbacks) -> Self {
Self { cb }
}
}
impl ShardStore for KvShardStore {
type H = MerkleHashVote;
type CheckpointId = u32;
type Error = KvError;
fn get_shard(
&self,
shard_root: Address,
) -> Result<Option<LocatedPrunableTree<MerkleHashVote>>, KvError> {
let idx = shard_root.index();
let key = shard_key(idx);
let Some(blob) = self.cb.get(&key)? else {
return Ok(None);
};
match read_shard_vote(&blob) {
Ok(tree) => Ok(LocatedTree::from_parts(shard_root, tree).ok()),
Err(_) => Err(KvError::Deserialization),
}
}
fn last_shard(&self) -> Result<Option<LocatedPrunableTree<MerkleHashVote>>, KvError> {
let prefix = [SHARD_PREFIX];
let mut iter = self.cb.iter(&prefix, true );
let Some((key, val)) = iter.next() else {
return Ok(None);
};
if key.len() < 9 {
return Ok(None);
}
let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
let level = Level::from(SHARD_HEIGHT);
let addr = Address::from_parts(level, idx);
match read_shard_vote(&val) {
Ok(tree) => Ok(LocatedTree::from_parts(addr, tree).ok()),
Err(_) => Err(KvError::Deserialization),
}
}
fn put_shard(&mut self, subtree: LocatedPrunableTree<MerkleHashVote>) -> Result<(), KvError> {
let idx = subtree.root_addr().index();
let key = shard_key(idx);
let blob = write_shard_vote(subtree.root()).map_err(|_| KvError::Serialization)?;
self.cb.set(&key, &blob)
}
fn get_shard_roots(&self) -> Result<Vec<Address>, KvError> {
let prefix = [SHARD_PREFIX];
let mut iter = self.cb.iter(&prefix, false);
let level = Level::from(SHARD_HEIGHT);
let mut roots = Vec::new();
while let Some((key, _)) = iter.next() {
if key.len() < 9 {
continue;
}
let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
roots.push(Address::from_parts(level, idx));
}
Ok(roots)
}
fn truncate_shards(&mut self, shard_index: u64) -> Result<(), KvError> {
let prefix = [SHARD_PREFIX];
let mut iter = self.cb.iter(&prefix, false);
let mut to_delete = Vec::new();
while let Some((key, _)) = iter.next() {
if key.len() < 9 {
continue;
}
let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
if idx >= shard_index {
to_delete.push(key);
}
}
drop(iter);
for key in to_delete {
self.cb.delete(&key)?;
}
Ok(())
}
fn get_cap(&self) -> Result<PrunableTree<MerkleHashVote>, KvError> {
let key = cap_key();
let Some(blob) = self.cb.get(&key)? else {
return Ok(Tree::empty());
};
read_shard_vote(&blob).map_err(|_| KvError::Deserialization)
}
fn put_cap(&mut self, cap: PrunableTree<MerkleHashVote>) -> Result<(), KvError> {
let key = cap_key();
let blob = write_shard_vote(&cap).map_err(|_| KvError::Serialization)?;
self.cb.set(&key, &blob)
}
fn min_checkpoint_id(&self) -> Result<Option<u32>, KvError> {
let prefix = [CHECKPOINT_PREFIX];
let mut iter = self.cb.iter(&prefix, false);
Ok(iter.next().and_then(|(k, _)| {
if k.len() >= 5 {
Some(u32::from_be_bytes(k[1..5].try_into().unwrap()))
} else {
None
}
}))
}
fn max_checkpoint_id(&self) -> Result<Option<u32>, KvError> {
let prefix = [CHECKPOINT_PREFIX];
let mut iter = self.cb.iter(&prefix, true );
Ok(iter.next().and_then(|(k, _)| {
if k.len() >= 5 {
Some(u32::from_be_bytes(k[1..5].try_into().unwrap()))
} else {
None
}
}))
}
fn add_checkpoint(
&mut self,
checkpoint_id: u32,
checkpoint: Checkpoint,
) -> Result<(), KvError> {
let key = checkpoint_key(checkpoint_id);
let blob = write_checkpoint(&checkpoint);
self.cb.set(&key, &blob)
}
fn checkpoint_count(&self) -> Result<usize, KvError> {
let prefix = [CHECKPOINT_PREFIX];
let mut iter = self.cb.iter(&prefix, false);
let mut count = 0usize;
while iter.next().is_some() {
count += 1;
}
Ok(count)
}
fn get_checkpoint_at_depth(
&self,
checkpoint_depth: usize,
) -> Result<Option<(u32, Checkpoint)>, KvError> {
let prefix = [CHECKPOINT_PREFIX];
let mut iter = self.cb.iter(&prefix, true );
let mut seen = 0usize;
while let Some((key, val)) = iter.next() {
if seen == checkpoint_depth {
if key.len() < 5 {
return Ok(None);
}
let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
return Ok(read_checkpoint(&val).ok().map(|cp| (id, cp)));
}
seen += 1;
}
Ok(None)
}
fn get_checkpoint(&self, checkpoint_id: &u32) -> Result<Option<Checkpoint>, KvError> {
let key = checkpoint_key(*checkpoint_id);
let Some(blob) = self.cb.get(&key)? else {
return Ok(None);
};
Ok(read_checkpoint(&blob).ok())
}
fn with_checkpoints<F>(&mut self, limit: usize, mut callback: F) -> Result<(), KvError>
where
F: FnMut(&u32, &Checkpoint) -> Result<(), KvError>,
{
let prefix = [CHECKPOINT_PREFIX];
let mut iter = self.cb.iter(&prefix, false);
let mut count = 0usize;
while count < limit {
let Some((key, val)) = iter.next() else {
break;
};
if key.len() < 5 {
continue;
}
let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
if let Ok(cp) = read_checkpoint(&val) {
callback(&id, &cp)?;
}
count += 1;
}
Ok(())
}
fn for_each_checkpoint<F>(&self, limit: usize, mut callback: F) -> Result<(), KvError>
where
F: FnMut(&u32, &Checkpoint) -> Result<(), KvError>,
{
let prefix = [CHECKPOINT_PREFIX];
let mut iter = self.cb.iter(&prefix, false);
let mut count = 0usize;
while count < limit {
let Some((key, val)) = iter.next() else {
break;
};
if key.len() < 5 {
continue;
}
let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
if let Ok(cp) = read_checkpoint(&val) {
callback(&id, &cp)?;
}
count += 1;
}
Ok(())
}
fn update_checkpoint_with<F>(&mut self, checkpoint_id: &u32, update: F) -> Result<bool, KvError>
where
F: Fn(&mut Checkpoint) -> Result<(), KvError>,
{
let key = checkpoint_key(*checkpoint_id);
let Some(blob) = self.cb.get(&key)? else {
return Ok(false);
};
let Ok(mut cp) = read_checkpoint(&blob) else {
return Ok(false);
};
update(&mut cp)?;
let new_blob = write_checkpoint(&cp);
self.cb.set(&key, &new_blob)?;
Ok(true)
}
fn remove_checkpoint(&mut self, checkpoint_id: &u32) -> Result<(), KvError> {
let key = checkpoint_key(*checkpoint_id);
self.cb.delete(&key)
}
fn truncate_checkpoints_retaining(&mut self, checkpoint_id: &u32) -> Result<(), KvError> {
let prefix = [CHECKPOINT_PREFIX];
let mut iter = self.cb.iter(&prefix, false);
let mut to_delete = Vec::new();
while let Some((key, _)) = iter.next() {
if key.len() < 5 {
continue;
}
let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
if id < *checkpoint_id {
to_delete.push(key);
} else {
break;
}
}
drop(iter);
for key in to_delete {
self.cb.delete(&key)?;
}
let retain_key = checkpoint_key(*checkpoint_id);
if let Some(blob) = self.cb.get(&retain_key)? {
if let Ok(cp) = read_checkpoint(&blob) {
let cleared = Checkpoint::from_parts(cp.tree_state(), BTreeSet::new());
self.cb.set(&retain_key, &write_checkpoint(&cleared))?;
}
}
Ok(())
}
}