use anyhow::{anyhow, Context, Result};
use clap::Parser;
use itertools::iproduct;
use ndarray::{ArrayBase, Ix2, OwnedRepr};
use noodles_fasta as fasta;
use std::collections::HashSet;
use std::fmt::Write as _;
use std::io::{BufRead, Error, Write};
use std::iter::FromIterator;
const IGNORE: u8 = b'.';
trait SortExt<T> {
fn argsort(&self) -> Vec<usize>;
fn sort_by_indices(&mut self, indices: &mut Vec<usize>);
}
impl<T: Ord + Clone> SortExt<T> for Vec<T> {
fn argsort(&self) -> Vec<usize> {
let mut indices = (0..self.len()).collect::<Vec<_>>();
indices.sort_by_key(|&i| &self[i]);
indices
}
fn sort_by_indices(&mut self, indices: &mut Vec<usize>) {
for idx in 0..self.len() {
if indices[idx] != usize::MAX {
let mut current_idx = idx;
loop {
let target_idx = indices[current_idx];
indices[current_idx] = usize::MAX;
if indices[target_idx] == usize::MAX {
break;
}
self.swap(current_idx, target_idx);
current_idx = target_idx;
}
}
}
}
}
fn parse_ignored_chars(s: &str) -> HashSet<u8> {
HashSet::from_iter(s.as_bytes().to_vec())
}
#[derive(Parser, Debug, Default)]
pub struct Transformer {
#[clap(short, long)]
case_sensitive: bool,
#[clap(short, long)]
sort: bool,
#[clap(short = 'e', long, default_value="N-", parse(from_str=parse_ignored_chars), allow_hyphen_values = true)]
ignored_chars: HashSet<u8>,
}
type NamesAndSeqs = (Vec<Vec<u8>>, Vec<Vec<u8>>);
impl Transformer {
pub fn load_alignment<R: BufRead>(
&self,
reader: &mut fasta::Reader<R>,
starting_seqlen: usize,
) -> Result<NamesAndSeqs, anyhow::Error> {
let mut seqlen: usize = starting_seqlen;
let mut names: Vec<Vec<u8>> = vec![];
let mut seqs: Vec<Vec<u8>> = vec![];
for result in reader.records() {
let record = result.context("Failed to parse record")?;
names.push(record.name().to_owned());
if seqlen > 0 && seqlen != record.sequence().len() {
return Err(anyhow!(format!(
"Alignment sequences must all be the same length [id: {}]",
String::from_utf8_lossy(record.name())
)));
} else if seqlen == 0 {
seqlen = record.sequence().len();
}
let seq = record.sequence().as_ref();
seqs.push(seq.to_vec());
}
if self.sort {
let mut indices = names.argsort();
names.sort();
seqs.sort_by_indices(&mut indices);
}
let skip_transform = self.ignored_chars.is_empty() && !self.case_sensitive;
if !skip_transform {
for seq in seqs.iter_mut() {
self.transform(seq);
}
}
Ok((names, seqs))
}
fn transform(&self, seq: &mut Vec<u8>) {
for b in seq {
if !self.case_sensitive {
b.make_ascii_uppercase();
}
if self.ignored_chars.contains(b) {
IGNORE.clone_into(b);
}
}
}
}
fn dist(a: u8, b: u8) -> u64 {
(a != b && a != IGNORE && b != IGNORE) as u64
}
pub fn hamming_distance(a: &[u8], b: &[u8]) -> u64 {
a.iter().zip(b).fold(0, |acc, (x, y)| acc + dist(*x, *y))
}
pub trait ToTable {
fn to_csv(
&self,
ostream: &mut Box<dyn Write>,
delimiter: char,
column_names: &[Vec<u8>],
row_names: &[Vec<u8>],
) -> Result<(), Error>;
fn to_long(
&self,
ostream: &mut Box<dyn Write>,
delimiter: char,
column_names: &[Vec<u8>],
row_names: &[Vec<u8>],
) -> Result<(), Error>;
}
impl ToTable for ArrayBase<OwnedRepr<u64>, Ix2> {
fn to_csv(
&self,
ostream: &mut Box<dyn Write>,
delimiter: char,
column_names: &[Vec<u8>],
row_names: &[Vec<u8>],
) -> Result<(), Error> {
write!(ostream, "{}", delimiter)?;
let header = column_names
.iter()
.map(|x| String::from_utf8_lossy(x))
.collect::<Vec<_>>()
.join(&delimiter.to_string());
writeln!(ostream, "{}", header)?;
for (row_idx, row_name) in row_names.iter().enumerate() {
write!(ostream, "{}", String::from_utf8_lossy(row_name))?;
let row = self.row(row_idx);
let s = row.iter().fold(String::new(), |mut output, x| {
let _ = write!(output, "{}{}", delimiter, x);
output
});
writeln!(ostream, "{}", s)?;
}
Ok(())
}
fn to_long(
&self,
ostream: &mut Box<dyn Write>,
delimiter: char,
column_names: &[Vec<u8>],
row_names: &[Vec<u8>],
) -> Result<(), Error> {
for (i, j) in iproduct!(0..column_names.len(), 0..row_names.len()) {
let dist = self[[j, i]];
let c_name = &column_names[i];
let r_name = &row_names[j];
writeln!(
ostream,
"{}{d}{}{d}{}",
String::from_utf8_lossy(c_name),
String::from_utf8_lossy(r_name),
dist,
d = delimiter
)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_chars() {
let s = "N-X";
let actual = parse_ignored_chars(s);
let expected: HashSet<u8> = HashSet::from_iter([b'N', b'-', b'X']);
assert_eq!(actual, expected)
}
#[test]
fn alignments_all_have_same_length() {
let data = b">s1\nACGT\n>s0\nCCCC\n";
let mut reader = fasta::Reader::new(&data[..]);
let t: Transformer = Default::default();
let actual = t.load_alignment(&mut reader, 0).unwrap();
let expected = (
vec![b"s1".to_vec(), b"s0".to_vec()],
vec![b"ACGT".to_vec(), b"CCCC".to_vec()],
);
assert_eq!(actual, expected)
}
#[test]
fn alignments_do_not_all_have_same_length() {
let data = b">s0\nACGT\n>s1\nCCCCC\n";
let mut reader = fasta::Reader::new(&data[..]);
let t: Transformer = Default::default();
let actual = t.load_alignment(&mut reader, 0).unwrap_err();
assert!(actual.to_string().contains("[id: s1]"))
}
#[test]
fn alignments_do_not_have_same_length_as_starting_seqlen() {
let data = b">s0\nACGT\n>s1\nCCCC\n";
let mut reader = fasta::Reader::new(&data[..]);
let t: Transformer = Default::default();
let actual = t.load_alignment(&mut reader, 1).unwrap_err();
assert!(actual.to_string().contains("[id: s0]"))
}
#[test]
fn alignments_sorted_by_id() {
let data = b">s10\nACGT\n>s51\nCCCC\n>s0\nGGCC\n";
let mut reader = fasta::Reader::new(&data[..]);
let t: Transformer = Transformer {
sort: true,
..Default::default()
};
let actual = t.load_alignment(&mut reader, 0).unwrap();
let expected = (
vec![b"s0".to_vec(), b"s10".to_vec(), b"s51".to_vec()],
vec![b"GGCC".to_vec(), b"ACGT".to_vec(), b"CCCC".to_vec()],
);
assert_eq!(actual, expected)
}
#[test]
fn argsort() {
let v = vec![1, 7, 4, 2];
let i = v.argsort();
assert_eq!(i, &[0, 3, 2, 1]);
}
#[test]
fn argsort_with_str() {
let v = vec!["a", "c", "B", "-"];
let i = v.argsort();
assert_eq!(i, &[3, 2, 0, 1]);
}
#[test]
fn argsort_on_empty() {
let v: Vec<u8> = vec![];
let i = v.argsort();
assert!(i.is_empty());
}
#[test]
fn sort_by_index() {
let mut i = vec![5, 3, 2, 1, 6, 4, 0];
let mut v = vec!["a", "b", "c", "d", "e", "f", "g"];
v.sort_by_indices(&mut i);
assert_eq!(v, &["f", "d", "c", "b", "g", "e", "a"]);
}
#[test]
fn transform_preserve_case() {
let mut s = b"aC-t".to_vec();
let expected = s.clone();
let t: Transformer = Transformer {
case_sensitive: true,
..Default::default()
};
t.transform(&mut s);
assert_eq!(s, expected)
}
#[test]
fn transform_ignore_case() {
let mut s = b"aC-t".to_vec();
let t = Transformer {
case_sensitive: false,
..Default::default()
};
t.transform(&mut s);
let expected = b"AC-T".to_vec();
assert_eq!(s, expected)
}
#[test]
fn transform_ignore_chars() {
let ignore = HashSet::from_iter(b"N-x".to_vec());
let t = Transformer {
ignored_chars: ignore,
case_sensitive: true,
..Default::default()
};
let mut s = b"AxC-GNt".to_vec();
t.transform(&mut s);
let expected = vec![b'A', IGNORE, b'C', IGNORE, b'G', IGNORE, b't'];
assert_eq!(s, expected)
}
#[test]
fn transform_ignore_chars_case_sensitive() {
let ignore = HashSet::from_iter(b"N".to_vec());
let t = Transformer {
ignored_chars: ignore,
case_sensitive: true,
..Default::default()
};
let mut s = b"ACGnt".to_vec();
t.transform(&mut s);
let expected = vec![b'A', b'C', b'G', b'n', b't'];
assert_eq!(s, expected)
}
#[test]
fn transform_ignore_chars_case_insensitive() {
let ignore = HashSet::from_iter(b"N".to_vec());
let t = Transformer {
ignored_chars: ignore,
case_sensitive: false,
..Default::default()
};
let mut s = b"ACGnt".to_vec();
t.transform(&mut s);
let expected = vec![b'A', b'C', b'G', IGNORE, b'T'];
assert_eq!(s, expected)
}
#[test]
fn test_hamming_distance() {
let a = vec![b'A', IGNORE, b't', b'C', b'-'];
let b = vec![b'A', b'T', b'T', b'C', b'G'];
let actual = hamming_distance(&a, &b);
let expected = 2;
assert_eq!(actual, expected)
}
}