1use std::sync::atomic::{AtomicU64, Ordering};
2use std::sync::{Arc, OnceLock};
3
4use morok_dtype::DType;
5use smallvec::{SmallVec, smallvec};
6
7use morok_dtype::ext::HasDType;
8use snafu::ResultExt;
9
10use crate::allocator::{Allocator, BufferOptions, RawBuffer};
11use crate::error::{
12 InvalidViewSnafu, NdarrayShapeSnafu, NotCpuAccessibleSnafu, Result, SizeMismatchSnafu, TypeMismatchSnafu,
13};
14
15static BUFFER_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
20
21fn next_buffer_id() -> u64 {
22 BUFFER_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub struct BufferId(pub u64);
31
32#[cfg(feature = "cuda")]
33use crate::error::CudaSnafu;
34#[cfg(feature = "cuda")]
35use snafu::ResultExt;
36
37#[derive(Debug)]
39struct BufferData {
40 id: BufferId,
42 raw: OnceLock<RawBuffer>,
44 allocator: Arc<dyn Allocator>,
45 total_size: usize,
47 options: BufferOptions,
49}
50
51impl BufferData {
52 fn new(allocator: Arc<dyn Allocator>, size: usize, options: BufferOptions) -> Self {
53 Self { id: BufferId(next_buffer_id()), raw: OnceLock::new(), allocator, total_size: size, options }
54 }
55
56 fn ensure_allocated(&self) -> Result<()> {
59 if self.raw.get().is_some() {
60 return Ok(());
61 }
62
63 let raw = self.allocator.alloc(self.total_size, &self.options)?;
65
66 if let Err(raw) = self.raw.set(raw) {
68 self.allocator.free(raw, &self.options);
70 }
71
72 Ok(())
73 }
74
75 fn is_allocated(&self) -> bool {
77 self.raw.get().is_some()
78 }
79
80 fn raw(&self) -> &RawBuffer {
82 self.raw.get().expect("buffer not allocated")
83 }
84}
85
86impl Drop for BufferData {
87 fn drop(&mut self) {
88 if let Some(raw) = self.raw.take() {
90 self.allocator.free(raw, &self.options);
91 }
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct Buffer {
98 data: Arc<BufferData>,
100 offset: usize,
102 size: usize,
104 dtype: DType,
106 shape: SmallVec<[usize; 4]>,
108}
109
110impl Buffer {
111 pub fn new(allocator: Arc<dyn Allocator>, dtype: DType, shape: Vec<usize>, options: BufferOptions) -> Self {
113 let size = dtype.bytes() * shape.iter().product::<usize>();
114 Self {
115 data: Arc::new(BufferData::new(allocator, size, options)),
116 offset: 0,
117 size,
118 dtype,
119 shape: SmallVec::from_vec(shape),
120 }
121 }
122
123 pub fn allocate(
125 allocator: Arc<dyn Allocator>,
126 dtype: DType,
127 shape: Vec<usize>,
128 options: BufferOptions,
129 ) -> Result<Self> {
130 let buffer = Self::new(allocator, dtype, shape, options);
131 buffer.ensure_allocated()?;
132 Ok(buffer)
133 }
134
135 pub fn view(&self, offset: usize, size: usize) -> Result<Self> {
137 if offset + size > self.size {
139 return InvalidViewSnafu { offset, size, buffer_size: self.size }.fail();
140 }
141
142 Ok(Self {
143 data: Arc::clone(&self.data),
144 offset: self.offset + offset,
145 size,
146 dtype: self.dtype.clone(),
147 shape: smallvec![size / self.dtype.bytes()],
149 })
150 }
151
152 pub fn ensure_allocated(&self) -> Result<()> {
154 self.data.ensure_allocated()
155 }
156
157 pub fn is_allocated(&self) -> bool {
159 self.data.is_allocated()
160 }
161
162 pub fn size(&self) -> usize {
164 self.size
165 }
166
167 pub fn offset(&self) -> usize {
169 self.offset
170 }
171
172 pub fn dtype(&self) -> DType {
174 self.dtype.clone()
175 }
176
177 pub fn shape(&self) -> &[usize] {
179 &self.shape
180 }
181
182 pub fn as_host_bytes(&self) -> Result<&[u8]> {
191 self.ensure_allocated()?;
192 let raw = self.data.raw();
193 match raw {
194 RawBuffer::Cpu { data, .. } => {
195 let bytes = unsafe { &(&(*data.get()))[self.offset..self.offset + self.size] };
199 Ok(bytes)
200 }
201 RawBuffer::Mmap { data, .. } => Ok(&data[self.offset..self.offset + self.size]),
202 #[cfg(feature = "cuda")]
203 _ => NotCpuAccessibleSnafu.fail(),
204 }
205 }
206
207 #[allow(clippy::mut_from_ref)] pub fn as_host_bytes_mut(&self) -> Result<&mut [u8]> {
217 self.ensure_allocated()?;
218 let raw = self.data.raw();
219 match raw {
220 RawBuffer::Cpu { data, .. } => {
221 let bytes = unsafe { &mut (&mut *data.get())[self.offset..self.offset + self.size] };
225 Ok(bytes)
226 }
227 RawBuffer::Mmap { .. } => NotCpuAccessibleSnafu.fail(),
229 #[cfg(feature = "cuda")]
230 _ => NotCpuAccessibleSnafu.fail(),
231 }
232 }
233
234 pub fn as_array<T: HasDType>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
244 self.ensure_allocated()?;
245 if self.dtype != T::DTYPE {
246 return TypeMismatchSnafu { expected: T::DTYPE, actual: self.dtype.clone() }.fail();
247 }
248 let raw = self.data.raw();
249 match raw {
250 RawBuffer::Cpu { data, .. } => {
251 let bytes = unsafe { &(&(*data.get()))[self.offset..self.offset + self.size] };
252 let count = bytes.len() / T::DTYPE.bytes();
253 let typed = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const T, count) };
254 ndarray::ArrayViewD::from_shape(ndarray::IxDyn(&self.shape), typed).context(NdarrayShapeSnafu)
255 }
256 RawBuffer::Mmap { data, .. } => {
257 let bytes = &data[self.offset..self.offset + self.size];
258 let count = bytes.len() / T::DTYPE.bytes();
259 let typed = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const T, count) };
260 ndarray::ArrayViewD::from_shape(ndarray::IxDyn(&self.shape), typed).context(NdarrayShapeSnafu)
261 }
262 #[cfg(feature = "cuda")]
263 _ => NotCpuAccessibleSnafu.fail(),
264 }
265 }
266
267 #[allow(clippy::mut_from_ref)]
275 pub fn as_array_mut<T: HasDType>(&self) -> Result<ndarray::ArrayViewMutD<'_, T>> {
276 self.ensure_allocated()?;
277 if self.dtype != T::DTYPE {
278 return TypeMismatchSnafu { expected: T::DTYPE, actual: self.dtype.clone() }.fail();
279 }
280 let raw = self.data.raw();
281 match raw {
282 RawBuffer::Cpu { data, cpu_accessible } if *cpu_accessible => {
283 let bytes = unsafe { &mut (&mut *data.get())[self.offset..self.offset + self.size] };
284 let count = bytes.len() / T::DTYPE.bytes();
285 let typed = unsafe { std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut T, count) };
286 ndarray::ArrayViewMutD::from_shape(ndarray::IxDyn(&self.shape), typed).context(NdarrayShapeSnafu)
287 }
288 _ => NotCpuAccessibleSnafu.fail(),
289 }
290 }
291
292 pub fn as_slice<T: HasDType>(&self) -> Result<&[T]> {
294 self.ensure_allocated()?;
295 if self.dtype != T::DTYPE {
296 return TypeMismatchSnafu { expected: T::DTYPE, actual: self.dtype.clone() }.fail();
297 }
298 let raw = self.data.raw();
299 match raw {
300 RawBuffer::Cpu { data, cpu_accessible } if *cpu_accessible => {
301 let bytes = unsafe { &(&(*data.get()))[self.offset..self.offset + self.size] };
302 let count = bytes.len() / T::DTYPE.bytes();
303 Ok(unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const T, count) })
304 }
305 _ => NotCpuAccessibleSnafu.fail(),
306 }
307 }
308
309 pub fn item<T: HasDType + Copy>(&self) -> Result<T> {
313 let slice = self.as_slice::<T>()?;
314 assert_eq!(slice.len(), 1, "item() requires exactly 1 element, got {}", slice.len());
315 Ok(slice[0])
316 }
317
318 pub fn allocator(&self) -> &dyn Allocator {
320 &*self.data.allocator
321 }
322
323 pub fn id(&self) -> BufferId {
328 self.data.id
329 }
330
331 pub fn copyin(&mut self, src: &[u8]) -> Result<()> {
333 self.ensure_allocated()?;
334
335 let expected = self.size;
336 let actual = src.len();
337 snafu::ensure!(expected == actual, SizeMismatchSnafu { expected, actual });
338
339 let raw = self.data.raw();
340 match raw {
341 RawBuffer::Cpu { data, .. } => {
342 let slice = unsafe {
344 let data_mut = &mut *data.get();
345 &mut data_mut[self.offset..self.offset + self.size]
346 };
347 slice.copy_from_slice(src);
348 Ok(())
349 }
350 RawBuffer::Mmap { .. } => panic!("DISK device is read-only: copyin not supported"),
351 #[cfg(feature = "cuda")]
352 RawBuffer::CudaDevice { data, device } => {
353 let cuda_data = unsafe { &mut *data.get() };
355 let mut view = cuda_data.slice_mut(self.offset..self.offset + self.size);
356 device.default_stream().memcpy_htod(src, &mut view).context(CudaSnafu)
357 }
358 #[cfg(feature = "cuda")]
359 RawBuffer::CudaUnified { data, .. } => {
360 let unified_data = unsafe { &mut *data.get() };
362 let slice = unified_data.as_mut_slice().context(CudaSnafu)?;
363 let target = &mut slice[self.offset..self.offset + self.size];
364 target.copy_from_slice(src);
365 Ok(())
366 }
367 }
368 }
369
370 pub fn copyout(&self, dst: &mut [u8]) -> Result<()> {
372 self.ensure_allocated()?;
373
374 let expected = self.size;
375 let actual = dst.len();
376 snafu::ensure!(expected == actual, SizeMismatchSnafu { expected, actual });
377
378 let raw = self.data.raw();
379 match raw {
380 RawBuffer::Cpu { data, .. } => {
381 let data_ref = unsafe { &*data.get() };
383 dst.copy_from_slice(&data_ref[self.offset..self.offset + self.size]);
384 Ok(())
385 }
386 RawBuffer::Mmap { data, .. } => {
387 dst.copy_from_slice(&data[self.offset..self.offset + self.size]);
388 Ok(())
389 }
390 #[cfg(feature = "cuda")]
391 RawBuffer::CudaDevice { data, device } => {
392 device.synchronize().context(CudaSnafu)?;
393 let cuda_data = unsafe { &*data.get() };
395 let view = cuda_data.slice(self.offset..self.offset + self.size);
396 device.default_stream().memcpy_dtoh(&view, dst).context(CudaSnafu)
397 }
398 #[cfg(feature = "cuda")]
399 RawBuffer::CudaUnified { data, .. } => {
400 let unified_data = unsafe { &*data.get() };
402 let slice = unified_data.as_slice().context(CudaSnafu)?;
403 let source = &slice[self.offset..self.offset + self.size];
404 dst.copy_from_slice(source);
405 Ok(())
406 }
407 }
408 }
409
410 pub fn copy_from(&mut self, src: &Buffer) -> Result<()> {
412 self.ensure_allocated()?;
413 src.ensure_allocated()?;
414
415 let expected = self.size;
416 let actual = src.size;
417 snafu::ensure!(expected == actual, SizeMismatchSnafu { expected, actual });
418
419 let dst_raw = self.data.raw();
420 let src_raw = src.data.raw();
421
422 match (dst_raw, src_raw) {
425 (RawBuffer::Cpu { data: dst_data, .. }, RawBuffer::Cpu { data: src_data, .. }) => {
427 let dst_mut = unsafe { &mut *dst_data.get() };
428 let src_ref = unsafe { &*src_data.get() };
429 let dst_slice = &mut dst_mut[self.offset..self.offset + self.size];
430 let src_slice = &src_ref[src.offset..src.offset + src.size];
431 dst_slice.copy_from_slice(src_slice);
432 Ok(())
433 }
434 (RawBuffer::Cpu { data: dst_data, .. }, RawBuffer::Mmap { data: src_data, .. }) => {
436 let dst_mut = unsafe { &mut *dst_data.get() };
437 let dst_slice = &mut dst_mut[self.offset..self.offset + self.size];
438 let src_slice = &src_data[src.offset..src.offset + src.size];
439 dst_slice.copy_from_slice(src_slice);
440 Ok(())
441 }
442 (RawBuffer::Mmap { .. }, _) => panic!("DISK device is read-only: copy_from not supported"),
444 #[cfg(feature = "cuda")]
446 (
447 RawBuffer::CudaDevice { data: dst_data, device: dst_device },
448 RawBuffer::CudaDevice { data: src_data, .. },
449 ) => {
450 let dst_cuda = unsafe { &mut *dst_data.get() };
451 let src_cuda = unsafe { &*src_data.get() };
452 let mut dst_view = dst_cuda.slice_mut(self.offset..self.offset + self.size);
453 let src_view = src_cuda.slice(src.offset..src.offset + src.size);
454 dst_device.default_stream().memcpy_dtod(&src_view, &mut dst_view).context(CudaSnafu)
455 }
456 #[cfg(feature = "cuda")]
458 (RawBuffer::CudaDevice { data: dst_data, device }, RawBuffer::Cpu { data: src_data, .. }) => {
459 let dst_cuda = unsafe { &mut *dst_data.get() };
460 let src_ref = unsafe { &*src_data.get() };
461 let mut dst_view = dst_cuda.slice_mut(self.offset..self.offset + self.size);
462 let src_slice = &src_ref[src.offset..src.offset + src.size];
463 device.default_stream().memcpy_htod(src_slice, &mut dst_view).context(CudaSnafu)
464 }
465 #[cfg(feature = "cuda")]
467 (RawBuffer::Cpu { data: dst_data, .. }, RawBuffer::CudaDevice { data: src_data, device }) => {
468 let dst_mut = unsafe { &mut *dst_data.get() };
469 let src_cuda = unsafe { &*src_data.get() };
470 let dst_slice = &mut dst_mut[self.offset..self.offset + self.size];
471 let src_view = src_cuda.slice(src.offset..src.offset + src.size);
472 device.default_stream().memcpy_dtoh(&src_view, dst_slice).context(CudaSnafu)
473 }
474 #[cfg(feature = "cuda")]
476 (RawBuffer::CudaUnified { data: dst_data, .. }, RawBuffer::CudaUnified { data: src_data, .. }) => {
477 let dst_unified = unsafe { &mut *dst_data.get() };
478 let src_unified = unsafe { &*src_data.get() };
479 let dst_slice = dst_unified.as_mut_slice().context(CudaSnafu)?;
480 let src_slice = src_unified.as_slice().context(CudaSnafu)?;
481 let dst_target = &mut dst_slice[self.offset..self.offset + self.size];
482 let src_source = &src_slice[src.offset..src.offset + src.size];
483 dst_target.copy_from_slice(src_source);
484 Ok(())
485 }
486 #[cfg(feature = "cuda")]
488 (RawBuffer::CudaUnified { data: dst_data, .. }, RawBuffer::Cpu { data: src_data, .. }) => {
489 let dst_unified = unsafe { &mut *dst_data.get() };
490 let src_ref = unsafe { &*src_data.get() };
491 let dst_slice = dst_unified.as_mut_slice().context(CudaSnafu)?;
492 let dst_target = &mut dst_slice[self.offset..self.offset + self.size];
493 let src_source = &src_ref[src.offset..src.offset + src.size];
494 dst_target.copy_from_slice(src_source);
495 Ok(())
496 }
497 #[cfg(feature = "cuda")]
499 (RawBuffer::Cpu { data: dst_data, .. }, RawBuffer::CudaUnified { data: src_data, .. }) => {
500 let dst_mut = unsafe { &mut *dst_data.get() };
501 let src_unified = unsafe { &*src_data.get() };
502 let src_slice = src_unified.as_slice().context(CudaSnafu)?;
503 let dst_target = &mut dst_mut[self.offset..self.offset + self.size];
504 let src_source = &src_slice[src.offset..src.offset + src.size];
505 dst_target.copy_from_slice(src_source);
506 Ok(())
507 }
508 #[cfg(feature = "cuda")]
510 (
511 RawBuffer::CudaUnified { data: dst_data, device: dst_device },
512 RawBuffer::CudaDevice { data: src_data, .. },
513 ) => {
514 let src_cuda = unsafe { &*src_data.get() };
515 let src_view = src_cuda.slice(src.offset..src.offset + src.size);
516 let dst_unified = unsafe { &mut *dst_data.get() };
518 let mut dst_target = dst_unified.slice_mut(self.offset..self.offset + self.size);
519 dst_device.default_stream().memcpy_dtod(&src_view, &mut dst_target).context(CudaSnafu)
521 }
522 #[cfg(feature = "cuda")]
524 (RawBuffer::CudaDevice { data: dst_data, device }, RawBuffer::CudaUnified { data: src_data, .. }) => {
525 let dst_cuda = unsafe { &mut *dst_data.get() };
526 let mut dst_view = dst_cuda.slice_mut(self.offset..self.offset + self.size);
527 let src_unified = unsafe { &*src_data.get() };
529 let src_source = src_unified.slice(src.offset..src.offset + src.size);
530 device.default_stream().memcpy_htod(&src_source, &mut dst_view).context(CudaSnafu)
532 }
533 }
534 }
535
536 pub fn synchronize(&self) -> Result<()> {
538 self.data.allocator.synchronize()
539 }
540
541 pub unsafe fn as_raw_ptr(&self) -> *mut u8 {
555 let raw = self.data.raw();
556 match raw {
557 RawBuffer::Cpu { data, .. } => {
558 unsafe { (&mut *data.get()).as_mut_ptr().add(self.offset) }
561 }
562 RawBuffer::Mmap { data, .. } => {
563 unsafe { data.as_ptr().add(self.offset) as *mut u8 }
565 }
566 #[cfg(feature = "cuda")]
567 RawBuffer::CudaDevice { .. } | RawBuffer::CudaUnified { .. } => {
568 unimplemented!("CUDA buffer raw pointers not yet supported for kernel execution")
572 }
573 }
574 }
575
576 #[cfg(test)]
581 pub(crate) fn raw_data_ptr(&self) -> usize {
582 let raw = self.data.raw();
583 match raw {
584 RawBuffer::Cpu { data, .. } => {
585 unsafe { (*data.get()).as_ptr() as usize }
587 }
588 RawBuffer::Mmap { data, .. } => data.as_ptr() as usize,
589 #[cfg(feature = "cuda")]
590 RawBuffer::CudaDevice { data, .. } => {
591 unsafe { &*data.get() as *const _ as usize }
594 }
595 #[cfg(feature = "cuda")]
596 RawBuffer::CudaUnified { data, .. } => {
597 unsafe { &*data.get() as *const _ as usize }
600 }
601 }
602 }
603}