multiscreen_rs/
runtime.rs1use crate::error::{Error, Result};
2use burn::{
3 backend::Autodiff,
4 tensor::backend::{Backend, BackendTypes},
5};
6use std::env;
7
8#[cfg(not(feature = "cuda"))]
14pub type DefaultBackend = burn::backend::Flex;
15
16#[cfg(feature = "cuda")]
18pub type DefaultBackend = burn::backend::Cuda;
19
20pub type DefaultAutodiffBackend = Autodiff<DefaultBackend>;
22
23pub type Device = <DefaultAutodiffBackend as BackendTypes>::Device;
25
26pub fn default_device() -> Result<Device> {
33 let requested = env::var("MULTISCREEN_DEVICE").unwrap_or_else(|_| "auto".to_string());
34 match requested.trim().to_ascii_lowercase().as_str() {
35 "" | "auto" => Ok(Device::default()),
36 "cpu" | "flex" => {
37 #[cfg(feature = "cuda")]
38 {
39 Err(Error::Config(
40 "MULTISCREEN_DEVICE=cpu is not supported when compiled with the cuda feature. \
41 Recompile without --features cuda for CPU training."
42 .to_string(),
43 ))
44 }
45 #[cfg(not(feature = "cuda"))]
46 {
47 Ok(Device::default())
48 }
49 }
50 other => Err(Error::Config(format!(
51 "unsupported MULTISCREEN_DEVICE={other:?}; use auto"
52 ))),
53 }
54}
55
56pub fn device_label(device: &Device) -> String {
58 <DefaultAutodiffBackend as Backend>::name(device)
59}
60
61#[cfg(feature = "cuda")]
67pub type CudaAutodiffBackend = Autodiff<burn::backend::Cuda>;
68
69#[cfg(feature = "cuda")]
71pub type CudaDevice = <CudaAutodiffBackend as BackendTypes>::Device;
72
73#[cfg(feature = "cuda")]
75pub type CudaMultiscreenModel = crate::model::MultiscreenModel<CudaAutodiffBackend>;