Skip to main content

multiscreen_rs/
runtime.rs

1use crate::error::{Error, Result};
2use burn::{
3    backend::Autodiff,
4    tensor::backend::{Backend, BackendTypes},
5};
6use std::env;
7
8// ---------------------------------------------------------------------------
9// Inference-only backend (no autodiff)
10// ---------------------------------------------------------------------------
11
12/// Inference-only device type — same physical device, but without autodiff wrapper.
13pub type InferenceDevice = <DefaultBackend as BackendTypes>::Device;
14
15// ---------------------------------------------------------------------------
16// Backend selection (feature-gated)
17// ---------------------------------------------------------------------------
18
19/// Default CPU backend (Flex) — used when the `cuda` feature is not enabled.
20#[cfg(not(feature = "cuda"))]
21pub type DefaultBackend = burn::backend::Flex;
22
23/// Default CUDA backend — used when the `cuda` feature is enabled.
24#[cfg(feature = "cuda")]
25pub type DefaultBackend = burn::backend::Cuda;
26
27/// Default training backend. Burn's `Autodiff` decorator adds backpropagation.
28pub type DefaultAutodiffBackend = Autodiff<DefaultBackend>;
29
30/// Device type for the active backend.
31pub type Device = <DefaultAutodiffBackend as BackendTypes>::Device;
32
33/// Chooses the default device based on the active backend.
34///
35/// With `cuda` feature: returns the default CUDA device (GPU 0).
36/// Without `cuda` feature: returns the default Flex (CPU) device.
37///
38/// Also reads `MULTISCREEN_DEVICE` env var for manual override.
39pub fn default_device() -> Result<Device> {
40    let requested = env::var("MULTISCREEN_DEVICE").unwrap_or_else(|_| "auto".to_string());
41    match requested.trim().to_ascii_lowercase().as_str() {
42        "" | "auto" => Ok(Device::default()),
43        "cpu" | "flex" => {
44            #[cfg(feature = "cuda")]
45            {
46                Err(Error::Config(
47                    "MULTISCREEN_DEVICE=cpu is not supported when compiled with the cuda feature. \
48                     Recompile without --features cuda for CPU training."
49                        .to_string(),
50                ))
51            }
52            #[cfg(not(feature = "cuda"))]
53            {
54                Ok(Device::default())
55            }
56        }
57        other => Err(Error::Config(format!(
58            "unsupported MULTISCREEN_DEVICE={other:?}; use auto"
59        ))),
60    }
61}
62
63/// Human-readable label for the active backend device.
64pub fn device_label(device: &Device) -> String {
65    <DefaultAutodiffBackend as Backend>::name(device)
66}
67
68// ---------------------------------------------------------------------------
69// CUDA-specific types (always available for documentation)
70// ---------------------------------------------------------------------------
71
72/// Autodiff-wrapped CUDA backend for training on GPU.
73#[cfg(feature = "cuda")]
74pub type CudaAutodiffBackend = Autodiff<burn::backend::Cuda>;
75
76/// Device type for the CUDA backend.
77#[cfg(feature = "cuda")]
78pub type CudaDevice = <CudaAutodiffBackend as BackendTypes>::Device;
79
80/// Convenience alias for a Multiscreen model backed by CUDA.
81#[cfg(feature = "cuda")]
82pub type CudaMultiscreenModel = crate::model::MultiscreenModel<CudaAutodiffBackend>;