1use crate::{ToCv, TryAsRefCv, TryToCv};
2use anyhow::{ensure, Error, Result};
3use slice_of_array::prelude::*;
4use std::{mem::ManuallyDrop, ops::Deref, slice};
5
6macro_rules! impl_from_array {
8 ($elem:ty, 1) => {
9 impl<'a, const N: usize> TryAsRefCv<'a, TensorAsArray<'a, [$elem; N]>> for tch::Tensor {
11 type Error = Error;
12
13 fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [$elem; N]>, Self::Error> {
14 ensure!(self.device() == tch::Device::Cpu);
15 ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
16 ensure!(self.size() == &[N as i64]);
17
18 let slice: &[$elem] =
19 unsafe { slice::from_raw_parts(self.data_ptr() as *mut $elem, N) };
20 #[allow(unstable_name_collisions)]
21 let array = slice.as_array();
22
23 Ok(TensorAsArray {
24 data: ManuallyDrop::new(*array),
25 _tensor: self,
26 })
27 }
28 }
29
30 impl<const N: usize> TryToCv<[$elem; N]> for tch::Tensor {
32 type Error = Error;
33
34 fn try_to_cv(&self) -> Result<[$elem; N], Self::Error> {
35 ensure!(self.size() == &[N as i64]);
36 let mut array = [Default::default(); N];
37 self.f_copy_data(array.as_mut(), N)?;
38 Ok(array)
39 }
40 }
41
42 impl<const N: usize> ToCv<tch::Tensor> for [$elem; N] {
44 fn to_cv(&self) -> tch::Tensor {
45 tch::Tensor::from_slice(self.as_ref())
46 }
47 }
48 };
49
50 ($elem:ty, 2) => {
51 impl<'a, const N1: usize, const N2: usize> TryAsRefCv<'a, TensorAsArray<'a, [[$elem; N2]; N1]>>
53 for tch::Tensor
54 {
55 type Error = Error;
56
57 fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[$elem; N2]; N1]>, Self::Error> {
58 ensure!(self.device() == tch::Device::Cpu);
59 ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
60 ensure!(self.size() == &[N1 as i64, N2 as i64]);
61
62 let slice: &[$elem] =
63 unsafe { slice::from_raw_parts(self.data_ptr() as *mut $elem, N1 * N2) };
64 #[allow(unstable_name_collisions)]
65 let array = slice.nest().as_array();
66
67 Ok(TensorAsArray {
68 data: ManuallyDrop::new(*array),
69 _tensor: self,
70 })
71 }
72 }
73
74 impl<const N1: usize, const N2: usize> TryToCv<[[$elem; N2]; N1]> for tch::Tensor {
76 type Error = Error;
77
78 fn try_to_cv(&self) -> Result<[[$elem; N2]; N1], Self::Error> {
79 ensure!(self.size() == &[N1 as i64, N2 as i64]);
80 let mut array = [[Default::default(); N2]; N1];
81 self.f_copy_data(array.flat_mut(), N1 * N2)?;
82 Ok(array)
83 }
84 }
85
86 impl<const N1: usize, const N2: usize> ToCv<tch::Tensor> for [[$elem; N2]; N1] {
88 fn to_cv(&self) -> tch::Tensor {
89 tch::Tensor::from_slice(self.flat()).view([N1 as i64, N2 as i64])
90 }
91 }
92 };
93
94 ($elem:ty, 3) => {
95 impl<'a, const N1: usize, const N2: usize, const N3: usize>
97 TryAsRefCv<'a, TensorAsArray<'a, [[[$elem; N3]; N2]; N1]>> for tch::Tensor
98 {
99 type Error = Error;
100
101 fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[[$elem; N3]; N2]; N1]>, Self::Error> {
102 ensure!(self.device() == tch::Device::Cpu);
103 ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
104 ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64]);
105
106 let slice: &[$elem] =
107 unsafe { slice::from_raw_parts(self.data_ptr() as *mut $elem, N1 * N2 * N3) };
108 #[allow(unstable_name_collisions)]
109 let array = slice.nest().nest().as_array();
110
111 Ok(TensorAsArray {
112 data: ManuallyDrop::new(*array),
113 _tensor: self,
114 })
115 }
116 }
117
118 impl<const N1: usize, const N2: usize, const N3: usize> TryToCv<[[[$elem; N3]; N2]; N1]>
120 for tch::Tensor
121 {
122 type Error = Error;
123
124 fn try_to_cv(&self) -> Result<[[[$elem; N3]; N2]; N1], Self::Error> {
125 ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64]);
126 let mut array = [[[Default::default(); N3]; N2]; N1];
127 self.f_copy_data(array.flat_mut().flat_mut(), N1 * N2 * N3)?;
128 Ok(array)
129 }
130 }
131
132 impl<const N1: usize, const N2: usize, const N3: usize> ToCv<tch::Tensor>
134 for [[[$elem; N3]; N2]; N1]
135 {
136 fn to_cv(&self) -> tch::Tensor {
137 tch::Tensor::from_slice(self.flat().flat()).view([N1 as i64, N2 as i64, N3 as i64])
138 }
139 }
140 };
141
142 ($elem:ty, 4) => {
143 impl<'a, const N1: usize, const N2: usize, const N3: usize, const N4: usize>
145 TryAsRefCv<'a, TensorAsArray<'a, [[[[$elem; N4]; N3]; N2]; N1]>> for tch::Tensor
146 {
147 type Error = Error;
148
149 fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[[[$elem; N4]; N3]; N2]; N1]>, Self::Error> {
150 ensure!(self.device() == tch::Device::Cpu);
151 ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
152 ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64]);
153
154 let slice: &[$elem] = unsafe {
155 slice::from_raw_parts(self.data_ptr() as *mut $elem, N1 * N2 * N3 * N4)
156 };
157 #[allow(unstable_name_collisions)]
158 let array = slice.nest().nest().nest().as_array();
159
160 Ok(TensorAsArray {
161 data: ManuallyDrop::new(*array),
162 _tensor: self,
163 })
164 }
165 }
166
167 impl<const N1: usize, const N2: usize, const N3: usize, const N4: usize>
169 TryToCv<[[[[$elem; N4]; N3]; N2]; N1]> for tch::Tensor
170 {
171 type Error = Error;
172
173 fn try_to_cv(&self) -> Result<[[[[$elem; N4]; N3]; N2]; N1], Self::Error> {
174 ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64]);
175 let mut array = [[[[Default::default(); N4]; N3]; N2]; N1];
176 self.f_copy_data(array.flat_mut().flat_mut().flat_mut(), N1 * N2 * N3 * N4)?;
177 Ok(array)
178 }
179 }
180
181 impl<const N1: usize, const N2: usize, const N3: usize, const N4: usize>
183 ToCv<tch::Tensor> for [[[[$elem; N4]; N3]; N2]; N1]
184 {
185 fn to_cv(&self) -> tch::Tensor {
186 tch::Tensor::from_slice(self.flat().flat().flat())
187 .view([N1 as i64, N2 as i64, N3 as i64, N4 as i64])
188 }
189 }
190 };
191
192 ($elem:ty, 5) => {
193 impl<
195 'a,
196 const N1: usize,
197 const N2: usize,
198 const N3: usize,
199 const N4: usize,
200 const N5: usize,
201 > TryAsRefCv<'a, TensorAsArray<'a, [[[[[$elem; N5]; N4]; N3]; N2]; N1]>> for tch::Tensor
202 {
203 type Error = Error;
204
205 fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[[[[$elem; N5]; N4]; N3]; N2]; N1]>, Self::Error> {
206 ensure!(self.device() == tch::Device::Cpu);
207 ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
208 ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64]);
209
210 let slice: &[$elem] = unsafe {
211 slice::from_raw_parts(self.data_ptr() as *mut $elem, N1 * N2 * N3 * N4 * N5)
212 };
213 #[allow(unstable_name_collisions)]
214 let array = slice.nest().nest().nest().nest().as_array();
215
216 Ok(TensorAsArray {
217 data: ManuallyDrop::new(*array),
218 _tensor: self,
219 })
220 }
221 }
222
223 impl<
225 const N1: usize,
226 const N2: usize,
227 const N3: usize,
228 const N4: usize,
229 const N5: usize,
230 > TryToCv<[[[[[$elem; N5]; N4]; N3]; N2]; N1]> for tch::Tensor
231 {
232 type Error = Error;
233
234 fn try_to_cv(&self) -> Result<[[[[[$elem; N5]; N4]; N3]; N2]; N1], Self::Error> {
235 ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64]);
236 let mut array = [[[[[Default::default(); N5]; N4]; N3]; N2]; N1];
237 self.f_copy_data(
238 array.flat_mut().flat_mut().flat_mut().flat_mut(),
239 N1 * N2 * N3 * N4 * N5,
240 )?;
241 Ok(array)
242 }
243 }
244
245 impl<
247 const N1: usize,
248 const N2: usize,
249 const N3: usize,
250 const N4: usize,
251 const N5: usize,
252 > ToCv<tch::Tensor> for [[[[[$elem; N5]; N4]; N3]; N2]; N1]
253 {
254 fn to_cv(&self) -> tch::Tensor {
255 tch::Tensor::from_slice(self.flat().flat().flat().flat())
256 .view([N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64])
257 }
258 }
259 };
260
261 ($elem:ty, 6) => {
262 impl<
264 'a,
265 const N1: usize,
266 const N2: usize,
267 const N3: usize,
268 const N4: usize,
269 const N5: usize,
270 const N6: usize,
271 > TryAsRefCv<'a, TensorAsArray<'a, [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]>> for tch::Tensor
272 {
273 type Error = Error;
274
275 fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]>, Self::Error> {
276 ensure!(self.device() == tch::Device::Cpu);
277 ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
278 ensure!(
279 self.size()
280 == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64]
281 );
282
283 let slice: &[$elem] = unsafe {
284 slice::from_raw_parts(
285 self.data_ptr() as *mut $elem,
286 N1 * N2 * N3 * N4 * N5 * N6,
287 )
288 };
289 #[allow(unstable_name_collisions)]
290 let array = slice.nest().nest().nest().nest().nest().as_array();
291
292 Ok(TensorAsArray {
293 data: ManuallyDrop::new(*array),
294 _tensor: self,
295 })
296 }
297 }
298
299 impl<
301 const N1: usize,
302 const N2: usize,
303 const N3: usize,
304 const N4: usize,
305 const N5: usize,
306 const N6: usize,
307 > TryToCv<[[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]> for tch::Tensor
308 {
309 type Error = Error;
310
311 fn try_to_cv(&self) -> Result<[[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1], Self::Error> {
312 ensure!(
313 self.size()
314 == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64]
315 );
316 let mut array = [[[[[[Default::default(); N6]; N5]; N4]; N3]; N2]; N1];
317 self.f_copy_data(
318 array.flat_mut().flat_mut().flat_mut().flat_mut().flat_mut(),
319 N1 * N2 * N3 * N4 * N5 * N6,
320 )?;
321 Ok(array)
322 }
323 }
324
325 impl<
327 const N1: usize,
328 const N2: usize,
329 const N3: usize,
330 const N4: usize,
331 const N5: usize,
332 const N6: usize,
333 > ToCv<tch::Tensor> for [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]
334 {
335 fn to_cv(&self) -> tch::Tensor {
336 tch::Tensor::from_slice(self.flat().flat().flat().flat().flat()).view([
337 N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64,
338 ])
339 }
340 }
341 };
342}
343
344impl_from_array!(u8, 1);
346impl_from_array!(u8, 2);
347impl_from_array!(u8, 3);
348impl_from_array!(u8, 4);
349impl_from_array!(u8, 5);
350impl_from_array!(u8, 6);
351
352impl_from_array!(i8, 1);
353impl_from_array!(i8, 2);
354impl_from_array!(i8, 3);
355impl_from_array!(i8, 4);
356impl_from_array!(i8, 5);
357impl_from_array!(i8, 6);
358
359impl_from_array!(i16, 1);
360impl_from_array!(i16, 2);
361impl_from_array!(i16, 3);
362impl_from_array!(i16, 4);
363impl_from_array!(i16, 5);
364impl_from_array!(i16, 6);
365
366impl_from_array!(i32, 1);
367impl_from_array!(i32, 2);
368impl_from_array!(i32, 3);
369impl_from_array!(i32, 4);
370impl_from_array!(i32, 5);
371impl_from_array!(i32, 6);
372
373impl_from_array!(i64, 1);
374impl_from_array!(i64, 2);
375impl_from_array!(i64, 3);
376impl_from_array!(i64, 4);
377impl_from_array!(i64, 5);
378impl_from_array!(i64, 6);
379
380impl_from_array!(half::f16, 1);
381impl_from_array!(half::f16, 2);
382impl_from_array!(half::f16, 3);
383impl_from_array!(half::f16, 4);
384impl_from_array!(half::f16, 5);
385impl_from_array!(half::f16, 6);
386
387impl_from_array!(f32, 1);
388impl_from_array!(f32, 2);
389impl_from_array!(f32, 3);
390impl_from_array!(f32, 4);
391impl_from_array!(f32, 5);
392impl_from_array!(f32, 6);
393
394impl_from_array!(f64, 1);
395impl_from_array!(f64, 2);
396impl_from_array!(f64, 3);
397impl_from_array!(f64, 4);
398impl_from_array!(f64, 5);
399impl_from_array!(f64, 6);
400
401impl_from_array!(bool, 1);
402impl_from_array!(bool, 2);
403impl_from_array!(bool, 3);
404impl_from_array!(bool, 4);
405impl_from_array!(bool, 5);
406impl_from_array!(bool, 6);
407
408pub use tensors::*;
409mod tensors {
410 use super::*;
411
412 #[derive(Debug)]
414 pub struct TensorAsArray<'a, T> {
415 pub(crate) data: ManuallyDrop<T>,
416 pub(crate) _tensor: &'a tch::Tensor,
417 }
418
419 impl<'a, T> Drop for TensorAsArray<'a, T> {
420 fn drop(&mut self) {
421 unsafe {
422 ManuallyDrop::drop(&mut self.data);
423 }
424 }
425 }
426
427 impl<'a, T> AsRef<T> for TensorAsArray<'a, T> {
428 fn as_ref(&self) -> &T {
429 &self.data
430 }
431 }
432
433 impl<'a, T> Deref for TensorAsArray<'a, T> {
434 type Target = T;
435
436 fn deref(&self) -> &Self::Target {
437 &self.data
438 }
439 }
440
441 #[derive(Debug)]
443 pub struct TchTensorAsImage {
444 pub(crate) tensor: tch::Tensor,
445 pub(crate) kind: TchTensorImageShape,
446 }
447
448 #[derive(Debug, Clone, Copy)]
450 pub enum TchTensorImageShape {
451 Whc,
452 Hwc,
453 Chw,
454 Cwh,
455 }
456
457 impl TchTensorAsImage {
458 pub fn new(tensor: tch::Tensor, kind: TchTensorImageShape) -> Result<Self> {
459 let ndim = tensor.dim();
460 ensure!(
461 ndim == 3,
462 "the tensor must have 3 dimensions, but get {}",
463 ndim
464 );
465 Ok(Self { tensor, kind })
466 }
467
468 pub fn into_inner(self) -> tch::Tensor {
469 self.tensor
470 }
471
472 pub fn kind(&self) -> TchTensorImageShape {
473 self.kind
474 }
475
476 pub fn try_to_cv<T>(&self) -> Result<T, <Self as TryToCv<T>>::Error>
477 where
478 Self: TryToCv<T>,
479 {
480 TryToCv::try_to_cv(self)
481 }
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use crate::{TryAsRefCv, TryToCv, ToCv};
489 use rand::prelude::*;
490
491 #[test]
492 fn tensor_to_array_ref() {
493 let mut rng = rand::thread_rng();
494
495 {
497 type T = [f32; 3];
498
499 let input: T = rng.gen();
500 let tensor = input.to_cv();
501
502 let array: T = tensor.try_to_cv().unwrap();
503 assert!(array == input);
504
505 let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
506 assert!(*array_wrapper == input);
507 }
508
509 {
511 type T = [[f32; 3]; 2];
512
513 let input: T = rng.gen();
514 let tensor = input.to_cv();
515
516 let array: T = tensor.try_to_cv().unwrap();
517 assert!(array == input);
518
519 let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
520 assert!(*array_wrapper == input);
521 }
522
523 {
525 type T = [[[f32; 4]; 3]; 2];
526
527 let input: T = rng.gen();
528 let tensor = input.to_cv();
529
530 let array: T = tensor.try_to_cv().unwrap();
531 assert!(array == input);
532
533 let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
534 assert!(*array_wrapper == input);
535 }
536
537 {
539 type T = [[[[f32; 2]; 4]; 3]; 2];
540
541 let input: T = rng.gen();
542 let tensor = input.to_cv();
543
544 let array: T = tensor.try_to_cv().unwrap();
545 assert!(array == input);
546
547 let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
548 assert!(*array_wrapper == input);
549 }
550
551 {
553 type T = [[[[[f32; 3]; 2]; 4]; 3]; 2];
554
555 let input: T = rng.gen();
556 let tensor = input.to_cv();
557
558 let array: T = tensor.try_to_cv().unwrap();
559 assert!(array == input);
560
561 let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
562 assert!(*array_wrapper == input);
563 }
564
565 {
567 type T = [[[[[[f32; 2]; 3]; 2]; 4]; 3]; 2];
568
569 let input: T = rng.gen();
570 let tensor = input.to_cv();
571
572 let array: T = tensor.try_to_cv().unwrap();
573 assert!(array == input);
574
575 let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
576 assert!(*array_wrapper == input);
577 }
578 }
579}