use anyhow::Result;
use rayon::prelude::*;
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::mpsc;
use std::time::Instant;
use crate::constants::{
COO_MERGE_JOIN_MAX_BUCKETS, DENSE_ACCUMULATOR_MAX_BUCKETS, ESTIMATED_MINIMIZERS_PER_SEQUENCE,
};
use crate::core::extraction::get_paired_minimizers_into;
use crate::core::workspace::MinimizerWorkspace;
use crate::indices::sharded::{ShardManifest, ShardedInvertedIndex};
use crate::indices::{InvertedIndex, QueryInvertedIndex};
use crate::types::{HitResult, QueryRecord};
use crate::log_timing;
use super::common::{collect_negative_minimizers_sharded, filter_negative_mins};
use super::merge_join::{
merge_join_coo_parallel, merge_join_csr, merge_join_pairs_sparse, DenseAccumulator,
HitAccumulator, SparseAccumulator, SparseHit,
};
enum LoadedShard {
Coo(Vec<(u64, u32)>),
Csr(InvertedIndex),
}
fn estimate_minimizers_from_records(records: &[QueryRecord], k: usize, w: usize) -> usize {
if records.is_empty() {
return ESTIMATED_MINIMIZERS_PER_SEQUENCE;
}
let (_, s1, s2) = &records[0];
let query_len = s1.len() + s2.map(|s| s.len()).unwrap_or(0);
if query_len <= k {
return ESTIMATED_MINIMIZERS_PER_SEQUENCE;
}
let estimate = ((query_len - k) / w + 1) * 2;
estimate.max(ESTIMATED_MINIMIZERS_PER_SEQUENCE)
}
pub fn extract_batch_minimizers(
k: usize,
w: usize,
salt: u64,
negative_mins: Option<&HashSet<u64>>,
records: &[QueryRecord],
) -> Vec<(Vec<u64>, Vec<u64>)> {
if records.is_empty() {
return Vec::new();
}
let estimated_mins = estimate_minimizers_from_records(records, k, w);
records
.par_iter()
.map_init(
|| MinimizerWorkspace::with_estimate(estimated_mins),
|ws, (_, s1, s2)| {
let (ha, hb) = get_paired_minimizers_into(s1, *s2, k, w, salt, ws);
filter_negative_mins(ha, hb, negative_mins)
},
)
.collect()
}
fn max_bucket_id_from_manifest(manifest: &crate::indices::sharded::ShardManifest) -> u32 {
manifest.bucket_names.keys().max().copied().unwrap_or(0)
}
fn use_dense_accumulator(manifest: &crate::indices::sharded::ShardManifest) -> bool {
let max_id = max_bucket_id_from_manifest(manifest);
max_id > 0 && (max_id as usize) <= DENSE_ACCUMULATOR_MAX_BUCKETS
}
fn use_coo_merge_join(manifest: &crate::indices::sharded::ShardManifest) -> bool {
manifest.bucket_names.len() <= COO_MERGE_JOIN_MAX_BUCKETS
}
fn filter_unseen(loaded: &[(u64, u32)], seen: &[(u64, u32)]) -> Vec<(u64, u32)> {
let mut result = Vec::new();
let mut si = 0;
let mut prev: Option<(u64, u32)> = None;
for &pair in loaded {
if prev == Some(pair) {
continue;
}
prev = Some(pair);
while si < seen.len() && seen[si] < pair {
si += 1;
}
if si >= seen.len() || seen[si] != pair {
result.push(pair);
}
}
result
}
fn merge_sorted_into(a: &[(u64, u32)], b: &[(u64, u32)], out: &mut Vec<(u64, u32)>) {
out.clear();
out.reserve(a.len() + b.len());
let (mut i, mut j) = (0, 0);
while i < a.len() && j < b.len() {
if a[i] <= b[j] {
out.push(a[i]);
i += 1;
} else {
out.push(b[j]);
j += 1;
}
}
out.extend_from_slice(&a[i..]);
out.extend_from_slice(&b[j..]);
}
fn classify_shard_loop<A: HitAccumulator>(
sharded: &ShardedInvertedIndex,
query_idx: &QueryInvertedIndex,
query_ids: &[i64],
threshold: f64,
read_options: Option<&crate::indices::parquet::ParquetReadOptions>,
mut accumulator: A,
use_coo: bool,
) -> Result<Vec<HitResult>> {
let t_start = Instant::now();
let manifest = sharded.manifest();
let mut total_shard_load_ms = 0u128;
let mut total_merge_join_ms = 0u128;
let query_minimizers = query_idx.unique_minimizers();
let needs_dedup = manifest.has_overlapping_shards && manifest.shards.len() > 1;
let load_result: Result<()> = std::thread::scope(|scope| {
let (tx, rx) = mpsc::sync_channel::<Result<(LoadedShard, u128)>>(1);
let query_mins_ref = &query_minimizers;
let loader = scope.spawn(move || {
for shard_info in &manifest.shards {
let t_load = Instant::now();
let loaded: Result<LoadedShard> = if use_coo {
sharded
.load_shard_coo_for_query(shard_info.shard_id, query_mins_ref, read_options)
.map(LoadedShard::Coo)
.map_err(Into::into)
} else {
sharded
.load_shard_for_query(shard_info.shard_id, query_mins_ref, read_options)
.map(LoadedShard::Csr)
.map_err(Into::into)
};
let load_ms = t_load.elapsed().as_millis();
if tx.send(loaded.map(|s| (s, load_ms))).is_err() {
break;
}
}
});
let mut seen: Vec<(u64, u32)> = Vec::new();
let mut merge_buf: Vec<(u64, u32)> = Vec::new();
for received in rx {
let (shard, load_ms) = received?;
total_shard_load_ms += load_ms;
let t_merge = Instant::now();
match shard {
LoadedShard::Coo(ref pairs) => {
if needs_dedup {
let filtered = filter_unseen(pairs, &seen);
merge_sorted_into(&seen, &filtered, &mut merge_buf);
std::mem::swap(&mut seen, &mut merge_buf);
merge_join_coo_parallel(query_idx, &filtered, &mut accumulator);
} else {
merge_join_coo_parallel(query_idx, pairs, &mut accumulator);
}
}
LoadedShard::Csr(ref idx) => {
merge_join_csr(query_idx, idx, &mut accumulator, &query_minimizers);
}
}
total_merge_join_ms += t_merge.elapsed().as_millis();
}
loader.join().expect("shard loader thread panicked");
Ok(())
});
load_result?;
log_timing("merge_join: shard_load_total", total_shard_load_ms);
log_timing("merge_join: merge_join_total", total_merge_join_ms);
let t_score = Instant::now();
let results = accumulator.score_and_filter(query_idx, query_ids, threshold);
log_timing("merge_join: scoring", t_score.elapsed().as_millis());
log_timing("merge_join: total", t_start.elapsed().as_millis());
Ok(results)
}
pub fn classify_from_query_index(
sharded: &ShardedInvertedIndex,
query_idx: &QueryInvertedIndex,
query_ids: &[i64],
threshold: f64,
read_options: Option<&crate::indices::parquet::ParquetReadOptions>,
) -> Result<Vec<HitResult>> {
let num_reads = query_idx.num_reads();
if num_reads == 0 {
return Ok(Vec::new());
}
let manifest = sharded.manifest();
let max_id = max_bucket_id_from_manifest(manifest);
let use_coo = use_coo_merge_join(manifest);
if use_dense_accumulator(manifest) {
classify_shard_loop(
sharded,
query_idx,
query_ids,
threshold,
read_options,
DenseAccumulator::new(num_reads, max_id),
use_coo,
)
} else {
classify_shard_loop(
sharded,
query_idx,
query_ids,
threshold,
read_options,
SparseAccumulator::new(num_reads),
use_coo,
)
}
}
pub fn classify_from_extracted_minimizers(
sharded: &ShardedInvertedIndex,
extracted: &[(Vec<u64>, Vec<u64>)],
query_ids: &[i64],
threshold: f64,
read_options: Option<&crate::indices::parquet::ParquetReadOptions>,
) -> Result<Vec<HitResult>> {
if extracted.is_empty() {
return Ok(Vec::new());
}
let t_build_idx = Instant::now();
let query_idx = QueryInvertedIndex::build(extracted);
log_timing(
"merge_join: build_query_index",
t_build_idx.elapsed().as_millis(),
);
classify_from_query_index(sharded, &query_idx, query_ids, threshold, read_options)
}
pub fn classify_batch_sharded_merge_join(
sharded: &ShardedInvertedIndex,
negative_mins: Option<&HashSet<u64>>,
records: &[QueryRecord],
threshold: f64,
read_options: Option<&crate::indices::parquet::ParquetReadOptions>,
) -> Result<Vec<HitResult>> {
if records.is_empty() {
return Ok(Vec::new());
}
let manifest = sharded.manifest();
let t_extract = Instant::now();
let extracted = extract_batch_minimizers(
manifest.k,
manifest.w,
manifest.salt,
negative_mins,
records,
);
log_timing("merge_join: extraction", t_extract.elapsed().as_millis());
let query_ids: Vec<i64> = records.iter().map(|(id, _, _)| *id).collect();
classify_from_extracted_minimizers(sharded, &extracted, &query_ids, threshold, read_options)
}
#[allow(clippy::too_many_arguments)]
fn parallel_rg_inner<A>(
work_items: Vec<(PathBuf, usize)>,
query_idx: &QueryInvertedIndex,
query_ids: &[i64],
query_minimizers: &[u64],
threshold: f64,
num_reads: usize,
total_rg_count: usize,
filtered_rg_count: usize,
t_start: Instant,
make_acc: impl Fn() -> A + Send + Sync,
) -> Result<Vec<HitResult>>
where
A: HitAccumulator,
{
use crate::indices::load_row_group_pairs;
const FOLD_REDUCE_MAX_READS: usize = 500_000;
let t_parallel = Instant::now();
let final_accumulator = if num_reads <= FOLD_REDUCE_MAX_READS {
work_items
.into_par_iter()
.try_fold(&make_acc, |mut acc, (shard_path, rg_idx)| -> Result<A> {
let pairs = load_row_group_pairs(&shard_path, rg_idx, query_minimizers)?;
if !pairs.is_empty() {
let hits = merge_join_pairs_sparse(query_idx, &pairs);
for (read_idx, bucket_id, fwd, rc) in hits {
acc.record_hit_counts(read_idx as usize, bucket_id, fwd, rc);
}
}
Ok(acc)
})
.try_reduce(&make_acc, |mut a, b| {
a.merge(b);
Ok(a)
})?
} else {
let results: Result<Vec<Vec<SparseHit>>> = work_items
.into_par_iter()
.map(|(shard_path, rg_idx)| {
let pairs = load_row_group_pairs(&shard_path, rg_idx, query_minimizers)?;
Ok(merge_join_pairs_sparse(query_idx, &pairs))
})
.collect();
let mut acc = make_acc();
for rg_hits in results? {
if rg_hits.is_empty() {
continue;
}
for (read_idx, bucket_id, fwd, rc) in rg_hits {
acc.record_hit_counts(read_idx as usize, bucket_id, fwd, rc);
}
}
acc
};
log_timing(
"parallel_rg: rg_process_total",
t_parallel.elapsed().as_millis(),
);
log_timing("parallel_rg: total_rg_count", total_rg_count as u128);
log_timing("parallel_rg: filtered_rg_count", filtered_rg_count as u128);
let t_score = Instant::now();
let results = final_accumulator.score_and_filter(query_idx, query_ids, threshold);
log_timing("parallel_rg: scoring", t_score.elapsed().as_millis());
log_timing("parallel_rg: total", t_start.elapsed().as_millis());
Ok(results)
}
pub fn classify_from_query_index_parallel_rg(
sharded: &ShardedInvertedIndex,
query_idx: &QueryInvertedIndex,
query_ids: &[i64],
threshold: f64,
_read_options: Option<&crate::indices::parquet::ParquetReadOptions>,
) -> Result<Vec<HitResult>> {
use crate::indices::get_row_group_ranges;
let t_start = Instant::now();
let num_reads = query_idx.num_reads();
if num_reads == 0 {
return Ok(Vec::new());
}
let manifest = sharded.manifest();
if manifest.has_overlapping_shards && manifest.shards.len() > 1 {
return classify_from_query_index(sharded, query_idx, query_ids, threshold, _read_options);
}
let query_minimizers = query_idx.unique_minimizers();
let (query_min, query_max) = match query_idx.minimizer_range() {
Some(range) => range,
None => return Ok(Vec::new()),
};
let mut total_rg_count = 0usize;
let mut work_items: Vec<(PathBuf, usize)> = Vec::new();
let t_filter = Instant::now();
let use_cache = sharded.has_rg_cache();
for (shard_pos, shard_info) in manifest.shards.iter().enumerate() {
let shard_path =
ShardManifest::shard_path_parquet(sharded.base_path(), shard_info.shard_id);
let rg_ranges = if use_cache {
sharded
.rg_ranges(shard_pos)
.map(|s| s.to_vec())
.unwrap_or_default()
} else {
get_row_group_ranges(&shard_path)?
};
total_rg_count += rg_ranges.len();
for info in rg_ranges {
if info.max >= query_min && info.min <= query_max {
work_items.push((shard_path.clone(), info.rg_idx));
}
}
}
let filtered_rg_count = work_items.len();
log_timing("parallel_rg: rg_filter", t_filter.elapsed().as_millis());
let max_id = max_bucket_id_from_manifest(manifest);
if use_dense_accumulator(manifest) {
parallel_rg_inner(
work_items,
query_idx,
query_ids,
&query_minimizers,
threshold,
num_reads,
total_rg_count,
filtered_rg_count,
t_start,
|| DenseAccumulator::new(num_reads, max_id),
)
} else {
parallel_rg_inner(
work_items,
query_idx,
query_ids,
&query_minimizers,
threshold,
num_reads,
total_rg_count,
filtered_rg_count,
t_start,
|| SparseAccumulator::new(num_reads),
)
}
}
pub fn classify_from_extracted_minimizers_parallel_rg(
sharded: &ShardedInvertedIndex,
extracted: &[(Vec<u64>, Vec<u64>)],
query_ids: &[i64],
threshold: f64,
_read_options: Option<&crate::indices::parquet::ParquetReadOptions>,
) -> Result<Vec<HitResult>> {
if extracted.is_empty() {
return Ok(Vec::new());
}
let t_build_idx = Instant::now();
let query_idx = QueryInvertedIndex::build(extracted);
log_timing(
"parallel_rg: build_query_index",
t_build_idx.elapsed().as_millis(),
);
classify_from_query_index_parallel_rg(sharded, &query_idx, query_ids, threshold, _read_options)
}
pub fn classify_batch_sharded_parallel_rg(
sharded: &ShardedInvertedIndex,
negative_mins: Option<&HashSet<u64>>,
records: &[QueryRecord],
threshold: f64,
_read_options: Option<&crate::indices::parquet::ParquetReadOptions>,
) -> Result<Vec<HitResult>> {
if records.is_empty() {
return Ok(Vec::new());
}
let manifest = sharded.manifest();
let t_extract = Instant::now();
let extracted = extract_batch_minimizers(
manifest.k,
manifest.w,
manifest.salt,
negative_mins,
records,
);
log_timing("parallel_rg: extraction", t_extract.elapsed().as_millis());
let query_ids: Vec<i64> = records.iter().map(|(id, _, _)| *id).collect();
classify_from_extracted_minimizers_parallel_rg(
sharded,
&extracted,
&query_ids,
threshold,
_read_options,
)
}
pub fn classify_with_sharded_negative(
positive_index: &ShardedInvertedIndex,
negative_index: Option<&ShardedInvertedIndex>,
records: &[QueryRecord],
threshold: f64,
read_options: Option<&crate::indices::parquet::ParquetReadOptions>,
) -> Result<Vec<HitResult>> {
if negative_index.is_none() {
return classify_batch_sharded_merge_join(
positive_index,
None,
records,
threshold,
read_options,
);
}
let negative = negative_index.unwrap();
let manifest = positive_index.manifest();
let extracted = extract_batch_minimizers(manifest.k, manifest.w, manifest.salt, None, records);
let mut all_minimizers: Vec<u64> = extracted
.iter()
.flat_map(|(fwd, rc)| fwd.iter().chain(rc.iter()).copied())
.collect();
all_minimizers.sort_unstable();
all_minimizers.dedup();
let negative_set =
collect_negative_minimizers_sharded(negative, &all_minimizers, read_options)?;
let filtered: Vec<(Vec<u64>, Vec<u64>)> = extracted
.into_iter()
.map(|(fwd, rc)| filter_negative_mins(fwd, rc, Some(&negative_set)))
.collect();
let query_ids: Vec<i64> = records.iter().map(|(id, _, _)| *id).collect();
classify_from_extracted_minimizers(
positive_index,
&filtered,
&query_ids,
threshold,
read_options,
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{create_parquet_inverted_index, extract_into, BucketData, ParquetWriteOptions};
use tempfile::tempdir;
fn generate_sequence(len: usize, seed: u8) -> Vec<u8> {
let bases = [b'A', b'C', b'G', b'T'];
(0..len).map(|i| bases[(i + seed as usize) % 4]).collect()
}
fn create_test_index() -> (tempfile::TempDir, ShardedInvertedIndex, Vec<Vec<u8>>) {
let dir = tempdir().unwrap();
let index_path = dir.path().join("test.ryxdi");
let k = 32;
let w = 10;
let salt = 0x12345u64;
let mut ws = MinimizerWorkspace::new();
let seq1 = generate_sequence(200, 0);
let seq2 = generate_sequence(200, 1);
extract_into(&seq1, k, w, salt, &mut ws);
let mut mins1: Vec<u64> = ws.buffer.drain(..).collect();
mins1.sort();
mins1.dedup();
extract_into(&seq2, k, w, salt, &mut ws);
let mut mins2: Vec<u64> = ws.buffer.drain(..).collect();
mins2.sort();
mins2.dedup();
let buckets = vec![
BucketData {
bucket_id: 1,
bucket_name: "Bucket1".to_string(),
sources: vec!["seq1".to_string()],
minimizers: mins1,
},
BucketData {
bucket_id: 2,
bucket_name: "Bucket2".to_string(),
sources: vec!["seq2".to_string()],
minimizers: mins2,
},
];
let options = ParquetWriteOptions::default();
create_parquet_inverted_index(&index_path, buckets, k, w, salt, None, Some(&options), None)
.unwrap();
let index = ShardedInvertedIndex::open(&index_path).unwrap();
(dir, index, vec![seq1, seq2])
}
#[test]
fn test_merge_join_empty_records() {
let (_dir, index, _seqs) = create_test_index();
let records: Vec<QueryRecord> = vec![];
let results = classify_batch_sharded_merge_join(&index, None, &records, 0.1, None).unwrap();
assert!(
results.is_empty(),
"Empty records should produce empty results"
);
}
#[test]
fn test_merge_join_self_match() {
let (_dir, index, seqs) = create_test_index();
let records: Vec<QueryRecord> = vec![(1, seqs[0].as_slice(), None)];
let results = classify_batch_sharded_merge_join(&index, None, &records, 0.1, None).unwrap();
let bucket1_hit = results.iter().find(|r| r.query_id == 1 && r.bucket_id == 1);
assert!(bucket1_hit.is_some(), "Should have self-match for bucket 1");
assert!(
bucket1_hit.unwrap().score > 0.9,
"Self-match score should be >0.9, got {}",
bucket1_hit.unwrap().score
);
}
#[test]
fn test_merge_join_multiple_queries() {
let (_dir, index, seqs) = create_test_index();
let records: Vec<QueryRecord> =
vec![(1, seqs[0].as_slice(), None), (2, seqs[1].as_slice(), None)];
let results = classify_batch_sharded_merge_join(&index, None, &records, 0.1, None).unwrap();
let q1_b1 = results.iter().find(|r| r.query_id == 1 && r.bucket_id == 1);
let q2_b2 = results.iter().find(|r| r.query_id == 2 && r.bucket_id == 2);
assert!(q1_b1.is_some(), "Query 1 should match bucket 1");
assert!(q2_b2.is_some(), "Query 2 should match bucket 2");
assert!(
q1_b1.unwrap().score > 0.9,
"Self-match should have high score"
);
assert!(
q2_b2.unwrap().score > 0.9,
"Self-match should have high score"
);
}
#[test]
fn test_merge_join_threshold_filtering() {
let (_dir, index, seqs) = create_test_index();
let records: Vec<QueryRecord> = vec![(1, seqs[0].as_slice(), None)];
let low_results =
classify_batch_sharded_merge_join(&index, None, &records, 0.1, None).unwrap();
let mid_results =
classify_batch_sharded_merge_join(&index, None, &records, 0.5, None).unwrap();
let high_results =
classify_batch_sharded_merge_join(&index, None, &records, 1.01, None).unwrap();
assert!(!low_results.is_empty(), "Low threshold should have results");
assert!(
mid_results.len() <= low_results.len(),
"Higher threshold should have fewer or equal results"
);
assert!(
high_results.is_empty(),
"Threshold > 1.0 should filter all results"
);
}
#[test]
fn test_merge_join_short_sequence() {
let (_dir, index, _seqs) = create_test_index();
let short_seq = b"ACGTACGT"; let records: Vec<QueryRecord> = vec![(1, short_seq.as_slice(), None)];
let results = classify_batch_sharded_merge_join(&index, None, &records, 0.1, None).unwrap();
assert!(results.is_empty(), "Short sequence should have no hits");
}
#[test]
fn test_merge_join_with_negative_mins() {
let (_dir, index, seqs) = create_test_index();
let records: Vec<QueryRecord> = vec![(1, seqs[0].as_slice(), None)];
let results_without =
classify_batch_sharded_merge_join(&index, None, &records, 0.1, None).unwrap();
assert!(
!results_without.is_empty(),
"Should have results without filtering"
);
let k = index.k();
let w = index.w();
let salt = index.salt();
let mut ws = MinimizerWorkspace::new();
extract_into(&seqs[0], k, w, salt, &mut ws);
let negative_mins: HashSet<u64> = ws.buffer.drain(..).collect();
let results_with =
classify_batch_sharded_merge_join(&index, Some(&negative_mins), &records, 0.1, None)
.unwrap();
assert!(
results_with.is_empty(),
"Filtering all minimizers should produce no hits"
);
}
#[test]
fn test_merge_join_paired_end() {
let (_dir, index, seqs) = create_test_index();
let records: Vec<QueryRecord> = vec![(1, seqs[0].as_slice(), Some(seqs[1].as_slice()))];
let results = classify_batch_sharded_merge_join(&index, None, &records, 0.1, None).unwrap();
let b1_hit = results.iter().find(|r| r.bucket_id == 1);
let b2_hit = results.iter().find(|r| r.bucket_id == 2);
assert!(
b1_hit.is_some(),
"Should have hit for bucket 1 (from read1)"
);
assert!(
b2_hit.is_some(),
"Should have hit for bucket 2 (from read2)"
);
}
#[test]
fn test_estimate_minimizers_from_records_empty() {
let records: Vec<QueryRecord> = vec![];
let estimate = estimate_minimizers_from_records(&records, 32, 10);
assert_eq!(estimate, ESTIMATED_MINIMIZERS_PER_SEQUENCE);
}
#[test]
fn test_estimate_minimizers_from_records_short_sequence() {
let short_seq = b"ACGT"; let records: Vec<QueryRecord> = vec![(1, short_seq.as_slice(), None)];
let estimate = estimate_minimizers_from_records(&records, 32, 10);
assert_eq!(estimate, ESTIMATED_MINIMIZERS_PER_SEQUENCE);
}
#[test]
fn test_estimate_minimizers_from_records_long_sequence() {
let long_seq = generate_sequence(200, 0);
let records: Vec<QueryRecord> = vec![(1, long_seq.as_slice(), None)];
let estimate = estimate_minimizers_from_records(&records, 32, 10);
assert!(
estimate >= 30 && estimate <= 50,
"Estimate should be reasonable"
);
}
fn create_test_index_at_path(
path: &std::path::Path,
bucket_data: Vec<(u32, &str, Vec<u64>)>,
k: usize,
w: usize,
salt: u64,
) -> ShardedInvertedIndex {
let buckets: Vec<BucketData> = bucket_data
.into_iter()
.map(|(id, name, mins)| BucketData {
bucket_id: id,
bucket_name: name.to_string(),
sources: vec![format!("source_{}", id)],
minimizers: mins,
})
.collect();
let options = ParquetWriteOptions::default();
create_parquet_inverted_index(path, buckets, k, w, salt, None, Some(&options), None)
.unwrap();
ShardedInvertedIndex::open(path).unwrap()
}
#[test]
fn test_classify_with_sharded_negative_no_negative() {
let (_dir, index, seqs) = create_test_index();
let records: Vec<QueryRecord> = vec![(1, seqs[0].as_slice(), None)];
let results_standard =
classify_batch_sharded_merge_join(&index, None, &records, 0.1, None).unwrap();
let results_sharded =
classify_with_sharded_negative(&index, None, &records, 0.1, None).unwrap();
assert_eq!(
results_standard.len(),
results_sharded.len(),
"Results should match when no negative index"
);
}
#[test]
fn test_classify_with_sharded_negative_filters_correctly() {
let dir = tempdir().unwrap();
let k = 32;
let w = 10;
let salt = 0x12345u64;
let positive_mins: Vec<u64> = (100..110).collect();
let negative_mins: Vec<u64> = (100..105).collect();
let pos_path = dir.path().join("positive.ryxdi");
let pos_index = create_test_index_at_path(
&pos_path,
vec![(1, "target", positive_mins.clone())],
k,
w,
salt,
);
let neg_path = dir.path().join("negative.ryxdi");
let neg_index = create_test_index_at_path(
&neg_path,
vec![(1, "contaminant", negative_mins)],
k,
w,
salt,
);
let seq = generate_sequence(500, 42);
let mut ws = MinimizerWorkspace::new();
extract_into(&seq, k, w, salt, &mut ws);
let _query_mins: Vec<u64> = ws.buffer.drain(..).collect();
let records: Vec<QueryRecord> = vec![(1, seq.as_slice(), None)];
let results_without =
classify_with_sharded_negative(&pos_index, None, &records, 0.0, None).unwrap();
let results_with =
classify_with_sharded_negative(&pos_index, Some(&neg_index), &records, 0.0, None)
.unwrap();
assert!(
results_with.len() <= results_without.len(),
"Negative filtering should not increase hit count"
);
}
#[test]
fn test_classify_with_sharded_negative_all_filtered() {
let (_dir, index, seqs) = create_test_index();
let k = index.k();
let w = index.w();
let salt = index.salt();
let mut ws = MinimizerWorkspace::new();
extract_into(&seqs[0], k, w, salt, &mut ws);
let mut seq_mins: Vec<u64> = ws.buffer.drain(..).collect();
seq_mins.sort();
seq_mins.dedup();
let dir = tempdir().unwrap();
let neg_path = dir.path().join("negative.ryxdi");
let neg_index =
create_test_index_at_path(&neg_path, vec![(1, "contaminant", seq_mins)], k, w, salt);
let records: Vec<QueryRecord> = vec![(1, seqs[0].as_slice(), None)];
let results =
classify_with_sharded_negative(&index, Some(&neg_index), &records, 0.1, None).unwrap();
assert!(
results.is_empty(),
"All minimizers filtered should produce no hits"
);
}
#[test]
fn test_classify_with_sharded_negative_empty_records() {
let (_dir, index, _seqs) = create_test_index();
let records: Vec<QueryRecord> = vec![];
let results = classify_with_sharded_negative(&index, None, &records, 0.1, None).unwrap();
assert!(
results.is_empty(),
"Empty records should produce empty results"
);
}
#[test]
fn test_classify_from_query_index_matches_extracted() {
let (_dir, index, seqs) = create_test_index();
let manifest = index.manifest();
let extracted = extract_batch_minimizers(
manifest.k,
manifest.w,
manifest.salt,
None,
&[
(1i64, seqs[0].as_slice(), None),
(2i64, seqs[1].as_slice(), None),
],
);
let query_ids = vec![1i64, 2];
let results_existing =
classify_from_extracted_minimizers(&index, &extracted, &query_ids, 0.1, None).unwrap();
let query_idx = QueryInvertedIndex::build(&extracted);
let results_new =
classify_from_query_index(&index, &query_idx, &query_ids, 0.1, None).unwrap();
assert_eq!(
results_existing.len(),
results_new.len(),
"classify_from_query_index should produce same number of results"
);
let mut existing_sorted = results_existing.clone();
existing_sorted.sort_by(|a, b| {
a.query_id
.cmp(&b.query_id)
.then(a.bucket_id.cmp(&b.bucket_id))
});
let mut new_sorted = results_new.clone();
new_sorted.sort_by(|a, b| {
a.query_id
.cmp(&b.query_id)
.then(a.bucket_id.cmp(&b.bucket_id))
});
for (e, n) in existing_sorted.iter().zip(new_sorted.iter()) {
assert_eq!(e.query_id, n.query_id);
assert_eq!(e.bucket_id, n.bucket_id);
assert!(
(e.score - n.score).abs() < 1e-10,
"Scores should match: {} vs {}",
e.score,
n.score
);
}
}
#[test]
fn test_classify_from_query_index_parallel_rg_matches_extracted() {
let (_dir, index, seqs) = create_test_index();
let manifest = index.manifest();
let extracted = extract_batch_minimizers(
manifest.k,
manifest.w,
manifest.salt,
None,
&[
(1i64, seqs[0].as_slice(), None),
(2i64, seqs[1].as_slice(), None),
],
);
let query_ids = vec![1i64, 2];
let results_existing = classify_from_extracted_minimizers_parallel_rg(
&index, &extracted, &query_ids, 0.1, None,
)
.unwrap();
let query_idx = QueryInvertedIndex::build(&extracted);
let results_new =
classify_from_query_index_parallel_rg(&index, &query_idx, &query_ids, 0.1, None)
.unwrap();
assert_eq!(
results_existing.len(),
results_new.len(),
"classify_from_query_index_parallel_rg should produce same number of results"
);
let mut existing_sorted = results_existing.clone();
existing_sorted.sort_by(|a, b| {
a.query_id
.cmp(&b.query_id)
.then(a.bucket_id.cmp(&b.bucket_id))
});
let mut new_sorted = results_new.clone();
new_sorted.sort_by(|a, b| {
a.query_id
.cmp(&b.query_id)
.then(a.bucket_id.cmp(&b.bucket_id))
});
for (e, n) in existing_sorted.iter().zip(new_sorted.iter()) {
assert_eq!(e.query_id, n.query_id);
assert_eq!(e.bucket_id, n.bucket_id);
assert!(
(e.score - n.score).abs() < 1e-10,
"Scores should match: {} vs {}",
e.score,
n.score
);
}
}
#[test]
fn test_classify_from_query_index_empty() {
let (_dir, index, _seqs) = create_test_index();
let query_idx = QueryInvertedIndex::build(&[]);
let results = classify_from_query_index(&index, &query_idx, &[], 0.1, None).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_classify_with_sharded_negative_consistency_with_hashset() {
let (_dir, index, seqs) = create_test_index();
let k = index.k();
let w = index.w();
let salt = index.salt();
let mut ws = MinimizerWorkspace::new();
extract_into(&seqs[0], k, w, salt, &mut ws);
let mut seq_mins: Vec<u64> = ws.buffer.drain(..).collect();
seq_mins.sort();
seq_mins.dedup();
let neg_count = (seq_mins.len() / 3).max(1);
let negative_mins: Vec<u64> = seq_mins[..neg_count].to_vec();
let negative_set: HashSet<u64> = negative_mins.iter().copied().collect();
let dir = tempdir().unwrap();
let neg_path = dir.path().join("negative.ryxdi");
let neg_index = create_test_index_at_path(
&neg_path,
vec![(1, "contaminant", negative_mins)],
k,
w,
salt,
);
let records: Vec<QueryRecord> = vec![(1, seqs[0].as_slice(), None)];
let results_hashset =
classify_batch_sharded_merge_join(&index, Some(&negative_set), &records, 0.1, None)
.unwrap();
let results_sharded =
classify_with_sharded_negative(&index, Some(&neg_index), &records, 0.1, None).unwrap();
assert_eq!(
results_hashset.len(),
results_sharded.len(),
"Both approaches should produce same number of results"
);
if !results_hashset.is_empty() {
let score_hashset = results_hashset[0].score;
let score_sharded = results_sharded[0].score;
let diff = (score_hashset - score_sharded).abs();
assert!(
diff < 1e-10,
"Scores should match: hashset={}, sharded={}",
score_hashset,
score_sharded
);
}
}
#[test]
fn test_filter_unseen_empty_seen() {
let loaded = vec![(1, 0), (2, 0), (3, 0)];
let seen: Vec<(u64, u32)> = vec![];
let result = filter_unseen(&loaded, &seen);
assert_eq!(result, loaded);
}
#[test]
fn test_filter_unseen_all_seen() {
let loaded = vec![(1, 0), (2, 0), (3, 0)];
let seen = vec![(1, 0), (2, 0), (3, 0)];
let result = filter_unseen(&loaded, &seen);
assert!(result.is_empty());
}
#[test]
fn test_filter_unseen_partial() {
let loaded = vec![(1, 0), (2, 0), (3, 0), (4, 0)];
let seen = vec![(2, 0), (4, 0)];
let result = filter_unseen(&loaded, &seen);
assert_eq!(result, vec![(1, 0), (3, 0)]);
}
#[test]
fn test_filter_unseen_within_shard_duplicates() {
let loaded = vec![(1, 0), (1, 0), (2, 0), (3, 0), (3, 0), (3, 0)];
let seen: Vec<(u64, u32)> = vec![];
let result = filter_unseen(&loaded, &seen);
assert_eq!(result, vec![(1, 0), (2, 0), (3, 0)]);
}
#[test]
fn test_filter_unseen_both_within_and_cross_shard() {
let loaded = vec![(1, 0), (1, 0), (2, 0), (3, 0), (3, 0), (4, 0)];
let seen = vec![(1, 0), (3, 0)];
let result = filter_unseen(&loaded, &seen);
assert_eq!(result, vec![(2, 0), (4, 0)]);
}
#[test]
fn test_merge_sorted_into_both_nonempty() {
let a = vec![(1, 0), (3, 0), (5, 0)];
let b = vec![(2, 0), (4, 0), (6, 0)];
let mut out = Vec::new();
merge_sorted_into(&a, &b, &mut out);
assert_eq!(out, vec![(1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]);
}
#[test]
fn test_merge_sorted_into_reuses_allocation() {
let a = vec![(1, 0), (3, 0)];
let b = vec![(2, 0)];
let mut out = Vec::with_capacity(100);
merge_sorted_into(&a, &b, &mut out);
assert_eq!(out, vec![(1, 0), (2, 0), (3, 0)]);
assert!(out.capacity() >= 100, "should reuse pre-allocated capacity");
}
#[test]
fn test_overlapping_shards_dedup() {
use crate::indices::parquet::{InvertedManifest, InvertedShardInfo, ParquetManifest};
let dir = tempdir().unwrap();
let index_path = dir.path().join("test.ryxdi");
let k = 32;
let w = 10;
let salt = 0x12345u64;
let seq = generate_sequence(500, 7);
let mut ws = MinimizerWorkspace::new();
extract_into(&seq, k, w, salt, &mut ws);
let mut mins: Vec<u64> = ws.buffer.drain(..).collect();
mins.sort();
mins.dedup();
let buckets = vec![BucketData {
bucket_id: 1,
bucket_name: "TestBucket".to_string(),
sources: vec!["seq".to_string()],
minimizers: mins,
}];
let options = ParquetWriteOptions::default();
create_parquet_inverted_index(&index_path, buckets, k, w, salt, None, Some(&options), None)
.unwrap();
let index_1shard = ShardedInvertedIndex::open(&index_path).unwrap();
let records: Vec<QueryRecord> = vec![(1, seq.as_slice(), None)];
let baseline =
classify_batch_sharded_merge_join(&index_1shard, None, &records, 0.0, None).unwrap();
assert!(!baseline.is_empty(), "self-match should produce hits");
let baseline_score = baseline[0].score;
let shard0 = index_path.join("inverted").join("shard.0.parquet");
let shard1 = index_path.join("inverted").join("shard.1.parquet");
std::fs::copy(&shard0, &shard1).unwrap();
let mut manifest = ParquetManifest::load(&index_path).unwrap();
let inv = manifest.inverted.as_ref().unwrap();
let shard0_info = inv.shards[0];
manifest.inverted = Some(InvertedManifest {
format: inv.format,
num_shards: 2,
total_entries: inv.total_entries * 2,
has_overlapping_shards: true,
shards: vec![
shard0_info,
InvertedShardInfo {
shard_id: 1,
..shard0_info
},
],
});
manifest.save(&index_path).unwrap();
let index_2shard = ShardedInvertedIndex::open(&index_path).unwrap();
let deduped =
classify_batch_sharded_merge_join(&index_2shard, None, &records, 0.0, None).unwrap();
assert!(!deduped.is_empty(), "2-shard should still produce hits");
let deduped_score = deduped[0].score;
assert!(
(deduped_score - baseline_score).abs() < 1e-10,
"Deduped 2-shard score ({}) should equal 1-shard baseline ({})",
deduped_score,
baseline_score,
);
}
}