use std::cmp::max;
use std::collections::HashSet;
use std::hash::{BuildHasher, BuildHasherDefault};
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
use byteorder::{LittleEndian, WriteBytesExt};
use log::{info, trace};
use rayon::prelude::*;
use rocksdb::MergeOperands;
use crate::Result;
use crate::collection::{Collection, CollectionSet};
use crate::encodings::{Color, Idx};
use crate::index::revindex::{
self as module, CounterGather, DatasetPicklist, Datasets, DbStats, QueryColors, RevIndexOps,
stats_for_cf,
};
use crate::index::{GatherResult, SigCounter, calculate_gather_stats};
use crate::manifest::Manifest;
use crate::prelude::*;
use crate::sketch::Sketch;
use crate::sketch::minhash::{KmerMinHash, KmerMinHashBTree};
use crate::storage::{
InnerStorage, RocksDBStorage, Storage,
rocksdb::{ALL_CFS, DB, HASHES, METADATA, cf_descriptors, db_options},
};
const DB_VERSION: u8 = 1;
const MANIFEST: &str = "manifest";
const STORAGE_SPEC: &str = "storage_spec";
const VERSION: &str = "version";
const PROCESSED: &str = "processed";
fn compute_color(idxs: &Datasets) -> Color {
let s = BuildHasherDefault::<crate::encodings::Xxh3Hash128>::default();
s.hash_one(idxs)
}
#[derive(Clone)]
pub struct DiskRevIndex {
location: String,
db: Arc<DB>,
collection: Arc<CollectionSet>,
processed: Arc<RwLock<Datasets>>,
}
pub(crate) fn merge_datasets(
_: &[u8],
existing_val: Option<&[u8]>,
operands: &MergeOperands,
) -> Option<Vec<u8>> {
let mut datasets = match existing_val {
Some(val) => Datasets::from_slice(val).expect("cannot unpack slice"),
None => Default::default(),
};
for op in operands {
let new_vals = Datasets::from_slice(op).unwrap();
datasets.union(new_vals);
}
datasets.as_bytes()
}
impl DiskRevIndex {
pub fn create(path: &Path, collection: CollectionSet) -> Result<module::RevIndex> {
let mut opts = db_options();
opts.create_if_missing(true);
opts.create_missing_column_families(true);
let cfs = cf_descriptors();
let db = Arc::new(DB::open_cf_descriptors(&opts, path, cfs).unwrap());
let processed_sigs = AtomicUsize::new(0);
let collection = Arc::new(collection);
let processed = Arc::new(RwLock::new(Self::load_processed(
db.clone(),
collection.clone(),
true,
)?));
let index = Self {
location: String::from(path.to_str().expect("cannot extract path")),
db,
collection,
processed: processed.clone(),
};
index.collection.par_iter().for_each(|(dataset_id, _)| {
if !processed.read().unwrap().contains(&dataset_id) {
let i = processed_sigs.fetch_add(1, Ordering::SeqCst);
if i % 1000 == 0 {
info!("Processed {} reference sigs", i);
}
index.map_hashes_colors(dataset_id as Idx);
processed.write().unwrap().extend([dataset_id]);
}
});
index.save_collection().expect("Error saving collection");
info!("Compact SSTs");
index.compact();
info!(
"Done! Processed {} reference sigs",
processed_sigs.into_inner()
);
Ok(module::RevIndex::Disk(index))
}
pub fn open<P: AsRef<Path>>(
path: P,
read_only: bool,
storage_spec: Option<&str>,
) -> Result<module::RevIndex> {
let mut opts = db_options();
opts.create_if_missing(true);
opts.create_missing_column_families(true);
let cfs = cf_descriptors();
let db = if read_only {
Arc::new(DB::open_cf_descriptors_read_only(
&opts,
path.as_ref(),
cfs,
false,
)?)
} else {
Arc::new(DB::open_cf_descriptors(&opts, path.as_ref(), cfs)?)
};
let collection = Arc::new(Self::load_collection_from_rocksdb(
db.clone(),
storage_spec,
)?);
let processed = Arc::new(RwLock::new(Self::load_processed(
db.clone(),
collection.clone(),
false,
)?));
Ok(module::RevIndex::Disk(Self {
location: String::from(path.as_ref().to_str().expect("cannot extract path")),
db,
collection,
processed,
}))
}
pub unsafe fn db(&self) -> Arc<DB> {
self.db.clone()
}
fn load_processed(
db: Arc<DB>,
collection: Arc<CollectionSet>,
assume_empty: bool,
) -> Result<Datasets> {
let cf_metadata = db.cf_handle(METADATA).unwrap();
if let Some(rdr) = db.get_pinned_cf(&cf_metadata, PROCESSED)? {
Datasets::from_slice(&rdr)
} else if assume_empty {
Ok(Datasets::default())
} else {
let all_datasets: Vec<_> = (0..collection.manifest().len()).map(|v| v as Idx).collect();
Ok(Datasets::new(&all_datasets))
}
}
fn load_collection_from_rocksdb(
db: Arc<DB>,
storage_spec: Option<&str>,
) -> Result<CollectionSet> {
let cf_metadata = db.cf_handle(METADATA).unwrap();
let rdr = db.get_cf(&cf_metadata, VERSION)?.unwrap();
assert_eq!(rdr[0], DB_VERSION);
let rdr = db.get_cf(&cf_metadata, MANIFEST)?.unwrap();
let manifest = Manifest::from_reader(&rdr[..])?;
let spec = match storage_spec {
Some(spec) => spec.into(),
None => {
let db_spec = db.get_cf(&cf_metadata, STORAGE_SPEC)?;
String::from_utf8(db_spec.unwrap()).map_err(|e| e.utf8_error())?
}
};
let storage = if spec == "rocksdb://" {
InnerStorage::new(RocksDBStorage::from_db(db.clone()))
} else {
InnerStorage::from_spec(spec)?
};
Collection::new(manifest, storage).try_into()
}
fn save_collection(&self) -> Result<()> {
let cf_metadata = self.db.cf_handle(METADATA).unwrap();
self.db.put_cf(&cf_metadata, VERSION, [DB_VERSION])?;
let mut wtr = vec![];
{
self.collection.manifest().to_writer(&mut wtr)?;
}
self.db.put_cf(&cf_metadata, MANIFEST, &wtr[..])?;
let spec = self.collection.storage().spec();
self.db.put_cf(&cf_metadata, STORAGE_SPEC, spec)?;
Ok(())
}
fn map_hashes_colors(&self, dataset_id: Idx) {
let search_sig = self
.collection
.sig_for_dataset(dataset_id)
.expect("Couldn't find a compatible Signature");
let search_mh = &search_sig.sketches()[0];
let colors = Datasets::new(&[dataset_id]).as_bytes().unwrap();
let cf_hashes = self.db.cf_handle(HASHES).unwrap();
let hashes = match search_mh {
Sketch::MinHash(mh) => mh.mins(),
Sketch::LargeMinHash(mh) => mh.mins(),
_ => unimplemented!(),
};
let mut hash_bytes = [0u8; 8];
for hash in hashes {
(&mut hash_bytes[..])
.write_u64::<LittleEndian>(hash)
.expect("error writing bytes");
self.db
.merge_cf(&cf_hashes, &hash_bytes[..], colors.as_slice())
.expect("error merging");
}
let cf_metadata = self.db.cf_handle(METADATA).unwrap();
self.db
.merge_cf(&cf_metadata, PROCESSED, colors.as_slice())
.expect("error merging");
}
}
impl RevIndexOps for DiskRevIndex {
fn location(&self) -> &str {
self.location.as_str()
}
fn counter_for_query(
&self,
query: &KmerMinHash,
picklist: Option<DatasetPicklist>,
) -> SigCounter {
info!("Collecting hashes");
let cf_hashes = self.db.cf_handle(HASHES).unwrap();
let hashes_iter = query.iter_mins().map(|hash| {
let mut v = vec![0_u8; 8];
(&mut v[..])
.write_u64::<LittleEndian>(*hash)
.expect("error writing bytes");
(&cf_hashes, v)
});
info!("Multi get");
self.db
.multi_get_cf(hashes_iter)
.into_iter()
.filter_map(|r| r.ok().unwrap_or(None))
.flat_map(|raw_datasets| {
let new_vals = Datasets::from_slice(&raw_datasets).unwrap();
if let Some(pl) = &picklist {
let new_vals: HashSet<_> = new_vals
.into_iter()
.filter(|&i| pl.dataset_ids.contains(&i))
.collect();
Box::new(new_vals.into_iter())
} else {
new_vals.into_iter()
}
})
.collect()
}
fn prepare_gather_counters(
&self,
query: &KmerMinHash,
picklist: Option<DatasetPicklist>,
) -> CounterGather {
let cf_hashes = self.db.cf_handle(HASHES).unwrap();
let hashes_iter = query.iter_mins().map(|hash| {
let mut v = vec![0_u8; 8];
(&mut v[..])
.write_u64::<LittleEndian>(*hash)
.expect("error writing bytes");
(&cf_hashes, v)
});
let mut query_colors: QueryColors = Default::default();
let mut counter: SigCounter = Default::default();
info!("Building hash_to_color and query_colors");
let hash_to_color = query
.iter_mins()
.zip(self.db.multi_get_cf(hashes_iter))
.filter_map(|(k, r)| {
let raw: Option<Vec<u8>> = r.ok().unwrap_or(None);
if let Some(r) = raw {
let mut new_vals = Datasets::from_slice(&r).unwrap();
if let Some(pl) = &picklist {
let val_set: Vec<Idx> = new_vals
.into_iter()
.filter(|&i| pl.dataset_ids.contains(&i))
.collect();
new_vals = Datasets::new(&val_set[..]);
}
if !new_vals.is_empty() {
let color = compute_color(&new_vals);
query_colors
.entry(color)
.or_insert_with(|| new_vals.clone());
counter.update(new_vals);
Some((*k, color))
} else {
None
}
} else {
None
}
})
.collect();
CounterGather {
counter,
query_colors,
hash_to_color,
}
}
fn gather(
&self,
mut cg: CounterGather,
threshold: usize,
orig_query: &KmerMinHash,
selection: Option<Selection>,
) -> Result<Vec<GatherResult>> {
let match_size = usize::MAX;
let mut matches = vec![];
let mut query = KmerMinHashBTree::from(orig_query.clone());
let mut sum_weighted_found = 0;
let _selection = selection.unwrap_or_else(|| self.collection.selection());
let total_weighted_hashes = orig_query.sum_abunds();
let calc_abund_stats = orig_query.track_abundance();
let calc_ani_ci = false;
let ani_confidence_interval_fraction = None;
while match_size > threshold && !cg.is_empty() {
trace!("counter len: {}", cg.len());
trace!("match size: {}", match_size);
let result = cg.peek(threshold);
if result.is_none() {
break;
}
let (dataset_id, match_size) = result.unwrap();
let match_sig = self.collection.sig_for_dataset(dataset_id)?;
let match_mh = match_sig.minhash().unwrap().clone();
let max_scaled = max(match_mh.scaled(), query.scaled());
let match_mh = match_mh
.downsample_scaled(max_scaled)
.expect("cannot downsample match");
query = query
.downsample_scaled(max_scaled)
.expect("cannot downsample query");
let query_mh = KmerMinHash::from(query.clone());
let gather_result_rank = matches.len() as u32;
let (gather_result, isect) = calculate_gather_stats(
orig_query,
query_mh,
match_sig,
match_size,
gather_result_rank,
sum_weighted_found,
total_weighted_hashes,
calc_abund_stats,
calc_ani_ci,
ani_confidence_interval_fraction,
)
.expect("could not calculate gather stats");
let mut isect_mh = match_mh.clone();
isect_mh.clear();
isect_mh.add_many(&isect.0)?;
sum_weighted_found = gather_result.sum_weighted_found();
matches.push(gather_result);
trace!("Preparing counter for next round");
query.remove_many(isect_mh.iter_mins().copied())?;
cg.consume(&isect_mh);
}
Ok(matches)
}
fn update(mut self, collection: CollectionSet) -> Result<module::RevIndex> {
self.collection.check_superset(&collection)?;
info!("sigs in the original index: {}", self.collection.len());
self.collection = Arc::new(collection);
info!(
"sigs in the new index once finished: {}",
self.collection.len()
);
let processed = self.processed.clone();
info!(
"sigs left to process: {}",
self.collection.len() - processed.read().unwrap().len()
);
let processed_sigs = AtomicUsize::new(0);
self.collection.par_iter().for_each(|(dataset_id, _)| {
if !processed.read().unwrap().contains(&dataset_id) {
let i = processed_sigs.fetch_add(1, Ordering::SeqCst);
if i % 1000 == 0 {
info!("Processed {} reference sigs", i);
}
self.map_hashes_colors(dataset_id as Idx);
processed.write().unwrap().extend([dataset_id]);
}
});
self.save_collection().expect("Error saving collection");
info!("Compact SSTs");
self.compact();
info!(
"Processed additional {} reference sigs",
processed_sigs.into_inner()
);
Ok(module::RevIndex::Disk(self))
}
fn check(&self, quick: bool) -> DbStats {
stats_for_cf(self.db.clone(), HASHES, true, quick)
}
fn compact(&self) {
for cf_name in ALL_CFS {
let cf = self.db.cf_handle(cf_name).unwrap();
self.db.compact_range_cf(&cf, None::<&[u8]>, None::<&[u8]>)
}
}
fn flush(&self) -> Result<()> {
self.db.flush_wal(true)?;
for cf_name in [HASHES, METADATA] {
let cf = self.db.cf_handle(cf_name).unwrap();
self.db.flush_cf(&cf)?;
}
Ok(())
}
fn collection(&self) -> &CollectionSet {
&self.collection
}
fn internalize_storage(&mut self) -> Result<()> {
if self.collection.storage().spec() == "rocksdb://" {
return Ok(());
}
let new_storage = RocksDBStorage::from_db(self.db.clone());
self.collection()
.par_iter()
.try_for_each(|(_, record)| -> Result<()> {
let path = record.internal_location().as_str();
let sig_data = self.collection.storage().load(path).unwrap();
new_storage.save(path, &sig_data)?;
Ok(())
})?;
unsafe {
if let Some(v) = Arc::get_mut(&mut self.collection) {
v.set_storage_unchecked(InnerStorage::new(new_storage))
}
}
let cf_metadata = self.db.cf_handle(METADATA).unwrap();
let spec = "rocksdb://";
self.db.put_cf(&cf_metadata, STORAGE_SPEC, spec)?;
Ok(())
}
fn convert(&self, _output_db: module::RevIndex) -> Result<()> {
todo!()
}
fn find_signatures(
&self,
query_mh: &KmerMinHash,
threshold: f64,
picklist: Option<DatasetPicklist>,
) -> Result<Vec<(f64, Signature, String)>> {
let counter = self.counter_for_query(query_mh, picklist);
let filename = self.location();
let results: Vec<(f64, Signature, String)> = counter
.most_common()
.into_iter()
.filter_map(|(dataset_id, _size)| {
let sig: Signature = self
.collection()
.sig_for_dataset(dataset_id)
.expect("dataset not found")
.into();
let match_mh = sig.minhash().expect("cannot retrieve match");
let f_match = if match_mh.scaled() != query_mh.scaled() {
let match_ds = match_mh
.clone()
.downsample_scaled(query_mh.scaled())
.expect("cannot downsample");
query_mh
.jaccard(&match_ds)
.expect("cannot calculate Jaccard")
} else {
query_mh
.jaccard(match_mh)
.expect("cannot calculate Jaccard")
};
if f_match >= threshold {
Some((f_match, sig, filename.to_owned()))
} else {
None
}
})
.collect();
Ok(results)
}
}