cv_convert_fork/
with_tch.rs

1use crate::tch;
2use crate::{common::*, FromCv, TryFromCv};
3use slice_of_array::prelude::*;
4
5macro_rules! impl_from_array {
6    ($elem:ty) => {
7        // borrowed tensor to array reference
8
9        impl<'a, const N: usize> TryFromCv<&'a tch::Tensor> for &'a [$elem; N] {
10            type Error = Error;
11
12            fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
13                ensure!(from.device() == tch::Device::Cpu);
14                ensure!(from.kind() == <$elem as tch::kind::Element>::KIND);
15                ensure!(from.size() == &[N as i64]);
16                let slice: &[$elem] =
17                    unsafe { slice::from_raw_parts(from.data_ptr() as *mut $elem, N) };
18                Ok(slice.as_array())
19            }
20        }
21
22        impl<'a, const N1: usize, const N2: usize> TryFromCv<&'a tch::Tensor>
23            for &'a [[$elem; N2]; N1]
24        {
25            type Error = Error;
26
27            fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
28                ensure!(from.device() == tch::Device::Cpu);
29                ensure!(from.kind() == <$elem as tch::kind::Element>::KIND);
30                ensure!(from.size() == &[N1 as i64, N2 as i64]);
31                let slice: &[$elem] =
32                    unsafe { slice::from_raw_parts(from.data_ptr() as *mut $elem, N1 * N2) };
33                Ok(slice.nest().as_array())
34            }
35        }
36
37        impl<'a, const N1: usize, const N2: usize, const N3: usize> TryFromCv<&'a tch::Tensor>
38            for &'a [[[$elem; N3]; N2]; N1]
39        {
40            type Error = Error;
41
42            fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
43                ensure!(from.device() == tch::Device::Cpu);
44                ensure!(from.kind() == <$elem as tch::kind::Element>::KIND);
45                ensure!(from.size() == &[N1 as i64, N2 as i64, N3 as i64]);
46                let slice: &[$elem] =
47                    unsafe { slice::from_raw_parts(from.data_ptr() as *mut $elem, N1 * N2 * N3) };
48                Ok(slice.nest().nest().as_array())
49            }
50        }
51
52        impl<'a, const N1: usize, const N2: usize, const N3: usize, const N4: usize>
53            TryFromCv<&'a tch::Tensor> for &'a [[[[$elem; N4]; N3]; N2]; N1]
54        {
55            type Error = Error;
56
57            fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
58                ensure!(from.device() == tch::Device::Cpu);
59                ensure!(from.kind() == <$elem as tch::kind::Element>::KIND);
60                ensure!(from.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64]);
61                let slice: &[$elem] = unsafe {
62                    slice::from_raw_parts(from.data_ptr() as *mut $elem, N1 * N2 * N3 * N4)
63                };
64                Ok(slice.nest().nest().nest().as_array())
65            }
66        }
67
68        impl<
69                'a,
70                const N1: usize,
71                const N2: usize,
72                const N3: usize,
73                const N4: usize,
74                const N5: usize,
75            > TryFromCv<&'a tch::Tensor> for &'a [[[[[$elem; N5]; N4]; N3]; N2]; N1]
76        {
77            type Error = Error;
78
79            fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
80                ensure!(from.device() == tch::Device::Cpu);
81                ensure!(from.kind() == <$elem as tch::kind::Element>::KIND);
82                ensure!(from.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64]);
83                let slice: &[$elem] = unsafe {
84                    slice::from_raw_parts(from.data_ptr() as *mut $elem, N1 * N2 * N3 * N4 * N5)
85                };
86                Ok(slice.nest().nest().nest().nest().as_array())
87            }
88        }
89
90        impl<
91                'a,
92                const N1: usize,
93                const N2: usize,
94                const N3: usize,
95                const N4: usize,
96                const N5: usize,
97                const N6: usize,
98            > TryFromCv<&'a tch::Tensor> for &'a [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]
99        {
100            type Error = Error;
101
102            fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
103                ensure!(from.device() == tch::Device::Cpu);
104                ensure!(from.kind() == <$elem as tch::kind::Element>::KIND);
105                ensure!(
106                    from.size()
107                        == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64]
108                );
109                let slice: &[$elem] = unsafe {
110                    slice::from_raw_parts(
111                        from.data_ptr() as *mut $elem,
112                        N1 * N2 * N3 * N4 * N5 * N6,
113                    )
114                };
115                Ok(slice.nest().nest().nest().nest().nest().as_array())
116            }
117        }
118
119        // borrowed tensor to array
120
121        impl<const N: usize> TryFromCv<&tch::Tensor> for [$elem; N] {
122            type Error = Error;
123
124            fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
125                ensure!(from.size() == &[N as i64]);
126                let mut array = [Default::default(); N];
127                from.f_copy_data(array.as_mut(), N)?;
128                Ok(array)
129            }
130        }
131
132        impl<const N1: usize, const N2: usize> TryFromCv<&tch::Tensor> for [[$elem; N2]; N1] {
133            type Error = Error;
134
135            fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
136                ensure!(from.size() == &[N1 as i64, N2 as i64]);
137                let mut array = [[Default::default(); N2]; N1];
138                from.f_copy_data(array.flat_mut(), N1 * N2)?;
139                Ok(array)
140            }
141        }
142
143        impl<const N1: usize, const N2: usize, const N3: usize> TryFromCv<&tch::Tensor>
144            for [[[$elem; N3]; N2]; N1]
145        {
146            type Error = Error;
147
148            fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
149                ensure!(from.size() == &[N1 as i64, N2 as i64, N3 as i64]);
150                let mut array = [[[Default::default(); N3]; N2]; N1];
151                from.f_copy_data(array.flat_mut().flat_mut(), N1 * N2 * N3)?;
152                Ok(array)
153            }
154        }
155
156        impl<const N1: usize, const N2: usize, const N3: usize, const N4: usize>
157            TryFromCv<&tch::Tensor> for [[[[$elem; N4]; N3]; N2]; N1]
158        {
159            type Error = Error;
160
161            fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
162                ensure!(from.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64]);
163                let mut array = [[[[Default::default(); N4]; N3]; N2]; N1];
164                from.f_copy_data(array.flat_mut().flat_mut().flat_mut(), N1 * N2 * N3 * N4)?;
165                Ok(array)
166            }
167        }
168
169        impl<
170                const N1: usize,
171                const N2: usize,
172                const N3: usize,
173                const N4: usize,
174                const N5: usize,
175            > TryFromCv<&tch::Tensor> for [[[[[$elem; N5]; N4]; N3]; N2]; N1]
176        {
177            type Error = Error;
178
179            fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
180                ensure!(from.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64]);
181                let mut array = [[[[[Default::default(); N5]; N4]; N3]; N2]; N1];
182                from.f_copy_data(
183                    array.flat_mut().flat_mut().flat_mut().flat_mut(),
184                    N1 * N2 * N3 * N4 * N5,
185                )?;
186                Ok(array)
187            }
188        }
189
190        impl<
191                const N1: usize,
192                const N2: usize,
193                const N3: usize,
194                const N4: usize,
195                const N5: usize,
196                const N6: usize,
197            > TryFromCv<&tch::Tensor> for [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]
198        {
199            type Error = Error;
200
201            fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
202                ensure!(
203                    from.size()
204                        == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64]
205                );
206                let mut array = [[[[[[Default::default(); N6]; N5]; N4]; N3]; N2]; N1];
207                from.f_copy_data(
208                    array.flat_mut().flat_mut().flat_mut().flat_mut().flat_mut(),
209                    N1 * N2 * N3 * N4 * N5 * N6,
210                )?;
211                Ok(array)
212            }
213        }
214
215        // owned tensor to array
216
217        impl<const N: usize> TryFromCv<tch::Tensor> for [$elem; N] {
218            type Error = Error;
219
220            fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
221                Self::try_from_cv(&from)
222            }
223        }
224
225        impl<const N1: usize, const N2: usize> TryFromCv<tch::Tensor> for [[$elem; N2]; N1] {
226            type Error = Error;
227
228            fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
229                Self::try_from_cv(&from)
230            }
231        }
232
233        impl<const N1: usize, const N2: usize, const N3: usize> TryFromCv<tch::Tensor>
234            for [[[$elem; N3]; N2]; N1]
235        {
236            type Error = Error;
237
238            fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
239                Self::try_from_cv(&from)
240            }
241        }
242
243        impl<const N1: usize, const N2: usize, const N3: usize, const N4: usize>
244            TryFromCv<tch::Tensor> for [[[[$elem; N4]; N3]; N2]; N1]
245        {
246            type Error = Error;
247
248            fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
249                Self::try_from_cv(&from)
250            }
251        }
252
253        impl<
254                const N1: usize,
255                const N2: usize,
256                const N3: usize,
257                const N4: usize,
258                const N5: usize,
259            > TryFromCv<tch::Tensor> for [[[[[$elem; N5]; N4]; N3]; N2]; N1]
260        {
261            type Error = Error;
262
263            fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
264                Self::try_from_cv(&from)
265            }
266        }
267
268        impl<
269                const N1: usize,
270                const N2: usize,
271                const N3: usize,
272                const N4: usize,
273                const N5: usize,
274                const N6: usize,
275            > TryFromCv<tch::Tensor> for [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]
276        {
277            type Error = Error;
278
279            fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
280                Self::try_from_cv(&from)
281            }
282        }
283
284        // borrowed array to tensor
285
286        impl<const N: usize> FromCv<&[$elem; N]> for tch::Tensor {
287            fn from_cv(from: &[$elem; N]) -> Self {
288                Self::from_slice(from.as_ref())
289            }
290        }
291
292        impl<const N1: usize, const N2: usize> FromCv<&[[$elem; N2]; N1]> for tch::Tensor {
293            fn from_cv(from: &[[$elem; N2]; N1]) -> Self {
294                Self::from_slice(from.flat()).view([N1 as i64, N2 as i64])
295            }
296        }
297
298        impl<const N1: usize, const N2: usize, const N3: usize> FromCv<&[[[$elem; N3]; N2]; N1]>
299            for tch::Tensor
300        {
301            fn from_cv(from: &[[[$elem; N3]; N2]; N1]) -> Self {
302                Self::from_slice(from.flat().flat()).view([N1 as i64, N2 as i64, N3 as i64])
303            }
304        }
305
306        impl<const N1: usize, const N2: usize, const N3: usize, const N4: usize>
307            FromCv<&[[[[$elem; N4]; N3]; N2]; N1]> for tch::Tensor
308        {
309            fn from_cv(from: &[[[[$elem; N4]; N3]; N2]; N1]) -> Self {
310                Self::from_slice(from.flat().flat().flat())
311                    .view([N1 as i64, N2 as i64, N3 as i64, N4 as i64])
312            }
313        }
314
315        impl<
316                const N1: usize,
317                const N2: usize,
318                const N3: usize,
319                const N4: usize,
320                const N5: usize,
321            > FromCv<&[[[[[$elem; N5]; N4]; N3]; N2]; N1]> for tch::Tensor
322        {
323            fn from_cv(from: &[[[[[$elem; N5]; N4]; N3]; N2]; N1]) -> Self {
324                Self::from_slice(from.flat().flat().flat().flat())
325                    .view([N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64])
326            }
327        }
328
329        impl<
330                const N1: usize,
331                const N2: usize,
332                const N3: usize,
333                const N4: usize,
334                const N5: usize,
335                const N6: usize,
336            > FromCv<&[[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]> for tch::Tensor
337        {
338            fn from_cv(from: &[[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]) -> Self {
339                Self::from_slice(from.flat().flat().flat().flat().flat()).view([
340                    N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64,
341                ])
342            }
343        }
344
345        // owned array to tensor
346
347        impl<const N: usize> FromCv<[$elem; N]> for tch::Tensor {
348            fn from_cv(from: [$elem; N]) -> Self {
349                Self::from_cv(&from)
350            }
351        }
352
353        impl<const N1: usize, const N2: usize> FromCv<[[$elem; N2]; N1]> for tch::Tensor {
354            fn from_cv(from: [[$elem; N2]; N1]) -> Self {
355                Self::from_cv(&from)
356            }
357        }
358
359        impl<const N1: usize, const N2: usize, const N3: usize> FromCv<[[[$elem; N3]; N2]; N1]>
360            for tch::Tensor
361        {
362            fn from_cv(from: [[[$elem; N3]; N2]; N1]) -> Self {
363                Self::from_cv(&from)
364            }
365        }
366
367        impl<const N1: usize, const N2: usize, const N3: usize, const N4: usize>
368            FromCv<[[[[$elem; N4]; N3]; N2]; N1]> for tch::Tensor
369        {
370            fn from_cv(from: [[[[$elem; N4]; N3]; N2]; N1]) -> Self {
371                Self::from_cv(&from)
372            }
373        }
374
375        impl<
376                const N1: usize,
377                const N2: usize,
378                const N3: usize,
379                const N4: usize,
380                const N5: usize,
381            > FromCv<[[[[[$elem; N5]; N4]; N3]; N2]; N1]> for tch::Tensor
382        {
383            fn from_cv(from: [[[[[$elem; N5]; N4]; N3]; N2]; N1]) -> Self {
384                Self::from_cv(&from)
385            }
386        }
387
388        impl<
389                const N1: usize,
390                const N2: usize,
391                const N3: usize,
392                const N4: usize,
393                const N5: usize,
394                const N6: usize,
395            > FromCv<[[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]> for tch::Tensor
396        {
397            fn from_cv(from: [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]) -> Self {
398                Self::from_cv(&from)
399            }
400        }
401    };
402}
403
404impl_from_array!(u8);
405impl_from_array!(i8);
406impl_from_array!(i16);
407impl_from_array!(i32);
408impl_from_array!(i64);
409impl_from_array!(half::f16);
410impl_from_array!(f32);
411impl_from_array!(f64);
412impl_from_array!(bool);
413
414pub use tensor_as_image::*;
415mod tensor_as_image {
416    use super::*;
417
418    /// An 2D image [Tensor](tch::Tensor) with dimension order.
419    #[derive(Debug)]
420    pub struct TchTensorAsImage
421    {
422        pub(crate) tensor: tch::Tensor,
423        pub(crate) kind: TchTensorImageShape,
424    }
425
426    /// Describes the image channel order of a [Tensor](tch::Tensor).
427    #[derive(Debug, Clone, Copy)]
428    pub enum TchTensorImageShape {
429        Whc,
430        Hwc,
431        Chw,
432        Cwh,
433    }
434
435    impl TchTensorAsImage
436    {
437        pub fn new(tensor: tch::Tensor, kind: TchTensorImageShape) -> Result<Self> {
438            let ndim = tensor.dim();
439            ensure!(
440                ndim == 3,
441                "the tensor must have 3 dimensions, but get {}",
442                ndim
443            );
444            Ok(Self { tensor, kind })
445        }
446
447        pub fn into_inner(self) -> tch::Tensor {
448            self.tensor
449        }
450
451        pub fn kind(&self) -> TchTensorImageShape {
452            self.kind
453        }
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use crate::tch;
461    use crate::TryIntoCv;
462    use rand::prelude::*;
463
464    #[test]
465    fn tensor_to_array_ref() {
466        let mut rng = rand::thread_rng();
467
468        // 1 dim
469        {
470            type T = [f32; 3];
471
472            let input: T = rng.gen();
473            let tensor = tch::Tensor::from_cv(&input);
474
475            let array: T = (&tensor).try_into_cv().unwrap();
476            assert!(array == input);
477
478            let array_ref: &T = (&tensor).try_into_cv().unwrap();
479            assert!(array_ref == input.as_ref());
480        }
481
482        // 2 dim
483        {
484            type T = [[f32; 3]; 2];
485
486            let input: T = rng.gen();
487            let tensor = tch::Tensor::from_cv(&input);
488
489            let array: T = (&tensor).try_into_cv().unwrap();
490            assert!(array == input);
491
492            let array_ref: &T = (&tensor).try_into_cv().unwrap();
493            assert!(array_ref == input.as_ref());
494        }
495
496        // 3 dim
497        {
498            type T = [[[f32; 4]; 3]; 2];
499
500            let input: T = rng.gen();
501            let tensor = tch::Tensor::from_cv(&input);
502
503            let array: T = (&tensor).try_into_cv().unwrap();
504            assert!(array == input);
505
506            let array_ref: &T = (&tensor).try_into_cv().unwrap();
507            assert!(array_ref == input.as_ref());
508        }
509
510        // 4 dim
511        {
512            type T = [[[[f32; 2]; 4]; 3]; 2];
513
514            let input: T = rng.gen();
515            let tensor = tch::Tensor::from_cv(&input);
516
517            let array: T = (&tensor).try_into_cv().unwrap();
518            assert!(array == input);
519
520            let array_ref: &T = (&tensor).try_into_cv().unwrap();
521            assert!(array_ref == input.as_ref());
522        }
523
524        // 4 dim
525        {
526            type T = [[[[[f32; 3]; 2]; 4]; 3]; 2];
527
528            let input: T = rng.gen();
529            let tensor = tch::Tensor::from_cv(&input);
530
531            let array: T = (&tensor).try_into_cv().unwrap();
532            assert!(array == input);
533
534            let array_ref: &T = (&tensor).try_into_cv().unwrap();
535            assert!(array_ref == input.as_ref());
536        }
537
538        // 5 dim
539        {
540            type T = [[[[[[f32; 2]; 3]; 2]; 4]; 3]; 2];
541
542            let input: T = rng.gen();
543            let tensor = tch::Tensor::from_cv(&input);
544
545            let array: T = (&tensor).try_into_cv().unwrap();
546            assert!(array == input);
547
548            let array_ref: &T = (&tensor).try_into_cv().unwrap();
549            assert!(array_ref == input.as_ref());
550        }
551
552        // 6 dim
553        {
554            type T = [[[[[[f32; 2]; 3]; 2]; 4]; 3]; 2];
555
556            let input: T = rng.gen();
557            let tensor = tch::Tensor::from_cv(&input);
558
559            let array: T = (&tensor).try_into_cv().unwrap();
560            assert!(array == input);
561
562            let array_ref: &T = (&tensor).try_into_cv().unwrap();
563            assert!(array_ref == input.as_ref());
564        }
565    }
566}