Skip to main content

morok_device/
registry.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use once_cell::sync::Lazy;
5use parking_lot::RwLock;
6
7pub use morok_dtype::DeviceSpec;
8
9use crate::allocator::{Allocator, CpuAllocator, LruAllocator};
10use crate::error::{InvalidDeviceSnafu, Result};
11
12/// Extension trait for DeviceSpec to add parsing functionality.
13///
14/// This is in the device crate because parsing depends on feature flags
15/// and error types that are device-specific.
16pub trait DeviceSpecExt {
17    /// Parse a device string into a DeviceSpec.
18    ///
19    /// Examples:
20    /// - "CPU" -> DeviceSpec::Cpu
21    /// - "CUDA:0" -> DeviceSpec::Cuda { device_id: 0 }
22    /// - "cuda" -> DeviceSpec::Cuda { device_id: 0 } (default to device 0)
23    fn parse(s: &str) -> Result<DeviceSpec>;
24}
25
26impl DeviceSpecExt for DeviceSpec {
27    fn parse(s: &str) -> Result<Self> {
28        let s = s.to_uppercase();
29        let parts: Vec<&str> = s.split(':').collect();
30
31        match parts[0] {
32            "CPU" => Ok(DeviceSpec::Cpu),
33            #[cfg(feature = "cuda")]
34            "CUDA" | "GPU" => {
35                let device_id = if parts.len() > 1 {
36                    parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
37                } else {
38                    0
39                };
40                Ok(DeviceSpec::Cuda { device_id })
41            }
42            #[cfg(not(feature = "cuda"))]
43            "CUDA" | "GPU" => {
44                let device_id = if parts.len() > 1 {
45                    parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
46                } else {
47                    0
48                };
49                Ok(DeviceSpec::Cuda { device_id })
50            }
51            #[cfg(feature = "metal")]
52            "METAL" => {
53                let device_id = if parts.len() > 1 {
54                    parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
55                } else {
56                    0
57                };
58                Ok(DeviceSpec::Metal { device_id })
59            }
60            #[cfg(not(feature = "metal"))]
61            "METAL" => {
62                let device_id = if parts.len() > 1 {
63                    parts[1].parse().map_err(|_| crate::error::Error::InvalidDevice { device: s.to_string() })?
64                } else {
65                    0
66                };
67                Ok(DeviceSpec::Metal { device_id })
68            }
69            #[cfg(feature = "webgpu")]
70            "WEBGPU" => Ok(DeviceSpec::WebGpu),
71            #[cfg(not(feature = "webgpu"))]
72            "WEBGPU" => Ok(DeviceSpec::WebGpu),
73            _ => InvalidDeviceSnafu { device: s }.fail(),
74        }
75    }
76}
77
78#[derive(Default)]
79pub struct DeviceRegistry {
80    devices: RwLock<HashMap<DeviceSpec, Arc<dyn Allocator>>>,
81}
82
83impl DeviceRegistry {
84    /// Get or create a device allocator.
85    pub fn get(&self, spec: &DeviceSpec) -> Result<Arc<dyn Allocator>> {
86        // Fast path: read lock
87        {
88            let devices = self.devices.read();
89            if let Some(allocator) = devices.get(spec) {
90                return Ok(Arc::clone(allocator));
91            }
92        }
93
94        // Slow path: write lock to create
95        let mut devices = self.devices.write();
96
97        // Double-check after acquiring write lock
98        if let Some(allocator) = devices.get(spec) {
99            return Ok(Arc::clone(allocator));
100        }
101
102        // Create new allocator
103        let allocator = self.create_allocator(spec)?;
104        devices.insert(spec.clone(), Arc::clone(&allocator));
105        Ok(allocator)
106    }
107
108    /// Get a device by parsing a device string.
109    pub fn get_device(&self, device: &str) -> Result<Arc<dyn Allocator>> {
110        let spec = <DeviceSpec as DeviceSpecExt>::parse(device)?;
111        self.get(&spec)
112    }
113
114    fn create_allocator(&self, spec: &DeviceSpec) -> Result<Arc<dyn Allocator>> {
115        let base: Box<dyn Allocator> = match spec {
116            DeviceSpec::Cpu => Box::new(CpuAllocator),
117            #[cfg(feature = "cuda")]
118            DeviceSpec::Cuda { device_id } => Box::new(crate::allocator::CudaAllocator::new(*device_id)?),
119            #[cfg(not(feature = "cuda"))]
120            DeviceSpec::Cuda { .. } => unimplemented!("Cuda allocator - to be implemented"),
121            DeviceSpec::Metal { .. } => unimplemented!("Metal allocator - to be implemented"),
122            DeviceSpec::WebGpu => unimplemented!("WebGPU allocator - to be implemented"),
123        };
124
125        // Wrap with LRU cache (already thread-safe via Mutex)
126        let lru = LruAllocator::new(base);
127
128        Ok(Arc::new(lru))
129    }
130}
131
132/// Global device registry instance.
133static REGISTRY: Lazy<DeviceRegistry> = Lazy::new(DeviceRegistry::default);
134
135/// Get the global device registry.
136pub fn registry() -> &'static DeviceRegistry {
137    &REGISTRY
138}
139
140/// Convenience function to get a device allocator by string.
141pub fn get_device(device: &str) -> Result<Arc<dyn Allocator>> {
142    registry().get_device(device)
143}
144
145/// Convenience function to get CPU allocator.
146pub fn cpu() -> Result<Arc<dyn Allocator>> {
147    registry().get(&DeviceSpec::Cpu)
148}
149
150/// Convenience function to get CUDA allocator.
151#[cfg(feature = "cuda")]
152pub fn cuda(device_id: usize) -> Result<Arc<dyn Allocator>> {
153    registry().get(&DeviceSpec::Cuda { device_id })
154}