use crate::backend::{Backend, DeviceCapabilities};
#[derive(Debug, Clone)]
pub enum BurnDevice {
Cpu,
Cuda(u32),
Metal(u32),
Vulkan(u32),
}
impl BurnDevice {
pub fn as_label(&self) -> String {
match self {
BurnDevice::Cpu => "cpu".to_string(),
BurnDevice::Cuda(id) => format!("cuda:{}", id),
BurnDevice::Metal(id) => format!("metal:{}", id),
BurnDevice::Vulkan(id) => format!("vulkan:{}", id),
}
}
}
#[derive(Debug, Clone)]
pub struct BurnBackendConfig {
pub device: BurnDevice,
pub allow_tf32: bool,
}
impl Default for BurnBackendConfig {
fn default() -> Self {
Self {
device: BurnDevice::Cpu,
allow_tf32: false,
}
}
}
#[derive(Debug, Clone)]
pub struct BurnBackend {
config: BurnBackendConfig,
}
impl BurnBackend {
pub fn new(config: BurnBackendConfig) -> Self {
Self { config }
}
pub fn config(&self) -> &BurnBackendConfig {
&self.config
}
}
impl Backend for BurnBackend {
fn name(&self) -> &str {
"burn"
}
fn device(&self) -> &str {
match &self.config.device {
BurnDevice::Cpu => "cpu",
BurnDevice::Cuda(_) => "cuda",
BurnDevice::Metal(_) => "metal",
BurnDevice::Vulkan(_) => "vulkan",
}
}
fn capabilities(&self) -> DeviceCapabilities {
match &self.config.device {
BurnDevice::Cpu => DeviceCapabilities {
supports_f16: false,
supports_bf16: false,
supports_tf32: false,
max_memory_bytes: None,
compute_units: None,
},
BurnDevice::Cuda(_) => DeviceCapabilities {
supports_f16: true,
supports_bf16: true,
supports_tf32: self.config.allow_tf32,
max_memory_bytes: None, compute_units: None,
},
BurnDevice::Metal(_) => DeviceCapabilities {
supports_f16: true,
supports_bf16: false,
supports_tf32: false,
max_memory_bytes: None,
compute_units: None,
},
BurnDevice::Vulkan(_) => DeviceCapabilities {
supports_f16: true,
supports_bf16: false,
supports_tf32: false,
max_memory_bytes: None,
compute_units: None,
},
}
}
fn is_available(&self) -> bool {
true
}
}