use nodedb_query::msgpack_scan;
use super::merge_join_docs_binary;
pub(super) fn hash_join_key(
doc: &[u8],
keys: &[&str],
state: &std::collections::hash_map::RandomState,
) -> (u64, Vec<(usize, usize)>) {
use std::hash::{BuildHasher, Hasher};
let mut hasher = state.build_hasher();
let mut ranges = Vec::with_capacity(keys.len());
for key in keys {
if let Some((start, end)) = extract_join_key_range(doc, key) {
hasher.write(&doc[start..end]);
ranges.push((start, end));
} else {
hasher.write_u8(0xc0); ranges.push((0, 0));
}
}
(hasher.finish(), ranges)
}
fn extract_join_key_range(doc: &[u8], key: &str) -> Option<(usize, usize)> {
msgpack_scan::extract_field(doc, 0, key).or_else(|| {
key.rsplit_once('.')
.and_then(|(_, field)| msgpack_scan::extract_field(doc, 0, field))
})
}
type BucketEntry = (usize, Vec<(usize, usize)>);
pub(super) struct HashIndex {
pub(super) buckets: std::collections::HashMap<u64, Vec<BucketEntry>>,
pub(super) state: std::collections::hash_map::RandomState,
}
impl HashIndex {
pub(super) fn build(docs: &[(String, Vec<u8>)], keys: &[&str]) -> Self {
let state = std::collections::hash_map::RandomState::new();
let mut buckets: std::collections::HashMap<u64, Vec<BucketEntry>> =
std::collections::HashMap::with_capacity(docs.len());
for (i, (_, value)) in docs.iter().enumerate() {
let (hash, ranges) = hash_join_key(value, keys, &state);
buckets.entry(hash).or_default().push((i, ranges));
}
Self { buckets, state }
}
pub(super) fn probe(
&self,
probe_doc: &[u8],
probe_keys: &[&str],
build_docs: &[(String, Vec<u8>)],
) -> (u64, Vec<(usize, usize)>, Vec<usize>) {
let (hash, probe_ranges) = hash_join_key(probe_doc, probe_keys, &self.state);
let mut matched = Vec::new();
if let Some(bucket) = self.buckets.get(&hash) {
for (doc_idx, idx_ranges) in bucket {
let mut all_match = !probe_ranges.is_empty();
for (i, &(ps, pe)) in probe_ranges.iter().enumerate() {
if let Some(&(bs, be)) = idx_ranges.get(i) {
let build_doc = &build_docs[*doc_idx].1;
if pe - ps != be - bs || probe_doc[ps..pe] != build_doc[bs..be] {
all_match = false;
break;
}
} else {
all_match = false;
break;
}
}
if all_match {
matched.push(*doc_idx);
}
}
}
(hash, probe_ranges, matched)
}
}
pub(super) struct ProbeParams<'a> {
pub(super) probe_docs: &'a [(String, Vec<u8>)],
pub(super) index: &'a HashIndex,
pub(super) index_docs: &'a [(String, Vec<u8>)],
pub(super) probe_keys: &'a [&'a str],
pub(super) join_type: &'a str,
pub(super) limit: usize,
pub(super) probe_collection: &'a str,
pub(super) index_collection: &'a str,
pub(super) emit_unmatched_right: bool,
}
pub(super) fn probe_hash_index(p: &ProbeParams<'_>) -> Vec<Vec<u8>> {
let is_left = p.join_type == "left" || p.join_type == "full";
let is_right = p.join_type == "right" || p.join_type == "full";
let is_semi = p.join_type == "semi";
let is_anti = p.join_type == "anti";
let is_cross = p.join_type == "cross";
if is_cross {
let mut results = Vec::new();
for (_, left_val) in p.probe_docs {
for (_, right_val) in p.index_docs {
if results.len() >= p.limit {
return results;
}
results.push(merge_join_docs_binary(
left_val,
Some(right_val),
p.probe_collection,
p.index_collection,
));
}
}
return results;
}
let mut index_matched: Vec<bool> = if is_right {
vec![false; p.index_docs.len()]
} else {
Vec::new()
};
let mut results = Vec::new();
for (_, value) in p.probe_docs {
if !is_right && results.len() >= p.limit {
break;
}
let (_, _, matched_indices) = p.index.probe(value, p.probe_keys, p.index_docs);
if !matched_indices.is_empty() {
if is_semi {
if results.len() < p.limit {
results.push(merge_join_docs_binary(value, None, p.probe_collection, ""));
}
} else if is_anti {
} else {
for &mi in &matched_indices {
if is_right {
index_matched[mi] = true;
}
if results.len() < p.limit {
results.push(merge_join_docs_binary(
value,
Some(&p.index_docs[mi].1),
p.probe_collection,
p.index_collection,
));
}
}
}
} else if is_anti && results.len() < p.limit {
results.push(merge_join_docs_binary(value, None, p.probe_collection, ""));
} else if is_left && results.len() < p.limit {
results.push(merge_join_docs_binary(
value,
None,
p.probe_collection,
p.index_collection,
));
}
}
if is_right && p.emit_unmatched_right {
for (i, (_, bytes)) in p.index_docs.iter().enumerate() {
if results.len() >= p.limit {
break;
}
if !index_matched[i] {
results.push(merge_join_docs_binary(
&[],
Some(bytes),
"",
p.index_collection,
));
}
}
}
results
}