space_trav_lr_rust 1.3.0

Spatial gene regulatory network inference and in-silico perturbation (Rust port of SpaceTravLR)
use burn::backend::ndarray::NdArrayDevice;
use burn::backend::wgpu::WgpuDevice;
use burn::backend::{NdArray, Wgpu};
use burn_autodiff::Autodiff;
use space_trav_lr_rust::config::{
    CnnConfig, CnnTrainingMode, HybridCnnGatingConfig, ModelExportConfig, SpaceshipConfig,
};
use space_trav_lr_rust::spatial_estimator::SpatialCellularProgramsEstimator;
use space_trav_lr_rust::training_hud::TrainingHud;
use std::sync::OnceLock;

#[derive(Clone, Debug)]
pub(crate) enum ComputeChoice {
    Wgpu(WgpuDevice),
    NdArray(NdArrayDevice),
}

impl ComputeChoice {
    pub(crate) fn label(&self) -> &'static str {
        match self {
            ComputeChoice::Wgpu(_) => "WebGPU",
            ComputeChoice::NdArray(_) => "CPU (NdArray)",
        }
    }
}

fn env_truthy(name: &str) -> bool {
    std::env::var(name)
        .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
        .unwrap_or(false)
}

/// Burn **WebGPU** backend when the `wgpu` crate can request an adapter, otherwise Burn **NdArray** on CPU.
/// `SPACETRAVLR_FORCE_CPU` / `SPACETRAVLR_DISABLE_WGPU` force CPU (no adapter probe).
pub(crate) fn select_compute_backend() -> ComputeChoice {
    if env_truthy("SPACETRAVLR_FORCE_CPU") || env_truthy("SPACETRAVLR_DISABLE_WGPU") {
        return ComputeChoice::NdArray(NdArrayDevice::Cpu);
    }

    match wgpu_adapter_probe_cached().as_ref() {
        Some(_) => ComputeChoice::Wgpu(WgpuDevice::default()),
        None => ComputeChoice::NdArray(NdArrayDevice::Cpu),
    }
}

static WGPU_ADAPTER_PROBE: OnceLock<Option<wgpu::AdapterInfo>> = OnceLock::new();

fn wgpu_adapter_probe_cached() -> &'static Option<wgpu::AdapterInfo> {
    WGPU_ADAPTER_PROBE.get_or_init(preferred_wgpu_adapter_info)
}

fn preferred_wgpu_adapter_info() -> Option<wgpu::AdapterInfo> {
    pollster::block_on(async {
        let instance = wgpu::Instance::default();
        let adapter = instance
            .request_adapter(&wgpu::RequestAdapterOptions {
                power_preference: wgpu::PowerPreference::HighPerformance,
                force_fallback_adapter: false,
                compatible_surface: None,
            })
            .await?;
        Some(adapter.get_info())
    })
}

pub(crate) fn compute_hardware_details(choice: &ComputeChoice) -> String {
    match choice {
        ComputeChoice::Wgpu(_) => {
            if let Some(info) = wgpu_adapter_probe_cached().as_ref() {
                format!(
                    "{} ({:?}, {} backend)",
                    info.name, info.device_type, info.backend
                )
            } else {
                "adapter details unavailable".to_string()
            }
        }
        ComputeChoice::NdArray(_) => {
            let arch = std::env::consts::ARCH;
            let os = std::env::consts::OS;
            let threads = std::thread::available_parallelism()
                .map(|n| n.get())
                .unwrap_or(1);
            format!("{} {} CPU ({} threads)", os, arch, threads)
        }
    }
}

pub(crate) struct FitAllGenesParams<'a> {
    pub path: &'a str,
    pub obs_row_subset: Option<std::sync::Arc<[usize]>>,
    pub radius: f64,
    pub spatial_dim: usize,
    pub contact_distance: f64,
    pub tf_ligand_cutoff: f64,
    pub max_ligands: Option<usize>,
    pub use_tf_modulators: bool,
    pub use_lr_modulators: bool,
    pub use_tfl_modulators: bool,
    pub layer: &'a str,
    pub cluster_annot: &'a str,
    pub cnn: &'a CnnConfig,
    pub epochs: usize,
    pub learning_rate: f64,
    pub score_threshold: f64,
    pub l1_reg: f64,
    pub group_reg: f64,
    pub n_iter: usize,
    pub tol: f64,
    pub cnn_training_mode: CnnTrainingMode,
    pub hybrid_pass2_full_cnn: bool,
    pub hybrid_gating: &'a HybridCnnGatingConfig,
    pub min_mean_lasso_r2_for_cnn: f64,
    pub gene_filter: Option<Vec<String>>,
    pub max_genes: Option<usize>,
    pub n_parallel: usize,
    pub output_dir: &'a str,
    pub model_export: &'a ModelExportConfig,
    pub hud: Option<TrainingHud>,
    pub network_data_dir: Option<String>,
    pub tf_priors_feather: Option<String>,
    pub write_minimal_repro_h5ad: bool,
    pub spaceship_config: &'a SpaceshipConfig,
    pub config_source_path: Option<std::path::PathBuf>,
    /// Loaded from shared `spacetravlr_run_repro.toml` (`--join-output-dir`); skips overwriting that file at end.
    pub join_training: bool,
}

macro_rules! dispatch_fit_all_genes {
    ($backend:ty, $p:expr, $device:expr) => {
        SpatialCellularProgramsEstimator::<Autodiff<$backend>, anndata_hdf5::H5>::fit_all_genes(
            $p.path,
            $p.obs_row_subset.clone(),
            $p.radius,
            $p.spatial_dim,
            $p.contact_distance,
            $p.tf_ligand_cutoff,
            $p.max_ligands,
            $p.use_tf_modulators,
            $p.use_lr_modulators,
            $p.use_tfl_modulators,
            $p.layer,
            $p.cluster_annot,
            $p.cnn,
            $p.epochs,
            $p.learning_rate,
            $p.score_threshold,
            $p.l1_reg,
            $p.group_reg,
            $p.n_iter,
            $p.tol,
            $p.cnn_training_mode,
            $p.hybrid_pass2_full_cnn,
            $p.hybrid_gating,
            $p.min_mean_lasso_r2_for_cnn,
            $p.gene_filter.clone(),
            $p.max_genes,
            $p.n_parallel,
            $p.output_dir,
            $p.model_export,
            $p.hud.clone(),
            $p.network_data_dir.as_deref(),
            $p.tf_priors_feather.as_deref(),
            $p.write_minimal_repro_h5ad,
            $p.spaceship_config,
            $p.config_source_path.clone(),
            $p.join_training,
            $device,
        )
    };
}

pub(crate) fn fit_all_genes_dispatch(
    p: &FitAllGenesParams<'_>,
    choice: &ComputeChoice,
) -> anyhow::Result<()> {
    match choice {
        ComputeChoice::Wgpu(device) => dispatch_fit_all_genes!(Wgpu, p, device),
        ComputeChoice::NdArray(device) => {
            dispatch_fit_all_genes!(NdArray<f32, i32>, p, device)
        }
    }
}