use crate::core::error::{RedicatError, Result};
use log::{info, warn};
use rayon::prelude::*;
use std::collections::HashSet;
use std::fmt;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::str::FromStr;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum EditingType {
AG,
AC,
AT,
CA,
CG,
CT,
}
impl FromStr for EditingType {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"ag" => Ok(EditingType::AG),
"ac" => Ok(EditingType::AC),
"at" => Ok(EditingType::AT),
"ca" => Ok(EditingType::CA),
"cg" => Ok(EditingType::CG),
"ct" => Ok(EditingType::CT),
_ => Err(format!(
"Invalid editing type: {}. Valid types: ag, ac, at, ca, cg, ct",
s
)),
}
}
}
impl fmt::Display for EditingType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
EditingType::AG => write!(f, "ag"),
EditingType::AC => write!(f, "ac"),
EditingType::AT => write!(f, "at"),
EditingType::CA => write!(f, "ca"),
EditingType::CG => write!(f, "cg"),
EditingType::CT => write!(f, "ct"),
}
}
}
impl EditingType {
pub fn to_bases(&self) -> (char, char) {
match self {
EditingType::AG => ('A', 'G'),
EditingType::AC => ('A', 'C'),
EditingType::AT => ('A', 'T'),
EditingType::CA => ('C', 'A'),
EditingType::CG => ('C', 'G'),
EditingType::CT => ('C', 'T'),
}
}
pub fn get_strand_aware_ref_bases(&self) -> [char; 2] {
match self {
EditingType::AG => ['A', 'T'],
EditingType::AC => ['A', 'T'],
EditingType::AT => ['A', 'T'],
EditingType::CA => ['C', 'G'],
EditingType::CG => ['C', 'G'],
EditingType::CT => ['C', 'G'],
}
}
pub fn get_alt_base_for_ref(&self, ref_base: char) -> char {
match self {
EditingType::AG => match ref_base {
'A' => 'G',
'T' => 'C',
_ => 'N',
},
EditingType::AC => match ref_base {
'A' => 'C',
'T' => 'G',
_ => 'N',
},
EditingType::AT => match ref_base {
'A' => 'T',
'T' => 'A',
_ => 'N',
},
EditingType::CA => match ref_base {
'C' => 'A',
'G' => 'T',
_ => 'N',
},
EditingType::CG => match ref_base {
'C' => 'G',
'G' => 'C',
_ => 'N',
},
EditingType::CT => match ref_base {
'C' => 'T',
'G' => 'A',
_ => 'N',
},
}
}
pub fn is_valid_ref_base(&self, base: char) -> bool {
self.get_strand_aware_ref_bases().contains(&base)
}
pub fn all_types() -> Vec<EditingType> {
vec![
EditingType::AG,
EditingType::AC,
EditingType::AT,
EditingType::CA,
EditingType::CG,
EditingType::CT,
]
}
}
pub fn load_rediportal_parallel(path: &str) -> Result<HashSet<String>> {
info!("Loading REDIPortal from: {}", path);
if !std::path::Path::new(path).exists() {
return Err(RedicatError::FileNotFound(format!(
"REDIPortal file not found: {}",
path
)));
}
let file = File::open(path).map_err(|e| {
RedicatError::FileNotFound(format!("Failed to open REDIPortal file {}: {}", path, e))
})?;
let reader: Box<dyn BufRead + Send> = if path.ends_with(".gz") {
Box::new(BufReader::new(flate2::read::GzDecoder::new(file)))
} else {
Box::new(BufReader::new(file))
};
let lines: Vec<String> = reader
.lines()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(RedicatError::Io)?;
if lines.is_empty() {
return Err(RedicatError::EmptyData(
"REDIPortal file is empty".to_string(),
));
}
info!("Read {} lines from REDIPortal file", lines.len());
let editing_sites: HashSet<String> = lines
.par_iter()
.skip(1)
.filter_map(|line| {
let fields: Vec<&str> = line.split('\t').collect();
if fields.len() >= 2 {
match fields[1].parse::<u64>() {
Ok(_pos) => {
Some(format!("{}:{}", fields[0], fields[1]))
}
Err(_) => {
warn!("Invalid position in line: {}", line);
None
}
}
} else {
warn!("Invalid line format (insufficient columns): {}", line);
None
}
})
.collect();
if editing_sites.is_empty() {
return Err(RedicatError::EmptyData(
"No valid editing sites found in REDIPortal file".to_string(),
));
}
info!(
"Loaded {} editing sites from REDIPortal",
editing_sites.len()
);
if log::log_enabled!(log::Level::Debug) {
let sample_sites: Vec<&String> = editing_sites.iter().take(5).collect();
log::debug!("Sample editing sites: {:?}", sample_sites);
}
Ok(editing_sites)
}
pub fn load_rediportal_with_filters(
path: &str,
min_chromosome_length: usize,
allowed_chromosomes: Option<&[&str]>,
) -> Result<HashSet<String>> {
info!("Loading REDIPortal with filters from: {}", path);
let mut editing_sites = load_rediportal_parallel(path)?;
let original_count = editing_sites.len();
if let Some(allowed_chroms) = allowed_chromosomes {
editing_sites.retain(|key| {
if let Some(chr) = key.split(':').next() {
allowed_chroms.contains(&chr)
} else {
false
}
});
info!(
"Filtered by allowed chromosomes: {} -> {} sites",
original_count,
editing_sites.len()
);
}
if min_chromosome_length > 0 {
let before_filter = editing_sites.len();
editing_sites.retain(|key| {
if let Some(chr) = key.split(':').next() {
chr.len() >= min_chromosome_length
} else {
false
}
});
info!(
"Filtered by chromosome name length (>= {}): {} -> {} sites",
min_chromosome_length,
before_filter,
editing_sites.len()
);
}
if editing_sites.is_empty() {
return Err(RedicatError::EmptyData(
"No editing sites remain after filtering".to_string(),
));
}
info!("Final filtered editing sites: {}", editing_sites.len());
Ok(editing_sites)
}
pub fn validate_rediportal_format(path: &str) -> Result<()> {
info!("Validating REDIPortal file format: {}", path);
let file = File::open(path).map_err(|e| {
RedicatError::FileNotFound(format!("Failed to open REDIPortal file {}: {}", path, e))
})?;
let reader: Box<dyn BufRead> = if path.ends_with(".gz") {
Box::new(BufReader::new(flate2::read::GzDecoder::new(file)))
} else {
Box::new(BufReader::new(file))
};
let mut lines = reader.lines();
let header = lines
.next()
.ok_or_else(|| RedicatError::EmptyData("REDIPortal file is empty".to_string()))??;
info!("Header: {}", header);
let mut valid_lines = 0;
let mut invalid_lines = 0;
const MAX_CHECK_LINES: usize = 100;
for (i, line_result) in lines.enumerate().take(MAX_CHECK_LINES) {
match line_result {
Ok(line) => {
let fields: Vec<&str> = line.split('\t').collect();
if fields.len() >= 2 && fields[1].parse::<u64>().is_ok() {
valid_lines += 1;
} else {
invalid_lines += 1;
if invalid_lines <= 5 {
warn!("Invalid line {}: {}", i + 2, line);
}
}
}
Err(e) => {
return Err(RedicatError::Io(e));
}
}
}
if valid_lines == 0 {
return Err(RedicatError::InvalidInput(
"No valid data lines found in REDIPortal file".to_string(),
));
}
info!(
"Validation complete: {} valid lines, {} invalid lines (checked first {} lines)",
valid_lines, invalid_lines, MAX_CHECK_LINES
);
if invalid_lines > valid_lines / 2 {
warn!("High proportion of invalid lines detected. Please check file format.");
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use flate2::{write::GzEncoder, Compression};
use std::fs::File;
use std::io::Write;
use tempfile::tempdir;
fn write_plain_rediportal(path: &std::path::Path) {
let mut f = File::create(path).unwrap();
writeln!(f, "chrom\tpos").unwrap();
writeln!(f, "chr22\t50783283").unwrap();
writeln!(f, "chr1\t100").unwrap();
writeln!(f, "bad\tnot_a_pos").unwrap();
writeln!(f, "chrM\t200").unwrap();
}
fn write_gz_rediportal(path: &std::path::Path) {
let file = File::create(path).unwrap();
let mut gz = GzEncoder::new(file, Compression::default());
writeln!(gz, "chrom\tpos").unwrap();
writeln!(gz, "chr22\t50783283").unwrap();
writeln!(gz, "chrX\t12345").unwrap();
gz.finish().unwrap();
}
#[test]
fn test_editing_type_parsing() {
assert_eq!(EditingType::from_str("ag").unwrap(), EditingType::AG);
assert_eq!(EditingType::from_str("AG").unwrap(), EditingType::AG);
assert_eq!(EditingType::from_str("Ag").unwrap(), EditingType::AG);
assert!(EditingType::from_str("invalid").is_err());
}
#[test]
fn test_strand_aware_bases() {
let ag_type = EditingType::AG;
assert_eq!(ag_type.get_strand_aware_ref_bases(), ['A', 'T']);
assert_eq!(ag_type.get_alt_base_for_ref('A'), 'G');
assert_eq!(ag_type.get_alt_base_for_ref('T'), 'C');
assert_eq!(ag_type.get_alt_base_for_ref('G'), 'N');
}
#[test]
fn test_editing_type_display() {
assert_eq!(EditingType::AG.to_string(), "ag");
assert_eq!(EditingType::CT.to_string(), "ct");
}
#[test]
fn test_valid_ref_base() {
let ag_type = EditingType::AG;
assert!(ag_type.is_valid_ref_base('A'));
assert!(ag_type.is_valid_ref_base('T'));
assert!(!ag_type.is_valid_ref_base('G'));
assert!(!ag_type.is_valid_ref_base('C'));
}
#[test]
fn test_all_types_contains_expected_entries() {
let types = EditingType::all_types();
assert!(types.contains(&EditingType::AG));
assert!(types.contains(&EditingType::CT));
assert_eq!(types.len(), 6);
}
#[test]
fn test_all_editing_type_alt_base_mappings() {
assert_eq!(EditingType::AG.get_alt_base_for_ref('A'), 'G');
assert_eq!(EditingType::AG.get_alt_base_for_ref('T'), 'C');
assert_eq!(EditingType::AG.get_alt_base_for_ref('C'), 'N');
assert_eq!(EditingType::AG.get_alt_base_for_ref('G'), 'N');
assert_eq!(EditingType::AC.get_alt_base_for_ref('A'), 'C');
assert_eq!(EditingType::AC.get_alt_base_for_ref('T'), 'G');
assert_eq!(EditingType::AC.get_alt_base_for_ref('C'), 'N');
assert_eq!(EditingType::AT.get_alt_base_for_ref('A'), 'T');
assert_eq!(EditingType::AT.get_alt_base_for_ref('T'), 'A');
assert_eq!(EditingType::AT.get_alt_base_for_ref('G'), 'N');
assert_eq!(EditingType::CA.get_alt_base_for_ref('C'), 'A');
assert_eq!(EditingType::CA.get_alt_base_for_ref('G'), 'T');
assert_eq!(EditingType::CA.get_alt_base_for_ref('A'), 'N');
assert_eq!(EditingType::CG.get_alt_base_for_ref('C'), 'G');
assert_eq!(EditingType::CG.get_alt_base_for_ref('G'), 'C');
assert_eq!(EditingType::CG.get_alt_base_for_ref('A'), 'N');
assert_eq!(EditingType::CT.get_alt_base_for_ref('C'), 'T');
assert_eq!(EditingType::CT.get_alt_base_for_ref('G'), 'A');
assert_eq!(EditingType::CT.get_alt_base_for_ref('A'), 'N');
}
#[test]
fn test_strand_aware_ref_bases_for_all_types() {
assert_eq!(EditingType::AG.get_strand_aware_ref_bases(), ['A', 'T']);
assert_eq!(EditingType::AC.get_strand_aware_ref_bases(), ['A', 'T']);
assert_eq!(EditingType::AT.get_strand_aware_ref_bases(), ['A', 'T']);
assert_eq!(EditingType::CA.get_strand_aware_ref_bases(), ['C', 'G']);
assert_eq!(EditingType::CG.get_strand_aware_ref_bases(), ['C', 'G']);
assert_eq!(EditingType::CT.get_strand_aware_ref_bases(), ['C', 'G']);
}
#[test]
fn test_to_bases_for_all_types() {
assert_eq!(EditingType::AG.to_bases(), ('A', 'G'));
assert_eq!(EditingType::AC.to_bases(), ('A', 'C'));
assert_eq!(EditingType::AT.to_bases(), ('A', 'T'));
assert_eq!(EditingType::CA.to_bases(), ('C', 'A'));
assert_eq!(EditingType::CG.to_bases(), ('C', 'G'));
assert_eq!(EditingType::CT.to_bases(), ('C', 'T'));
}
#[test]
fn test_editing_type_case_insensitive_parsing() {
assert_eq!(EditingType::from_str("ct").unwrap(), EditingType::CT);
assert_eq!(EditingType::from_str("CT").unwrap(), EditingType::CT);
assert_eq!(EditingType::from_str("Ct").unwrap(), EditingType::CT);
assert_eq!(EditingType::from_str("cT").unwrap(), EditingType::CT);
}
#[test]
fn test_is_valid_ref_base_all_types() {
assert!(EditingType::AG.is_valid_ref_base('A'));
assert!(EditingType::AG.is_valid_ref_base('T'));
assert!(!EditingType::AG.is_valid_ref_base('C'));
assert!(!EditingType::AG.is_valid_ref_base('G'));
assert!(!EditingType::AG.is_valid_ref_base('N'));
assert!(EditingType::CT.is_valid_ref_base('C'));
assert!(EditingType::CT.is_valid_ref_base('G'));
assert!(!EditingType::CT.is_valid_ref_base('A'));
}
#[test]
fn test_load_rediportal_empty_after_header() {
let dir = tempdir().unwrap();
let path = dir.path().join("empty_data.tsv");
let mut f = File::create(&path).unwrap();
writeln!(f, "chrom\tpos").unwrap();
let err = load_rediportal_parallel(path.to_str().unwrap()).unwrap_err();
assert!(format!("{}", err).contains("No valid editing sites"));
}
#[test]
fn test_load_rediportal_single_column_lines_skipped() {
let dir = tempdir().unwrap();
let path = dir.path().join("single_col.tsv");
let mut f = File::create(&path).unwrap();
writeln!(f, "chrom\tpos").unwrap();
writeln!(f, "chr1").unwrap(); writeln!(f, "chr2\t500").unwrap(); let sites = load_rediportal_parallel(path.to_str().unwrap()).unwrap();
assert_eq!(sites.len(), 1);
assert!(sites.contains("chr2:500"));
}
#[test]
fn test_load_rediportal_parallel_plain_and_gz() {
let dir = tempdir().unwrap();
let plain = dir.path().join("redi.tsv");
let gz = dir.path().join("redi.tsv.gz");
write_plain_rediportal(&plain);
write_gz_rediportal(&gz);
let plain_sites = load_rediportal_parallel(plain.to_str().unwrap()).unwrap();
assert!(plain_sites.contains("chr22:50783283"));
assert!(plain_sites.contains("chr1:100"));
assert!(plain_sites.contains("chrM:200"));
assert!(!plain_sites.contains("bad:not_a_pos"));
let gz_sites = load_rediportal_parallel(gz.to_str().unwrap()).unwrap();
assert!(gz_sites.contains("chr22:50783283"));
assert!(gz_sites.contains("chrX:12345"));
}
#[test]
fn test_load_rediportal_with_filters() {
let dir = tempdir().unwrap();
let plain = dir.path().join("redi.tsv");
write_plain_rediportal(&plain);
let filtered = load_rediportal_with_filters(
plain.to_str().unwrap(),
4,
Some(&["chr22", "chr1"]),
)
.unwrap();
assert!(filtered.contains("chr22:50783283"));
assert!(filtered.contains("chr1:100"));
assert!(!filtered.contains("chrM:200"));
}
#[test]
fn test_validate_rediportal_format_and_errors() {
let dir = tempdir().unwrap();
let ok = dir.path().join("ok.tsv");
write_plain_rediportal(&ok);
validate_rediportal_format(ok.to_str().unwrap()).unwrap();
let bad = dir.path().join("bad.tsv");
let mut f = File::create(&bad).unwrap();
writeln!(f, "chrom\tpos").unwrap();
writeln!(f, "chr1\tnotnum").unwrap();
writeln!(f, "chr2\tNaN").unwrap();
let err = validate_rediportal_format(bad.to_str().unwrap()).unwrap_err();
assert!(format!("{}", err).contains("No valid data lines"));
let missing = dir.path().join("missing.tsv");
let err = load_rediportal_parallel(missing.to_str().unwrap()).unwrap_err();
assert!(format!("{}", err).contains("not found"));
}
}