use core::hash::Hash;
use std::collections::{HashSet};
use serde::{Serialize};
use bincode::Options;
use rocksdb::{DB, DBWithThreadMode, ColumnFamily, ColumnFamilyDescriptor, MergeOperands};
use super::bincode_helpers::{*};
use super::records::{*};
use super::key_groups::{*};
use super::perf_counters::{*};
pub const KEYS_CF_NAME : &str = "keys";
pub const RECORD_DATA_CF_NAME : &str = "rec_data";
pub const VALUES_CF_NAME : &str = "values";
pub const VARIANTS_CF_NAME : &str = "variants";
pub struct DBConnection {
db : DBWithThreadMode<rocksdb::SingleThreaded>,
path : String,
}
impl DBConnection {
pub fn new(path : &str) -> Result<Self, String> {
let keys_cf = ColumnFamilyDescriptor::new(KEYS_CF_NAME, rocksdb::Options::default());
let rec_data_cf = ColumnFamilyDescriptor::new(RECORD_DATA_CF_NAME, rocksdb::Options::default());
let values_cf = ColumnFamilyDescriptor::new(VALUES_CF_NAME, rocksdb::Options::default());
let mut variants_opts = rocksdb::Options::default();
variants_opts.create_if_missing(true);
variants_opts.set_merge_operator_associative("append to RecordID vec", variant_append_merge);
let variants_cf = ColumnFamilyDescriptor::new(VARIANTS_CF_NAME, variants_opts);
let mut db_opts = rocksdb::Options::default();
db_opts.create_missing_column_families(true);
db_opts.create_if_missing(true);
let db = DB::open_cf_descriptors(&db_opts, path, vec![keys_cf, rec_data_cf, values_cf, variants_cf])?;
Ok(Self{
db,
path : path.to_string(),
})
}
pub fn reset_database(&mut self) -> Result<(), String> {
self.db.drop_cf(KEYS_CF_NAME)?;
self.db.drop_cf(RECORD_DATA_CF_NAME)?;
self.db.drop_cf(VALUES_CF_NAME)?;
self.db.drop_cf(VARIANTS_CF_NAME)?;
self.db.create_cf(KEYS_CF_NAME, &rocksdb::Options::default())?;
self.db.create_cf(RECORD_DATA_CF_NAME, &rocksdb::Options::default())?;
self.db.create_cf(VALUES_CF_NAME, &rocksdb::Options::default())?;
let mut variants_opts = rocksdb::Options::default();
variants_opts.create_if_missing(true);
variants_opts.set_merge_operator_associative("append to RecordID vec", variant_append_merge);
self.db.create_cf(VARIANTS_CF_NAME, &variants_opts)?;
Ok(())
}
pub fn record_count(&self) -> Result<usize, String> {
let rec_data_cf_handle = self.db.cf_handle(RECORD_DATA_CF_NAME).unwrap();
let record_count = probe_for_max_sequential_key(&self.db, rec_data_cf_handle, 255)?;
Ok(record_count)
}
#[inline(always)]
pub fn get_record_key_groups(&self, record_id : RecordID) -> Result<impl Iterator<Item=KeyGroupID>, String> {
let rec_data_cf_handle = self.db.cf_handle(RECORD_DATA_CF_NAME).unwrap();
if let Some(rec_data_vec_bytes) = self.db.get_pinned_cf(rec_data_cf_handle, record_id.to_le_bytes())? {
let record_coder = bincode::DefaultOptions::new().with_varint_encoding().with_little_endian();
let rec_data : RecordData = record_coder.deserialize(&rec_data_vec_bytes).unwrap();
if !rec_data.key_groups.is_empty() {
Ok(rec_data.key_groups.into_iter().map(move |group_idx| KeyGroupID::from_record_and_idx(record_id, group_idx)))
} else {
Err("Invalid record_id".to_string())
}
} else {
Err("Invalid record_id".to_string())
}
}
pub fn put_record_key_groups(&self, record_id : RecordID, key_groups_vec : &[usize]) -> Result<(), String> {
let rec_data_cf_handle = self.db.cf_handle(RECORD_DATA_CF_NAME).unwrap();
let record_coder = bincode::DefaultOptions::new().with_varint_encoding().with_little_endian();
let new_rec_data = RecordData::new(key_groups_vec);
let rec_data_bytes = record_coder.serialize(&new_rec_data).unwrap();
self.db.put_cf(rec_data_cf_handle, record_id.to_le_bytes(), rec_data_bytes)?;
Ok(())
}
#[inline(always)]
#[allow(unused_variables)] pub fn get_keys_in_group<OwnedKeyT : 'static + Sized + Serialize + serde::de::DeserializeOwned>(&self, key_group : KeyGroupID, perf_counters : &PerfCounters) -> Result<impl Iterator<Item=OwnedKeyT>, String> {
let keys_cf_handle = self.db.cf_handle(KEYS_CF_NAME).unwrap();
if let Some(keys_vec_bytes) = self.db.get_pinned_cf(keys_cf_handle, key_group.to_le_bytes())? {
let record_coder = bincode::DefaultOptions::new().with_varint_encoding().with_little_endian();
let keys_vec : Vec<OwnedKeyT> = record_coder.deserialize(&keys_vec_bytes).unwrap();
#[cfg(feature = "perf_counters")]
{
let mut counter_fields = perf_counters.get();
counter_fields.key_group_load_count += 1;
counter_fields.keys_found_count += keys_vec.len();
perf_counters.set(counter_fields);
}
if !keys_vec.is_empty() {
Ok(keys_vec.into_iter())
} else {
Err("Invalid record_id".to_string())
}
} else {
Err("Invalid record_id".to_string())
}
}
#[inline(always)]
pub fn keys_count_in_group(&self, key_group : KeyGroupID) -> Result<usize, String> {
let keys_cf_handle = self.db.cf_handle(KEYS_CF_NAME).unwrap();
if let Some(keys_vec_bytes) = self.db.get_pinned_cf(keys_cf_handle, key_group.to_le_bytes())? {
let mut skip_bytes = 0;
let keys_count = bincode_u64_le_varint(&keys_vec_bytes, &mut skip_bytes);
Ok(keys_count as usize)
} else {
panic!(); }
}
pub fn put_key_group_entry<K : Eq + Hash + Serialize>(&mut self, key_group_id : KeyGroupID, raw_keys : &HashSet<K>) -> Result<(), String> {
let record_coder = bincode::DefaultOptions::new().with_varint_encoding().with_little_endian();
let keys_bytes = record_coder.serialize(&raw_keys).unwrap();
let keys_cf_handle = self.db.cf_handle(KEYS_CF_NAME).unwrap();
self.db.put_cf(keys_cf_handle, key_group_id.to_le_bytes(), keys_bytes)?;
Ok(())
}
pub fn delete_key_group_entry(&mut self, key_group : KeyGroupID) -> Result<(), String> {
let keys_cf_handle = self.db.cf_handle(KEYS_CF_NAME).unwrap();
self.db.delete_cf(keys_cf_handle, key_group.to_le_bytes())?;
Ok(())
}
#[inline(always)]
pub fn get_value<ValueT : 'static + Serialize + serde::de::DeserializeOwned>(&self, record_id : RecordID) -> Result<ValueT, String> {
let values_cf_handle = self.db.cf_handle(VALUES_CF_NAME).unwrap();
if let Some(value_bytes) = self.db.get_pinned_cf(values_cf_handle, record_id.to_le_bytes())? {
let record_coder = bincode::DefaultOptions::new().with_varint_encoding().with_little_endian();
let value : ValueT = record_coder.deserialize(&value_bytes).unwrap();
Ok(value)
} else {
Err("Invalid record_id".to_string())
}
}
pub fn delete_value(&mut self, record_id : RecordID) -> Result<(), String> {
let value_cf_handle = self.db.cf_handle(VALUES_CF_NAME).unwrap();
self.db.delete_cf(value_cf_handle, record_id.to_le_bytes())?;
Ok(())
}
pub fn put_value<ValueT : 'static + Serialize + serde::de::DeserializeOwned>(&mut self, record_id : RecordID, value : &ValueT) -> Result<(), String> {
let value_cf_handle = self.db.cf_handle(VALUES_CF_NAME).unwrap();
let record_coder = bincode::DefaultOptions::new().with_varint_encoding().with_little_endian();
let value_bytes = record_coder.serialize(value).unwrap();
self.db.put_cf(value_cf_handle, record_id.to_le_bytes(), value_bytes)?;
Ok(())
}
#[inline(always)]
pub fn visit_variants<F : FnMut(&[u8])>(&self, variants : HashSet<Vec<u8>>, mut visitor_closure : F) -> Result<(), String> {
let variants_cf_handle = self.db.cf_handle(VARIANTS_CF_NAME).unwrap();
for variant in variants {
if let Some(variant_vec_bytes) = self.db.get_pinned_cf(variants_cf_handle, variant)? {
visitor_closure(&variant_vec_bytes);
}
}
Ok(())
}
#[inline(always)]
pub fn visit_exact_variant<F : FnMut(&[u8])>(&self, variant : &[u8], mut visitor_closure : F) -> Result<(), String> {
let variants_cf_handle = self.db.cf_handle(VARIANTS_CF_NAME).unwrap();
if let Some(variant_vec_bytes) = self.db.get_pinned_cf(variants_cf_handle, variant)? {
visitor_closure(&variant_vec_bytes);
}
Ok(())
}
pub fn delete_variant_references(&mut self, key_group : KeyGroupID, variants : HashSet<Vec<u8>>) -> Result<(), String> {
let variants_cf_handle = self.db.cf_handle(VARIANTS_CF_NAME).unwrap();
for variant in variants.iter() {
if let Some(variant_entry_bytes) = self.db.get_pinned_cf(variants_cf_handle, variant)? {
let variant_entry_len = bincode_vec_fixint_len(&variant_entry_bytes);
if variant_entry_len > 1 {
let mut new_vec : Vec<KeyGroupID> = Vec::with_capacity(variant_entry_len-1);
for key_group_id_bytes in bincode_vec_iter::<KeyGroupID>(&variant_entry_bytes) {
let other_key_group_id = KeyGroupID::from(usize::from_le_bytes(key_group_id_bytes.try_into().unwrap()));
if other_key_group_id != key_group {
new_vec.push(other_key_group_id);
}
}
let vec_coder = bincode::DefaultOptions::new().with_fixint_encoding().with_little_endian();
self.db.put_cf(variants_cf_handle, variant, vec_coder.serialize(&new_vec).unwrap())?;
} else {
self.db.delete_cf(variants_cf_handle, variant)?;
}
}
}
Ok(())
}
pub fn put_variant_references(&mut self, key_group : KeyGroupID, variants : HashSet<Vec<u8>>) -> Result<(), String> {
fn new_variant_vec(key_group : KeyGroupID) -> Vec<u8> {
let new_vec = vec![key_group];
let vec_coder = bincode::DefaultOptions::new().with_fixint_encoding().with_little_endian();
vec_coder.serialize(&new_vec).unwrap()
}
let variants_cf_handle = self.db.cf_handle(VARIANTS_CF_NAME).unwrap();
for variant in variants {
let val_bytes = new_variant_vec(key_group);
self.db.merge_cf(variants_cf_handle, variant, val_bytes)?;
}
Ok(())
}
}
impl Drop for DBConnection {
fn drop(&mut self) {
self.db.flush().unwrap();
let _ = DB::destroy(&rocksdb::Options::default(), self.path.as_str());
}
}
fn variant_append_merge(_key: &[u8], existing_val: Option<&[u8]>, operands: &mut MergeOperands) -> Option<Vec<u8>> {
let vec_coder = bincode::DefaultOptions::new().with_fixint_encoding().with_little_endian();
let mut variant_vec = if let Some(existing_bytes) = existing_val {
let new_vec : HashSet<KeyGroupID> = vec_coder.deserialize(existing_bytes).unwrap();
new_vec
} else {
HashSet::with_capacity(operands.size_hint().0)
};
for op in operands {
let operand_vec : HashSet<KeyGroupID> = vec_coder.deserialize(op).unwrap();
variant_vec.extend(operand_vec);
}
let result = vec_coder.serialize(&variant_vec).unwrap();
Some(result)
}
fn probe_for_max_sequential_key(db : &DBWithThreadMode<rocksdb::SingleThreaded>, cf : &ColumnFamily, starting_hint : usize) -> Result<usize, rocksdb::Error> {
let mut min = 0;
let mut max = usize::MAX;
debug_assert!(::std::mem::size_of::<usize>() == 8);
let mut guess_max = if starting_hint > 0xFFFFFFFF {
usize::MAX
} else if starting_hint < 1 {
1
} else {
starting_hint * starting_hint
};
let mut cur_val = starting_hint;
loop {
if max == min {
return Ok(cur_val)
}
if let Some(_value) = db.get_pinned_cf(cf, cur_val.to_le_bytes())? {
min = cur_val + 1;
if guess_max < max/2 {
guess_max *= 2;
} else {
guess_max = max;
}
} else {
max = cur_val;
guess_max = max;
if max == min {
return Ok(cur_val)
}
}
cur_val = ((guess_max - min) / 2) + min;
}
}