Skip to main content

jam_rs/
sketch.rs

1use crate::bias::HashBiasTable;
2use crate::core_utils::passes_entropy_filter;
3use crate::format::{BUCKET_COUNT, ENTRY_SIZE, Entry, bucket_id};
4use crate::io::EntryWriter;
5use crossfire::mpsc;
6use crossfire::{MTx, Rx};
7use dashmap::DashMap;
8use indicatif::{ProgressBar, ProgressStyle};
9use jamhash::jamhash_u64;
10use memmap2::Mmap;
11use needletail::{Sequence, parse_fastx_reader};
12use rayon::prelude::*;
13use std::fs::File;
14use std::io;
15use std::path::{Path, PathBuf};
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
18use std::time::Duration;
19use tempfile::TempDir;
20
21const WRITE_BUFFER_SIZE: usize = 8 * 1024 * 1024;
22const MIN_MEMORY_GB: usize = 4;
23const DEFAULT_SEND_TIMEOUT: Duration = Duration::from_millis(1);
24const MIN_SPLIT_SIZE: usize = 1024 * 1024;
25const MAX_CONCURRENT_MMAPS: usize = 256;
26const OPTIMAL_CHANNEL_CAPACITY: usize = 512 * 1024;
27
28type Sender = MTx<mpsc::Array<Entry>>;
29type Receiver = Rx<mpsc::Array<Entry>>;
30
31#[derive(Clone)]
32pub struct SketchConfig {
33    pub kmer_size: u8,
34    pub fscale: u64,
35    pub num_threads: usize,
36    pub memory: usize,
37    pub temp_dir_base: Option<PathBuf>,
38    pub min_entropy: f64,
39    pub singleton: bool,
40    pub bias_table: Option<Arc<HashBiasTable>>,
41    pub send_timeout: Duration,
42    pub show_progress: bool,
43}
44
45impl Default for SketchConfig {
46    fn default() -> Self {
47        Self {
48            kmer_size: 21,
49            fscale: 1000,
50            num_threads: 1,
51            memory: MIN_MEMORY_GB,
52            temp_dir_base: None,
53            min_entropy: 0.0,
54            singleton: false,
55            bias_table: None,
56            send_timeout: DEFAULT_SEND_TIMEOUT,
57            show_progress: false,
58        }
59    }
60}
61
62pub struct SketchResult {
63    pub sample_count: u32,
64    pub bucket_entry_counts: [u64; BUCKET_COUNT],
65    pub frac_max: u64,
66    pub temp_dir: TempDir,
67    pub sample_names: Vec<String>,
68}
69
70#[derive(Debug, thiserror::Error)]
71pub enum SketchError {
72    #[error("I/O error: {0}")]
73    Io(#[from] std::io::Error),
74
75    #[error("Parse error in {path}: {message}")]
76    Parse { path: PathBuf, message: String },
77
78    #[error("Channel send error")]
79    Channel,
80
81    #[error("Invalid configuration: {0}")]
82    Config(String),
83}
84
85struct WorkUnit {
86    mmap: Option<Arc<Mmap>>,
87    start: usize,
88    end: usize,
89    sample_id: Option<u32>,
90    source_path: Arc<PathBuf>,
91}
92
93struct BucketWriter {
94    receiver: Receiver,
95    writer: EntryWriter,
96    bucket_id: usize,
97}
98
99impl BucketWriter {
100    fn drain(&mut self) -> io::Result<()> {
101        while let Ok(entry) = self.receiver.try_recv() {
102            self.writer.write(&entry)?;
103        }
104        Ok(())
105    }
106
107    fn drain_until_disconnected(&mut self, timeout: Duration) -> io::Result<bool> {
108        match self.receiver.recv_timeout(timeout) {
109            Ok(entry) => {
110                self.writer.write(&entry)?;
111                Ok(true)
112            }
113            Err(crossfire::RecvTimeoutError::Timeout) => Ok(true),
114            Err(crossfire::RecvTimeoutError::Disconnected) => Ok(false),
115        }
116    }
117}
118
119struct SketchContext<'a> {
120    senders: &'a [Sender],
121    config: &'a SketchConfig,
122    sample_counter: &'a AtomicU32,
123    frac_max: u64,
124    sample_names: &'a DashMap<u32, String>,
125}
126
127struct MmapSliceReader {
128    mmap: Arc<Mmap>,
129    start: usize,
130    end: usize,
131    pos: usize,
132}
133
134impl MmapSliceReader {
135    fn new(mmap: Arc<Mmap>, start: usize, end: usize) -> Self {
136        Self {
137            mmap,
138            start,
139            end,
140            pos: 0,
141        }
142    }
143}
144
145impl io::Read for MmapSliceReader {
146    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
147        let current = self.start + self.pos;
148        if current >= self.end {
149            return Ok(0);
150        }
151        let remaining = &self.mmap[current..self.end];
152        let n = remaining.len().min(buf.len());
153        buf[..n].copy_from_slice(&remaining[..n]);
154        self.pos += n;
155        Ok(n)
156    }
157}
158
159fn distribute_evenly(total: usize, parts: usize) -> impl Iterator<Item = usize> {
160    let per_part = total / parts;
161    let remainder = total % parts;
162    (0..parts).map(move |i| {
163        if i < remainder {
164            per_part + 1
165        } else {
166            per_part
167        }
168    })
169}
170
171const GZ_MAGIC: [u8; 2] = [0x1F, 0x8B];
172const BZ_MAGIC: [u8; 2] = [0x42, 0x5A];
173const XZ_MAGIC: [u8; 2] = [0xFD, 0x37];
174const ZST_MAGIC: [u8; 2] = [0x28, 0xB5];
175
176#[inline]
177fn is_compressed(magic: [u8; 2]) -> bool {
178    matches!(magic, GZ_MAGIC | BZ_MAGIC | XZ_MAGIC | ZST_MAGIC)
179}
180
181fn validate_format(magic: [u8; 2], path: &Path) -> Result<(), SketchError> {
182    if !is_compressed(magic) && !matches!(magic[0], b'>' | b'@') {
183        return Err(SketchError::Parse {
184            path: path.to_path_buf(),
185            message: format!(
186                "unrecognized format (bytes: [{:#04X}, {:#04X}])",
187                magic[0], magic[1]
188            ),
189        });
190    }
191    Ok(())
192}
193
194fn validate_file_header(path: &Path) -> Result<[u8; 2], SketchError> {
195    let mut file = File::open(path)?;
196    let mut magic = [0u8; 2];
197    use std::io::Read;
198    let bytes_read = file.read(&mut magic)?;
199    if bytes_read < 2 {
200        return Err(SketchError::Parse {
201            path: path.to_path_buf(),
202            message: format!("file too small ({bytes_read} bytes) to be valid FASTA/FASTQ"),
203        });
204    }
205    validate_format(magic, path)?;
206    Ok(magic)
207}
208
209fn validate_mmap_header(mmap: &Mmap, path: &Path) -> Result<[u8; 2], SketchError> {
210    if mmap.len() < 2 {
211        return Err(SketchError::Parse {
212            path: path.to_path_buf(),
213            message: format!(
214                "file too small ({} bytes) to be valid FASTA/FASTQ",
215                mmap.len()
216            ),
217        });
218    }
219    let magic = [mmap[0], mmap[1]];
220    validate_format(magic, path)?;
221    Ok(magic)
222}
223
224fn scan_fasta_boundaries(data: &[u8]) -> Vec<usize> {
225    let mut bounds = vec![0];
226    bounds.extend(
227        data.windows(2)
228            .enumerate()
229            .filter_map(|(i, w)| (w == b"\n>").then_some(i + 1)),
230    );
231    bounds
232}
233
234#[inline]
235fn is_iupac_nucleotide(b: u8) -> bool {
236    matches!(
237        b | 0x20,
238        b'a' | b'c'
239            | b'g'
240            | b't'
241            | b'u'
242            | b'n'
243            | b'r'
244            | b'y'
245            | b's'
246            | b'w'
247            | b'k'
248            | b'm'
249            | b'b'
250            | b'd'
251            | b'h'
252            | b'v'
253    )
254}
255
256fn scan_fastq_boundaries(data: &[u8]) -> Vec<usize> {
257    let mut bounds = vec![0];
258    let mut i = 0;
259
260    while i + 1 < data.len() {
261        if data[i] == b'\n' && data[i + 1] == b'@' {
262            let header_start = i + 1;
263            let header_end = data[header_start..]
264                .iter()
265                .position(|&b| b == b'\n')
266                .map(|p| header_start + p)
267                .unwrap_or(data.len());
268
269            if header_end > header_start + 1 {
270                let seq_start = header_end + 1;
271                if seq_start < data.len() && is_iupac_nucleotide(data[seq_start]) {
272                    bounds.push(header_start);
273                }
274            }
275        }
276        i += 1;
277    }
278
279    bounds
280}
281
282fn setup_channels(
283    num_threads: usize,
284    memory_gb: usize,
285    input_size_bytes: u64,
286    temp_path: &Path,
287) -> Result<(Vec<Sender>, Vec<Vec<BucketWriter>>), SketchError> {
288    let capacity = compute_channel_capacity(memory_gb, input_size_bytes);
289
290    let (senders, receivers): (Vec<_>, Vec<_>) = (0..BUCKET_COUNT)
291        .map(|_| mpsc::bounded_blocking(capacity))
292        .unzip();
293
294    let bucket_threads = num_threads.min(BUCKET_COUNT);
295    let chunk_sizes = distribute_evenly(BUCKET_COUNT, bucket_threads);
296
297    let mut rx_iter = receivers.into_iter().enumerate();
298    let mut bucket_writers: Vec<Vec<BucketWriter>> = chunk_sizes
299        .map(|count| {
300            rx_iter
301                .by_ref()
302                .take(count)
303                .map(|(bucket_id, receiver)| {
304                    let writer = EntryWriter::new(
305                        temp_path.join(format!("bucket_{bucket_id:03}.bin")),
306                        WRITE_BUFFER_SIZE,
307                    )?;
308                    Ok(BucketWriter {
309                        receiver,
310                        writer,
311                        bucket_id,
312                    })
313                })
314                .collect::<Result<Vec<_>, std::io::Error>>()
315        })
316        .collect::<Result<Vec<_>, _>>()?;
317
318    bucket_writers.resize_with(num_threads, Vec::new);
319
320    Ok((senders, bucket_writers))
321}
322
323fn compute_channel_capacity(memory_gb: usize, input_size_bytes: u64) -> usize {
324    let memory_bytes = memory_gb as u64 * 1024 * 1024 * 1024;
325    let writer_memory = BUCKET_COUNT as u64 * WRITE_BUFFER_SIZE as u64;
326
327    let available = memory_bytes
328        .saturating_sub(input_size_bytes)
329        .saturating_sub(writer_memory);
330
331    let computed = (available / (BUCKET_COUNT as u64 * ENTRY_SIZE as u64)) as usize;
332
333    computed.clamp(1024, OPTIMAL_CHANNEL_CAPACITY)
334}
335
336fn scan_boundaries(mmap: &Mmap, magic: [u8; 2]) -> Vec<usize> {
337    if mmap.len() < MIN_SPLIT_SIZE {
338        return vec![0];
339    }
340    match magic[0] {
341        b'>' => scan_fasta_boundaries(mmap),
342        _ => scan_fastq_boundaries(mmap),
343    }
344}
345
346fn distribute_work_units(
347    positions: Vec<(Arc<Mmap>, Arc<PathBuf>, usize, usize)>,
348    num_threads: usize,
349    singleton: bool,
350    sample_counter: &AtomicU32,
351    thread_work: &mut [Vec<WorkUnit>],
352) {
353    if positions.is_empty() {
354        return;
355    }
356
357    let mut file_sample_ids: std::collections::HashMap<PathBuf, u32> =
358        std::collections::HashMap::new();
359
360    let chunk_sizes: Vec<_> = distribute_evenly(positions.len(), num_threads).collect();
361    let mut offset = 0;
362
363    for (t, &count) in chunk_sizes.iter().enumerate() {
364        for (mmap, path, start_byte, end_byte) in &positions[offset..offset + count] {
365            let sample_id = (!singleton).then(|| {
366                *file_sample_ids
367                    .entry((**path).clone())
368                    .or_insert_with(|| sample_counter.fetch_add(1, Ordering::SeqCst))
369            });
370
371            thread_work[t].push(WorkUnit {
372                mmap: Some(Arc::clone(mmap)),
373                start: *start_byte,
374                end: *end_byte,
375                sample_id,
376                source_path: Arc::clone(path),
377            });
378        }
379        offset += count;
380    }
381}
382
383struct WorkUnitResult {
384    thread_work: Vec<Vec<WorkUnit>>,
385    total_input_bytes: u64,
386}
387
388fn build_work_units(
389    input_files: &[PathBuf],
390    num_threads: usize,
391    singleton: bool,
392    memory_gb: usize,
393    sample_counter: &AtomicU32,
394    show_progress: bool,
395) -> Result<WorkUnitResult, SketchError> {
396    let mut thread_work: Vec<Vec<WorkUnit>> = (0..num_threads).map(|_| Vec::new()).collect();
397    let next_sample = || (!singleton).then(|| sample_counter.fetch_add(1, Ordering::SeqCst));
398
399    let total_input_bytes: u64 = input_files
400        .iter()
401        .filter_map(|p| std::fs::metadata(p).ok())
402        .map(|m| m.len())
403        .sum();
404
405    let memory_bytes = memory_gb as u64 * 1024 * 1024 * 1024;
406
407    let skip_mmap = input_files.len() > MAX_CONCURRENT_MMAPS || total_input_bytes > memory_bytes;
408
409    if skip_mmap {
410        let validation_pb = if show_progress {
411            let pb = ProgressBar::new(input_files.len() as u64);
412            pb.set_style(
413                ProgressStyle::default_bar()
414                    .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} files validated")
415                    .unwrap()
416                    .progress_chars("#>-"),
417            );
418            Some(pb)
419        } else {
420            None
421        };
422
423        let validation_results: Vec<Result<(), SketchError>> = input_files
424            .par_iter()
425            .map(|path| {
426                let result = validate_file_header(path).map(|_| ());
427                if let Some(ref pb) = validation_pb {
428                    pb.inc(1);
429                }
430                result
431            })
432            .collect();
433
434        if let Some(pb) = validation_pb {
435            pb.finish_with_message("validation complete");
436        }
437
438        for result in validation_results {
439            result?;
440        }
441
442        for (i, path) in input_files.iter().enumerate() {
443            thread_work[i % num_threads].push(WorkUnit {
444                mmap: None,
445                start: 0,
446                end: 0,
447                sample_id: next_sample(),
448                source_path: Arc::new(path.clone()),
449            });
450        }
451        return Ok(WorkUnitResult {
452            thread_work,
453            total_input_bytes: 0,
454        });
455    }
456
457    let mut flat_positions: Vec<(Arc<Mmap>, Arc<PathBuf>, usize, usize)> = Vec::new();
458    let mut compressed_files: Vec<(Arc<Mmap>, Arc<PathBuf>)> = Vec::new();
459
460    for path in input_files {
461        let file = File::open(path)?;
462        let mmap = Arc::new(unsafe { Mmap::map(&file)? });
463        let path = Arc::new(path.clone());
464        let magic = validate_mmap_header(&mmap, &path)?;
465
466        if is_compressed(magic) {
467            compressed_files.push((mmap, path));
468            continue;
469        }
470
471        let boundaries = scan_boundaries(&mmap, magic);
472        for (i, &start) in boundaries.iter().enumerate() {
473            let end = boundaries.get(i + 1).copied().unwrap_or(mmap.len());
474            flat_positions.push((Arc::clone(&mmap), Arc::clone(&path), start, end));
475        }
476    }
477
478    distribute_work_units(
479        flat_positions,
480        num_threads,
481        singleton,
482        sample_counter,
483        &mut thread_work,
484    );
485
486    for (i, (mmap, path)) in compressed_files.into_iter().enumerate() {
487        let end = mmap.len();
488        thread_work[i % num_threads].push(WorkUnit {
489            mmap: Some(mmap),
490            start: 0,
491            end,
492            sample_id: next_sample(),
493            source_path: path,
494        });
495    }
496
497    Ok(WorkUnitResult {
498        thread_work,
499        total_input_bytes,
500    })
501}
502
503pub fn run(input_files: &[PathBuf], config: &SketchConfig) -> Result<SketchResult, SketchError> {
504    if config.fscale == 0 {
505        return Err(SketchError::Config("fscale must be non-zero".to_string()));
506    }
507    if config.kmer_size == 0 || config.kmer_size > 31 {
508        return Err(SketchError::Config(format!(
509            "kmer_size must be between 1 and 31, got {}",
510            config.kmer_size
511        )));
512    }
513
514    let temp_dir = match &config.temp_dir_base {
515        Some(base) => tempfile::Builder::new().prefix("jam_").tempdir_in(base)?,
516        None => tempfile::Builder::new().prefix("jam_").tempdir()?,
517    };
518
519    let frac_max = u64::MAX / config.fscale;
520    let sample_counter = Arc::new(AtomicU32::new(0));
521    let num_threads = config.num_threads.max(1);
522
523    let WorkUnitResult {
524        thread_work,
525        total_input_bytes,
526    } = build_work_units(
527        input_files,
528        num_threads,
529        config.singleton,
530        config.memory,
531        &sample_counter,
532        config.show_progress,
533    )?;
534    let (senders, resources) = setup_channels(
535        num_threads,
536        config.memory,
537        total_input_bytes,
538        temp_dir.path(),
539    )?;
540
541    let (result_tx, result_rx) = std::sync::mpsc::channel();
542
543    let sample_names_map: DashMap<u32, String> = DashMap::new();
544
545    let total_files: u64 = thread_work.iter().map(|w| w.len() as u64).sum();
546    let files_processed = Arc::new(AtomicU64::new(0));
547
548    let progress_bar = if config.show_progress {
549        let pb = ProgressBar::new(total_files);
550        pb.set_style(
551            ProgressStyle::default_bar()
552                .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} files ({msg})")
553                .unwrap()
554                .progress_chars("#>-"),
555        );
556        pb.enable_steady_tick(Duration::from_millis(100));
557        Some(pb)
558    } else {
559        None
560    };
561
562    rayon::scope(|s| {
563        let sample_counter = &sample_counter;
564        let files_processed = &files_processed;
565        let progress_bar = &progress_bar;
566        let sample_names_map = &sample_names_map;
567        for (work, res) in thread_work.into_iter().zip(resources) {
568            let thread_senders = senders.to_vec();
569            let result_tx = result_tx.clone();
570
571            s.spawn(move |_| {
572                let result = process_thread_work(
573                    work,
574                    thread_senders,
575                    res,
576                    config,
577                    sample_counter,
578                    frac_max,
579                    files_processed,
580                    progress_bar,
581                    sample_names_map,
582                );
583                let _ = result_tx.send(result);
584            });
585        }
586        drop(senders);
587        drop(result_tx);
588    });
589
590    if let Some(pb) = progress_bar {
591        let final_samples = sample_counter.load(Ordering::SeqCst);
592        pb.finish_with_message(format!("{} samples", final_samples));
593    }
594
595    let mut bucket_entry_counts = [0u64; BUCKET_COUNT];
596    for result in result_rx {
597        for (bucket_idx, count) in result? {
598            bucket_entry_counts[bucket_idx] = count;
599        }
600    }
601
602    let sample_count = sample_counter.load(Ordering::SeqCst);
603    let mut sample_names: Vec<String> = vec![String::new(); sample_count as usize];
604
605    for entry in sample_names_map.iter() {
606        if (*entry.key() as usize) < sample_names.len() {
607            sample_names[*entry.key() as usize] = entry.value().clone();
608        }
609    }
610
611    Ok(SketchResult {
612        sample_count,
613        bucket_entry_counts,
614        frac_max,
615        temp_dir,
616        sample_names,
617    })
618}
619
620#[allow(clippy::too_many_arguments)]
621fn process_thread_work(
622    work_units: Vec<WorkUnit>,
623    senders: Vec<Sender>,
624    mut bucket_writers: Vec<BucketWriter>,
625    config: &SketchConfig,
626    sample_counter: &AtomicU32,
627    frac_max: u64,
628    files_processed: &AtomicU64,
629    progress_bar: &Option<ProgressBar>,
630    sample_names_map: &DashMap<u32, String>,
631) -> Result<Vec<(usize, u64)>, SketchError> {
632    let ctx = SketchContext {
633        senders: &senders,
634        config,
635        sample_counter,
636        frac_max,
637        sample_names: sample_names_map,
638    };
639
640    for unit in &work_units {
641        let reader: Box<dyn io::Read + Send> = match &unit.mmap {
642            Some(mmap) => Box::new(MmapSliceReader::new(Arc::clone(mmap), unit.start, unit.end)),
643            None => Box::new(io::BufReader::new(File::open(&*unit.source_path)?)),
644        };
645
646        let mut fastx = match parse_fastx_reader(reader) {
647            Ok(reader) => reader,
648            Err(e) if e.kind == needletail::errors::ParseErrorKind::EmptyFile => {
649                eprintln!(
650                    "Empty file detected: {}, skipping",
651                    unit.source_path.display()
652                );
653                let processed = files_processed.fetch_add(1, Ordering::Relaxed) + 1;
654                if let Some(pb) = progress_bar {
655                    pb.set_position(processed);
656                }
657                continue;
658            }
659            Err(e) => {
660                return Err(SketchError::Parse {
661                    path: (*unit.source_path).clone(),
662                    message: e.to_string(),
663                });
664            }
665        };
666
667        sketch_records(
668            fastx.as_mut(),
669            unit.sample_id,
670            &unit.source_path,
671            &ctx,
672            &mut bucket_writers,
673        )?;
674
675        let processed = files_processed.fetch_add(1, Ordering::Relaxed) + 1;
676        if let Some(pb) = progress_bar {
677            pb.set_position(processed);
678            let samples = sample_counter.load(Ordering::Relaxed);
679            pb.set_message(format!("{} samples", samples));
680        }
681    }
682
683    drop(senders);
684
685    const DRAIN_TIMEOUT: Duration = Duration::from_millis(1);
686    loop {
687        let mut any_pending = false;
688        for bw in bucket_writers.iter_mut() {
689            if bw.drain_until_disconnected(DRAIN_TIMEOUT)? {
690                any_pending = true;
691            }
692        }
693        if !any_pending {
694            break;
695        }
696    }
697
698    bucket_writers
699        .iter_mut()
700        .map(|bw| {
701            bw.writer.flush()?;
702            Ok((bw.bucket_id, bw.writer.count()))
703        })
704        .collect()
705}
706
707fn sketch_records(
708    reader: &mut dyn needletail::FastxReader,
709    file_sample_id: Option<u32>,
710    source_path: &Path,
711    ctx: &SketchContext,
712    bucket_writers: &mut [BucketWriter],
713) -> Result<(), SketchError> {
714    let k = ctx.config.kmer_size;
715    let min_entropy = ctx.config.min_entropy;
716    let timeout = ctx.config.send_timeout;
717
718    while let Some(record) = reader.next() {
719        let record = record.map_err(|e| SketchError::Parse {
720            path: source_path.to_path_buf(),
721            message: e.to_string(),
722        })?;
723
724        let sample_id =
725            file_sample_id.unwrap_or_else(|| ctx.sample_counter.fetch_add(1, Ordering::SeqCst));
726
727        if !ctx.sample_names.contains_key(&sample_id) {
728            let name = if file_sample_id.is_some() {
729                source_path
730                    .file_name()
731                    .and_then(|s| s.to_str())
732                    .unwrap_or("sample")
733                    .to_string()
734            } else {
735                String::from_utf8_lossy(record.id()).to_string()
736            };
737            ctx.sample_names.insert(sample_id, name);
738        }
739
740        let sequence = record.normalize(false);
741        if sequence.len() < k as usize {
742            continue;
743        }
744
745        for (_, kmer, _) in sequence.bit_kmers(k, true) {
746            let hash = jamhash_u64(kmer.0);
747
748            if hash >= ctx.frac_max {
749                continue;
750            }
751
752            if min_entropy > 0.0 && !passes_entropy_filter(kmer.0, k, min_entropy) {
753                continue;
754            }
755
756            if ctx
757                .config
758                .bias_table
759                .as_ref()
760                .is_some_and(|b| !b.passes_filter(hash))
761            {
762                continue;
763            }
764
765            let entry = Entry::new(hash, sample_id);
766            let bucket = bucket_id(hash);
767
768            if let Err(crossfire::SendTimeoutError::Timeout(mut entry)) =
769                ctx.senders[bucket].send_timeout(entry, timeout)
770            {
771                const MAX_RETRIES: u32 = 10;
772
773                for retry in 0..MAX_RETRIES {
774                    for bw in bucket_writers.iter_mut() {
775                        bw.drain()?;
776                    }
777
778                    let backoff_sleep = Duration::from_micros(100 << retry.min(4));
779                    std::thread::sleep(backoff_sleep);
780
781                    let backoff_timeout = timeout.saturating_mul(1 << retry.min(4));
782                    match ctx.senders[bucket].send_timeout(entry, backoff_timeout) {
783                        Ok(()) => break,
784                        Err(crossfire::SendTimeoutError::Timeout(e)) => {
785                            entry = e;
786                            if retry == MAX_RETRIES - 1 {
787                                if ctx.senders[bucket].send(entry).is_err() {
788                                    return Err(SketchError::Channel);
789                                }
790                            }
791                        }
792                        Err(crossfire::SendTimeoutError::Disconnected(_)) => {
793                            return Err(SketchError::Channel);
794                        }
795                    }
796                }
797            }
798        }
799    }
800
801    Ok(())
802}
803
804#[cfg(test)]
805mod tests {
806    use super::*;
807    use std::io::Write;
808    use tempfile::NamedTempFile;
809
810    fn make_fasta(seqs: &[(&str, &str)]) -> NamedTempFile {
811        let mut f = NamedTempFile::with_suffix(".fa").unwrap();
812        for (name, seq) in seqs {
813            writeln!(f, ">{name}").unwrap();
814            writeln!(f, "{seq}").unwrap();
815        }
816        f
817    }
818
819    #[test]
820    fn test_sketch_basic() {
821        let input = make_fasta(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
822        let config = SketchConfig {
823            kmer_size: 11,
824            fscale: 1,
825            num_threads: 2,
826            memory: 1,
827            ..Default::default()
828        };
829
830        let result = run(&[input.path().to_path_buf()], &config).unwrap();
831        assert_eq!(result.sample_count, 1);
832        assert!(result.bucket_entry_counts.iter().sum::<u64>() > 0);
833    }
834
835    #[test]
836    fn test_sketch_singleton() {
837        let input = make_fasta(&[
838            ("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG"),
839            ("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"),
840        ]);
841        let config = SketchConfig {
842            kmer_size: 11,
843            fscale: 1,
844            singleton: true,
845            num_threads: 2,
846            memory: 1,
847            ..Default::default()
848        };
849
850        let result = run(&[input.path().to_path_buf()], &config).unwrap();
851        assert_eq!(result.sample_count, 2);
852    }
853
854    #[test]
855    fn test_sketch_fracmin_filters() {
856        let input = make_fasta(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
857
858        let result_all = run(
859            &[input.path().to_path_buf()],
860            &SketchConfig {
861                kmer_size: 11,
862                fscale: 1,
863                memory: 1,
864                ..Default::default()
865            },
866        )
867        .unwrap();
868
869        let result_filtered = run(
870            &[input.path().to_path_buf()],
871            &SketchConfig {
872                kmer_size: 11,
873                fscale: 100,
874                memory: 1,
875                ..Default::default()
876            },
877        )
878        .unwrap();
879
880        let total_all: u64 = result_all.bucket_entry_counts.iter().sum();
881        let total_filtered: u64 = result_filtered.bucket_entry_counts.iter().sum();
882        assert!(total_filtered < total_all);
883    }
884
885    #[test]
886    fn test_sketch_multiple_files() {
887        let input1 = make_fasta(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
888        let input2 = make_fasta(&[("seq2", "GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA")]);
889        let config = SketchConfig {
890            kmer_size: 11,
891            fscale: 1,
892            num_threads: 2,
893            memory: 1,
894            ..Default::default()
895        };
896
897        let result = run(
898            &[input1.path().to_path_buf(), input2.path().to_path_buf()],
899            &config,
900        )
901        .unwrap();
902        assert_eq!(result.sample_count, 2);
903        assert!(result.bucket_entry_counts.iter().sum::<u64>() > 0);
904    }
905
906    #[test]
907    fn test_sketch_backpressure() {
908        let input = make_fasta(&[(
909            "seq1",
910            "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG",
911        )]);
912        let config = SketchConfig {
913            kmer_size: 11,
914            fscale: 1,
915            num_threads: 2,
916            memory: MIN_MEMORY_GB,
917            send_timeout: Duration::from_micros(100),
918            ..Default::default()
919        };
920
921        let result = run(&[input.path().to_path_buf()], &config).unwrap();
922        assert!(result.bucket_entry_counts.iter().sum::<u64>() > 0);
923    }
924
925    #[test]
926    fn test_channel_capacity_calculation() {
927        let cap_4gb = compute_channel_capacity(4, 0);
928        assert_eq!(cap_4gb, OPTIMAL_CHANNEL_CAPACITY);
929
930        let cap_4gb_with_input = compute_channel_capacity(4, 3 * 1024 * 1024 * 1024);
931        assert!(cap_4gb_with_input < cap_4gb);
932
933        let cap_exceeded = compute_channel_capacity(4, 10 * 1024 * 1024 * 1024);
934        assert_eq!(cap_exceeded, 1024);
935
936        assert_eq!(compute_channel_capacity(0, 0), 1024);
937    }
938
939    #[test]
940    fn test_scan_fasta_boundaries() {
941        let data = b">seq1\nATCG\n>seq2\nGCTA\n>seq3\nAAAA\n";
942        assert_eq!(scan_fasta_boundaries(data), vec![0, 11, 22]);
943    }
944
945    #[test]
946    fn test_scan_fastq_boundaries() {
947        let data = b"@read1\nATCG\n+\nIIII\n@read2\nGCTA\n+\nIIII\n";
948        assert_eq!(scan_fastq_boundaries(data), vec![0, 19]);
949    }
950
951    #[test]
952    fn test_scan_fastq_boundaries_wrapped() {
953        let data = b"@read1\nATCG\nGCTA\n+\nIIII\nJJJJ\n@read2\nAAAA\n+\nKKKK\n";
954        let bounds = scan_fastq_boundaries(data);
955        assert_eq!(bounds[0], 0);
956        assert!(bounds.contains(&29));
957    }
958
959    #[test]
960    fn test_scan_fastq_boundaries_at_in_quality() {
961        let data = b"@read1\nATCG\n+\n@@@I\n@read2\nGCTA\n+\nIIII\n";
962        let bounds = scan_fastq_boundaries(data);
963        assert_eq!(bounds, vec![0, 19]);
964    }
965
966    #[test]
967    fn test_scan_fastq_boundaries_at_followed_by_at_skipped() {
968        let data = b"@read1\nATCG\n+\nIIII\n@@ambiguous\n";
969        assert_eq!(scan_fastq_boundaries(data), vec![0]);
970    }
971
972    #[test]
973    fn test_scan_fastq_boundaries_iupac_codes() {
974        let data = b"@read1\nRYSW\n+\nIIII\n@read2\nKMBD\n+\nIIII\n";
975        let bounds = scan_fastq_boundaries(data);
976        assert_eq!(bounds, vec![0, 19]);
977    }
978
979    #[test]
980    fn test_is_compressed() {
981        assert!(is_compressed([0x1F, 0x8B]));
982        assert!(is_compressed([0x42, 0x5A]));
983        assert!(is_compressed([0xFD, 0x37]));
984        assert!(is_compressed([0x28, 0xB5]));
985        assert!(!is_compressed([b'>', b's']));
986        assert!(!is_compressed([b'@', b'r']));
987    }
988
989    #[test]
990    fn test_mmap_slice_reader() {
991        use std::io::Read;
992        let dir = tempfile::tempdir().unwrap();
993        let path = dir.path().join("test.bin");
994        std::fs::write(&path, b"Hello, World!").unwrap();
995
996        let file = File::open(&path).unwrap();
997        let mmap = Arc::new(unsafe { Mmap::map(&file).unwrap() });
998
999        let mut reader = MmapSliceReader::new(mmap, 7, 12);
1000        let mut buf = String::new();
1001        reader.read_to_string(&mut buf).unwrap();
1002        assert_eq!(buf, "World");
1003    }
1004
1005    #[test]
1006    fn test_tiny_file_errors() {
1007        let dir = tempfile::tempdir().unwrap();
1008
1009        let empty_path = dir.path().join("empty.fa");
1010        std::fs::write(&empty_path, b"").unwrap();
1011
1012        let config = SketchConfig::default();
1013        let result = run(&[empty_path], &config);
1014        let err = match result {
1015            Err(e) => e,
1016            Ok(_) => panic!("expected error for empty file"),
1017        };
1018        assert!(err.to_string().contains("too small"));
1019
1020        let tiny_path = dir.path().join("tiny.fa");
1021        std::fs::write(&tiny_path, b">").unwrap();
1022
1023        let result = run(&[tiny_path], &config);
1024        let err = match result {
1025            Err(e) => e,
1026            Ok(_) => panic!("expected error for 1-byte file"),
1027        };
1028        assert!(err.to_string().contains("too small"));
1029    }
1030
1031    #[test]
1032    fn test_fscale_zero_errors() {
1033        let input = make_fasta(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1034        let config = SketchConfig {
1035            fscale: 0,
1036            ..Default::default()
1037        };
1038
1039        let result = run(&[input.path().to_path_buf()], &config);
1040        let err = match result {
1041            Err(e) => e,
1042            Ok(_) => panic!("expected error for fscale=0"),
1043        };
1044        assert!(err.to_string().contains("fscale must be non-zero"));
1045    }
1046
1047    #[test]
1048    fn test_kmer_size_validation() {
1049        let input = make_fasta(&[("seq1", "ATCGATCGATCGATCGATCGATCGATCGATCG")]);
1050
1051        let config = SketchConfig {
1052            kmer_size: 0,
1053            memory: 1,
1054            ..Default::default()
1055        };
1056        let err = match run(&[input.path().to_path_buf()], &config) {
1057            Err(e) => e,
1058            Ok(_) => panic!("expected error for kmer_size=0"),
1059        };
1060        assert!(
1061            err.to_string()
1062                .contains("kmer_size must be between 1 and 31")
1063        );
1064
1065        let config = SketchConfig {
1066            kmer_size: 32,
1067            memory: 1,
1068            ..Default::default()
1069        };
1070        let err = match run(&[input.path().to_path_buf()], &config) {
1071            Err(e) => e,
1072            Ok(_) => panic!("expected error for kmer_size=32"),
1073        };
1074        assert!(
1075            err.to_string()
1076                .contains("kmer_size must be between 1 and 31")
1077        );
1078
1079        let config = SketchConfig {
1080            kmer_size: 31,
1081            fscale: 1,
1082            memory: 1,
1083            ..Default::default()
1084        };
1085        let result = run(&[input.path().to_path_buf()], &config);
1086        assert!(result.is_ok());
1087    }
1088}