use std::{
fs,
path::{Path, PathBuf},
sync::mpsc,
};
use colored::Colorize;
use indexmap::IndexMap;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use crate::{
core::OutputPrint,
helper::{
files,
sequence::SeqParser,
types::{DataType, Header, InputFmt, OutputFmt, SeqMatrix},
utils,
},
stats::sequence::Sites,
writer::sequences::SeqWriter,
};
pub enum TrimmingParameters {
MissingData(f64),
ParsInf(usize),
None,
}
pub struct AlignmentTrimming<'a> {
pub input_files: &'a [PathBuf],
pub input_fmt: &'a InputFmt,
pub datatype: &'a DataType,
pub output_dir: &'a Path,
pub output_fmt: &'a OutputFmt,
pub params: &'a TrimmingParameters,
}
impl Default for AlignmentTrimming<'_> {
fn default() -> Self {
Self {
input_files: &[],
input_fmt: &InputFmt::Fasta,
datatype: &DataType::Dna,
output_dir: Path::new(""),
output_fmt: &OutputFmt::Fasta,
params: &TrimmingParameters::MissingData(0.1),
}
}
}
impl OutputPrint for AlignmentTrimming<'_> {}
impl<'a> AlignmentTrimming<'a> {
pub fn new(
input_files: &'a [PathBuf],
input_fmt: &'a InputFmt,
datatype: &'a DataType,
output_dir: &'a Path,
output_fmt: &'a OutputFmt,
params: &'a TrimmingParameters,
) -> Self {
Self {
input_files,
input_fmt,
datatype,
output_dir,
output_fmt,
params,
}
}
pub fn trim(&self) {
let spinner = utils::set_spinner();
spinner.set_message("Trimming alignment files...");
spinner.set_message("Writing summary...");
let summary = self.trim_sites();
self.write_summary(&summary);
spinner.finish_with_message("Alignment files trimmed successfully\n");
self.print_output_info();
}
fn trim_sites(&self) -> Vec<TrimmingSummary> {
match self.params {
TrimmingParameters::MissingData(threshold) => self.par_trim_missing_data(*threshold),
TrimmingParameters::ParsInf(threshold) => self.par_trim_informative_sites(*threshold),
TrimmingParameters::None => {
log::warn!("No trimming parameters provided. Skipping trimming...");
std::process::exit(1);
}
}
}
fn par_trim_informative_sites(&self, threshold: usize) -> Vec<TrimmingSummary> {
let (tx, rx) = mpsc::channel();
self.input_files.par_iter().for_each_with(tx, |tx, file| {
let (matrix, header) = self.parse_alignment(file);
let mut summary = TrimmingSummary::new(file);
match self.remove_site_with_informative(&matrix, threshold) {
Some(new_matrix) => {
let nchar = self.write_output(&new_matrix, file);
summary.add_summary(header.nchar, nchar);
}
None => {
summary.add_summary(header.nchar, 0);
}
}
tx.send(summary).expect("Failed to send summary");
});
rx.iter().collect()
}
fn par_trim_missing_data(&self, threshold: f64) -> Vec<TrimmingSummary> {
let (tx, rx) = mpsc::channel();
self.input_files.par_iter().for_each_with(tx, |tx, file| {
let mut summary = TrimmingSummary::new(file);
let (matrix, header) = self.parse_alignment(file);
match self.remove_site_with_missing_data(&matrix, threshold) {
Some(new_matrix) => {
let nchar = self.write_output(&new_matrix, file);
summary.add_summary(header.nchar, nchar);
}
None => {
let nchar = self.write_output(&matrix, file);
summary.add_summary(header.nchar, nchar);
}
}
tx.send(summary).expect("Failed to send summary");
});
rx.iter().collect()
}
fn parse_alignment(&self, file: &Path) -> (SeqMatrix, Header) {
SeqParser::new(file, self.datatype).get_alignment(self.input_fmt)
}
fn remove_site_with_informative(
&self,
matrix: &SeqMatrix,
threshold: usize,
) -> Option<SeqMatrix> {
let site_pos = Sites::default()
.get_site_with_pars_informative(matrix, self.datatype, threshold)
.iter()
.map(|(i, _)| *i)
.collect::<Vec<usize>>();
if site_pos.is_empty() {
None
} else {
Some(self.generate_new_matrix(matrix, site_pos))
}
}
fn remove_site_with_missing_data(
&self,
matrix: &SeqMatrix,
threshold: f64,
) -> Option<SeqMatrix> {
let site_pos = Sites::default().get_site_without_missing_data(matrix, threshold);
if site_pos.is_empty() {
None
} else {
Some(self.generate_new_matrix(matrix, site_pos))
}
}
fn generate_new_matrix(&self, matrix: &SeqMatrix, site_pos: Vec<usize>) -> SeqMatrix {
let mut new_matrix: SeqMatrix = IndexMap::new();
matrix.iter().for_each(|(k, v)| {
let new_row = v
.bytes()
.enumerate()
.filter(|(i, _)| site_pos.contains(i))
.map(|(_, b)| b)
.collect::<Vec<u8>>();
new_matrix.insert(
k.clone(),
String::from_utf8(new_row).expect("Invalid UTF-8"),
);
});
new_matrix
}
fn write_output(&self, matrix: &SeqMatrix, file: &Path) -> usize {
let alignment_dir = self.output_dir.join("trimmed_alignments");
let output_path = files::create_output_fname(&alignment_dir, file, self.output_fmt);
let mut header = Header::new();
header.from_seq_matrix(matrix, true);
let mut writer = SeqWriter::new(&output_path, matrix, &header);
writer
.write_sequence(self.output_fmt)
.expect("Failed to write output file");
header.nchar
}
fn write_summary(&self, summary: &[TrimmingSummary]) {
let output_path = self
.output_dir
.join("trimming_summary")
.with_extension("csv");
fs::create_dir_all(self.output_dir).expect("Failed to create output directory");
let mut writer = csv::Writer::from_path(output_path).expect("Failed to create CSV writer");
summary.iter().for_each(|s| {
writer.serialize(s).expect("Failed to write summary");
});
}
fn print_output_info(&self) {
log::info!("{}", "Output".yellow());
log::info!("{:18}: {}", "Directory", self.output_dir.display());
self.print_output_fmt(self.output_fmt);
}
}
#[derive(Debug, Serialize, Default, Deserialize)]
struct TrimmingSummary {
parent_path: PathBuf,
file_name: String,
site_count_before: usize,
site_count_after: usize,
site_removed: usize,
}
impl TrimmingSummary {
fn new(path: &Path) -> Self {
Self {
parent_path: path
.parent()
.expect("Failed to get parent path")
.to_path_buf(),
file_name: path
.file_name()
.expect("Failed to get file name")
.to_str()
.expect("Failed to convert to string")
.to_string(),
site_count_before: 0,
site_count_after: 0,
site_removed: 0,
}
}
fn add_summary(&mut self, before: usize, after: usize) {
if before < after {
panic!("Number of sites after trimming is greater than before");
}
if before == after {
self.site_count_before = before;
self.site_count_after = after;
self.site_removed = 0;
} else {
self.site_count_before = before;
self.site_count_after = after;
self.site_removed = before - after;
}
}
}
#[cfg(test)]
mod tests {
use tempdir::TempDir;
use super::*;
macro_rules! init_trimming {
($input:expr_2021, $output:expr_2021, $params:expr_2021) => {
AlignmentTrimming::new(
$input,
&InputFmt::Auto,
&DataType::Dna,
&$output.path(),
&OutputFmt::Fasta,
$params,
)
};
}
const INPUT_PATH: &str = "tests/files/trimming.fas";
#[test]
fn test_trim_missing_data() {
let input_files = vec![PathBuf::from(INPUT_PATH)];
let output_dir = TempDir::new("test").expect("Failed to create temp dir");
let params = TrimmingParameters::MissingData(0.4);
let align_trim = init_trimming!(&input_files, output_dir, ¶ms);
let summary = align_trim.trim_sites();
let (matrix, header) = align_trim.parse_alignment(&input_files[0]);
let site = Sites::default();
let site_missing = site.get_site_without_missing_data(&matrix, 0.6);
let index_site = site.index_site_with_missing_data(&matrix);
assert_eq!(index_site.len(), 8);
assert_eq!(index_site.get(&0).unwrap().len(), 4);
assert_eq!(summary.len(), 1);
assert_eq!(summary[0].site_count_before, 8);
assert_eq!(summary[0].site_count_after, 7);
assert_eq!(matrix.len(), 4);
assert_eq!(header.nchar, 8);
assert_eq!(site_missing.len(), 7);
}
#[test]
fn test_trimming_results() {
let input_files = vec![PathBuf::from(INPUT_PATH)];
let output_dir = TempDir::new("test").expect("Failed to create temp dir");
let params = TrimmingParameters::MissingData(0.4);
let align_trim = init_trimming!(&input_files, output_dir, ¶ms);
align_trim.trim();
let output_files = output_dir.path().read_dir().unwrap();
assert_eq!(output_files.count(), 2);
}
}