use std::collections::{BTreeMap, HashSet};
use std::path::PathBuf;
use anyhow::Result;
use clap::Args;
use log::info;
use crate::crf::backend::CrfSuiteModel;
use crate::crf::ClusterCRF;
use crate::hmmer;
use crate::io::tables::{FeatureTable, GeneTable};
#[derive(Args)]
pub struct TrainArgs {
#[arg(short, long)]
pub genes: PathBuf,
#[arg(short, long, num_args = 1..)]
pub features: Vec<PathBuf>,
#[arg(short, long)]
pub clusters: PathBuf,
#[arg(short, long, default_value = ".")]
pub output_dir: PathBuf,
#[arg(short, long)]
pub e_filter: Option<f64>,
#[arg(short, long, default_value = "1e-9")]
pub p_filter: f64,
#[arg(long)]
pub no_shuffle: bool,
#[arg(long, default_value = "42")]
pub seed: u64,
#[arg(short = 'W', long, default_value = "5")]
pub window_size: usize,
#[arg(long, default_value = "1")]
pub window_step: usize,
#[arg(long, default_value = "0.15")]
pub c1: f64,
#[arg(long, default_value = "0.15")]
pub c2: f64,
#[arg(long, default_value = "protein")]
pub feature_type: String,
#[arg(long)]
pub select: Option<f64>,
}
impl TrainArgs {
pub fn execute(&self) -> Result<()> {
std::fs::create_dir_all(&self.output_dir)?;
info!("Loading gene table from {:?}", self.genes);
let mut genes =
GeneTable::read_to_genes(std::fs::File::open(&self.genes)?)?;
for feat_path in &self.features {
info!("Loading features from {:?}", feat_path);
let feat_genes =
FeatureTable::read_to_genes(std::fs::File::open(feat_path)?)?;
let domain_map: BTreeMap<String, Vec<_>> = feat_genes
.into_iter()
.map(|g| (g.protein.id.clone(), g.protein.domains))
.collect();
for gene in &mut genes {
if let Some(domains) = domain_map.get(&gene.protein.id) {
gene.protein.domains = domains.clone();
}
}
}
if let Some(e) = self.e_filter {
hmmer::filter_by_evalue(&mut genes, e);
}
hmmer::filter_by_pvalue(&mut genes, self.p_filter);
genes.sort_by(|a, b| {
a.source_id
.cmp(&b.source_id)
.then(a.start.cmp(&b.start))
});
info!("Loading cluster annotations from {:?}", self.clusters);
let cluster_reader = std::fs::File::open(&self.clusters)?;
let mut cluster_rdr = csv::ReaderBuilder::new()
.delimiter(b'\t')
.from_reader(cluster_reader);
let mut cluster_ranges: BTreeMap<String, Vec<(i64, i64, String)>> =
BTreeMap::new();
for result in cluster_rdr.deserialize::<crate::io::tables::ClusterRow>() {
let row = result?;
cluster_ranges
.entry(row.sequence_id.clone())
.or_default()
.push((row.start, row.end, row.cluster_id));
}
for gene in &mut genes {
let in_cluster = cluster_ranges
.get(&gene.source_id)
.map(|ranges| {
ranges
.iter()
.any(|(s, e, _)| gene.start <= *e && gene.end >= *s)
})
.unwrap_or(false);
gene.probability = Some(if in_cluster { 1.0 } else { 0.0 });
for domain in &mut gene.protein.domains {
domain.probability = Some(if in_cluster { 1.0 } else { 0.0 });
}
}
info!(
"Training CRF model (feature_type={}, window={}x{})",
self.feature_type, self.window_size, self.window_step
);
let mut crf = ClusterCRF::new(&self.feature_type, self.window_size, self.window_step);
let init_model: Box<dyn crate::crf::CrfModel> = Box::new(CrfSuiteModel::empty());
crf.set_model(init_model);
crf.fit(&genes, !self.no_shuffle)?;
info!("CRF model trained successfully");
let all_domains: Vec<String> = {
let mut set: HashSet<String> = HashSet::new();
for gene in &genes {
for domain in &gene.protein.domains {
set.insert(domain.name.clone());
}
}
let mut v: Vec<String> = set.into_iter().collect();
v.sort();
v
};
let domains_path = self.output_dir.join("domains.tsv");
let mut domains_file = std::fs::File::create(&domains_path)?;
use std::io::Write;
for d in &all_domains {
writeln!(domains_file, "{}", d)?;
}
info!("Finished training");
Ok(())
}
}