use crate::error::{Error, Result};
use burn::{
backend::{Autodiff, Flex},
tensor::backend::{Backend, BackendTypes},
};
use std::env;
pub type DefaultBackend = Flex;
pub type DefaultAutodiffBackend = Autodiff<DefaultBackend>;
pub type Device = <DefaultAutodiffBackend as BackendTypes>::Device;
pub fn default_device() -> Result<Device> {
let requested = env::var("MULTISCREEN_DEVICE").unwrap_or_else(|_| "auto".to_string());
match requested.trim().to_ascii_lowercase().as_str() {
"" | "auto" | "cpu" | "flex" => Ok(Device::default()),
"cuda" | "gpu" => Err(Error::Config(
"Burn CUDA is backend-generic; instantiate MultiscreenModel with a CUDA Burn backend \
instead of using the default Flex device"
.to_string(),
)),
other => Err(Error::Config(format!(
"unsupported MULTISCREEN_DEVICE={other:?}; use auto, cpu, or flex"
))),
}
}
pub fn device_label(device: &Device) -> String {
<DefaultAutodiffBackend as Backend>::name(device)
}
#[cfg(feature = "cuda")]
pub use burn::backend::Cuda;
#[cfg(feature = "cuda")]
pub type CudaAutodiffBackend = Autodiff<Cuda>;
#[cfg(feature = "cuda")]
pub type CudaDevice = <CudaAutodiffBackend as BackendTypes>::Device;
#[cfg(feature = "cuda")]
pub type CudaMultiscreenModel = crate::model::MultiscreenModel<CudaAutodiffBackend>;
#[cfg(feature = "cuda")]
pub fn cuda_device() -> Result<CudaDevice> {
Ok(CudaDevice::default())
}