use itertools::{iproduct, Itertools};
use ndarray::Array;
use rayon::prelude::*;
use std::ffi::OsStr;
use std::fs::File;
use std::io::{stdout, BufReader, BufWriter, Write};
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use anyhow::{Context, Result};
use clap::Parser;
use log::info;
use log::LevelFilter;
use noodles_fasta as fasta;
use psdm::{hamming_distance, ToTable, Transformer};
fn path_exists<S: AsRef<OsStr> + ?Sized>(s: &S) -> Result<PathBuf, String> {
let path = PathBuf::from(s);
if path.exists() {
Ok(path)
} else {
Err(format!("{:?} does not exist", path))
}
}
fn parse_delim(s: &str) -> Result<char, String> {
let strip = &['\'', '"', ' '][..];
let stripped = s.replace(strip, "").replace("\\\\", "\\");
let chars: Vec<char> = stripped.chars().collect();
if chars.len() > 1 {
if chars[0] == '\\' && chars[1] == 't' {
Ok('\t')
} else {
Err(format!("Too many delimiter characters given {:?}", chars))
}
} else {
Ok(chars[0])
}
}
#[derive(Parser, Debug)]
#[clap(author, version, about, verbatim_doc_comment)]
struct Opt {
#[clap(min_values = 1, max_values = 2, parse(try_from_os_str = path_exists))]
alignments: Vec<PathBuf>,
#[clap(short, long, parse(from_os_str))]
output: Option<PathBuf>,
#[clap(short, long, default_value = "1")]
threads: usize,
#[clap(short, long = "long")]
long_form: bool,
#[clap(short, long = "delim", default_value = ",", parse(try_from_str=parse_delim))]
delimiter: char,
#[clap(short = 'P', long = "progress")]
show_progress: bool,
#[clap(short, long)]
quiet: bool,
#[clap(flatten)]
transformer: Transformer,
}
fn main() -> Result<()> {
let opts = Opt::parse();
let log_lvl = if opts.quiet {
LevelFilter::Error
} else {
LevelFilter::Info
};
let mut log_builder = env_logger::builder();
log_builder
.filter(None, log_lvl)
.format_module_path(false)
.init();
rayon::ThreadPoolBuilder::new()
.num_threads(opts.threads)
.build_global()?;
info!("Using {} thread(s)", rayon::current_num_threads());
let mut ostream: Box<dyn Write> = match opts.output {
None => Box::new(stdout()),
Some(ref p) => {
let file = File::create(p).context("Failed to create output file")?;
Box::new(BufWriter::new(file))
}
};
let mut reader1 = niffler::from_path(&opts.alignments[0])
.map(|(r, _)| BufReader::new(r))
.map(fasta::Reader::new)
.context("Could not open first alignment file")?;
info!("Loading first alignment file...");
let (names1, seqs1) = opts
.transformer
.load_alignment(&mut reader1, 0)
.context("Failed to load first alignment file")?;
info!(
"Loaded {} sequences with length {}bp",
seqs1.len(),
seqs1[0].len()
);
let (names2, seqs2) = match opts.alignments.get(1) {
Some(p) => {
let mut reader2 = niffler::from_path(p)
.map(|(r, _)| BufReader::new(r))
.map(fasta::Reader::new)
.context("Could not open second alignment file")?;
info!("Loading second alignment file...");
let (n, s) = opts
.transformer
.load_alignment(&mut reader2, seqs1[0].len())
.context("Failed to load second alignment file")?;
info!("Loaded {} sequences with length {}bp", s.len(), s[0].len());
(Some(n), Some(s))
}
None => (None, None),
};
let n_seqs1 = seqs1.len();
let n_seqs2: usize = match seqs2 {
None => 0,
Some(ref s) => s.len(),
};
let pairwise_indices: Vec<Vec<usize>> = match n_seqs2 {
0 => (0..n_seqs1).combinations_with_replacement(2).collect(),
i => iproduct!(0..n_seqs1, 0..i)
.map(|t| vec![t.0, t.1])
.collect(),
};
let num_items = pairwise_indices.len();
let counter = Arc::new(AtomicUsize::new(0));
let progress_interval = std::cmp::min((num_items as f64 / 100.0).ceil() as usize, 100);
info!("Calculating {num_items} pairwise distances...",);
let dists: Vec<u64> = pairwise_indices
.as_slice()
.into_par_iter()
.map_with(Arc::clone(&counter), |counter, ix| {
let i = ix[0];
let j = ix[1];
let distance = match &seqs2 {
None if i == j => 0, None => hamming_distance(&seqs1[i], &seqs1[j]),
Some(ref s) => hamming_distance(&seqs1[i], &s[j]),
};
let current_count = counter.fetch_add(1, Ordering::SeqCst) + 1;
if opts.show_progress && current_count % progress_interval == 0 {
let progress = (current_count as f64 / num_items as f64) * 100.0;
eprint!(
"\rProgress: {:.2}% ({} / {})",
progress, current_count, num_items
);
match std::io::stderr().flush() {
Ok(_) => (),
Err(e) => eprintln!("Error occurred when flushing stderr: {:?}", e),
}
}
distance
})
.collect();
if opts.show_progress {
eprintln!();
}
let matrix =
if n_seqs2 > 0 {
Array::from_shape_vec((n_seqs1, n_seqs2), dists).context(
"Failed to create matrix. This shouldn't happen, please raise an issue on GitHub",
)?.t().to_owned()
} else {
let mut mtx = Array::zeros((n_seqs1, n_seqs1));
for (ix, d) in pairwise_indices.iter().zip(dists) {
let i = ix[0];
let j = ix[1];
mtx[[i, j]] = d;
if i != j {
mtx[[j, i]] = d;
}
}
mtx
};
info!("Finished computing distances");
let row_names: &Vec<Vec<u8>> = match &names2 {
Some(n) => n,
None => &names1,
};
let col_names: &Vec<Vec<u8>> = &names1;
if opts.long_form {
info!("Writing long-form table...");
matrix
.to_long(&mut ostream, opts.delimiter, col_names, row_names)
.context("Failed to write output table")?;
} else {
info!("Writing matrix...");
matrix
.to_csv(&mut ostream, opts.delimiter, col_names, row_names)
.context("Failed to write output table")?;
}
info!("Done!");
Ok(())
}
#[cfg(test)]
mod tests {
use std::ffi::OsStr;
use super::*;
#[test]
fn check_path_exists_it_doesnt() {
let result = path_exists(OsStr::new("fake.path"));
assert!(result.is_err())
}
#[test]
fn check_path_it_does() {
let actual = path_exists(OsStr::new("Cargo.toml")).unwrap();
let expected = PathBuf::from("Cargo.toml");
assert_eq!(actual, expected)
}
}