space_trav_lr_rust 1.3.0

Spatial gene regulatory network inference and in-silico perturbation (Rust port of SpaceTravLR)
use crate::cnn_gating::CnnGateDecision;
use crate::estimator::ClusterTrainingSummary;
use serde::Serialize;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::Path;

#[derive(Debug, Clone, Serialize)]
pub struct ClusterTrainingLogRow {
    pub cluster_id: usize,
    pub n_cells: usize,
    pub n_modulators: usize,
    pub lasso_r2: f64,
    pub lasso_train_mse: f64,
    pub lasso_fista_iters: usize,
    pub lasso_converged: bool,
    pub cnn_epochs_ran: usize,
    pub cnn_mse_first: Option<f64>,
    pub cnn_mse_last: Option<f64>,
}

#[derive(Debug, Clone, Serialize)]
pub struct GeneTrainingRollup {
    pub gene: String,
    pub seed_only: bool,
    pub per_cell_cnn_export: bool,
    pub cnn_epochs_config: usize,
    pub learning_rate: f64,
    pub lasso_n_iter_max: usize,
    pub lasso_tol: f64,
    pub mean_lasso_r2: f64,
    pub min_lasso_r2: f64,
    pub max_lasso_r2: f64,
    pub frac_lasso_converged: f64,
    pub sum_cnn_epochs_ran: usize,
    pub n_clusters: usize,
    pub gate: Option<HashMap<String, String>>,
    pub clusters: Vec<ClusterTrainingLogRow>,
}

fn parse_f64_cell(s: &str) -> Option<f64> {
    if s == "NA" || s.is_empty() {
        return None;
    }
    s.parse().ok()
}

fn parse_cluster_summary_row(cols: &[&str]) -> Option<ClusterTrainingLogRow> {
    if cols.len() < 10 {
        return None;
    }
    Some(ClusterTrainingLogRow {
        cluster_id: cols[0].parse().ok()?,
        n_cells: cols[1].parse().ok()?,
        n_modulators: cols[2].parse().ok()?,
        lasso_r2: cols[3].parse().ok()?,
        lasso_train_mse: cols[4].parse().ok()?,
        lasso_fista_iters: cols[5].parse().ok()?,
        lasso_converged: cols[6] == "true",
        cnn_epochs_ran: cols[7].parse().ok()?,
        cnn_mse_first: parse_f64_cell(cols[8]),
        cnn_mse_last: parse_f64_cell(cols[9]),
    })
}

pub fn parse_gene_training_log(path: &Path) -> anyhow::Result<Option<GeneTrainingRollup>> {
    let f = match File::open(path) {
        Ok(f) => f,
        Err(_) => return Ok(None),
    };
    let reader = BufReader::new(f);
    let lines: Vec<String> = reader.lines().collect::<Result<_, _>>()?;
    parse_gene_training_log_lines(&lines)
}

