1use crate::bias::HashBiasTable;
2use crate::format::{BUCKET_COUNT, bucket_id};
3use crate::reader::{JamReader, ReaderError};
4use jamhash::jamhash_u64;
5use needletail::{Sequence, parse_fastx_file};
6use std::collections::{HashMap, HashSet};
7use std::path::Path;
8use std::sync::Arc;
9
10#[derive(Debug)]
11pub struct QuerySketch {
12 pub buckets: [Vec<(u64, u32)>; BUCKET_COUNT],
13 pub sample_names: Vec<String>,
14 pub query_sizes: Vec<usize>,
15}
16
17impl QuerySketch {
18 pub fn new() -> Self {
19 Self {
20 buckets: std::array::from_fn(|_| Vec::new()),
21 sample_names: Vec::new(),
22 query_sizes: Vec::new(),
23 }
24 }
25
26 #[inline]
27 pub fn bucket(&self, idx: usize) -> &[(u64, u32)] {
28 &self.buckets[idx]
29 }
30
31 #[inline]
32 pub fn sample_count(&self) -> usize {
33 self.sample_names.len()
34 }
35
36 #[inline]
37 pub fn total_entries(&self) -> usize {
38 self.buckets.iter().map(|b| b.len()).sum()
39 }
40
41 pub fn from_jam<P: AsRef<Path>>(path: P, db: &JamReader) -> Result<Self, QueryError> {
42 let source = JamReader::open(path)?;
43
44 if source.kmer_size() != db.kmer_size() {
45 return Err(QueryError::ParameterMismatch {
46 parameter: "k-mer size".to_string(),
47 source_value: source.kmer_size().to_string(),
48 target_value: db.kmer_size().to_string(),
49 });
50 }
51
52 if source.threshold() != db.threshold() {
53 return Err(QueryError::ParameterMismatch {
54 parameter: "hash threshold".to_string(),
55 source_value: source.threshold().to_string(),
56 target_value: db.threshold().to_string(),
57 });
58 }
59
60 let stats = source.stats();
61 let expected_sample_count = stats.sample_count as usize;
62
63 let sample_names = source.sample_names().to_vec();
64 if sample_names.len() != expected_sample_count {
65 return Err(QueryError::Parse {
66 path: "JAM file".to_string(),
67 message: format!(
68 "sample names count ({}) doesn't match header sample_count ({})",
69 sample_names.len(),
70 expected_sample_count
71 ),
72 });
73 }
74
75 let stored_sizes = source.sample_sizes();
76 let query_sizes: Vec<usize> = stored_sizes.iter().map(|&s| s as usize).collect();
77
78 let mut buckets: [Vec<(u64, u32)>; BUCKET_COUNT] = std::array::from_fn(|_| Vec::new());
79 for (bucket_idx, bucket) in buckets.iter_mut().enumerate() {
80 for entry in source.bucket_entries(bucket_idx) {
81 bucket.push((entry.hash, entry.sample_id));
82 }
83 }
84
85 Ok(Self {
86 buckets,
87 sample_names,
88 query_sizes,
89 })
90 }
91
92 pub fn from_fasta<P: AsRef<Path>>(
93 input: P,
94 db: &JamReader,
95 singleton: bool,
96 ) -> Result<Self, QueryError> {
97 let input_path = input.as_ref();
98 let kmer_size = db.kmer_size();
99 let threshold = db.threshold();
100 let bias_table = db.bias_table();
101
102 let mut reader = match parse_fastx_file(input_path) {
103 Ok(reader) => reader,
104 Err(e) if e.kind == needletail::errors::ParseErrorKind::EmptyFile => {
105 eprintln!(
106 "Empty file detected: {}, returning empty sketch",
107 input_path.display()
108 );
109 return Ok(Self::new());
110 }
111 Err(e) => {
112 return Err(QueryError::Parse {
113 path: input_path.display().to_string(),
114 message: e.to_string(),
115 });
116 }
117 };
118
119 let mut buckets: [Vec<(u64, u32)>; BUCKET_COUNT] = std::array::from_fn(|_| Vec::new());
120 let mut sample_names: Vec<String> = Vec::new();
121 let mut sample_hash_sets: Vec<HashSet<u64>> = Vec::new();
122 let mut current_sample_id: u32 = 0;
123
124 if !singleton {
125 sample_names.push(
126 input_path
127 .file_name()
128 .and_then(|s| s.to_str())
129 .unwrap_or("query")
130 .to_string(),
131 );
132 sample_hash_sets.push(HashSet::new());
133 }
134
135 while let Some(record) = reader.next() {
136 let record = record.map_err(|e| QueryError::Parse {
137 path: input_path.display().to_string(),
138 message: e.to_string(),
139 })?;
140
141 if singleton {
142 let name = std::str::from_utf8(record.id())
143 .unwrap_or("unknown")
144 .to_string();
145 sample_names.push(name);
146 sample_hash_sets.push(HashSet::new());
147 current_sample_id = (sample_names.len() - 1) as u32;
148 }
149
150 let sequence = record.normalize(false);
151 if sequence.len() < kmer_size as usize {
152 continue;
153 }
154
155 for (_, kmer, _) in sequence.bit_kmers(kmer_size, true) {
156 let hash = jamhash_u64(kmer.0);
157
158 if hash >= threshold {
159 continue;
160 }
161
162 if bias_table.as_ref().is_some_and(|b| !b.passes_filter(hash)) {
163 continue;
164 }
165
166 if sample_hash_sets[current_sample_id as usize].insert(hash) {
167 buckets[bucket_id(hash)].push((hash, current_sample_id));
168 }
169 }
170 }
171
172 for bucket in &mut buckets {
173 bucket.sort_unstable();
174 bucket.dedup();
175 }
176
177 let query_sizes: Vec<usize> = sample_hash_sets.iter().map(|set| set.len()).collect();
178
179 Ok(Self {
180 buckets,
181 sample_names,
182 query_sizes,
183 })
184 }
185
186 pub fn from_inputs(
187 inputs: &[std::path::PathBuf],
188 db: &JamReader,
189 singleton: bool,
190 ) -> Result<Self, QueryError> {
191 use crate::format::MAGIC;
192 use std::fs::File;
193 use std::io::Read;
194
195 if inputs.is_empty() {
196 return Ok(Self::new());
197 }
198
199 let is_jam_file = |path: &std::path::PathBuf| -> bool {
200 if path
201 .extension()
202 .is_some_and(|ext| ext.eq_ignore_ascii_case("jam"))
203 {
204 return true;
205 }
206 File::open(path)
207 .ok()
208 .and_then(|mut f| {
209 let mut magic = [0u8; 4];
210 f.read_exact(&mut magic).ok()?;
211 Some(magic == MAGIC)
212 })
213 .unwrap_or(false)
214 };
215
216 let mut combined = Self::new();
217
218 for input in inputs {
219 let sketch = if is_jam_file(input) {
220 Self::from_jam(input, db)?
221 } else {
222 Self::from_fasta(input, db, singleton)?
223 };
224
225 let sample_offset = combined.sample_count() as u32;
226 combined.sample_names.extend(sketch.sample_names);
227 combined.query_sizes.extend(sketch.query_sizes);
228
229 for (bucket_idx, bucket) in sketch.buckets.into_iter().enumerate() {
230 for (hash, sample_id) in bucket {
231 combined.buckets[bucket_idx].push((hash, sample_id + sample_offset));
232 }
233 }
234 }
235
236 for bucket in &mut combined.buckets {
237 bucket.sort_unstable();
238 }
239
240 Ok(combined)
241 }
242}
243
244impl Default for QuerySketch {
245 fn default() -> Self {
246 Self::new()
247 }
248}
249
250#[derive(Debug, thiserror::Error)]
251pub enum QueryError {
252 #[error("I/O error: {0}")]
253 Io(#[from] std::io::Error),
254
255 #[error("Database error: {0}")]
256 Database(#[from] ReaderError),
257
258 #[error("Parse error in {path}: {message}")]
259 Parse { path: String, message: String },
260
261 #[error(
262 "Parameter mismatch: {parameter} - source has {source_value}, target database has {target_value}"
263 )]
264 ParameterMismatch {
265 parameter: String,
266 source_value: String,
267 target_value: String,
268 },
269}
270
271#[derive(Debug, Clone)]
272pub struct SampleMatch {
273 pub sample_id: u32,
274 pub hit_count: u32,
275 pub containment: f64,
276}
277
278#[derive(Debug, Clone)]
279pub struct QueryResult {
280 pub query_size: usize,
281 pub hashes_found: usize,
282 pub matches: Vec<SampleMatch>,
283 pub failed_bucket_count: usize,
284}
285
286impl QueryResult {
287 pub fn top(&self, n: usize) -> Vec<&SampleMatch> {
288 let mut sorted: Vec<_> = self.matches.iter().collect();
289 sorted.sort_by(|a, b| b.containment.total_cmp(&a.containment));
290 sorted.truncate(n);
291 sorted
292 }
293
294 pub fn above_threshold(&self, min_containment: f64) -> Vec<&SampleMatch> {
295 self.matches
296 .iter()
297 .filter(|m| m.containment >= min_containment)
298 .collect()
299 }
300
301 pub fn has_matches(&self) -> bool {
302 !self.matches.is_empty()
303 }
304
305 pub fn is_partial(&self) -> bool {
306 self.failed_bucket_count > 0
307 }
308}
309
310pub struct QueryEngine {
311 reader: JamReader,
312 bias_table: Option<Arc<HashBiasTable>>,
313}
314
315impl QueryEngine {
316 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, ReaderError> {
317 let reader = JamReader::open(path)?;
318 let bias_table = reader.bias_table();
319 Ok(Self { reader, bias_table })
320 }
321
322 pub fn threshold(&self) -> u64 {
323 self.reader.threshold()
324 }
325
326 pub fn kmer_size(&self) -> u8 {
327 self.reader.kmer_size()
328 }
329
330 pub fn bias_table(&self) -> Option<Arc<HashBiasTable>> {
331 self.bias_table.clone()
332 }
333
334 pub fn has_bias_table(&self) -> bool {
335 self.bias_table.is_some()
336 }
337
338 pub fn reader(&self) -> &JamReader {
339 &self.reader
340 }
341
342 pub fn query(&self, hashes: &[u64]) -> QueryResult {
343 if hashes.is_empty() {
344 return QueryResult {
345 query_size: 0,
346 hashes_found: 0,
347 matches: Vec::new(),
348 failed_bucket_count: 0,
349 };
350 }
351
352 let mut sorted_hashes = hashes.to_vec();
353 sorted_hashes.sort_unstable_by_key(|&h| (h & 0xFF, h));
354
355 let mut sample_hits: HashMap<u32, u32> = HashMap::new();
356 let mut hashes_found = 0;
357
358 for &hash in &sorted_hashes {
359 let mut found = false;
360 for sample_id in self.reader.search(hash) {
361 *sample_hits.entry(sample_id).or_insert(0) += 1;
362 found = true;
363 }
364 if found {
365 hashes_found += 1;
366 }
367 }
368
369 let query_size = hashes.len();
370 let matches: Vec<SampleMatch> = sample_hits
371 .into_iter()
372 .map(|(sample_id, hit_count)| SampleMatch {
373 sample_id,
374 hit_count,
375 containment: hit_count as f64 / query_size as f64,
376 })
377 .collect();
378
379 QueryResult {
380 query_size,
381 hashes_found,
382 matches,
383 failed_bucket_count: 0,
384 }
385 }
386
387 pub fn query_filtered(
388 &self,
389 hashes: &[u64],
390 min_containment: f64,
391 max_results: usize,
392 ) -> QueryResult {
393 let mut result = self.query(hashes);
394 result.matches.retain(|m| m.containment >= min_containment);
395 result
396 .matches
397 .sort_by(|a, b| b.containment.total_cmp(&a.containment));
398 result.matches.truncate(max_results);
399 result
400 }
401
402 pub fn query_batch(&self, queries: &[Vec<u64>]) -> Vec<QueryResult> {
403 use rayon::prelude::*;
404 queries.par_iter().map(|q| self.query(q)).collect()
405 }
406
407 pub fn query_sketch(&self, sketch: &QuerySketch) -> Vec<QueryResult> {
408 use crate::format::{ENTRY_SIZE, PAGE_SIZE};
409 use rayon::prelude::*;
410 use std::sync::atomic::{AtomicU32, Ordering};
411
412 let num_samples = sketch.sample_count();
413 if num_samples == 0 {
414 return Vec::new();
415 }
416
417 let threshold = self.reader.threshold();
418
419 self.reader.advise_random();
420
421 let hashes_found: Vec<AtomicU32> = (0..num_samples)
422 .into_par_iter()
423 .map(|_| AtomicU32::new(0))
424 .collect();
425
426 let bucket_pairs: Vec<Vec<(u32, u32)>> = (0..BUCKET_COUNT)
427 .into_par_iter()
428 .map(|bucket_idx| {
429 let mut pairs = Vec::new();
430 let query_bucket = sketch.bucket(bucket_idx);
431 if query_bucket.is_empty() {
432 return pairs;
433 }
434
435 let filter = match self.reader.bucket_filter(bucket_idx) {
436 Some(f) => f,
437 None => return pairs,
438 };
439
440 let mut survivors = Vec::with_capacity(query_bucket.len() / 10);
441 let mut prev_hash = u64::MAX;
442 let mut prev_passed = false;
443
444 for &(hash, sample_id) in query_bucket {
445 if hash != prev_hash {
446 prev_hash = hash;
447 prev_passed = filter.contains(&hash);
448 }
449 if prev_passed {
450 survivors.push((hash, sample_id));
451 }
452 }
453
454 let (filter_start, filter_end) = self.reader.bucket_filter_byte_range(bucket_idx);
455 self.reader.release_pages(filter_start, filter_end);
456
457 if survivors.is_empty() {
458 return pairs;
459 }
460
461 let db_bucket = self.reader.bucket_entries(bucket_idx);
462 let count = db_bucket.len();
463 if count == 0 {
464 return pairs;
465 }
466
467 let (entry_start, _entry_end) = self.reader.bucket_entry_byte_range(bucket_idx);
468 let mut last_released_page = entry_start & !(PAGE_SIZE - 1);
469
470 let mut q_idx = 0;
471 while q_idx < survivors.len() {
472 let q_hash = survivors[q_idx].0;
473
474 let est = ((q_hash as u128 * count as u128) / threshold as u128) as usize;
475 let mut d_idx = est.saturating_sub(16).min(count.saturating_sub(1));
476
477 while d_idx > 0 && db_bucket[d_idx].hash > q_hash {
478 d_idx -= 1;
479 }
480
481 while d_idx < count && db_bucket[d_idx].hash < q_hash {
482 d_idx += 1;
483 }
484
485 while d_idx > 0 && db_bucket[d_idx - 1].hash == q_hash {
486 d_idx -= 1;
487 }
488
489 let current_byte = entry_start + d_idx * ENTRY_SIZE;
490 let current_page = current_byte & !(PAGE_SIZE - 1);
491 if current_page > last_released_page + PAGE_SIZE {
492 self.reader
493 .release_pages(last_released_page, current_page - PAGE_SIZE);
494 last_released_page = current_page - PAGE_SIZE;
495 }
496
497 let db_start = d_idx;
498 let mut db_end = d_idx;
499 while db_end < count && db_bucket[db_end].hash == q_hash {
500 db_end += 1;
501 }
502 let has_matches = db_start < db_end;
503
504 let mut prev_sample = u32::MAX;
505 while q_idx < survivors.len() && survivors[q_idx].0 == q_hash {
506 let q_sample = survivors[q_idx].1;
507
508 if q_sample != prev_sample {
509 if has_matches {
510 for db_entry in &db_bucket[db_start..db_end] {
511 pairs.push((q_sample, db_entry.sample_id));
512 }
513 hashes_found[q_sample as usize].fetch_add(1, Ordering::Relaxed);
514 }
515 prev_sample = q_sample;
516 }
517 q_idx += 1;
518 }
519 }
520
521 self.reader.release_bucket(bucket_idx);
522
523 pairs
524 })
525 .collect();
526
527 let bucket_sizes: Vec<usize> = bucket_pairs.iter().map(|v| v.len()).collect();
528 let total_pairs: usize = bucket_sizes.iter().sum();
529 let mut bucket_offsets = Vec::with_capacity(BUCKET_COUNT + 1);
530 bucket_offsets.push(0usize);
531 for size in &bucket_sizes {
532 bucket_offsets.push(bucket_offsets.last().unwrap() + size);
533 }
534
535 let mut all_pairs: Vec<(u32, u32)> = vec![(0, 0); total_pairs];
536 bucket_pairs
537 .into_par_iter()
538 .enumerate()
539 .for_each(|(bucket_idx, pairs)| {
540 let start = bucket_offsets[bucket_idx];
541 let dest = unsafe {
542 std::slice::from_raw_parts_mut(
543 all_pairs.as_ptr().add(start) as *mut (u32, u32),
544 pairs.len(),
545 )
546 };
547 dest.copy_from_slice(&pairs);
548 });
549
550 let merged_hashes_found: Vec<u32> = hashes_found
551 .into_par_iter()
552 .map(|a| a.load(Ordering::Relaxed))
553 .collect();
554
555 all_pairs.par_sort_unstable();
556
557 if all_pairs.is_empty() {
558 return (0..num_samples)
559 .map(|i| QueryResult {
560 query_size: sketch.query_sizes[i],
561 hashes_found: merged_hashes_found[i] as usize,
562 matches: Vec::new(),
563 failed_bucket_count: 0,
564 })
565 .collect();
566 }
567
568 let sample_starts: Vec<usize> = (0..num_samples as u32)
569 .into_par_iter()
570 .map(|q_sample| all_pairs.partition_point(|&(qs, _)| qs < q_sample))
571 .collect();
572
573 let results: Vec<QueryResult> = (0..num_samples)
574 .into_par_iter()
575 .map(|sample_idx| {
576 let q_sample = sample_idx as u32;
577 let start = sample_starts[sample_idx];
578 let end = if sample_idx + 1 < num_samples {
579 sample_starts[sample_idx + 1]
580 } else {
581 all_pairs.len()
582 };
583
584 let mut matches = Vec::new();
585 let query_size = sketch.query_sizes[sample_idx];
586
587 let mut i = start;
588 while i < end {
589 let (_, db_sample) = all_pairs[i];
590 let mut count = 1u32;
591 while i + (count as usize) < end
592 && all_pairs[i + count as usize] == (q_sample, db_sample)
593 {
594 count += 1;
595 }
596 matches.push(SampleMatch {
597 sample_id: db_sample,
598 hit_count: count,
599 containment: if query_size > 0 {
600 count as f64 / query_size as f64
601 } else {
602 0.0
603 },
604 });
605 i += count as usize;
606 }
607
608 QueryResult {
609 query_size,
610 hashes_found: merged_hashes_found[sample_idx] as usize,
611 matches,
612 failed_bucket_count: 0,
613 }
614 })
615 .collect();
616
617 results
618 }
619
620 pub fn query_fasta<P: AsRef<Path>>(
621 &self,
622 input: P,
623 singleton: bool,
624 ) -> Result<Vec<QueryResult>, QueryError> {
625 let sketch = QuerySketch::from_fasta(input, &self.reader, singleton)?;
626 Ok(self.query_sketch(&sketch))
627 }
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633 use crate::writer::{BuildConfig, build};
634 use std::io::Write;
635 use tempfile::NamedTempFile;
636
637 fn make_fasta(seqs: &[(&str, &str)]) -> NamedTempFile {
638 let mut f = NamedTempFile::with_suffix(".fa").unwrap();
639 for (name, seq) in seqs {
640 writeln!(f, ">{name}").unwrap();
641 writeln!(f, "{seq}").unwrap();
642 }
643 f
644 }
645
646 fn build_test_db(
647 seqs: &[(&str, &str)],
648 singleton: bool,
649 ) -> (tempfile::TempDir, std::path::PathBuf) {
650 let input = make_fasta(seqs);
651 let output_dir = tempfile::tempdir().unwrap();
652 let output_path = output_dir.path().join("test.jam");
653
654 let config = BuildConfig {
655 kmer_size: 11,
656 fscale: 1,
657 singleton,
658 num_threads: 1,
659 memory: 1,
660 ..Default::default()
661 };
662
663 build(&[input.path().to_path_buf()], &output_path, &config).unwrap();
664 (output_dir, output_path)
665 }
666
667 #[test]
668 fn test_query_engine_open() {
669 let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
670 let engine = QueryEngine::open(&path).unwrap();
671 assert!(engine.threshold() > 0);
672 assert_eq!(engine.kmer_size(), 11);
673 }
674
675 #[test]
676 fn test_query_basic() {
677 let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
678 let engine = QueryEngine::open(&path).unwrap();
679
680 let reader = JamReader::open(&path).unwrap();
681 let mut test_hashes = Vec::new();
682 for bucket_idx in 0..256 {
683 let entries = reader.bucket_entries(bucket_idx);
684 for entry in entries.iter().take(5) {
685 test_hashes.push(entry.hash);
686 }
687 if test_hashes.len() >= 10 {
688 break;
689 }
690 }
691
692 if !test_hashes.is_empty() {
693 let result = engine.query(&test_hashes);
694 assert!(result.has_matches());
695 assert!(result.hashes_found > 0);
696 }
697 }
698
699 #[test]
700 fn test_query_empty() {
701 let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
702 let engine = QueryEngine::open(&path).unwrap();
703
704 let result = engine.query(&[]);
705 assert!(!result.has_matches());
706 assert_eq!(result.query_size, 0);
707 }
708
709 #[test]
710 fn test_query_nonexistent() {
711 let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
712 let engine = QueryEngine::open(&path).unwrap();
713
714 let fake_hashes: Vec<u64> = (0..10).map(|i| u64::MAX - i).collect();
715 let result = engine.query(&fake_hashes);
716 assert_eq!(result.hashes_found, 0);
717 }
718
719 #[test]
720 fn test_query_filtered() {
721 let (_dir, path) = build_test_db(
722 &[
723 ("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
724 ("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
725 ],
726 true,
727 );
728 let engine = QueryEngine::open(&path).unwrap();
729
730 let reader = JamReader::open(&path).unwrap();
731 let mut test_hashes = Vec::new();
732 for bucket_idx in 0..256 {
733 for entry in reader.bucket_entries(bucket_idx) {
734 if entry.sample_id == 0 {
735 test_hashes.push(entry.hash);
736 }
737 if test_hashes.len() >= 20 {
738 break;
739 }
740 }
741 if test_hashes.len() >= 20 {
742 break;
743 }
744 }
745
746 if !test_hashes.is_empty() {
747 let result = engine.query_filtered(&test_hashes, 0.5, 10);
748 assert!(result.matches.len() <= 10);
749 for m in &result.matches {
750 assert!(m.containment >= 0.5);
751 }
752 }
753 }
754
755 #[test]
756 fn test_query_result_helpers() {
757 let result = QueryResult {
758 query_size: 100,
759 hashes_found: 50,
760 matches: vec![
761 SampleMatch {
762 sample_id: 0,
763 hit_count: 50,
764 containment: 0.5,
765 },
766 SampleMatch {
767 sample_id: 1,
768 hit_count: 30,
769 containment: 0.3,
770 },
771 SampleMatch {
772 sample_id: 2,
773 hit_count: 80,
774 containment: 0.8,
775 },
776 ],
777 failed_bucket_count: 0,
778 };
779
780 let top2 = result.top(2);
781 assert_eq!(top2.len(), 2);
782 assert_eq!(top2[0].sample_id, 2);
783 assert_eq!(top2[1].sample_id, 0);
784
785 let above_threshold = result.above_threshold(0.4);
786 assert_eq!(above_threshold.len(), 2);
787
788 assert!(result.has_matches());
789 assert!(!result.is_partial());
790 }
791
792 #[test]
793 fn test_query_sketch_new() {
794 let sketch = QuerySketch::new();
795
796 assert_eq!(sketch.sample_count(), 0);
797 assert_eq!(sketch.total_entries(), 0);
798 assert_eq!(sketch.buckets.len(), 256);
799 assert!(sketch.sample_names.is_empty());
800 assert!(sketch.query_sizes.is_empty());
801 }
802
803 #[test]
804 fn test_query_sketch_default() {
805 let sketch = QuerySketch::default();
806
807 assert_eq!(sketch.sample_count(), 0);
808 assert_eq!(sketch.total_entries(), 0);
809 }
810
811 #[test]
812 fn test_query_sketch_bucket_accessor() {
813 let mut sketch = QuerySketch::new();
814
815 sketch.buckets[0].push((100, 0));
816 sketch.buckets[0].push((200, 1));
817
818 sketch.buckets[255].push((300, 0));
819
820 let bucket_0 = sketch.bucket(0);
821 assert_eq!(bucket_0.len(), 2);
822 assert_eq!(bucket_0[0], (100, 0));
823 assert_eq!(bucket_0[1], (200, 1));
824
825 let bucket_255 = sketch.bucket(255);
826 assert_eq!(bucket_255.len(), 1);
827 assert_eq!(bucket_255[0], (300, 0));
828
829 let bucket_1 = sketch.bucket(1);
830 assert!(bucket_1.is_empty());
831 }
832
833 #[test]
834 fn test_query_sketch_sample_count() {
835 let mut sketch = QuerySketch::new();
836 assert_eq!(sketch.sample_count(), 0);
837
838 sketch.sample_names.push("sample1".to_string());
839 assert_eq!(sketch.sample_count(), 1);
840
841 sketch.sample_names.push("sample2".to_string());
842 sketch.sample_names.push("sample3".to_string());
843 assert_eq!(sketch.sample_count(), 3);
844 }
845
846 #[test]
847 fn test_query_sketch_total_entries() {
848 let mut sketch = QuerySketch::new();
849 assert_eq!(sketch.total_entries(), 0);
850
851 sketch.buckets[0].push((100, 0));
852 sketch.buckets[0].push((200, 0));
853 assert_eq!(sketch.total_entries(), 2);
854
855 sketch.buckets[50].push((300, 1));
856 assert_eq!(sketch.total_entries(), 3);
857
858 sketch.buckets[255].push((400, 0));
859 sketch.buckets[255].push((500, 1));
860 sketch.buckets[255].push((600, 2));
861 assert_eq!(sketch.total_entries(), 6);
862 }
863
864 #[test]
865 fn test_query_sketch_with_populated_fields() {
866 let mut sketch = QuerySketch::new();
867
868 sketch.sample_names = vec!["query_sample_1".to_string(), "query_sample_2".to_string()];
869
870 sketch.query_sizes = vec![1000, 500];
871
872 for i in 0..10 {
873 sketch.buckets[i].push((i as u64 * 100, 0));
874 sketch.buckets[i].push((i as u64 * 100 + 1, 1));
875 }
876
877 assert_eq!(sketch.sample_count(), 2);
878 assert_eq!(sketch.total_entries(), 20);
879 assert_eq!(sketch.query_sizes[0], 1000);
880 assert_eq!(sketch.query_sizes[1], 500);
881 assert_eq!(sketch.sample_names[0], "query_sample_1");
882 }
883
884 #[test]
885 #[should_panic]
886 fn test_query_sketch_bucket_out_of_bounds() {
887 let sketch = QuerySketch::new();
888 let _ = sketch.bucket(256); }
890
891 #[test]
892 fn test_query_sketch_empty() {
893 let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
894 let engine = QueryEngine::open(&path).unwrap();
895
896 let sketch = QuerySketch::new();
897 let results = engine.query_sketch(&sketch);
898 assert!(results.is_empty());
899 }
900
901 #[test]
902 fn test_query_sketch_single_sample() {
903 let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
904 let engine = QueryEngine::open(&path).unwrap();
905 let reader = JamReader::open(&path).unwrap();
906
907 let mut sketch = QuerySketch::new();
908 sketch.sample_names.push("query_sample".to_string());
909
910 let mut unique_hashes = std::collections::HashSet::new();
911 for bucket_idx in 0..256 {
912 for entry in reader.bucket_entries(bucket_idx) {
913 if unique_hashes.insert(entry.hash) {
914 sketch.buckets[bucket_idx].push((entry.hash, 0));
915 }
916 }
917 }
918 sketch.query_sizes.push(unique_hashes.len());
919
920 let results = engine.query_sketch(&sketch);
921
922 assert_eq!(results.len(), 1);
923 assert!(results[0].has_matches());
924
925 let db_sample_0_match = results[0].matches.iter().find(|m| m.sample_id == 0);
926 assert!(db_sample_0_match.is_some(), "Should match db sample 0");
927
928 let m = db_sample_0_match.unwrap();
929 assert!(
930 m.hit_count >= results[0].query_size as u32,
931 "Expected hit_count >= query_size, got {} vs {}",
932 m.hit_count,
933 results[0].query_size
934 );
935 }
936
937 #[test]
938 fn test_query_sketch_multiple_samples() {
939 let (_dir, path) = build_test_db(
940 &[
941 ("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
942 ("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
943 ],
944 true, );
946 let engine = QueryEngine::open(&path).unwrap();
947 let reader = JamReader::open(&path).unwrap();
948
949 let mut sketch = QuerySketch::new();
950 sketch.sample_names.push("query_0".to_string());
951 sketch.sample_names.push("query_1".to_string());
952
953 let mut hashes_per_sample: [std::collections::HashSet<u64>; 2] = Default::default();
954
955 for bucket_idx in 0..256 {
956 for entry in reader.bucket_entries(bucket_idx) {
957 let query_sample_id = entry.sample_id;
958 if (query_sample_id as usize) < 2 {
959 hashes_per_sample[query_sample_id as usize].insert(entry.hash);
960 sketch.buckets[bucket_idx].push((entry.hash, query_sample_id));
961 }
962 }
963 }
964
965 sketch.query_sizes.push(hashes_per_sample[0].len());
966 sketch.query_sizes.push(hashes_per_sample[1].len());
967
968 let results = engine.query_sketch(&sketch);
969
970 assert_eq!(results.len(), 2);
971
972 for (query_idx, result) in results.iter().enumerate() {
973 assert!(result.has_matches());
974 let self_match = result
975 .matches
976 .iter()
977 .find(|m| m.sample_id == query_idx as u32);
978 if let Some(m) = self_match {
979 assert!(
980 m.containment >= 0.9,
981 "Query {} should have high containment with DB sample {}, got {}",
982 query_idx,
983 query_idx,
984 m.containment
985 );
986 }
987 }
988 }
989
990 #[test]
991 fn test_query_sketch_no_matches() {
992 let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
993 let engine = QueryEngine::open(&path).unwrap();
994
995 let mut sketch = QuerySketch::new();
996 sketch.sample_names.push("fake_sample".to_string());
997 sketch.query_sizes.push(10);
998
999 for i in 0..10 {
1000 let fake_hash = u64::MAX - i;
1001 let bucket_idx = (fake_hash & 0xFF) as usize;
1002 sketch.buckets[bucket_idx].push((fake_hash, 0));
1003 }
1004
1005 let results = engine.query_sketch(&sketch);
1006
1007 assert_eq!(results.len(), 1);
1008 assert_eq!(results[0].hashes_found, 0);
1009 assert!(results[0].matches.is_empty());
1010 }
1011
1012 #[test]
1013 fn test_query_sketch_containment_calculation() {
1014 let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1015 let engine = QueryEngine::open(&path).unwrap();
1016 let reader = JamReader::open(&path).unwrap();
1017
1018 let mut sketch = QuerySketch::new();
1019 sketch.sample_names.push("half_sample".to_string());
1020
1021 let mut all_hashes = Vec::new();
1022 for bucket_idx in 0..256 {
1023 for entry in reader.bucket_entries(bucket_idx) {
1024 all_hashes.push((entry.hash, bucket_idx));
1025 }
1026 }
1027
1028 let selected_hashes: Vec<_> = all_hashes.iter().step_by(2).collect();
1029 sketch.query_sizes.push(selected_hashes.len());
1030
1031 for &(hash, bucket_idx) in &selected_hashes {
1032 sketch.buckets[*bucket_idx].push((*hash, 0));
1033 }
1034
1035 let results = engine.query_sketch(&sketch);
1036
1037 assert_eq!(results.len(), 1);
1038 assert!(results[0].has_matches());
1039 let top = results[0].top(1);
1040 assert!(!top.is_empty());
1041 assert!(top[0].containment >= 0.9);
1042 }
1043
1044 #[test]
1045 fn test_from_fasta_non_singleton() {
1046 let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1047 let db = JamReader::open(&db_path).unwrap();
1048
1049 let query_fasta = make_fasta(&[("query_seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1050
1051 let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, false).unwrap();
1052
1053 assert_eq!(sketch.sample_count(), 1);
1054 assert!(!sketch.sample_names[0].is_empty());
1055
1056 assert!(sketch.total_entries() > 0);
1057 assert!(sketch.query_sizes[0] > 0);
1058
1059 assert_eq!(sketch.query_sizes[0], sketch.total_entries());
1060 }
1061
1062 #[test]
1063 fn test_from_fasta_singleton() {
1064 let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1065 let db = JamReader::open(&db_path).unwrap();
1066
1067 let query_fasta = make_fasta(&[
1068 ("query_seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
1069 ("query_seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
1070 ]);
1071
1072 let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, true).unwrap();
1073
1074 assert_eq!(sketch.sample_count(), 2);
1075 assert_eq!(sketch.sample_names[0], "query_seq1");
1076 assert_eq!(sketch.sample_names[1], "query_seq2");
1077
1078 assert!(sketch.query_sizes[0] > 0);
1079 assert!(sketch.query_sizes[1] > 0);
1080
1081 let total_unique: usize = sketch.query_sizes.iter().sum();
1082 assert!(total_unique <= sketch.total_entries() + sketch.sample_count());
1083 }
1084
1085 #[test]
1086 fn test_from_fasta_uses_db_parameters() {
1087 let input = make_fasta(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1088 let output_dir = tempfile::tempdir().unwrap();
1089 let db_path = output_dir.path().join("test.jam");
1090
1091 let config = BuildConfig {
1092 kmer_size: 15,
1093 fscale: 10,
1094 singleton: false,
1095 num_threads: 1,
1096 memory: 1,
1097 ..Default::default()
1098 };
1099
1100 build(&[input.path().to_path_buf()], &db_path, &config).unwrap();
1101 let db = JamReader::open(&db_path).unwrap();
1102
1103 assert_eq!(db.kmer_size(), 15);
1104
1105 let query_fasta = make_fasta(&[("query", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1106
1107 let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, false).unwrap();
1108
1109 assert!(sketch.sample_count() == 1);
1110 }
1111
1112 #[test]
1113 fn test_from_fasta_deduplication() {
1114 let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1115 let db = JamReader::open(&db_path).unwrap();
1116
1117 let query_fasta = make_fasta(&[(
1118 "query",
1119 "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG",
1120 )]);
1121
1122 let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, false).unwrap();
1123
1124 assert_eq!(sketch.query_sizes[0], sketch.total_entries());
1125 }
1126
1127 #[test]
1128 fn test_from_fasta_bucketization() {
1129 let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1130 let db = JamReader::open(&db_path).unwrap();
1131
1132 let query_fasta = make_fasta(&[(
1133 "query",
1134 "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG",
1135 )]);
1136
1137 let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, false).unwrap();
1138
1139 for (bucket_idx, bucket) in sketch.buckets.iter().enumerate() {
1140 for &(hash, _sample_id) in bucket {
1141 assert_eq!(
1142 bucket_id(hash),
1143 bucket_idx,
1144 "Hash {} should be in bucket {}, not {}",
1145 hash,
1146 bucket_id(hash),
1147 bucket_idx
1148 );
1149 }
1150 }
1151 }
1152
1153 #[test]
1154 fn test_from_fasta_sorted_buckets() {
1155 let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1156 let db = JamReader::open(&db_path).unwrap();
1157
1158 let query_fasta = make_fasta(&[
1159 (
1160 "query1",
1161 "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG",
1162 ),
1163 (
1164 "query2",
1165 "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA",
1166 ),
1167 ]);
1168
1169 let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, true).unwrap();
1170
1171 for bucket in &sketch.buckets {
1172 for window in bucket.windows(2) {
1173 assert!(
1174 window[0] <= window[1],
1175 "Bucket not sorted: {:?} > {:?}",
1176 window[0],
1177 window[1]
1178 );
1179 }
1180 }
1181 }
1182
1183 #[test]
1184 fn test_from_fasta_short_sequences_skipped() {
1185 let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1186 let db = JamReader::open(&db_path).unwrap();
1187 assert_eq!(db.kmer_size(), 11);
1188
1189 let query_fasta = make_fasta(&[
1190 ("short", "ATCGATCG"), ("long", "ATCGATCGATCGATCGATCGATCGATCGATCG"), ]);
1193
1194 let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, true).unwrap();
1195
1196 assert_eq!(sketch.sample_count(), 2);
1197
1198 assert_eq!(sketch.query_sizes[0], 0);
1199
1200 assert!(sketch.query_sizes[1] > 0);
1201 }
1202
1203 #[test]
1204 fn test_from_fasta_file_not_found() {
1205 let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1206 let db = JamReader::open(&db_path).unwrap();
1207
1208 let result = QuerySketch::from_fasta("/nonexistent/path.fasta", &db, false);
1209 assert!(result.is_err());
1210
1211 if let Err(QueryError::Parse { path, message: _ }) = result {
1212 assert!(path.contains("nonexistent"));
1213 } else {
1214 panic!("Expected Parse error");
1215 }
1216 }
1217
1218 #[test]
1219 fn test_from_fasta_integration_with_query_engine() {
1220 let (_dir, db_path) =
1221 build_test_db(&[("db_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1222 let db = JamReader::open(&db_path).unwrap();
1223 let engine = QueryEngine::open(&db_path).unwrap();
1224
1225 let query_fasta = make_fasta(&[("query_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1226
1227 let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, false).unwrap();
1228
1229 let results = engine.query_sketch(&sketch);
1230
1231 assert_eq!(results.len(), 1);
1232 assert!(results[0].has_matches());
1233
1234 let top = results[0].top(1);
1235 assert!(!top.is_empty());
1236 assert!(
1237 top[0].containment >= 0.9,
1238 "Expected high containment, got {}",
1239 top[0].containment
1240 );
1241 }
1242
1243 fn build_test_db_with_params(
1244 seqs: &[(&str, &str)],
1245 kmer_size: u8,
1246 fscale: u64,
1247 singleton: bool,
1248 ) -> (tempfile::TempDir, std::path::PathBuf) {
1249 let input = make_fasta(seqs);
1250 let output_dir = tempfile::tempdir().unwrap();
1251 let output_path = output_dir.path().join("test.jam");
1252
1253 let config = BuildConfig {
1254 kmer_size,
1255 fscale,
1256 singleton,
1257 num_threads: 1,
1258 memory: 1,
1259 ..Default::default()
1260 };
1261
1262 build(&[input.path().to_path_buf()], &output_path, &config).unwrap();
1263 (output_dir, output_path)
1264 }
1265
1266 #[test]
1267 fn test_from_jam_success() {
1268 let (_dir1, db_path) = build_test_db_with_params(
1269 &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1270 11,
1271 1,
1272 false,
1273 );
1274 let (_dir2, query_path) = build_test_db_with_params(
1275 &[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1276 11,
1277 1,
1278 false,
1279 );
1280
1281 let db = JamReader::open(&db_path).unwrap();
1282 let sketch = QuerySketch::from_jam(&query_path, &db).unwrap();
1283
1284 assert_eq!(sketch.sample_count(), 1);
1285 assert!(sketch.total_entries() > 0);
1286 assert!(!sketch.sample_names[0].is_empty());
1287 assert!(sketch.query_sizes[0] > 0);
1288 }
1289
1290 #[test]
1291 fn test_from_jam_multiple_samples() {
1292 let (_dir1, db_path) = build_test_db_with_params(
1293 &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1294 11,
1295 1,
1296 false,
1297 );
1298 let (_dir2, query_path) = build_test_db_with_params(
1299 &[
1300 ("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
1301 ("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
1302 ],
1303 11,
1304 1,
1305 true,
1306 );
1307
1308 let db = JamReader::open(&db_path).unwrap();
1309 let sketch = QuerySketch::from_jam(&query_path, &db).unwrap();
1310
1311 assert_eq!(sketch.sample_count(), 2);
1312 assert_eq!(sketch.sample_names[0], "seq1");
1313 assert_eq!(sketch.sample_names[1], "seq2");
1314 assert_eq!(sketch.query_sizes.len(), 2);
1315 assert!(sketch.query_sizes[0] > 0);
1316 assert!(sketch.query_sizes[1] > 0);
1317 }
1318
1319 #[test]
1320 fn test_from_jam_kmer_size_mismatch() {
1321 let (_dir1, db_path) = build_test_db_with_params(
1322 &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1323 11,
1324 1,
1325 false,
1326 );
1327 let (_dir2, query_path) = build_test_db_with_params(
1328 &[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1329 21,
1330 1,
1331 false,
1332 );
1333
1334 let db = JamReader::open(&db_path).unwrap();
1335 let result = QuerySketch::from_jam(&query_path, &db);
1336
1337 assert!(result.is_err());
1338 let err = result.unwrap_err();
1339 match err {
1340 QueryError::ParameterMismatch {
1341 parameter,
1342 source_value,
1343 target_value,
1344 } => {
1345 assert!(parameter.contains("k-mer"));
1346 assert_eq!(source_value, "21");
1347 assert_eq!(target_value, "11");
1348 }
1349 _ => panic!("Expected ParameterMismatch error, got {:?}", err),
1350 }
1351 }
1352
1353 #[test]
1354 fn test_from_jam_threshold_mismatch() {
1355 let (_dir1, db_path) = build_test_db_with_params(
1356 &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1357 11,
1358 1,
1359 false,
1360 );
1361 let (_dir2, query_path) = build_test_db_with_params(
1362 &[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1363 11,
1364 1000,
1365 false,
1366 );
1367
1368 let db = JamReader::open(&db_path).unwrap();
1369 let result = QuerySketch::from_jam(&query_path, &db);
1370
1371 assert!(result.is_err());
1372 let err = result.unwrap_err();
1373 match err {
1374 QueryError::ParameterMismatch {
1375 parameter,
1376 source_value,
1377 target_value,
1378 } => {
1379 assert!(parameter.contains("threshold"));
1380 assert_ne!(source_value, target_value);
1381 }
1382 _ => panic!("Expected ParameterMismatch error, got {:?}", err),
1383 }
1384 }
1385
1386 #[test]
1387 fn test_from_jam_preserves_bucketization() {
1388 let (_dir1, db_path) = build_test_db_with_params(
1389 &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1390 11,
1391 1,
1392 false,
1393 );
1394 let (_dir2, query_path) = build_test_db_with_params(
1395 &[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1396 11,
1397 1,
1398 false,
1399 );
1400
1401 let db = JamReader::open(&db_path).unwrap();
1402 let sketch = QuerySketch::from_jam(&query_path, &db).unwrap();
1403
1404 for bucket_idx in 0..BUCKET_COUNT {
1405 for &(hash, _sample_id) in sketch.bucket(bucket_idx) {
1406 assert_eq!(
1407 bucket_id(hash),
1408 bucket_idx,
1409 "Entry with hash {} is in wrong bucket",
1410 hash
1411 );
1412 }
1413 }
1414 }
1415
1416 #[test]
1417 fn test_from_jam_query_sizes_correct() {
1418 let (_dir1, db_path) = build_test_db_with_params(
1419 &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1420 11,
1421 1,
1422 false,
1423 );
1424 let (_dir2, query_path) = build_test_db_with_params(
1425 &[
1426 ("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
1427 ("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
1428 ],
1429 11,
1430 1,
1431 true,
1432 );
1433
1434 let db = JamReader::open(&db_path).unwrap();
1435 let sketch = QuerySketch::from_jam(&query_path, &db).unwrap();
1436
1437 for (sample_id, &expected_size) in sketch.query_sizes.iter().enumerate() {
1438 let mut unique_hashes = std::collections::HashSet::new();
1439 for bucket_idx in 0..BUCKET_COUNT {
1440 for &(hash, sid) in sketch.bucket(bucket_idx) {
1441 if sid as usize == sample_id {
1442 unique_hashes.insert(hash);
1443 }
1444 }
1445 }
1446 assert_eq!(
1447 unique_hashes.len(),
1448 expected_size,
1449 "query_sizes[{}] should match actual unique hash count",
1450 sample_id
1451 );
1452 }
1453 }
1454
1455 #[test]
1456 fn test_from_jam_empty_source() {
1457 let (_dir1, db_path) = build_test_db_with_params(
1458 &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1459 11,
1460 1_000_000,
1461 false,
1462 );
1463 let (_dir2, query_path) = build_test_db_with_params(
1464 &[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1465 11,
1466 1_000_000,
1467 false,
1468 );
1469
1470 let db = JamReader::open(&db_path).unwrap();
1471 let result = QuerySketch::from_jam(&query_path, &db);
1472
1473 assert!(result.is_ok());
1474 }
1475
1476 #[test]
1477 fn test_from_inputs_empty() {
1478 let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1479 let db = JamReader::open(&db_path).unwrap();
1480
1481 let sketch = QuerySketch::from_inputs(&[], &db, false).unwrap();
1482
1483 assert_eq!(sketch.sample_count(), 0);
1484 assert_eq!(sketch.total_entries(), 0);
1485 }
1486
1487 #[test]
1488 fn test_from_inputs_single_fasta() {
1489 let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1490 let db = JamReader::open(&db_path).unwrap();
1491
1492 let query_fasta = make_fasta(&[("query_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1493
1494 let sketch =
1495 QuerySketch::from_inputs(&[query_fasta.path().to_path_buf()], &db, false).unwrap();
1496
1497 assert_eq!(sketch.sample_count(), 1);
1498 assert!(sketch.total_entries() > 0);
1499 }
1500
1501 #[test]
1502 fn test_from_inputs_single_jam() {
1503 let (_dir1, db_path) = build_test_db_with_params(
1504 &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1505 11,
1506 1,
1507 false,
1508 );
1509 let (_dir2, query_jam) = build_test_db_with_params(
1510 &[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1511 11,
1512 1,
1513 false,
1514 );
1515
1516 let db = JamReader::open(&db_path).unwrap();
1517
1518 let sketch = QuerySketch::from_inputs(&[query_jam], &db, false).unwrap();
1519
1520 assert_eq!(sketch.sample_count(), 1);
1521 assert!(sketch.total_entries() > 0);
1522 }
1523
1524 #[test]
1525 fn test_from_inputs_multiple_fasta() {
1526 let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1527 let db = JamReader::open(&db_path).unwrap();
1528
1529 let fasta1 = make_fasta(&[("query1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1530 let fasta2 = make_fasta(&[("query2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")]);
1531
1532 let sketch = QuerySketch::from_inputs(
1533 &[fasta1.path().to_path_buf(), fasta2.path().to_path_buf()],
1534 &db,
1535 false,
1536 )
1537 .unwrap();
1538
1539 assert_eq!(sketch.sample_count(), 2);
1540 assert!(sketch.total_entries() > 0);
1541 assert_eq!(sketch.query_sizes.len(), 2);
1542 }
1543
1544 #[test]
1545 fn test_from_inputs_multiple_fasta_singleton() {
1546 let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1547 let db = JamReader::open(&db_path).unwrap();
1548
1549 let fasta1 = make_fasta(&[
1550 ("seq1a", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
1551 ("seq1b", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
1552 ]);
1553 let fasta2 = make_fasta(&[
1554 ("seq2a", "TATATATATATATATATATATATATATATATA"),
1555 ("seq2b", "GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC"),
1556 ]);
1557
1558 let sketch = QuerySketch::from_inputs(
1559 &[fasta1.path().to_path_buf(), fasta2.path().to_path_buf()],
1560 &db,
1561 true,
1562 )
1563 .unwrap();
1564
1565 assert_eq!(sketch.sample_count(), 4);
1566 assert_eq!(sketch.sample_names.len(), 4);
1567 assert_eq!(sketch.sample_names[0], "seq1a");
1568 assert_eq!(sketch.sample_names[1], "seq1b");
1569 assert_eq!(sketch.sample_names[2], "seq2a");
1570 assert_eq!(sketch.sample_names[3], "seq2b");
1571 }
1572
1573 #[test]
1574 fn test_from_inputs_mixed_fasta_and_jam() {
1575 let (_dir1, db_path) = build_test_db_with_params(
1576 &[("db_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1577 11,
1578 1,
1579 false,
1580 );
1581 let db = JamReader::open(&db_path).unwrap();
1582
1583 let query_fasta = make_fasta(&[("fasta_query", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1584
1585 let (_dir2, query_jam) = build_test_db_with_params(
1586 &[("jam_query", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1587 11,
1588 1,
1589 false,
1590 );
1591
1592 let sketch =
1593 QuerySketch::from_inputs(&[query_fasta.path().to_path_buf(), query_jam], &db, false)
1594 .unwrap();
1595
1596 assert_eq!(sketch.sample_count(), 2);
1597 assert!(sketch.total_entries() > 0);
1598 }
1599
1600 #[test]
1601 fn test_from_inputs_sample_id_renumbering() {
1602 let (_dir1, db_path) = build_test_db_with_params(
1603 &[("db_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1604 11,
1605 1,
1606 false,
1607 );
1608 let db = JamReader::open(&db_path).unwrap();
1609
1610 let (_dir2, jam1) = build_test_db_with_params(
1611 &[
1612 ("seq1a", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
1613 ("seq1b", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
1614 ],
1615 11,
1616 1,
1617 true,
1618 );
1619 let (_dir3, jam2) = build_test_db_with_params(
1620 &[
1621 ("seq2a", "TATATATATATATATATATATATATATATATA"),
1622 ("seq2b", "GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC"),
1623 ],
1624 11,
1625 1,
1626 true,
1627 );
1628
1629 let sketch = QuerySketch::from_inputs(&[jam1, jam2], &db, false).unwrap();
1630
1631 assert_eq!(sketch.sample_count(), 4);
1632
1633 for bucket in &sketch.buckets {
1634 for &(_hash, sample_id) in bucket {
1635 assert!(sample_id < 4, "Sample ID {} should be < 4", sample_id);
1636 }
1637 }
1638
1639 let mut seen_samples = std::collections::HashSet::new();
1640 for bucket in &sketch.buckets {
1641 for &(_hash, sample_id) in bucket {
1642 seen_samples.insert(sample_id);
1643 }
1644 }
1645 assert_eq!(seen_samples.len(), 4, "All samples should have entries");
1646 }
1647
1648 #[test]
1649 fn test_from_inputs_buckets_sorted() {
1650 let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1651 let db = JamReader::open(&db_path).unwrap();
1652
1653 let fasta1 = make_fasta(&[("q1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1654 let fasta2 = make_fasta(&[("q2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")]);
1655
1656 let sketch = QuerySketch::from_inputs(
1657 &[fasta1.path().to_path_buf(), fasta2.path().to_path_buf()],
1658 &db,
1659 false,
1660 )
1661 .unwrap();
1662
1663 for bucket in &sketch.buckets {
1664 for window in bucket.windows(2) {
1665 assert!(
1666 window[0] <= window[1],
1667 "Bucket not sorted: {:?} > {:?}",
1668 window[0],
1669 window[1]
1670 );
1671 }
1672 }
1673 }
1674
1675 #[test]
1676 fn test_from_inputs_query_sizes_preserved() {
1677 let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1678 let db = JamReader::open(&db_path).unwrap();
1679
1680 let fasta1 = make_fasta(&[("q1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1681 let fasta2 = make_fasta(&[("q2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")]);
1682
1683 let sketch1 = QuerySketch::from_fasta(fasta1.path(), &db, false).unwrap();
1684 let sketch2 = QuerySketch::from_fasta(fasta2.path(), &db, false).unwrap();
1685
1686 let fasta1_new = make_fasta(&[("q1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1687 let fasta2_new = make_fasta(&[("q2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")]);
1688
1689 let combined = QuerySketch::from_inputs(
1690 &[
1691 fasta1_new.path().to_path_buf(),
1692 fasta2_new.path().to_path_buf(),
1693 ],
1694 &db,
1695 false,
1696 )
1697 .unwrap();
1698
1699 assert_eq!(combined.query_sizes.len(), 2);
1700 assert_eq!(combined.query_sizes[0], sketch1.query_sizes[0]);
1701 assert_eq!(combined.query_sizes[1], sketch2.query_sizes[0]);
1702 }
1703
1704 #[test]
1705 fn test_from_inputs_jam_detection_by_extension() {
1706 let (_dir1, db_path) = build_test_db_with_params(
1707 &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1708 11,
1709 1,
1710 false,
1711 );
1712 let db = JamReader::open(&db_path).unwrap();
1713
1714 let (_dir2, jam_path) = build_test_db_with_params(
1715 &[("jam_seq", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1716 11,
1717 1,
1718 false,
1719 );
1720
1721 assert_eq!(jam_path.extension().unwrap(), "jam");
1722
1723 let sketch = QuerySketch::from_inputs(&[jam_path], &db, false).unwrap();
1724
1725 assert_eq!(sketch.sample_count(), 1);
1726 assert!(!sketch.sample_names[0].is_empty());
1727 }
1728
1729 #[test]
1730 fn test_from_inputs_propagates_errors() {
1731 let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1732 let db = JamReader::open(&db_path).unwrap();
1733
1734 let result = QuerySketch::from_inputs(
1735 &[std::path::PathBuf::from("/nonexistent/file.fasta")],
1736 &db,
1737 false,
1738 );
1739
1740 assert!(result.is_err());
1741 }
1742
1743 #[test]
1744 fn test_from_inputs_jam_parameter_mismatch_propagates() {
1745 let (_dir1, db_path) = build_test_db_with_params(
1746 &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1747 11, 1,
1749 false,
1750 );
1751 let db = JamReader::open(&db_path).unwrap();
1752
1753 let (_dir2, jam_path) = build_test_db_with_params(
1754 &[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1755 21,
1756 1,
1757 false,
1758 );
1759
1760 let result = QuerySketch::from_inputs(&[jam_path], &db, false);
1761
1762 assert!(result.is_err());
1763 match result.unwrap_err() {
1764 QueryError::ParameterMismatch { parameter, .. } => {
1765 assert!(parameter.contains("k-mer"));
1766 }
1767 e => panic!("Expected ParameterMismatch error, got {:?}", e),
1768 }
1769 }
1770
1771 #[test]
1772 fn test_from_inputs_integration_with_query_engine() {
1773 let (_dir1, db_path) = build_test_db_with_params(
1774 &[("db_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1775 11,
1776 1,
1777 false,
1778 );
1779 let db = JamReader::open(&db_path).unwrap();
1780 let engine = QueryEngine::open(&db_path).unwrap();
1781
1782 let query_fasta = make_fasta(&[("same_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1783 let (_dir2, query_jam) = build_test_db_with_params(
1784 &[("different_seq", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1785 11,
1786 1,
1787 false,
1788 );
1789
1790 let sketch =
1791 QuerySketch::from_inputs(&[query_fasta.path().to_path_buf(), query_jam], &db, false)
1792 .unwrap();
1793
1794 assert_eq!(sketch.sample_count(), 2);
1795
1796 let results = engine.query_sketch(&sketch);
1797
1798 assert_eq!(results.len(), 2);
1799
1800 assert!(results[0].has_matches());
1801 let top0 = results[0].top(1);
1802 assert!(!top0.is_empty());
1803 assert!(
1804 top0[0].containment >= 0.9,
1805 "Same sequence should have high containment, got {}",
1806 top0[0].containment
1807 );
1808
1809 }
1810
1811 #[test]
1812 fn test_query_fasta_non_singleton() {
1813 let (_dir, db_path) =
1814 build_test_db(&[("db_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1815 let engine = QueryEngine::open(&db_path).unwrap();
1816
1817 let query_fasta = make_fasta(&[("query_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1818
1819 let results = engine.query_fasta(query_fasta.path(), false).unwrap();
1820
1821 assert_eq!(results.len(), 1);
1822 assert!(results[0].has_matches());
1823
1824 let top = results[0].top(1);
1825 assert!(!top.is_empty());
1826 assert!(
1827 top[0].containment >= 0.9,
1828 "Expected high containment, got {}",
1829 top[0].containment
1830 );
1831 }
1832
1833 #[test]
1834 fn test_query_fasta_singleton() {
1835 let (_dir, db_path) = build_test_db(
1836 &[
1837 ("db_seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
1838 ("db_seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
1839 ],
1840 true,
1841 );
1842 let engine = QueryEngine::open(&db_path).unwrap();
1843
1844 let query_fasta = make_fasta(&[
1845 ("query1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
1846 ("query2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
1847 ]);
1848
1849 let results = engine.query_fasta(query_fasta.path(), true).unwrap();
1850
1851 assert_eq!(results.len(), 2);
1852
1853 assert!(results[0].has_matches());
1854 assert!(results[1].has_matches());
1855
1856 for (i, result) in results.iter().enumerate() {
1857 let self_match = result.matches.iter().find(|m| m.sample_id == i as u32);
1858 if let Some(m) = self_match {
1859 assert!(
1860 m.containment >= 0.9,
1861 "Query {} should have high containment with DB sample {}, got {}",
1862 i,
1863 i,
1864 m.containment
1865 );
1866 }
1867 }
1868 }
1869
1870 #[test]
1871 fn test_query_fasta_file_not_found() {
1872 let (_dir, db_path) =
1873 build_test_db(&[("db_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1874 let engine = QueryEngine::open(&db_path).unwrap();
1875
1876 let result = engine.query_fasta("/nonexistent/path.fasta", false);
1877
1878 assert!(result.is_err());
1879 match result.unwrap_err() {
1880 QueryError::Parse { path, message: _ } => {
1881 assert!(path.contains("nonexistent"));
1882 }
1883 e => panic!("Expected Parse error, got {:?}", e),
1884 }
1885 }
1886}