1use std::collections::HashMap;
2use std::sync::Arc;
3
4use once_cell::sync::Lazy;
5use parking_lot::RwLock;
6
7pub use morok_dtype::DeviceSpec;
8
9use crate::allocator::{Allocator, CpuAllocator, LruAllocator};
10use crate::error::{InvalidDeviceSnafu, Result};
11
12pub trait DeviceSpecExt {
17 fn parse(s: &str) -> Result<DeviceSpec>;
24}
25
26impl DeviceSpecExt for DeviceSpec {
27 fn parse(s: &str) -> Result<Self> {
28 if s.len() >= 5 && s[..5].eq_ignore_ascii_case("DISK:") {
30 return Ok(DeviceSpec::Disk { path: std::path::PathBuf::from(&s[5..]) });
31 }
32
33 let s = s.to_uppercase();
34 let parts: Vec<&str> = s.split(':').collect();
35
36 match parts[0] {
37 "CPU" => Ok(DeviceSpec::Cpu),
38 #[cfg(feature = "cuda")]
39 "CUDA" | "GPU" => {
40 let device_id = if parts.len() > 1 {
41 parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
42 } else {
43 0
44 };
45 Ok(DeviceSpec::Cuda { device_id })
46 }
47 #[cfg(not(feature = "cuda"))]
48 "CUDA" | "GPU" => {
49 let device_id = if parts.len() > 1 {
50 parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
51 } else {
52 0
53 };
54 Ok(DeviceSpec::Cuda { device_id })
55 }
56 #[cfg(feature = "metal")]
57 "METAL" => {
58 let device_id = if parts.len() > 1 {
59 parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
60 } else {
61 0
62 };
63 Ok(DeviceSpec::Metal { device_id })
64 }
65 #[cfg(not(feature = "metal"))]
66 "METAL" => {
67 let device_id = if parts.len() > 1 {
68 parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
69 } else {
70 0
71 };
72 Ok(DeviceSpec::Metal { device_id })
73 }
74 #[cfg(feature = "webgpu")]
75 "WEBGPU" => Ok(DeviceSpec::WebGpu),
76 #[cfg(not(feature = "webgpu"))]
77 "WEBGPU" => Ok(DeviceSpec::WebGpu),
78 _ => InvalidDeviceSnafu { device: s }.fail(),
79 }
80 }
81}
82
83#[derive(Default)]
84pub struct DeviceRegistry {
85 devices: RwLock<HashMap<DeviceSpec, Arc<dyn Allocator>>>,
86}
87
88impl DeviceRegistry {
89 pub fn get(&self, spec: &DeviceSpec) -> Result<Arc<dyn Allocator>> {
91 {
93 let devices = self.devices.read();
94 if let Some(allocator) = devices.get(spec) {
95 return Ok(Arc::clone(allocator));
96 }
97 }
98
99 let mut devices = self.devices.write();
101
102 if let Some(allocator) = devices.get(spec) {
104 return Ok(Arc::clone(allocator));
105 }
106
107 let allocator = self.create_allocator(spec)?;
109 devices.insert(spec.clone(), Arc::clone(&allocator));
110 Ok(allocator)
111 }
112
113 pub fn get_device(&self, device: &str) -> Result<Arc<dyn Allocator>> {
115 let spec = <DeviceSpec as DeviceSpecExt>::parse(device)?;
116 self.get(&spec)
117 }
118
119 fn create_allocator(&self, spec: &DeviceSpec) -> Result<Arc<dyn Allocator>> {
120 if let DeviceSpec::Disk { path } = spec {
122 return Ok(Arc::new(crate::allocator::DiskAllocator::new(path.clone())));
123 }
124
125 let base: Box<dyn Allocator> = match spec {
126 DeviceSpec::Cpu => Box::new(CpuAllocator),
127 #[cfg(feature = "cuda")]
128 DeviceSpec::Cuda { device_id } => Box::new(crate::allocator::CudaAllocator::new(*device_id)?),
129 #[cfg(not(feature = "cuda"))]
130 DeviceSpec::Cuda { .. } => unimplemented!("Cuda allocator - to be implemented"),
131 DeviceSpec::Metal { .. } => unimplemented!("Metal allocator - to be implemented"),
132 DeviceSpec::WebGpu => unimplemented!("WebGPU allocator - to be implemented"),
133 DeviceSpec::Disk { .. } => unreachable!(),
134 };
135
136 let lru = LruAllocator::new(base);
138
139 Ok(Arc::new(lru))
140 }
141}
142
143static REGISTRY: Lazy<DeviceRegistry> = Lazy::new(DeviceRegistry::default);
145
146pub fn registry() -> &'static DeviceRegistry {
148 ®ISTRY
149}
150
151pub fn get_device(device: &str) -> Result<Arc<dyn Allocator>> {
153 registry().get_device(device)
154}
155
156pub fn cpu() -> Result<Arc<dyn Allocator>> {
158 registry().get(&DeviceSpec::Cpu)
159}
160
161#[cfg(feature = "cuda")]
163pub fn cuda(device_id: usize) -> Result<Arc<dyn Allocator>> {
164 registry().get(&DeviceSpec::Cuda { device_id })
165}