1use crate::bias::HashBiasTable;
2use crate::core_utils::passes_entropy_filter;
3use crate::format::{BUCKET_COUNT, ENTRY_SIZE, Entry, bucket_id};
4use crate::io::EntryWriter;
5use crossfire::mpsc;
6use crossfire::{MTx, Rx};
7use dashmap::DashMap;
8use indicatif::{ProgressBar, ProgressStyle};
9use jamhash::jamhash_u64;
10use memmap2::Mmap;
11use needletail::{Sequence, parse_fastx_reader};
12use rayon::prelude::*;
13use std::fs::File;
14use std::io;
15use std::path::{Path, PathBuf};
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
18use std::time::Duration;
19use tempfile::TempDir;
20
21const WRITE_BUFFER_SIZE: usize = 8 * 1024 * 1024;
22const MIN_MEMORY_GB: usize = 4;
23const DEFAULT_SEND_TIMEOUT: Duration = Duration::from_millis(1);
24const MIN_SPLIT_SIZE: usize = 1024 * 1024;
25const MAX_CONCURRENT_MMAPS: usize = 256;
26const OPTIMAL_CHANNEL_CAPACITY: usize = 512 * 1024;
27
28type Sender = MTx<mpsc::Array<Entry>>;
29type Receiver = Rx<mpsc::Array<Entry>>;
30
31#[derive(Clone)]
32pub struct SketchConfig {
33 pub kmer_size: u8,
34 pub fscale: u64,
35 pub num_threads: usize,
36 pub memory: usize,
37 pub temp_dir_base: Option<PathBuf>,
38 pub min_entropy: f64,
39 pub singleton: bool,
40 pub bias_table: Option<Arc<HashBiasTable>>,
41 pub send_timeout: Duration,
42 pub show_progress: bool,
43}
44
45impl Default for SketchConfig {
46 fn default() -> Self {
47 Self {
48 kmer_size: 21,
49 fscale: 1000,
50 num_threads: 1,
51 memory: MIN_MEMORY_GB,
52 temp_dir_base: None,
53 min_entropy: 0.0,
54 singleton: false,
55 bias_table: None,
56 send_timeout: DEFAULT_SEND_TIMEOUT,
57 show_progress: false,
58 }
59 }
60}
61
62pub struct SketchResult {
63 pub sample_count: u32,
64 pub bucket_entry_counts: [u64; BUCKET_COUNT],
65 pub frac_max: u64,
66 pub temp_dir: TempDir,
67 pub sample_names: Vec<String>,
68}
69
70#[derive(Debug, thiserror::Error)]
71pub enum SketchError {
72 #[error("I/O error: {0}")]
73 Io(#[from] std::io::Error),
74
75 #[error("Parse error in {path}: {message}")]
76 Parse { path: PathBuf, message: String },
77
78 #[error("Channel send error")]
79 Channel,
80
81 #[error("Invalid configuration: {0}")]
82 Config(String),
83}
84
85struct WorkUnit {
86 mmap: Option<Arc<Mmap>>,
87 start: usize,
88 end: usize,
89 sample_id: Option<u32>,
90 source_path: Arc<PathBuf>,
91}
92
93struct BucketWriter {
94 receiver: Receiver,
95 writer: EntryWriter,
96 bucket_id: usize,
97}
98
99impl BucketWriter {
100 fn drain(&mut self) -> io::Result<()> {
101 while let Ok(entry) = self.receiver.try_recv() {
102 self.writer.write(&entry)?;
103 }
104 Ok(())
105 }
106
107 fn drain_until_disconnected(&mut self, timeout: Duration) -> io::Result<bool> {
108 match self.receiver.recv_timeout(timeout) {
109 Ok(entry) => {
110 self.writer.write(&entry)?;
111 Ok(true)
112 }
113 Err(crossfire::RecvTimeoutError::Timeout) => Ok(true),
114 Err(crossfire::RecvTimeoutError::Disconnected) => Ok(false),
115 }
116 }
117}
118
119struct SketchContext<'a> {
120 senders: &'a [Sender],
121 config: &'a SketchConfig,
122 sample_counter: &'a AtomicU32,
123 frac_max: u64,
124 sample_names: &'a DashMap<u32, String>,
125}
126
127struct MmapSliceReader {
128 mmap: Arc<Mmap>,
129 start: usize,
130 end: usize,
131 pos: usize,
132}
133
134impl MmapSliceReader {
135 fn new(mmap: Arc<Mmap>, start: usize, end: usize) -> Self {
136 Self {
137 mmap,
138 start,
139 end,
140 pos: 0,
141 }
142 }
143}
144
145impl io::Read for MmapSliceReader {
146 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
147 let current = self.start + self.pos;
148 if current >= self.end {
149 return Ok(0);
150 }
151 let remaining = &self.mmap[current..self.end];
152 let n = remaining.len().min(buf.len());
153 buf[..n].copy_from_slice(&remaining[..n]);
154 self.pos += n;
155 Ok(n)
156 }
157}
158
159fn distribute_evenly(total: usize, parts: usize) -> impl Iterator<Item = usize> {
160 let per_part = total / parts;
161 let remainder = total % parts;
162 (0..parts).map(move |i| {
163 if i < remainder {
164 per_part + 1
165 } else {
166 per_part
167 }
168 })
169}
170
171const GZ_MAGIC: [u8; 2] = [0x1F, 0x8B];
172const BZ_MAGIC: [u8; 2] = [0x42, 0x5A];
173const XZ_MAGIC: [u8; 2] = [0xFD, 0x37];
174const ZST_MAGIC: [u8; 2] = [0x28, 0xB5];
175
176#[inline]
177fn is_compressed(magic: [u8; 2]) -> bool {
178 matches!(magic, GZ_MAGIC | BZ_MAGIC | XZ_MAGIC | ZST_MAGIC)
179}
180
181fn validate_format(magic: [u8; 2], path: &Path) -> Result<(), SketchError> {
182 if !is_compressed(magic) && !matches!(magic[0], b'>' | b'@') {
183 return Err(SketchError::Parse {
184 path: path.to_path_buf(),
185 message: format!(
186 "unrecognized format (bytes: [{:#04X}, {:#04X}])",
187 magic[0], magic[1]
188 ),
189 });
190 }
191 Ok(())
192}
193
194fn validate_file_header(path: &Path) -> Result<[u8; 2], SketchError> {
195 let mut file = File::open(path)?;
196 let mut magic = [0u8; 2];
197 use std::io::Read;
198 let bytes_read = file.read(&mut magic)?;
199 if bytes_read < 2 {
200 return Err(SketchError::Parse {
201 path: path.to_path_buf(),
202 message: format!("file too small ({bytes_read} bytes) to be valid FASTA/FASTQ"),
203 });
204 }
205 validate_format(magic, path)?;
206 Ok(magic)
207}
208
209fn validate_mmap_header(mmap: &Mmap, path: &Path) -> Result<[u8; 2], SketchError> {
210 if mmap.len() < 2 {
211 return Err(SketchError::Parse {
212 path: path.to_path_buf(),
213 message: format!(
214 "file too small ({} bytes) to be valid FASTA/FASTQ",
215 mmap.len()
216 ),
217 });
218 }
219 let magic = [mmap[0], mmap[1]];
220 validate_format(magic, path)?;
221 Ok(magic)
222}
223
224fn scan_fasta_boundaries(data: &[u8]) -> Vec<usize> {
225 let mut bounds = vec![0];
226 bounds.extend(
227 data.windows(2)
228 .enumerate()
229 .filter_map(|(i, w)| (w == b"\n>").then_some(i + 1)),
230 );
231 bounds
232}
233
234#[inline]
235fn is_iupac_nucleotide(b: u8) -> bool {
236 matches!(
237 b | 0x20,
238 b'a' | b'c'
239 | b'g'
240 | b't'
241 | b'u'
242 | b'n'
243 | b'r'
244 | b'y'
245 | b's'
246 | b'w'
247 | b'k'
248 | b'm'
249 | b'b'
250 | b'd'
251 | b'h'
252 | b'v'
253 )
254}
255
256fn scan_fastq_boundaries(data: &[u8]) -> Vec<usize> {
257 let mut bounds = vec![0];
258 let mut i = 0;
259
260 while i + 1 < data.len() {
261 if data[i] == b'\n' && data[i + 1] == b'@' {
262 let header_start = i + 1;
263 let header_end = data[header_start..]
264 .iter()
265 .position(|&b| b == b'\n')
266 .map(|p| header_start + p)
267 .unwrap_or(data.len());
268
269 if header_end > header_start + 1 {
270 let seq_start = header_end + 1;
271 if seq_start < data.len() && is_iupac_nucleotide(data[seq_start]) {
272 bounds.push(header_start);
273 }
274 }
275 }
276 i += 1;
277 }
278
279 bounds
280}
281
282fn setup_channels(
283 num_threads: usize,
284 memory_gb: usize,
285 input_size_bytes: u64,
286 temp_path: &Path,
287) -> Result<(Vec<Sender>, Vec<Vec<BucketWriter>>), SketchError> {
288 let capacity = compute_channel_capacity(memory_gb, input_size_bytes);
289
290 let (senders, receivers): (Vec<_>, Vec<_>) = (0..BUCKET_COUNT)
291 .map(|_| mpsc::bounded_blocking(capacity))
292 .unzip();
293
294 let bucket_threads = num_threads.min(BUCKET_COUNT);
295 let chunk_sizes = distribute_evenly(BUCKET_COUNT, bucket_threads);
296
297 let mut rx_iter = receivers.into_iter().enumerate();
298 let mut bucket_writers: Vec<Vec<BucketWriter>> = chunk_sizes
299 .map(|count| {
300 rx_iter
301 .by_ref()
302 .take(count)
303 .map(|(bucket_id, receiver)| {
304 let writer = EntryWriter::new(
305 temp_path.join(format!("bucket_{bucket_id:03}.bin")),
306 WRITE_BUFFER_SIZE,
307 )?;
308 Ok(BucketWriter {
309 receiver,
310 writer,
311 bucket_id,
312 })
313 })
314 .collect::<Result<Vec<_>, std::io::Error>>()
315 })
316 .collect::<Result<Vec<_>, _>>()?;
317
318 bucket_writers.resize_with(num_threads, Vec::new);
319
320 Ok((senders, bucket_writers))
321}
322
323fn compute_channel_capacity(memory_gb: usize, input_size_bytes: u64) -> usize {
324 let memory_bytes = memory_gb as u64 * 1024 * 1024 * 1024;
325 let writer_memory = BUCKET_COUNT as u64 * WRITE_BUFFER_SIZE as u64;
326
327 let available = memory_bytes
328 .saturating_sub(input_size_bytes)
329 .saturating_sub(writer_memory);
330
331 let computed = (available / (BUCKET_COUNT as u64 * ENTRY_SIZE as u64)) as usize;
332
333 computed.clamp(1024, OPTIMAL_CHANNEL_CAPACITY)
334}
335
336fn scan_boundaries(mmap: &Mmap, magic: [u8; 2]) -> Vec<usize> {
337 if mmap.len() < MIN_SPLIT_SIZE {
338 return vec![0];
339 }
340 match magic[0] {
341 b'>' => scan_fasta_boundaries(mmap),
342 _ => scan_fastq_boundaries(mmap),
343 }
344}
345
346fn distribute_work_units(
347 positions: Vec<(Arc<Mmap>, Arc<PathBuf>, usize, usize)>,
348 num_threads: usize,
349 singleton: bool,
350 sample_counter: &AtomicU32,
351 thread_work: &mut [Vec<WorkUnit>],
352) {
353 if positions.is_empty() {
354 return;
355 }
356
357 let mut file_sample_ids: std::collections::HashMap<PathBuf, u32> =
358 std::collections::HashMap::new();
359
360 let chunk_sizes: Vec<_> = distribute_evenly(positions.len(), num_threads).collect();
361 let mut offset = 0;
362
363 for (t, &count) in chunk_sizes.iter().enumerate() {
364 for (mmap, path, start_byte, end_byte) in &positions[offset..offset + count] {
365 let sample_id = (!singleton).then(|| {
366 *file_sample_ids
367 .entry((**path).clone())
368 .or_insert_with(|| sample_counter.fetch_add(1, Ordering::SeqCst))
369 });
370
371 thread_work[t].push(WorkUnit {
372 mmap: Some(Arc::clone(mmap)),
373 start: *start_byte,
374 end: *end_byte,
375 sample_id,
376 source_path: Arc::clone(path),
377 });
378 }
379 offset += count;
380 }
381}
382
383struct WorkUnitResult {
384 thread_work: Vec<Vec<WorkUnit>>,
385 total_input_bytes: u64,
386}
387
388fn build_work_units(
389 input_files: &[PathBuf],
390 num_threads: usize,
391 singleton: bool,
392 memory_gb: usize,
393 sample_counter: &AtomicU32,
394 show_progress: bool,
395) -> Result<WorkUnitResult, SketchError> {
396 let mut thread_work: Vec<Vec<WorkUnit>> = (0..num_threads).map(|_| Vec::new()).collect();
397 let next_sample = || (!singleton).then(|| sample_counter.fetch_add(1, Ordering::SeqCst));
398
399 let total_input_bytes: u64 = input_files
400 .iter()
401 .filter_map(|p| std::fs::metadata(p).ok())
402 .map(|m| m.len())
403 .sum();
404
405 let memory_bytes = memory_gb as u64 * 1024 * 1024 * 1024;
406
407 let skip_mmap = input_files.len() > MAX_CONCURRENT_MMAPS || total_input_bytes > memory_bytes;
408
409 if skip_mmap {
410 let validation_pb = if show_progress {
411 let pb = ProgressBar::new(input_files.len() as u64);
412 pb.set_style(
413 ProgressStyle::default_bar()
414 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} files validated")
415 .unwrap()
416 .progress_chars("#>-"),
417 );
418 Some(pb)
419 } else {
420 None
421 };
422
423 let validation_results: Vec<Result<(), SketchError>> = input_files
424 .par_iter()
425 .map(|path| {
426 let result = validate_file_header(path).map(|_| ());
427 if let Some(ref pb) = validation_pb {
428 pb.inc(1);
429 }
430 result
431 })
432 .collect();
433
434 if let Some(pb) = validation_pb {
435 pb.finish_with_message("validation complete");
436 }
437
438 for result in validation_results {
439 result?;
440 }
441
442 for (i, path) in input_files.iter().enumerate() {
443 thread_work[i % num_threads].push(WorkUnit {
444 mmap: None,
445 start: 0,
446 end: 0,
447 sample_id: next_sample(),
448 source_path: Arc::new(path.clone()),
449 });
450 }
451 return Ok(WorkUnitResult {
452 thread_work,
453 total_input_bytes: 0,
454 });
455 }
456
457 let mut flat_positions: Vec<(Arc<Mmap>, Arc<PathBuf>, usize, usize)> = Vec::new();
458 let mut compressed_files: Vec<(Arc<Mmap>, Arc<PathBuf>)> = Vec::new();
459
460 for path in input_files {
461 let file = File::open(path)?;
462 let mmap = Arc::new(unsafe { Mmap::map(&file)? });
463 let path = Arc::new(path.clone());
464 let magic = validate_mmap_header(&mmap, &path)?;
465
466 if is_compressed(magic) {
467 compressed_files.push((mmap, path));
468 continue;
469 }
470
471 let boundaries = scan_boundaries(&mmap, magic);
472 for (i, &start) in boundaries.iter().enumerate() {
473 let end = boundaries.get(i + 1).copied().unwrap_or(mmap.len());
474 flat_positions.push((Arc::clone(&mmap), Arc::clone(&path), start, end));
475 }
476 }
477
478 distribute_work_units(
479 flat_positions,
480 num_threads,
481 singleton,
482 sample_counter,
483 &mut thread_work,
484 );
485
486 for (i, (mmap, path)) in compressed_files.into_iter().enumerate() {
487 let end = mmap.len();
488 thread_work[i % num_threads].push(WorkUnit {
489 mmap: Some(mmap),
490 start: 0,
491 end,
492 sample_id: next_sample(),
493 source_path: path,
494 });
495 }
496
497 Ok(WorkUnitResult {
498 thread_work,
499 total_input_bytes,
500 })
501}
502
503pub fn run(input_files: &[PathBuf], config: &SketchConfig) -> Result<SketchResult, SketchError> {
504 if config.fscale == 0 {
505 return Err(SketchError::Config("fscale must be non-zero".to_string()));
506 }
507 if config.kmer_size == 0 || config.kmer_size > 31 {
508 return Err(SketchError::Config(format!(
509 "kmer_size must be between 1 and 31, got {}",
510 config.kmer_size
511 )));
512 }
513
514 let temp_dir = match &config.temp_dir_base {
515 Some(base) => tempfile::Builder::new().prefix("jam_").tempdir_in(base)?,
516 None => tempfile::Builder::new().prefix("jam_").tempdir()?,
517 };
518
519 let frac_max = u64::MAX / config.fscale;
520 let sample_counter = Arc::new(AtomicU32::new(0));
521 let num_threads = config.num_threads.max(1);
522
523 let WorkUnitResult {
524 thread_work,
525 total_input_bytes,
526 } = build_work_units(
527 input_files,
528 num_threads,
529 config.singleton,
530 config.memory,
531 &sample_counter,
532 config.show_progress,
533 )?;
534 let (senders, resources) = setup_channels(
535 num_threads,
536 config.memory,
537 total_input_bytes,
538 temp_dir.path(),
539 )?;
540
541 let (result_tx, result_rx) = std::sync::mpsc::channel();
542
543 let sample_names_map: DashMap<u32, String> = DashMap::new();
544
545 let total_files: u64 = thread_work.iter().map(|w| w.len() as u64).sum();
546 let files_processed = Arc::new(AtomicU64::new(0));
547
548 let progress_bar = if config.show_progress {
549 let pb = ProgressBar::new(total_files);
550 pb.set_style(
551 ProgressStyle::default_bar()
552 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} files ({msg})")
553 .unwrap()
554 .progress_chars("#>-"),
555 );
556 pb.enable_steady_tick(Duration::from_millis(100));
557 Some(pb)
558 } else {
559 None
560 };
561
562 rayon::scope(|s| {
563 let sample_counter = &sample_counter;
564 let files_processed = &files_processed;
565 let progress_bar = &progress_bar;
566 let sample_names_map = &sample_names_map;
567 for (work, res) in thread_work.into_iter().zip(resources) {
568 let thread_senders = senders.to_vec();
569 let result_tx = result_tx.clone();
570
571 s.spawn(move |_| {
572 let result = process_thread_work(
573 work,
574 thread_senders,
575 res,
576 config,
577 sample_counter,
578 frac_max,
579 files_processed,
580 progress_bar,
581 sample_names_map,
582 );
583 let _ = result_tx.send(result);
584 });
585 }
586 drop(senders);
587 drop(result_tx);
588 });
589
590 if let Some(pb) = progress_bar {
591 let final_samples = sample_counter.load(Ordering::SeqCst);
592 pb.finish_with_message(format!("{} samples", final_samples));
593 }
594
595 let mut bucket_entry_counts = [0u64; BUCKET_COUNT];
596 for result in result_rx {
597 for (bucket_idx, count) in result? {
598 bucket_entry_counts[bucket_idx] = count;
599 }
600 }
601
602 let sample_count = sample_counter.load(Ordering::SeqCst);
603 let mut sample_names: Vec<String> = vec![String::new(); sample_count as usize];
604
605 for entry in sample_names_map.iter() {
606 if (*entry.key() as usize) < sample_names.len() {
607 sample_names[*entry.key() as usize] = entry.value().clone();
608 }
609 }
610
611 Ok(SketchResult {
612 sample_count,
613 bucket_entry_counts,
614 frac_max,
615 temp_dir,
616 sample_names,
617 })
618}
619
620#[allow(clippy::too_many_arguments)]
621fn process_thread_work(
622 work_units: Vec<WorkUnit>,
623 senders: Vec<Sender>,
624 mut bucket_writers: Vec<BucketWriter>,
625 config: &SketchConfig,
626 sample_counter: &AtomicU32,
627 frac_max: u64,
628 files_processed: &AtomicU64,
629 progress_bar: &Option<ProgressBar>,
630 sample_names_map: &DashMap<u32, String>,
631) -> Result<Vec<(usize, u64)>, SketchError> {
632 let ctx = SketchContext {
633 senders: &senders,
634 config,
635 sample_counter,
636 frac_max,
637 sample_names: sample_names_map,
638 };
639
640 for unit in &work_units {
641 let reader: Box<dyn io::Read + Send> = match &unit.mmap {
642 Some(mmap) => Box::new(MmapSliceReader::new(Arc::clone(mmap), unit.start, unit.end)),
643 None => Box::new(io::BufReader::new(File::open(&*unit.source_path)?)),
644 };
645
646 let mut fastx = match parse_fastx_reader(reader) {
647 Ok(reader) => reader,
648 Err(e) if e.kind == needletail::errors::ParseErrorKind::EmptyFile => {
649 eprintln!(
650 "Empty file detected: {}, skipping",
651 unit.source_path.display()
652 );
653 let processed = files_processed.fetch_add(1, Ordering::Relaxed) + 1;
654 if let Some(pb) = progress_bar {
655 pb.set_position(processed);
656 }
657 continue;
658 }
659 Err(e) => {
660 return Err(SketchError::Parse {
661 path: (*unit.source_path).clone(),
662 message: e.to_string(),
663 });
664 }
665 };
666
667 sketch_records(
668 fastx.as_mut(),
669 unit.sample_id,
670 &unit.source_path,
671 &ctx,
672 &mut bucket_writers,
673 )?;
674
675 let processed = files_processed.fetch_add(1, Ordering::Relaxed) + 1;
676 if let Some(pb) = progress_bar {
677 pb.set_position(processed);
678 let samples = sample_counter.load(Ordering::Relaxed);
679 pb.set_message(format!("{} samples", samples));
680 }
681 }
682
683 drop(senders);
684
685 const DRAIN_TIMEOUT: Duration = Duration::from_millis(1);
686 loop {
687 let mut any_pending = false;
688 for bw in bucket_writers.iter_mut() {
689 if bw.drain_until_disconnected(DRAIN_TIMEOUT)? {
690 any_pending = true;
691 }
692 }
693 if !any_pending {
694 break;
695 }
696 }
697
698 bucket_writers
699 .iter_mut()
700 .map(|bw| {
701 bw.writer.flush()?;
702 Ok((bw.bucket_id, bw.writer.count()))
703 })
704 .collect()
705}
706
707fn sketch_records(
708 reader: &mut dyn needletail::FastxReader,
709 file_sample_id: Option<u32>,
710 source_path: &Path,
711 ctx: &SketchContext,
712 bucket_writers: &mut [BucketWriter],
713) -> Result<(), SketchError> {
714 let k = ctx.config.kmer_size;
715 let min_entropy = ctx.config.min_entropy;
716 let timeout = ctx.config.send_timeout;
717
718 while let Some(record) = reader.next() {
719 let record = record.map_err(|e| SketchError::Parse {
720 path: source_path.to_path_buf(),
721 message: e.to_string(),
722 })?;
723
724 let sample_id =
725 file_sample_id.unwrap_or_else(|| ctx.sample_counter.fetch_add(1, Ordering::SeqCst));
726
727 if !ctx.sample_names.contains_key(&sample_id) {
728 let name = if file_sample_id.is_some() {
729 source_path
730 .file_name()
731 .and_then(|s| s.to_str())
732 .unwrap_or("sample")
733 .to_string()
734 } else {
735 String::from_utf8_lossy(record.id()).to_string()
736 };
737 ctx.sample_names.insert(sample_id, name);
738 }
739
740 let sequence = record.normalize(false);
741 if sequence.len() < k as usize {
742 continue;
743 }
744
745 for (_, kmer, _) in sequence.bit_kmers(k, true) {
746 let hash = jamhash_u64(kmer.0);
747
748 if hash >= ctx.frac_max {
749 continue;
750 }
751
752 if min_entropy > 0.0 && !passes_entropy_filter(kmer.0, k, min_entropy) {
753 continue;
754 }
755
756 if ctx
757 .config
758 .bias_table
759 .as_ref()
760 .is_some_and(|b| !b.passes_filter(hash))
761 {
762 continue;
763 }
764
765 let entry = Entry::new(hash, sample_id);
766 let bucket = bucket_id(hash);
767
768 if let Err(crossfire::SendTimeoutError::Timeout(mut entry)) =
769 ctx.senders[bucket].send_timeout(entry, timeout)
770 {
771 const MAX_RETRIES: u32 = 10;
772
773 for retry in 0..MAX_RETRIES {
774 for bw in bucket_writers.iter_mut() {
775 bw.drain()?;
776 }
777
778 let backoff_sleep = Duration::from_micros(100 << retry.min(4));
779 std::thread::sleep(backoff_sleep);
780
781 let backoff_timeout = timeout.saturating_mul(1 << retry.min(4));
782 match ctx.senders[bucket].send_timeout(entry, backoff_timeout) {
783 Ok(()) => break,
784 Err(crossfire::SendTimeoutError::Timeout(e)) => {
785 entry = e;
786 if retry == MAX_RETRIES - 1 {
787 if ctx.senders[bucket].send(entry).is_err() {
788 return Err(SketchError::Channel);
789 }
790 }
791 }
792 Err(crossfire::SendTimeoutError::Disconnected(_)) => {
793 return Err(SketchError::Channel);
794 }
795 }
796 }
797 }
798 }
799 }
800
801 Ok(())
802}
803
804#[cfg(test)]
805mod tests {
806 use super::*;
807 use std::io::Write;
808 use tempfile::NamedTempFile;
809
810 fn make_fasta(seqs: &[(&str, &str)]) -> NamedTempFile {
811 let mut f = NamedTempFile::with_suffix(".fa").unwrap();
812 for (name, seq) in seqs {
813 writeln!(f, ">{name}").unwrap();
814 writeln!(f, "{seq}").unwrap();
815 }
816 f
817 }
818
819 #[test]
820 fn test_sketch_basic() {
821 let input = make_fasta(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
822 let config = SketchConfig {
823 kmer_size: 11,
824 fscale: 1,
825 num_threads: 2,
826 memory: 1,
827 ..Default::default()
828 };
829
830 let result = run(&[input.path().to_path_buf()], &config).unwrap();
831 assert_eq!(result.sample_count, 1);
832 assert!(result.bucket_entry_counts.iter().sum::<u64>() > 0);
833 }
834
835 #[test]
836 fn test_sketch_singleton() {
837 let input = make_fasta(&[
838 ("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
839 ("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
840 ]);
841 let config = SketchConfig {
842 kmer_size: 11,
843 fscale: 1,
844 singleton: true,
845 num_threads: 2,
846 memory: 1,
847 ..Default::default()
848 };
849
850 let result = run(&[input.path().to_path_buf()], &config).unwrap();
851 assert_eq!(result.sample_count, 2);
852 }
853
854 #[test]
855 fn test_sketch_fracmin_filters() {
856 let input = make_fasta(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
857
858 let result_all = run(
859 &[input.path().to_path_buf()],
860 &SketchConfig {
861 kmer_size: 11,
862 fscale: 1,
863 memory: 1,
864 ..Default::default()
865 },
866 )
867 .unwrap();
868
869 let result_filtered = run(
870 &[input.path().to_path_buf()],
871 &SketchConfig {
872 kmer_size: 11,
873 fscale: 100,
874 memory: 1,
875 ..Default::default()
876 },
877 )
878 .unwrap();
879
880 let total_all: u64 = result_all.bucket_entry_counts.iter().sum();
881 let total_filtered: u64 = result_filtered.bucket_entry_counts.iter().sum();
882 assert!(total_filtered < total_all);
883 }
884
885 #[test]
886 fn test_sketch_multiple_files() {
887 let input1 = make_fasta(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
888 let input2 = make_fasta(&[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")]);
889 let config = SketchConfig {
890 kmer_size: 11,
891 fscale: 1,
892 num_threads: 2,
893 memory: 1,
894 ..Default::default()
895 };
896
897 let result = run(
898 &[input1.path().to_path_buf(), input2.path().to_path_buf()],
899 &config,
900 )
901 .unwrap();
902 assert_eq!(result.sample_count, 2);
903 assert!(result.bucket_entry_counts.iter().sum::<u64>() > 0);
904 }
905
906 #[test]
907 fn test_sketch_backpressure() {
908 let input = make_fasta(&[(
909 "seq1",
910 "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG",
911 )]);
912 let config = SketchConfig {
913 kmer_size: 11,
914 fscale: 1,
915 num_threads: 2,
916 memory: MIN_MEMORY_GB,
917 send_timeout: Duration::from_micros(100),
918 ..Default::default()
919 };
920
921 let result = run(&[input.path().to_path_buf()], &config).unwrap();
922 assert!(result.bucket_entry_counts.iter().sum::<u64>() > 0);
923 }
924
925 #[test]
926 fn test_channel_capacity_calculation() {
927 let cap_4gb = compute_channel_capacity(4, 0);
928 assert_eq!(cap_4gb, OPTIMAL_CHANNEL_CAPACITY);
929
930 let cap_4gb_with_input = compute_channel_capacity(4, 3 * 1024 * 1024 * 1024);
931 assert!(cap_4gb_with_input < cap_4gb);
932
933 let cap_exceeded = compute_channel_capacity(4, 10 * 1024 * 1024 * 1024);
934 assert_eq!(cap_exceeded, 1024);
935
936 assert_eq!(compute_channel_capacity(0, 0), 1024);
937 }
938
939 #[test]
940 fn test_scan_fasta_boundaries() {
941 let data = b">seq1\nATCG\n>seq2\nGCTA\n>seq3\nAAAA\n";
942 assert_eq!(scan_fasta_boundaries(data), vec![0, 11, 22]);
943 }
944
945 #[test]
946 fn test_scan_fastq_boundaries() {
947 let data = b"@read1\nATCG\n+\nIIII\n@read2\nGCTA\n+\nIIII\n";
948 assert_eq!(scan_fastq_boundaries(data), vec![0, 19]);
949 }
950
951 #[test]
952 fn test_scan_fastq_boundaries_wrapped() {
953 let data = b"@read1\nATCG\nGCTA\n+\nIIII\nJJJJ\n@read2\nAAAA\n+\nKKKK\n";
954 let bounds = scan_fastq_boundaries(data);
955 assert_eq!(bounds[0], 0);
956 assert!(bounds.contains(&29));
957 }
958
959 #[test]
960 fn test_scan_fastq_boundaries_at_in_quality() {
961 let data = b"@read1\nATCG\n+\n@@@I\n@read2\nGCTA\n+\nIIII\n";
962 let bounds = scan_fastq_boundaries(data);
963 assert_eq!(bounds, vec![0, 19]);
964 }
965
966 #[test]
967 fn test_scan_fastq_boundaries_at_followed_by_at_skipped() {
968 let data = b"@read1\nATCG\n+\nIIII\n@@ambiguous\n";
969 assert_eq!(scan_fastq_boundaries(data), vec![0]);
970 }
971
972 #[test]
973 fn test_scan_fastq_boundaries_iupac_codes() {
974 let data = b"@read1\nRYSW\n+\nIIII\n@read2\nKMBD\n+\nIIII\n";
975 let bounds = scan_fastq_boundaries(data);
976 assert_eq!(bounds, vec![0, 19]);
977 }
978
979 #[test]
980 fn test_is_compressed() {
981 assert!(is_compressed([0x1F, 0x8B]));
982 assert!(is_compressed([0x42, 0x5A]));
983 assert!(is_compressed([0xFD, 0x37]));
984 assert!(is_compressed([0x28, 0xB5]));
985 assert!(!is_compressed([b'>', b's']));
986 assert!(!is_compressed([b'@', b'r']));
987 }
988
989 #[test]
990 fn test_mmap_slice_reader() {
991 use std::io::Read;
992 let dir = tempfile::tempdir().unwrap();
993 let path = dir.path().join("test.bin");
994 std::fs::write(&path, b"Hello, World!").unwrap();
995
996 let file = File::open(&path).unwrap();
997 let mmap = Arc::new(unsafe { Mmap::map(&file).unwrap() });
998
999 let mut reader = MmapSliceReader::new(mmap, 7, 12);
1000 let mut buf = String::new();
1001 reader.read_to_string(&mut buf).unwrap();
1002 assert_eq!(buf, "World");
1003 }
1004
1005 #[test]
1006 fn test_tiny_file_errors() {
1007 let dir = tempfile::tempdir().unwrap();
1008
1009 let empty_path = dir.path().join("empty.fa");
1010 std::fs::write(&empty_path, b"").unwrap();
1011
1012 let config = SketchConfig::default();
1013 let result = run(&[empty_path], &config);
1014 let err = match result {
1015 Err(e) => e,
1016 Ok(_) => panic!("expected error for empty file"),
1017 };
1018 assert!(err.to_string().contains("too small"));
1019
1020 let tiny_path = dir.path().join("tiny.fa");
1021 std::fs::write(&tiny_path, b">").unwrap();
1022
1023 let result = run(&[tiny_path], &config);
1024 let err = match result {
1025 Err(e) => e,
1026 Ok(_) => panic!("expected error for 1-byte file"),
1027 };
1028 assert!(err.to_string().contains("too small"));
1029 }
1030
1031 #[test]
1032 fn test_fscale_zero_errors() {
1033 let input = make_fasta(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1034 let config = SketchConfig {
1035 fscale: 0,
1036 ..Default::default()
1037 };
1038
1039 let result = run(&[input.path().to_path_buf()], &config);
1040 let err = match result {
1041 Err(e) => e,
1042 Ok(_) => panic!("expected error for fscale=0"),
1043 };
1044 assert!(err.to_string().contains("fscale must be non-zero"));
1045 }
1046
1047 #[test]
1048 fn test_kmer_size_validation() {
1049 let input = make_fasta(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1050
1051 let config = SketchConfig {
1052 kmer_size: 0,
1053 memory: 1,
1054 ..Default::default()
1055 };
1056 let err = match run(&[input.path().to_path_buf()], &config) {
1057 Err(e) => e,
1058 Ok(_) => panic!("expected error for kmer_size=0"),
1059 };
1060 assert!(
1061 err.to_string()
1062 .contains("kmer_size must be between 1 and 31")
1063 );
1064
1065 let config = SketchConfig {
1066 kmer_size: 32,
1067 memory: 1,
1068 ..Default::default()
1069 };
1070 let err = match run(&[input.path().to_path_buf()], &config) {
1071 Err(e) => e,
1072 Ok(_) => panic!("expected error for kmer_size=32"),
1073 };
1074 assert!(
1075 err.to_string()
1076 .contains("kmer_size must be between 1 and 31")
1077 );
1078
1079 let config = SketchConfig {
1080 kmer_size: 31,
1081 fscale: 1,
1082 memory: 1,
1083 ..Default::default()
1084 };
1085 let result = run(&[input.path().to_path_buf()], &config);
1086 assert!(result.is_ok());
1087 }
1088}