1use std::collections::HashMap;
2use std::sync::Arc;
3
4use once_cell::sync::Lazy;
5use parking_lot::RwLock;
6
7pub use morok_dtype::DeviceSpec;
8
9use crate::allocator::{Allocator, CpuAllocator, LruAllocator};
10use crate::error::{InvalidDeviceSnafu, Result};
11
12pub trait DeviceSpecExt {
17 fn parse(s: &str) -> Result<DeviceSpec>;
24}
25
26impl DeviceSpecExt for DeviceSpec {
27 fn parse(s: &str) -> Result<Self> {
28 let s = s.to_uppercase();
29 let parts: Vec<&str> = s.split(':').collect();
30
31 match parts[0] {
32 "CPU" => Ok(DeviceSpec::Cpu),
33 #[cfg(feature = "cuda")]
34 "CUDA" | "GPU" => {
35 let device_id = if parts.len() > 1 {
36 parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
37 } else {
38 0
39 };
40 Ok(DeviceSpec::Cuda { device_id })
41 }
42 #[cfg(not(feature = "cuda"))]
43 "CUDA" | "GPU" => {
44 let device_id = if parts.len() > 1 {
45 parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
46 } else {
47 0
48 };
49 Ok(DeviceSpec::Cuda { device_id })
50 }
51 #[cfg(feature = "metal")]
52 "METAL" => {
53 let device_id = if parts.len() > 1 {
54 parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
55 } else {
56 0
57 };
58 Ok(DeviceSpec::Metal { device_id })
59 }
60 #[cfg(not(feature = "metal"))]
61 "METAL" => {
62 let device_id = if parts.len() > 1 {
63 parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
64 } else {
65 0
66 };
67 Ok(DeviceSpec::Metal { device_id })
68 }
69 #[cfg(feature = "webgpu")]
70 "WEBGPU" => Ok(DeviceSpec::WebGpu),
71 #[cfg(not(feature = "webgpu"))]
72 "WEBGPU" => Ok(DeviceSpec::WebGpu),
73 _ => InvalidDeviceSnafu { device: s }.fail(),
74 }
75 }
76}
77
78#[derive(Default)]
79pub struct DeviceRegistry {
80 devices: RwLock<HashMap<DeviceSpec, Arc<dyn Allocator>>>,
81}
82
83impl DeviceRegistry {
84 pub fn get(&self, spec: &DeviceSpec) -> Result<Arc<dyn Allocator>> {
86 {
88 let devices = self.devices.read();
89 if let Some(allocator) = devices.get(spec) {
90 return Ok(Arc::clone(allocator));
91 }
92 }
93
94 let mut devices = self.devices.write();
96
97 if let Some(allocator) = devices.get(spec) {
99 return Ok(Arc::clone(allocator));
100 }
101
102 let allocator = self.create_allocator(spec)?;
104 devices.insert(spec.clone(), Arc::clone(&allocator));
105 Ok(allocator)
106 }
107
108 pub fn get_device(&self, device: &str) -> Result<Arc<dyn Allocator>> {
110 let spec = <DeviceSpec as DeviceSpecExt>::parse(device)?;
111 self.get(&spec)
112 }
113
114 fn create_allocator(&self, spec: &DeviceSpec) -> Result<Arc<dyn Allocator>> {
115 let base: Box<dyn Allocator> = match spec {
116 DeviceSpec::Cpu => Box::new(CpuAllocator),
117 #[cfg(feature = "cuda")]
118 DeviceSpec::Cuda { device_id } => Box::new(crate::allocator::CudaAllocator::new(*device_id)?),
119 #[cfg(not(feature = "cuda"))]
120 DeviceSpec::Cuda { .. } => unimplemented!("Cuda allocator - to be implemented"),
121 DeviceSpec::Metal { .. } => unimplemented!("Metal allocator - to be implemented"),
122 DeviceSpec::WebGpu => unimplemented!("WebGPU allocator - to be implemented"),
123 };
124
125 let lru = LruAllocator::new(base);
127
128 Ok(Arc::new(lru))
129 }
130}
131
132static REGISTRY: Lazy<DeviceRegistry> = Lazy::new(DeviceRegistry::default);
134
135pub fn registry() -> &'static DeviceRegistry {
137 ®ISTRY
138}
139
140pub fn get_device(device: &str) -> Result<Arc<dyn Allocator>> {
142 registry().get_device(device)
143}
144
145pub fn cpu() -> Result<Arc<dyn Allocator>> {
147 registry().get(&DeviceSpec::Cpu)
148}
149
150#[cfg(feature = "cuda")]
152pub fn cuda(device_id: usize) -> Result<Arc<dyn Allocator>> {
153 registry().get(&DeviceSpec::Cuda { device_id })
154}