Skip to main content

multiscreen_rs/
runtime.rs

1use crate::error::{Error, Result};
2use burn::{
3    backend::{Autodiff, Flex},
4    tensor::backend::{Backend, BackendTypes},
5};
6use std::env;
7
8/// Default Candle-free CPU backend used by this crate.
9pub type DefaultBackend = Flex;
10
11/// Default training backend. Burn's `Autodiff` decorator adds backpropagation.
12pub type DefaultAutodiffBackend = Autodiff<DefaultBackend>;
13
14/// Device type for the default Burn backend.
15pub type Device = <DefaultAutodiffBackend as BackendTypes>::Device;
16
17/// Chooses the default Burn device from `MULTISCREEN_DEVICE`.
18///
19/// Values: `auto`, `cpu`, `flex`.
20pub fn default_device() -> Result<Device> {
21    let requested = env::var("MULTISCREEN_DEVICE").unwrap_or_else(|_| "auto".to_string());
22    match requested.trim().to_ascii_lowercase().as_str() {
23        "" | "auto" | "cpu" | "flex" => Ok(Device::default()),
24        "cuda" | "gpu" => Err(Error::Config(
25            "Burn CUDA is backend-generic; instantiate MultiscreenModel with a CUDA Burn backend \
26             instead of using the default Flex device"
27                .to_string(),
28        )),
29        other => Err(Error::Config(format!(
30            "unsupported MULTISCREEN_DEVICE={other:?}; use auto, cpu, or flex"
31        ))),
32    }
33}
34
35/// Human-readable label for the default Burn device.
36pub fn device_label(device: &Device) -> String {
37    <DefaultAutodiffBackend as Backend>::name(device)
38}
39
40// ---------------------------------------------------------------------------
41// CUDA support (enabled by the `cuda` feature flag)
42// ---------------------------------------------------------------------------
43
44#[cfg(feature = "cuda")]
45pub use burn::backend::Cuda;
46
47/// Autodiff-wrapped CUDA backend for training on GPU.
48#[cfg(feature = "cuda")]
49pub type CudaAutodiffBackend = Autodiff<Cuda>;
50
51/// Device type for the CUDA backend.
52#[cfg(feature = "cuda")]
53pub type CudaDevice = <CudaAutodiffBackend as BackendTypes>::Device;
54
55/// Convenience alias for a Multiscreen model backed by CUDA.
56#[cfg(feature = "cuda")]
57pub type CudaMultiscreenModel = crate::model::MultiscreenModel<CudaAutodiffBackend>;
58
59/// Returns the default CUDA device (GPU index 0).
60///
61/// Enable the `cuda` feature flag and make sure the Burn CUDA runtime
62/// and compatible NVIDIA driver are available on the system.
63///
64/// ```toml
65/// [dependencies]
66/// multiscreen-rs = { version = "0.1", features = ["cuda"] }
67/// ```
68#[cfg(feature = "cuda")]
69pub fn cuda_device() -> Result<CudaDevice> {
70    Ok(CudaDevice::default())
71}