use needletail::{FastxReader, parse_fastx_file, parse_fastx_stdin};
use sassy::CachedRev;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
const DEFAULT_BATCH_BYTES: usize = 1024 * 1024; const DEFAULT_BATCH_PATTERNS: usize = 64;
pub type ID = String;
#[derive(Clone, Debug)]
pub struct PatternRecord {
pub id: ID,
pub seq: Vec<u8>,
}
#[derive(Debug)]
pub struct TextRecord {
pub id: ID,
pub seq: CachedRev<Vec<u8>>,
pub quality: Vec<u8>,
}
pub type TextBatch = Arc<Vec<TextRecord>>;
pub type TaskBatch<'a> = (&'a Path, &'a [PatternRecord], TextBatch);
struct RecordState {
reader: Box<dyn FastxReader + Send>,
text_batch: Arc<Vec<TextRecord>>,
file_idx: usize,
pat_idx: usize,
batch_id: usize,
}
pub struct InputIterator<'a> {
patterns: &'a [PatternRecord],
paths: &'a Vec<PathBuf>,
state: Mutex<RecordState>,
batch_byte_limit: usize,
batch_pattern_limit: usize,
rev: bool,
}
fn parse_file(path: &PathBuf) -> Box<dyn FastxReader> {
if path == Path::new("") || path == Path::new("-") {
parse_fastx_stdin().unwrap()
} else {
parse_fastx_file(path).unwrap()
}
}
impl<'a> InputIterator<'a> {
pub fn new(
paths: &'a Vec<PathBuf>,
patterns: &'a [PatternRecord],
max_batch_bytes: Option<usize>,
max_batch_patterns: Option<usize>,
rev: bool,
) -> Self {
let reader = parse_file(&paths[0]);
let batch_pattern_limit = max_batch_patterns.unwrap_or(DEFAULT_BATCH_PATTERNS);
let state = RecordState {
reader,
text_batch: Arc::new(Vec::new()),
file_idx: 0,
pat_idx: patterns.len() / batch_pattern_limit + 2,
batch_id: 0,
};
Self {
patterns,
paths,
state: Mutex::new(state),
batch_byte_limit: max_batch_bytes.unwrap_or(DEFAULT_BATCH_BYTES),
batch_pattern_limit,
rev,
}
}
pub fn next_batch(&self) -> Option<(usize, TaskBatch<'a>)> {
let mut state = self.state.lock().unwrap();
let batch_id = state.batch_id;
state.batch_id += 1;
if state.pat_idx * self.batch_pattern_limit >= self.patterns.len() {
if state.file_idx >= self.paths.len() {
log::debug!("No more files to read, returning None");
return None;
}
let mut text_batch = Vec::new();
let mut bytes_in_batch = 0usize;
'outer: loop {
let current_record = loop {
match state.reader.next() {
Some(Ok(rec)) => {
let id = String::from_utf8(rec.id().to_vec()).unwrap().to_string();
let seq = rec.seq().into_owned();
let static_text = CachedRev::new(seq, false);
break TextRecord {
id,
seq: static_text,
quality: rec.qual().unwrap_or(&[]).to_vec(),
};
}
Some(Err(e)) => panic!("Error reading FASTA record: {e}"),
None => {
if !text_batch.is_empty() {
break 'outer;
}
state.file_idx += 1;
let end_of_files = state.file_idx >= self.paths.len();
if end_of_files {
log::debug!("No more files to read, returning None");
return None;
}
state.reader = parse_file(&self.paths[state.file_idx]);
continue;
}
}
};
let record_len = current_record.seq.text.len();
bytes_in_batch += record_len;
log::trace!(
"Push record of len {record_len:>5} total len {bytes_in_batch:>8} limit {}",
self.batch_byte_limit
);
text_batch.push(current_record);
if bytes_in_batch >= self.batch_byte_limit {
log::debug!("New batch of {} kB", bytes_in_batch / 1024);
break;
}
}
if self.rev {
for text_record in &mut text_batch {
text_record.seq.initialize_rev();
}
}
state.text_batch = Arc::new(text_batch);
state.pat_idx = 0;
}
let start = state.pat_idx * self.batch_pattern_limit;
let end = (start + self.batch_pattern_limit).min(self.patterns.len());
state.pat_idx += 1;
log::debug!(
"Batch {batch_id:>3}: {} seqs {} patterns",
state.text_batch.len(),
end - start,
);
Some((
batch_id,
(
&self.paths[state.file_idx],
&self.patterns[start..end],
state.text_batch.clone(),
),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
use std::io::Write;
use tempfile::NamedTempFile;
fn random_dan_seq(len: usize) -> Vec<u8> {
let mut rng = rand::rng();
let mut seq = Vec::new();
let bases = b"ACGT";
for _ in 0..len {
seq.push(bases[rng.random_range(0..bases.len())]);
}
seq
}
#[test]
fn test_record_iterator() {
let mut rng = rand::rng();
let mut seqs = Vec::new();
for _ in 0..100 {
seqs.push(random_dan_seq(rng.random_range(100..1000)));
}
let mut file = NamedTempFile::new().unwrap();
for (i, seq) in seqs.into_iter().enumerate() {
file.write_all(format!(">seq_{}\n{}\n", i, String::from_utf8(seq).unwrap()).as_bytes())
.unwrap();
}
file.flush().unwrap();
let mut patterns = Vec::new();
for i in 0..10 {
patterns.push(PatternRecord {
id: format!("pattern_{}", i),
seq: random_dan_seq(rng.random_range(250..1000)),
});
}
let paths = vec![file.path().to_path_buf()];
let iter = InputIterator::new(&paths, &patterns, Some(500), None, true);
let mut batch_id = 0;
while let Some(batch) = iter.next_batch() {
batch_id += 1;
let unique_texts = batch
.1
.2
.iter()
.map(|item| item.seq.text.clone())
.collect::<std::collections::HashSet<_>>();
let text_len = unique_texts.iter().map(|text| text.len()).sum::<usize>();
let n_patterns = batch
.1
.1
.iter()
.map(|item| item.id.clone())
.collect::<std::collections::HashSet<_>>()
.len();
let n_texts = unique_texts.len();
println!(
"Batch {batch_id} (tot_size: {text_len}, n_texts: {n_texts}): {n_patterns} patterns"
);
}
drop(file);
}
}