1use std::alloc::Layout;
2use std::cell::UnsafeCell;
3use std::collections::HashMap;
4use std::ops::{Deref, DerefMut};
5use std::ptr::NonNull;
6use std::sync::Mutex;
7
8#[cfg(feature = "cuda")]
9use cudarc::driver::{CudaContext, CudaSlice, UnifiedSlice};
10#[cfg(feature = "cuda")]
11use snafu::ResultExt;
12#[cfg(feature = "cuda")]
13use std::sync::Arc;
14
15use crate::error::*;
16
17pub struct AlignedBuffer {
24 ptr: NonNull<u8>,
25 len: usize,
26}
27
28const BUFFER_ALIGN: usize = 64;
29
30impl AlignedBuffer {
31 pub fn new_zeroed(size: usize) -> Self {
32 if size == 0 {
33 return Self { ptr: NonNull::dangling(), len: 0 };
34 }
35 let layout = Layout::from_size_align(size, BUFFER_ALIGN).expect("invalid buffer layout");
36 let ptr = unsafe { std::alloc::alloc_zeroed(layout) };
37 let ptr = NonNull::new(ptr).unwrap_or_else(|| std::alloc::handle_alloc_error(layout));
38 Self { ptr, len: size }
39 }
40
41 pub fn len(&self) -> usize {
42 self.len
43 }
44
45 pub fn is_empty(&self) -> bool {
46 self.len == 0
47 }
48}
49
50impl Deref for AlignedBuffer {
51 type Target = [u8];
52 fn deref(&self) -> &[u8] {
53 if self.len == 0 { &[] } else { unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) } }
54 }
55}
56
57impl DerefMut for AlignedBuffer {
58 fn deref_mut(&mut self) -> &mut [u8] {
59 if self.len == 0 { &mut [] } else { unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) } }
60 }
61}
62
63impl Drop for AlignedBuffer {
64 fn drop(&mut self) {
65 if self.len > 0 {
66 let layout = Layout::from_size_align(self.len, BUFFER_ALIGN).unwrap();
67 unsafe { std::alloc::dealloc(self.ptr.as_ptr(), layout) };
68 }
69 }
70}
71
72pub enum RawBuffer {
88 Cpu {
89 data: UnsafeCell<AlignedBuffer>,
90 cpu_accessible: bool,
91 },
92 Mmap {
94 data: memmap2::Mmap,
95 size: usize,
96 },
97 #[cfg(feature = "cuda")]
98 CudaDevice {
99 data: UnsafeCell<CudaSlice<u8>>,
100 device: Arc<CudaContext>,
101 },
102 #[cfg(feature = "cuda")]
103 CudaUnified {
104 data: UnsafeCell<UnifiedSlice<u8>>,
105 device: Arc<CudaContext>,
106 },
107}
108
109unsafe impl Send for RawBuffer {}
112unsafe impl Sync for RawBuffer {}
113
114impl std::fmt::Debug for RawBuffer {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 match self {
118 RawBuffer::Cpu { cpu_accessible, .. } => {
119 f.debug_struct("Cpu").field("cpu_accessible", cpu_accessible).finish_non_exhaustive()
120 }
121 RawBuffer::Mmap { size, .. } => f.debug_struct("Mmap").field("size", size).finish_non_exhaustive(),
122 #[cfg(feature = "cuda")]
123 RawBuffer::CudaDevice { device, .. } => {
124 f.debug_struct("CudaDevice").field("device", device).finish_non_exhaustive()
125 }
126 #[cfg(feature = "cuda")]
127 RawBuffer::CudaUnified { device, .. } => {
128 f.debug_struct("CudaUnified").field("device", device).finish_non_exhaustive()
129 }
130 }
131 }
132}
133
134impl RawBuffer {
135 pub fn size(&self) -> usize {
137 match self {
139 RawBuffer::Cpu { data, .. } => unsafe { (&*data.get()).len() },
140 RawBuffer::Mmap { size, .. } => *size,
141 #[cfg(feature = "cuda")]
142 RawBuffer::CudaDevice { data, .. } => unsafe { (&*data.get()).len() },
143 #[cfg(feature = "cuda")]
144 RawBuffer::CudaUnified { data, .. } => unsafe { (&*data.get()).len() },
145 }
146 }
147
148 pub fn cpu_accessible(&self) -> bool {
150 match self {
151 RawBuffer::Cpu { cpu_accessible, .. } => *cpu_accessible,
152 RawBuffer::Mmap { .. } => true,
153 #[cfg(feature = "cuda")]
154 RawBuffer::CudaDevice { .. } => false,
155 #[cfg(feature = "cuda")]
156 RawBuffer::CudaUnified { .. } => true,
157 }
158 }
159}
160
161#[derive(Debug, Clone)]
163#[cfg_attr(feature = "proptest", derive(proptest_derive::Arbitrary))]
164pub struct BufferOptions {
165 pub zero_init: bool,
167 pub cpu_accessible: bool,
172}
173
174impl Default for BufferOptions {
175 fn default() -> Self {
176 Self { zero_init: false, cpu_accessible: true }
177 }
178}
179
180pub trait Allocator: Send + Sync + std::fmt::Debug {
181 fn alloc(&self, size: usize, options: &BufferOptions) -> Result<RawBuffer>;
182 fn free(&self, _buffer: RawBuffer, _options: &BufferOptions) {}
183 fn synchronize(&self) -> Result<()> {
184 Ok(())
185 }
186 fn name(&self) -> &str;
187
188 fn device_spec(&self) -> morok_dtype::DeviceSpec;
190}
191
192#[derive(Debug, Clone)]
194pub struct CpuAllocator;
195
196impl Allocator for CpuAllocator {
197 fn alloc(&self, size: usize, options: &BufferOptions) -> Result<RawBuffer> {
198 let data = AlignedBuffer::new_zeroed(size);
199 Ok(RawBuffer::Cpu { data: UnsafeCell::new(data), cpu_accessible: options.cpu_accessible })
200 }
201
202 fn name(&self) -> &str {
203 "CPU"
204 }
205
206 fn device_spec(&self) -> morok_dtype::DeviceSpec {
207 morok_dtype::DeviceSpec::Cpu
208 }
209}
210
211#[derive(Debug, Clone)]
214pub struct DiskAllocator {
215 path: std::path::PathBuf,
216}
217
218impl DiskAllocator {
219 pub fn new(path: std::path::PathBuf) -> Self {
220 Self { path }
221 }
222}
223
224impl Allocator for DiskAllocator {
225 fn alloc(&self, size: usize, _options: &BufferOptions) -> Result<RawBuffer> {
226 let file = std::fs::File::open(&self.path).map_err(|e| crate::Error::CopyFailed {
227 reason: format!("DISK: failed to open {}: {e}", self.path.display()),
228 })?;
229 let file_size = file
230 .metadata()
231 .map_err(|e| crate::Error::CopyFailed {
232 reason: format!("DISK: failed to read metadata for {}: {e}", self.path.display()),
233 })?
234 .len() as usize;
235 if size > file_size {
236 return Err(crate::Error::CopyFailed {
237 reason: format!("DISK: requested {size} bytes but {} is only {file_size} bytes", self.path.display()),
238 });
239 }
240 let mmap = unsafe { memmap2::Mmap::map(&file) }.map_err(|e| crate::Error::CopyFailed {
241 reason: format!("DISK: mmap failed for {}: {e}", self.path.display()),
242 })?;
243 Ok(RawBuffer::Mmap { data: mmap, size })
244 }
245
246 fn name(&self) -> &str {
247 "DISK"
248 }
249
250 fn device_spec(&self) -> morok_dtype::DeviceSpec {
251 morok_dtype::DeviceSpec::Disk { path: self.path.clone() }
252 }
253}
254
255#[cfg(feature = "cuda")]
257#[derive(Debug, Clone)]
258pub struct CudaAllocator {
259 device: Arc<CudaContext>,
260 device_id: usize,
261}
262
263#[cfg(feature = "cuda")]
264impl CudaAllocator {
265 pub fn new(device_id: usize) -> Result<Self> {
266 let device = CudaContext::new(device_id).context(CudaSnafu)?;
267 Ok(Self { device, device_id })
268 }
269
270 pub fn device_id(&self) -> usize {
271 self.device_id
272 }
273}
274
275#[cfg(feature = "cuda")]
276impl Allocator for CudaAllocator {
277 fn alloc(&self, size: usize, options: &BufferOptions) -> Result<RawBuffer> {
278 if options.cpu_accessible {
279 let mut data = unsafe { self.device.alloc_unified::<u8>(size, true) }.context(CudaSnafu)?;
281
282 if options.zero_init {
283 self.device.default_stream().memset_zeros(&mut data).context(CudaSnafu)?;
284 }
285
286 Ok(RawBuffer::CudaUnified { data: UnsafeCell::new(data), device: Arc::clone(&self.device) })
287 } else {
288 let stream = self.device.default_stream();
290 let data =
291 if options.zero_init { stream.alloc_zeros::<u8>(size) } else { unsafe { stream.alloc::<u8>(size) } }
292 .context(CudaSnafu)?;
293
294 Ok(RawBuffer::CudaDevice { data: UnsafeCell::new(data), device: Arc::clone(&self.device) })
295 }
296 }
297
298 fn synchronize(&self) -> Result<()> {
299 self.device.default_stream().synchronize().context(CudaSnafu)
300 }
301
302 fn name(&self) -> &str {
303 "CUDA"
304 }
305
306 fn device_spec(&self) -> morok_dtype::DeviceSpec {
307 morok_dtype::DeviceSpec::Cuda { device_id: self.device_id }
308 }
309}
310
311#[derive(Debug, Clone, Hash, Eq, PartialEq)]
323struct CacheKey {
324 size: usize,
325 cpu_accessible: bool,
326}
327
328#[derive(Debug)]
330pub(crate) struct LruAllocator {
331 inner: Box<dyn Allocator>,
332 cache: Mutex<HashMap<CacheKey, Vec<RawBuffer>>>,
333 max_buffers_per_size: usize,
334 name: String,
335}
336
337impl LruAllocator {
338 pub fn new(inner: Box<dyn Allocator>) -> Self {
339 Self::with_capacity(inner, 32)
340 }
341
342 pub fn with_capacity(inner: Box<dyn Allocator>, max_buffers_per_size: usize) -> Self {
343 let name = inner.name().to_string();
344 Self { inner, cache: Mutex::new(HashMap::new()), max_buffers_per_size, name }
345 }
346
347 #[cfg(test)]
350 pub(crate) fn cache_count(&self, size: usize, cpu_accessible: bool) -> usize {
351 let key = CacheKey { size, cpu_accessible };
352 let cache = self.cache.lock().unwrap();
353 cache.get(&key).map(|v| v.len()).unwrap_or(0)
354 }
355
356 #[cfg(test)]
359 #[allow(dead_code)]
360 pub(crate) fn total_cached(&self) -> usize {
361 let cache = self.cache.lock().unwrap();
362 cache.values().map(|v| v.len()).sum()
363 }
364}
365
366impl Allocator for LruAllocator {
367 fn alloc(&self, size: usize, options: &BufferOptions) -> Result<RawBuffer> {
368 let key = CacheKey { size, cpu_accessible: options.cpu_accessible };
369
370 let buffer = {
372 let mut cache = self.cache.lock().unwrap();
373 if let Some(buffers) = cache.get_mut(&key)
374 && let Some(buffer) = buffers.pop()
375 {
376 if buffers.is_empty() {
377 cache.remove(&key);
378 }
379 Some(buffer)
380 } else {
381 None
382 }
383 }; if let Some(buffer) = buffer {
387 if options.zero_init {
388 match &buffer {
391 RawBuffer::Cpu { data, .. } => {
392 unsafe { (*data.get()).fill(0) };
393 }
394 RawBuffer::Mmap { .. } => panic!("DISK device is read-only: cannot zero-init mmap buffer"),
395 #[cfg(feature = "cuda")]
396 RawBuffer::CudaDevice { data, device } => {
397 let cuda_data = unsafe { &mut *data.get() };
398 device.default_stream().memset_zeros(cuda_data).context(CudaSnafu)?;
399 }
400 #[cfg(feature = "cuda")]
401 RawBuffer::CudaUnified { data, device } => {
402 let unified_data = unsafe { &mut *data.get() };
403 device.default_stream().memset_zeros(unified_data).context(CudaSnafu)?;
404 }
405 }
406 }
407 return Ok(buffer);
408 }
409
410 match self.inner.alloc(size, options) {
412 Ok(buffer) => Ok(buffer),
413 Err(e) => {
414 self.cache.lock().unwrap().clear();
416 self.inner.alloc(size, options).map_err(|_| e)
417 }
418 }
419 }
420
421 fn free(&self, buffer: RawBuffer, options: &BufferOptions) {
422 let key = CacheKey { size: buffer.size(), cpu_accessible: options.cpu_accessible };
423
424 let mut cache = self.cache.lock().unwrap();
425 let buffers = cache.entry(key).or_default();
426 if buffers.len() < self.max_buffers_per_size {
427 buffers.push(buffer);
428 }
429 }
430
431 fn synchronize(&self) -> Result<()> {
432 self.inner.synchronize()
433 }
434
435 fn name(&self) -> &str {
436 &self.name
437 }
438
439 fn device_spec(&self) -> morok_dtype::DeviceSpec {
440 self.inner.device_spec()
441 }
442}