use std::{collections::HashMap, fmt::Debug, io::Write, path::Path};
use crate::{
cli::OutputFormat,
error::{BuilderError, KmerLengthError},
format::SequenceFormat,
kmer::KmerLength,
progress::Progress,
run::{count_kmers_with_format, count_kmers_with_progress, run_with_options},
};
#[derive(Debug, Clone)]
pub struct KmerCounter {
k: Option<KmerLength>,
min_count: u64,
format: OutputFormat,
input_format: SequenceFormat,
}
impl Default for KmerCounter {
fn default() -> Self {
Self::new()
}
}
impl KmerCounter {
#[must_use]
pub const fn new() -> Self {
Self {
k: None,
min_count: 1,
format: OutputFormat::Fasta,
input_format: SequenceFormat::Auto,
}
}
pub fn k(mut self, k: usize) -> Result<Self, KmerLengthError> {
self.k = Some(KmerLength::new(k)?);
Ok(self)
}
#[must_use]
pub const fn k_validated(mut self, k: KmerLength) -> Self {
self.k = Some(k);
self
}
#[must_use]
pub const fn min_count(mut self, min_count: u64) -> Self {
self.min_count = min_count;
self
}
#[must_use]
pub const fn format(mut self, format: OutputFormat) -> Self {
self.format = format;
self
}
#[must_use]
pub const fn input_format(mut self, format: SequenceFormat) -> Self {
self.input_format = format;
self
}
pub fn count<P>(&self, path: P) -> Result<HashMap<String, u64>, BuilderError>
where
P: AsRef<Path> + Debug,
{
let k = self.k.ok_or(BuilderError::KmerLengthNotSet)?;
let counts = count_kmers_with_format(&path, k.get(), self.input_format)?;
if self.min_count > 1 {
Ok(counts
.into_iter()
.filter(|(_, count)| *count >= self.min_count)
.collect())
} else {
Ok(counts)
}
}
pub fn histogram<P>(&self, path: P) -> Result<crate::histogram::KmerHistogram, BuilderError>
where
P: AsRef<Path> + Debug,
{
use crate::histogram::compute_histogram;
let counts = self.count(&path)?;
Ok(compute_histogram(&counts))
}
pub fn count_with_progress<P, F>(
&self,
path: P,
callback: F,
) -> Result<HashMap<String, u64>, BuilderError>
where
P: AsRef<Path> + Debug,
F: Fn(Progress) + Send + Sync + 'static,
{
let k = self.k.ok_or(BuilderError::KmerLengthNotSet)?;
let counts = count_kmers_with_progress(&path, k.get(), callback)?;
if self.min_count > 1 {
Ok(counts
.into_iter()
.filter(|(_, count)| *count >= self.min_count)
.collect())
} else {
Ok(counts)
}
}
pub fn run<P>(&self, path: P) -> Result<(), BuilderError>
where
P: AsRef<Path> + Debug,
{
let k = self.k.ok_or(BuilderError::KmerLengthNotSet)?;
run_with_options(path, k.get(), self.format, self.min_count)?;
Ok(())
}
pub fn count_to_writer<P, W>(&self, path: P, mut writer: W) -> Result<(), BuilderError>
where
P: AsRef<Path> + Debug,
W: Write,
{
let counts = self.count(&path)?;
match self.format {
OutputFormat::Fasta => {
for (kmer, count) in counts {
writeln!(writer, ">{count}\n{kmer}")?;
}
}
OutputFormat::Tsv => {
for (kmer, count) in counts {
writeln!(writer, "{kmer}\t{count}")?;
}
}
OutputFormat::Json => {
#[derive(serde::Serialize)]
struct KmerCount {
kmer: String,
count: u64,
}
let json_data: Vec<KmerCount> = counts
.into_iter()
.map(|(kmer, count)| KmerCount { kmer, count })
.collect();
serde_json::to_writer_pretty(&mut writer, &json_data)?;
writeln!(writer)?;
}
OutputFormat::Histogram => {
use crate::histogram::compute_histogram;
let histogram = compute_histogram(&counts);
for (count, frequency) in histogram {
writeln!(writer, "{count}\t{frequency}")?;
}
}
}
writer.flush()?;
Ok(())
}
#[cfg(feature = "mmap")]
pub fn count_mmap<P>(&self, path: P) -> Result<HashMap<String, u64>, BuilderError>
where
P: AsRef<Path> + Debug,
{
use crate::run::count_kmers_mmap;
let k = self.k.ok_or(BuilderError::KmerLengthNotSet)?;
let counts = count_kmers_mmap(&path, k.get())?;
if self.min_count > 1 {
Ok(counts
.into_iter()
.filter(|(_, count)| *count >= self.min_count)
.collect())
} else {
Ok(counts)
}
}
pub fn count_streaming<P>(&self, path: P) -> Result<HashMap<String, u64>, BuilderError>
where
P: AsRef<Path> + Debug,
{
use crate::streaming::count_kmers_streaming;
let k = self.k.ok_or(BuilderError::KmerLengthNotSet)?;
let counts = count_kmers_streaming(&path, k.get())?;
if self.min_count > 1 {
Ok(counts
.into_iter()
.filter(|(_, count)| *count >= self.min_count)
.collect())
} else {
Ok(counts)
}
}
#[must_use]
pub const fn get_k(&self) -> Option<KmerLength> {
self.k
}
#[must_use]
pub const fn get_min_count(&self) -> u64 {
self.min_count
}
#[must_use]
pub const fn get_format(&self) -> OutputFormat {
self.format
}
#[must_use]
pub const fn get_input_format(&self) -> SequenceFormat {
self.input_format
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn builder_default() {
let counter = KmerCounter::new();
assert!(counter.get_k().is_none());
assert_eq!(counter.get_min_count(), 1);
}
#[test]
fn builder_k_valid() {
let counter = KmerCounter::new().k(21).unwrap();
assert_eq!(counter.get_k().unwrap().get(), 21);
}
#[test]
fn builder_k_invalid() {
let result = KmerCounter::new().k(0);
assert!(result.is_err());
let result = KmerCounter::new().k(33);
assert!(result.is_err());
}
#[test]
fn builder_k_validated() {
let k = KmerLength::new(21).unwrap();
let counter = KmerCounter::new().k_validated(k);
assert_eq!(counter.get_k().unwrap().get(), 21);
}
#[test]
fn builder_min_count() {
let counter = KmerCounter::new().min_count(5);
assert_eq!(counter.get_min_count(), 5);
}
#[test]
fn builder_format() {
let counter = KmerCounter::new().format(OutputFormat::Tsv);
assert!(matches!(counter.get_format(), OutputFormat::Tsv));
}
#[test]
fn builder_chained() {
let counter = KmerCounter::new()
.k(21)
.unwrap()
.min_count(3)
.format(OutputFormat::Json);
assert_eq!(counter.get_k().unwrap().get(), 21);
assert_eq!(counter.get_min_count(), 3);
assert!(matches!(counter.get_format(), OutputFormat::Json));
}
#[test]
fn builder_count_without_k_fails() {
let counter = KmerCounter::new();
let result = counter.count("nonexistent.fa");
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("k-mer length not set"));
}
#[test]
fn builder_write_tsv_format() {
let counts: HashMap<String, u64> = [("ACGT".to_string(), 5), ("TGCA".to_string(), 3)]
.into_iter()
.collect();
let mut output = Cursor::new(Vec::new());
let _counter = KmerCounter::new().k(4).unwrap().format(OutputFormat::Tsv);
for (kmer, count) in &counts {
writeln!(output, "{kmer}\t{count}").unwrap();
}
let result = String::from_utf8(output.into_inner()).unwrap();
assert!(result.contains("ACGT\t5") || result.contains("TGCA\t3"));
}
}