cv_convert/
with_tch.rs

1use crate::{ToCv, TryAsRefCv, TryToCv};
2use anyhow::{ensure, Error, Result};
3use slice_of_array::prelude::*;
4use std::{mem::ManuallyDrop, ops::Deref, slice};
5
6// Helper macros for implementing conversions between tensors and different dimensioned arrays
7macro_rules! impl_from_array {
8    ($elem:ty, 1) => {
9        // Borrowed tensor to borrowed array
10        impl<'a, const N: usize> TryAsRefCv<'a, TensorAsArray<'a, [$elem; N]>> for tch::Tensor {
11            type Error = Error;
12
13            fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [$elem; N]>, Self::Error> {
14                ensure!(self.device() == tch::Device::Cpu);
15                ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
16                ensure!(self.size() == &[N as i64]);
17
18                let slice: &[$elem] =
19                    unsafe { slice::from_raw_parts(self.data_ptr() as *mut $elem, N) };
20                #[allow(unstable_name_collisions)]
21                let array = slice.as_array();
22
23                Ok(TensorAsArray {
24                    data: ManuallyDrop::new(*array),
25                    _tensor: self,
26                })
27            }
28        }
29
30        // Borrowed tensor to owned array
31        impl<const N: usize> TryToCv<[$elem; N]> for tch::Tensor {
32            type Error = Error;
33
34            fn try_to_cv(&self) -> Result<[$elem; N], Self::Error> {
35                ensure!(self.size() == &[N as i64]);
36                let mut array = [Default::default(); N];
37                self.f_copy_data(array.as_mut(), N)?;
38                Ok(array)
39            }
40        }
41
42        // Borrowed array to tensor
43        impl<const N: usize> ToCv<tch::Tensor> for [$elem; N] {
44            fn to_cv(&self) -> tch::Tensor {
45                tch::Tensor::from_slice(self.as_ref())
46            }
47        }
48    };
49
50    ($elem:ty, 2) => {
51        // Borrowed tensor to borrowed array
52        impl<'a, const N1: usize, const N2: usize> TryAsRefCv<'a, TensorAsArray<'a, [[$elem; N2]; N1]>>
53            for tch::Tensor
54        {
55            type Error = Error;
56
57            fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[$elem; N2]; N1]>, Self::Error> {
58                ensure!(self.device() == tch::Device::Cpu);
59                ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
60                ensure!(self.size() == &[N1 as i64, N2 as i64]);
61
62                let slice: &[$elem] =
63                    unsafe { slice::from_raw_parts(self.data_ptr() as *mut $elem, N1 * N2) };
64                #[allow(unstable_name_collisions)]
65                let array = slice.nest().as_array();
66
67                Ok(TensorAsArray {
68                    data: ManuallyDrop::new(*array),
69                    _tensor: self,
70                })
71            }
72        }
73
74        // Borrowed tensor to owned array
75        impl<const N1: usize, const N2: usize> TryToCv<[[$elem; N2]; N1]> for tch::Tensor {
76            type Error = Error;
77
78            fn try_to_cv(&self) -> Result<[[$elem; N2]; N1], Self::Error> {
79                ensure!(self.size() == &[N1 as i64, N2 as i64]);
80                let mut array = [[Default::default(); N2]; N1];
81                self.f_copy_data(array.flat_mut(), N1 * N2)?;
82                Ok(array)
83            }
84        }
85
86        // Borrowed array to tensor
87        impl<const N1: usize, const N2: usize> ToCv<tch::Tensor> for [[$elem; N2]; N1] {
88            fn to_cv(&self) -> tch::Tensor {
89                tch::Tensor::from_slice(self.flat()).view([N1 as i64, N2 as i64])
90            }
91        }
92    };
93
94    ($elem:ty, 3) => {
95        // Borrowed tensor to borrowed array
96        impl<'a, const N1: usize, const N2: usize, const N3: usize>
97            TryAsRefCv<'a, TensorAsArray<'a, [[[$elem; N3]; N2]; N1]>> for tch::Tensor
98        {
99            type Error = Error;
100
101            fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[[$elem; N3]; N2]; N1]>, Self::Error> {
102                ensure!(self.device() == tch::Device::Cpu);
103                ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
104                ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64]);
105
106                let slice: &[$elem] =
107                    unsafe { slice::from_raw_parts(self.data_ptr() as *mut $elem, N1 * N2 * N3) };
108                #[allow(unstable_name_collisions)]
109                let array = slice.nest().nest().as_array();
110
111                Ok(TensorAsArray {
112                    data: ManuallyDrop::new(*array),
113                    _tensor: self,
114                })
115            }
116        }
117
118        // Borrowed tensor to owned array
119        impl<const N1: usize, const N2: usize, const N3: usize> TryToCv<[[[$elem; N3]; N2]; N1]>
120            for tch::Tensor
121        {
122            type Error = Error;
123
124            fn try_to_cv(&self) -> Result<[[[$elem; N3]; N2]; N1], Self::Error> {
125                ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64]);
126                let mut array = [[[Default::default(); N3]; N2]; N1];
127                self.f_copy_data(array.flat_mut().flat_mut(), N1 * N2 * N3)?;
128                Ok(array)
129            }
130        }
131
132        // Borrowed array to tensor
133        impl<const N1: usize, const N2: usize, const N3: usize> ToCv<tch::Tensor>
134            for [[[$elem; N3]; N2]; N1]
135        {
136            fn to_cv(&self) -> tch::Tensor {
137                tch::Tensor::from_slice(self.flat().flat()).view([N1 as i64, N2 as i64, N3 as i64])
138            }
139        }
140    };
141
142    ($elem:ty, 4) => {
143        // Borrowed tensor to borrowed array
144        impl<'a, const N1: usize, const N2: usize, const N3: usize, const N4: usize>
145            TryAsRefCv<'a, TensorAsArray<'a, [[[[$elem; N4]; N3]; N2]; N1]>> for tch::Tensor
146        {
147            type Error = Error;
148
149            fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[[[$elem; N4]; N3]; N2]; N1]>, Self::Error> {
150                ensure!(self.device() == tch::Device::Cpu);
151                ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
152                ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64]);
153
154                let slice: &[$elem] = unsafe {
155                    slice::from_raw_parts(self.data_ptr() as *mut $elem, N1 * N2 * N3 * N4)
156                };
157                #[allow(unstable_name_collisions)]
158                let array = slice.nest().nest().nest().as_array();
159
160                Ok(TensorAsArray {
161                    data: ManuallyDrop::new(*array),
162                    _tensor: self,
163                })
164            }
165        }
166
167        // Borrowed tensor to owned array
168        impl<const N1: usize, const N2: usize, const N3: usize, const N4: usize>
169            TryToCv<[[[[$elem; N4]; N3]; N2]; N1]> for tch::Tensor
170        {
171            type Error = Error;
172
173            fn try_to_cv(&self) -> Result<[[[[$elem; N4]; N3]; N2]; N1], Self::Error> {
174                ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64]);
175                let mut array = [[[[Default::default(); N4]; N3]; N2]; N1];
176                self.f_copy_data(array.flat_mut().flat_mut().flat_mut(), N1 * N2 * N3 * N4)?;
177                Ok(array)
178            }
179        }
180
181        // Borrowed array to tensor
182        impl<const N1: usize, const N2: usize, const N3: usize, const N4: usize>
183            ToCv<tch::Tensor> for [[[[$elem; N4]; N3]; N2]; N1]
184        {
185            fn to_cv(&self) -> tch::Tensor {
186                tch::Tensor::from_slice(self.flat().flat().flat())
187                    .view([N1 as i64, N2 as i64, N3 as i64, N4 as i64])
188            }
189        }
190    };
191
192    ($elem:ty, 5) => {
193        // Borrowed tensor to borrowed array
194        impl<
195                'a,
196                const N1: usize,
197                const N2: usize,
198                const N3: usize,
199                const N4: usize,
200                const N5: usize,
201            > TryAsRefCv<'a, TensorAsArray<'a, [[[[[$elem; N5]; N4]; N3]; N2]; N1]>> for tch::Tensor
202        {
203            type Error = Error;
204
205            fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[[[[$elem; N5]; N4]; N3]; N2]; N1]>, Self::Error> {
206                ensure!(self.device() == tch::Device::Cpu);
207                ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
208                ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64]);
209
210                let slice: &[$elem] = unsafe {
211                    slice::from_raw_parts(self.data_ptr() as *mut $elem, N1 * N2 * N3 * N4 * N5)
212                };
213                #[allow(unstable_name_collisions)]
214                let array = slice.nest().nest().nest().nest().as_array();
215
216                Ok(TensorAsArray {
217                    data: ManuallyDrop::new(*array),
218                    _tensor: self,
219                })
220            }
221        }
222
223        // Borrowed tensor to owned array
224        impl<
225                const N1: usize,
226                const N2: usize,
227                const N3: usize,
228                const N4: usize,
229                const N5: usize,
230            > TryToCv<[[[[[$elem; N5]; N4]; N3]; N2]; N1]> for tch::Tensor
231        {
232            type Error = Error;
233
234            fn try_to_cv(&self) -> Result<[[[[[$elem; N5]; N4]; N3]; N2]; N1], Self::Error> {
235                ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64]);
236                let mut array = [[[[[Default::default(); N5]; N4]; N3]; N2]; N1];
237                self.f_copy_data(
238                    array.flat_mut().flat_mut().flat_mut().flat_mut(),
239                    N1 * N2 * N3 * N4 * N5,
240                )?;
241                Ok(array)
242            }
243        }
244
245        // Borrowed array to tensor
246        impl<
247                const N1: usize,
248                const N2: usize,
249                const N3: usize,
250                const N4: usize,
251                const N5: usize,
252            > ToCv<tch::Tensor> for [[[[[$elem; N5]; N4]; N3]; N2]; N1]
253        {
254            fn to_cv(&self) -> tch::Tensor {
255                tch::Tensor::from_slice(self.flat().flat().flat().flat())
256                    .view([N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64])
257            }
258        }
259    };
260
261    ($elem:ty, 6) => {
262        // Borrowed tensor to borrowed array
263        impl<
264                'a,
265                const N1: usize,
266                const N2: usize,
267                const N3: usize,
268                const N4: usize,
269                const N5: usize,
270                const N6: usize,
271            > TryAsRefCv<'a, TensorAsArray<'a, [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]>> for tch::Tensor
272        {
273            type Error = Error;
274
275            fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]>, Self::Error> {
276                ensure!(self.device() == tch::Device::Cpu);
277                ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
278                ensure!(
279                    self.size()
280                        == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64]
281                );
282
283                let slice: &[$elem] = unsafe {
284                    slice::from_raw_parts(
285                        self.data_ptr() as *mut $elem,
286                        N1 * N2 * N3 * N4 * N5 * N6,
287                    )
288                };
289                #[allow(unstable_name_collisions)]
290                let array = slice.nest().nest().nest().nest().nest().as_array();
291
292                Ok(TensorAsArray {
293                    data: ManuallyDrop::new(*array),
294                    _tensor: self,
295                })
296            }
297        }
298
299        // Borrowed tensor to owned array
300        impl<
301                const N1: usize,
302                const N2: usize,
303                const N3: usize,
304                const N4: usize,
305                const N5: usize,
306                const N6: usize,
307            > TryToCv<[[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]> for tch::Tensor
308        {
309            type Error = Error;
310
311            fn try_to_cv(&self) -> Result<[[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1], Self::Error> {
312                ensure!(
313                    self.size()
314                        == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64]
315                );
316                let mut array = [[[[[[Default::default(); N6]; N5]; N4]; N3]; N2]; N1];
317                self.f_copy_data(
318                    array.flat_mut().flat_mut().flat_mut().flat_mut().flat_mut(),
319                    N1 * N2 * N3 * N4 * N5 * N6,
320                )?;
321                Ok(array)
322            }
323        }
324
325        // Borrowed array to tensor
326        impl<
327                const N1: usize,
328                const N2: usize,
329                const N3: usize,
330                const N4: usize,
331                const N5: usize,
332                const N6: usize,
333            > ToCv<tch::Tensor> for [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]
334        {
335            fn to_cv(&self) -> tch::Tensor {
336                tch::Tensor::from_slice(self.flat().flat().flat().flat().flat()).view([
337                    N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64,
338                ])
339            }
340        }
341    };
342}
343
344// Generate implementations for each element type and dimension
345impl_from_array!(u8, 1);
346impl_from_array!(u8, 2);
347impl_from_array!(u8, 3);
348impl_from_array!(u8, 4);
349impl_from_array!(u8, 5);
350impl_from_array!(u8, 6);
351
352impl_from_array!(i8, 1);
353impl_from_array!(i8, 2);
354impl_from_array!(i8, 3);
355impl_from_array!(i8, 4);
356impl_from_array!(i8, 5);
357impl_from_array!(i8, 6);
358
359impl_from_array!(i16, 1);
360impl_from_array!(i16, 2);
361impl_from_array!(i16, 3);
362impl_from_array!(i16, 4);
363impl_from_array!(i16, 5);
364impl_from_array!(i16, 6);
365
366impl_from_array!(i32, 1);
367impl_from_array!(i32, 2);
368impl_from_array!(i32, 3);
369impl_from_array!(i32, 4);
370impl_from_array!(i32, 5);
371impl_from_array!(i32, 6);
372
373impl_from_array!(i64, 1);
374impl_from_array!(i64, 2);
375impl_from_array!(i64, 3);
376impl_from_array!(i64, 4);
377impl_from_array!(i64, 5);
378impl_from_array!(i64, 6);
379
380impl_from_array!(half::f16, 1);
381impl_from_array!(half::f16, 2);
382impl_from_array!(half::f16, 3);
383impl_from_array!(half::f16, 4);
384impl_from_array!(half::f16, 5);
385impl_from_array!(half::f16, 6);
386
387impl_from_array!(f32, 1);
388impl_from_array!(f32, 2);
389impl_from_array!(f32, 3);
390impl_from_array!(f32, 4);
391impl_from_array!(f32, 5);
392impl_from_array!(f32, 6);
393
394impl_from_array!(f64, 1);
395impl_from_array!(f64, 2);
396impl_from_array!(f64, 3);
397impl_from_array!(f64, 4);
398impl_from_array!(f64, 5);
399impl_from_array!(f64, 6);
400
401impl_from_array!(bool, 1);
402impl_from_array!(bool, 2);
403impl_from_array!(bool, 3);
404impl_from_array!(bool, 4);
405impl_from_array!(bool, 5);
406impl_from_array!(bool, 6);
407
408pub use tensors::*;
409mod tensors {
410    use super::*;
411
412    /// A wrapper for a borrowed array reference from a tensor.
413    #[derive(Debug)]
414    pub struct TensorAsArray<'a, T> {
415        pub(crate) data: ManuallyDrop<T>,
416        pub(crate) _tensor: &'a tch::Tensor,
417    }
418
419    impl<'a, T> Drop for TensorAsArray<'a, T> {
420        fn drop(&mut self) {
421            unsafe {
422                ManuallyDrop::drop(&mut self.data);
423            }
424        }
425    }
426
427    impl<'a, T> AsRef<T> for TensorAsArray<'a, T> {
428        fn as_ref(&self) -> &T {
429            &self.data
430        }
431    }
432
433    impl<'a, T> Deref for TensorAsArray<'a, T> {
434        type Target = T;
435
436        fn deref(&self) -> &Self::Target {
437            &self.data
438        }
439    }
440
441    /// An 2D image [Tensor](tch::Tensor) with dimension order.
442    #[derive(Debug)]
443    pub struct TchTensorAsImage {
444        pub(crate) tensor: tch::Tensor,
445        pub(crate) kind: TchTensorImageShape,
446    }
447
448    /// Describes the image channel order of a [Tensor](tch::Tensor).
449    #[derive(Debug, Clone, Copy)]
450    pub enum TchTensorImageShape {
451        Whc,
452        Hwc,
453        Chw,
454        Cwh,
455    }
456
457    impl TchTensorAsImage {
458        pub fn new(tensor: tch::Tensor, kind: TchTensorImageShape) -> Result<Self> {
459            let ndim = tensor.dim();
460            ensure!(
461                ndim == 3,
462                "the tensor must have 3 dimensions, but get {}",
463                ndim
464            );
465            Ok(Self { tensor, kind })
466        }
467
468        pub fn into_inner(self) -> tch::Tensor {
469            self.tensor
470        }
471
472        pub fn kind(&self) -> TchTensorImageShape {
473            self.kind
474        }
475
476        pub fn try_to_cv<T>(&self) -> Result<T, <Self as TryToCv<T>>::Error>
477        where
478            Self: TryToCv<T>,
479        {
480            TryToCv::try_to_cv(self)
481        }
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use crate::{TryAsRefCv, TryToCv, ToCv};
489    use rand::prelude::*;
490
491    #[test]
492    fn tensor_to_array_ref() {
493        let mut rng = rand::thread_rng();
494
495        // 1 dim
496        {
497            type T = [f32; 3];
498
499            let input: T = rng.gen();
500            let tensor = input.to_cv();
501
502            let array: T = tensor.try_to_cv().unwrap();
503            assert!(array == input);
504
505            let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
506            assert!(*array_wrapper == input);
507        }
508
509        // 2 dim
510        {
511            type T = [[f32; 3]; 2];
512
513            let input: T = rng.gen();
514            let tensor = input.to_cv();
515
516            let array: T = tensor.try_to_cv().unwrap();
517            assert!(array == input);
518
519            let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
520            assert!(*array_wrapper == input);
521        }
522
523        // 3 dim
524        {
525            type T = [[[f32; 4]; 3]; 2];
526
527            let input: T = rng.gen();
528            let tensor = input.to_cv();
529
530            let array: T = tensor.try_to_cv().unwrap();
531            assert!(array == input);
532
533            let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
534            assert!(*array_wrapper == input);
535        }
536
537        // 4 dim
538        {
539            type T = [[[[f32; 2]; 4]; 3]; 2];
540
541            let input: T = rng.gen();
542            let tensor = input.to_cv();
543
544            let array: T = tensor.try_to_cv().unwrap();
545            assert!(array == input);
546
547            let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
548            assert!(*array_wrapper == input);
549        }
550
551        // 5 dim
552        {
553            type T = [[[[[f32; 3]; 2]; 4]; 3]; 2];
554
555            let input: T = rng.gen();
556            let tensor = input.to_cv();
557
558            let array: T = tensor.try_to_cv().unwrap();
559            assert!(array == input);
560
561            let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
562            assert!(*array_wrapper == input);
563        }
564
565        // 6 dim
566        {
567            type T = [[[[[[f32; 2]; 3]; 2]; 4]; 3]; 2];
568
569            let input: T = rng.gen();
570            let tensor = input.to_cv();
571
572            let array: T = tensor.try_to_cv().unwrap();
573            assert!(array == input);
574
575            let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
576            assert!(*array_wrapper == input);
577        }
578    }
579}