fn parse_gene_training_log_lines(lines: &[String]) -> anyhow::Result<Option<GeneTrainingRollup>> {
    let mut gene = String::new();
    let mut seed_only = false;
    let mut per_cell_cnn_export = false;
    let mut cnn_epochs_config = 0usize;
    let mut learning_rate = 0.0f64;
    let mut lasso_n_iter_max = 0usize;
    let mut lasso_tol = 0.0f64;
    let mut gate: Option<HashMap<String, String>> = None;
    let mut clusters: Vec<ClusterTrainingLogRow> = Vec::new();

    let mut i = 0usize;
    if lines.is_empty() {
        return Ok(None);
    }
    let first = lines[0].trim();
    if first != "format\tspacetravlr_training_log\tv1" {
        return Ok(None);
    }
    i += 1;

    while i < lines.len() {
        let line = lines[i].trim_end();
        if line.is_empty() {
            i += 1;
            continue;
        }
        if line.starts_with("# hybrid_cnn_gate") {
            i += 1;
            let mut g = HashMap::new();
            while i < lines.len() {
                let cur_line = lines[i].trim_end();
                if cur_line.is_empty() {
                    break;
                }
                if cur_line.starts_with('#') {
                    break;
                }
                if let Some((k, v)) = cur_line.split_once('\t') {
                    if k.starts_with("gate_") {
                        g.insert(k.to_string(), v.to_string());
                    }
                }
                i += 1;
            }
            if !g.is_empty() {
                gate = Some(g);
            }
            continue;
        }
        if line.starts_with("# summary:") {
            i += 1;
            if i >= lines.len() {
                break;
            }
            i += 1;
            while i < lines.len() {
                let cur_line = lines[i].trim_end();
                if cur_line.is_empty() {
                    break;
                }
                if cur_line.starts_with('#') {
                    break;
                }
                let cols: Vec<&str> = cur_line.split('\t').collect();
                if let Some(row) = parse_cluster_summary_row(&cols) {
                    clusters.push(row);
                }
                i += 1;
            }
            break;
        }
        if line.starts_with('#') {
            i += 1;
            continue;
        }
        if let Some((k, v)) = line.split_once('\t') {
            match k {
                "gene" => gene = v.to_string(),
                "seed_only" => seed_only = v == "true",
                "per_cell_cnn_export" => per_cell_cnn_export = v == "true",
                "cnn_epochs_config" => cnn_epochs_config = v.parse().unwrap_or(0),
                "learning_rate" => learning_rate = v.parse().unwrap_or(0.0),
                "lasso_n_iter_max" => lasso_n_iter_max = v.parse().unwrap_or(0),
                "lasso_tol" => lasso_tol = v.parse().unwrap_or(0.0),
                _ => {}
            }
        }
        i += 1;
    }

    if gene.is_empty() || clusters.is_empty() {
        return Ok(None);
    }

    let n = clusters.len() as f64;
    let sum_r2: f64 = clusters.iter().map(|c| c.lasso_r2).sum();
    let mean_lasso_r2 = if n > 0.0 { sum_r2 / n } else { 0.0 };
    let min_lasso_r2 = clusters
        .iter()
        .map(|c| c.lasso_r2)
        .fold(f64::INFINITY, f64::min);
    let max_lasso_r2 = clusters
        .iter()
        .map(|c| c.lasso_r2)
        .fold(f64::NEG_INFINITY, f64::max);
    let conv = clusters.iter().filter(|c| c.lasso_converged).count() as f64;
    let frac_lasso_converged = if n > 0.0 { conv / n } else { 0.0 };
    let sum_cnn_epochs_ran: usize = clusters.iter().map(|c| c.cnn_epochs_ran).sum();

    Ok(Some(GeneTrainingRollup {
        gene,
        seed_only,
        per_cell_cnn_export,
        cnn_epochs_config,
        learning_rate,
        lasso_n_iter_max,
        lasso_tol,
        mean_lasso_r2,
        min_lasso_r2,
        max_lasso_r2,
        frac_lasso_converged,
        sum_cnn_epochs_ran,
        n_clusters: clusters.len(),
        gate,
        clusters,
    }))
}

pub fn scan_gene_training_logs(log_dir: &Path) -> anyhow::Result<Vec<GeneTrainingRollup>> {
    let mut out = Vec::new();
    if !log_dir.is_dir() {
        return Ok(out);
    }
    for e in std::fs::read_dir(log_dir)? {
        let e = e?;
        if !e.file_type()?.is_file() {
            continue;
        }
        let p = e.path();
        if p.extension().and_then(|s| s.to_str()) != Some("log") {
            continue;
        }
        if let Some(r) = parse_gene_training_log(&p)? {
            out.push(r);
        }
    }
    out.sort_by(|a, b| a.gene.to_lowercase().cmp(&b.gene.to_lowercase()));
    Ok(out)
}

