use rand::Rng;
use crate::error::{Result, WriteError};
pub const RNG_SEED: u64 = 42;
#[derive(Debug, Clone, Copy, Default)]
pub enum Policy {
#[default]
IgnoreSequence,
BreakOnInvalid,
RandomDraw,
SetToA,
SetToC,
SetToG,
SetToT,
}
impl Policy {
fn fill_with_known(sequence: &[u8], val: u8, ibuf: &mut Vec<u8>) {
for &n in sequence {
ibuf.push(match n {
b'A' | b'C' | b'G' | b'T' => n,
_ => val,
});
}
}
fn fill_with_random<R: Rng>(sequence: &[u8], rng: &mut R, ibuf: &mut Vec<u8>) {
for &n in sequence {
ibuf.push(match n {
b'A' | b'C' | b'G' | b'T' => n,
_ => match rng.random_range(0..4) {
0 => b'A',
1 => b'C',
2 => b'G',
3 => b'T',
_ => unreachable!(),
},
});
}
}
pub fn handle<R: Rng>(&self, sequence: &[u8], ibuf: &mut Vec<u8>, rng: &mut R) -> Result<bool> {
ibuf.clear();
match self {
Self::IgnoreSequence => Ok(false),
Self::BreakOnInvalid => {
let seq_str = std::str::from_utf8(sequence)?.to_string();
Err(WriteError::InvalidNucleotideSequence(seq_str).into())
}
Self::RandomDraw => {
Self::fill_with_random(sequence, rng, ibuf);
Ok(true)
}
Self::SetToA => {
Self::fill_with_known(sequence, b'A', ibuf);
Ok(true)
}
Self::SetToC => {
Self::fill_with_known(sequence, b'C', ibuf);
Ok(true)
}
Self::SetToG => {
Self::fill_with_known(sequence, b'G', ibuf);
Ok(true)
}
Self::SetToT => {
Self::fill_with_known(sequence, b'T', ibuf);
Ok(true)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
use rand::rngs::StdRng;
#[test]
fn test_default_policy() {
let policy = Policy::default();
assert!(matches!(policy, Policy::IgnoreSequence));
}
#[test]
fn test_ignore_sequence_policy() {
let policy = Policy::IgnoreSequence;
let sequence = b"ACGTNX";
let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
let should_process = policy.handle(sequence, &mut output, &mut rng).unwrap();
assert!(!should_process); assert!(output.is_empty()); }
#[test]
fn test_break_on_invalid_policy() {
let policy = Policy::BreakOnInvalid;
let sequence = b"ACGTNX";
let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
let result = policy.handle(sequence, &mut output, &mut rng);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
crate::error::Error::WriteError(WriteError::InvalidNucleotideSequence(_))
));
}
#[test]
fn test_break_on_invalid_with_valid_sequence() {
let policy = Policy::BreakOnInvalid;
let sequence = b"ACGT";
let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
let result = policy.handle(sequence, &mut output, &mut rng);
assert!(result.is_err());
}
#[test]
fn test_set_to_a_policy() {
let policy = Policy::SetToA;
let sequence = b"ACGTNX";
let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
let should_process = policy.handle(sequence, &mut output, &mut rng).unwrap();
assert!(should_process); assert_eq!(output, b"ACGTAA"); }
#[test]
fn test_set_to_c_policy() {
let policy = Policy::SetToC;
let sequence = b"ACGTNX";
let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
let should_process = policy.handle(sequence, &mut output, &mut rng).unwrap();
assert!(should_process);
assert_eq!(output, b"ACGTCC"); }
#[test]
fn test_set_to_g_policy() {
let policy = Policy::SetToG;
let sequence = b"ACGTNX";
let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
let should_process = policy.handle(sequence, &mut output, &mut rng).unwrap();
assert!(should_process);
assert_eq!(output, b"ACGTGG"); }
#[test]
fn test_set_to_t_policy() {
let policy = Policy::SetToT;
let sequence = b"ACGTNX";
let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
let should_process = policy.handle(sequence, &mut output, &mut rng).unwrap();
assert!(should_process);
assert_eq!(output, b"ACGTTT"); }
#[test]
fn test_all_valid_nucleotides_unchanged() {
let policy = Policy::SetToA;
let sequence = b"ACGTACGT";
let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
let should_process = policy.handle(sequence, &mut output, &mut rng).unwrap();
assert!(should_process);
assert_eq!(output, b"ACGTACGT"); }
#[test]
fn test_random_draw_policy() {
let policy = Policy::RandomDraw;
let sequence = b"ACGTNX";
let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
let should_process = policy.handle(sequence, &mut output, &mut rng).unwrap();
assert!(should_process);
assert_eq!(output.len(), 6); assert_eq!(&output[0..4], b"ACGT");
assert!(matches!(output[4], b'A' | b'C' | b'G' | b'T'));
assert!(matches!(output[5], b'A' | b'C' | b'G' | b'T'));
}
#[test]
fn test_random_draw_deterministic_with_seed() {
let policy = Policy::RandomDraw;
let sequence = b"NNNN";
let mut output1 = Vec::new();
let mut output2 = Vec::new();
let mut rng1 = StdRng::seed_from_u64(RNG_SEED);
let mut rng2 = StdRng::seed_from_u64(RNG_SEED);
policy.handle(sequence, &mut output1, &mut rng1).unwrap();
policy.handle(sequence, &mut output2, &mut rng2).unwrap();
assert_eq!(output1, output2);
}
#[test]
fn test_buffer_cleared_before_processing() {
let policy = Policy::SetToA;
let sequence = b"ACGT";
let mut output = vec![b'X', b'Y', b'Z']; let mut rng = StdRng::seed_from_u64(RNG_SEED);
policy.handle(sequence, &mut output, &mut rng).unwrap();
assert_eq!(output, b"ACGT");
}
#[test]
fn test_multiple_calls_clear_buffer() {
let policy = Policy::SetToA;
let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
policy.handle(b"ACGT", &mut output, &mut rng).unwrap();
assert_eq!(output, b"ACGT");
policy.handle(b"TT", &mut output, &mut rng).unwrap();
assert_eq!(output, b"TT"); }
#[test]
fn test_empty_sequence() {
let policy = Policy::SetToA;
let sequence = b"";
let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
let should_process = policy.handle(sequence, &mut output, &mut rng).unwrap();
assert!(should_process);
assert!(output.is_empty());
}
#[test]
fn test_all_invalid_nucleotides() {
let policy = Policy::SetToG;
let sequence = b"NNNXXX";
let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
let should_process = policy.handle(sequence, &mut output, &mut rng).unwrap();
assert!(should_process);
assert_eq!(output, b"GGGGGG"); }
#[test]
fn test_policy_clone() {
let policy1 = Policy::SetToA;
let policy2 = policy1;
let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
policy1.handle(b"NT", &mut output, &mut rng).unwrap();
assert_eq!(output, b"AT");
policy2.handle(b"NT", &mut output, &mut rng).unwrap();
assert_eq!(output, b"AT");
}
#[test]
fn test_policy_debug() {
let policy = Policy::SetToA;
let debug_str = format!("{:?}", policy);
assert!(debug_str.contains("SetToA"));
}
#[test]
fn test_lowercase_nucleotides_treated_as_invalid() {
let policy = Policy::SetToA;
let sequence = b"acgt"; let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
policy.handle(sequence, &mut output, &mut rng).unwrap();
assert_eq!(output, b"AAAA");
}
#[test]
fn test_mixed_case_nucleotides() {
let policy = Policy::SetToC;
let sequence = b"AcGt";
let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
policy.handle(sequence, &mut output, &mut rng).unwrap();
assert_eq!(output, b"ACGC"); }
#[test]
fn test_ambiguous_nucleotide_codes() {
let policy = Policy::SetToT;
let sequence = b"RYWSMK"; let mut output = Vec::new();
let mut rng = StdRng::seed_from_u64(RNG_SEED);
policy.handle(sequence, &mut output, &mut rng).unwrap();
assert_eq!(output, b"TTTTTT"); }
}