Skip to main content

morok_device/
registry.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use once_cell::sync::Lazy;
5use parking_lot::RwLock;
6
7pub use morok_dtype::DeviceSpec;
8
9use crate::allocator::{Allocator, CpuAllocator, LruAllocator};
10use crate::error::{InvalidDeviceSnafu, Result};
11
12/// Extension trait for DeviceSpec to add parsing functionality.
13///
14/// This is in the device crate because parsing depends on feature flags
15/// and error types that are device-specific.
16pub trait DeviceSpecExt {
17    /// Parse a device string into a DeviceSpec.
18    ///
19    /// Examples:
20    /// - "CPU" -> DeviceSpec::Cpu
21    /// - "CUDA:0" -> DeviceSpec::Cuda { device_id: 0 }
22    /// - "cuda" -> DeviceSpec::Cuda { device_id: 0 } (default to device 0)
23    fn parse(s: &str) -> Result<DeviceSpec>;
24}
25
26impl DeviceSpecExt for DeviceSpec {
27    fn parse(s: &str) -> Result<Self> {
28        // DISK: preserve path case (don't uppercase)
29        if s.len() >= 5 && s[..5].eq_ignore_ascii_case("DISK:") {
30            return Ok(DeviceSpec::Disk { path: std::path::PathBuf::from(&s[5..]) });
31        }
32
33        let s = s.to_uppercase();
34        let parts: Vec<&str> = s.split(':').collect();
35
36        match parts[0] {
37            "CPU" => Ok(DeviceSpec::Cpu),
38            #[cfg(feature = "cuda")]
39            "CUDA" | "GPU" => {
40                let device_id = if parts.len() > 1 {
41                    parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
42                } else {
43                    0
44                };
45                Ok(DeviceSpec::Cuda { device_id })
46            }
47            #[cfg(not(feature = "cuda"))]
48            "CUDA" | "GPU" => {
49                let device_id = if parts.len() > 1 {
50                    parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
51                } else {
52                    0
53                };
54                Ok(DeviceSpec::Cuda { device_id })
55            }
56            #[cfg(feature = "metal")]
57            "METAL" => {
58                let device_id = if parts.len() > 1 {
59                    parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
60                } else {
61                    0
62                };
63                Ok(DeviceSpec::Metal { device_id })
64            }
65            #[cfg(not(feature = "metal"))]
66            "METAL" => {
67                let device_id = if parts.len() > 1 {
68                    parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
69                } else {
70                    0
71                };
72                Ok(DeviceSpec::Metal { device_id })
73            }
74            #[cfg(feature = "webgpu")]
75            "WEBGPU" => Ok(DeviceSpec::WebGpu),
76            #[cfg(not(feature = "webgpu"))]
77            "WEBGPU" => Ok(DeviceSpec::WebGpu),
78            _ => InvalidDeviceSnafu { device: s }.fail(),
79        }
80    }
81}
82
83#[derive(Default)]
84pub struct DeviceRegistry {
85    devices: RwLock<HashMap<DeviceSpec, Arc<dyn Allocator>>>,
86}
87
88impl DeviceRegistry {
89    /// Get or create a device allocator.
90    pub fn get(&self, spec: &DeviceSpec) -> Result<Arc<dyn Allocator>> {
91        // Fast path: read lock
92        {
93            let devices = self.devices.read();
94            if let Some(allocator) = devices.get(spec) {
95                return Ok(Arc::clone(allocator));
96            }
97        }
98
99        // Slow path: write lock to create
100        let mut devices = self.devices.write();
101
102        // Double-check after acquiring write lock
103        if let Some(allocator) = devices.get(spec) {
104            return Ok(Arc::clone(allocator));
105        }
106
107        // Create new allocator
108        let allocator = self.create_allocator(spec)?;
109        devices.insert(spec.clone(), Arc::clone(&allocator));
110        Ok(allocator)
111    }
112
113    /// Get a device by parsing a device string.
114    pub fn get_device(&self, device: &str) -> Result<Arc<dyn Allocator>> {
115        let spec = <DeviceSpec as DeviceSpecExt>::parse(device)?;
116        self.get(&spec)
117    }
118
119    fn create_allocator(&self, spec: &DeviceSpec) -> Result<Arc<dyn Allocator>> {
120        // DISK: no LRU caching (Tinygrad: DiskAllocator extends Allocator, not LRUAllocator)
121        if let DeviceSpec::Disk { path } = spec {
122            return Ok(Arc::new(crate::allocator::DiskAllocator::new(path.clone())));
123        }
124
125        let base: Box<dyn Allocator> = match spec {
126            DeviceSpec::Cpu => Box::new(CpuAllocator),
127            #[cfg(feature = "cuda")]
128            DeviceSpec::Cuda { device_id } => Box::new(crate::allocator::CudaAllocator::new(*device_id)?),
129            #[cfg(not(feature = "cuda"))]
130            DeviceSpec::Cuda { .. } => unimplemented!("Cuda allocator - to be implemented"),
131            DeviceSpec::Metal { .. } => unimplemented!("Metal allocator - to be implemented"),
132            DeviceSpec::WebGpu => unimplemented!("WebGPU allocator - to be implemented"),
133            DeviceSpec::Disk { .. } => unreachable!(),
134        };
135
136        // Wrap with LRU cache (already thread-safe via Mutex)
137        let lru = LruAllocator::new(base);
138
139        Ok(Arc::new(lru))
140    }
141}
142
143/// Global device registry instance.
144static REGISTRY: Lazy<DeviceRegistry> = Lazy::new(DeviceRegistry::default);
145
146/// Get the global device registry.
147pub fn registry() -> &'static DeviceRegistry {
148    &REGISTRY
149}
150
151/// Convenience function to get a device allocator by string.
152pub fn get_device(device: &str) -> Result<Arc<dyn Allocator>> {
153    registry().get_device(device)
154}
155
156/// Convenience function to get CPU allocator.
157pub fn cpu() -> Result<Arc<dyn Allocator>> {
158    registry().get(&DeviceSpec::Cpu)
159}
160
161/// Convenience function to get CUDA allocator.
162#[cfg(feature = "cuda")]
163pub fn cuda(device_id: usize) -> Result<Arc<dyn Allocator>> {
164    registry().get(&DeviceSpec::Cuda { device_id })
165}