hpt_common/utils/
pointer.rs

1use std::{
2    fmt::{Debug, Display, Formatter},
3    ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, SubAssign},
4};
5
6/// Pointer wrapper struct for raw pointers
7/// This is for wrapping raw pointers to make them safe for multithreading
8///
9/// This is for internal use only
10#[derive(Debug, Copy, Clone)]
11pub struct Pointer<T> {
12    /// raw pointer
13    pub ptr: *mut T,
14    /// len of the pointer, it is used when the `bound_check` feature is enabled
15    #[cfg(feature = "bound_check")]
16    pub len: i64,
17}
18
19impl<T> Pointer<T> {
20    /// return a slice of the pointer
21    ///
22    /// # Returns
23    /// `&[T]`
24    #[cfg(feature = "bound_check")]
25    pub fn as_slice(&self) -> &[T] {
26        unsafe { std::slice::from_raw_parts(self.ptr, self.len as usize) }
27    }
28    /// cast the pointer to a new type
29    ///
30    /// # Arguments
31    /// `U` - the new type
32    ///
33    /// # Returns
34    /// `Pointer<U>`
35    pub fn cast<U>(&self) -> Pointer<U> {
36        #[cfg(feature = "bound_check")]
37        return Pointer::new(self.ptr as *mut U, self.len);
38        #[cfg(not(feature = "bound_check"))]
39        return Pointer::new(self.ptr as *mut U);
40    }
41    /// return raw pointer
42    ///
43    /// # Returns
44    /// `*mut T`
45    ///
46    /// # Example
47    /// ```
48    /// use tensor_pointer::Pointer;
49    /// let mut _a = 10;
50    /// let a = Pointer::<i32>::new(_a as *mut i32);
51    /// let b = a.get_ptr();
52    /// assert_eq!(b, _a as *mut i32);
53    /// ```
54    #[inline(always)]
55    pub fn get_ptr(&self) -> *mut T {
56        self.ptr
57    }
58
59    /// Wrap a raw pointer into a Pointer struct for supporting `Send` in multithreading, zero cost
60    ///
61    /// # Arguments
62    /// `ptr` - `*mut T`
63    ///
64    /// # Returns
65    /// `Pointer<T>`
66    ///
67    /// # Example
68    /// ```
69    /// use tensor_pointer::Pointer;
70    /// let mut _a = 10i32;
71    /// let a = Pointer::<i32>::new(_a as *mut i32);
72    /// assert_eq!(a.read(), 10);
73    /// ```
74    #[cfg(not(feature = "bound_check"))]
75    #[inline(always)]
76    pub fn new(ptr: *mut T) -> Self {
77        Self { ptr }
78    }
79
80    /// Wrap a raw pointer into a Pointer struct for supporting `Send` in multithreading, zero cost
81    ///
82    /// # Arguments
83    /// `ptr` - `*mut T`
84    ///
85    /// # Returns
86    /// `Pointer<T>`
87    ///
88    /// # Example
89    /// ```
90    /// use tensor_pointer::Pointer;
91    /// let mut _a = 10i32;
92    /// let a = Pointer::<i32>::new(_a as *mut i32);
93    /// assert_eq!(a.read(), 10);
94    /// ```
95    #[cfg(feature = "bound_check")]
96    #[inline(always)]
97    pub fn new(ptr: *mut T, len: i64) -> Self {
98        Self { ptr, len }
99    }
100
101    /// modify the value of the pointer in the address by the specified offset
102    ///
103    /// # Arguments
104    /// `offset` - the offset from the current address
105    /// `value` - the value to be written
106    ///
107    /// # Example
108    /// ```
109    /// use tensor_pointer::Pointer;
110    /// let mut _a = unsafe { std::alloc::alloc(std::alloc::Layout::new::<i32>()) as *mut i32 };
111    /// let mut a = Pointer::<i32>::new(_a);
112    /// a.modify(0, 10);
113    /// assert_eq!(a.read(), 10);
114    /// unsafe { std::alloc::dealloc(_a as *mut u8, std::alloc::Layout::new::<i32>()); }
115    /// ```
116    #[inline(always)]
117    pub fn modify(&mut self, offset: i64, value: T) {
118        unsafe {
119            self.ptr.offset(offset as isize).write(value);
120        }
121    }
122
123    /// inplace increment the value of the pointer in the current address
124    ///
125    /// # Arguments
126    /// `value` - the value to be added
127    ///
128    /// # Example
129    /// ```
130    /// use tensor_pointer::Pointer;
131    /// let mut _a = unsafe { std::alloc::alloc(std::alloc::Layout::new::<i32>()) as *mut i32 };
132    /// unsafe { _a.write(10); }
133    /// let mut a = Pointer::<i32>::new(_a);
134    /// a.add(0);
135    /// assert_eq!(a.read(), 10);
136    /// unsafe { std::alloc::dealloc(_a as *mut u8, std::alloc::Layout::new::<i32>()); }
137    /// ```
138    #[inline(always)]
139    pub fn add(&mut self, offset: usize) {
140        unsafe {
141            self.ptr = self.ptr.add(offset);
142        }
143    }
144
145    /// inplace offset the value of the pointer in the current address
146    ///
147    /// # Arguments
148    /// `offset` - the offset to be added
149    ///
150    /// # Example
151    /// ```
152    /// use tensor_pointer::Pointer;
153    /// let mut _a = unsafe { std::alloc::alloc(std::alloc::Layout::new::<i32>()) as *mut i32 };
154    /// unsafe { _a.write(10); }
155    /// let mut a = Pointer::<i32>::new(_a);
156    /// a.offset(0);
157    /// assert_eq!(a.read(), 10);
158    /// unsafe { std::alloc::dealloc(_a as *mut u8, std::alloc::Layout::new::<i32>()); }
159    /// ```
160    #[inline(always)]
161    pub fn offset(&mut self, offset: i64) {
162        unsafe {
163            self.ptr = self.ptr.offset(offset as isize);
164        }
165    }
166}
167
168unsafe impl<T> Send for Pointer<T> {}
169
170impl<T> Index<i64> for Pointer<T> {
171    type Output = T;
172    fn index(&self, index: i64) -> &Self::Output {
173        #[cfg(feature = "bound_check")]
174        {
175            if index < 0 || index >= (self.len as i64) {
176                panic!("index out of bounds. index: {}, len: {}", index, self.len);
177            }
178        }
179        unsafe { &*self.ptr.offset(index as isize) }
180    }
181}
182
183impl<T: Display> Index<isize> for Pointer<T> {
184    type Output = T;
185    fn index(&self, index: isize) -> &Self::Output {
186        #[cfg(feature = "bound_check")]
187        {
188            if index < 0 || (index as i64) >= (self.len as i64) {
189                panic!("index out of bounds. index: {}, len: {}", index, self.len);
190            }
191        }
192        unsafe { &*self.ptr.offset(index) }
193    }
194}
195
196impl<T: Display> Index<usize> for Pointer<T> {
197    type Output = T;
198    fn index(&self, index: usize) -> &Self::Output {
199        #[cfg(feature = "bound_check")]
200        {
201            if (index as i64) >= (self.len as i64) {
202                panic!("index out of bounds. index: {}, len: {}", index, self.len);
203            }
204        }
205        unsafe { &*self.ptr.add(index) }
206    }
207}
208
209impl<T: Display> IndexMut<i64> for Pointer<T> {
210    fn index_mut(&mut self, index: i64) -> &mut Self::Output {
211        #[cfg(feature = "bound_check")]
212        {
213            if index < 0 || index >= (self.len as i64) {
214                panic!("index out of bounds. index: {}, len: {}", index, self.len);
215            }
216        }
217        unsafe { &mut *self.ptr.offset(index as isize) }
218    }
219}
220
221impl<T: Display> IndexMut<isize> for Pointer<T> {
222    fn index_mut(&mut self, index: isize) -> &mut Self::Output {
223        #[cfg(feature = "bound_check")]
224        {
225            if index < 0 || (index as i64) >= (self.len as i64) {
226                panic!("index out of bounds. index: {}, len: {}", index, self.len);
227            }
228        }
229        unsafe { &mut *self.ptr.offset(index) }
230    }
231}
232
233impl<T: Display> IndexMut<usize> for Pointer<T> {
234    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
235        #[cfg(feature = "bound_check")]
236        {
237            if (index as i64) >= (self.len as i64) {
238                panic!("index out of bounds. index: {}, len: {}", index, self.len);
239            }
240        }
241        unsafe { &mut *self.ptr.add(index) }
242    }
243}
244
245impl<T> AddAssign<usize> for Pointer<T> {
246    fn add_assign(&mut self, rhs: usize) {
247        #[cfg(feature = "bound_check")]
248        {
249            self.len -= rhs as i64;
250            assert!(self.len >= 0);
251        }
252        unsafe {
253            self.ptr = self.ptr.add(rhs);
254        }
255    }
256}
257
258impl<T> Add<usize> for Pointer<T> {
259    type Output = Self;
260    fn add(self, rhs: usize) -> Self::Output {
261        #[cfg(feature = "bound_check")]
262        unsafe {
263            Self {
264                ptr: self.ptr.add(rhs),
265                len: self.len,
266            }
267        }
268        #[cfg(not(feature = "bound_check"))]
269        unsafe {
270            Self {
271                ptr: self.ptr.add(rhs),
272            }
273        }
274    }
275}
276
277impl<T> AddAssign<usize> for &mut Pointer<T> {
278    fn add_assign(&mut self, rhs: usize) {
279        #[cfg(feature = "bound_check")]
280        {
281            self.len -= rhs as i64;
282        }
283        unsafe {
284            self.ptr = self.ptr.add(rhs);
285        }
286    }
287}
288
289impl<T> AddAssign<i64> for &mut Pointer<T> {
290    fn add_assign(&mut self, rhs: i64) {
291        #[cfg(feature = "bound_check")]
292        {
293            self.len -= rhs;
294            assert!(self.len >= 0);
295        }
296        unsafe {
297            self.ptr = self.ptr.offset(rhs as isize);
298        }
299    }
300}
301
302impl<T> Add<usize> for &mut Pointer<T> {
303    type Output = Pointer<T>;
304    fn add(self, rhs: usize) -> Self::Output {
305        #[cfg(feature = "bound_check")]
306        unsafe {
307            Pointer::new(self.ptr.add(rhs), self.len)
308        }
309        #[cfg(not(feature = "bound_check"))]
310        unsafe {
311            Pointer::new(self.ptr.add(rhs))
312        }
313    }
314}
315
316impl<T> AddAssign<isize> for Pointer<T> {
317    fn add_assign(&mut self, rhs: isize) {
318        #[cfg(feature = "bound_check")]
319        {
320            self.len -= rhs as i64;
321            assert!(self.len >= 0);
322        }
323        unsafe {
324            self.ptr = self.ptr.offset(rhs);
325        }
326    }
327}
328
329impl<T> Add<isize> for Pointer<T> {
330    type Output = Self;
331    fn add(self, rhs: isize) -> Self::Output {
332        #[cfg(feature = "bound_check")]
333        unsafe {
334            Self {
335                ptr: self.ptr.offset(rhs),
336                len: self.len,
337            }
338        }
339        #[cfg(not(feature = "bound_check"))]
340        unsafe {
341            Self {
342                ptr: self.ptr.offset(rhs),
343            }
344        }
345    }
346}
347
348impl<T> AddAssign<i64> for Pointer<T> {
349    fn add_assign(&mut self, rhs: i64) {
350        #[cfg(feature = "bound_check")]
351        {
352            assert!(self.len >= 0);
353            self.len -= rhs as i64;
354        }
355        unsafe {
356            self.ptr = self.ptr.offset(rhs as isize);
357        }
358    }
359}
360
361impl<T> Add<i64> for Pointer<T> {
362    type Output = Self;
363    fn add(self, rhs: i64) -> Self::Output {
364        #[cfg(feature = "bound_check")]
365        unsafe {
366            Self {
367                ptr: self.ptr.offset(rhs as isize),
368                len: self.len,
369            }
370        }
371        #[cfg(not(feature = "bound_check"))]
372        unsafe {
373            Self {
374                ptr: self.ptr.offset(rhs as isize),
375            }
376        }
377    }
378}
379
380impl<T> SubAssign<usize> for Pointer<T> {
381    fn sub_assign(&mut self, rhs: usize) {
382        #[cfg(feature = "bound_check")]
383        {
384            self.len += rhs as i64;
385        }
386        unsafe {
387            self.ptr = self.ptr.offset(-(rhs as isize));
388        }
389    }
390}
391
392impl<T> SubAssign<isize> for Pointer<T> {
393    fn sub_assign(&mut self, rhs: isize) {
394        #[cfg(feature = "bound_check")]
395        {
396            self.len += rhs as i64;
397        }
398        unsafe {
399            self.ptr = self.ptr.offset(-rhs);
400        }
401    }
402}
403
404impl<T> SubAssign<i64> for Pointer<T> {
405    fn sub_assign(&mut self, rhs: i64) {
406        #[cfg(feature = "bound_check")]
407        {
408            self.len += rhs as i64;
409        }
410        unsafe {
411            self.ptr = self.ptr.offset(-rhs as isize);
412        }
413    }
414}
415
416impl<T> SubAssign<i64> for &mut Pointer<T> {
417    fn sub_assign(&mut self, rhs: i64) {
418        #[cfg(feature = "bound_check")]
419        {
420            self.len += rhs as i64;
421        }
422        unsafe {
423            self.ptr = self.ptr.offset(-rhs as isize);
424        }
425    }
426}
427
428impl<T> Deref for Pointer<T> {
429    type Target = T;
430    fn deref(&self) -> &Self::Target {
431        unsafe { &*self.ptr }
432    }
433}
434
435impl<T> DerefMut for Pointer<T> {
436    fn deref_mut(&mut self) -> &mut Self::Target {
437        unsafe { &mut *self.ptr }
438    }
439}
440
441unsafe impl<T> Sync for Pointer<T> {}
442
443impl<T: Display> Display for Pointer<T> {
444    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
445        write!(
446            f,
447            "Pointer( ptr: {}, val: {} )",
448            self.ptr as usize,
449            unsafe { self.ptr.read() }
450        )
451    }
452}