use crate::error::Result;
use std::cell::RefCell;
use std::fmt;
use std::sync::Arc;
pub mod async_executor;
pub mod cpu;
pub mod gpu;
pub mod kernel_fusion;
pub mod memory_pool;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DeviceType {
Cpu,
Vulkan,
}
impl fmt::Display for DeviceType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DeviceType::Cpu => write!(f, "CPU"),
DeviceType::Vulkan => write!(f, "Vulkan"),
}
}
}
thread_local! {
static DEFAULT_DEVICE: RefCell<Option<Device>> = RefCell::new(None);
}
pub fn set_default_device(device: Device) {
DEFAULT_DEVICE.with(|d| {
*d.borrow_mut() = Some(device);
});
}
pub fn clear_default_device() {
DEFAULT_DEVICE.with(|d| {
*d.borrow_mut() = None;
});
}
#[derive(Debug, Clone)]
pub struct DeviceInfo {
pub name: String,
pub device_type: DeviceType,
pub memory_size: Option<u64>,
pub compute_units: Option<u32>,
pub supports_f16: bool,
pub supports_f64: bool,
}
#[derive(Clone)]
pub struct Device {
backend: Arc<dyn Backend + Send + Sync>,
info: DeviceInfo,
}
impl Device {
pub fn new(backend: Arc<dyn Backend + Send + Sync>, info: DeviceInfo) -> Self {
Self { backend, info }
}
pub fn auto_select() -> Result<Self> {
let default_device = DEFAULT_DEVICE.with(|d| d.borrow().clone());
if let Some(device) = default_device {
return Ok(device);
}
if std::env::var("NNL_CPU_ONLY").is_ok() {
log::info!("NNL_CPU_ONLY set, using CPU device");
return Self::cpu();
}
if let Ok(device) = Self::vulkan() {
log::info!("Selected Vulkan device: {}", device.info.name);
return Ok(device);
}
let device = Self::cpu()?;
log::info!("Selected CPU device: {}", device.info.name);
Ok(device)
}
pub fn cpu() -> Result<Self> {
let backend = Arc::new(cpu::CpuBackend::new()?);
let info = DeviceInfo {
name: "CPU".to_string(),
device_type: DeviceType::Cpu,
memory_size: None, compute_units: Some(num_cpus::get() as u32),
supports_f16: false,
supports_f64: true,
};
let device = Self::new(backend, info);
set_default_device(device.clone());
Ok(device)
}
pub fn vulkan() -> Result<Self> {
if std::env::var("NNL_CPU_ONLY").is_ok() {
return Err(crate::error::NnlError::device(
"Vulkan device creation blocked by NNL_CPU_ONLY environment variable",
));
}
let backend = Arc::new(gpu::VulkanBackend::new()?);
let info = backend.device_info()?;
Ok(Self::new(backend, info))
}
pub fn info(&self) -> &DeviceInfo {
&self.info
}
pub fn device_type(&self) -> DeviceType {
self.info.device_type
}
pub fn backend(&self) -> &dyn Backend {
self.backend.as_ref()
}
pub fn supports_f16(&self) -> bool {
self.info.supports_f16
}
pub fn supports_f64(&self) -> bool {
self.info.supports_f64
}
pub fn memory_size(&self) -> Option<u64> {
self.info.memory_size
}
pub fn synchronize(&self) -> Result<()> {
self.backend.synchronize()
}
}
impl fmt::Debug for Device {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Device").field("info", &self.info).finish()
}
}
pub trait Backend {
fn device_info(&self) -> Result<DeviceInfo>;
fn allocate(&self, size: usize) -> Result<Arc<dyn DeviceMemory>>;
fn allocate_uniform(&self, size: usize) -> Result<Arc<dyn DeviceMemory>> {
self.allocate(size)
}
fn copy_to_device(&self, data: &[f32], memory: &dyn DeviceMemory) -> Result<()>;
fn copy_u32_to_device(&self, data: &[u32], memory: &dyn DeviceMemory) -> Result<()> {
let f32_data: Vec<f32> = data.iter().map(|&x| x as f32).collect();
self.copy_to_device(&f32_data, memory)
}
fn copy_to_host(&self, memory: &dyn DeviceMemory, data: &mut [f32]) -> Result<()>;
fn execute_kernel(
&self,
kernel: &dyn Kernel,
inputs: &[&dyn DeviceMemory],
outputs: &[&dyn DeviceMemory],
) -> Result<()>;
fn execute_kernel_with_uniform(
&self,
kernel: &dyn Kernel,
inputs: &[&dyn DeviceMemory],
outputs: &[&dyn DeviceMemory],
uniform: Option<&dyn DeviceMemory>,
) -> Result<()> {
if uniform.is_some() {
return Err(crate::error::NnlError::device(
"Uniform buffers not supported by this backend",
));
}
self.execute_kernel(kernel, inputs, outputs)
}
fn synchronize(&self) -> Result<()>;
fn is_available(&self) -> bool;
fn as_any(&self) -> &dyn std::any::Any;
}
pub trait DeviceMemory: std::fmt::Debug + Send + Sync {
fn size(&self) -> usize;
fn device_type(&self) -> DeviceType;
fn as_any(&self) -> &dyn std::any::Any;
fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
}
pub trait Kernel {
fn name(&self) -> &str;
fn local_size(&self) -> Option<[u32; 3]>;
fn as_any(&self) -> &dyn std::any::Any;
}
pub mod utils {
use super::*;
pub fn list_devices() -> Vec<DeviceInfo> {
let mut devices = Vec::new();
if let Ok(cpu) = Device::cpu() {
devices.push(cpu.info().clone());
}
if let Ok(vulkan) = Device::vulkan() {
devices.push(vulkan.info().clone());
}
devices
}
pub fn benchmark_devices() -> Result<Vec<(DeviceInfo, f64)>> {
let devices = list_devices();
let mut results = Vec::new();
for device_info in devices {
let _device = match device_info.device_type {
DeviceType::Cpu => Device::cpu()?,
DeviceType::Vulkan => Device::vulkan()?,
};
let start = std::time::Instant::now();
benchmark_matrix_multiply(&_device)?;
let duration = start.elapsed().as_secs_f64();
results.push((device_info, duration));
}
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
Ok(results)
}
fn benchmark_matrix_multiply(_device: &Device) -> Result<()> {
std::thread::sleep(std::time::Duration::from_millis(1));
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_auto_select() {
let device = Device::auto_select();
assert!(device.is_ok());
let device = device.unwrap();
println!("Auto-selected device: {:?}", device.device_type());
}
#[test]
fn test_cpu_device() {
let device = Device::cpu();
assert!(device.is_ok());
let device = device.unwrap();
assert_eq!(device.device_type(), DeviceType::Cpu);
assert!(device.supports_f64());
}
#[test]
fn test_list_devices() {
let devices = utils::list_devices();
assert!(!devices.is_empty());
assert!(devices.iter().any(|d| d.device_type == DeviceType::Cpu));
}
#[test]
fn test_device_info_display() {
let info = DeviceInfo {
name: "Test Device".to_string(),
device_type: DeviceType::Cpu,
memory_size: Some(8_000_000_000),
compute_units: Some(8),
supports_f16: false,
supports_f64: true,
};
assert_eq!(format!("{}", info.device_type), "CPU");
}
}