use std::{
ops::Bound,
ops::RangeBounds,
sync::atomic::AtomicBool,
sync::{Arc, atomic::Ordering},
};
use armour_derive::armour_metrics;
use derive_more::Debug;
use fjall::{Keyspace, UserValue};
use parking_lot::Mutex;
use rayon::iter::IntoParallelIterator;
use xxhash_rust::xxh3::Xxh3Default;
use super::{ByteValue, MaybeParIter, db::Db, events::ChangeEvent};
use crate::{
DbError, DbResult,
logdb::{RawIterTree, raw_filter_map},
types::{attribute::EntityAttribute, num_ops::g4bits},
utils::{CheckSumVec, CollectionInfo, GroupVal, HashPoints},
};
#[derive(Debug)]
pub(crate) struct InnerFields {
pub(crate) info: CollectionInfo,
pub(crate) hashpoints: HashPoints,
}
impl InnerFields {
pub(crate) fn invalidate_hash(&self, group_id: u32) {
self.hashpoints.insert(
group_id,
GroupVal {
hash: 0,
changed: true,
},
);
}
}
#[derive(Clone, Debug)]
pub struct RawTree {
pub(crate) db: Db,
pub name: String,
pub partition_name: String,
pub hashname: u64,
pub attributes: &'static EntityAttribute,
#[debug(skip)]
pub(crate) tree: Keyspace,
#[debug(skip)]
pub(crate) removed: Keyspace,
pub(crate) inner: Arc<InnerFields>,
pub(crate) meta_saved: Arc<AtomicBool>,
#[debug(skip)]
pub(crate) seq_lock: Arc<Mutex<()>>,
}
impl Drop for RawTree {
fn drop(&mut self) {
let count = Arc::strong_count(&self.meta_saved);
if count == 1 {
self.close();
}
}
}
impl RawTree {
pub fn static_name(&self) -> &'static str {
self.attributes.name
}
#[instrument(skip_all, fields(name = self.name, ret))]
#[armour_metrics(prefix = "armour_db_rawtree", name = self.static_name())]
pub fn checksum(&self) -> u32 {
let mut hasher = crc32fast::Hasher::new();
let iter = self.tree.iter();
iter.for_each(|item| {
let item = item.into_inner();
match item {
Ok((k, v)) => {
hasher.update(&k);
hasher.update(&v);
}
Err(err) => {
error!(%err);
}
}
});
let checksum = hasher.finalize();
if checksum != 0 {
debug!("checksum: {checksum:#X}");
}
checksum
}
pub fn is_empty(&self) -> bool {
self.tree.first_key_value().is_none()
}
#[instrument(level = "debug", skip_all, fields(name = self.name))]
pub fn count(&self) -> usize {
self.tree.approximate_len()
}
#[instrument(skip_all, fields(name = self.name))]
#[armour_metrics(prefix = "armour_db_rawtree", name = self.static_name())]
pub fn hashpoints(&self) -> CheckSumVec {
self.inner
.hashpoints
.iter()
.map(|entry| {
let key = *entry.key();
(key, entry.value().hash)
})
.collect()
}
#[instrument(level = "debug", skip_all, fields(name = self.name))]
pub fn scan_group(&self, group: u32) -> RawIterTree {
counter!("armour_db_rawtree_scan_group_total", "name" => self.static_name()).increment(1);
let start = group.to_be_bytes();
let group_bits = self.attributes.group_bits;
let bits_sub = u32::BITS - group_bits;
let bits_pow_of_two = 2u32.pow(bits_sub);
let end = group + bits_pow_of_two;
let end = end.to_be_bytes();
let start = Bound::Included(start);
let end = Bound::Excluded(end);
let range = (start, end);
self.tree.range(range).filter_map(raw_filter_map)
}
#[instrument(skip_all, fields(name = self.name, ret))]
#[armour_metrics(prefix = "armour_db_rawtree", name = self.static_name())]
pub fn recalcucate_hash(&self) -> u64 {
let hash = self
.inner
.hashpoints
.iter()
.map(|item| {
let group_val = item.value();
if group_val.changed {
let group = *item.key();
drop(item);
let mut hash_val = Xxh3Default::new();
for (key, value) in self.scan_group(group) {
hash_val.update(&key);
hash_val.update(&value);
}
let hash = hash_val.digest();
self.inner.hashpoints.insert(
group,
GroupVal {
hash,
changed: false,
},
);
hash
} else {
item.value().hash
}
})
.fold(Xxh3Default::new(), |mut hasher, item| {
hasher.update(&item.to_le_bytes());
hasher
});
hash.digest()
}
#[instrument(skip_all, fields(name = self.name))]
#[armour_metrics(prefix = "armour_db_rawtree", name = self.static_name())]
pub fn close(&self) {
if !self.meta_saved.swap(true, Ordering::AcqRel) {
let count = Arc::strong_count(&self.meta_saved);
if count != 1 {
error!(count, "strong refs");
}
let typ_hash = self.attributes.ty.h();
let version = self.attributes.version;
let info = CollectionInfo { typ_hash, version };
self.db.db_info.update(|db_info| {
db_info.collections.insert(self.name.clone(), info);
});
} else {
warn!("tree already closed");
}
}
#[instrument(level = "debug", skip_all, fields(name = self.name))]
#[armour_metrics(prefix = "armour_db_rawtree", name = self.static_name())]
pub fn get(&self, id: &[u8]) -> DbResult<Option<Vec<u8>>> {
self.tree
.get(id)
.map(|item| {
item.map(|item| {
let len = id.len() + item.len();
let mut v = vec![0; len];
v[..id.len()].copy_from_slice(id);
v[id.len()..].copy_from_slice(&item);
v
})
})
.map_err(DbError::from)
}
#[instrument(level = "debug", skip_all, fields(name = self.name))]
pub fn iter(&self) -> super::RawIterTree {
counter!("armour_db_rawtree_range_total", "name" => self.static_name()).increment(1);
let iter = self.tree.iter();
iter.filter_map(|item| match item.into_inner() {
Ok((key, value)) => Some((key, value)),
Err(e) => {
error!(%e);
None
}
})
}
#[instrument(level = "debug", skip_all, fields(name = self.name))]
pub fn range<K: AsRef<[u8]>, R: RangeBounds<K> + std::fmt::Debug>(
&self,
range: R,
) -> super::RawIterTree {
counter!("armour_db_rawtree_range_total", "name" => self.static_name()).increment(1);
let iter = self.tree.range(range);
iter.filter_map(|item| match item.into_inner() {
Ok((key, value)) => Some((key, value)),
Err(e) => {
error!(%e);
None
}
})
}
#[instrument(level = "debug", skip_all, fields(name = self.name))]
pub fn prefix<K: AsRef<[u8]> + std::fmt::Debug>(&self, prefix: K) -> super::RawIterTree {
counter!("armour_db_rawtree_range_total", "name" => self.static_name()).increment(1);
let iter = self.tree.prefix(prefix);
iter.filter_map(|item| match item.into_inner() {
Ok((key, value)) => Some((key, value)),
Err(e) => {
error!(%e);
None
}
})
}
#[instrument(level = "debug", skip_all, fields(name = self.name))]
pub(crate) fn invalidate_hash(&self, id: &[u8]) {
let mut bytes = [0; 4];
bytes.copy_from_slice(&id[..4]);
let group = u32::from_be_bytes(bytes);
let group = g4bits(group, self.attributes.group_bits);
self.inner.invalidate_hash(group);
}
#[instrument(level = "debug", skip_all, fields(name = self.name))]
#[armour_metrics(prefix = "armour_db_rawtree", name = self.static_name())]
pub fn apply_event(&self, event: ChangeEvent) -> DbResult<()> {
match &event {
ChangeEvent::Upsert((key, val)) => {
self.tree.insert(key.clone(), val.clone())?;
}
ChangeEvent::Delete(key) => {
self.tree.remove(key.clone())?;
}
}
self.invalidate_hash(event.key());
Ok(())
}
#[instrument(level = "debug", skip_all, fields(name = self.name))]
#[armour_metrics(prefix = "armour_db_rawtree", name = self.static_name())]
pub fn apply_batch<Val>(
&self,
iter: impl Iterator<Item = (ByteValue, Option<Val>)>,
) -> DbResult<()>
where
Val: Into<UserValue>,
{
let mut tx = self.db.batch();
for (key, val) in iter {
self.invalidate_hash(&key);
match val {
Some(val) => {
tx.insert(&self.tree, key, val);
}
None => {
tx.remove(&self.tree, key);
}
}
}
tx.commit()?;
Ok(())
}
#[instrument(level = "debug", skip_all, fields(name = self.name))]
pub fn par_iter(&self, seq: Option<usize>) -> DbResult<MaybeParIter> {
counter!("armour_db_rawtree_par_iter_total", "name" => self.static_name()).increment(1);
if self.tree.is_empty()? {
return Ok(MaybeParIter::Empty);
}
let seq = seq.unwrap_or_else(|| self.count());
if seq < MIN_PAR_SIZE {
return Ok(MaybeParIter::Seq);
}
let chunk_count = seq / MIN_PAR_SIZE;
let cpus = num_cpus::get();
let workers_count = chunk_count.min(cpus);
if workers_count < 2 {
return Ok(MaybeParIter::Seq);
}
let first = self.tree.first_key_value().ok_or(DbError::Empty)?;
let first_key = first.key()?;
let last = self.tree.last_key_value().ok_or(DbError::Empty)?;
let last_key = last.key()?;
let first = &first_key.as_ref()[..4];
let first = u32::from_be_bytes(first.try_into().expect("Invalid byte array length"));
let first_group = g4bits(first, self.attributes.group_bits);
let last = &last_key.as_ref()[..4];
let last = u32::from_be_bytes(last.try_into().expect("Invalid byte array length"));
let last_group = g4bits(last, self.attributes.group_bits);
let diff = last_group - first_group;
if diff < 2 {
return Ok(MaybeParIter::Seq);
}
let step = diff / (workers_count as u32);
if step == 0 {
return Ok(MaybeParIter::Seq);
}
let mut arr = Vec::with_capacity(workers_count);
let mut start = first_group;
while start <= last_group {
let end = start + step;
let start_bound = Bound::Included(start.to_be_bytes());
let end_bound = Bound::Excluded(end.to_be_bytes());
let bounds = (start_bound, end_bound);
arr.push(bounds);
start = end;
}
info!(workers_count, step, first_group, last_group, "par_iter");
let res = arr.into_par_iter();
Ok(MaybeParIter::Par(res))
}
#[instrument(level = "debug", skip_all, fields(name = self.name))]
#[armour_metrics(prefix = "armour_db_rawtree", name = self.static_name())]
pub fn next_id(&self) -> DbResult<u64> {
const NEXT_ID: &str = "next_id";
let name = format!("__{}-{}", NEXT_ID, self.name);
let _lock = self.seq_lock.lock();
let val = self.db.seq_tree.get(&name)?;
let val = match val {
Some(val) => {
let bytes = val.as_ref().try_into().expect("Invalid byte array length");
u64::from_le_bytes(bytes)
}
None => 0,
};
let bytes = (val + 1).to_le_bytes();
self.db.seq_tree.insert(&name, bytes)?;
Ok(val)
}
#[doc(hidden)]
pub fn inner(&self) -> &Keyspace {
&self.tree
}
}
pub const MIN_PAR_SIZE: usize = 2usize.pow(16);