use core::hash::Hash;
use std::collections::HashSet;
use serde::Serialize;
use rocksdb::{DB, DBWithThreadMode, ColumnFamily, ColumnFamilyDescriptor, MergeOperands};
use crate::table_config::TableMetadata;
use crate::TableConfig;
use crate::encode_decode::internal_coder;
use super::encode_decode::Coder;
use super::records::{*};
use super::key_groups::{*};
use super::perf_counters::{*};
pub const KEYS_CF_NAME : &str = "keys";
pub const RECORD_DATA_CF_NAME : &str = "rec_data";
pub const VALUES_CF_NAME : &str = "values";
pub const METADATA_CF_NAME : &str = "metadata";
pub const VARIANTS_CF_NAME : &str = "variants";
pub struct DBConnection<C: Coder + Send + Sync> {
db : DBWithThreadMode<rocksdb::SingleThreaded>,
path : String,
value_coder: C,
}
impl<C: Coder> DBConnection<C> {
pub fn new(value_coder: C, path : &str) -> Result<Self, String> {
let keys_cf = ColumnFamilyDescriptor::new(KEYS_CF_NAME, rocksdb::Options::default());
let rec_data_cf = ColumnFamilyDescriptor::new(RECORD_DATA_CF_NAME, rocksdb::Options::default());
let values_cf = ColumnFamilyDescriptor::new(VALUES_CF_NAME, rocksdb::Options::default());
let version_cf = ColumnFamilyDescriptor::new(METADATA_CF_NAME, rocksdb::Options::default());
let mut variants_opts = rocksdb::Options::default();
variants_opts.create_if_missing(true);
variants_opts.set_merge_operator_associative("append to RecordID vec",
move |key: &[u8], existing_val: Option<&[u8]>, operands: &MergeOperands| -> Option<Vec<u8>> {
variant_append_merge(key, existing_val, operands)
});
let variants_cf = ColumnFamilyDescriptor::new(VARIANTS_CF_NAME, variants_opts);
let mut db_opts = rocksdb::Options::default();
db_opts.create_missing_column_families(true);
db_opts.create_if_missing(true);
let db = DB::open_cf_descriptors(&db_opts, path, vec![keys_cf, rec_data_cf, values_cf, variants_cf, version_cf])?;
Ok(Self{
db,
path : path.to_string(),
value_coder
})
}
pub fn reset_database(&mut self) -> Result<(), String> {
self.db.drop_cf(KEYS_CF_NAME)?;
self.db.drop_cf(RECORD_DATA_CF_NAME)?;
self.db.drop_cf(VALUES_CF_NAME)?;
self.db.drop_cf(VARIANTS_CF_NAME)?;
self.db.create_cf(KEYS_CF_NAME, &rocksdb::Options::default())?;
self.db.create_cf(RECORD_DATA_CF_NAME, &rocksdb::Options::default())?;
self.db.create_cf(VALUES_CF_NAME, &rocksdb::Options::default())?;
let mut variants_opts = rocksdb::Options::default();
variants_opts.create_if_missing(true);
variants_opts.set_merge_operator_associative("append to RecordID vec",
move |key: &[u8], existing_val: Option<&[u8]>, operands: &MergeOperands| -> Option<Vec<u8>> {
variant_append_merge(key, existing_val, operands)
});
self.db.create_cf(VARIANTS_CF_NAME, &variants_opts)?;
Ok(())
}
pub fn record_count(&self) -> Result<usize, String> {
let rec_data_cf_handle = self.db.cf_handle(RECORD_DATA_CF_NAME).unwrap();
let record_count = probe_for_max_sequential_key(&self.db, rec_data_cf_handle, 255)?;
Ok(record_count)
}
pub fn table_is_empty(&self) -> Result<bool, String> {
let rec_data_cf_handle = self.db.cf_handle(RECORD_DATA_CF_NAME).unwrap();
Ok(self.db.get_pinned_cf(rec_data_cf_handle, RecordID(0).to_le_bytes())?.is_none())
}
#[inline(always)]
pub fn get_record_key_groups(&self, record_id : RecordID) -> Result<impl Iterator<Item=KeyGroupID>, String> {
let rec_data_cf_handle = self.db.cf_handle(RECORD_DATA_CF_NAME).unwrap();
if let Some(rec_data_vec_bytes) = self.db.get_pinned_cf(rec_data_cf_handle, record_id.to_le_bytes())? {
let rec_data : RecordData = internal_coder!().decode_fmt1_from_bytes(&rec_data_vec_bytes).unwrap();
if !rec_data.key_groups.is_empty() {
Ok(rec_data.key_groups.into_iter().map(move |group_idx| KeyGroupID::from_record_and_idx(record_id, group_idx)))
} else {
Err("Invalid record_id".to_string())
}
} else {
Err("Invalid record_id".to_string())
}
}
pub fn put_record_key_groups(&self, record_id : RecordID, key_groups_vec : &[usize]) -> Result<(), String> {
let rec_data_cf_handle = self.db.cf_handle(RECORD_DATA_CF_NAME).unwrap();
let new_rec_data = RecordData::new(key_groups_vec);
let rec_data_bytes = internal_coder!().encode_fmt1_to_buf(&new_rec_data).unwrap();
self.db.put_cf(rec_data_cf_handle, record_id.to_le_bytes(), rec_data_bytes)?;
Ok(())
}
#[inline(always)]
#[allow(unused_variables)] pub fn get_keys_in_group<OwnedKeyT : 'static + Sized + Serialize + serde::de::DeserializeOwned>(&self, key_group : KeyGroupID, perf_counters : &PerfCounters) -> Result<impl Iterator<Item=OwnedKeyT>, String> {
let keys_cf_handle = self.db.cf_handle(KEYS_CF_NAME).unwrap();
if let Some(keys_vec_bytes) = self.db.get_pinned_cf(keys_cf_handle, key_group.to_le_bytes())? {
let keys_vec : Vec<OwnedKeyT> = internal_coder!().decode_fmt1_from_bytes(&keys_vec_bytes).unwrap();
#[cfg(feature = "perf_counters")]
{
let mut counter_fields = perf_counters.get();
counter_fields.key_group_load_count += 1;
counter_fields.keys_found_count += keys_vec.len();
perf_counters.set(counter_fields);
}
if !keys_vec.is_empty() {
Ok(keys_vec.into_iter())
} else {
Err("Invalid record_id".to_string())
}
} else {
Err("Invalid record_id".to_string())
}
}
#[inline(always)]
pub fn keys_count_in_group(&self, key_group : KeyGroupID) -> Result<usize, String> {
let keys_cf_handle = self.db.cf_handle(KEYS_CF_NAME).unwrap();
if let Some(keys_vec_bytes) = self.db.get_pinned_cf(keys_cf_handle, key_group.to_le_bytes())? {
Ok(internal_coder!().fmt1_list_len(&keys_vec_bytes).unwrap())
} else {
panic!(); }
}
pub fn put_key_group_entry<K : Eq + Hash + Serialize>(&mut self, key_group_id : KeyGroupID, raw_keys : &HashSet<K>) -> Result<(), String> {
let keys_bytes = internal_coder!().encode_fmt1_list_to_buf(&raw_keys).unwrap();
let keys_cf_handle = self.db.cf_handle(KEYS_CF_NAME).unwrap();
self.db.put_cf(keys_cf_handle, key_group_id.to_le_bytes(), keys_bytes)?;
Ok(())
}
pub fn delete_key_group_entry(&mut self, key_group : KeyGroupID) -> Result<(), String> {
let keys_cf_handle = self.db.cf_handle(KEYS_CF_NAME).unwrap();
self.db.delete_cf(keys_cf_handle, key_group.to_le_bytes())?;
Ok(())
}
#[inline(always)]
pub fn get_value<ValueT : 'static + Serialize + serde::de::DeserializeOwned>(&self, record_id : RecordID) -> Result<ValueT, String> {
let values_cf_handle = self.db.cf_handle(VALUES_CF_NAME).unwrap();
if let Some(value_bytes) = self.db.get_pinned_cf(values_cf_handle, record_id.to_le_bytes())? {
let value : ValueT = self.value_coder.decode_fmt1_from_bytes(&value_bytes).unwrap();
Ok(value)
} else {
Err("Invalid record_id".to_string())
}
}
pub fn delete_value(&mut self, record_id : RecordID) -> Result<(), String> {
let value_cf_handle = self.db.cf_handle(VALUES_CF_NAME).unwrap();
self.db.delete_cf(value_cf_handle, record_id.to_le_bytes())?;
Ok(())
}
pub fn put_value<ValueT : 'static + Serialize + serde::de::DeserializeOwned>(&mut self, record_id : RecordID, value : &ValueT) -> Result<(), String> {
let value_cf_handle = self.db.cf_handle(VALUES_CF_NAME).unwrap();
let value_bytes = self.value_coder.encode_fmt1_to_buf(value).unwrap();
self.db.put_cf(value_cf_handle, record_id.to_le_bytes(), value_bytes)?;
Ok(())
}
pub fn get_version(&self) -> Result<String, String> {
let version_cf_handle = self.db.cf_handle(METADATA_CF_NAME).unwrap();
if let Some(value_bytes) = self.db.get_pinned_cf(version_cf_handle, [])? {
let value = String::from_utf8(value_bytes.to_owned()).unwrap();
Ok(value)
} else {
Err("Error: No config metadata found for table. The table was likely built with an old version of fuzzy_rocks".to_string())
}
}
pub fn put_version(&mut self) -> Result<(), String> {
let value_cf_handle = self.db.cf_handle(METADATA_CF_NAME).unwrap();
let value_bytes = env!("CARGO_PKG_VERSION").as_bytes();
self.db.put_cf(value_cf_handle, [], value_bytes)?;
Ok(())
}
pub fn get_config(&self) -> Result<TableMetadata, String> {
let metadata_cf_handle = self.db.cf_handle(METADATA_CF_NAME).unwrap();
if let Some(config_bytes) = self.db.get_pinned_cf(metadata_cf_handle, [1])? {
let config = internal_coder!().decode_fmt1_from_bytes(&config_bytes)
.map_err(|_e| format!("Error decoding table metadata, probably caused by a change to fuzzy_rocks coder feature flags"))?;
Ok(config)
} else {
Err("Error: No config metadata found for table. The table was likely built with an old version of fuzzy_rocks".to_string())
}
}
pub fn put_config<ConfigT : TableConfig>(&mut self) -> Result<(), String> {
let metadata_cf_handle = self.db.cf_handle(METADATA_CF_NAME).unwrap();
let config = ConfigT::metadata();
let config_bytes = internal_coder!().encode_fmt1_to_buf(&config).unwrap();
self.db.put_cf(metadata_cf_handle, [1], config_bytes)?;
Ok(())
}
#[inline(always)]
pub fn visit_variants<F : FnMut(&[u8])>(&self, variants : HashSet<Vec<u8>>, mut visitor_closure : F) -> Result<(), String> {
let variants_cf_handle = self.db.cf_handle(VARIANTS_CF_NAME).unwrap();
for variant in variants {
if let Some(variant_vec_bytes) = self.db.get_pinned_cf(variants_cf_handle, variant)? {
visitor_closure(&variant_vec_bytes);
}
}
Ok(())
}
#[inline(always)]
pub fn visit_exact_variant<F : FnMut(&[u8])>(&self, variant : &[u8], mut visitor_closure : F) -> Result<(), String> {
let variants_cf_handle = self.db.cf_handle(VARIANTS_CF_NAME).unwrap();
if let Some(variant_vec_bytes) = self.db.get_pinned_cf(variants_cf_handle, variant)? {
visitor_closure(&variant_vec_bytes);
}
Ok(())
}
pub fn delete_variant_references(&mut self, key_group : KeyGroupID, variants : HashSet<Vec<u8>>) -> Result<(), String> {
let variants_cf_handle = self.db.cf_handle(VARIANTS_CF_NAME).unwrap();
for variant in variants.iter() {
if let Some(variant_entry_bytes) = self.db.get_pinned_cf(variants_cf_handle, variant)? {
let variant_entry_len = internal_coder!().fmt2_list_len(&variant_entry_bytes).unwrap();
if variant_entry_len > 1 {
let mut new_vec : Vec<KeyGroupID> = Vec::with_capacity(variant_entry_len-1);
let decoded_vec : Vec<KeyGroupID> = internal_coder!().decode_fmt2_from_bytes(&variant_entry_bytes).unwrap();
for other_key_group_id in decoded_vec {
if other_key_group_id != key_group {
new_vec.push(other_key_group_id);
}
}
self.db.put_cf(variants_cf_handle, variant, internal_coder!().encode_fmt2_list_to_buf(&new_vec).unwrap())?;
} else {
self.db.delete_cf(variants_cf_handle, variant)?;
}
}
}
Ok(())
}
pub fn put_variant_references(&mut self, key_group : KeyGroupID, variants : HashSet<Vec<u8>>) -> Result<(), String> {
let new_variant_vec = |key_group : KeyGroupID| -> Vec<u8> {
internal_coder!().encode_fmt2_list_to_buf(&vec![key_group]).unwrap()
};
let variants_cf_handle = self.db.cf_handle(VARIANTS_CF_NAME).unwrap();
for variant in variants {
let val_bytes = new_variant_vec(key_group);
self.db.merge_cf(variants_cf_handle, variant, val_bytes)?;
}
Ok(())
}
}
impl<C: Coder + Send + Sync> Drop for DBConnection<C> {
fn drop(&mut self) {
self.db.flush().unwrap();
let _ = DB::destroy(&rocksdb::Options::default(), self.path.as_str());
}
}
fn variant_append_merge(_key: &[u8], existing_val: Option<&[u8]>, operands: &MergeOperands) -> Option<Vec<u8>> {
let operands_iter = operands.into_iter();
let mut variant_vec = if let Some(existing_bytes) = existing_val {
let new_vec : HashSet<KeyGroupID> = internal_coder!().decode_fmt2_from_bytes(existing_bytes).unwrap();
new_vec
} else {
HashSet::with_capacity(operands_iter.size_hint().0)
};
for op in operands_iter {
let operand_vec : HashSet<KeyGroupID> = internal_coder!().decode_fmt2_from_bytes(op).unwrap();
variant_vec.extend(operand_vec);
}
internal_coder!().encode_fmt2_to_buf(&variant_vec).ok()
}
fn probe_for_max_sequential_key(db : &DBWithThreadMode<rocksdb::SingleThreaded>, cf : &ColumnFamily, starting_hint : usize) -> Result<usize, rocksdb::Error> {
let mut min = 0;
let mut max = usize::MAX;
debug_assert!(::std::mem::size_of::<usize>() == 8);
let mut guess_max = if starting_hint > 0xFFFFFFFF {
usize::MAX
} else if starting_hint < 1 {
1
} else {
starting_hint * starting_hint
};
let mut cur_val = starting_hint;
loop {
if max == min {
return Ok(cur_val)
}
if let Some(_value) = db.get_pinned_cf(cf, cur_val.to_le_bytes())? {
min = cur_val + 1;
if guess_max < max/2 {
guess_max *= 2;
} else {
guess_max = max;
}
} else {
max = cur_val;
guess_max = max;
if max == min {
return Ok(cur_val)
}
}
cur_val = ((guess_max - min) / 2) + min;
}
}