use std::{collections::HashMap, fmt::Debug, io::BufRead, path::Path};
use bytes::Bytes;
use dashmap::DashMap;
use rayon::prelude::*;
use rustc_hash::FxHasher;
use std::hash::BuildHasherDefault;
use crate::{
error::KmeRustError,
format::SequenceFormat,
input::Input,
kmer::{pack_canonical, unpack_to_string, KmerLength},
};
#[cfg(feature = "tracing")]
use tracing::{debug, info, info_span};
pub fn count_kmers_streaming<P>(path: P, k: usize) -> Result<HashMap<String, u64>, KmeRustError>
where
P: AsRef<Path> + Debug,
{
#[cfg(feature = "tracing")]
info!(k = k, path = ?path, "Starting streaming k-mer counting");
let k_len = KmerLength::new(k)?;
let packed = count_kmers_streaming_packed(&path, k_len)?;
#[cfg(feature = "tracing")]
let _unpack_span = info_span!("unpack_kmers", count = packed.len()).entered();
let result: HashMap<String, u64> = packed
.into_par_iter()
.map(|(bits, count)| (unpack_to_string(bits, k_len), count))
.collect();
#[cfg(feature = "tracing")]
info!(
unique_kmers = result.len(),
"Streaming k-mer counting complete"
);
Ok(result)
}
pub fn count_kmers_streaming_packed<P>(
path: P,
k: KmerLength,
) -> Result<HashMap<u64, u64>, KmeRustError>
where
P: AsRef<Path> + Debug,
{
let counter = StreamingKmerCounter::new();
counter.count_file(path, k)
}
pub fn count_kmers_from_sequences<I>(sequences: I, k: KmerLength) -> HashMap<u64, u64>
where
I: Iterator<Item = Bytes>,
{
let counter = StreamingKmerCounter::new();
counter.count_sequences(sequences, k)
}
pub fn count_kmers_sequential<P>(path: P, k: usize) -> Result<HashMap<u64, u64>, KmeRustError>
where
P: AsRef<Path> + Debug,
{
let k_len = KmerLength::new(k)?;
let counter = SequentialKmerCounter::new();
counter.count_file(path, k_len)
}
pub fn count_kmers_stdin(k: usize) -> Result<HashMap<String, u64>, KmeRustError> {
count_kmers_stdin_with_format(k, SequenceFormat::Auto)
}
pub fn count_kmers_stdin_with_format(
k: usize,
format: SequenceFormat,
) -> Result<HashMap<String, u64>, KmeRustError> {
let k_len = KmerLength::new(k)?;
let resolved_format = format.resolve(None);
let packed = {
let stdin = std::io::stdin();
let reader = stdin.lock();
count_kmers_from_reader_impl_with_format(reader, k_len, resolved_format)?
};
Ok(packed
.into_iter()
.map(|(bits, count)| (unpack_to_string(bits, k_len), count))
.collect())
}
pub fn count_kmers_stdin_packed(k: KmerLength) -> Result<HashMap<u64, u64>, KmeRustError> {
let stdin = std::io::stdin();
let reader = stdin.lock();
count_kmers_from_reader_impl(reader, k)
}
pub fn count_kmers_from_reader<R>(reader: R, k: usize) -> Result<HashMap<String, u64>, KmeRustError>
where
R: BufRead,
{
let k_len = KmerLength::new(k)?;
let packed = count_kmers_from_reader_impl(reader, k_len)?;
Ok(packed
.into_iter()
.map(|(bits, count)| (unpack_to_string(bits, k_len), count))
.collect())
}
pub fn count_kmers_from_reader_packed<R>(
reader: R,
k: KmerLength,
) -> Result<HashMap<u64, u64>, KmeRustError>
where
R: BufRead,
{
count_kmers_from_reader_impl(reader, k)
}
pub fn count_kmers_from_input(
input: &Input,
k: usize,
) -> Result<HashMap<String, u64>, KmeRustError> {
match input {
Input::File(path) => count_kmers_streaming(path, k),
Input::Stdin => count_kmers_stdin(k),
}
}
pub fn count_kmers_from_input_packed(
input: &Input,
k: KmerLength,
) -> Result<HashMap<u64, u64>, KmeRustError> {
match input {
Input::File(path) => count_kmers_streaming_packed(path, k),
Input::Stdin => count_kmers_stdin_packed(k),
}
}
#[cfg(not(feature = "needletail"))]
fn count_kmers_from_reader_impl<R>(
reader: R,
k: KmerLength,
) -> Result<HashMap<u64, u64>, KmeRustError>
where
R: BufRead,
{
count_kmers_from_reader_impl_with_format(reader, k, SequenceFormat::Fasta)
}
#[cfg(not(feature = "needletail"))]
fn count_kmers_from_reader_impl_with_format<R>(
reader: R,
k: KmerLength,
format: SequenceFormat,
) -> Result<HashMap<u64, u64>, KmeRustError>
where
R: BufRead,
{
use bio::io::{fasta, fastq};
let mut counts: HashMap<u64, u64, BuildHasherDefault<FxHasher>> =
HashMap::with_hasher(BuildHasherDefault::default());
match format {
SequenceFormat::Fastq => {
let fastq_reader = fastq::Reader::new(reader);
for result in fastq_reader.records() {
let record = result.map_err(|e| KmeRustError::SequenceParse {
details: e.to_string(),
})?;
process_sequence_into_counts(&mut counts, record.seq(), None, k, None);
}
}
SequenceFormat::Fasta | SequenceFormat::Auto => {
let fasta_reader = fasta::Reader::new(reader);
for result in fasta_reader.records() {
let record = result.map_err(|e| KmeRustError::SequenceParse {
details: e.to_string(),
})?;
process_sequence_into_counts(&mut counts, record.seq(), None, k, None);
}
}
}
Ok(counts.into_iter().collect())
}
#[cfg(feature = "needletail")]
fn count_kmers_from_reader_impl<R>(
reader: R,
k: KmerLength,
) -> Result<HashMap<u64, u64>, KmeRustError>
where
R: BufRead,
{
count_kmers_from_reader_impl_with_format(reader, k, SequenceFormat::Auto)
}
#[cfg(feature = "needletail")]
fn count_kmers_from_reader_impl_with_format<R>(
mut reader: R,
k: KmerLength,
_format: SequenceFormat,
) -> Result<HashMap<u64, u64>, KmeRustError>
where
R: BufRead,
{
use std::io::Cursor;
let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.map_err(|e| KmeRustError::SequenceParse {
details: format!("failed to read input: {e}"),
})?;
let mut parser = needletail::parse_fastx_reader(Cursor::new(buffer)).map_err(|e| {
KmeRustError::SequenceParse {
details: e.to_string(),
}
})?;
let mut counts: HashMap<u64, u64, BuildHasherDefault<FxHasher>> =
HashMap::with_hasher(BuildHasherDefault::default());
while let Some(result) = parser.next() {
let record = result.map_err(|e| KmeRustError::SequenceParse {
details: e.to_string(),
})?;
process_sequence_into_counts(&mut counts, &record.seq(), None, k, None);
}
Ok(counts.into_iter().collect())
}
fn process_sequence_into_counts(
counts: &mut HashMap<u64, u64, BuildHasherDefault<FxHasher>>,
seq: &[u8],
qual: Option<&[u8]>,
k: KmerLength,
min_quality: Option<u8>,
) {
let k_val = k.get();
if seq.len() < k_val {
return;
}
let quality_threshold = min_quality.map(|q| q.saturating_add(33));
let mut i = 0;
while i <= seq.len() - k_val {
if let (Some(q), Some(threshold)) = (qual, quality_threshold) {
if let Some(bad_pos) = q[i..i + k_val].iter().position(|&qv| qv < threshold) {
i += bad_pos + 1; continue;
}
}
match pack_canonical(&seq[i..i + k_val]) {
Ok(canonical_bits) => {
*counts.entry(canonical_bits).or_insert(0) += 1;
i += 1;
}
Err(err) => {
i += err.position + 1;
}
}
}
}
struct SequentialKmerCounter {
counts: HashMap<u64, u64, BuildHasherDefault<FxHasher>>,
}
impl SequentialKmerCounter {
fn new() -> Self {
Self {
counts: HashMap::with_hasher(BuildHasherDefault::<FxHasher>::default()),
}
}
#[cfg(not(feature = "needletail"))]
fn count_file<P>(mut self, path: P, k: KmerLength) -> Result<HashMap<u64, u64>, KmeRustError>
where
P: AsRef<Path> + Debug,
{
use bio::io::{fasta, fastq};
let path_ref = path.as_ref();
let format = SequenceFormat::from_extension(path_ref);
#[cfg(feature = "tracing")]
let _span = info_span!("sequential_count", path = ?path_ref, ?format).entered();
#[cfg(feature = "gzip")]
let is_gzip = path_ref.extension().map(|ext| ext == "gz").unwrap_or(false);
#[cfg(feature = "gzip")]
if is_gzip {
use flate2::read::GzDecoder;
use std::{fs::File, io::BufReader};
let file = File::open(path_ref).map_err(|e| KmeRustError::SequenceRead {
source: e,
path: path_ref.to_path_buf(),
})?;
let decoder = GzDecoder::new(file);
let buf_reader = BufReader::new(decoder);
match format {
SequenceFormat::Fastq => {
let reader = fastq::Reader::new(buf_reader);
for result in reader.records() {
let record = result.map_err(|e| KmeRustError::SequenceParse {
details: e.to_string(),
})?;
self.process_sequence(record.seq(), None, k, None);
}
}
SequenceFormat::Fasta | SequenceFormat::Auto => {
let reader = fasta::Reader::new(buf_reader);
for result in reader.records() {
let record = result.map_err(|e| KmeRustError::SequenceParse {
details: e.to_string(),
})?;
self.process_sequence(record.seq(), None, k, None);
}
}
}
return Ok(self.counts.into_iter().collect());
}
match format {
SequenceFormat::Fastq => {
let reader =
fastq::Reader::from_file(path_ref).map_err(|e| KmeRustError::SequenceRead {
source: std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()),
path: path_ref.to_path_buf(),
})?;
for result in reader.records() {
let record = result.map_err(|e| KmeRustError::SequenceParse {
details: e.to_string(),
})?;
self.process_sequence(record.seq(), None, k, None);
}
}
SequenceFormat::Fasta | SequenceFormat::Auto => {
let reader =
fasta::Reader::from_file(path_ref).map_err(|e| KmeRustError::SequenceRead {
source: std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()),
path: path_ref.to_path_buf(),
})?;
for result in reader.records() {
let record = result.map_err(|e| KmeRustError::SequenceParse {
details: e.to_string(),
})?;
self.process_sequence(record.seq(), None, k, None);
}
}
}
Ok(self.counts.into_iter().collect())
}
#[cfg(feature = "needletail")]
fn count_file<P>(mut self, path: P, k: KmerLength) -> Result<HashMap<u64, u64>, KmeRustError>
where
P: AsRef<Path> + Debug,
{
let path_ref = path.as_ref();
#[cfg(feature = "tracing")]
let _span = info_span!("sequential_count", path = ?path_ref).entered();
let mut reader =
needletail::parse_fastx_file(path_ref).map_err(|e| KmeRustError::SequenceRead {
source: std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()),
path: path_ref.to_path_buf(),
})?;
while let Some(result) = reader.next() {
let record = result.map_err(|e| KmeRustError::SequenceParse {
details: e.to_string(),
})?;
self.process_sequence(&record.seq(), None, k, None);
}
Ok(self.counts.into_iter().collect())
}
fn process_sequence(
&mut self,
seq: &[u8],
qual: Option<&[u8]>,
k: KmerLength,
min_quality: Option<u8>,
) {
let k_val = k.get();
if seq.len() < k_val {
return;
}
let quality_threshold = min_quality.map(|q| q.saturating_add(33));
let mut i = 0;
while i <= seq.len() - k_val {
if let (Some(q), Some(threshold)) = (qual, quality_threshold) {
if let Some(bad_pos) = q[i..i + k_val].iter().position(|&qv| qv < threshold) {
i += bad_pos + 1; continue;
}
}
match pack_canonical(&seq[i..i + k_val]) {
Ok(canonical_bits) => {
*self.counts.entry(canonical_bits).or_insert(0) += 1;
i += 1;
}
Err(err) => {
i += err.position + 1;
}
}
}
}
}
struct StreamingKmerCounter {
counts: DashMap<u64, u64, BuildHasherDefault<FxHasher>>,
}
impl StreamingKmerCounter {
fn new() -> Self {
Self {
counts: DashMap::with_hasher(BuildHasherDefault::<FxHasher>::default()),
}
}
#[cfg(all(not(feature = "needletail"), not(feature = "gzip")))]
fn count_file<P>(self, path: P, k: KmerLength) -> Result<HashMap<u64, u64>, KmeRustError>
where
P: AsRef<Path> + Debug,
{
use bio::io::{fasta, fastq};
let path_ref = path.as_ref();
let format = SequenceFormat::from_extension(path_ref);
#[cfg(feature = "tracing")]
let read_span = info_span!("read_sequences", path = ?path_ref, ?format).entered();
let sequences: Vec<Bytes> = match format {
SequenceFormat::Fastq => {
let reader =
fastq::Reader::from_file(path_ref).map_err(|e| KmeRustError::SequenceRead {
source: std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()),
path: path_ref.to_path_buf(),
})?;
reader
.records()
.map(|r| {
r.map(|rec| Bytes::copy_from_slice(rec.seq())).map_err(|e| {
KmeRustError::SequenceParse {
details: e.to_string(),
}
})
})
.collect::<Result<Vec<_>, _>>()?
}
SequenceFormat::Fasta | SequenceFormat::Auto => {
let reader =
fasta::Reader::from_file(path_ref).map_err(|e| KmeRustError::SequenceRead {
source: std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()),
path: path_ref.to_path_buf(),
})?;
reader
.records()
.map(|r| {
r.map(|rec| Bytes::copy_from_slice(rec.seq())).map_err(|e| {
KmeRustError::SequenceParse {
details: e.to_string(),
}
})
})
.collect::<Result<Vec<_>, _>>()?
}
};
#[cfg(feature = "tracing")]
{
drop(read_span);
debug!(sequences = sequences.len(), "Read sequences from file");
}
#[cfg(feature = "tracing")]
let process_span = info_span!("process_sequences", count = sequences.len()).entered();
sequences.par_iter().for_each(|seq| {
self.process_sequence(seq, None, k, None);
});
Ok(self.counts.into_iter().collect())
}
#[cfg(all(not(feature = "needletail"), feature = "gzip"))]
fn count_file<P>(self, path: P, k: KmerLength) -> Result<HashMap<u64, u64>, KmeRustError>
where
P: AsRef<Path> + Debug,
{
use bio::io::{fasta, fastq};
use flate2::read::GzDecoder;
use std::{fs::File, io::BufReader};
let path_ref = path.as_ref();
let format = SequenceFormat::from_extension(path_ref);
let is_gzip = path_ref.extension().map(|ext| ext == "gz").unwrap_or(false);
#[cfg(feature = "tracing")]
let read_span = info_span!("read_sequences", path = ?path_ref, ?format).entered();
let sequences: Vec<Bytes> = match (format, is_gzip) {
(SequenceFormat::Fastq, true) => {
let file = File::open(path_ref).map_err(|e| KmeRustError::SequenceRead {
source: e,
path: path_ref.to_path_buf(),
})?;
let decoder = GzDecoder::new(file);
let reader = fastq::Reader::new(BufReader::new(decoder));
reader
.records()
.map(|r| {
r.map(|rec| Bytes::copy_from_slice(rec.seq())).map_err(|e| {
KmeRustError::SequenceParse {
details: e.to_string(),
}
})
})
.collect::<Result<Vec<_>, _>>()?
}
(SequenceFormat::Fastq, false) => {
let reader =
fastq::Reader::from_file(path_ref).map_err(|e| KmeRustError::SequenceRead {
source: std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()),
path: path_ref.to_path_buf(),
})?;
reader
.records()
.map(|r| {
r.map(|rec| Bytes::copy_from_slice(rec.seq())).map_err(|e| {
KmeRustError::SequenceParse {
details: e.to_string(),
}
})
})
.collect::<Result<Vec<_>, _>>()?
}
(SequenceFormat::Fasta | SequenceFormat::Auto, true) => {
let file = File::open(path_ref).map_err(|e| KmeRustError::SequenceRead {
source: e,
path: path_ref.to_path_buf(),
})?;
let decoder = GzDecoder::new(file);
let reader = fasta::Reader::new(BufReader::new(decoder));
reader
.records()
.map(|r| {
r.map(|rec| Bytes::copy_from_slice(rec.seq())).map_err(|e| {
KmeRustError::SequenceParse {
details: e.to_string(),
}
})
})
.collect::<Result<Vec<_>, _>>()?
}
(SequenceFormat::Fasta | SequenceFormat::Auto, false) => {
let reader =
fasta::Reader::from_file(path_ref).map_err(|e| KmeRustError::SequenceRead {
source: std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()),
path: path_ref.to_path_buf(),
})?;
reader
.records()
.map(|r| {
r.map(|rec| Bytes::copy_from_slice(rec.seq())).map_err(|e| {
KmeRustError::SequenceParse {
details: e.to_string(),
}
})
})
.collect::<Result<Vec<_>, _>>()?
}
};
#[cfg(feature = "tracing")]
{
drop(read_span);
debug!(sequences = sequences.len(), "Read sequences from file");
}
#[cfg(feature = "tracing")]
let process_span = info_span!("process_sequences", count = sequences.len()).entered();
sequences.par_iter().for_each(|seq| {
self.process_sequence(seq, None, k, None);
});
Ok(self.counts.into_iter().collect())
}
#[cfg(feature = "needletail")]
fn count_file<P>(self, path: P, k: KmerLength) -> Result<HashMap<u64, u64>, KmeRustError>
where
P: AsRef<Path> + Debug,
{
let path_ref = path.as_ref();
#[cfg(feature = "tracing")]
let read_span = info_span!("read_fasta", path = ?path_ref).entered();
let mut reader =
needletail::parse_fastx_file(path_ref).map_err(|e| KmeRustError::SequenceRead {
source: std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()),
path: path_ref.to_path_buf(),
})?;
let mut sequences = Vec::new();
while let Some(record) = reader.next() {
let record = record.map_err(|e| KmeRustError::SequenceParse {
details: e.to_string(),
})?;
sequences.push(Bytes::copy_from_slice(&record.seq()));
}
#[cfg(feature = "tracing")]
{
drop(read_span);
debug!(sequences = sequences.len(), "Read sequences from file");
}
#[cfg(feature = "tracing")]
let _process_span = info_span!("process_sequences", count = sequences.len()).entered();
sequences.par_iter().for_each(|seq| {
self.process_sequence(seq, None, k, None);
});
Ok(self.counts.into_iter().collect())
}
fn count_sequences<I>(self, sequences: I, k: KmerLength) -> HashMap<u64, u64>
where
I: Iterator<Item = Bytes>,
{
for seq in sequences {
self.process_sequence(&seq, None, k, None);
}
self.counts.into_iter().collect()
}
fn process_sequence(
&self,
seq: &Bytes,
qual: Option<&[u8]>,
k: KmerLength,
min_quality: Option<u8>,
) {
let k_val = k.get();
if seq.len() < k_val {
return;
}
let quality_threshold = min_quality.map(|q| q.saturating_add(33));
let mut i = 0;
while i <= seq.len() - k_val {
if let (Some(q), Some(threshold)) = (qual, quality_threshold) {
if let Some(bad_pos) = q[i..i + k_val].iter().position(|&qv| qv < threshold) {
i += bad_pos + 1; continue;
}
}
match pack_canonical(&seq[i..i + k_val]) {
Ok(canonical_bits) => {
self.counts
.entry(canonical_bits)
.and_modify(|c| *c = c.saturating_add(1))
.or_insert(1);
i += 1;
}
Err(err) => {
i += err.position + 1;
}
}
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn count_from_sequences_basic() {
let sequences = vec![Bytes::from_static(b"ACGTACGT")];
let k = KmerLength::new(4).unwrap();
let counts = count_kmers_from_sequences(sequences.into_iter(), k);
assert!(!counts.is_empty());
}
#[test]
fn count_from_sequences_empty() {
let sequences: Vec<Bytes> = vec![];
let k = KmerLength::new(4).unwrap();
let counts = count_kmers_from_sequences(sequences.into_iter(), k);
assert!(counts.is_empty());
}
#[test]
fn count_from_sequences_short_sequence() {
let sequences = vec![Bytes::from_static(b"ACG")];
let k = KmerLength::new(4).unwrap();
let counts = count_kmers_from_sequences(sequences.into_iter(), k);
assert!(counts.is_empty());
}
#[test]
fn count_from_sequences_multiple() {
let sequences = vec![
Bytes::from_static(b"AAAA"),
Bytes::from_static(b"TTTT"), ];
let k = KmerLength::new(4).unwrap();
let counts = count_kmers_from_sequences(sequences.into_iter(), k);
assert_eq!(counts.len(), 1);
let count = counts.values().next().unwrap();
assert_eq!(*count, 2);
}
#[test]
fn quality_filtering_skips_low_quality_kmers() {
let mut counts: HashMap<u64, u64, BuildHasherDefault<FxHasher>> =
HashMap::with_hasher(BuildHasherDefault::default());
let seq = b"ACGTACGT";
let qual = b"IIII!!!!"; let k = KmerLength::new(4).unwrap();
let min_quality = Some(20u8);
process_sequence_into_counts(&mut counts, seq, Some(qual), k, min_quality);
assert_eq!(counts.len(), 1);
let count = counts.values().next().unwrap();
assert_eq!(*count, 1);
}
#[test]
fn quality_filtering_with_no_threshold_counts_all() {
let mut counts: HashMap<u64, u64, BuildHasherDefault<FxHasher>> =
HashMap::with_hasher(BuildHasherDefault::default());
let seq = b"ACGTACGT";
let qual = b"IIII!!!!";
let k = KmerLength::new(4).unwrap();
process_sequence_into_counts(&mut counts, seq, Some(qual), k, None);
assert!(!counts.is_empty());
}
#[test]
fn quality_filtering_with_zero_threshold_counts_all() {
let mut counts: HashMap<u64, u64, BuildHasherDefault<FxHasher>> =
HashMap::with_hasher(BuildHasherDefault::default());
let seq = b"ACGTACGT";
let qual = b"IIII!!!!";
let k = KmerLength::new(4).unwrap();
process_sequence_into_counts(&mut counts, seq, Some(qual), k, Some(0));
assert!(!counts.is_empty());
}
#[test]
fn quality_filtering_without_quality_data_counts_all() {
let mut counts: HashMap<u64, u64, BuildHasherDefault<FxHasher>> =
HashMap::with_hasher(BuildHasherDefault::default());
let seq = b"ACGTACGT";
let k = KmerLength::new(4).unwrap();
process_sequence_into_counts(&mut counts, seq, None, k, Some(20));
assert!(!counts.is_empty());
}
}