use crate::error::FerroError;
use std::collections::HashMap;
use std::fs::{self, File};
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::{Path, PathBuf};
use std::sync::RwLock;
pub struct ProteinCache {
cache_dir: PathBuf,
memory_cache: RwLock<HashMap<String, String>>,
max_memory_entries: usize,
}
impl ProteinCache {
pub fn new(cache_dir: impl AsRef<Path>) -> Result<Self, FerroError> {
let cache_dir = cache_dir.as_ref().to_path_buf();
fs::create_dir_all(&cache_dir)?;
Ok(Self {
cache_dir,
memory_cache: RwLock::new(HashMap::new()),
max_memory_entries: 10000,
})
}
pub fn get_sequence(&self, accession: &str) -> Result<String, FerroError> {
{
let cache = self.memory_cache.read().unwrap();
if let Some(seq) = cache.get(accession) {
return Ok(seq.clone());
}
}
if let Ok(seq) = self.load_from_disk(accession) {
self.add_to_memory_cache(accession, &seq);
return Ok(seq);
}
let seq = self.fetch_from_ncbi(accession)?;
self.save_to_disk(accession, &seq)?;
self.add_to_memory_cache(accession, &seq);
Ok(seq)
}
pub fn get_subsequence(
&self,
accession: &str,
start: u64,
end: u64,
) -> Result<String, FerroError> {
let seq = self.get_sequence(accession)?;
let start = start as usize;
let end = end as usize;
if start >= seq.len() || end > seq.len() || start > end {
return Err(FerroError::InvalidCoordinates {
msg: format!(
"Position {}:{}-{} out of bounds for protein {} (length {})",
accession,
start,
end,
accession,
seq.len()
),
});
}
Ok(seq[start..end].to_string())
}
fn add_to_memory_cache(&self, accession: &str, sequence: &str) {
let mut cache = self.memory_cache.write().unwrap();
if cache.len() >= self.max_memory_entries {
let to_remove: Vec<_> = cache
.keys()
.take(self.max_memory_entries / 2)
.cloned()
.collect();
for key in to_remove {
cache.remove(&key);
}
}
cache.insert(accession.to_string(), sequence.to_string());
}
fn load_from_disk(&self, accession: &str) -> Result<String, FerroError> {
let path = self.cache_path(accession);
let file = File::open(&path).map_err(|e| FerroError::Io {
msg: format!("Failed to open cache file {:?}: {}", path, e),
})?;
let reader = BufReader::new(file);
let mut sequence = String::new();
for line in reader.lines() {
let line = line?;
if !line.starts_with('>') {
sequence.push_str(line.trim());
}
}
if sequence.is_empty() {
return Err(FerroError::ProteinReferenceNotAvailable {
accession: accession.to_string(),
start: 0,
end: 0,
});
}
Ok(sequence)
}
fn save_to_disk(&self, accession: &str, sequence: &str) -> Result<(), FerroError> {
let path = self.cache_path(accession);
let file = File::create(&path)?;
let mut writer = BufWriter::new(file);
writeln!(writer, ">{}", accession)?;
for chunk in sequence.as_bytes().chunks(60) {
writeln!(writer, "{}", std::str::from_utf8(chunk).unwrap_or(""))?;
}
writer.flush()?;
Ok(())
}
fn cache_path(&self, accession: &str) -> PathBuf {
let safe_name = accession.replace(['/', '\\', ':', '*', '?', '"', '<', '>', '|'], "_");
self.cache_dir.join(format!("{}.fa", safe_name))
}
#[cfg(feature = "protein-fetch")]
fn fetch_from_ncbi(&self, accession: &str) -> Result<String, FerroError> {
let url = format!(
"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?db=protein&id={}&rettype=fasta&retmode=text",
accession
);
let response = ureq::get(&url).call().map_err(|e| FerroError::Io {
msg: format!("Failed to fetch protein {} from NCBI: {}", accession, e),
})?;
let body = response.into_string().map_err(|e| FerroError::Io {
msg: format!("Failed to read response for {}: {}", accession, e),
})?;
let mut sequence = String::new();
for line in body.lines() {
if !line.starts_with('>') && !line.is_empty() {
sequence.push_str(line.trim());
}
}
if sequence.is_empty() {
return Err(FerroError::ProteinReferenceNotAvailable {
accession: accession.to_string(),
start: 0,
end: 0,
});
}
Ok(sequence)
}
#[cfg(not(feature = "protein-fetch"))]
fn fetch_from_ncbi(&self, accession: &str) -> Result<String, FerroError> {
Err(FerroError::Io {
msg: format!(
"Cannot fetch protein {} from NCBI: protein-fetch feature not enabled",
accession
),
})
}
pub fn prefetch(&self, accessions: &[&str]) -> usize {
let mut success_count = 0;
for accession in accessions {
if self.get_sequence(accession).is_ok() {
success_count += 1;
}
}
success_count
}
pub fn is_cached(&self, accession: &str) -> bool {
self.cache_path(accession).exists()
}
pub fn cached_count(&self) -> usize {
fs::read_dir(&self.cache_dir)
.map(|entries| entries.filter_map(Result::ok).count())
.unwrap_or(0)
}
pub fn clear_memory_cache(&self) {
let mut cache = self.memory_cache.write().unwrap();
cache.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_protein_cache_new() {
let temp_dir = TempDir::new().unwrap();
let cache = ProteinCache::new(temp_dir.path()).unwrap();
assert_eq!(cache.cached_count(), 0);
}
#[test]
fn test_cache_path_sanitization() {
let temp_dir = TempDir::new().unwrap();
let cache = ProteinCache::new(temp_dir.path()).unwrap();
let path = cache.cache_path("NP_000079.2");
assert!(path.to_string_lossy().contains("NP_000079.2.fa"));
let path = cache.cache_path("test/acc:1");
let filename = path.file_name().unwrap().to_string_lossy();
assert!(!filename.contains('/'));
assert!(!filename.contains(':'));
assert!(filename.contains("test_acc_1.fa"));
}
#[test]
fn test_save_and_load_sequence() {
let temp_dir = TempDir::new().unwrap();
let cache = ProteinCache::new(temp_dir.path()).unwrap();
let accession = "TEST_PROTEIN.1";
let sequence = "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSH";
cache.save_to_disk(accession, sequence).unwrap();
assert!(cache.is_cached(accession));
let loaded = cache.load_from_disk(accession).unwrap();
assert_eq!(loaded, sequence);
}
#[test]
fn test_get_subsequence() {
let temp_dir = TempDir::new().unwrap();
let cache = ProteinCache::new(temp_dir.path()).unwrap();
let accession = "TEST_PROTEIN.1";
let sequence = "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSH";
cache.save_to_disk(accession, sequence).unwrap();
let subseq = cache.get_subsequence(accession, 0, 5).unwrap();
assert_eq!(subseq, "MVLSP");
let subseq = cache.get_subsequence(accession, 10, 20).unwrap();
assert_eq!(subseq, "VKAAWGKVGA");
}
#[test]
fn test_get_subsequence_out_of_bounds() {
let temp_dir = TempDir::new().unwrap();
let cache = ProteinCache::new(temp_dir.path()).unwrap();
let accession = "TEST_PROTEIN.1";
let sequence = "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSH";
cache.save_to_disk(accession, sequence).unwrap();
let result = cache.get_subsequence(accession, 0, 1000);
assert!(result.is_err());
}
#[test]
fn test_memory_cache() {
let temp_dir = TempDir::new().unwrap();
let cache = ProteinCache::new(temp_dir.path()).unwrap();
let accession = "TEST_PROTEIN.1";
let sequence = "MVLSPADKTN";
cache.save_to_disk(accession, sequence).unwrap();
let _ = cache.get_sequence(accession).unwrap();
{
let mem_cache = cache.memory_cache.read().unwrap();
assert!(mem_cache.contains_key(accession));
}
cache.clear_memory_cache();
{
let mem_cache = cache.memory_cache.read().unwrap();
assert!(mem_cache.is_empty());
}
}
}