1use std::{rc::Rc, sync::Arc};
2
3use thiserror::Error;
4
5#[derive(Copy, Clone, PartialEq, Eq, Debug, Error)]
8#[error("allocation error")]
9pub struct AllocError;
10
11#[derive(Copy, Clone, PartialEq, Eq, Debug, Error)]
12#[error("copy error")]
13pub struct CopyError;
14
15#[derive(Copy, Clone, PartialEq, Eq, Debug)]
17pub enum CopyDirection {
18 HostToDevice,
19 DeviceToHost,
20 DeviceToDevice,
21}
22
23pub trait DeviceMemory {
25 unsafe fn copy_nonoverlapping(
27 &self,
28 src: *const u8,
29 dst: *mut u8,
30 size: usize,
31 direction: CopyDirection,
32 ) -> Result<(), CopyError>;
33
34 unsafe fn write_bytes(&self, dst: *mut u8, value: u8, size: usize) -> Result<(), CopyError>;
38}
39
40impl<T: DeviceMemory> DeviceMemory for &T {
41 #[inline]
42 unsafe fn copy_nonoverlapping(
43 &self,
44 src: *const u8,
45 dst: *mut u8,
46 size: usize,
47 direction: CopyDirection,
48 ) -> Result<(), CopyError> {
49 (**self).copy_nonoverlapping(src, dst, size, direction)
50 }
51
52 #[inline]
53 unsafe fn write_bytes(&self, dst: *mut u8, value: u8, size: usize) -> Result<(), CopyError> {
54 (**self).write_bytes(dst, value, size)
55 }
56}
57
58impl<T: DeviceMemory> DeviceMemory for Rc<T> {
59 #[inline]
60 unsafe fn copy_nonoverlapping(
61 &self,
62 src: *const u8,
63 dst: *mut u8,
64 size: usize,
65 direction: CopyDirection,
66 ) -> Result<(), CopyError> {
67 (**self).copy_nonoverlapping(src, dst, size, direction)
68 }
69
70 #[inline]
71 unsafe fn write_bytes(&self, dst: *mut u8, value: u8, size: usize) -> Result<(), CopyError> {
72 (**self).write_bytes(dst, value, size)
73 }
74}
75
76impl<T: DeviceMemory> DeviceMemory for Arc<T> {
77 #[inline]
78 unsafe fn copy_nonoverlapping(
79 &self,
80 src: *const u8,
81 dst: *mut u8,
82 size: usize,
83 direction: CopyDirection,
84 ) -> Result<(), CopyError> {
85 (**self).copy_nonoverlapping(src, dst, size, direction)
86 }
87
88 #[inline]
89 unsafe fn write_bytes(&self, dst: *mut u8, value: u8, size: usize) -> Result<(), CopyError> {
90 (**self).write_bytes(dst, value, size)
91 }
92}