use std::hash::Hash;
use std::ops::Bound;
use std::sync::Arc;
use armour_rpc::protocol::UpsertKey;
use crate::hook::TypedWriteHook;
use crate::zero_tree::{from_value_bytes, to_bytes};
use crate::{Codec, DbError, DbResult, Key, TypedMap, TypedTree, ZeroMap, ZeroTree};
pub type KeyBytes = Vec<u8>;
pub type ValueBytes = Vec<u8>;
pub trait RpcHandler: Send + Sync {
fn name(&self) -> &str;
fn info(&self) -> (u64, u16);
fn get(&self, key: &[u8]) -> DbResult<Option<ValueBytes>>;
fn contains(&self, key: &[u8]) -> DbResult<Option<u32>>;
fn first(&self) -> DbResult<Option<(KeyBytes, ValueBytes)>>;
fn last(&self) -> DbResult<Option<(KeyBytes, ValueBytes)>>;
fn range(
&self,
start: Bound<KeyBytes>,
end: Bound<KeyBytes>,
) -> DbResult<Vec<(KeyBytes, ValueBytes)>>;
fn range_keys(&self, start: Bound<KeyBytes>, end: Bound<KeyBytes>)
-> DbResult<Vec<ValueBytes>>;
fn upsert(&self, key: UpsertKey, flag: Option<bool>, value: ValueBytes)
-> DbResult<ValueBytes>;
fn remove(&self, key: &[u8]) -> DbResult<()>;
fn count(&self) -> DbResult<u64>;
fn apply_batch(&self, items: Vec<(KeyBytes, Option<ValueBytes>)>) -> DbResult<()>;
}
fn check_key_len<K: Key>(key: &[u8]) -> DbResult<K> {
if key.len() != size_of::<K>() {
return Err(DbError::KeyNotFound);
}
Ok(K::from_bytes(key))
}
fn resolve_key<K: Key>(
key: UpsertKey,
seq: &super::seq::SeqGen,
name: &str,
) -> DbResult<(K, KeyBytes)> {
let key_bytes = match key {
UpsertKey::Sequence => {
let id = seq.next_id(name)?;
id.to_le_bytes().to_vec()
}
UpsertKey::Provided(k) => k,
};
if key_bytes.len() != size_of::<K>() {
return Err(DbError::Config("invalid key bytes"));
}
Ok((K::from_bytes(&key_bytes), key_bytes))
}
fn check_flag(tree_contains: bool, flag: Option<bool>) -> DbResult<()> {
if let Some(update_only) = flag {
if update_only && !tree_contains {
return Err(DbError::KeyNotFound);
}
if !update_only && tree_contains {
return Err(DbError::KeyExists);
}
}
Ok(())
}
fn bound_to_key_bound<K: Key>(bound: &Bound<KeyBytes>) -> DbResult<Bound<K>> {
match bound {
Bound::Included(b) => {
if b.len() != size_of::<K>() {
return Err(DbError::Config("invalid bound key bytes"));
}
Ok(Bound::Included(K::from_bytes(b)))
}
Bound::Excluded(b) => {
if b.len() != size_of::<K>() {
return Err(DbError::Config("invalid bound key bytes"));
}
Ok(Bound::Excluded(K::from_bytes(b)))
}
Bound::Unbounded => Ok(Bound::Unbounded),
}
}
fn unsupported(op: &str) -> DbResult<()> {
Err(DbError::Config(Box::leak(
format!("{op} not supported for map collections").into_boxed_str(),
)))
}
pub struct TypedTreeHandler<K, T, C, H: TypedWriteHook<K, T> = crate::NoHook>
where
K: Key + Ord,
T: Send + Sync,
C: Codec<T>,
{
pub name: String,
pub typ_hash: u64,
pub version: u16,
pub tree: Arc<TypedTree<K, T, C, H>>,
pub codec: Arc<C>,
pub seq: Arc<super::seq::SeqGen>,
}
impl<K, T, C, H> RpcHandler for TypedTreeHandler<K, T, C, H>
where
K: Key + Ord + Send + Sync,
T: Send + Sync,
C: Codec<T>,
H: TypedWriteHook<K, T>,
{
fn name(&self) -> &str {
&self.name
}
fn info(&self) -> (u64, u16) {
(self.typ_hash, self.version)
}
fn get(&self, key: &[u8]) -> DbResult<Option<ValueBytes>> {
let key: K = check_key_len(key)?;
match self.tree.get(&key) {
Some(val) => {
let mut buf = Vec::new();
self.codec.encode_to(&val, &mut buf);
Ok(Some(buf))
}
None => Ok(None),
}
}
fn contains(&self, key: &[u8]) -> DbResult<Option<u32>> {
let key: K = check_key_len(key)?;
match self.tree.get(&key) {
Some(val) => {
let mut buf = Vec::new();
self.codec.encode_to(&val, &mut buf);
Ok(Some(buf.len() as u32))
}
None => Ok(None),
}
}
fn first(&self) -> DbResult<Option<(KeyBytes, ValueBytes)>> {
match self.tree.first() {
Some((k, v)) => {
let mut buf = Vec::new();
self.codec.encode_to(&v, &mut buf);
Ok(Some((k.as_bytes().to_vec(), buf)))
}
None => Ok(None),
}
}
fn last(&self) -> DbResult<Option<(KeyBytes, ValueBytes)>> {
match self.tree.last() {
Some((k, v)) => {
let mut buf = Vec::new();
self.codec.encode_to(&v, &mut buf);
Ok(Some((k.as_bytes().to_vec(), buf)))
}
None => Ok(None),
}
}
fn range(
&self,
start: Bound<KeyBytes>,
end: Bound<KeyBytes>,
) -> DbResult<Vec<(KeyBytes, ValueBytes)>> {
let sb = bound_to_key_bound::<K>(&start)?;
let eb = bound_to_key_bound::<K>(&end)?;
let iter = self.tree.range_bounds(sb.as_ref(), eb.as_ref());
let mut result = Vec::new();
for (k, v) in iter {
let mut buf = Vec::new();
self.codec.encode_to(v, &mut buf);
result.push((k.as_bytes().to_vec(), buf));
}
Ok(result)
}
fn range_keys(
&self,
start: Bound<KeyBytes>,
end: Bound<KeyBytes>,
) -> DbResult<Vec<ValueBytes>> {
let sb = bound_to_key_bound::<K>(&start)?;
let eb = bound_to_key_bound::<K>(&end)?;
let iter = self.tree.range_bounds(sb.as_ref(), eb.as_ref());
let mut result = Vec::new();
for (k, _) in iter {
result.push(k.as_bytes().to_vec());
}
Ok(result)
}
fn upsert(
&self,
key: UpsertKey,
flag: Option<bool>,
value: ValueBytes,
) -> DbResult<ValueBytes> {
let typed_value = self.codec.decode_from(&value)?;
let (typed_key, key_bytes) = resolve_key::<K>(key, &self.seq, &self.name)?;
check_flag(self.tree.contains(&typed_key), flag)?;
self.tree.put(&typed_key, typed_value)?;
Ok(key_bytes)
}
fn remove(&self, key: &[u8]) -> DbResult<()> {
let key: K = check_key_len(key)?;
self.tree.delete(&key)?;
Ok(())
}
fn count(&self) -> DbResult<u64> {
Ok(self.tree.iter().count() as u64)
}
fn apply_batch(&self, items: Vec<(KeyBytes, Option<ValueBytes>)>) -> DbResult<()> {
for (key, value) in items {
let typed_key: K = check_key_len(&key)?;
match value {
Some(v) => {
let typed_value = self.codec.decode_from(&v)?;
self.tree.put(&typed_key, typed_value)?;
}
None => {
self.tree.delete(&typed_key)?;
}
}
}
Ok(())
}
}
pub struct TypedMapHandler<K, T, C, H: TypedWriteHook<K, T> = crate::NoHook>
where
K: Key + Send + Sync + Hash + Eq,
T: Send + Sync,
C: Codec<T>,
{
pub name: String,
pub typ_hash: u64,
pub version: u16,
pub map: Arc<TypedMap<K, T, C, H>>,
pub codec: Arc<C>,
pub seq: Arc<super::seq::SeqGen>,
}
impl<K, T, C, H> RpcHandler for TypedMapHandler<K, T, C, H>
where
K: Key + Send + Sync + Hash + Eq,
T: Send + Sync,
C: Codec<T>,
H: TypedWriteHook<K, T>,
{
fn name(&self) -> &str {
&self.name
}
fn info(&self) -> (u64, u16) {
(self.typ_hash, self.version)
}
fn get(&self, key: &[u8]) -> DbResult<Option<ValueBytes>> {
let key: K = check_key_len(key)?;
match self.map.get(&key) {
Some(val) => {
let mut buf = Vec::new();
self.codec.encode_to(&val, &mut buf);
Ok(Some(buf))
}
None => Ok(None),
}
}
fn contains(&self, key: &[u8]) -> DbResult<Option<u32>> {
let key: K = check_key_len(key)?;
match self.map.get(&key) {
Some(val) => {
let mut buf = Vec::new();
self.codec.encode_to(&val, &mut buf);
Ok(Some(buf.len() as u32))
}
None => Ok(None),
}
}
fn first(&self) -> DbResult<Option<(KeyBytes, ValueBytes)>> {
unsupported("first")?;
unreachable!()
}
fn last(&self) -> DbResult<Option<(KeyBytes, ValueBytes)>> {
unsupported("last")?;
unreachable!()
}
fn range(
&self,
_start: Bound<KeyBytes>,
_end: Bound<KeyBytes>,
) -> DbResult<Vec<(KeyBytes, ValueBytes)>> {
unsupported("range")?;
unreachable!()
}
fn range_keys(
&self,
_start: Bound<KeyBytes>,
_end: Bound<KeyBytes>,
) -> DbResult<Vec<ValueBytes>> {
unsupported("range_keys")?;
unreachable!()
}
fn upsert(
&self,
key: UpsertKey,
flag: Option<bool>,
value: ValueBytes,
) -> DbResult<ValueBytes> {
let typed_value = self.codec.decode_from(&value)?;
let (typed_key, key_bytes) = resolve_key::<K>(key, &self.seq, &self.name)?;
check_flag(self.map.contains(&typed_key), flag)?;
self.map.put(&typed_key, typed_value)?;
Ok(key_bytes)
}
fn remove(&self, key: &[u8]) -> DbResult<()> {
let key: K = check_key_len(key)?;
self.map.delete(&key)?;
Ok(())
}
fn count(&self) -> DbResult<u64> {
Ok(self.map.len() as u64)
}
fn apply_batch(&self, items: Vec<(KeyBytes, Option<ValueBytes>)>) -> DbResult<()> {
for (key, value) in items {
let typed_key: K = check_key_len(&key)?;
match value {
Some(v) => {
let typed_value = self.codec.decode_from(&v)?;
self.map.put(&typed_key, typed_value)?;
}
None => {
self.map.delete(&typed_key)?;
}
}
}
Ok(())
}
}
pub struct ZeroTreeHandler<K, const V: usize, T, H: TypedWriteHook<K, T> = crate::NoHook>
where
K: Key + Ord,
T: Copy + Send + Sync,
{
pub name: String,
pub typ_hash: u64,
pub version: u16,
pub tree: Arc<ZeroTree<K, V, T, H, crate::durability::Bitcask>>,
pub seq: Arc<super::seq::SeqGen>,
}
impl<K, const V: usize, T, H> RpcHandler for ZeroTreeHandler<K, V, T, H>
where
K: Key + Ord + Send + Sync,
T: Copy + Send + Sync,
H: TypedWriteHook<K, T>,
{
fn name(&self) -> &str {
&self.name
}
fn info(&self) -> (u64, u16) {
(self.typ_hash, self.version)
}
fn get(&self, key: &[u8]) -> DbResult<Option<ValueBytes>> {
let key: K = check_key_len(key)?;
match self.tree.get(&key) {
Some(val) => Ok(Some(to_bytes::<V, T>(&val).to_vec())),
None => Ok(None),
}
}
fn contains(&self, key: &[u8]) -> DbResult<Option<u32>> {
let key: K = check_key_len(key)?;
if self.tree.contains(&key) {
Ok(Some(V as u32))
} else {
Ok(None)
}
}
fn first(&self) -> DbResult<Option<(KeyBytes, ValueBytes)>> {
match self.tree.first() {
Some((k, v)) => Ok(Some((k.as_bytes().to_vec(), to_bytes::<V, T>(&v).to_vec()))),
None => Ok(None),
}
}
fn last(&self) -> DbResult<Option<(KeyBytes, ValueBytes)>> {
match self.tree.last() {
Some((k, v)) => Ok(Some((k.as_bytes().to_vec(), to_bytes::<V, T>(&v).to_vec()))),
None => Ok(None),
}
}
fn range(
&self,
start: Bound<KeyBytes>,
end: Bound<KeyBytes>,
) -> DbResult<Vec<(KeyBytes, ValueBytes)>> {
let sb = bound_to_key_bound::<K>(&start)?;
let eb = bound_to_key_bound::<K>(&end)?;
Ok(self
.tree
.range_bounds(sb.as_ref(), eb.as_ref())
.map(|(k, v)| (k.as_bytes().to_vec(), to_bytes::<V, T>(&v).to_vec()))
.collect())
}
fn range_keys(
&self,
start: Bound<KeyBytes>,
end: Bound<KeyBytes>,
) -> DbResult<Vec<ValueBytes>> {
let sb = bound_to_key_bound::<K>(&start)?;
let eb = bound_to_key_bound::<K>(&end)?;
Ok(self
.tree
.range_bounds(sb.as_ref(), eb.as_ref())
.map(|(k, _)| k.as_bytes().to_vec())
.collect())
}
fn upsert(
&self,
key: UpsertKey,
flag: Option<bool>,
value: ValueBytes,
) -> DbResult<ValueBytes> {
if value.len() != V {
return Err(DbError::Config("invalid value bytes"));
}
let typed_value = from_value_bytes::<V, T>(
value[..V]
.try_into()
.map_err(|_| DbError::CorruptedEntry { offset: 0 })?,
);
let (typed_key, key_bytes) = resolve_key::<K>(key, &self.seq, &self.name)?;
check_flag(self.tree.contains(&typed_key), flag)?;
self.tree.put(&typed_key, &typed_value)?;
Ok(key_bytes)
}
fn remove(&self, key: &[u8]) -> DbResult<()> {
let key: K = check_key_len(key)?;
self.tree.delete(&key)?;
Ok(())
}
fn count(&self) -> DbResult<u64> {
Ok(self.tree.iter().count() as u64)
}
fn apply_batch(&self, items: Vec<(KeyBytes, Option<ValueBytes>)>) -> DbResult<()> {
for (key, value) in items {
let typed_key: K = check_key_len(&key)?;
match value {
Some(v) => {
if v.len() != V {
return Err(DbError::Config("invalid value bytes"));
}
let typed_value = from_value_bytes::<V, T>(
v[..V]
.try_into()
.map_err(|_| DbError::CorruptedEntry { offset: 0 })?,
);
self.tree.put(&typed_key, &typed_value)?;
}
None => {
self.tree.delete(&typed_key)?;
}
}
}
Ok(())
}
}
pub struct ZeroMapHandler<K, const V: usize, T, H: TypedWriteHook<K, T> = crate::NoHook>
where
K: Key + Send + Sync + Hash + Eq,
T: Copy + Send + Sync,
{
pub name: String,
pub typ_hash: u64,
pub version: u16,
pub map: Arc<ZeroMap<K, V, T, H, crate::durability::Bitcask>>,
pub seq: Arc<super::seq::SeqGen>,
}
impl<K, const V: usize, T, H> RpcHandler for ZeroMapHandler<K, V, T, H>
where
K: Key + Send + Sync + Hash + Eq,
T: Copy + Send + Sync,
H: TypedWriteHook<K, T>,
{
fn name(&self) -> &str {
&self.name
}
fn info(&self) -> (u64, u16) {
(self.typ_hash, self.version)
}
fn get(&self, key: &[u8]) -> DbResult<Option<ValueBytes>> {
let key: K = check_key_len(key)?;
match self.map.get(&key) {
Some(val) => Ok(Some(to_bytes::<V, T>(&val).to_vec())),
None => Ok(None),
}
}
fn contains(&self, key: &[u8]) -> DbResult<Option<u32>> {
let key: K = check_key_len(key)?;
if self.map.contains(&key) {
Ok(Some(V as u32))
} else {
Ok(None)
}
}
fn first(&self) -> DbResult<Option<(KeyBytes, ValueBytes)>> {
unsupported("first")?;
unreachable!()
}
fn last(&self) -> DbResult<Option<(KeyBytes, ValueBytes)>> {
unsupported("last")?;
unreachable!()
}
fn range(
&self,
_start: Bound<KeyBytes>,
_end: Bound<KeyBytes>,
) -> DbResult<Vec<(KeyBytes, ValueBytes)>> {
unsupported("range")?;
unreachable!()
}
fn range_keys(
&self,
_start: Bound<KeyBytes>,
_end: Bound<KeyBytes>,
) -> DbResult<Vec<ValueBytes>> {
unsupported("range_keys")?;
unreachable!()
}
fn upsert(
&self,
key: UpsertKey,
flag: Option<bool>,
value: ValueBytes,
) -> DbResult<ValueBytes> {
if value.len() != V {
return Err(DbError::Config("invalid value bytes"));
}
let typed_value = from_value_bytes::<V, T>(
value[..V]
.try_into()
.map_err(|_| DbError::CorruptedEntry { offset: 0 })?,
);
let (typed_key, key_bytes) = resolve_key::<K>(key, &self.seq, &self.name)?;
check_flag(self.map.contains(&typed_key), flag)?;
self.map.put(&typed_key, &typed_value)?;
Ok(key_bytes)
}
fn remove(&self, key: &[u8]) -> DbResult<()> {
let key: K = check_key_len(key)?;
self.map.delete(&key)?;
Ok(())
}
fn count(&self) -> DbResult<u64> {
Ok(self.map.len() as u64)
}
fn apply_batch(&self, items: Vec<(KeyBytes, Option<ValueBytes>)>) -> DbResult<()> {
for (key, value) in items {
let typed_key: K = check_key_len(&key)?;
match value {
Some(v) => {
if v.len() != V {
return Err(DbError::Config("invalid value bytes"));
}
let typed_value = from_value_bytes::<V, T>(
v[..V]
.try_into()
.map_err(|_| DbError::CorruptedEntry { offset: 0 })?,
);
self.map.put(&typed_key, &typed_value)?;
}
None => {
self.map.delete(&typed_key)?;
}
}
}
Ok(())
}
}