Skip to main content

multiscreen_rs/
runtime.rs

1use crate::error::{Error, Result};
2use burn::{
3    backend::Autodiff,
4    tensor::backend::{Backend, BackendTypes},
5};
6use std::env;
7
8// ---------------------------------------------------------------------------
9// Backend selection (feature-gated)
10// ---------------------------------------------------------------------------
11
12/// Default CPU backend (Flex) — used when the `cuda` feature is not enabled.
13#[cfg(not(feature = "cuda"))]
14pub type DefaultBackend = burn::backend::Flex;
15
16/// Default CUDA backend — used when the `cuda` feature is enabled.
17#[cfg(feature = "cuda")]
18pub type DefaultBackend = burn::backend::Cuda;
19
20/// Default training backend. Burn's `Autodiff` decorator adds backpropagation.
21pub type DefaultAutodiffBackend = Autodiff<DefaultBackend>;
22
23/// Device type for the active backend.
24pub type Device = <DefaultAutodiffBackend as BackendTypes>::Device;
25
26/// Chooses the default device based on the active backend.
27///
28/// With `cuda` feature: returns the default CUDA device (GPU 0).
29/// Without `cuda` feature: returns the default Flex (CPU) device.
30///
31/// Also reads `MULTISCREEN_DEVICE` env var for manual override.
32pub fn default_device() -> Result<Device> {
33    let requested = env::var("MULTISCREEN_DEVICE").unwrap_or_else(|_| "auto".to_string());
34    match requested.trim().to_ascii_lowercase().as_str() {
35        "" | "auto" => Ok(Device::default()),
36        "cpu" | "flex" => {
37            #[cfg(feature = "cuda")]
38            {
39                Err(Error::Config(
40                    "MULTISCREEN_DEVICE=cpu is not supported when compiled with the cuda feature. \
41                     Recompile without --features cuda for CPU training."
42                        .to_string(),
43                ))
44            }
45            #[cfg(not(feature = "cuda"))]
46            {
47                Ok(Device::default())
48            }
49        }
50        other => Err(Error::Config(format!(
51            "unsupported MULTISCREEN_DEVICE={other:?}; use auto"
52        ))),
53    }
54}
55
56/// Human-readable label for the active backend device.
57pub fn device_label(device: &Device) -> String {
58    <DefaultAutodiffBackend as Backend>::name(device)
59}
60
61// ---------------------------------------------------------------------------
62// CUDA-specific types (always available for documentation)
63// ---------------------------------------------------------------------------
64
65/// Autodiff-wrapped CUDA backend for training on GPU.
66#[cfg(feature = "cuda")]
67pub type CudaAutodiffBackend = Autodiff<burn::backend::Cuda>;
68
69/// Device type for the CUDA backend.
70#[cfg(feature = "cuda")]
71pub type CudaDevice = <CudaAutodiffBackend as BackendTypes>::Device;
72
73/// Convenience alias for a Multiscreen model backed by CUDA.
74#[cfg(feature = "cuda")]
75pub type CudaMultiscreenModel = crate::model::MultiscreenModel<CudaAutodiffBackend>;