1use pin_project::{pin_project, pinned_drop};
2
3use core::slice;
4use std::{
5 fmt::Debug,
6 marker::{PhantomData, PhantomPinned},
7 mem::transmute,
8 os::raw::c_void,
9 pin::Pin,
10 ptr::{self, NonNull},
11};
12
13use crate::{
14 datatype::DataType,
15 device::Device,
16 ffi::{DLManagedTensor, DLTensor},
17};
18
19#[derive(Debug)]
22#[repr(transparent)]
23pub struct Tensor<'tensor> {
24 pub inner: DLTensor,
25 _marker: PhantomData<fn(&'tensor ()) -> &'tensor ()>, }
27
28impl<'tensor> From<Tensor<'tensor>> for DLTensor {
29 fn from(ts: Tensor<'tensor>) -> Self {
30 ts.inner
31 }
32}
33
34impl<'tensor> From<DLTensor> for Tensor<'tensor> {
35 fn from(dts: DLTensor) -> Self {
36 Tensor {
37 inner: dts,
38 _marker: PhantomData,
39 }
40 }
41}
42
43impl<'tensor> Tensor<'tensor> {
44 pub fn new(
46 data: *mut c_void,
47 device: Device,
48 ndim: i32,
49 dtype: DataType,
50 shape: *mut i64,
51 strides: *mut i64,
52 byte_offset: u64,
53 ) -> Self {
54 let inner = DLTensor {
55 data,
56 device: device.into(),
57 ndim,
58 dtype: dtype.into(),
59 shape,
60 strides,
61 byte_offset,
62 };
63 Tensor {
64 inner,
65 _marker: PhantomData,
66 }
67 }
68
69 pub fn into_inner(self) -> DLTensor {
71 self.inner
72 }
73
74 pub fn into_raw(self) -> *const DLTensor {
76 &self.inner as *const _
77 }
78
79 pub unsafe fn from_raw(ptr: *mut DLTensor) -> Self {
81 debug_assert!(!ptr.is_null());
82 Tensor {
83 inner: *ptr,
84 _marker: PhantomData,
85 }
86 }
87
88 pub fn data(&self) -> *mut c_void {
90 self.inner.data
91 }
92
93 pub fn device(&self) -> Device {
95 self.inner.device.into()
96 }
97
98 pub fn itemsize(&self) -> usize {
100 let ty = self.dtype();
101 ty.lanes() * ty.bits() / 8_usize
102 }
103
104 pub fn ndim(&self) -> usize {
106 self.inner.ndim as usize
107 }
108
109 pub fn dtype(&self) -> DataType {
111 self.inner.dtype.into()
112 }
113
114 pub fn shape(&self) -> Option<&[usize]> {
116 let dlt = self.inner;
117 if dlt.shape.is_null() || dlt.data.is_null() {
118 return None;
119 };
120 let ret = unsafe { slice::from_raw_parts(dlt.shape as *const _, dlt.ndim as usize) };
121 Some(ret)
122 }
123
124 pub fn strides(&self) -> Option<&[usize]> {
126 let dlt = self.inner;
127 if dlt.strides.is_null() || dlt.data.is_null() {
128 return None;
129 };
130 let ret = unsafe { slice::from_raw_parts(dlt.strides as *const _, dlt.ndim as usize) };
131 Some(ret)
132 }
133
134 pub fn byte_offset(&self) -> isize {
136 self.inner.byte_offset as isize
137 }
138
139 pub fn size(&self) -> Option<usize> {
141 let ty = self.dtype();
142 self.shape().map(|v| {
143 v.iter().product::<usize>() * (ty.bits() as usize * ty.lanes() as usize + 7) / 8
144 })
145 }
146}
147
148#[derive(Debug)]
150#[repr(C)]
151pub struct ManagerContext<C> {
152 pub ptr: Option<NonNull<*mut c_void>>,
153 ty: PhantomData<C>,
154 _pin: PhantomPinned,
155}
156
157impl<C> ManagerContext<C> {
158 pub fn new(ptr: Option<NonNull<*mut c_void>>) -> Self {
159 Self {
160 ptr,
161 ty: PhantomData,
162 _pin: PhantomPinned,
163 }
164 }
165}
166
167#[pin_project(PinnedDrop)]
170#[repr(C)]
171pub struct ManagedTensorProxy<C> {
172 pub dl_tensor: DLTensor,
174 #[pin]
176 pub manager_ctx: ManagerContext<C>, pub deleter: Option<fn(&mut ManagedTensor<C>)>,
180}
181
182impl<C: Debug> Debug for ManagedTensorProxy<C> {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 f.debug_struct("ManagedTensorProxy")
185 .field("dl_tensor", &self.dl_tensor)
186 .field("manager_ctx", &self.manager_ctx)
187 .finish()
188 }
189}
190
191impl<C> ManagedTensorProxy<C> {
192 pub fn dl_tensor(&self) -> DLTensor {
193 self.dl_tensor
194 }
195
196 pub fn manager_ctx(self: Pin<&mut Self>) -> Option<NonNull<*mut c_void>> {
197 let mut this = self.project();
198 this.manager_ctx.as_mut().ptr
199 }
200
201 pub fn set_manager_ctx(self: Pin<&mut Self>, manager_ctx: NonNull<*mut c_void>) {
202 let mut this = self.project();
203 let new = ManagerContext::new(Some(manager_ctx));
204 this.manager_ctx.set(new);
205 }
206}
207
208impl<C> From<DLManagedTensor> for ManagedTensorProxy<C> {
209 fn from(mut dlmt: DLManagedTensor) -> Self {
210 let ptr: Option<NonNull<*mut c_void>> = if dlmt.manager_ctx.is_null() {
211 None
212 } else {
213 unsafe { Some(NonNull::new_unchecked(&mut dlmt.manager_ctx as *mut _)) }
214 };
215 let manager_ctx = ManagerContext::new(ptr);
216 let deleter = dlmt.deleter.take().map(|del| unsafe {
217 transmute::<unsafe extern "C" fn(*mut DLManagedTensor), fn(&mut ManagedTensor<C>)>(del)
218 });
219 ManagedTensorProxy {
220 dl_tensor: dlmt.dl_tensor,
221 manager_ctx,
222 deleter,
223 }
224 }
225}
226
227impl<C> From<ManagedTensorProxy<C>> for DLManagedTensor {
228 fn from(pmt: ManagedTensorProxy<C>) -> Self {
229 let dl_tensor = pmt.dl_tensor;
230 let manager_ctx = match pmt.manager_ctx.ptr {
231 None => ptr::null_mut(),
232 Some(nnptr) => unsafe { *nnptr.as_ptr() },
233 };
234 let deleter = unsafe {
235 pmt.deleter.map(|del_fn| {
236 transmute::<fn(&mut ManagedTensor<C>), unsafe extern "C" fn(*mut DLManagedTensor)>(
237 del_fn,
238 )
239 })
240 };
241 DLManagedTensor {
242 dl_tensor,
243 manager_ctx,
244 deleter,
245 }
246 }
247}
248
249impl<C> From<Pin<&mut ManagedTensorProxy<C>>> for DLManagedTensor {
250 fn from(pmt: Pin<&mut ManagedTensorProxy<C>>) -> Self {
251 let dl_tensor = pmt.dl_tensor;
252 let manager_ctx = match pmt.manager_ctx.ptr {
253 None => ptr::null_mut(),
254 Some(nnptr) => unsafe { *nnptr.as_ptr() },
255 };
256 let deleter = unsafe {
257 pmt.deleter.map(|del_fn| {
258 transmute::<fn(&mut ManagedTensor<C>), unsafe extern "C" fn(*mut DLManagedTensor)>(
259 del_fn,
260 )
261 })
262 };
263 DLManagedTensor {
264 dl_tensor,
265 manager_ctx,
266 deleter,
267 }
268 }
269}
270
271#[allow(clippy::needless_lifetimes)]
272#[pinned_drop]
273impl<C> PinnedDrop for ManagedTensorProxy<C> {
274 fn drop(mut self: Pin<&mut Self>) {
275 let mut dlm: DLManagedTensor = self.as_mut().into();
276 if let Some(fptr) = self.deleter {
277 unsafe {
278 let cfptr = transmute::<fn(&mut ManagedTensor<C>), fn(*mut DLManagedTensor)>(fptr);
279 cfptr(&mut dlm as *mut _);
280 };
281 }
282 }
283}
284
285#[derive(Debug)]
289#[repr(transparent)]
290pub struct ManagedTensor<'tensor, C: 'tensor> {
291 pub inner: ManagedTensorProxy<C>,
292 _marker: PhantomData<fn(&'tensor ()) -> &'tensor ()>, }
294
295impl<'tensor, C> From<DLManagedTensor> for ManagedTensor<'tensor, C> {
296 fn from(dlm: DLManagedTensor) -> Self {
297 let proxy: ManagedTensorProxy<C> = dlm.into();
298 ManagedTensor {
299 inner: proxy,
300 _marker: PhantomData,
301 }
302 }
303}
304
305impl<'tensor, C> From<ManagedTensor<'tensor, C>> for DLManagedTensor {
306 fn from(mt: ManagedTensor<'tensor, C>) -> Self {
307 mt.inner.into()
308 }
309}
310
311impl<'tensor, C: 'tensor> ManagedTensor<'tensor, C> {
312 pub fn new(tensor: Tensor<'tensor>, manager_ctx: Option<NonNull<*mut c_void>>) -> Self {
314 let manager_ctx = ManagerContext::new(manager_ctx);
315 let inner = ManagedTensorProxy {
316 dl_tensor: tensor.into_inner(),
317 manager_ctx,
318 deleter: None,
319 };
320
321 ManagedTensor {
322 inner,
323 _marker: PhantomData,
324 }
325 }
326
327 pub fn set_deleter(&mut self, deleter: fn(&mut ManagedTensor<C>)) {
329 self.inner.deleter = Some(deleter);
330 }
331
332 pub fn into_raw(self) -> *const DLManagedTensor {
334 let ret: DLManagedTensor = self.inner.into();
335 &ret as *const _
336 }
337
338 pub unsafe fn from_raw(ptr: *mut DLManagedTensor) -> Self {
340 debug_assert!(!ptr.is_null());
341 ManagedTensor {
342 inner: (*ptr).into(),
343 _marker: PhantomData,
344 }
345 }
346
347 pub fn into_tensor(self) -> Tensor<'tensor> {
349 self.inner.dl_tensor.into()
350 }
351}