pub fn write_gene_training_log(
    log_path: &Path,
    gene: &str,
    seed_only: bool,
    per_cell_cnn_export: bool,
    epochs: usize,
    learning_rate: f64,
    lasso_n_iter_max: usize,
    lasso_tol: f64,
    summaries: &[ClusterTrainingSummary],
    gate: Option<&CnnGateDecision>,
) -> std::io::Result<()> {
    if let Some(parent) = log_path.parent() {
        std::fs::create_dir_all(parent)?;
    }
    let f = File::create(log_path)?;
    let mut w = BufWriter::with_capacity(256 * 1024, f);

    writeln!(w, "format\tspacetravlr_training_log\tv1")?;
    writeln!(w, "gene\t{}", gene)?;
    writeln!(w, "seed_only\t{}", seed_only)?;
    writeln!(w, "per_cell_cnn_export\t{}", per_cell_cnn_export)?;
    writeln!(w, "cnn_epochs_config\t{}", epochs)?;
    writeln!(w, "learning_rate\t{}", learning_rate)?;
    writeln!(w, "lasso_n_iter_max\t{}", lasso_n_iter_max)?;
    writeln!(w, "lasso_tol\t{}", lasso_tol)?;
    writeln!(w)?;

    if let Some(g) = gate {
        writeln!(
            w,
            "# hybrid_cnn_gate (empty use_cnn means non-hybrid or pass-2 full CNN)"
        )?;
        writeln!(w, "gate_use_cnn\t{}", g.use_cnn)?;
        writeln!(w, "gate_reason\t{}", g.reason.replace('\t', " "))?;
        writeln!(w, "gate_min_cells_per_cluster\t{}", g.min_cells_per_cluster)?;
        writeln!(w, "gate_n_modulators\t{}", g.n_modulators)?;
        writeln!(w, "gate_n_lr_pairs\t{}", g.n_lr_pairs)?;
        writeln!(w, "gate_n_tfl_pairs\t{}", g.n_tfl_pairs)?;
        writeln!(
            w,
            "gate_modulator_spatial_fraction\t{:.6}",
            g.modulator_spatial_fraction
        )?;
        writeln!(w, "gate_mean_lasso_r2\t{:.6}", g.mean_lasso_r2)?;
        writeln!(w, "gate_all_lasso_converged\t{}", g.all_lasso_converged)?;
        writeln!(w, "gate_moran_i\t{:.8}", g.moran_i)?;
        writeln!(w, "gate_moran_p_value\t{:.8}", g.moran_p_value)?;
        writeln!(w, "gate_moran_permutations\t{}", g.moran_permutations)?;
        writeln!(w, "gate_forced_allowlist\t{}", g.forced_by_allowlist)?;
        writeln!(w, "gate_blocked_skip_list\t{}", g.blocked_by_denylist)?;
        if let Some(m) = g.mean_target_expression {
            writeln!(w, "gate_mean_target_expression\t{:.8}", m)?;
        } else {
            writeln!(w, "gate_mean_target_expression\tNA")?;
        }
        writeln!(w, "gate_rank_score\t{:.6}", g.rank_score)?;
        writeln!(w)?;
    }

    writeln!(
        w,
        "# summary: cluster_id, n_cells, n_modulators, lasso_r2, lasso_train_mse, lasso_fista_iters, lasso_converged, cnn_epochs_ran, cnn_mse_first, cnn_mse_last"
    )?;
    writeln!(
        w,
        "cluster_id\tn_cells\tn_modulators\tlasso_r2\tlasso_train_mse\tlasso_fista_iters\tlasso_converged\tcnn_epochs_ran\tcnn_mse_first\tcnn_mse_last"
    )?;

    for s in summaries {
        let (ran, first_s, last_s) = if s.cnn_train_mse_epochs.is_empty() {
            (0usize, "NA".to_string(), "NA".to_string())
        } else {
            let v = &s.cnn_train_mse_epochs;
            (
                v.len(),
                format!("{:.6}", v[0]),
                format!("{:.6}", v.last().expect("nonempty")),
            )
        };
        writeln!(
            w,
            "{}\t{}\t{}\t{:.6}\t{:.6}\t{}\t{}\t{}\t{}\t{}",
            s.cluster_id,
            s.n_cells,
            s.n_modulators,
            s.lasso_r2,
            s.lasso_train_mse,
            s.lasso_fista_iters,
            s.lasso_converged,
            ran,
            first_s,
            last_s,
        )?;
    }

    writeln!(w)?;
    writeln!(w, "# cnn_mse_by_epoch: cluster_id, epoch, train_mse")?;
    writeln!(w, "cluster_id\tepoch\ttrain_mse")?;
    for s in summaries {
        for (epoch, &mse) in s.cnn_train_mse_epochs.iter().enumerate() {
            writeln!(w, "{}\t{}\t{:.6}", s.cluster_id, epoch, mse)?;
        }
    }

    w.flush()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parses_kit_style_log() {
        let raw = "format\tspacetravlr_training_log\tv1
gene\tKit
seed_only\ttrue
per_cell_cnn_export\tfalse
cnn_epochs_config\t10
learning_rate\t0.0002
lasso_n_iter_max\t100
lasso_tol\t0.0001

# summary: cluster_id, n_cells, n_modulators, lasso_r2, lasso_train_mse, lasso_fista_iters, lasso_converged, cnn_epochs_ran, cnn_mse_first, cnn_mse_last
cluster_id\tn_cells\tn_modulators\tlasso_r2\tlasso_train_mse\tlasso_fista_iters\tlasso_converged\tcnn_epochs_ran\tcnn_mse_first\tcnn_mse_last
0\t1107\t376\t0.284636\t0.000007\t100\tfalse\t0\tNA\tNA
1\t841\t376\t0.011789\t0.000004\t100\tfalse\t0\tNA\tNA
";
        let lines: Vec<String> = raw.lines().map(|s| s.to_string()).collect();
        let r = parse_gene_training_log_lines(&lines)
            .unwrap()
            .expect("rollup");
        assert_eq!(r.gene, "Kit");
        assert!(r.seed_only);
        assert_eq!(r.clusters.len(), 2);
        assert_eq!(r.clusters[0].cluster_id, 0);
        assert_eq!(r.clusters[0].cnn_mse_first, None);
        assert!((r.mean_lasso_r2 - 0.1482125).abs() < 1e-5);
    }
}