multiscreen_rs/
runtime.rs1use crate::error::{Error, Result};
2use burn::{
3 backend::{Autodiff, Flex},
4 tensor::backend::{Backend, BackendTypes},
5};
6use std::env;
7
8pub type DefaultBackend = Flex;
10
11pub type DefaultAutodiffBackend = Autodiff<DefaultBackend>;
13
14pub type Device = <DefaultAutodiffBackend as BackendTypes>::Device;
16
17pub fn default_device() -> Result<Device> {
21 let requested = env::var("MULTISCREEN_DEVICE").unwrap_or_else(|_| "auto".to_string());
22 match requested.trim().to_ascii_lowercase().as_str() {
23 "" | "auto" | "cpu" | "flex" => Ok(Device::default()),
24 "cuda" | "gpu" => Err(Error::Config(
25 "Burn CUDA is backend-generic; instantiate MultiscreenModel with a CUDA Burn backend \
26 instead of using the default Flex device"
27 .to_string(),
28 )),
29 other => Err(Error::Config(format!(
30 "unsupported MULTISCREEN_DEVICE={other:?}; use auto, cpu, or flex"
31 ))),
32 }
33}
34
35pub fn device_label(device: &Device) -> String {
37 <DefaultAutodiffBackend as Backend>::name(device)
38}
39
40#[cfg(feature = "cuda")]
45pub use burn::backend::Cuda;
46
47#[cfg(feature = "cuda")]
49pub type CudaAutodiffBackend = Autodiff<Cuda>;
50
51#[cfg(feature = "cuda")]
53pub type CudaDevice = <CudaAutodiffBackend as BackendTypes>::Device;
54
55#[cfg(feature = "cuda")]
57pub type CudaMultiscreenModel = crate::model::MultiscreenModel<CudaAutodiffBackend>;
58
59#[cfg(feature = "cuda")]
69pub fn cuda_device() -> Result<CudaDevice> {
70 Ok(CudaDevice::default())
71}