Skip to main content

jam_rs/
query.rs

1use crate::bias::HashBiasTable;
2use crate::format::{BUCKET_COUNT, bucket_id};
3use crate::reader::{JamReader, ReaderError};
4use jamhash::jamhash_u64;
5use needletail::{Sequence, parse_fastx_file};
6use std::collections::{HashMap, HashSet};
7use std::path::Path;
8use std::sync::Arc;
9
10#[derive(Debug)]
11pub struct QuerySketch {
12    pub buckets: [Vec<(u64, u32)>; BUCKET_COUNT],
13    pub sample_names: Vec<String>,
14    pub query_sizes: Vec<usize>,
15}
16
17impl QuerySketch {
18    pub fn new() -> Self {
19        Self {
20            buckets: std::array::from_fn(|_| Vec::new()),
21            sample_names: Vec::new(),
22            query_sizes: Vec::new(),
23        }
24    }
25
26    #[inline]
27    pub fn bucket(&self, idx: usize) -> &[(u64, u32)] {
28        &self.buckets[idx]
29    }
30
31    #[inline]
32    pub fn sample_count(&self) -> usize {
33        self.sample_names.len()
34    }
35
36    #[inline]
37    pub fn total_entries(&self) -> usize {
38        self.buckets.iter().map(|b| b.len()).sum()
39    }
40
41    pub fn from_jam<P: AsRef<Path>>(path: P, db: &JamReader) -> Result<Self, QueryError> {
42        let source = JamReader::open(path)?;
43
44        if source.kmer_size() != db.kmer_size() {
45            return Err(QueryError::ParameterMismatch {
46                parameter: "k-mer size".to_string(),
47                source_value: source.kmer_size().to_string(),
48                target_value: db.kmer_size().to_string(),
49            });
50        }
51
52        if source.threshold() != db.threshold() {
53            return Err(QueryError::ParameterMismatch {
54                parameter: "hash threshold".to_string(),
55                source_value: source.threshold().to_string(),
56                target_value: db.threshold().to_string(),
57            });
58        }
59
60        let stats = source.stats();
61        let expected_sample_count = stats.sample_count as usize;
62
63        let sample_names = source.sample_names().to_vec();
64        if sample_names.len() != expected_sample_count {
65            return Err(QueryError::Parse {
66                path: "JAM file".to_string(),
67                message: format!(
68                    "sample names count ({}) doesn't match header sample_count ({})",
69                    sample_names.len(),
70                    expected_sample_count
71                ),
72            });
73        }
74
75        let stored_sizes = source.sample_sizes();
76        let query_sizes: Vec<usize> = stored_sizes.iter().map(|&s| s as usize).collect();
77
78        let mut buckets: [Vec<(u64, u32)>; BUCKET_COUNT] = std::array::from_fn(|_| Vec::new());
79        for (bucket_idx, bucket) in buckets.iter_mut().enumerate() {
80            for entry in source.bucket_entries(bucket_idx) {
81                bucket.push((entry.hash, entry.sample_id));
82            }
83        }
84
85        Ok(Self {
86            buckets,
87            sample_names,
88            query_sizes,
89        })
90    }
91
92    pub fn from_fasta<P: AsRef<Path>>(
93        input: P,
94        db: &JamReader,
95        singleton: bool,
96    ) -> Result<Self, QueryError> {
97        let input_path = input.as_ref();
98        let kmer_size = db.kmer_size();
99        let threshold = db.threshold();
100        let bias_table = db.bias_table();
101
102        let mut reader = match parse_fastx_file(input_path) {
103            Ok(reader) => reader,
104            Err(e) if e.kind == needletail::errors::ParseErrorKind::EmptyFile => {
105                eprintln!(
106                    "Empty file detected: {}, returning empty sketch",
107                    input_path.display()
108                );
109                return Ok(Self::new());
110            }
111            Err(e) => {
112                return Err(QueryError::Parse {
113                    path: input_path.display().to_string(),
114                    message: e.to_string(),
115                });
116            }
117        };
118
119        let mut buckets: [Vec<(u64, u32)>; BUCKET_COUNT] = std::array::from_fn(|_| Vec::new());
120        let mut sample_names: Vec<String> = Vec::new();
121        let mut sample_hash_sets: Vec<HashSet<u64>> = Vec::new();
122        let mut current_sample_id: u32 = 0;
123
124        if !singleton {
125            sample_names.push(
126                input_path
127                    .file_name()
128                    .and_then(|s| s.to_str())
129                    .unwrap_or("query")
130                    .to_string(),
131            );
132            sample_hash_sets.push(HashSet::new());
133        }
134
135        while let Some(record) = reader.next() {
136            let record = record.map_err(|e| QueryError::Parse {
137                path: input_path.display().to_string(),
138                message: e.to_string(),
139            })?;
140
141            if singleton {
142                let name = std::str::from_utf8(record.id())
143                    .unwrap_or("unknown")
144                    .to_string();
145                sample_names.push(name);
146                sample_hash_sets.push(HashSet::new());
147                current_sample_id = (sample_names.len() - 1) as u32;
148            }
149
150            let sequence = record.normalize(false);
151            if sequence.len() < kmer_size as usize {
152                continue;
153            }
154
155            for (_, kmer, _) in sequence.bit_kmers(kmer_size, true) {
156                let hash = jamhash_u64(kmer.0);
157
158                if hash >= threshold {
159                    continue;
160                }
161
162                if bias_table.as_ref().is_some_and(|b| !b.passes_filter(hash)) {
163                    continue;
164                }
165
166                if sample_hash_sets[current_sample_id as usize].insert(hash) {
167                    buckets[bucket_id(hash)].push((hash, current_sample_id));
168                }
169            }
170        }
171
172        for bucket in &mut buckets {
173            bucket.sort_unstable();
174            bucket.dedup();
175        }
176
177        let query_sizes: Vec<usize> = sample_hash_sets.iter().map(|set| set.len()).collect();
178
179        Ok(Self {
180            buckets,
181            sample_names,
182            query_sizes,
183        })
184    }
185
186    pub fn from_inputs(
187        inputs: &[std::path::PathBuf],
188        db: &JamReader,
189        singleton: bool,
190    ) -> Result<Self, QueryError> {
191        use crate::format::MAGIC;
192        use std::fs::File;
193        use std::io::Read;
194
195        if inputs.is_empty() {
196            return Ok(Self::new());
197        }
198
199        let is_jam_file = |path: &std::path::PathBuf| -> bool {
200            if path
201                .extension()
202                .is_some_and(|ext| ext.eq_ignore_ascii_case("jam"))
203            {
204                return true;
205            }
206            File::open(path)
207                .ok()
208                .and_then(|mut f| {
209                    let mut magic = [0u8; 4];
210                    f.read_exact(&mut magic).ok()?;
211                    Some(magic == MAGIC)
212                })
213                .unwrap_or(false)
214        };
215
216        let mut combined = Self::new();
217
218        for input in inputs {
219            let sketch = if is_jam_file(input) {
220                Self::from_jam(input, db)?
221            } else {
222                Self::from_fasta(input, db, singleton)?
223            };
224
225            let sample_offset = combined.sample_count() as u32;
226            combined.sample_names.extend(sketch.sample_names);
227            combined.query_sizes.extend(sketch.query_sizes);
228
229            for (bucket_idx, bucket) in sketch.buckets.into_iter().enumerate() {
230                for (hash, sample_id) in bucket {
231                    combined.buckets[bucket_idx].push((hash, sample_id + sample_offset));
232                }
233            }
234        }
235
236        for bucket in &mut combined.buckets {
237            bucket.sort_unstable();
238        }
239
240        Ok(combined)
241    }
242}
243
244impl Default for QuerySketch {
245    fn default() -> Self {
246        Self::new()
247    }
248}
249
250#[derive(Debug, thiserror::Error)]
251pub enum QueryError {
252    #[error("I/O error: {0}")]
253    Io(#[from] std::io::Error),
254
255    #[error("Database error: {0}")]
256    Database(#[from] ReaderError),
257
258    #[error("Parse error in {path}: {message}")]
259    Parse { path: String, message: String },
260
261    #[error(
262        "Parameter mismatch: {parameter} - source has {source_value}, target database has {target_value}"
263    )]
264    ParameterMismatch {
265        parameter: String,
266        source_value: String,
267        target_value: String,
268    },
269}
270
271#[derive(Debug, Clone)]
272pub struct SampleMatch {
273    pub sample_id: u32,
274    pub hit_count: u32,
275    pub containment: f64,
276}
277
278#[derive(Debug, Clone)]
279pub struct QueryResult {
280    pub query_size: usize,
281    pub hashes_found: usize,
282    pub matches: Vec<SampleMatch>,
283    pub failed_bucket_count: usize,
284}
285
286impl QueryResult {
287    pub fn top(&self, n: usize) -> Vec<&SampleMatch> {
288        let mut sorted: Vec<_> = self.matches.iter().collect();
289        sorted.sort_by(|a, b| b.containment.total_cmp(&a.containment));
290        sorted.truncate(n);
291        sorted
292    }
293
294    pub fn above_threshold(&self, min_containment: f64) -> Vec<&SampleMatch> {
295        self.matches
296            .iter()
297            .filter(|m| m.containment >= min_containment)
298            .collect()
299    }
300
301    pub fn has_matches(&self) -> bool {
302        !self.matches.is_empty()
303    }
304
305    pub fn is_partial(&self) -> bool {
306        self.failed_bucket_count > 0
307    }
308}
309
310pub struct QueryEngine {
311    reader: JamReader,
312    bias_table: Option<Arc<HashBiasTable>>,
313}
314
315impl QueryEngine {
316    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, ReaderError> {
317        let reader = JamReader::open(path)?;
318        let bias_table = reader.bias_table();
319        Ok(Self { reader, bias_table })
320    }
321
322    pub fn threshold(&self) -> u64 {
323        self.reader.threshold()
324    }
325
326    pub fn kmer_size(&self) -> u8 {
327        self.reader.kmer_size()
328    }
329
330    pub fn bias_table(&self) -> Option<Arc<HashBiasTable>> {
331        self.bias_table.clone()
332    }
333
334    pub fn has_bias_table(&self) -> bool {
335        self.bias_table.is_some()
336    }
337
338    pub fn reader(&self) -> &JamReader {
339        &self.reader
340    }
341
342    pub fn query(&self, hashes: &[u64]) -> QueryResult {
343        if hashes.is_empty() {
344            return QueryResult {
345                query_size: 0,
346                hashes_found: 0,
347                matches: Vec::new(),
348                failed_bucket_count: 0,
349            };
350        }
351
352        let mut sorted_hashes = hashes.to_vec();
353        sorted_hashes.sort_unstable_by_key(|&h| (h & 0xFF, h));
354
355        let mut sample_hits: HashMap<u32, u32> = HashMap::new();
356        let mut hashes_found = 0;
357
358        for &hash in &sorted_hashes {
359            let mut found = false;
360            for sample_id in self.reader.search(hash) {
361                *sample_hits.entry(sample_id).or_insert(0) += 1;
362                found = true;
363            }
364            if found {
365                hashes_found += 1;
366            }
367        }
368
369        let query_size = hashes.len();
370        let matches: Vec<SampleMatch> = sample_hits
371            .into_iter()
372            .map(|(sample_id, hit_count)| SampleMatch {
373                sample_id,
374                hit_count,
375                containment: hit_count as f64 / query_size as f64,
376            })
377            .collect();
378
379        QueryResult {
380            query_size,
381            hashes_found,
382            matches,
383            failed_bucket_count: 0,
384        }
385    }
386
387    pub fn query_filtered(
388        &self,
389        hashes: &[u64],
390        min_containment: f64,
391        max_results: usize,
392    ) -> QueryResult {
393        let mut result = self.query(hashes);
394        result.matches.retain(|m| m.containment >= min_containment);
395        result
396            .matches
397            .sort_by(|a, b| b.containment.total_cmp(&a.containment));
398        result.matches.truncate(max_results);
399        result
400    }
401
402    pub fn query_batch(&self, queries: &[Vec<u64>]) -> Vec<QueryResult> {
403        use rayon::prelude::*;
404        queries.par_iter().map(|q| self.query(q)).collect()
405    }
406
407    pub fn query_sketch(&self, sketch: &QuerySketch) -> Vec<QueryResult> {
408        use crate::format::{ENTRY_SIZE, PAGE_SIZE};
409        use rayon::prelude::*;
410        use std::sync::atomic::{AtomicU32, Ordering};
411
412        let num_samples = sketch.sample_count();
413        if num_samples == 0 {
414            return Vec::new();
415        }
416
417        let threshold = self.reader.threshold();
418
419        self.reader.advise_random();
420
421        let hashes_found: Vec<AtomicU32> = (0..num_samples)
422            .into_par_iter()
423            .map(|_| AtomicU32::new(0))
424            .collect();
425
426        let bucket_pairs: Vec<Vec<(u32, u32)>> = (0..BUCKET_COUNT)
427            .into_par_iter()
428            .map(|bucket_idx| {
429                let mut pairs = Vec::new();
430                let query_bucket = sketch.bucket(bucket_idx);
431                if query_bucket.is_empty() {
432                    return pairs;
433                }
434
435                let filter = match self.reader.bucket_filter(bucket_idx) {
436                    Some(f) => f,
437                    None => return pairs,
438                };
439
440                let mut survivors = Vec::with_capacity(query_bucket.len() / 10);
441                let mut prev_hash = u64::MAX;
442                let mut prev_passed = false;
443
444                for &(hash, sample_id) in query_bucket {
445                    if hash != prev_hash {
446                        prev_hash = hash;
447                        prev_passed = filter.contains(&hash);
448                    }
449                    if prev_passed {
450                        survivors.push((hash, sample_id));
451                    }
452                }
453
454                let (filter_start, filter_end) = self.reader.bucket_filter_byte_range(bucket_idx);
455                self.reader.release_pages(filter_start, filter_end);
456
457                if survivors.is_empty() {
458                    return pairs;
459                }
460
461                let db_bucket = self.reader.bucket_entries(bucket_idx);
462                let count = db_bucket.len();
463                if count == 0 {
464                    return pairs;
465                }
466
467                let (entry_start, _entry_end) = self.reader.bucket_entry_byte_range(bucket_idx);
468                let mut last_released_page = entry_start & !(PAGE_SIZE - 1);
469
470                let mut q_idx = 0;
471                while q_idx < survivors.len() {
472                    let q_hash = survivors[q_idx].0;
473
474                    let est = ((q_hash as u128 * count as u128) / threshold as u128) as usize;
475                    let mut d_idx = est.saturating_sub(16).min(count.saturating_sub(1));
476
477                    while d_idx > 0 && db_bucket[d_idx].hash > q_hash {
478                        d_idx -= 1;
479                    }
480
481                    while d_idx < count && db_bucket[d_idx].hash < q_hash {
482                        d_idx += 1;
483                    }
484
485                    while d_idx > 0 && db_bucket[d_idx - 1].hash == q_hash {
486                        d_idx -= 1;
487                    }
488
489                    let current_byte = entry_start + d_idx * ENTRY_SIZE;
490                    let current_page = current_byte & !(PAGE_SIZE - 1);
491                    if current_page > last_released_page + PAGE_SIZE {
492                        self.reader
493                            .release_pages(last_released_page, current_page - PAGE_SIZE);
494                        last_released_page = current_page - PAGE_SIZE;
495                    }
496
497                    let db_start = d_idx;
498                    let mut db_end = d_idx;
499                    while db_end < count && db_bucket[db_end].hash == q_hash {
500                        db_end += 1;
501                    }
502                    let has_matches = db_start < db_end;
503
504                    let mut prev_sample = u32::MAX;
505                    while q_idx < survivors.len() && survivors[q_idx].0 == q_hash {
506                        let q_sample = survivors[q_idx].1;
507
508                        if q_sample != prev_sample {
509                            if has_matches {
510                                for db_entry in &db_bucket[db_start..db_end] {
511                                    pairs.push((q_sample, db_entry.sample_id));
512                                }
513                                hashes_found[q_sample as usize].fetch_add(1, Ordering::Relaxed);
514                            }
515                            prev_sample = q_sample;
516                        }
517                        q_idx += 1;
518                    }
519                }
520
521                self.reader.release_bucket(bucket_idx);
522
523                pairs
524            })
525            .collect();
526
527        let bucket_sizes: Vec<usize> = bucket_pairs.iter().map(|v| v.len()).collect();
528        let total_pairs: usize = bucket_sizes.iter().sum();
529        let mut bucket_offsets = Vec::with_capacity(BUCKET_COUNT + 1);
530        bucket_offsets.push(0usize);
531        for size in &bucket_sizes {
532            bucket_offsets.push(bucket_offsets.last().unwrap() + size);
533        }
534
535        let mut all_pairs: Vec<(u32, u32)> = vec![(0, 0); total_pairs];
536        bucket_pairs
537            .into_par_iter()
538            .enumerate()
539            .for_each(|(bucket_idx, pairs)| {
540                let start = bucket_offsets[bucket_idx];
541                let dest = unsafe {
542                    std::slice::from_raw_parts_mut(
543                        all_pairs.as_ptr().add(start) as *mut (u32, u32),
544                        pairs.len(),
545                    )
546                };
547                dest.copy_from_slice(&pairs);
548            });
549
550        let merged_hashes_found: Vec<u32> = hashes_found
551            .into_par_iter()
552            .map(|a| a.load(Ordering::Relaxed))
553            .collect();
554
555        all_pairs.par_sort_unstable();
556
557        if all_pairs.is_empty() {
558            return (0..num_samples)
559                .map(|i| QueryResult {
560                    query_size: sketch.query_sizes[i],
561                    hashes_found: merged_hashes_found[i] as usize,
562                    matches: Vec::new(),
563                    failed_bucket_count: 0,
564                })
565                .collect();
566        }
567
568        let sample_starts: Vec<usize> = (0..num_samples as u32)
569            .into_par_iter()
570            .map(|q_sample| all_pairs.partition_point(|&(qs, _)| qs < q_sample))
571            .collect();
572
573        let results: Vec<QueryResult> = (0..num_samples)
574            .into_par_iter()
575            .map(|sample_idx| {
576                let q_sample = sample_idx as u32;
577                let start = sample_starts[sample_idx];
578                let end = if sample_idx + 1 < num_samples {
579                    sample_starts[sample_idx + 1]
580                } else {
581                    all_pairs.len()
582                };
583
584                let mut matches = Vec::new();
585                let query_size = sketch.query_sizes[sample_idx];
586
587                let mut i = start;
588                while i < end {
589                    let (_, db_sample) = all_pairs[i];
590                    let mut count = 1u32;
591                    while i + (count as usize) < end
592                        && all_pairs[i + count as usize] == (q_sample, db_sample)
593                    {
594                        count += 1;
595                    }
596                    matches.push(SampleMatch {
597                        sample_id: db_sample,
598                        hit_count: count,
599                        containment: if query_size > 0 {
600                            count as f64 / query_size as f64
601                        } else {
602                            0.0
603                        },
604                    });
605                    i += count as usize;
606                }
607
608                QueryResult {
609                    query_size,
610                    hashes_found: merged_hashes_found[sample_idx] as usize,
611                    matches,
612                    failed_bucket_count: 0,
613                }
614            })
615            .collect();
616
617        results
618    }
619
620    pub fn query_fasta<P: AsRef<Path>>(
621        &self,
622        input: P,
623        singleton: bool,
624    ) -> Result<Vec<QueryResult>, QueryError> {
625        let sketch = QuerySketch::from_fasta(input, &self.reader, singleton)?;
626        Ok(self.query_sketch(&sketch))
627    }
628}
629
630#[cfg(test)]
631mod tests {
632    use super::*;
633    use crate::writer::{BuildConfig, build};
634    use std::io::Write;
635    use tempfile::NamedTempFile;
636
637    fn make_fasta(seqs: &[(&str, &str)]) -> NamedTempFile {
638        let mut f = NamedTempFile::with_suffix(".fa").unwrap();
639        for (name, seq) in seqs {
640            writeln!(f, ">{name}").unwrap();
641            writeln!(f, "{seq}").unwrap();
642        }
643        f
644    }
645
646    fn build_test_db(
647        seqs: &[(&str, &str)],
648        singleton: bool,
649    ) -> (tempfile::TempDir, std::path::PathBuf) {
650        let input = make_fasta(seqs);
651        let output_dir = tempfile::tempdir().unwrap();
652        let output_path = output_dir.path().join("test.jam");
653
654        let config = BuildConfig {
655            kmer_size: 11,
656            fscale: 1,
657            singleton,
658            num_threads: 1,
659            memory: 1,
660            ..Default::default()
661        };
662
663        build(&[input.path().to_path_buf()], &output_path, &config).unwrap();
664        (output_dir, output_path)
665    }
666
667    #[test]
668    fn test_query_engine_open() {
669        let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
670        let engine = QueryEngine::open(&path).unwrap();
671        assert!(engine.threshold() > 0);
672        assert_eq!(engine.kmer_size(), 11);
673    }
674
675    #[test]
676    fn test_query_basic() {
677        let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
678        let engine = QueryEngine::open(&path).unwrap();
679
680        let reader = JamReader::open(&path).unwrap();
681        let mut test_hashes = Vec::new();
682        for bucket_idx in 0..256 {
683            let entries = reader.bucket_entries(bucket_idx);
684            for entry in entries.iter().take(5) {
685                test_hashes.push(entry.hash);
686            }
687            if test_hashes.len() >= 10 {
688                break;
689            }
690        }
691
692        if !test_hashes.is_empty() {
693            let result = engine.query(&test_hashes);
694            assert!(result.has_matches());
695            assert!(result.hashes_found > 0);
696        }
697    }
698
699    #[test]
700    fn test_query_empty() {
701        let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
702        let engine = QueryEngine::open(&path).unwrap();
703
704        let result = engine.query(&[]);
705        assert!(!result.has_matches());
706        assert_eq!(result.query_size, 0);
707    }
708
709    #[test]
710    fn test_query_nonexistent() {
711        let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
712        let engine = QueryEngine::open(&path).unwrap();
713
714        let fake_hashes: Vec<u64> = (0..10).map(|i| u64::MAX - i).collect();
715        let result = engine.query(&fake_hashes);
716        assert_eq!(result.hashes_found, 0);
717    }
718
719    #[test]
720    fn test_query_filtered() {
721        let (_dir, path) = build_test_db(
722            &[
723                ("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
724                ("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
725            ],
726            true,
727        );
728        let engine = QueryEngine::open(&path).unwrap();
729
730        let reader = JamReader::open(&path).unwrap();
731        let mut test_hashes = Vec::new();
732        for bucket_idx in 0..256 {
733            for entry in reader.bucket_entries(bucket_idx) {
734                if entry.sample_id == 0 {
735                    test_hashes.push(entry.hash);
736                }
737                if test_hashes.len() >= 20 {
738                    break;
739                }
740            }
741            if test_hashes.len() >= 20 {
742                break;
743            }
744        }
745
746        if !test_hashes.is_empty() {
747            let result = engine.query_filtered(&test_hashes, 0.5, 10);
748            assert!(result.matches.len() <= 10);
749            for m in &result.matches {
750                assert!(m.containment >= 0.5);
751            }
752        }
753    }
754
755    #[test]
756    fn test_query_result_helpers() {
757        let result = QueryResult {
758            query_size: 100,
759            hashes_found: 50,
760            matches: vec![
761                SampleMatch {
762                    sample_id: 0,
763                    hit_count: 50,
764                    containment: 0.5,
765                },
766                SampleMatch {
767                    sample_id: 1,
768                    hit_count: 30,
769                    containment: 0.3,
770                },
771                SampleMatch {
772                    sample_id: 2,
773                    hit_count: 80,
774                    containment: 0.8,
775                },
776            ],
777            failed_bucket_count: 0,
778        };
779
780        let top2 = result.top(2);
781        assert_eq!(top2.len(), 2);
782        assert_eq!(top2[0].sample_id, 2);
783        assert_eq!(top2[1].sample_id, 0);
784
785        let above_threshold = result.above_threshold(0.4);
786        assert_eq!(above_threshold.len(), 2);
787
788        assert!(result.has_matches());
789        assert!(!result.is_partial());
790    }
791
792    #[test]
793    fn test_query_sketch_new() {
794        let sketch = QuerySketch::new();
795
796        assert_eq!(sketch.sample_count(), 0);
797        assert_eq!(sketch.total_entries(), 0);
798        assert_eq!(sketch.buckets.len(), 256);
799        assert!(sketch.sample_names.is_empty());
800        assert!(sketch.query_sizes.is_empty());
801    }
802
803    #[test]
804    fn test_query_sketch_default() {
805        let sketch = QuerySketch::default();
806
807        assert_eq!(sketch.sample_count(), 0);
808        assert_eq!(sketch.total_entries(), 0);
809    }
810
811    #[test]
812    fn test_query_sketch_bucket_accessor() {
813        let mut sketch = QuerySketch::new();
814
815        sketch.buckets[0].push((100, 0));
816        sketch.buckets[0].push((200, 1));
817
818        sketch.buckets[255].push((300, 0));
819
820        let bucket_0 = sketch.bucket(0);
821        assert_eq!(bucket_0.len(), 2);
822        assert_eq!(bucket_0[0], (100, 0));
823        assert_eq!(bucket_0[1], (200, 1));
824
825        let bucket_255 = sketch.bucket(255);
826        assert_eq!(bucket_255.len(), 1);
827        assert_eq!(bucket_255[0], (300, 0));
828
829        let bucket_1 = sketch.bucket(1);
830        assert!(bucket_1.is_empty());
831    }
832
833    #[test]
834    fn test_query_sketch_sample_count() {
835        let mut sketch = QuerySketch::new();
836        assert_eq!(sketch.sample_count(), 0);
837
838        sketch.sample_names.push("sample1".to_string());
839        assert_eq!(sketch.sample_count(), 1);
840
841        sketch.sample_names.push("sample2".to_string());
842        sketch.sample_names.push("sample3".to_string());
843        assert_eq!(sketch.sample_count(), 3);
844    }
845
846    #[test]
847    fn test_query_sketch_total_entries() {
848        let mut sketch = QuerySketch::new();
849        assert_eq!(sketch.total_entries(), 0);
850
851        sketch.buckets[0].push((100, 0));
852        sketch.buckets[0].push((200, 0));
853        assert_eq!(sketch.total_entries(), 2);
854
855        sketch.buckets[50].push((300, 1));
856        assert_eq!(sketch.total_entries(), 3);
857
858        sketch.buckets[255].push((400, 0));
859        sketch.buckets[255].push((500, 1));
860        sketch.buckets[255].push((600, 2));
861        assert_eq!(sketch.total_entries(), 6);
862    }
863
864    #[test]
865    fn test_query_sketch_with_populated_fields() {
866        let mut sketch = QuerySketch::new();
867
868        sketch.sample_names = vec!["query_sample_1".to_string(), "query_sample_2".to_string()];
869
870        sketch.query_sizes = vec![1000, 500];
871
872        for i in 0..10 {
873            sketch.buckets[i].push((i as u64 * 100, 0));
874            sketch.buckets[i].push((i as u64 * 100 + 1, 1));
875        }
876
877        assert_eq!(sketch.sample_count(), 2);
878        assert_eq!(sketch.total_entries(), 20);
879        assert_eq!(sketch.query_sizes[0], 1000);
880        assert_eq!(sketch.query_sizes[1], 500);
881        assert_eq!(sketch.sample_names[0], "query_sample_1");
882    }
883
884    #[test]
885    #[should_panic]
886    fn test_query_sketch_bucket_out_of_bounds() {
887        let sketch = QuerySketch::new();
888        let _ = sketch.bucket(256); // Should panic
889    }
890
891    #[test]
892    fn test_query_sketch_empty() {
893        let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
894        let engine = QueryEngine::open(&path).unwrap();
895
896        let sketch = QuerySketch::new();
897        let results = engine.query_sketch(&sketch);
898        assert!(results.is_empty());
899    }
900
901    #[test]
902    fn test_query_sketch_single_sample() {
903        let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
904        let engine = QueryEngine::open(&path).unwrap();
905        let reader = JamReader::open(&path).unwrap();
906
907        let mut sketch = QuerySketch::new();
908        sketch.sample_names.push("query_sample".to_string());
909
910        let mut unique_hashes = std::collections::HashSet::new();
911        for bucket_idx in 0..256 {
912            for entry in reader.bucket_entries(bucket_idx) {
913                if unique_hashes.insert(entry.hash) {
914                    sketch.buckets[bucket_idx].push((entry.hash, 0));
915                }
916            }
917        }
918        sketch.query_sizes.push(unique_hashes.len());
919
920        let results = engine.query_sketch(&sketch);
921
922        assert_eq!(results.len(), 1);
923        assert!(results[0].has_matches());
924
925        let db_sample_0_match = results[0].matches.iter().find(|m| m.sample_id == 0);
926        assert!(db_sample_0_match.is_some(), "Should match db sample 0");
927
928        let m = db_sample_0_match.unwrap();
929        assert!(
930            m.hit_count >= results[0].query_size as u32,
931            "Expected hit_count >= query_size, got {} vs {}",
932            m.hit_count,
933            results[0].query_size
934        );
935    }
936
937    #[test]
938    fn test_query_sketch_multiple_samples() {
939        let (_dir, path) = build_test_db(
940            &[
941                ("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
942                ("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
943            ],
944            true, // singleton mode - each sequence is a separate sample
945        );
946        let engine = QueryEngine::open(&path).unwrap();
947        let reader = JamReader::open(&path).unwrap();
948
949        let mut sketch = QuerySketch::new();
950        sketch.sample_names.push("query_0".to_string());
951        sketch.sample_names.push("query_1".to_string());
952
953        let mut hashes_per_sample: [std::collections::HashSet<u64>; 2] = Default::default();
954
955        for bucket_idx in 0..256 {
956            for entry in reader.bucket_entries(bucket_idx) {
957                let query_sample_id = entry.sample_id;
958                if (query_sample_id as usize) < 2 {
959                    hashes_per_sample[query_sample_id as usize].insert(entry.hash);
960                    sketch.buckets[bucket_idx].push((entry.hash, query_sample_id));
961                }
962            }
963        }
964
965        sketch.query_sizes.push(hashes_per_sample[0].len());
966        sketch.query_sizes.push(hashes_per_sample[1].len());
967
968        let results = engine.query_sketch(&sketch);
969
970        assert_eq!(results.len(), 2);
971
972        for (query_idx, result) in results.iter().enumerate() {
973            assert!(result.has_matches());
974            let self_match = result
975                .matches
976                .iter()
977                .find(|m| m.sample_id == query_idx as u32);
978            if let Some(m) = self_match {
979                assert!(
980                    m.containment >= 0.9,
981                    "Query {} should have high containment with DB sample {}, got {}",
982                    query_idx,
983                    query_idx,
984                    m.containment
985                );
986            }
987        }
988    }
989
990    #[test]
991    fn test_query_sketch_no_matches() {
992        let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
993        let engine = QueryEngine::open(&path).unwrap();
994
995        let mut sketch = QuerySketch::new();
996        sketch.sample_names.push("fake_sample".to_string());
997        sketch.query_sizes.push(10);
998
999        for i in 0..10 {
1000            let fake_hash = u64::MAX - i;
1001            let bucket_idx = (fake_hash & 0xFF) as usize;
1002            sketch.buckets[bucket_idx].push((fake_hash, 0));
1003        }
1004
1005        let results = engine.query_sketch(&sketch);
1006
1007        assert_eq!(results.len(), 1);
1008        assert_eq!(results[0].hashes_found, 0);
1009        assert!(results[0].matches.is_empty());
1010    }
1011
1012    #[test]
1013    fn test_query_sketch_containment_calculation() {
1014        let (_dir, path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1015        let engine = QueryEngine::open(&path).unwrap();
1016        let reader = JamReader::open(&path).unwrap();
1017
1018        let mut sketch = QuerySketch::new();
1019        sketch.sample_names.push("half_sample".to_string());
1020
1021        let mut all_hashes = Vec::new();
1022        for bucket_idx in 0..256 {
1023            for entry in reader.bucket_entries(bucket_idx) {
1024                all_hashes.push((entry.hash, bucket_idx));
1025            }
1026        }
1027
1028        let selected_hashes: Vec<_> = all_hashes.iter().step_by(2).collect();
1029        sketch.query_sizes.push(selected_hashes.len());
1030
1031        for &(hash, bucket_idx) in &selected_hashes {
1032            sketch.buckets[*bucket_idx].push((*hash, 0));
1033        }
1034
1035        let results = engine.query_sketch(&sketch);
1036
1037        assert_eq!(results.len(), 1);
1038        assert!(results[0].has_matches());
1039        let top = results[0].top(1);
1040        assert!(!top.is_empty());
1041        assert!(top[0].containment >= 0.9);
1042    }
1043
1044    #[test]
1045    fn test_from_fasta_non_singleton() {
1046        let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1047        let db = JamReader::open(&db_path).unwrap();
1048
1049        let query_fasta = make_fasta(&[("query_seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1050
1051        let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, false).unwrap();
1052
1053        assert_eq!(sketch.sample_count(), 1);
1054        assert!(!sketch.sample_names[0].is_empty());
1055
1056        assert!(sketch.total_entries() > 0);
1057        assert!(sketch.query_sizes[0] > 0);
1058
1059        assert_eq!(sketch.query_sizes[0], sketch.total_entries());
1060    }
1061
1062    #[test]
1063    fn test_from_fasta_singleton() {
1064        let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1065        let db = JamReader::open(&db_path).unwrap();
1066
1067        let query_fasta = make_fasta(&[
1068            ("query_seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
1069            ("query_seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
1070        ]);
1071
1072        let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, true).unwrap();
1073
1074        assert_eq!(sketch.sample_count(), 2);
1075        assert_eq!(sketch.sample_names[0], "query_seq1");
1076        assert_eq!(sketch.sample_names[1], "query_seq2");
1077
1078        assert!(sketch.query_sizes[0] > 0);
1079        assert!(sketch.query_sizes[1] > 0);
1080
1081        let total_unique: usize = sketch.query_sizes.iter().sum();
1082        assert!(total_unique <= sketch.total_entries() + sketch.sample_count());
1083    }
1084
1085    #[test]
1086    fn test_from_fasta_uses_db_parameters() {
1087        let input = make_fasta(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1088        let output_dir = tempfile::tempdir().unwrap();
1089        let db_path = output_dir.path().join("test.jam");
1090
1091        let config = BuildConfig {
1092            kmer_size: 15,
1093            fscale: 10,
1094            singleton: false,
1095            num_threads: 1,
1096            memory: 1,
1097            ..Default::default()
1098        };
1099
1100        build(&[input.path().to_path_buf()], &db_path, &config).unwrap();
1101        let db = JamReader::open(&db_path).unwrap();
1102
1103        assert_eq!(db.kmer_size(), 15);
1104
1105        let query_fasta = make_fasta(&[("query", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1106
1107        let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, false).unwrap();
1108
1109        assert!(sketch.sample_count() == 1);
1110    }
1111
1112    #[test]
1113    fn test_from_fasta_deduplication() {
1114        let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1115        let db = JamReader::open(&db_path).unwrap();
1116
1117        let query_fasta = make_fasta(&[(
1118            "query",
1119            "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG",
1120        )]);
1121
1122        let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, false).unwrap();
1123
1124        assert_eq!(sketch.query_sizes[0], sketch.total_entries());
1125    }
1126
1127    #[test]
1128    fn test_from_fasta_bucketization() {
1129        let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1130        let db = JamReader::open(&db_path).unwrap();
1131
1132        let query_fasta = make_fasta(&[(
1133            "query",
1134            "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG",
1135        )]);
1136
1137        let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, false).unwrap();
1138
1139        for (bucket_idx, bucket) in sketch.buckets.iter().enumerate() {
1140            for &(hash, _sample_id) in bucket {
1141                assert_eq!(
1142                    bucket_id(hash),
1143                    bucket_idx,
1144                    "Hash {} should be in bucket {}, not {}",
1145                    hash,
1146                    bucket_id(hash),
1147                    bucket_idx
1148                );
1149            }
1150        }
1151    }
1152
1153    #[test]
1154    fn test_from_fasta_sorted_buckets() {
1155        let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1156        let db = JamReader::open(&db_path).unwrap();
1157
1158        let query_fasta = make_fasta(&[
1159            (
1160                "query1",
1161                "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG",
1162            ),
1163            (
1164                "query2",
1165                "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA",
1166            ),
1167        ]);
1168
1169        let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, true).unwrap();
1170
1171        for bucket in &sketch.buckets {
1172            for window in bucket.windows(2) {
1173                assert!(
1174                    window[0] <= window[1],
1175                    "Bucket not sorted: {:?} > {:?}",
1176                    window[0],
1177                    window[1]
1178                );
1179            }
1180        }
1181    }
1182
1183    #[test]
1184    fn test_from_fasta_short_sequences_skipped() {
1185        let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1186        let db = JamReader::open(&db_path).unwrap();
1187        assert_eq!(db.kmer_size(), 11);
1188
1189        let query_fasta = make_fasta(&[
1190            ("short", "ATCGATCG"),                        // 8 bp, < 11
1191            ("long", "ATCGATCGATCGATCGATCGATCGATCGATCG"), // 32 bp, > 11
1192        ]);
1193
1194        let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, true).unwrap();
1195
1196        assert_eq!(sketch.sample_count(), 2);
1197
1198        assert_eq!(sketch.query_sizes[0], 0);
1199
1200        assert!(sketch.query_sizes[1] > 0);
1201    }
1202
1203    #[test]
1204    fn test_from_fasta_file_not_found() {
1205        let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1206        let db = JamReader::open(&db_path).unwrap();
1207
1208        let result = QuerySketch::from_fasta("/nonexistent/path.fasta", &db, false);
1209        assert!(result.is_err());
1210
1211        if let Err(QueryError::Parse { path, message: _ }) = result {
1212            assert!(path.contains("nonexistent"));
1213        } else {
1214            panic!("Expected Parse error");
1215        }
1216    }
1217
1218    #[test]
1219    fn test_from_fasta_integration_with_query_engine() {
1220        let (_dir, db_path) =
1221            build_test_db(&[("db_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1222        let db = JamReader::open(&db_path).unwrap();
1223        let engine = QueryEngine::open(&db_path).unwrap();
1224
1225        let query_fasta = make_fasta(&[("query_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1226
1227        let sketch = QuerySketch::from_fasta(query_fasta.path(), &db, false).unwrap();
1228
1229        let results = engine.query_sketch(&sketch);
1230
1231        assert_eq!(results.len(), 1);
1232        assert!(results[0].has_matches());
1233
1234        let top = results[0].top(1);
1235        assert!(!top.is_empty());
1236        assert!(
1237            top[0].containment >= 0.9,
1238            "Expected high containment, got {}",
1239            top[0].containment
1240        );
1241    }
1242
1243    fn build_test_db_with_params(
1244        seqs: &[(&str, &str)],
1245        kmer_size: u8,
1246        fscale: u64,
1247        singleton: bool,
1248    ) -> (tempfile::TempDir, std::path::PathBuf) {
1249        let input = make_fasta(seqs);
1250        let output_dir = tempfile::tempdir().unwrap();
1251        let output_path = output_dir.path().join("test.jam");
1252
1253        let config = BuildConfig {
1254            kmer_size,
1255            fscale,
1256            singleton,
1257            num_threads: 1,
1258            memory: 1,
1259            ..Default::default()
1260        };
1261
1262        build(&[input.path().to_path_buf()], &output_path, &config).unwrap();
1263        (output_dir, output_path)
1264    }
1265
1266    #[test]
1267    fn test_from_jam_success() {
1268        let (_dir1, db_path) = build_test_db_with_params(
1269            &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1270            11,
1271            1,
1272            false,
1273        );
1274        let (_dir2, query_path) = build_test_db_with_params(
1275            &[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1276            11,
1277            1,
1278            false,
1279        );
1280
1281        let db = JamReader::open(&db_path).unwrap();
1282        let sketch = QuerySketch::from_jam(&query_path, &db).unwrap();
1283
1284        assert_eq!(sketch.sample_count(), 1);
1285        assert!(sketch.total_entries() > 0);
1286        assert!(!sketch.sample_names[0].is_empty());
1287        assert!(sketch.query_sizes[0] > 0);
1288    }
1289
1290    #[test]
1291    fn test_from_jam_multiple_samples() {
1292        let (_dir1, db_path) = build_test_db_with_params(
1293            &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1294            11,
1295            1,
1296            false,
1297        );
1298        let (_dir2, query_path) = build_test_db_with_params(
1299            &[
1300                ("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
1301                ("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
1302            ],
1303            11,
1304            1,
1305            true,
1306        );
1307
1308        let db = JamReader::open(&db_path).unwrap();
1309        let sketch = QuerySketch::from_jam(&query_path, &db).unwrap();
1310
1311        assert_eq!(sketch.sample_count(), 2);
1312        assert_eq!(sketch.sample_names[0], "seq1");
1313        assert_eq!(sketch.sample_names[1], "seq2");
1314        assert_eq!(sketch.query_sizes.len(), 2);
1315        assert!(sketch.query_sizes[0] > 0);
1316        assert!(sketch.query_sizes[1] > 0);
1317    }
1318
1319    #[test]
1320    fn test_from_jam_kmer_size_mismatch() {
1321        let (_dir1, db_path) = build_test_db_with_params(
1322            &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1323            11,
1324            1,
1325            false,
1326        );
1327        let (_dir2, query_path) = build_test_db_with_params(
1328            &[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1329            21,
1330            1,
1331            false,
1332        );
1333
1334        let db = JamReader::open(&db_path).unwrap();
1335        let result = QuerySketch::from_jam(&query_path, &db);
1336
1337        assert!(result.is_err());
1338        let err = result.unwrap_err();
1339        match err {
1340            QueryError::ParameterMismatch {
1341                parameter,
1342                source_value,
1343                target_value,
1344            } => {
1345                assert!(parameter.contains("k-mer"));
1346                assert_eq!(source_value, "21");
1347                assert_eq!(target_value, "11");
1348            }
1349            _ => panic!("Expected ParameterMismatch error, got {:?}", err),
1350        }
1351    }
1352
1353    #[test]
1354    fn test_from_jam_threshold_mismatch() {
1355        let (_dir1, db_path) = build_test_db_with_params(
1356            &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1357            11,
1358            1,
1359            false,
1360        );
1361        let (_dir2, query_path) = build_test_db_with_params(
1362            &[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1363            11,
1364            1000,
1365            false,
1366        );
1367
1368        let db = JamReader::open(&db_path).unwrap();
1369        let result = QuerySketch::from_jam(&query_path, &db);
1370
1371        assert!(result.is_err());
1372        let err = result.unwrap_err();
1373        match err {
1374            QueryError::ParameterMismatch {
1375                parameter,
1376                source_value,
1377                target_value,
1378            } => {
1379                assert!(parameter.contains("threshold"));
1380                assert_ne!(source_value, target_value);
1381            }
1382            _ => panic!("Expected ParameterMismatch error, got {:?}", err),
1383        }
1384    }
1385
1386    #[test]
1387    fn test_from_jam_preserves_bucketization() {
1388        let (_dir1, db_path) = build_test_db_with_params(
1389            &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1390            11,
1391            1,
1392            false,
1393        );
1394        let (_dir2, query_path) = build_test_db_with_params(
1395            &[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1396            11,
1397            1,
1398            false,
1399        );
1400
1401        let db = JamReader::open(&db_path).unwrap();
1402        let sketch = QuerySketch::from_jam(&query_path, &db).unwrap();
1403
1404        for bucket_idx in 0..BUCKET_COUNT {
1405            for &(hash, _sample_id) in sketch.bucket(bucket_idx) {
1406                assert_eq!(
1407                    bucket_id(hash),
1408                    bucket_idx,
1409                    "Entry with hash {} is in wrong bucket",
1410                    hash
1411                );
1412            }
1413        }
1414    }
1415
1416    #[test]
1417    fn test_from_jam_query_sizes_correct() {
1418        let (_dir1, db_path) = build_test_db_with_params(
1419            &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1420            11,
1421            1,
1422            false,
1423        );
1424        let (_dir2, query_path) = build_test_db_with_params(
1425            &[
1426                ("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
1427                ("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
1428            ],
1429            11,
1430            1,
1431            true,
1432        );
1433
1434        let db = JamReader::open(&db_path).unwrap();
1435        let sketch = QuerySketch::from_jam(&query_path, &db).unwrap();
1436
1437        for (sample_id, &expected_size) in sketch.query_sizes.iter().enumerate() {
1438            let mut unique_hashes = std::collections::HashSet::new();
1439            for bucket_idx in 0..BUCKET_COUNT {
1440                for &(hash, sid) in sketch.bucket(bucket_idx) {
1441                    if sid as usize == sample_id {
1442                        unique_hashes.insert(hash);
1443                    }
1444                }
1445            }
1446            assert_eq!(
1447                unique_hashes.len(),
1448                expected_size,
1449                "query_sizes[{}] should match actual unique hash count",
1450                sample_id
1451            );
1452        }
1453    }
1454
1455    #[test]
1456    fn test_from_jam_empty_source() {
1457        let (_dir1, db_path) = build_test_db_with_params(
1458            &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1459            11,
1460            1_000_000,
1461            false,
1462        );
1463        let (_dir2, query_path) = build_test_db_with_params(
1464            &[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1465            11,
1466            1_000_000,
1467            false,
1468        );
1469
1470        let db = JamReader::open(&db_path).unwrap();
1471        let result = QuerySketch::from_jam(&query_path, &db);
1472
1473        assert!(result.is_ok());
1474    }
1475
1476    #[test]
1477    fn test_from_inputs_empty() {
1478        let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1479        let db = JamReader::open(&db_path).unwrap();
1480
1481        let sketch = QuerySketch::from_inputs(&[], &db, false).unwrap();
1482
1483        assert_eq!(sketch.sample_count(), 0);
1484        assert_eq!(sketch.total_entries(), 0);
1485    }
1486
1487    #[test]
1488    fn test_from_inputs_single_fasta() {
1489        let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1490        let db = JamReader::open(&db_path).unwrap();
1491
1492        let query_fasta = make_fasta(&[("query_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1493
1494        let sketch =
1495            QuerySketch::from_inputs(&[query_fasta.path().to_path_buf()], &db, false).unwrap();
1496
1497        assert_eq!(sketch.sample_count(), 1);
1498        assert!(sketch.total_entries() > 0);
1499    }
1500
1501    #[test]
1502    fn test_from_inputs_single_jam() {
1503        let (_dir1, db_path) = build_test_db_with_params(
1504            &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1505            11,
1506            1,
1507            false,
1508        );
1509        let (_dir2, query_jam) = build_test_db_with_params(
1510            &[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1511            11,
1512            1,
1513            false,
1514        );
1515
1516        let db = JamReader::open(&db_path).unwrap();
1517
1518        let sketch = QuerySketch::from_inputs(&[query_jam], &db, false).unwrap();
1519
1520        assert_eq!(sketch.sample_count(), 1);
1521        assert!(sketch.total_entries() > 0);
1522    }
1523
1524    #[test]
1525    fn test_from_inputs_multiple_fasta() {
1526        let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1527        let db = JamReader::open(&db_path).unwrap();
1528
1529        let fasta1 = make_fasta(&[("query1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1530        let fasta2 = make_fasta(&[("query2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")]);
1531
1532        let sketch = QuerySketch::from_inputs(
1533            &[fasta1.path().to_path_buf(), fasta2.path().to_path_buf()],
1534            &db,
1535            false,
1536        )
1537        .unwrap();
1538
1539        assert_eq!(sketch.sample_count(), 2);
1540        assert!(sketch.total_entries() > 0);
1541        assert_eq!(sketch.query_sizes.len(), 2);
1542    }
1543
1544    #[test]
1545    fn test_from_inputs_multiple_fasta_singleton() {
1546        let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1547        let db = JamReader::open(&db_path).unwrap();
1548
1549        let fasta1 = make_fasta(&[
1550            ("seq1a", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
1551            ("seq1b", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
1552        ]);
1553        let fasta2 = make_fasta(&[
1554            ("seq2a", "TATATATATATATATATATATATATATATATA"),
1555            ("seq2b", "GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC"),
1556        ]);
1557
1558        let sketch = QuerySketch::from_inputs(
1559            &[fasta1.path().to_path_buf(), fasta2.path().to_path_buf()],
1560            &db,
1561            true,
1562        )
1563        .unwrap();
1564
1565        assert_eq!(sketch.sample_count(), 4);
1566        assert_eq!(sketch.sample_names.len(), 4);
1567        assert_eq!(sketch.sample_names[0], "seq1a");
1568        assert_eq!(sketch.sample_names[1], "seq1b");
1569        assert_eq!(sketch.sample_names[2], "seq2a");
1570        assert_eq!(sketch.sample_names[3], "seq2b");
1571    }
1572
1573    #[test]
1574    fn test_from_inputs_mixed_fasta_and_jam() {
1575        let (_dir1, db_path) = build_test_db_with_params(
1576            &[("db_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1577            11,
1578            1,
1579            false,
1580        );
1581        let db = JamReader::open(&db_path).unwrap();
1582
1583        let query_fasta = make_fasta(&[("fasta_query", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1584
1585        let (_dir2, query_jam) = build_test_db_with_params(
1586            &[("jam_query", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1587            11,
1588            1,
1589            false,
1590        );
1591
1592        let sketch =
1593            QuerySketch::from_inputs(&[query_fasta.path().to_path_buf(), query_jam], &db, false)
1594                .unwrap();
1595
1596        assert_eq!(sketch.sample_count(), 2);
1597        assert!(sketch.total_entries() > 0);
1598    }
1599
1600    #[test]
1601    fn test_from_inputs_sample_id_renumbering() {
1602        let (_dir1, db_path) = build_test_db_with_params(
1603            &[("db_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1604            11,
1605            1,
1606            false,
1607        );
1608        let db = JamReader::open(&db_path).unwrap();
1609
1610        let (_dir2, jam1) = build_test_db_with_params(
1611            &[
1612                ("seq1a", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
1613                ("seq1b", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
1614            ],
1615            11,
1616            1,
1617            true,
1618        );
1619        let (_dir3, jam2) = build_test_db_with_params(
1620            &[
1621                ("seq2a", "TATATATATATATATATATATATATATATATA"),
1622                ("seq2b", "GCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC"),
1623            ],
1624            11,
1625            1,
1626            true,
1627        );
1628
1629        let sketch = QuerySketch::from_inputs(&[jam1, jam2], &db, false).unwrap();
1630
1631        assert_eq!(sketch.sample_count(), 4);
1632
1633        for bucket in &sketch.buckets {
1634            for &(_hash, sample_id) in bucket {
1635                assert!(sample_id < 4, "Sample ID {} should be < 4", sample_id);
1636            }
1637        }
1638
1639        let mut seen_samples = std::collections::HashSet::new();
1640        for bucket in &sketch.buckets {
1641            for &(_hash, sample_id) in bucket {
1642                seen_samples.insert(sample_id);
1643            }
1644        }
1645        assert_eq!(seen_samples.len(), 4, "All samples should have entries");
1646    }
1647
1648    #[test]
1649    fn test_from_inputs_buckets_sorted() {
1650        let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1651        let db = JamReader::open(&db_path).unwrap();
1652
1653        let fasta1 = make_fasta(&[("q1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1654        let fasta2 = make_fasta(&[("q2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")]);
1655
1656        let sketch = QuerySketch::from_inputs(
1657            &[fasta1.path().to_path_buf(), fasta2.path().to_path_buf()],
1658            &db,
1659            false,
1660        )
1661        .unwrap();
1662
1663        for bucket in &sketch.buckets {
1664            for window in bucket.windows(2) {
1665                assert!(
1666                    window[0] <= window[1],
1667                    "Bucket not sorted: {:?} > {:?}",
1668                    window[0],
1669                    window[1]
1670                );
1671            }
1672        }
1673    }
1674
1675    #[test]
1676    fn test_from_inputs_query_sizes_preserved() {
1677        let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1678        let db = JamReader::open(&db_path).unwrap();
1679
1680        let fasta1 = make_fasta(&[("q1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1681        let fasta2 = make_fasta(&[("q2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")]);
1682
1683        let sketch1 = QuerySketch::from_fasta(fasta1.path(), &db, false).unwrap();
1684        let sketch2 = QuerySketch::from_fasta(fasta2.path(), &db, false).unwrap();
1685
1686        let fasta1_new = make_fasta(&[("q1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1687        let fasta2_new = make_fasta(&[("q2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")]);
1688
1689        let combined = QuerySketch::from_inputs(
1690            &[
1691                fasta1_new.path().to_path_buf(),
1692                fasta2_new.path().to_path_buf(),
1693            ],
1694            &db,
1695            false,
1696        )
1697        .unwrap();
1698
1699        assert_eq!(combined.query_sizes.len(), 2);
1700        assert_eq!(combined.query_sizes[0], sketch1.query_sizes[0]);
1701        assert_eq!(combined.query_sizes[1], sketch2.query_sizes[0]);
1702    }
1703
1704    #[test]
1705    fn test_from_inputs_jam_detection_by_extension() {
1706        let (_dir1, db_path) = build_test_db_with_params(
1707            &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1708            11,
1709            1,
1710            false,
1711        );
1712        let db = JamReader::open(&db_path).unwrap();
1713
1714        let (_dir2, jam_path) = build_test_db_with_params(
1715            &[("jam_seq", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1716            11,
1717            1,
1718            false,
1719        );
1720
1721        assert_eq!(jam_path.extension().unwrap(), "jam");
1722
1723        let sketch = QuerySketch::from_inputs(&[jam_path], &db, false).unwrap();
1724
1725        assert_eq!(sketch.sample_count(), 1);
1726        assert!(!sketch.sample_names[0].is_empty());
1727    }
1728
1729    #[test]
1730    fn test_from_inputs_propagates_errors() {
1731        let (_dir, db_path) = build_test_db(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1732        let db = JamReader::open(&db_path).unwrap();
1733
1734        let result = QuerySketch::from_inputs(
1735            &[std::path::PathBuf::from("/nonexistent/file.fasta")],
1736            &db,
1737            false,
1738        );
1739
1740        assert!(result.is_err());
1741    }
1742
1743    #[test]
1744    fn test_from_inputs_jam_parameter_mismatch_propagates() {
1745        let (_dir1, db_path) = build_test_db_with_params(
1746            &[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1747            11, // k=11
1748            1,
1749            false,
1750        );
1751        let db = JamReader::open(&db_path).unwrap();
1752
1753        let (_dir2, jam_path) = build_test_db_with_params(
1754            &[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1755            21,
1756            1,
1757            false,
1758        );
1759
1760        let result = QuerySketch::from_inputs(&[jam_path], &db, false);
1761
1762        assert!(result.is_err());
1763        match result.unwrap_err() {
1764            QueryError::ParameterMismatch { parameter, .. } => {
1765                assert!(parameter.contains("k-mer"));
1766            }
1767            e => panic!("Expected ParameterMismatch error, got {:?}", e),
1768        }
1769    }
1770
1771    #[test]
1772    fn test_from_inputs_integration_with_query_engine() {
1773        let (_dir1, db_path) = build_test_db_with_params(
1774            &[("db_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")],
1775            11,
1776            1,
1777            false,
1778        );
1779        let db = JamReader::open(&db_path).unwrap();
1780        let engine = QueryEngine::open(&db_path).unwrap();
1781
1782        let query_fasta = make_fasta(&[("same_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1783        let (_dir2, query_jam) = build_test_db_with_params(
1784            &[("different_seq", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")],
1785            11,
1786            1,
1787            false,
1788        );
1789
1790        let sketch =
1791            QuerySketch::from_inputs(&[query_fasta.path().to_path_buf(), query_jam], &db, false)
1792                .unwrap();
1793
1794        assert_eq!(sketch.sample_count(), 2);
1795
1796        let results = engine.query_sketch(&sketch);
1797
1798        assert_eq!(results.len(), 2);
1799
1800        assert!(results[0].has_matches());
1801        let top0 = results[0].top(1);
1802        assert!(!top0.is_empty());
1803        assert!(
1804            top0[0].containment >= 0.9,
1805            "Same sequence should have high containment, got {}",
1806            top0[0].containment
1807        );
1808
1809    }
1810
1811    #[test]
1812    fn test_query_fasta_non_singleton() {
1813        let (_dir, db_path) =
1814            build_test_db(&[("db_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1815        let engine = QueryEngine::open(&db_path).unwrap();
1816
1817        let query_fasta = make_fasta(&[("query_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1818
1819        let results = engine.query_fasta(query_fasta.path(), false).unwrap();
1820
1821        assert_eq!(results.len(), 1);
1822        assert!(results[0].has_matches());
1823
1824        let top = results[0].top(1);
1825        assert!(!top.is_empty());
1826        assert!(
1827            top[0].containment >= 0.9,
1828            "Expected high containment, got {}",
1829            top[0].containment
1830        );
1831    }
1832
1833    #[test]
1834    fn test_query_fasta_singleton() {
1835        let (_dir, db_path) = build_test_db(
1836            &[
1837                ("db_seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
1838                ("db_seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
1839            ],
1840            true,
1841        );
1842        let engine = QueryEngine::open(&db_path).unwrap();
1843
1844        let query_fasta = make_fasta(&[
1845            ("query1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
1846            ("query2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
1847        ]);
1848
1849        let results = engine.query_fasta(query_fasta.path(), true).unwrap();
1850
1851        assert_eq!(results.len(), 2);
1852
1853        assert!(results[0].has_matches());
1854        assert!(results[1].has_matches());
1855
1856        for (i, result) in results.iter().enumerate() {
1857            let self_match = result.matches.iter().find(|m| m.sample_id == i as u32);
1858            if let Some(m) = self_match {
1859                assert!(
1860                    m.containment >= 0.9,
1861                    "Query {} should have high containment with DB sample {}, got {}",
1862                    i,
1863                    i,
1864                    m.containment
1865                );
1866            }
1867        }
1868    }
1869
1870    #[test]
1871    fn test_query_fasta_file_not_found() {
1872        let (_dir, db_path) =
1873            build_test_db(&[("db_seq", "ATCGATCGATCGATCGATCGATCGATCGATCG")], false);
1874        let engine = QueryEngine::open(&db_path).unwrap();
1875
1876        let result = engine.query_fasta("/nonexistent/path.fasta", false);
1877
1878        assert!(result.is_err());
1879        match result.unwrap_err() {
1880            QueryError::Parse { path, message: _ } => {
1881                assert!(path.contains("nonexistent"));
1882            }
1883            e => panic!("Expected Parse error, got {:?}", e),
1884        }
1885    }
1886}