use crate::error::{Error, Result};
use burn::{
backend::Autodiff,
tensor::backend::{Backend, BackendTypes},
};
use std::env;
#[cfg(not(feature = "cuda"))]
pub type DefaultBackend = burn::backend::Flex;
#[cfg(feature = "cuda")]
pub type DefaultBackend = burn::backend::Cuda;
pub type DefaultAutodiffBackend = Autodiff<DefaultBackend>;
pub type Device = <DefaultAutodiffBackend as BackendTypes>::Device;
pub fn default_device() -> Result<Device> {
let requested = env::var("MULTISCREEN_DEVICE").unwrap_or_else(|_| "auto".to_string());
match requested.trim().to_ascii_lowercase().as_str() {
"" | "auto" => Ok(Device::default()),
"cpu" | "flex" => {
#[cfg(feature = "cuda")]
{
Err(Error::Config(
"MULTISCREEN_DEVICE=cpu is not supported when compiled with the cuda feature. \
Recompile without --features cuda for CPU training."
.to_string(),
))
}
#[cfg(not(feature = "cuda"))]
{
Ok(Device::default())
}
}
other => Err(Error::Config(format!(
"unsupported MULTISCREEN_DEVICE={other:?}; use auto"
))),
}
}
pub fn device_label(device: &Device) -> String {
<DefaultAutodiffBackend as Backend>::name(device)
}
#[cfg(feature = "cuda")]
pub type CudaAutodiffBackend = Autodiff<burn::backend::Cuda>;
#[cfg(feature = "cuda")]
pub type CudaDevice = <CudaAutodiffBackend as BackendTypes>::Device;
#[cfg(feature = "cuda")]
pub type CudaMultiscreenModel = crate::model::MultiscreenModel<CudaAutodiffBackend>;