use std::collections::BTreeMap;
use std::path::PathBuf;
use anyhow::Result;
use clap::Args;
use log::info;
use crate::crf::backend::CrfSuiteModel;
use crate::crf::ClusterCRF;
use crate::hmmer;
use crate::io::tables::{FeatureTable, GeneTable};
use crate::model::Gene;
use crate::refine::ClusterRefiner;
#[derive(Args)]
pub struct CvArgs {
#[arg(short, long)]
pub genes: PathBuf,
#[arg(short, long, num_args = 1..)]
pub features: Vec<PathBuf>,
#[arg(short, long)]
pub clusters: PathBuf,
#[arg(short, long, default_value = "cv.tsv")]
pub output: PathBuf,
#[arg(short, long)]
pub e_filter: Option<f64>,
#[arg(short, long, default_value = "1e-9")]
pub p_filter: f64,
#[arg(long)]
pub no_shuffle: bool,
#[arg(long, default_value = "42")]
pub seed: u64,
#[arg(short = 'W', long, default_value = "5")]
pub window_size: usize,
#[arg(long, default_value = "1")]
pub window_step: usize,
#[arg(long, default_value = "0.15")]
pub c1: f64,
#[arg(long, default_value = "0.15")]
pub c2: f64,
#[arg(long, default_value = "protein")]
pub feature_type: String,
#[arg(long)]
pub select: Option<f64>,
#[arg(long)]
pub loto: bool,
#[arg(long, default_value = "10")]
pub splits: usize,
}
impl CvArgs {
pub fn execute(&self) -> Result<()> {
let mode = if self.loto {
"Leave-One-Type-Out".to_string()
} else {
format!("{}-fold", self.splits)
};
info!("Running {} cross-validation", mode);
info!("Loading gene table from {:?}", self.genes);
let mut genes = GeneTable::read_to_genes(std::fs::File::open(&self.genes)?)?;
for feat_path in &self.features {
info!("Loading features from {:?}", feat_path);
let feat_genes = FeatureTable::read_to_genes(std::fs::File::open(feat_path)?)?;
let domain_map: BTreeMap<String, Vec<_>> = feat_genes
.into_iter()
.map(|g| (g.protein.id.clone(), g.protein.domains))
.collect();
for gene in &mut genes {
if let Some(domains) = domain_map.get(&gene.protein.id) {
gene.protein.domains = domains.clone();
}
}
}
if let Some(e) = self.e_filter {
hmmer::filter_by_evalue(&mut genes, e);
}
hmmer::filter_by_pvalue(&mut genes, self.p_filter);
genes.sort_by(|a, b| a.source_id.cmp(&b.source_id).then(a.start.cmp(&b.start)));
let cluster_reader = std::fs::File::open(&self.clusters)?;
let mut cluster_rdr = csv::ReaderBuilder::new()
.delimiter(b'\t')
.from_reader(cluster_reader);
let mut cluster_ranges: BTreeMap<String, Vec<(i64, i64)>> = BTreeMap::new();
for result in cluster_rdr.deserialize::<crate::io::tables::ClusterRow>() {
let row = result?;
cluster_ranges
.entry(row.sequence_id.clone())
.or_default()
.push((row.start, row.end));
}
for gene in &mut genes {
let in_cluster = cluster_ranges
.get(&gene.source_id)
.map(|ranges| {
ranges
.iter()
.any(|(s, e)| gene.start <= *e && gene.end >= *s)
})
.unwrap_or(false);
gene.probability = Some(if in_cluster { 1.0 } else { 0.0 });
for domain in &mut gene.protein.domains {
domain.probability = Some(if in_cluster { 1.0 } else { 0.0 });
}
}
let mut sequences: Vec<Vec<Gene>> = Vec::new();
let mut current_source: Option<String> = None;
for gene in &genes {
match ¤t_source {
Some(s) if s == &gene.source_id => {
sequences.last_mut().unwrap().push(gene.clone());
}
_ => {
current_source = Some(gene.source_id.clone());
sequences.push(vec![gene.clone()]);
}
}
}
let n_seqs = sequences.len();
let n_folds = if self.loto {
n_seqs
} else {
self.splits.min(n_seqs)
};
info!("Running {} folds on {} sequences", n_folds, n_seqs);
use std::io::Write;
let mut out = std::fs::File::create(&self.output)?;
writeln!(out, "fold\tsequence\tn_genes\tn_clusters")?;
for fold in 0..n_folds {
info!("Fold {}/{}", fold + 1, n_folds);
let mut train_genes = Vec::new();
let mut test_genes = Vec::new();
for (i, seq) in sequences.iter().enumerate() {
if i % n_folds == fold {
test_genes.extend(seq.clone());
} else {
train_genes.extend(seq.clone());
}
}
let crf_model = CrfSuiteModel::empty();
let mut crf = ClusterCRF::new(&self.feature_type, self.window_size, self.window_step);
crf.set_model(Box::new(crf_model));
if crf.fit(&train_genes, !self.no_shuffle).is_err() {
log::warn!("Training failed for fold {}", fold);
continue;
}
let predicted = crf.predict_probabilities(&test_genes, true, None)?;
let refiner = ClusterRefiner::default();
let clusters = refiner.iter_clusters(&predicted);
let test_source = test_genes
.first()
.map(|g| g.source_id.as_str())
.unwrap_or("?");
writeln!(
out,
"{}\t{}\t{}\t{}",
fold,
test_source,
test_genes.len(),
clusters.len()
)?;
}
info!("Cross-validation results written to {:?}", self.output);
Ok(())
}
}