use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader, Cursor, Read};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use crate::error::ChainError;
#[cfg(feature = "gzip")]
use flate2::read::MultiGzDecoder;
use twobit::{TwoBitFile, TwoBitPhysicalFile};
const GZIP_MAGIC: [u8; 2] = [0x1f, 0x8b];
const TWOBIT_MAGIC: [u8; 4] = [0x43, 0x27, 0x41, 0x1a];
const TWOBIT_REV_MAGIC: [u8; 4] = [0x1a, 0x41, 0x27, 0x43];
pub type SequenceMap = HashMap<Vec<u8>, Vec<u8>>;
#[derive(Debug, Clone)]
enum SequenceSource {
TwoBit(PathBuf),
Loaded {
path: PathBuf,
sequences: Arc<SequenceMap>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SequenceFormat {
TwoBit,
Fasta,
}
#[derive(Debug, Clone)]
pub struct SequenceResolver {
source: SequenceSource,
}
impl SequenceResolver {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, ChainError> {
let path = path.as_ref().to_path_buf();
match detect_sequence_format(&path) {
Some(SequenceFormat::TwoBit) => Ok(Self {
source: SequenceSource::TwoBit(path),
}),
Some(SequenceFormat::Fasta) => Ok(Self {
source: SequenceSource::Loaded {
path: path.clone(),
sequences: Arc::new(get_sequences(&path)?),
},
}),
None => Err(sequence_error(format!(
"unsupported sequence format for {} (expected .2bit, .fa, .fasta, .fna, or gzipped FASTA)",
path.display()
))),
}
}
pub fn fetch(
&self,
cache: &mut SequenceCache,
seq_name: &[u8],
start: u32,
length: u32,
) -> Result<Vec<u8>, ChainError> {
match &self.source {
SequenceSource::TwoBit(path) => cache.fetch_twobit(path, seq_name, start, length),
SequenceSource::Loaded { path, sequences } => {
fetch_loaded_sequence(path, sequences, seq_name, start, length)
}
}
}
}
#[derive(Default)]
pub struct SequenceCache {
files: HashMap<PathBuf, TwoBitPhysicalFile>,
}
impl SequenceCache {
fn fetch_twobit(
&mut self,
path: &Path,
seq_name: &[u8],
start: u32,
length: u32,
) -> Result<Vec<u8>, ChainError> {
let seq_name = bytes_to_utf8(seq_name, "2bit sequence name")?;
let end = start
.checked_add(length)
.ok_or_else(|| sequence_error("requested 2bit range overflows u32"))?;
let reader = match self.files.entry(path.to_path_buf()) {
std::collections::hash_map::Entry::Occupied(entry) => entry.into_mut(),
std::collections::hash_map::Entry::Vacant(entry) => {
let file = TwoBitFile::open(path)
.map_err(|err| {
sequence_error(format!("cannot open 2bit {}: {err}", path.display()))
})?
.enable_softmask(true);
entry.insert(file)
}
};
let sequence = reader
.read_sequence(seq_name, start as usize..end as usize)
.map_err(|err| {
sequence_error(format!(
"cannot read {seq_name}:{start}-{end} from {}: {err}",
path.display()
))
})?;
if sequence.len() != length as usize {
return Err(sequence_error(format!(
"sequence range {seq_name}:{start}-{end} exceeds {}",
path.display()
)));
}
Ok(sequence.into_bytes())
}
}
pub fn get_sequences<P: AsRef<Path>>(sequence: P) -> Result<SequenceMap, ChainError> {
let path = sequence.as_ref();
if path == Path::new("-") {
return from_stdin();
}
match detect_sequence_format(path) {
Some(SequenceFormat::TwoBit) => from_2bit(path),
Some(SequenceFormat::Fasta) => from_fa(path),
None => Err(sequence_error(format!(
"cannot determine supported sequence format for {}",
path.display()
))),
}
}
fn from_stdin() -> Result<SequenceMap, ChainError> {
let mut input = Vec::new();
std::io::stdin().read_to_end(&mut input)?;
if input.is_empty() {
return Err(sequence_error(
"missing sequence input and standard input is empty",
));
}
if input.starts_with(&GZIP_MAGIC) {
#[cfg(feature = "gzip")]
{
return parse_fasta_reader(
BufReader::new(MultiGzDecoder::new(Cursor::new(input))),
"stdin",
);
}
#[cfg(not(feature = "gzip"))]
{
return Err(sequence_error(
"gzip-compressed sequence input requires the `gzip` feature",
));
}
}
if input.starts_with(&TWOBIT_MAGIC) || input.starts_with(&TWOBIT_REV_MAGIC) {
return from_2bit_buf(input, "stdin");
}
if input
.iter()
.copied()
.find(|b| !b.is_ascii_whitespace())
.is_some_and(|b| b == b'>')
{
return parse_fasta_reader(BufReader::new(Cursor::new(input)), "stdin");
}
Err(sequence_error("unsupported standard input sequence format"))
}
pub fn from_2bit<P: AsRef<Path>>(path: P) -> Result<SequenceMap, ChainError> {
let path = path.as_ref();
let genome = TwoBitFile::open(path)
.map_err(|err| sequence_error(format!("cannot open 2bit {}: {err}", path.display())))?
.enable_softmask(true);
let source = format!("file {}", path.display());
collect_2bit_sequences(genome, &source)
}
fn from_2bit_buf(buf: Vec<u8>, source: &str) -> Result<SequenceMap, ChainError> {
let genome = TwoBitFile::from_buf(buf)
.map_err(|err| sequence_error(format!("cannot read 2bit from {source}: {err}")))?
.enable_softmask(true);
collect_2bit_sequences(genome, source)
}
fn collect_2bit_sequences<R: Read + std::io::Seek>(
mut genome: TwoBitFile<R>,
source: &str,
) -> Result<SequenceMap, ChainError> {
let mut sequences = HashMap::new();
for chr in genome.chrom_names() {
let seq = genome
.read_sequence(&chr, ..)
.map_err(|err| sequence_error(format!("cannot read {chr} from {source}: {err}")))?;
sequences.insert(chr.into_bytes(), seq.into_bytes());
}
Ok(sequences)
}
pub fn from_fa<P: AsRef<Path>>(path: P) -> Result<SequenceMap, ChainError> {
let path = path.as_ref();
let file = File::open(path)?;
let source = format!("file {}", path.display());
if path
.extension()
.and_then(|ext| ext.to_str())
.is_some_and(|ext| ext.eq_ignore_ascii_case("gz"))
{
#[cfg(feature = "gzip")]
{
return parse_fasta_reader(BufReader::new(MultiGzDecoder::new(file)), &source);
}
#[cfg(not(feature = "gzip"))]
{
return Err(sequence_error(
"gzip-compressed FASTA requires the `gzip` feature",
));
}
}
parse_fasta_reader(BufReader::new(file), &source)
}
fn parse_fasta_reader<R: BufRead>(mut reader: R, source: &str) -> Result<SequenceMap, ChainError> {
let mut acc = HashMap::new();
let mut line = Vec::new();
let mut header: Option<Vec<u8>> = None;
let mut seq = Vec::new();
loop {
line.clear();
let bytes_read = reader.read_until(b'\n', &mut line)?;
if bytes_read == 0 {
break;
}
trim_line_endings(&mut line);
if line.is_empty() {
continue;
}
if line[0] == b'>' {
let record_name = fasta_record_name(&line[1..]);
if record_name.is_empty() {
return Err(sequence_error(format!(
"invalid FASTA in {source}: empty record name"
)));
}
if let Some(prev_header) = header.replace(record_name.to_vec()) {
acc.insert(prev_header, std::mem::take(&mut seq));
}
} else {
if header.is_none() {
return Err(sequence_error(format!(
"invalid FASTA in {source}: sequence data before header"
)));
}
seq.extend_from_slice(&line);
}
}
if let Some(last_header) = header {
acc.insert(last_header, seq);
Ok(acc)
} else {
Err(sequence_error(format!(
"no FASTA records found in {source}"
)))
}
}
fn trim_line_endings(line: &mut Vec<u8>) {
if line.ends_with(b"\n") {
line.pop();
}
if line.ends_with(b"\r") {
line.pop();
}
}
fn fasta_record_name(header: &[u8]) -> &[u8] {
let mut start = 0usize;
while start < header.len() && header[start].is_ascii_whitespace() {
start += 1;
}
let mut end = start;
while end < header.len() && !header[end].is_ascii_whitespace() {
end += 1;
}
&header[start..end]
}
fn detect_sequence_format(path: &Path) -> Option<SequenceFormat> {
if path
.extension()
.and_then(|ext| ext.to_str())
.is_some_and(|ext| ext.eq_ignore_ascii_case("2bit"))
{
return Some(SequenceFormat::TwoBit);
}
let ext = path.extension().and_then(|ext| ext.to_str())?;
if is_fasta_extension(ext) {
return Some(SequenceFormat::Fasta);
}
if ext.eq_ignore_ascii_case("gz") {
let stem = Path::new(path.file_stem()?);
let stem_ext = stem.extension().and_then(|inner| inner.to_str())?;
if is_fasta_extension(stem_ext) {
return Some(SequenceFormat::Fasta);
}
}
None
}
fn is_fasta_extension(ext: &str) -> bool {
ext.eq_ignore_ascii_case("fa")
|| ext.eq_ignore_ascii_case("fasta")
|| ext.eq_ignore_ascii_case("fna")
}
fn fetch_loaded_sequence(
path: &Path,
sequences: &SequenceMap,
seq_name: &[u8],
start: u32,
length: u32,
) -> Result<Vec<u8>, ChainError> {
let sequence = sequences.get(seq_name).ok_or_else(|| {
sequence_error(format!(
"sequence {} not found in {}",
String::from_utf8_lossy(seq_name),
path.display()
))
})?;
let start = start as usize;
let end = start
.checked_add(length as usize)
.ok_or_else(|| sequence_error("requested sequence range overflows usize"))?;
if end > sequence.len() {
return Err(sequence_error(format!(
"sequence range {}:{}-{} exceeds {}",
String::from_utf8_lossy(seq_name),
start,
end,
path.display()
)));
}
Ok(sequence[start..end].to_vec())
}
fn bytes_to_utf8<'a>(value: &'a [u8], context: &str) -> Result<&'a str, ChainError> {
std::str::from_utf8(value).map_err(|_| sequence_error(format!("{context} must be valid UTF-8")))
}
fn sequence_error(message: impl Into<String>) -> ChainError {
ChainError::Unsupported {
msg: message.into().into(),
}
}