#![allow(dead_code)]
use crate::core::error::{RedicatError, Result};
use std::path::Path;
pub struct ValidationConfig {
pub max_other_threshold: f64,
pub min_edited_threshold: f64,
pub min_ref_threshold: f64,
pub min_coverage: u16,
pub chunk_size: usize,
pub num_threads: usize,
}
impl ValidationConfig {
pub fn validate(&self) -> Result<()> {
self.validate_threshold("max_other_threshold", self.max_other_threshold, 0.0, 1.0)?;
self.validate_threshold("min_edited_threshold", self.min_edited_threshold, 0.0, 1.0)?;
self.validate_threshold("min_ref_threshold", self.min_ref_threshold, 0.0, 1.0)?;
if self.min_coverage == 0 {
return Err(RedicatError::InvalidInput(
"min_coverage must be greater than 0".to_string(),
));
}
if self.chunk_size == 0 {
return Err(RedicatError::InvalidInput(
"chunk_size must be greater than 0".to_string(),
));
}
if self.chunk_size > 100_000 {
log::warn!(
"Large chunk_size ({}) may cause memory issues",
self.chunk_size
);
}
if self.num_threads == 0 {
return Err(RedicatError::InvalidInput(
"num_threads must be greater than 0".to_string(),
));
}
let max_threads = num_cpus::get() * 2;
if self.num_threads > max_threads {
log::warn!(
"num_threads ({}) exceeds recommended maximum ({})",
self.num_threads,
max_threads
);
}
Ok(())
}
fn validate_threshold(&self, name: &str, value: f64, min: f64, max: f64) -> Result<()> {
if !value.is_finite() {
return Err(RedicatError::InvalidInput(format!(
"{} must be a finite number",
name
)));
}
if value < min || value > max {
return Err(RedicatError::ThresholdValidation {
field: name.to_string(),
min,
max,
value,
});
}
Ok(())
}
}
pub fn validate_input_files(input: &str, fa: &str, rediportal: &str) -> Result<()> {
if !Path::new(input).exists() {
return Err(RedicatError::FileNotFound(format!(
"Input file not found: {}",
input
)));
}
if !input.ends_with(".h5ad") {
return Err(RedicatError::InvalidInput(
"Input file must have .h5ad extension".to_string(),
));
}
if !Path::new(fa).exists() {
return Err(RedicatError::FileNotFound(format!(
"Reference genome file not found: {}",
fa
)));
}
let fai_path = format!("{}.fai", fa);
if !Path::new(&fai_path).exists() {
return Err(RedicatError::FileNotFound(format!(
"FASTA index file not found: {}. Please create it using: samtools faidx {}",
fai_path, fa
)));
}
if !Path::new(rediportal).exists() {
return Err(RedicatError::FileNotFound(format!(
"REDIPortal file not found: {}",
rediportal
)));
}
Ok(())
}
pub fn validate_output_path(output: &str) -> Result<()> {
let path = Path::new(output);
if path.exists() {
std::fs::remove_file(path).map_err(|e| {
RedicatError::InvalidInput(format!(
"Failed to remove existing output file '{}': {}",
output, e
))
})?;
}
if !output.ends_with(".h5ad") {
return Err(RedicatError::InvalidInput(
"Output file must have .h5ad extension".to_string(),
));
}
Ok(())
}
pub fn validate_matrix_dimensions(
_matrix_name: &str,
actual_shape: (usize, usize),
expected_shape: (usize, usize),
) -> Result<()> {
if actual_shape != expected_shape {
return Err(RedicatError::DimensionMismatch {
expected: format!("{} × {}", expected_shape.0, expected_shape.1),
actual: format!("{} × {}", actual_shape.0, actual_shape.1),
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
fn valid_config() -> ValidationConfig {
ValidationConfig {
max_other_threshold: 0.01,
min_edited_threshold: 0.05,
min_ref_threshold: 0.1,
min_coverage: 5,
chunk_size: 1000,
num_threads: 2,
}
}
#[test]
fn validation_config_accepts_valid_values() {
valid_config().validate().unwrap();
}
#[test]
fn validation_config_rejects_invalid_thresholds_and_sizes() {
let mut cfg = valid_config();
cfg.max_other_threshold = 1.5;
let err = cfg.validate().unwrap_err();
assert!(format!("{}", err).contains("max_other_threshold"));
let mut cfg = valid_config();
cfg.min_ref_threshold = f64::NAN;
let err = cfg.validate().unwrap_err();
assert!(format!("{}", err).contains("finite number"));
let mut cfg = valid_config();
cfg.min_coverage = 0;
let err = cfg.validate().unwrap_err();
assert!(format!("{}", err).contains("min_coverage must be greater than 0"));
let mut cfg = valid_config();
cfg.chunk_size = 0;
let err = cfg.validate().unwrap_err();
assert!(format!("{}", err).contains("chunk_size must be greater than 0"));
let mut cfg = valid_config();
cfg.num_threads = 0;
let err = cfg.validate().unwrap_err();
assert!(format!("{}", err).contains("num_threads must be greater than 0"));
}
#[test]
fn validate_input_files_checks_presence_and_extensions() {
let dir = tempdir().unwrap();
let input = dir.path().join("input.h5ad");
let fasta = dir.path().join("ref.fa");
let fai = dir.path().join("ref.fa.fai");
let redi = dir.path().join("redi.tsv.gz");
fs::write(&input, b"dummy").unwrap();
fs::write(&fasta, b">chr1\nA\n").unwrap();
fs::write(&fai, b"chr1\t1\t6\t1\t2\n").unwrap();
fs::write(&redi, b"dummy").unwrap();
validate_input_files(
input.to_str().unwrap(),
fasta.to_str().unwrap(),
redi.to_str().unwrap(),
)
.unwrap();
let bad_input = dir.path().join("input.txt");
fs::write(&bad_input, b"dummy").unwrap();
let err = validate_input_files(
bad_input.to_str().unwrap(),
fasta.to_str().unwrap(),
redi.to_str().unwrap(),
)
.unwrap_err();
assert!(format!("{}", err).contains(".h5ad"));
let err = validate_input_files(
input.to_str().unwrap(),
dir.path().join("missing.fa").to_str().unwrap(),
redi.to_str().unwrap(),
)
.unwrap_err();
assert!(format!("{}", err).contains("Reference genome file not found"));
fs::remove_file(&fai).unwrap();
let err = validate_input_files(
input.to_str().unwrap(),
fasta.to_str().unwrap(),
redi.to_str().unwrap(),
)
.unwrap_err();
assert!(format!("{}", err).contains("FASTA index file not found"));
}
#[test]
fn validate_output_path_enforces_extension_and_overwrite() {
let dir = tempdir().unwrap();
let out = dir.path().join("out.h5ad");
fs::write(&out, b"old").unwrap();
validate_output_path(out.to_str().unwrap()).unwrap();
assert!(!out.exists());
let bad_out = dir.path().join("out.txt");
let err = validate_output_path(bad_out.to_str().unwrap()).unwrap_err();
assert!(format!("{}", err).contains("Output file must have .h5ad extension"));
}
#[test]
fn validate_matrix_dimensions_detects_mismatch() {
validate_matrix_dimensions("m", (10, 5), (10, 5)).unwrap();
let err = validate_matrix_dimensions("m", (10, 4), (10, 5)).unwrap_err();
assert!(format!("{}", err).contains("10 × 5"));
}
}