dlpackrs/
tensor.rs

1use pin_project::{pin_project, pinned_drop};
2
3use core::slice;
4use std::{
5    fmt::Debug,
6    marker::{PhantomData, PhantomPinned},
7    mem::transmute,
8    os::raw::c_void,
9    pin::Pin,
10    ptr::{self, NonNull},
11};
12
13use crate::{
14    datatype::DataType,
15    device::Device,
16    ffi::{DLManagedTensor, DLTensor},
17};
18
19/// Non-owned Tensor type interface.
20/// See [DLTensor](https://dmlc.github.io/dlpack/latest/c_api.html#_CPPv48DLTensor)
21#[derive(Debug)]
22#[repr(transparent)]
23pub struct Tensor<'tensor> {
24    pub inner: DLTensor,
25    _marker: PhantomData<fn(&'tensor ()) -> &'tensor ()>, // invariant wrt 'tensor
26}
27
28impl<'tensor> From<Tensor<'tensor>> for DLTensor {
29    fn from(ts: Tensor<'tensor>) -> Self {
30        ts.inner
31    }
32}
33
34impl<'tensor> From<DLTensor> for Tensor<'tensor> {
35    fn from(dts: DLTensor) -> Self {
36        Tensor {
37            inner: dts,
38            _marker: PhantomData,
39        }
40    }
41}
42
43impl<'tensor> Tensor<'tensor> {
44    /// Constructor
45    pub fn new(
46        data: *mut c_void,
47        device: Device,
48        ndim: i32,
49        dtype: DataType,
50        shape: *mut i64,
51        strides: *mut i64,
52        byte_offset: u64,
53    ) -> Self {
54        let inner = DLTensor {
55            data,
56            device: device.into(),
57            ndim,
58            dtype: dtype.into(),
59            shape,
60            strides,
61            byte_offset,
62        };
63        Tensor {
64            inner,
65            _marker: PhantomData,
66        }
67    }
68
69    /// Returns the underlying DLTensor where lifetime parameter is removed.
70    pub fn into_inner(self) -> DLTensor {
71        self.inner
72    }
73
74    /// Consumes the Tensor and returns the raw pointer to its underlying DLTensor.
75    pub fn into_raw(self) -> *const DLTensor {
76        &self.inner as *const _
77    }
78
79    /// Creates a Tensor from a raw DLTensor pointer (must be non-null).
80    pub unsafe fn from_raw(ptr: *mut DLTensor) -> Self {
81        debug_assert!(!ptr.is_null());
82        Tensor {
83            inner: *ptr,
84            _marker: PhantomData,
85        }
86    }
87
88    /// Returns a *mut pointer to the underlying data of the Tensor.
89    pub fn data(&self) -> *mut c_void {
90        self.inner.data
91    }
92
93    /// Returns the device type.
94    pub fn device(&self) -> Device {
95        self.inner.device.into()
96    }
97
98    /// Returns the size of an entry/item in the Tensor.
99    pub fn itemsize(&self) -> usize {
100        let ty = self.dtype();
101        ty.lanes() * ty.bits() / 8_usize
102    }
103
104    /// Returns the number of dimensions of the Tensor.
105    pub fn ndim(&self) -> usize {
106        self.inner.ndim as usize
107    }
108
109    /// Returns the type of the entries of the Tensor.
110    pub fn dtype(&self) -> DataType {
111        self.inner.dtype.into()
112    }
113
114    /// Returns the shape of the Tensor.
115    pub fn shape(&self) -> Option<&[usize]> {
116        let dlt = self.inner;
117        if dlt.shape.is_null() || dlt.data.is_null() {
118            return None;
119        };
120        let ret = unsafe { slice::from_raw_parts(dlt.shape as *const _, dlt.ndim as usize) };
121        Some(ret)
122    }
123
124    /// Returns the strides of the underlying Tensor.
125    pub fn strides(&self) -> Option<&[usize]> {
126        let dlt = self.inner;
127        if dlt.strides.is_null() || dlt.data.is_null() {
128            return None;
129        };
130        let ret = unsafe { slice::from_raw_parts(dlt.strides as *const _, dlt.ndim as usize) };
131        Some(ret)
132    }
133
134    /// Returns the byte offset of the underlying Tensor.
135    pub fn byte_offset(&self) -> isize {
136        self.inner.byte_offset as isize
137    }
138
139    /// Returns the size of the memory required to store the underlying data of the Tensor.
140    pub fn size(&self) -> Option<usize> {
141        let ty = self.dtype();
142        self.shape().map(|v| {
143            v.iter().product::<usize>() * (ty.bits() as usize * ty.lanes() as usize + 7) / 8
144        })
145    }
146}
147
148/// A typed ManagerContext type that is `!Unpin` i.e. pinnable for safety since it holds a pointer to the underlying DLTensor.
149#[derive(Debug)]
150#[repr(C)]
151pub struct ManagerContext<C> {
152    pub ptr: Option<NonNull<*mut c_void>>,
153    ty: PhantomData<C>,
154    _pin: PhantomPinned,
155}
156
157impl<C> ManagerContext<C> {
158    pub fn new(ptr: Option<NonNull<*mut c_void>>) -> Self {
159        Self {
160            ptr,
161            ty: PhantomData,
162            _pin: PhantomPinned,
163        }
164    }
165}
166
167/// Safe proxy to ffi::DLManagedTensor which is self-referential by design.
168/// See [DLManagedTensor](https://dmlc.github.io/dlpack/latest/c_api.html#_CPPv415DLManagedTensor)
169#[pin_project(PinnedDrop)]
170#[repr(C)]
171pub struct ManagedTensorProxy<C> {
172    /// Holds the underlying tensor.
173    pub dl_tensor: DLTensor,
174    /// The context holding the underlying DLTensor.
175    #[pin]
176    pub manager_ctx: ManagerContext<C>, // safe typed wrapper for *mut c_void which is !Unpin i.e. pinnable
177    /// Deleter function pointer.
178    // TODO: should this be `#[pin]`?
179    pub deleter: Option<fn(&mut ManagedTensor<C>)>,
180}
181
182impl<C: Debug> Debug for ManagedTensorProxy<C> {
183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        f.debug_struct("ManagedTensorProxy")
185            .field("dl_tensor", &self.dl_tensor)
186            .field("manager_ctx", &self.manager_ctx)
187            .finish()
188    }
189}
190
191impl<C> ManagedTensorProxy<C> {
192    pub fn dl_tensor(&self) -> DLTensor {
193        self.dl_tensor
194    }
195
196    pub fn manager_ctx(self: Pin<&mut Self>) -> Option<NonNull<*mut c_void>> {
197        let mut this = self.project();
198        this.manager_ctx.as_mut().ptr
199    }
200
201    pub fn set_manager_ctx(self: Pin<&mut Self>, manager_ctx: NonNull<*mut c_void>) {
202        let mut this = self.project();
203        let new = ManagerContext::new(Some(manager_ctx));
204        this.manager_ctx.set(new);
205    }
206}
207
208impl<C> From<DLManagedTensor> for ManagedTensorProxy<C> {
209    fn from(mut dlmt: DLManagedTensor) -> Self {
210        let ptr: Option<NonNull<*mut c_void>> = if dlmt.manager_ctx.is_null() {
211            None
212        } else {
213            unsafe { Some(NonNull::new_unchecked(&mut dlmt.manager_ctx as *mut _)) }
214        };
215        let manager_ctx = ManagerContext::new(ptr);
216        let deleter = dlmt.deleter.take().map(|del| unsafe {
217            transmute::<unsafe extern "C" fn(*mut DLManagedTensor), fn(&mut ManagedTensor<C>)>(del)
218        });
219        ManagedTensorProxy {
220            dl_tensor: dlmt.dl_tensor,
221            manager_ctx,
222            deleter,
223        }
224    }
225}
226
227impl<C> From<ManagedTensorProxy<C>> for DLManagedTensor {
228    fn from(pmt: ManagedTensorProxy<C>) -> Self {
229        let dl_tensor = pmt.dl_tensor;
230        let manager_ctx = match pmt.manager_ctx.ptr {
231            None => ptr::null_mut(),
232            Some(nnptr) => unsafe { *nnptr.as_ptr() },
233        };
234        let deleter = unsafe {
235            pmt.deleter.map(|del_fn| {
236                transmute::<fn(&mut ManagedTensor<C>), unsafe extern "C" fn(*mut DLManagedTensor)>(
237                    del_fn,
238                )
239            })
240        };
241        DLManagedTensor {
242            dl_tensor,
243            manager_ctx,
244            deleter,
245        }
246    }
247}
248
249impl<C> From<Pin<&mut ManagedTensorProxy<C>>> for DLManagedTensor {
250    fn from(pmt: Pin<&mut ManagedTensorProxy<C>>) -> Self {
251        let dl_tensor = pmt.dl_tensor;
252        let manager_ctx = match pmt.manager_ctx.ptr {
253            None => ptr::null_mut(),
254            Some(nnptr) => unsafe { *nnptr.as_ptr() },
255        };
256        let deleter = unsafe {
257            pmt.deleter.map(|del_fn| {
258                transmute::<fn(&mut ManagedTensor<C>), unsafe extern "C" fn(*mut DLManagedTensor)>(
259                    del_fn,
260                )
261            })
262        };
263        DLManagedTensor {
264            dl_tensor,
265            manager_ctx,
266            deleter,
267        }
268    }
269}
270
271#[allow(clippy::needless_lifetimes)]
272#[pinned_drop]
273impl<C> PinnedDrop for ManagedTensorProxy<C> {
274    fn drop(mut self: Pin<&mut Self>) {
275        let mut dlm: DLManagedTensor = self.as_mut().into();
276        if let Some(fptr) = self.deleter {
277            unsafe {
278                let cfptr = transmute::<fn(&mut ManagedTensor<C>), fn(*mut DLManagedTensor)>(fptr);
279                cfptr(&mut dlm as *mut _);
280            };
281        }
282    }
283}
284
285/// ManagedTensor type with Rust as the main owner of the underlying data.
286///
287///  See [DLManagedTensor](https://dmlc.github.io/dlpack/latest/c_api.html#_CPPv415DLManagedTensor)
288#[derive(Debug)]
289#[repr(transparent)]
290pub struct ManagedTensor<'tensor, C: 'tensor> {
291    pub inner: ManagedTensorProxy<C>,
292    _marker: PhantomData<fn(&'tensor ()) -> &'tensor ()>, // invariant wrt 'tensor
293}
294
295impl<'tensor, C> From<DLManagedTensor> for ManagedTensor<'tensor, C> {
296    fn from(dlm: DLManagedTensor) -> Self {
297        let proxy: ManagedTensorProxy<C> = dlm.into();
298        ManagedTensor {
299            inner: proxy,
300            _marker: PhantomData,
301        }
302    }
303}
304
305impl<'tensor, C> From<ManagedTensor<'tensor, C>> for DLManagedTensor {
306    fn from(mt: ManagedTensor<'tensor, C>) -> Self {
307        mt.inner.into()
308    }
309}
310
311impl<'tensor, C: 'tensor> ManagedTensor<'tensor, C> {
312    /// Contructor.
313    pub fn new(tensor: Tensor<'tensor>, manager_ctx: Option<NonNull<*mut c_void>>) -> Self {
314        let manager_ctx = ManagerContext::new(manager_ctx);
315        let inner = ManagedTensorProxy {
316            dl_tensor: tensor.into_inner(),
317            manager_ctx,
318            deleter: None,
319        };
320
321        ManagedTensor {
322            inner,
323            _marker: PhantomData,
324        }
325    }
326
327    /// Sets a deleter function pointer.
328    pub fn set_deleter(&mut self, deleter: fn(&mut ManagedTensor<C>)) {
329        self.inner.deleter = Some(deleter);
330    }
331
332    /// Consumes the ManagedTensor and returns the raw pointer to its underlying DLManagedTensor.
333    pub fn into_raw(self) -> *const DLManagedTensor {
334        let ret: DLManagedTensor = self.inner.into();
335        &ret as *const _
336    }
337
338    /// Returns a ManagedTensor instances from a raw pointer to DLManagedTensor.
339    pub unsafe fn from_raw(ptr: *mut DLManagedTensor) -> Self {
340        debug_assert!(!ptr.is_null());
341        ManagedTensor {
342            inner: (*ptr).into(),
343            _marker: PhantomData,
344        }
345    }
346
347    /// Consumes the ManagedTensor and returns Tensor.
348    pub fn into_tensor(self) -> Tensor<'tensor> {
349        self.inner.dl_tensor.into()
350    }
351}