use crate::backend::{Backend, DeviceCapabilities};
use crate::error::AphelionResult;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum BurnDevice {
#[default]
Cpu,
Cuda(u32),
Metal(u32),
Vulkan(u32),
}
impl BurnDevice {
pub fn as_label(&self) -> String {
match self {
BurnDevice::Cpu => "cpu".to_string(),
BurnDevice::Cuda(id) => format!("cuda:{}", id),
BurnDevice::Metal(id) => format!("metal:{}", id),
BurnDevice::Vulkan(id) => format!("vulkan:{}", id),
}
}
pub fn is_cpu(&self) -> bool {
matches!(self, BurnDevice::Cpu)
}
pub fn is_gpu(&self) -> bool {
!self.is_cpu()
}
}
#[derive(Debug, Clone)]
pub struct BurnBackendConfig {
pub device: BurnDevice,
pub allow_tf32: bool,
}
impl Default for BurnBackendConfig {
fn default() -> Self {
Self {
device: BurnDevice::Cpu,
allow_tf32: false,
}
}
}
impl BurnBackendConfig {
pub fn new(device: BurnDevice) -> Self {
Self {
device,
allow_tf32: false,
}
}
pub fn cpu() -> Self {
Self::new(BurnDevice::Cpu)
}
pub fn cuda(device_id: u32) -> Self {
Self::new(BurnDevice::Cuda(device_id))
}
pub fn with_tf32(mut self, allow: bool) -> Self {
self.allow_tf32 = allow;
self
}
}
#[derive(Debug)]
pub struct BurnBackend {
config: BurnBackendConfig,
initialized: Arc<AtomicBool>,
}
impl Clone for BurnBackend {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
initialized: Arc::new(AtomicBool::new(self.initialized.load(Ordering::SeqCst))),
}
}
}
impl BurnBackend {
pub fn new(config: BurnBackendConfig) -> Self {
Self {
config,
initialized: Arc::new(AtomicBool::new(false)),
}
}
pub fn cpu() -> Self {
Self::new(BurnBackendConfig::cpu())
}
pub fn cuda(device_id: u32) -> Self {
Self::new(BurnBackendConfig::cuda(device_id))
}
pub fn config(&self) -> &BurnBackendConfig {
&self.config
}
pub fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::SeqCst)
}
fn check_device_availability(&self) -> bool {
match &self.config.device {
BurnDevice::Cpu => true,
BurnDevice::Cuda(_) => {
false
}
BurnDevice::Metal(_) => {
cfg!(target_os = "macos")
}
BurnDevice::Vulkan(_) => {
false
}
}
}
}
impl Default for BurnBackend {
fn default() -> Self {
Self::cpu()
}
}
impl Backend for BurnBackend {
fn name(&self) -> &str {
"burn"
}
fn device(&self) -> &str {
match &self.config.device {
BurnDevice::Cpu => "cpu",
BurnDevice::Cuda(_) => "cuda",
BurnDevice::Metal(_) => "metal",
BurnDevice::Vulkan(_) => "vulkan",
}
}
fn capabilities(&self) -> DeviceCapabilities {
match &self.config.device {
BurnDevice::Cpu => DeviceCapabilities {
supports_f16: false,
supports_bf16: false,
supports_tf32: false,
max_memory_bytes: None,
compute_units: None,
},
BurnDevice::Cuda(_) => DeviceCapabilities {
supports_f16: true,
supports_bf16: true,
supports_tf32: self.config.allow_tf32,
max_memory_bytes: Some(8 * 1024 * 1024 * 1024), compute_units: Some(128), },
BurnDevice::Metal(_) => DeviceCapabilities {
supports_f16: true,
supports_bf16: false,
supports_tf32: false,
max_memory_bytes: None,
compute_units: None,
},
BurnDevice::Vulkan(_) => DeviceCapabilities {
supports_f16: true,
supports_bf16: false,
supports_tf32: false,
max_memory_bytes: None,
compute_units: None,
},
}
}
fn is_available(&self) -> bool {
self.check_device_availability()
}
fn initialize(&mut self) -> AphelionResult<()> {
if self.initialized.load(Ordering::SeqCst) {
return Ok(());
}
if !self.check_device_availability() {
return Err(crate::error::AphelionError::backend(format!(
"Burn device {} is not available",
self.config.device.as_label()
)));
}
self.initialized.store(true, Ordering::SeqCst);
tracing::info!(
backend = "burn",
device = %self.config.device.as_label(),
tf32 = self.config.allow_tf32,
"Burn backend initialized (placeholder)"
);
Ok(())
}
fn shutdown(&mut self) -> AphelionResult<()> {
if !self.initialized.load(Ordering::SeqCst) {
return Ok(());
}
self.initialized.store(false, Ordering::SeqCst);
tracing::info!(
backend = "burn",
device = %self.config.device.as_label(),
"Burn backend shutdown (placeholder)"
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_burn_device_as_label() {
assert_eq!(BurnDevice::Cpu.as_label(), "cpu");
assert_eq!(BurnDevice::Cuda(0).as_label(), "cuda:0");
assert_eq!(BurnDevice::Cuda(1).as_label(), "cuda:1");
assert_eq!(BurnDevice::Metal(0).as_label(), "metal:0");
assert_eq!(BurnDevice::Vulkan(2).as_label(), "vulkan:2");
}
#[test]
fn test_burn_device_default() {
let device = BurnDevice::default();
assert_eq!(device, BurnDevice::Cpu);
}
#[test]
fn test_burn_device_is_cpu_gpu() {
assert!(BurnDevice::Cpu.is_cpu());
assert!(!BurnDevice::Cpu.is_gpu());
assert!(!BurnDevice::Cuda(0).is_cpu());
assert!(BurnDevice::Cuda(0).is_gpu());
assert!(!BurnDevice::Metal(0).is_cpu());
assert!(BurnDevice::Metal(0).is_gpu());
assert!(!BurnDevice::Vulkan(0).is_cpu());
assert!(BurnDevice::Vulkan(0).is_gpu());
}
#[test]
fn test_burn_backend_config_default() {
let config = BurnBackendConfig::default();
assert_eq!(config.device, BurnDevice::Cpu);
assert!(!config.allow_tf32);
}
#[test]
fn test_burn_backend_config_builder() {
let config = BurnBackendConfig::cuda(0).with_tf32(true);
assert_eq!(config.device, BurnDevice::Cuda(0));
assert!(config.allow_tf32);
}
#[test]
fn test_burn_backend_new() {
let backend = BurnBackend::new(BurnBackendConfig::default());
assert_eq!(backend.name(), "burn");
assert_eq!(backend.device(), "cpu");
assert!(!backend.is_initialized());
}
#[test]
fn test_burn_backend_cpu() {
let backend = BurnBackend::cpu();
assert_eq!(backend.device(), "cpu");
assert!(backend.config().device.is_cpu());
}
#[test]
fn test_burn_backend_cuda() {
let backend = BurnBackend::cuda(0);
assert_eq!(backend.device(), "cuda");
assert!(backend.config().device.is_gpu());
}
#[test]
fn test_burn_backend_default() {
let backend = BurnBackend::default();
assert_eq!(backend.device(), "cpu");
}
#[test]
fn test_burn_backend_capabilities_cpu() {
let backend = BurnBackend::cpu();
let caps = backend.capabilities();
assert!(!caps.supports_f16);
assert!(!caps.supports_bf16);
assert!(!caps.supports_tf32);
}
#[test]
fn test_burn_backend_capabilities_cuda() {
let backend = BurnBackend::new(BurnBackendConfig::cuda(0).with_tf32(true));
let caps = backend.capabilities();
assert!(caps.supports_f16);
assert!(caps.supports_bf16);
assert!(caps.supports_tf32);
assert!(caps.max_memory_bytes.is_some());
}
#[test]
fn test_burn_backend_is_available_cpu() {
let backend = BurnBackend::cpu();
assert!(backend.is_available());
}
#[test]
fn test_burn_backend_is_available_cuda() {
let backend = BurnBackend::cuda(0);
assert!(!backend.is_available());
}
#[test]
fn test_burn_backend_initialize_cpu() {
let mut backend = BurnBackend::cpu();
assert!(!backend.is_initialized());
let result = backend.initialize();
assert!(result.is_ok());
assert!(backend.is_initialized());
let result = backend.initialize();
assert!(result.is_ok());
}
#[test]
fn test_burn_backend_initialize_cuda_fails() {
let mut backend = BurnBackend::cuda(0);
let result = backend.initialize();
assert!(result.is_err());
}
#[test]
fn test_burn_backend_shutdown() {
let mut backend = BurnBackend::cpu();
backend.initialize().unwrap();
assert!(backend.is_initialized());
let result = backend.shutdown();
assert!(result.is_ok());
assert!(!backend.is_initialized());
let result = backend.shutdown();
assert!(result.is_ok());
}
#[test]
fn test_burn_backend_clone() {
let mut backend = BurnBackend::cpu();
backend.initialize().unwrap();
let cloned = backend.clone();
assert!(cloned.is_initialized());
assert_eq!(cloned.device(), backend.device());
}
}