use crate::benchmark::types::PatternCategory;
use crate::FerroError;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::Path;
pub fn stratified_sample<P: AsRef<Path>>(
input: P,
output: P,
sample_size: usize,
seed: u64,
exclude_protein: bool,
) -> Result<SampleStats, FerroError> {
let input = input.as_ref();
let output = output.as_ref();
let file = File::open(input).map_err(|e| FerroError::Io {
msg: format!("Failed to open {}: {}", input.display(), e),
})?;
let reader = BufReader::new(file);
let mut by_category: HashMap<PatternCategory, Vec<String>> = HashMap::new();
let mut total = 0usize;
let mut protein_excluded = 0usize;
for line in reader.lines() {
let line = line.map_err(|e| FerroError::Io {
msg: format!("Error reading: {}", e),
})?;
let pattern = line.trim();
if pattern.is_empty() {
continue;
}
let category = PatternCategory::categorize(pattern);
if exclude_protein && category.is_protein() {
protein_excluded += 1;
continue;
}
by_category
.entry(category)
.or_default()
.push(pattern.to_string());
total += 1;
}
if exclude_protein && protein_excluded > 0 {
eprintln!(
" Excluded {} protein patterns (--exclude-protein)",
protein_excluded
);
}
if total == 0 {
return Err(FerroError::Io {
msg: "Input file is empty".to_string(),
});
}
let mut sampled = Vec::new();
let mut stats = SampleStats {
total_patterns: total,
sample_size: 0,
by_category: HashMap::new(),
};
let mut rng = SimpleRng::new(seed);
let mut categories: Vec<_> = by_category.into_iter().collect();
categories.sort_by_key(|(cat, _)| format!("{:?}", cat));
for (category, mut patterns) in categories {
let proportion = patterns.len() as f64 / total as f64;
let cat_sample_size = ((sample_size as f64 * proportion).ceil() as usize).max(1);
let actual_sample = cat_sample_size.min(patterns.len());
patterns.sort();
rng.shuffle(&mut patterns);
sampled.extend(patterns.into_iter().take(actual_sample));
stats.by_category.insert(
category,
CategoryStats {
total,
sampled: actual_sample,
proportion,
},
);
}
rng.shuffle(&mut sampled);
sampled.truncate(sample_size);
stats.sample_size = sampled.len();
if let Some(parent) = output.parent() {
std::fs::create_dir_all(parent).map_err(|e| FerroError::Io {
msg: format!("Failed to create directory {}: {}", parent.display(), e),
})?;
}
let out_file = File::create(output).map_err(|e| FerroError::Io {
msg: format!("Failed to create {}: {}", output.display(), e),
})?;
let mut writer = BufWriter::new(out_file);
for pattern in &sampled {
writeln!(writer, "{}", pattern).map_err(|e| FerroError::Io {
msg: format!("Error writing: {}", e),
})?;
}
eprintln!(
"Created stratified sample: {} patterns from {} total",
stats.sample_size, stats.total_patterns
);
Ok(stats)
}
pub fn stratified_sample_vec(
patterns: &[String],
sample_size: usize,
seed: u64,
exclude_protein: bool,
) -> Result<Vec<String>, FerroError> {
if patterns.is_empty() {
return Err(FerroError::Io {
msg: "Input is empty".to_string(),
});
}
let mut by_category: HashMap<PatternCategory, Vec<String>> = HashMap::new();
let mut protein_excluded = 0usize;
for pattern in patterns {
let category = PatternCategory::categorize(pattern);
if exclude_protein && category.is_protein() {
protein_excluded += 1;
continue;
}
by_category
.entry(category)
.or_default()
.push(pattern.clone());
}
if exclude_protein && protein_excluded > 0 {
eprintln!(" Excluded {} protein patterns", protein_excluded);
}
let total: usize = by_category.values().map(|v| v.len()).sum();
if total == 0 {
return Err(FerroError::Io {
msg: "No patterns remaining after filtering".to_string(),
});
}
let mut sampled = Vec::new();
let mut rng = SimpleRng::new(seed);
let mut categories: Vec<_> = by_category.into_iter().collect();
categories.sort_by_key(|(cat, _)| format!("{:?}", cat));
for (_category, mut cat_patterns) in categories {
let proportion = cat_patterns.len() as f64 / total as f64;
let cat_sample_size = ((sample_size as f64 * proportion).ceil() as usize).max(1);
let actual_sample = cat_sample_size.min(cat_patterns.len());
cat_patterns.sort();
rng.shuffle(&mut cat_patterns);
sampled.extend(cat_patterns.into_iter().take(actual_sample));
}
rng.shuffle(&mut sampled);
sampled.truncate(sample_size);
Ok(sampled)
}
#[derive(Debug, Clone)]
pub struct SampleStats {
pub total_patterns: usize,
pub sample_size: usize,
pub by_category: HashMap<PatternCategory, CategoryStats>,
}
#[derive(Debug, Clone)]
pub struct CategoryStats {
pub total: usize,
pub sampled: usize,
pub proportion: f64,
}
struct SimpleRng {
state: u64,
}
impl SimpleRng {
fn new(seed: u64) -> Self {
Self {
state: if seed == 0 { 1 } else { seed },
}
}
fn next(&mut self) -> u64 {
let mut x = self.state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.state = x;
x
}
fn shuffle<T>(&mut self, slice: &mut [T]) {
let len = slice.len();
for i in (1..len).rev() {
let j = (self.next() as usize) % (i + 1);
slice.swap(i, j);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_stratified_sample() {
let dir = TempDir::new().unwrap();
let input = dir.path().join("input.txt");
let output = dir.path().join("sample.txt");
let mut f = File::create(&input).unwrap();
for i in 0..5 {
writeln!(f, "NC_000001.11:g.{}A>G", 1000 + i).unwrap();
}
for i in 0..3 {
writeln!(f, "NM_000088.3:c.{}G>T", 100 + i).unwrap();
}
writeln!(f, "NP_000079.2:p.Val200Met").unwrap();
writeln!(f, "NP_000079.2:p.Arg300Gln").unwrap();
let stats = stratified_sample(&input, &output, 5, 42, false).unwrap();
assert_eq!(stats.total_patterns, 10);
assert!(stats.sample_size <= 5);
let content = std::fs::read_to_string(&output).unwrap();
let lines: Vec<&str> = content.lines().collect();
assert!(lines.len() <= 5);
let output_no_protein = dir.path().join("sample_no_protein.txt");
let stats_no_protein = stratified_sample(&input, &output_no_protein, 5, 42, true).unwrap();
assert_eq!(stats_no_protein.total_patterns, 8);
let content = std::fs::read_to_string(&output_no_protein).unwrap();
for line in content.lines() {
assert!(
!line.contains(":p."),
"Found protein pattern in exclude_protein output: {}",
line
);
}
}
#[test]
fn test_pattern_categorization() {
assert_eq!(
PatternCategory::categorize("NC_000001.11:g.12345A>G"),
PatternCategory::GenomicSnv
);
assert_eq!(
PatternCategory::categorize("NC_000001.11:g.12345del"),
PatternCategory::GenomicDel
);
assert_eq!(
PatternCategory::categorize("NM_000088.3:c.589G>T"),
PatternCategory::CodingSnv
);
assert_eq!(
PatternCategory::categorize("NM_000088.3:c.589+5G>T"),
PatternCategory::CodingIntronic
);
assert_eq!(
PatternCategory::categorize("NP_000079.2:p.Val200Met"),
PatternCategory::ProteinSub
);
assert_eq!(
PatternCategory::categorize("NC_012920.1:m.8993T>G"),
PatternCategory::Mitochondrial
);
}
}