Skip to main content

morok_device/
allocator.rs

1use std::alloc::Layout;
2use std::cell::UnsafeCell;
3use std::collections::HashMap;
4use std::ops::{Deref, DerefMut};
5use std::ptr::NonNull;
6use std::sync::Mutex;
7
8#[cfg(feature = "cuda")]
9use cudarc::driver::{CudaContext, CudaSlice, UnifiedSlice};
10#[cfg(feature = "cuda")]
11use snafu::ResultExt;
12#[cfg(feature = "cuda")]
13use std::sync::Arc;
14
15use crate::error::*;
16
17/// 64-byte aligned buffer for SIMD operations (covers SSE/AVX/AVX-512).
18///
19/// The C codegen emits vector types with alignment attributes (e.g. `aligned(32)` for
20/// `double4`). Clang then generates aligned load/store instructions (`vmovaps`) that
21/// segfault on unaligned pointers. This buffer guarantees all allocations are
22/// 64-byte aligned to satisfy any current SIMD width.
23pub struct AlignedBuffer {
24    ptr: NonNull<u8>,
25    len: usize,
26}
27
28const BUFFER_ALIGN: usize = 64;
29
30impl AlignedBuffer {
31    pub fn new_zeroed(size: usize) -> Self {
32        if size == 0 {
33            return Self { ptr: NonNull::dangling(), len: 0 };
34        }
35        let layout = Layout::from_size_align(size, BUFFER_ALIGN).expect("invalid buffer layout");
36        let ptr = unsafe { std::alloc::alloc_zeroed(layout) };
37        let ptr = NonNull::new(ptr).unwrap_or_else(|| std::alloc::handle_alloc_error(layout));
38        Self { ptr, len: size }
39    }
40
41    pub fn len(&self) -> usize {
42        self.len
43    }
44
45    pub fn is_empty(&self) -> bool {
46        self.len == 0
47    }
48}
49
50impl Deref for AlignedBuffer {
51    type Target = [u8];
52    fn deref(&self) -> &[u8] {
53        if self.len == 0 { &[] } else { unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) } }
54    }
55}
56
57impl DerefMut for AlignedBuffer {
58    fn deref_mut(&mut self) -> &mut [u8] {
59        if self.len == 0 { &mut [] } else { unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) } }
60    }
61}
62
63impl Drop for AlignedBuffer {
64    fn drop(&mut self) {
65        if self.len > 0 {
66            let layout = Layout::from_size_align(self.len, BUFFER_ALIGN).unwrap();
67            unsafe { std::alloc::dealloc(self.ptr.as_ptr(), layout) };
68        }
69    }
70}
71
72/// Opaque handle to device memory.
73///
74/// # Safety
75///
76/// `RawBuffer` uses `UnsafeCell` for interior mutability without locking overhead.
77/// Thread safety is guaranteed at a higher level by the scheduler:
78///
79/// 1. **Allocation**: `OnceLock` in `BufferData` ensures single initialization
80/// 2. **Buffer Access**: The scheduler guarantees exclusive access to each buffer
81///    during kernel execution - no two kernels access the same buffer concurrently
82/// 3. **Kernel Execution**: Raw pointers passed to JIT code; Rust doesn't access
83///    buffer data during execution
84///
85/// This design follows Tinygrad's approach where buffer synchronization is the
86/// scheduler's responsibility, not the buffer's.
87pub enum RawBuffer {
88    Cpu {
89        data: UnsafeCell<AlignedBuffer>,
90        cpu_accessible: bool,
91    },
92    /// Memory-mapped file region (read-only). Used by DISK device.
93    Mmap {
94        data: memmap2::Mmap,
95        size: usize,
96    },
97    #[cfg(feature = "cuda")]
98    CudaDevice {
99        data: UnsafeCell<CudaSlice<u8>>,
100        device: Arc<CudaContext>,
101    },
102    #[cfg(feature = "cuda")]
103    CudaUnified {
104        data: UnsafeCell<UnifiedSlice<u8>>,
105        device: Arc<CudaContext>,
106    },
107}
108
109// SAFETY: RawBuffer access is synchronized by the scheduler at a higher level.
110// See RawBuffer documentation for detailed safety invariants.
111unsafe impl Send for RawBuffer {}
112unsafe impl Sync for RawBuffer {}
113
114// UnsafeCell doesn't implement Debug, so we implement it manually
115impl std::fmt::Debug for RawBuffer {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        match self {
118            RawBuffer::Cpu { cpu_accessible, .. } => {
119                f.debug_struct("Cpu").field("cpu_accessible", cpu_accessible).finish_non_exhaustive()
120            }
121            RawBuffer::Mmap { size, .. } => f.debug_struct("Mmap").field("size", size).finish_non_exhaustive(),
122            #[cfg(feature = "cuda")]
123            RawBuffer::CudaDevice { device, .. } => {
124                f.debug_struct("CudaDevice").field("device", device).finish_non_exhaustive()
125            }
126            #[cfg(feature = "cuda")]
127            RawBuffer::CudaUnified { device, .. } => {
128                f.debug_struct("CudaUnified").field("device", device).finish_non_exhaustive()
129            }
130        }
131    }
132}
133
134impl RawBuffer {
135    /// Get the size of the buffer in bytes.
136    pub fn size(&self) -> usize {
137        // SAFETY: Reading .len() doesn't alias with content access and is immutable after allocation
138        match self {
139            RawBuffer::Cpu { data, .. } => unsafe { (&*data.get()).len() },
140            RawBuffer::Mmap { size, .. } => *size,
141            #[cfg(feature = "cuda")]
142            RawBuffer::CudaDevice { data, .. } => unsafe { (&*data.get()).len() },
143            #[cfg(feature = "cuda")]
144            RawBuffer::CudaUnified { data, .. } => unsafe { (&*data.get()).len() },
145        }
146    }
147
148    /// Get whether this buffer is CPU-accessible.
149    pub fn cpu_accessible(&self) -> bool {
150        match self {
151            RawBuffer::Cpu { cpu_accessible, .. } => *cpu_accessible,
152            RawBuffer::Mmap { .. } => true,
153            #[cfg(feature = "cuda")]
154            RawBuffer::CudaDevice { .. } => false,
155            #[cfg(feature = "cuda")]
156            RawBuffer::CudaUnified { .. } => true,
157        }
158    }
159}
160
161/// Options for buffer allocation.
162#[derive(Debug, Clone)]
163#[cfg_attr(feature = "proptest", derive(proptest_derive::Arbitrary))]
164pub struct BufferOptions {
165    /// Whether to zero-initialize the buffer.
166    pub zero_init: bool,
167    /// Whether this buffer is CPU-accessible.
168    ///
169    /// CPU allocator: always true (host memory is always accessible).
170    /// CUDA allocator: false = device-only (cuMemAlloc), true = unified (cuMemAllocManaged).
171    pub cpu_accessible: bool,
172}
173
174impl Default for BufferOptions {
175    fn default() -> Self {
176        Self { zero_init: false, cpu_accessible: true }
177    }
178}
179
180pub trait Allocator: Send + Sync + std::fmt::Debug {
181    fn alloc(&self, size: usize, options: &BufferOptions) -> Result<RawBuffer>;
182    fn free(&self, _buffer: RawBuffer, _options: &BufferOptions) {}
183    fn synchronize(&self) -> Result<()> {
184        Ok(())
185    }
186    fn name(&self) -> &str;
187
188    /// Get the device specification for this allocator.
189    fn device_spec(&self) -> morok_dtype::DeviceSpec;
190}
191
192/// CPU allocator using system memory.
193#[derive(Debug, Clone)]
194pub struct CpuAllocator;
195
196impl Allocator for CpuAllocator {
197    fn alloc(&self, size: usize, options: &BufferOptions) -> Result<RawBuffer> {
198        let data = AlignedBuffer::new_zeroed(size);
199        Ok(RawBuffer::Cpu { data: UnsafeCell::new(data), cpu_accessible: options.cpu_accessible })
200    }
201
202    fn name(&self) -> &str {
203        "CPU"
204    }
205
206    fn device_spec(&self) -> morok_dtype::DeviceSpec {
207        morok_dtype::DeviceSpec::Cpu
208    }
209}
210
211/// DISK allocator using memory-mapped files (Tinygrad: ops_disk.py).
212/// Read-only — cannot execute kernels. Data is transferred via COPY.
213#[derive(Debug, Clone)]
214pub struct DiskAllocator {
215    path: std::path::PathBuf,
216}
217
218impl DiskAllocator {
219    pub fn new(path: std::path::PathBuf) -> Self {
220        Self { path }
221    }
222}
223
224impl Allocator for DiskAllocator {
225    fn alloc(&self, size: usize, _options: &BufferOptions) -> Result<RawBuffer> {
226        let file = std::fs::File::open(&self.path).map_err(|e| crate::Error::CopyFailed {
227            reason: format!("DISK: failed to open {}: {e}", self.path.display()),
228        })?;
229        let file_size = file
230            .metadata()
231            .map_err(|e| crate::Error::CopyFailed {
232                reason: format!("DISK: failed to read metadata for {}: {e}", self.path.display()),
233            })?
234            .len() as usize;
235        if size > file_size {
236            return Err(crate::Error::CopyFailed {
237                reason: format!("DISK: requested {size} bytes but {} is only {file_size} bytes", self.path.display()),
238            });
239        }
240        let mmap = unsafe { memmap2::Mmap::map(&file) }.map_err(|e| crate::Error::CopyFailed {
241            reason: format!("DISK: mmap failed for {}: {e}", self.path.display()),
242        })?;
243        Ok(RawBuffer::Mmap { data: mmap, size })
244    }
245
246    fn name(&self) -> &str {
247        "DISK"
248    }
249
250    fn device_spec(&self) -> morok_dtype::DeviceSpec {
251        morok_dtype::DeviceSpec::Disk { path: self.path.clone() }
252    }
253}
254
255/// CUDA allocator using GPU memory.
256#[cfg(feature = "cuda")]
257#[derive(Debug, Clone)]
258pub struct CudaAllocator {
259    device: Arc<CudaContext>,
260    device_id: usize,
261}
262
263#[cfg(feature = "cuda")]
264impl CudaAllocator {
265    pub fn new(device_id: usize) -> Result<Self> {
266        let device = CudaContext::new(device_id).context(CudaSnafu)?;
267        Ok(Self { device, device_id })
268    }
269
270    pub fn device_id(&self) -> usize {
271        self.device_id
272    }
273}
274
275#[cfg(feature = "cuda")]
276impl Allocator for CudaAllocator {
277    fn alloc(&self, size: usize, options: &BufferOptions) -> Result<RawBuffer> {
278        if options.cpu_accessible {
279            // Allocate unified memory (CPU-accessible)
280            let mut data = unsafe { self.device.alloc_unified::<u8>(size, true) }.context(CudaSnafu)?;
281
282            if options.zero_init {
283                self.device.default_stream().memset_zeros(&mut data).context(CudaSnafu)?;
284            }
285
286            Ok(RawBuffer::CudaUnified { data: UnsafeCell::new(data), device: Arc::clone(&self.device) })
287        } else {
288            // Allocate device-only memory (faster GPU access)
289            let stream = self.device.default_stream();
290            let data =
291                if options.zero_init { stream.alloc_zeros::<u8>(size) } else { unsafe { stream.alloc::<u8>(size) } }
292                    .context(CudaSnafu)?;
293
294            Ok(RawBuffer::CudaDevice { data: UnsafeCell::new(data), device: Arc::clone(&self.device) })
295        }
296    }
297
298    fn synchronize(&self) -> Result<()> {
299        self.device.default_stream().synchronize().context(CudaSnafu)
300    }
301
302    fn name(&self) -> &str {
303        "CUDA"
304    }
305
306    fn device_spec(&self) -> morok_dtype::DeviceSpec {
307        morok_dtype::DeviceSpec::Cuda { device_id: self.device_id }
308    }
309}
310
311/// Cache key for buffer reuse in LRU allocator.
312///
313/// Includes size and cpu_accessible (hardware property that affects allocation).
314/// zero_init is NOT included - it's a software operation handled after cache retrieval.
315///
316/// Design rationale (following Tinygrad):
317/// - cpu_accessible is included because it represents different memory types:
318///   - false: Device-only memory (cuMemAlloc) - faster GPU access
319///   - true: Unified memory (cuMemAllocManaged) - CPU-accessible, not yet implemented
320/// - These are immutable hardware properties that cannot be changed post-allocation
321/// - Buffers allocated with different cpu_accessible values cannot be safely reused
322#[derive(Debug, Clone, Hash, Eq, PartialEq)]
323struct CacheKey {
324    size: usize,
325    cpu_accessible: bool,
326}
327
328/// LRU allocator that caches freed buffers for reuse.
329#[derive(Debug)]
330pub(crate) struct LruAllocator {
331    inner: Box<dyn Allocator>,
332    cache: Mutex<HashMap<CacheKey, Vec<RawBuffer>>>,
333    max_buffers_per_size: usize,
334    name: String,
335}
336
337impl LruAllocator {
338    pub fn new(inner: Box<dyn Allocator>) -> Self {
339        Self::with_capacity(inner, 32)
340    }
341
342    pub fn with_capacity(inner: Box<dyn Allocator>, max_buffers_per_size: usize) -> Self {
343        let name = inner.name().to_string();
344        Self { inner, cache: Mutex::new(HashMap::new()), max_buffers_per_size, name }
345    }
346
347    /// Get the number of cached buffers for a specific size and cpu_accessible flag.
348    /// Only available in tests for cache introspection.
349    #[cfg(test)]
350    pub(crate) fn cache_count(&self, size: usize, cpu_accessible: bool) -> usize {
351        let key = CacheKey { size, cpu_accessible };
352        let cache = self.cache.lock().unwrap();
353        cache.get(&key).map(|v| v.len()).unwrap_or(0)
354    }
355
356    /// Get the total number of cached buffers across all keys.
357    /// Only available in tests for cache introspection.
358    #[cfg(test)]
359    #[allow(dead_code)]
360    pub(crate) fn total_cached(&self) -> usize {
361        let cache = self.cache.lock().unwrap();
362        cache.values().map(|v| v.len()).sum()
363    }
364}
365
366impl Allocator for LruAllocator {
367    fn alloc(&self, size: usize, options: &BufferOptions) -> Result<RawBuffer> {
368        let key = CacheKey { size, cpu_accessible: options.cpu_accessible };
369
370        // Try cache first
371        let buffer = {
372            let mut cache = self.cache.lock().unwrap();
373            if let Some(buffers) = cache.get_mut(&key)
374                && let Some(buffer) = buffers.pop()
375            {
376                if buffers.is_empty() {
377                    cache.remove(&key);
378                }
379                Some(buffer)
380            } else {
381                None
382            }
383        }; // Drop lock before expensive allocation
384
385        // If found in cache, optionally zero and return
386        if let Some(buffer) = buffer {
387            if options.zero_init {
388                // Zero the cached buffer if requested
389                // SAFETY: Buffer just retrieved from cache, not yet returned - no other references exist
390                match &buffer {
391                    RawBuffer::Cpu { data, .. } => {
392                        unsafe { (*data.get()).fill(0) };
393                    }
394                    RawBuffer::Mmap { .. } => panic!("DISK device is read-only: cannot zero-init mmap buffer"),
395                    #[cfg(feature = "cuda")]
396                    RawBuffer::CudaDevice { data, device } => {
397                        let cuda_data = unsafe { &mut *data.get() };
398                        device.default_stream().memset_zeros(cuda_data).context(CudaSnafu)?;
399                    }
400                    #[cfg(feature = "cuda")]
401                    RawBuffer::CudaUnified { data, device } => {
402                        let unified_data = unsafe { &mut *data.get() };
403                        device.default_stream().memset_zeros(unified_data).context(CudaSnafu)?;
404                    }
405                }
406            }
407            return Ok(buffer);
408        }
409
410        // Cache miss - allocate from inner
411        match self.inner.alloc(size, options) {
412            Ok(buffer) => Ok(buffer),
413            Err(e) => {
414                // On allocation failure, clear cache and retry
415                self.cache.lock().unwrap().clear();
416                self.inner.alloc(size, options).map_err(|_| e)
417            }
418        }
419    }
420
421    fn free(&self, buffer: RawBuffer, options: &BufferOptions) {
422        let key = CacheKey { size: buffer.size(), cpu_accessible: options.cpu_accessible };
423
424        let mut cache = self.cache.lock().unwrap();
425        let buffers = cache.entry(key).or_default();
426        if buffers.len() < self.max_buffers_per_size {
427            buffers.push(buffer);
428        }
429    }
430
431    fn synchronize(&self) -> Result<()> {
432        self.inner.synchronize()
433    }
434
435    fn name(&self) -> &str {
436        &self.name
437    }
438
439    fn device_spec(&self) -> morok_dtype::DeviceSpec {
440        self.inner.device_spec()
441    }
442}