use std::collections::HashMap;
use std::sync::Arc;
use once_cell::sync::Lazy;
use parking_lot::RwLock;
pub use morok_dtype::DeviceSpec;
use crate::allocator::{Allocator, CpuAllocator, LruAllocator};
use crate::error::{InvalidDeviceSnafu, Result};
pub trait DeviceSpecExt {
fn parse(s: &str) -> Result<DeviceSpec>;
}
impl DeviceSpecExt for DeviceSpec {
fn parse(s: &str) -> Result<Self> {
if s.len() >= 5 && s[..5].eq_ignore_ascii_case("DISK:") {
return Ok(DeviceSpec::Disk { path: std::path::PathBuf::from(&s[5..]) });
}
let s = s.to_uppercase();
let parts: Vec<&str> = s.split(':').collect();
match parts[0] {
"CPU" => Ok(DeviceSpec::Cpu),
#[cfg(feature = "cuda")]
"CUDA" | "GPU" => {
let device_id = if parts.len() > 1 {
parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
} else {
0
};
Ok(DeviceSpec::Cuda { device_id })
}
#[cfg(not(feature = "cuda"))]
"CUDA" | "GPU" => {
let device_id = if parts.len() > 1 {
parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
} else {
0
};
Ok(DeviceSpec::Cuda { device_id })
}
#[cfg(feature = "metal")]
"METAL" => {
let device_id = if parts.len() > 1 {
parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
} else {
0
};
Ok(DeviceSpec::Metal { device_id })
}
#[cfg(not(feature = "metal"))]
"METAL" => {
let device_id = if parts.len() > 1 {
parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
} else {
0
};
Ok(DeviceSpec::Metal { device_id })
}
#[cfg(feature = "webgpu")]
"WEBGPU" => Ok(DeviceSpec::WebGpu),
#[cfg(not(feature = "webgpu"))]
"WEBGPU" => Ok(DeviceSpec::WebGpu),
_ => InvalidDeviceSnafu { device: s }.fail(),
}
}
}
#[derive(Default)]
pub struct DeviceRegistry {
devices: RwLock<HashMap<DeviceSpec, Arc<dyn Allocator>>>,
}
impl DeviceRegistry {
pub fn get(&self, spec: &DeviceSpec) -> Result<Arc<dyn Allocator>> {
{
let devices = self.devices.read();
if let Some(allocator) = devices.get(spec) {
return Ok(Arc::clone(allocator));
}
}
let mut devices = self.devices.write();
if let Some(allocator) = devices.get(spec) {
return Ok(Arc::clone(allocator));
}
let allocator = self.create_allocator(spec)?;
devices.insert(spec.clone(), Arc::clone(&allocator));
Ok(allocator)
}
pub fn get_device(&self, device: &str) -> Result<Arc<dyn Allocator>> {
let spec = <DeviceSpec as DeviceSpecExt>::parse(device)?;
self.get(&spec)
}
fn create_allocator(&self, spec: &DeviceSpec) -> Result<Arc<dyn Allocator>> {
if let DeviceSpec::Disk { path } = spec {
return Ok(Arc::new(crate::allocator::DiskAllocator::new(path.clone())));
}
let base: Box<dyn Allocator> = match spec {
DeviceSpec::Cpu => Box::new(CpuAllocator),
#[cfg(feature = "cuda")]
DeviceSpec::Cuda { device_id } => Box::new(crate::allocator::CudaAllocator::new(*device_id)?),
#[cfg(not(feature = "cuda"))]
DeviceSpec::Cuda { .. } => unimplemented!("Cuda allocator - to be implemented"),
DeviceSpec::Metal { .. } => unimplemented!("Metal allocator - to be implemented"),
DeviceSpec::WebGpu => unimplemented!("WebGPU allocator - to be implemented"),
DeviceSpec::Disk { .. } => unreachable!(),
};
let lru = LruAllocator::new(base);
Ok(Arc::new(lru))
}
}
static REGISTRY: Lazy<DeviceRegistry> = Lazy::new(DeviceRegistry::default);
pub fn registry() -> &'static DeviceRegistry {
®ISTRY
}
pub fn get_device(device: &str) -> Result<Arc<dyn Allocator>> {
registry().get_device(device)
}
pub fn cpu() -> Result<Arc<dyn Allocator>> {
registry().get(&DeviceSpec::Cpu)
}
#[cfg(feature = "cuda")]
pub fn cuda(device_id: usize) -> Result<Arc<dyn Allocator>> {
registry().get(&DeviceSpec::Cuda { device_id })
}