use needletail::parse_fastx_file;
use serde::{Deserialize, Serialize};
use std::io::{Read, Write};
use std::num::{NonZeroU8, NonZeroUsize};
use std::path::{Path, PathBuf};
use std::time::Instant;
use std::{error::Error, fs::File};
use log::{debug, info};
mod cluster;
pub use cluster::cluster;
pub const AUTHOR_AND_EMAIL: &str =
"Ben J. Woodcroft, Centre for Microbiome Research, School of Biomedical Sciences, Faculty of Health, Queensland University of Technology <benjwoodcroft near gmail.com>";
pub const CURRENT_DB_VERSION: u32 = 2;
#[derive(Serialize, Deserialize, Debug, Clone)]
struct SeqEncoding(Vec<u64>);
#[derive(Clone)]
struct SeqEncodingLength {
encoding: SeqEncoding,
len: usize,
}
impl SeqEncodingLength {
fn from_bytes(identifier: &[u8], seq: &[u8]) -> Self {
let encoding = seq.chunks(12).enumerate().map(|(chunk_num, chunk)| {
chunk.iter().enumerate().fold(0u64, |acc, (i, &byte)| {
let b = encode_single(byte)
.unwrap_or_else(|| {
let seqname = String::from_utf8_lossy(identifier);
panic!(
"Byte {} cannot be interpreted as nucleotide, in sequence \"{}\" at position {}",
byte, seqname, 12 * chunk_num + i
)
})
.get();
acc | ((b as u64) << (5 * i))
})
}).collect();
Self {
encoding: SeqEncoding(encoding),
len: seq.len(),
}
}
}
#[derive(Serialize, Deserialize, Debug)]
struct WindowSet {
version: u32,
windows: Vec<SeqEncoding>,
len: Option<NonZeroUsize>,
}
impl WindowSet {
fn new(version: u32) -> Self {
WindowSet {
version,
windows: Vec::new(),
len: None,
}
}
fn get_distances(&self, seq: &SeqEncodingLength, distances: &mut [usize]) {
if let Some(n) = self.len {
if n.get() != seq.len {
panic!(
"{}",
&format!("Cannot compute distances between seq of length {} and windows of lengths {}", seq.len, n.get())
)
}
}
for (window, distance) in self.windows.iter().zip(distances.iter_mut()) {
*distance = window
.0
.iter()
.zip(seq.encoding.0.iter())
.map(|(a, b)| (a ^ b).count_ones() as usize)
.sum::<usize>()
/ 2
}
}
fn push_encoding(&mut self, encoding: SeqEncodingLength) {
if let Some(n) = self.len {
if n.get() != encoding.len {
panic!(
"{}",
&format!(
"WindowSet seq length is {}, got a new sequence of length {}",
n, encoding.len
)
)
}
} else {
self.len = Some(
encoding
.len
.try_into()
.expect("Cannot add empty sequence to WindowSet"),
);
}
self.windows.push(encoding.encoding)
}
fn get_as_string(&self, index: usize) -> String {
let uints = &self.windows[index].0;
let v = (0..self.len.map(NonZeroUsize::get).unwrap_or(0))
.map(|i| {
let d = i / 12;
let r = i % 12;
let b = ((uints[d] >> (5 * r)) & 31) as u8;
match b {
0b10000 => b'A',
0b01000 => b'C',
0b00100 => b'G',
0b00010 => b'T',
0b00001 => b'N',
_ => {
panic!("Invalid character in query sequence: {b}")
}
}
})
.collect();
unsafe { String::from_utf8_unchecked(v) }
}
}
pub fn makedb(subject_fasta: &Path, db_path: &Path) -> Result<(), Box<dyn Error>> {
debug!("Opening subject fasta file: {:?}", subject_fasta);
let mut subject_reader =
parse_fastx_file(subject_fasta).expect("valid path/file of subject fasta");
info!("Encoding subject sequences ..");
let mut windows = WindowSet::new(CURRENT_DB_VERSION);
while let Some(record) = subject_reader.next() {
let record = record.expect("valid record");
let encoded = SeqEncodingLength::from_bytes(record.id(), &record.seq());
windows.push_encoding(encoded);
}
info!(
"Encoding of {} sequences complete, writing db file {}",
windows.windows.len(),
db_path.to_string_lossy()
);
let mut ferris_file = File::create(db_path)?;
ferris_file.write_all(&postcard::to_allocvec(&windows).unwrap())?;
info!("DB file written");
Ok(())
}
const fn create_lut() -> [u8; 256] {
let mut lut = [0; 256];
let mut i = 0;
while i < 256 {
let b = match i as u8 {
b'A' | b'a' => 0b10000,
b'C' | b'c' => 0b01000,
b'G' | b'g' => 0b00100,
b'T' | b't' | b'U' | b'u' => 0b00010,
b'N' | b'W' | b'S' | b'M' | b'K' | b'R' | b'Y' | b'B' | b'D' | b'H' | b'V' | b'-'
| b'n' | b'w' | b's' | b'm' | b'k' | b'r' | b'y' | b'b' | b'd' | b'h' | b'v' => 0b00001,
_ => 0,
};
lut[i] = b;
i += 1;
}
lut
}
const BYTE_LUT: [u8; 256] = create_lut();
#[inline(always)]
fn encode_single(c: u8) -> Option<NonZeroU8> {
let lut: [u8; 256] = BYTE_LUT;
let encoding = unsafe { *lut.get_unchecked(c as usize) };
NonZeroU8::new(encoding)
}
pub fn query(
db_path: &Path,
query_fasta: &Path,
max_divergence: Option<u32>,
max_num_hits: Option<u32>,
limit_per_sequence: Option<u32>,
) -> Result<(), Box<dyn Error>> {
info!("Decoding db file {:?}", db_path);
let start = Instant::now();
let mut ferris_file = File::open(db_path)?;
let mut buffer = Vec::new();
ferris_file.read_to_end(&mut buffer)?;
let version: u32 = postcard::from_bytes(&buffer[0..4])?;
if version != CURRENT_DB_VERSION {
panic!("Unsupported db file version: {}. This version of smafa only works with version {} databases. The last version to support version 1 databases was v0.7.1.", version, CURRENT_DB_VERSION);
}
let windows: WindowSet = postcard::from_bytes(&buffer)?;
let mut query_reader = parse_fastx_file(query_fasta).expect("valid path/file of query fasta");
let max_divergence_for_match = max_num_hits.filter(|&max_num_hits| max_num_hits != 1);
let mut distances = vec![0; windows.windows.len()];
info!("Querying ..");
let mut query_number: u32 = 0;
while let Some(record) = query_reader.next() {
let record = record.expect("Failed to parse query sequence");
let query_vec = SeqEncodingLength::from_bytes(record.id(), &record.seq());
windows.get_distances(&query_vec, &mut distances);
match max_divergence_for_match {
Some(max_num_hits) => {
let mut min_distances = distances
.iter()
.enumerate()
.map(|(i, d)| (*d, i))
.collect::<Vec<_>>();
min_distances.sort();
let max_distance = match max_num_hits > min_distances.len() as u32 {
true => *distances.iter().max().unwrap(),
false => min_distances[(max_num_hits - 1) as usize].0,
};
let mut last_sequence: Option<(String, u32)> = None;
let mut new_last_sequence: Option<(String, u32)>; for (distance, i) in min_distances.iter() {
if *distance <= max_distance
&& (max_divergence.is_none()
|| *distance <= max_divergence.unwrap() as usize)
{
let s = windows.get_as_string(*i);
debug!("Found hit sequence {} at distance {}", s, distance);
if let Some(limit_per_sequence_unwrapped) = limit_per_sequence {
match &last_sequence {
Some((last_seq, last_seq_count)) => {
if last_seq == &s {
if last_seq_count >= &limit_per_sequence_unwrapped {
continue;
} else {
new_last_sequence =
Some((s.clone(), last_seq_count + 1));
}
} else {
new_last_sequence = Some((s.clone(), 1));
}
}
None => {
new_last_sequence = Some((s.clone(), 1));
}
}
last_sequence = new_last_sequence;
}
println!("{}\t{}\t{}\t{}", query_number, i, distance, s);
}
}
}
None => {
let min_distance = distances.iter().min().unwrap();
debug!("Min distance: {}", min_distance);
if limit_per_sequence.is_some() {
panic!("limit_per_sequence is implemented unless max_num_hits > 1. It can be implemented by analogy, just haven't gotten around to it.");
}
if max_divergence.is_none() || *min_distance <= max_divergence.unwrap() as usize {
for (i, distance) in distances.iter().enumerate() {
if distance == min_distance {
let s = windows.get_as_string(i);
println!("{}\t{}\t{}\t{}", query_number, i, distance, s);
}
}
}
}
}
query_number += 1;
}
info!(
"Querying complete, took {} seconds",
start.elapsed().as_secs()
);
Ok(())
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use super::*;
#[test]
fn test_makedb() {
let temp_dir = tempfile::tempdir().unwrap();
let db_path = temp_dir.path().join("test.db");
let subject_fasta = std::path::PathBuf::from_str("tests/data/subjects.fa").unwrap();
assert!(makedb(&subject_fasta, &db_path).is_ok());
assert!(db_path.exists());
let mut ferris_file = File::open(&db_path).unwrap();
let mut encoded = Vec::new();
ferris_file.read_to_end(&mut encoded).unwrap();
let windows = postcard::from_bytes::<WindowSet>(&encoded).unwrap();
assert_eq!(windows.windows.len(), 5);
let expected_encoded = vec![
vec![0b10000],
vec![0b01000],
vec![0b00100],
vec![0b00010],
vec![0b00001],
];
for (i, j) in expected_encoded.iter().zip(windows.windows.iter()) {
assert_eq!(i, &j.0)
}
}
}
#[derive(Serialize, Deserialize, Debug)]
struct CountResult {
path: PathBuf,
num_reads: usize,
num_bases: usize,
}
pub fn count<T: Iterator<Item = P>, P: AsRef<Path>>(paths: T) -> Result<(), Box<dyn Error>> {
let mut results = Vec::new();
for path in paths {
let mut reader = parse_fastx_file(&path)?;
let mut read_count = 0;
let mut bases_count = 0;
while let Some(record) = reader.next() {
let record = record?;
read_count += 1;
bases_count += record.seq().len();
}
results.push(CountResult {
path: path.as_ref().to_owned(),
num_reads: read_count,
num_bases: bases_count,
});
}
println!("{}", serde_json::to_string(&results).unwrap());
Ok(())
}