use crate::core::error::{RedicatError, Result};
use bio::io::fasta::{Index, IndexedReader};
use log::{debug, info, warn};
use parking_lot::RwLock;
use rayon::prelude::*;
use std::collections::HashMap;
use std::sync::Arc;
pub struct ReferenceGenome {
reader: parking_lot::Mutex<IndexedReader<std::fs::File>>,
sequences: Vec<String>,
cache: Arc<RwLock<HashMap<String, char>>>,
cache_size_limit: usize,
}
unsafe impl Send for ReferenceGenome {}
unsafe impl Sync for ReferenceGenome {}
impl ReferenceGenome {
pub fn new(fasta_path: &str) -> Result<Self> {
Self::new_with_cache_size(fasta_path, 100_000)
}
pub fn new_with_cache_size(fasta_path: &str, cache_size: usize) -> Result<Self> {
let index_path = format!("{}.fai", fasta_path);
if !std::path::Path::new(fasta_path).exists() {
return Err(RedicatError::FileNotFound(format!(
"FASTA file not found: {}",
fasta_path
)));
}
if !std::path::Path::new(&index_path).exists() {
return Err(RedicatError::FileNotFound(format!(
"FASTA index file not found: {}. Please create it using: samtools faidx {}",
index_path, fasta_path
)));
}
let index = Index::from_file(&index_path).map_err(|e| {
RedicatError::ReferenceGenome(format!("Failed to load FASTA index: {:?}", e))
})?;
let reader = IndexedReader::with_index(
std::fs::File::open(fasta_path).map_err(|e| {
RedicatError::FileNotFound(format!("Cannot open FASTA file: {}", e))
})?,
index,
);
let sequences = Self::read_sequence_names(&index_path)?;
info!(
"Loaded reference genome with {} sequences, cache size: {}",
sequences.len(),
cache_size
);
Ok(ReferenceGenome {
reader: parking_lot::Mutex::new(reader),
sequences,
cache: Arc::new(RwLock::new(HashMap::new())),
cache_size_limit: cache_size,
})
}
pub fn get_ref_of_pos(&self, genomic_pos: &str) -> Result<char> {
{
let cache = self.cache.read();
if let Some(&cached_base) = cache.get(genomic_pos) {
return Ok(cached_base);
}
}
let parts: Vec<&str> = genomic_pos.split(':').collect();
if parts.len() != 2 {
debug!("Invalid genomic position format: {}", genomic_pos);
return Ok('N');
}
let chrom = parts[0];
let pos: u64 = match parts[1].parse() {
Ok(p) if p > 0 => p,
_ => {
debug!("Invalid position in {}: {}", genomic_pos, parts[1]);
return Ok('N');
}
};
if !self.sequences.iter().any(|seq| seq == chrom) {
warn!("Chromosome {} not found in reference genome", chrom);
return Ok('N');
}
let base = self.fetch_base_from_file(chrom, pos)?;
self.cache_base(genomic_pos.to_string(), base);
Ok(base)
}
fn fetch_base_from_file(&self, chrom: &str, pos: u64) -> Result<char> {
let mut reader = self.reader.lock();
let mut sequence = Vec::new();
match reader.fetch(chrom, pos - 1, pos) {
Ok(_) => {
if let Err(e) = reader.read(&mut sequence) {
debug!("Failed to read sequence at {}:{}: {:?}", chrom, pos, e);
return Ok('N');
}
if let Some(&base) = sequence.first() {
let base_char = (base as char).to_ascii_uppercase();
match base_char {
'A' | 'T' | 'C' | 'G' => Ok(base_char),
_ => {
debug!("Non-standard base at {}:{}: {}", chrom, pos, base_char);
Ok('N')
}
}
} else {
debug!("Empty sequence at {}:{}", chrom, pos);
Ok('N')
}
}
Err(e) => {
debug!("Failed to fetch position {}:{}: {:?}", chrom, pos, e);
Ok('N')
}
}
}
fn cache_base(&self, position: String, base: char) {
let mut cache = self.cache.write();
if cache.len() >= self.cache_size_limit {
let to_remove = self.cache_size_limit / 10;
let keys_to_remove: Vec<String> = cache.keys().take(to_remove).cloned().collect();
for key in keys_to_remove {
cache.remove(&key);
}
}
cache.insert(position, base);
}
pub fn get_multiple_refs_batched(&self, positions: &[String]) -> Result<Vec<char>> {
if positions.is_empty() {
return Ok(Vec::new());
}
let mut chrom_groups: HashMap<String, Vec<(usize, u64)>> = HashMap::new();
let mut results = vec!['N'; positions.len()];
for (idx, pos_str) in positions.iter().enumerate() {
let parts: Vec<&str> = pos_str.split(':').collect();
if parts.len() != 2 {
continue;
}
let chrom = parts[0].to_string();
if let Ok(p) = parts[1].parse::<u64>() {
if p > 0 {
chrom_groups.entry(chrom).or_default().push((idx, p));
}
}
}
for (chrom, group) in &chrom_groups {
if !self.sequences.iter().any(|s| s == chrom) {
continue;
}
let min_pos = group.iter().map(|&(_, p)| p).min().unwrap();
let max_pos = group.iter().map(|&(_, p)| p).max().unwrap();
let mut seq = Vec::new();
{
let mut reader = self.reader.lock();
if reader.fetch(chrom, min_pos - 1, max_pos).is_ok() {
let _ = reader.read(&mut seq);
}
}
if seq.is_empty() {
for &(result_idx, pos) in group {
results[result_idx] = self.fetch_base_from_file(chrom, pos).unwrap_or('N');
}
continue;
}
for &(result_idx, pos) in group {
let offset = (pos - min_pos) as usize;
if offset < seq.len() {
let base_char = (seq[offset] as char).to_ascii_uppercase();
results[result_idx] = match base_char {
'A' | 'T' | 'C' | 'G' => base_char,
_ => 'N',
};
}
}
for &(_, pos) in group {
let offset = (pos - min_pos) as usize;
if offset < seq.len() {
let base_char = (seq[offset] as char).to_ascii_uppercase();
let base = match base_char {
'A' | 'T' | 'C' | 'G' => base_char,
_ => 'N',
};
self.cache_base(format!("{}:{}", chrom, pos), base);
}
}
}
Ok(results)
}
pub fn get_multiple_refs_parallel(&self, positions: &[String]) -> Result<Vec<char>> {
if positions.is_empty() {
return Ok(Vec::new());
}
const PARALLEL_THRESHOLD: usize = 100;
if positions.len() < PARALLEL_THRESHOLD {
let results = positions
.iter()
.map(|pos| self.get_ref_of_pos(pos).unwrap_or('N'))
.collect::<Vec<_>>();
Ok(results) } else {
let results: Vec<char> = positions
.par_iter()
.map(|pos| self.get_ref_of_pos(pos).unwrap_or('N'))
.collect();
Ok(results)
}
}
pub fn get_multiple_refs_chunked(
&self,
positions: &[String],
chunk_size: usize,
) -> Result<Vec<char>> {
if positions.is_empty() {
return Ok(Vec::new());
}
let mut results = Vec::with_capacity(positions.len());
for chunk in positions.chunks(chunk_size) {
let chunk_results = self.get_multiple_refs_parallel(chunk)?;
results.extend(chunk_results);
}
Ok(results)
}
pub fn get_sequence_names(&self) -> &[String] {
&self.sequences
}
pub fn validate_position(&self, genomic_pos: &str) -> bool {
match self.get_ref_of_pos(genomic_pos) {
Ok(base) => base != 'N',
Err(_) => false,
}
}
pub fn validate_positions_parallel(&self, positions: &[String]) -> Vec<bool> {
positions
.par_iter()
.map(|pos| self.validate_position(pos))
.collect()
}
pub fn get_cache_stats(&self) -> (usize, usize, f64) {
let cache = self.cache.read();
let size = cache.len();
let limit = self.cache_size_limit;
let usage = size as f64 / limit as f64 * 100.0;
(size, limit, usage)
}
pub fn clear_cache(&self) {
let mut cache = self.cache.write();
cache.clear();
info!("Reference genome cache cleared");
}
pub fn preload_positions(&self, positions: &[String]) -> Result<()> {
info!("Preloading {} positions into cache", positions.len());
let _results: Vec<char> = positions
.par_iter()
.map(|pos| self.get_ref_of_pos(pos).unwrap_or('N'))
.collect();
let (cache_size, _, usage) = self.get_cache_stats();
info!(
"Preloading complete. Cache size: {}, usage: {:.1}%",
cache_size, usage
);
Ok(())
}
fn read_sequence_names(index_path: &str) -> Result<Vec<String>> {
use std::fs::File;
use std::io::{BufRead, BufReader};
let file = File::open(index_path).map_err(RedicatError::Io)?;
let reader = BufReader::new(file);
let sequences: Vec<String> = reader
.lines()
.map(|line| {
line.map_err(RedicatError::Io).and_then(|l| {
let fields: Vec<&str> = l.split('\t').collect();
if fields.is_empty() {
Err(RedicatError::Parse("Empty line in FASTA index".to_string()))
} else {
Ok(fields[0].to_string())
}
})
})
.collect::<Result<Vec<_>>>()?;
if sequences.is_empty() {
return Err(RedicatError::EmptyData(
"No sequences found in FASTA index".to_string(),
));
}
Ok(sequences)
}
}
#[cfg(test)]
mod tests {
use super::ReferenceGenome;
use tempfile::tempdir;
fn create_indexed_fasta() -> std::path::PathBuf {
let dir = tempdir().unwrap();
let fasta_path = dir.path().join("ref.fa");
let fai_path = dir.path().join("ref.fa.fai");
std::fs::write(&fasta_path, b">chr1\nATCG\n>chr2\nGCTA\n").unwrap();
std::fs::write(&fai_path, b"chr1\t4\t6\t4\t5\nchr2\t4\t17\t4\t5\n").unwrap();
let persistent = dir.keep();
persistent.join("ref.fa")
}
#[test]
fn new_reports_missing_reference_files() {
match ReferenceGenome::new("/definitely/not/exist.fa") {
Ok(_) => panic!("expected missing file error"),
Err(err) => assert!(format!("{}", err).contains("FASTA file not found")),
}
}
#[test]
fn get_ref_and_cache_stats_work() {
let fasta = create_indexed_fasta();
let rg = ReferenceGenome::new_with_cache_size(fasta.to_str().unwrap(), 16).unwrap();
assert_eq!(rg.get_ref_of_pos("chr1:1").unwrap(), 'A');
assert_eq!(rg.get_ref_of_pos("chr1:2").unwrap(), 'T');
assert_eq!(rg.get_ref_of_pos("chr2:4").unwrap(), 'A');
assert_eq!(rg.get_ref_of_pos("chr1:100").unwrap(), 'N');
let (size, limit, usage) = rg.get_cache_stats();
assert!(size >= 3);
assert_eq!(limit, 16);
assert!(usage >= 0.0);
rg.clear_cache();
let (size_after, _, _) = rg.get_cache_stats();
assert_eq!(size_after, 0);
}
#[test]
fn invalid_positions_return_n_and_validate_false() {
let fasta = create_indexed_fasta();
let rg = ReferenceGenome::new(fasta.to_str().unwrap()).unwrap();
assert_eq!(rg.get_ref_of_pos("chr1").unwrap(), 'N');
assert_eq!(rg.get_ref_of_pos("chr1:0").unwrap(), 'N');
assert_eq!(rg.get_ref_of_pos("chr1:notanint").unwrap(), 'N');
assert_eq!(rg.get_ref_of_pos("chr9:1").unwrap(), 'N');
assert!(!rg.validate_position("chr1:notanint"));
assert!(!rg.validate_position("chr9:1"));
}
#[test]
fn parallel_and_chunked_match_serial() {
let fasta = create_indexed_fasta();
let rg = ReferenceGenome::new(fasta.to_str().unwrap()).unwrap();
let mut positions = Vec::new();
for i in 1..=60 {
let p = ((i - 1) % 4) + 1;
positions.push(format!("chr1:{}", p));
positions.push(format!("chr2:{}", p));
}
let serial: Vec<char> = positions
.iter()
.map(|p| rg.get_ref_of_pos(p).unwrap())
.collect();
let parallel = rg.get_multiple_refs_parallel(&positions).unwrap();
let chunked = rg.get_multiple_refs_chunked(&positions, 13).unwrap();
let valid = rg.validate_positions_parallel(&positions);
assert_eq!(serial, parallel);
assert_eq!(serial, chunked);
assert!(valid.iter().all(|&x| x));
}
#[test]
fn batched_matches_serial() {
let fasta = create_indexed_fasta();
let rg = ReferenceGenome::new(fasta.to_str().unwrap()).unwrap();
let positions: Vec<String> = vec![
"chr1:1", "chr1:2", "chr1:3", "chr1:4",
"chr2:1", "chr2:2", "chr2:3", "chr2:4",
"chrX:1", "chr1:100", "bad_format", ].into_iter().map(String::from).collect();
let serial: Vec<char> = positions
.iter()
.map(|p| rg.get_ref_of_pos(p).unwrap_or('N'))
.collect();
rg.clear_cache();
let batched = rg.get_multiple_refs_batched(&positions).unwrap();
assert_eq!(serial, batched);
}
#[test]
fn batched_empty_input() {
let fasta = create_indexed_fasta();
let rg = ReferenceGenome::new(fasta.to_str().unwrap()).unwrap();
let result = rg.get_multiple_refs_batched(&[]).unwrap();
assert!(result.is_empty());
}
}