use burn::backend::ndarray::NdArrayDevice;
use burn::backend::wgpu::WgpuDevice;
use burn::backend::{NdArray, Wgpu};
use burn_autodiff::Autodiff;
use spacetravlr::config::{
CnnConfig, CnnTrainingMode, HybridCnnGatingConfig, ModelExportConfig, SpaceshipConfig,
};
use spacetravlr::spatial_estimator::SpatialCellularProgramsEstimator;
use spacetravlr::training_hud::TrainingHud;
use std::sync::OnceLock;
#[derive(Clone, Debug)]
pub(crate) enum ComputeChoice {
Wgpu(WgpuDevice),
NdArray(NdArrayDevice),
}
impl ComputeChoice {
pub(crate) fn label(&self) -> &'static str {
match self {
ComputeChoice::Wgpu(_) => "WebGPU",
ComputeChoice::NdArray(_) => "CPU (NdArray)",
}
}
}
fn env_truthy(name: &str) -> bool {
std::env::var(name)
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
}
pub(crate) fn select_compute_backend() -> ComputeChoice {
if env_truthy("SPACETRAVLR_FORCE_CPU") || env_truthy("SPACETRAVLR_DISABLE_WGPU") {
return ComputeChoice::NdArray(NdArrayDevice::Cpu);
}
match wgpu_adapter_probe_cached().as_ref() {
Some(_) => ComputeChoice::Wgpu(WgpuDevice::default()),
None => ComputeChoice::NdArray(NdArrayDevice::Cpu),
}
}
static WGPU_ADAPTER_PROBE: OnceLock<Option<wgpu::AdapterInfo>> = OnceLock::new();
fn wgpu_adapter_probe_cached() -> &'static Option<wgpu::AdapterInfo> {
WGPU_ADAPTER_PROBE.get_or_init(preferred_wgpu_adapter_info)
}
fn preferred_wgpu_adapter_info() -> Option<wgpu::AdapterInfo> {
pollster::block_on(async {
let instance = wgpu::Instance::default();
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await?;
Some(adapter.get_info())
})
}
pub(crate) fn compute_hardware_details(choice: &ComputeChoice) -> String {
match choice {
ComputeChoice::Wgpu(_) => {
if let Some(info) = wgpu_adapter_probe_cached().as_ref() {
format!(
"{} ({:?}, {} backend)",
info.name, info.device_type, info.backend
)
} else {
"adapter details unavailable".to_string()
}
}
ComputeChoice::NdArray(_) => {
let arch = std::env::consts::ARCH;
let os = std::env::consts::OS;
let threads = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
format!("{} {} CPU ({} threads)", os, arch, threads)
}
}
}
pub(crate) struct FitAllGenesParams<'a> {
pub path: &'a str,
pub obs_row_subset: Option<std::sync::Arc<[usize]>>,
pub radius: f64,
pub spatial_dim: usize,
pub contact_distance: f64,
pub tf_ligand_cutoff: f64,
pub max_ligands: Option<usize>,
pub use_tf_modulators: bool,
pub use_lr_modulators: bool,
pub use_tfl_modulators: bool,
pub layer: &'a str,
pub cluster_annot: &'a str,
pub cnn: &'a CnnConfig,
pub epochs: usize,
pub learning_rate: f64,
pub score_threshold: f64,
pub l1_reg: f64,
pub group_reg: f64,
pub n_iter: usize,
pub tol: f64,
pub cnn_training_mode: CnnTrainingMode,
pub hybrid_pass2_full_cnn: bool,
pub hybrid_gating: &'a HybridCnnGatingConfig,
pub min_mean_lasso_r2_for_cnn: f64,
pub gene_filter: Option<Vec<String>>,
pub max_genes: Option<usize>,
pub n_parallel: usize,
pub output_dir: &'a str,
pub model_export: &'a ModelExportConfig,
pub hud: Option<TrainingHud>,
pub network_data_dir: Option<String>,
pub tf_priors_feather: Option<String>,
pub write_minimal_repro_h5ad: bool,
pub spaceship_config: &'a SpaceshipConfig,
pub config_source_path: Option<std::path::PathBuf>,
pub join_training: bool,
}
macro_rules! dispatch_fit_all_genes {
($backend:ty, $p:expr, $device:expr) => {
SpatialCellularProgramsEstimator::<Autodiff<$backend>, anndata_hdf5::H5>::fit_all_genes(
$p.path,
$p.obs_row_subset.clone(),
$p.radius,
$p.spatial_dim,
$p.contact_distance,
$p.tf_ligand_cutoff,
$p.max_ligands,
$p.use_tf_modulators,
$p.use_lr_modulators,
$p.use_tfl_modulators,
$p.layer,
$p.cluster_annot,
$p.cnn,
$p.epochs,
$p.learning_rate,
$p.score_threshold,
$p.l1_reg,
$p.group_reg,
$p.n_iter,
$p.tol,
$p.cnn_training_mode,
$p.hybrid_pass2_full_cnn,
$p.hybrid_gating,
$p.min_mean_lasso_r2_for_cnn,
$p.gene_filter.clone(),
$p.max_genes,
$p.n_parallel,
$p.output_dir,
$p.model_export,
$p.hud.clone(),
$p.network_data_dir.as_deref(),
$p.tf_priors_feather.as_deref(),
$p.write_minimal_repro_h5ad,
$p.spaceship_config,
$p.config_source_path.clone(),
$p.join_training,
$device,
)
};
}
pub(crate) fn fit_all_genes_dispatch(
p: &FitAllGenesParams<'_>,
choice: &ComputeChoice,
) -> anyhow::Result<()> {
match choice {
ComputeChoice::Wgpu(device) => dispatch_fit_all_genes!(Wgpu, p, device),
ComputeChoice::NdArray(device) => {
dispatch_fit_all_genes!(NdArray<f32, i32>, p, device)
}
}
}