use std::{collections::HashMap, fmt::Debug, path::Path};
use tokio::task;
use crate::{
error::KmeRustError,
kmer::{unpack_to_string, KmerLength},
streaming::count_kmers_streaming_packed,
};
pub async fn count_kmers_async<P>(
path: P,
k: usize,
) -> Result<HashMap<String, u64>, Box<dyn std::error::Error + Send + Sync>>
where
P: AsRef<Path> + Debug + Send + 'static,
{
let k_len = KmerLength::new(k)?;
let packed = task::spawn_blocking(move || count_kmers_streaming_packed(path, k_len)).await??;
let result: HashMap<String, u64> = packed
.into_iter()
.map(|(bits, count)| (unpack_to_string(bits, k_len), count))
.collect();
Ok(result)
}
pub async fn count_kmers_packed_async<P>(
path: P,
k: KmerLength,
) -> Result<HashMap<u64, u64>, Box<dyn std::error::Error + Send + Sync>>
where
P: AsRef<Path> + Debug + Send + 'static,
{
let result = task::spawn_blocking(move || count_kmers_streaming_packed(path, k)).await??;
Ok(result)
}
#[derive(Debug, Clone)]
pub struct AsyncKmerCounter {
k: Option<KmerLength>,
min_count: u64,
}
impl Default for AsyncKmerCounter {
fn default() -> Self {
Self::new()
}
}
impl AsyncKmerCounter {
#[must_use]
pub const fn new() -> Self {
Self {
k: None,
min_count: 1,
}
}
pub fn k(mut self, k: usize) -> Result<Self, crate::error::KmerLengthError> {
self.k = Some(KmerLength::new(k)?);
Ok(self)
}
#[must_use]
pub const fn k_validated(mut self, k: KmerLength) -> Self {
self.k = Some(k);
self
}
#[must_use]
pub const fn min_count(mut self, min_count: u64) -> Self {
self.min_count = min_count;
self
}
pub async fn count<P>(
&self,
path: P,
) -> Result<HashMap<String, u64>, Box<dyn std::error::Error + Send + Sync>>
where
P: AsRef<Path> + Debug + Send + 'static,
{
let k = self.k.ok_or(KmeRustError::InvalidKmerLength {
k: 0,
min: 1,
max: 32,
})?;
let min_count = self.min_count;
let counts = count_kmers_async(path, k.get()).await?;
if min_count > 1 {
Ok(counts
.into_iter()
.filter(|(_, count)| *count >= min_count)
.collect())
} else {
Ok(counts)
}
}
pub async fn count_packed<P>(
&self,
path: P,
) -> Result<HashMap<u64, u64>, Box<dyn std::error::Error + Send + Sync>>
where
P: AsRef<Path> + Debug + Send + 'static,
{
let k = self.k.ok_or(KmeRustError::InvalidKmerLength {
k: 0,
min: 1,
max: 32,
})?;
let min_count = self.min_count;
let counts = count_kmers_packed_async(path, k).await?;
if min_count > 1 {
Ok(counts
.into_iter()
.filter(|(_, count)| *count >= min_count)
.collect())
} else {
Ok(counts)
}
}
#[must_use]
pub const fn get_k(&self) -> Option<KmerLength> {
self.k
}
#[must_use]
pub const fn get_min_count(&self) -> u64 {
self.min_count
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn async_counter_default() {
let counter = AsyncKmerCounter::new();
assert!(counter.get_k().is_none());
assert_eq!(counter.get_min_count(), 1);
}
#[test]
fn async_counter_k_valid() {
let counter = AsyncKmerCounter::new().k(21).unwrap();
assert_eq!(counter.get_k().unwrap().get(), 21);
}
#[test]
fn async_counter_k_invalid() {
let result = AsyncKmerCounter::new().k(0);
assert!(result.is_err());
let result = AsyncKmerCounter::new().k(33);
assert!(result.is_err());
}
#[test]
fn async_counter_chained() {
let counter = AsyncKmerCounter::new().k(21).unwrap().min_count(5);
assert_eq!(counter.get_k().unwrap().get(), 21);
assert_eq!(counter.get_min_count(), 5);
}
}