use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::hash::{BuildHasher, BuildHasherDefault, Hash};
use std::str;
use nohash_hasher::BuildNoHashHasher;
use once_cell::sync::Lazy;
use vec_collections::AbstractVecSet;
use crate::Error;
pub type Color = u64;
pub type Idx = u32;
type IdxTracker = (vec_collections::VecSet<[Idx; 8]>, u64);
type ColorToIdx = HashMap<Color, IdxTracker, BuildNoHashHasher<Color>>;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Serialize, rkyv::Deserialize, rkyv::Archive)
)]
#[non_exhaustive]
pub enum HashFunctions {
Murmur64Dna,
Murmur64Protein,
Murmur64Dayhoff,
Murmur64Hp,
Murmur64Skipm1n3,
Murmur64Skipm2n3,
Custom(String),
}
impl HashFunctions {
pub fn dna(&self) -> bool {
*self == HashFunctions::Murmur64Dna
}
pub fn protein(&self) -> bool {
*self == HashFunctions::Murmur64Protein
}
pub fn dayhoff(&self) -> bool {
*self == HashFunctions::Murmur64Dayhoff
}
pub fn hp(&self) -> bool {
*self == HashFunctions::Murmur64Hp
}
pub fn skipm1n3(&self) -> bool {
*self == HashFunctions::Murmur64Skipm1n3
}
pub fn skipm2n3(&self) -> bool {
*self == HashFunctions::Murmur64Skipm2n3
}
}
impl std::fmt::Display for HashFunctions {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"{}",
match self {
HashFunctions::Murmur64Dna => "DNA",
HashFunctions::Murmur64Protein => "protein",
HashFunctions::Murmur64Dayhoff => "dayhoff",
HashFunctions::Murmur64Hp => "hp",
HashFunctions::Murmur64Skipm1n3 => "skipm1n3",
HashFunctions::Murmur64Skipm2n3 => "skipm2n3",
HashFunctions::Custom(v) => v,
}
)
}
}
impl TryFrom<&str> for HashFunctions {
type Error = Error;
fn try_from(moltype: &str) -> Result<Self, Self::Error> {
match moltype.to_lowercase().as_ref() {
"dna" => Ok(HashFunctions::Murmur64Dna),
"dayhoff" => Ok(HashFunctions::Murmur64Dayhoff),
"hp" => Ok(HashFunctions::Murmur64Hp),
"protein" => Ok(HashFunctions::Murmur64Protein),
"skipm1n3" => Ok(HashFunctions::Murmur64Skipm1n3),
"skipm2n3" => Ok(HashFunctions::Murmur64Skipm2n3),
v => Err(Error::InvalidHashFunction {
function: v.to_string(),
}),
}
}
}
const COMPLEMENT: [u8; 256] = {
let mut lookup = [0; 256];
lookup[b'A' as usize] = b'T';
lookup[b'C' as usize] = b'G';
lookup[b'G' as usize] = b'C';
lookup[b'T' as usize] = b'A';
lookup[b'N' as usize] = b'N';
lookup
};
#[inline]
pub fn revcomp(seq: &[u8]) -> Vec<u8> {
seq.iter()
.rev()
.map(|nt| COMPLEMENT[*nt as usize])
.collect()
}
static CODONTABLE: Lazy<HashMap<&'static str, u8>> = Lazy::new(|| {
[
("TTT", b'F'),
("TTC", b'F'),
("TTA", b'L'),
("TTG", b'L'),
("TCT", b'S'),
("TCC", b'S'),
("TCA", b'S'),
("TCG", b'S'),
("TCN", b'S'),
("TAT", b'Y'),
("TAC", b'Y'),
("TAA", b'*'),
("TAG", b'*'),
("TGA", b'*'),
("TGT", b'C'),
("TGC", b'C'),
("TGG", b'W'),
("CTT", b'L'),
("CTC", b'L'),
("CTA", b'L'),
("CTG", b'L'),
("CTN", b'L'),
("CCT", b'P'),
("CCC", b'P'),
("CCA", b'P'),
("CCG", b'P'),
("CCN", b'P'),
("CAT", b'H'),
("CAC", b'H'),
("CAA", b'Q'),
("CAG", b'Q'),
("CGT", b'R'),
("CGC", b'R'),
("CGA", b'R'),
("CGG", b'R'),
("CGN", b'R'),
("ATT", b'I'),
("ATC", b'I'),
("ATA", b'I'),
("ATG", b'M'),
("ACT", b'T'),
("ACC", b'T'),
("ACA", b'T'),
("ACG", b'T'),
("ACN", b'T'),
("AAT", b'N'),
("AAC", b'N'),
("AAA", b'K'),
("AAG", b'K'),
("AGT", b'S'),
("AGC", b'S'),
("AGA", b'R'),
("AGG", b'R'),
("GTT", b'V'),
("GTC", b'V'),
("GTA", b'V'),
("GTG", b'V'),
("GTN", b'V'),
("GCT", b'A'),
("GCC", b'A'),
("GCA", b'A'),
("GCG", b'A'),
("GCN", b'A'),
("GAT", b'D'),
("GAC", b'D'),
("GAA", b'E'),
("GAG", b'E'),
("GGT", b'G'),
("GGC", b'G'),
("GGA", b'G'),
("GGG", b'G'),
("GGN", b'G'),
]
.iter()
.cloned()
.collect()
});
static DAYHOFFTABLE: Lazy<HashMap<u8, u8>> = Lazy::new(|| {
[
(b'C', b'a'),
(b'A', b'b'),
(b'G', b'b'),
(b'P', b'b'),
(b'S', b'b'),
(b'T', b'b'),
(b'D', b'c'),
(b'E', b'c'),
(b'N', b'c'),
(b'Q', b'c'),
(b'H', b'd'),
(b'K', b'd'),
(b'R', b'd'),
(b'I', b'e'),
(b'L', b'e'),
(b'M', b'e'),
(b'V', b'e'),
(b'F', b'f'),
(b'W', b'f'),
(b'Y', b'f'),
(b'*', b'*'),
]
.iter()
.cloned()
.collect()
});
static HPTABLE: Lazy<HashMap<u8, u8>> = Lazy::new(|| {
[
(b'A', b'h'),
(b'F', b'h'),
(b'G', b'h'),
(b'I', b'h'),
(b'L', b'h'),
(b'M', b'h'),
(b'P', b'h'),
(b'V', b'h'),
(b'W', b'h'),
(b'Y', b'h'),
(b'N', b'p'),
(b'C', b'p'),
(b'S', b'p'),
(b'T', b'p'),
(b'D', b'p'),
(b'E', b'p'),
(b'R', b'p'),
(b'H', b'p'),
(b'K', b'p'),
(b'Q', b'p'),
(b'*', b'*'),
]
.iter()
.cloned()
.collect()
});
#[inline]
pub fn translate_codon(codon: &[u8]) -> Result<u8, Error> {
if codon.len() == 1 {
return Ok(b'X');
}
if codon.len() == 2 {
let mut v = codon.to_vec();
v.push(b'N');
match CODONTABLE.get(str::from_utf8(v.as_slice()).unwrap()) {
Some(aa) => return Ok(*aa),
None => return Ok(b'X'),
}
}
if codon.len() == 3 {
match CODONTABLE.get(str::from_utf8(codon).unwrap()) {
Some(aa) => return Ok(*aa),
None => return Ok(b'X'),
}
}
Err(Error::InvalidCodonLength {
message: format!("{}", codon.len()),
})
}
#[inline]
pub fn aa_to_dayhoff(aa: u8) -> u8 {
match DAYHOFFTABLE.get(&aa) {
Some(letter) => *letter,
None => b'X',
}
}
pub fn aa_to_hp(aa: u8) -> u8 {
match HPTABLE.get(&aa) {
Some(letter) => *letter,
None => b'X',
}
}
#[inline]
pub fn to_aa(seq: &[u8], dayhoff: bool, hp: bool) -> Result<Vec<u8>, Error> {
let mut converted: Vec<u8> = Vec::with_capacity(seq.len() / 3);
for chunk in seq.chunks(3) {
if chunk.len() < 3 {
break;
}
let residue = translate_codon(chunk)?;
if dayhoff {
converted.push(aa_to_dayhoff(residue));
} else if hp {
converted.push(aa_to_hp(residue));
} else {
converted.push(residue);
}
}
Ok(converted)
}
pub const VALID: [bool; 256] = {
let mut lookup = [false; 256];
lookup[b'A' as usize] = true;
lookup[b'C' as usize] = true;
lookup[b'G' as usize] = true;
lookup[b'T' as usize] = true;
lookup
};
#[derive(Serialize, Deserialize, Default)]
pub struct Colors {
colors: ColorToIdx,
}
impl Colors {
pub fn new() -> Colors {
Default::default()
}
pub fn update<'a, I: IntoIterator<Item = &'a Idx>>(
&mut self,
current_color: Option<Color>,
new_idxs: I,
) -> Result<Color, Error> {
if let Some(color) = current_color {
if let Some(idxs) = self.colors.get_mut(&color) {
let idx_to_add: Vec<_> = new_idxs
.into_iter()
.filter(|new_idx| !idxs.0.contains(new_idx))
.collect();
if idx_to_add.is_empty() {
idxs.1 += 1;
Ok(color)
} else {
let mut idxs = idxs.clone();
idxs.0.extend(idx_to_add.into_iter().cloned());
let new_color = Colors::compute_color(&idxs);
if new_color != color {
self.colors.get_mut(&color).unwrap().1 -= 1;
if self.colors[&color].1 == 0 {
self.colors.remove(&color);
};
};
self.colors
.entry(new_color)
.and_modify(|old_idxs| {
assert_eq!(old_idxs.0, idxs.0);
old_idxs.1 += 1;
})
.or_insert_with(|| (idxs.0, 1));
Ok(new_color)
}
} else {
unimplemented!(
"throw error, current_color must exist in order to be updated. current_color: {:?}, colors: {:#?}",
current_color,
&self.colors
);
}
} else {
let mut idxs = IdxTracker::default();
idxs.0.extend(new_idxs.into_iter().cloned());
idxs.1 = 1;
let new_color = Colors::compute_color(&idxs);
self.colors
.entry(new_color)
.and_modify(|old_idxs| {
assert_eq!(old_idxs.0, idxs.0);
old_idxs.1 += 1;
})
.or_insert_with(|| (idxs.0, 1));
Ok(new_color)
}
}
fn compute_color(idxs: &IdxTracker) -> Color {
let s = BuildHasherDefault::<Xxh3Hash128>::default();
s.hash_one(&idxs.0)
}
pub fn len(&self) -> usize {
self.colors.len()
}
pub fn is_empty(&self) -> bool {
self.colors.is_empty()
}
pub fn contains(&self, color: Color, idx: Idx) -> bool {
if let Some(idxs) = self.colors.get(&color) {
idxs.0.contains(&idx)
} else {
false
}
}
pub fn indices(&self, color: &Color) -> Indices<'_> {
Indices {
iter: self.colors.get(color).unwrap().0.iter(),
}
}
pub fn retain<F>(&mut self, f: F)
where
F: FnMut(&Color, &mut IdxTracker) -> bool,
{
self.colors.retain(f)
}
}
pub struct Indices<'a> {
iter: vec_collections::VecSetIter<core::slice::Iter<'a, Idx>>,
}
impl<'a> Iterator for Indices<'a> {
type Item = &'a Idx;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
#[derive(Default)]
pub(crate) struct Xxh3Hash128(twox_hash::XxHash3_128);
impl std::hash::Hasher for Xxh3Hash128 {
#[inline(always)]
fn finish(&self) -> u64 {
self.0.finish_128() as u64
}
#[inline(always)]
fn write(&mut self, bytes: &[u8]) {
self.0.write(bytes)
}
}
#[cfg(test)]
mod test {
use super::*;
use std::convert::TryFrom;
#[test]
fn colors_update() {
let mut colors = Colors::new();
let color = colors.update(None, &[1_u32]).unwrap();
assert_eq!(colors.len(), 1);
dbg!("update");
let new_color = colors.update(Some(color), &[1_u32]).unwrap();
assert_eq!(colors.len(), 1);
assert_eq!(color, new_color);
dbg!("upgrade");
let new_color = colors.update(Some(color), &[2_u32]).unwrap();
assert_eq!(colors.len(), 2);
assert_ne!(color, new_color);
}
#[test]
fn colors_retain() {
let mut colors = Colors::new();
let color1 = colors.update(None, &[1_u32]).unwrap();
assert_eq!(colors.len(), 1);
dbg!("update");
let same_color = colors.update(Some(color1), &[1_u32]).unwrap();
assert_eq!(colors.len(), 1);
assert_eq!(color1, same_color);
dbg!("upgrade");
let color2 = colors.update(Some(color1), &[2_u32]).unwrap();
assert_eq!(colors.len(), 2);
assert_ne!(color1, color2);
dbg!("update");
let same_color = colors.update(Some(color2), &[2_u32]).unwrap();
assert_eq!(colors.len(), 2);
assert_eq!(color2, same_color);
dbg!("upgrade");
let color3 = colors.update(Some(color1), &[3_u32]).unwrap();
assert_ne!(color1, color3);
assert_ne!(color2, color3);
assert_eq!(colors.len(), 2);
}
#[test]
fn test_dna_method() {
assert!(HashFunctions::Murmur64Dna.dna());
assert!(!HashFunctions::Murmur64Protein.dna());
assert!(!HashFunctions::Murmur64Dayhoff.dna());
}
#[test]
fn test_protein_method() {
assert!(HashFunctions::Murmur64Protein.protein());
assert!(!HashFunctions::Murmur64Dna.protein());
assert!(!HashFunctions::Murmur64Dayhoff.protein());
}
#[test]
fn test_dayhoff_method() {
assert!(HashFunctions::Murmur64Dayhoff.dayhoff());
assert!(!HashFunctions::Murmur64Dna.dayhoff());
assert!(!HashFunctions::Murmur64Protein.dayhoff());
}
#[test]
fn test_hp_method() {
assert!(HashFunctions::Murmur64Hp.hp());
assert!(!HashFunctions::Murmur64Dna.hp());
assert!(!HashFunctions::Murmur64Protein.hp());
}
#[test]
fn test_skipm1n3_method() {
assert!(HashFunctions::Murmur64Skipm1n3.skipm1n3());
assert!(!HashFunctions::Murmur64Dna.skipm1n3());
assert!(!HashFunctions::Murmur64Protein.skipm1n3());
}
#[test]
fn test_skipm2n3_method() {
assert!(HashFunctions::Murmur64Skipm2n3.skipm2n3());
assert!(!HashFunctions::Murmur64Dna.skipm2n3());
assert!(!HashFunctions::Murmur64Protein.skipm2n3());
}
#[test]
fn test_display_hashfunctions() {
assert_eq!(HashFunctions::Murmur64Dna.to_string(), "DNA");
assert_eq!(HashFunctions::Murmur64Protein.to_string(), "protein");
assert_eq!(HashFunctions::Murmur64Dayhoff.to_string(), "dayhoff");
assert_eq!(HashFunctions::Murmur64Hp.to_string(), "hp");
assert_eq!(HashFunctions::Murmur64Skipm1n3.to_string(), "skipm1n3");
assert_eq!(HashFunctions::Murmur64Skipm2n3.to_string(), "skipm2n3");
assert_eq!(
HashFunctions::Custom("custom_string".into()).to_string(),
"custom_string"
);
}
#[test]
fn test_try_from_str_valid() {
assert_eq!(
HashFunctions::try_from("dna").unwrap(),
HashFunctions::Murmur64Dna
);
assert_eq!(
HashFunctions::try_from("protein").unwrap(),
HashFunctions::Murmur64Protein
);
assert_eq!(
HashFunctions::try_from("dayhoff").unwrap(),
HashFunctions::Murmur64Dayhoff
);
assert_eq!(
HashFunctions::try_from("hp").unwrap(),
HashFunctions::Murmur64Hp
);
assert_eq!(
HashFunctions::try_from("skipm1n3").unwrap(),
HashFunctions::Murmur64Skipm1n3
);
assert_eq!(
HashFunctions::try_from("skipm2n3").unwrap(),
HashFunctions::Murmur64Skipm2n3
);
}
#[test]
fn test_try_from_str_invalid() {
let result = HashFunctions::try_from("unknown");
assert!(result.is_err());
let error_message = format!("{}", result.unwrap_err());
assert!(error_message.contains("Invalid hash function"));
}
}