multiscreen_rs/runtime.rs
1use crate::error::{Error, Result};
2use burn::{
3 backend::Autodiff,
4 tensor::backend::{Backend, BackendTypes},
5};
6use std::env;
7
8// ---------------------------------------------------------------------------
9// Inference-only backend (no autodiff)
10// ---------------------------------------------------------------------------
11
12/// Inference-only device type — same physical device, but without autodiff wrapper.
13pub type InferenceDevice = <DefaultBackend as BackendTypes>::Device;
14
15// ---------------------------------------------------------------------------
16// Backend selection (feature-gated)
17// ---------------------------------------------------------------------------
18
19/// Default CPU backend (Flex) — used when the `cuda` feature is not enabled.
20#[cfg(not(feature = "cuda"))]
21pub type DefaultBackend = burn::backend::Flex;
22
23/// Default CUDA backend — used when the `cuda` feature is enabled.
24#[cfg(feature = "cuda")]
25pub type DefaultBackend = burn::backend::Cuda;
26
27/// Default training backend. Burn's `Autodiff` decorator adds backpropagation.
28pub type DefaultAutodiffBackend = Autodiff<DefaultBackend>;
29
30/// Device type for the active backend.
31pub type Device = <DefaultAutodiffBackend as BackendTypes>::Device;
32
33/// Chooses the default device based on the active backend.
34///
35/// With `cuda` feature: returns the default CUDA device (GPU 0).
36/// Without `cuda` feature: returns the default Flex (CPU) device.
37///
38/// Also reads `MULTISCREEN_DEVICE` env var for manual override.
39pub fn default_device() -> Result<Device> {
40 let requested = env::var("MULTISCREEN_DEVICE").unwrap_or_else(|_| "auto".to_string());
41 match requested.trim().to_ascii_lowercase().as_str() {
42 "" | "auto" => Ok(Device::default()),
43 "cpu" | "flex" => {
44 #[cfg(feature = "cuda")]
45 {
46 Err(Error::Config(
47 "MULTISCREEN_DEVICE=cpu is not supported when compiled with the cuda feature. \
48 Recompile without --features cuda for CPU training."
49 .to_string(),
50 ))
51 }
52 #[cfg(not(feature = "cuda"))]
53 {
54 Ok(Device::default())
55 }
56 }
57 other => Err(Error::Config(format!(
58 "unsupported MULTISCREEN_DEVICE={other:?}; use auto"
59 ))),
60 }
61}
62
63/// Human-readable label for the active backend device.
64pub fn device_label(device: &Device) -> String {
65 <DefaultAutodiffBackend as Backend>::name(device)
66}
67
68// ---------------------------------------------------------------------------
69// CUDA-specific types (always available for documentation)
70// ---------------------------------------------------------------------------
71
72/// Autodiff-wrapped CUDA backend for training on GPU.
73#[cfg(feature = "cuda")]
74pub type CudaAutodiffBackend = Autodiff<burn::backend::Cuda>;
75
76/// Device type for the CUDA backend.
77#[cfg(feature = "cuda")]
78pub type CudaDevice = <CudaAutodiffBackend as BackendTypes>::Device;
79
80/// Convenience alias for a Multiscreen model backed by CUDA.
81#[cfg(feature = "cuda")]
82pub type CudaMultiscreenModel = crate::model::MultiscreenModel<CudaAutodiffBackend>;