1use crate::tch;
2use crate::{common::*, FromCv, TryFromCv};
3use slice_of_array::prelude::*;
4
5macro_rules! impl_from_array {
6 ($elem:ty) => {
7 impl<'a, const N: usize> TryFromCv<&'a tch::Tensor> for &'a [$elem; N] {
10 type Error = Error;
11
12 fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
13 ensure!(from.device() == tch::Device::Cpu);
14 ensure!(from.kind() == <$elem as tch::kind::Element>::KIND);
15 ensure!(from.size() == &[N as i64]);
16 let slice: &[$elem] =
17 unsafe { slice::from_raw_parts(from.data_ptr() as *mut $elem, N) };
18 Ok(slice.as_array())
19 }
20 }
21
22 impl<'a, const N1: usize, const N2: usize> TryFromCv<&'a tch::Tensor>
23 for &'a [[$elem; N2]; N1]
24 {
25 type Error = Error;
26
27 fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
28 ensure!(from.device() == tch::Device::Cpu);
29 ensure!(from.kind() == <$elem as tch::kind::Element>::KIND);
30 ensure!(from.size() == &[N1 as i64, N2 as i64]);
31 let slice: &[$elem] =
32 unsafe { slice::from_raw_parts(from.data_ptr() as *mut $elem, N1 * N2) };
33 Ok(slice.nest().as_array())
34 }
35 }
36
37 impl<'a, const N1: usize, const N2: usize, const N3: usize> TryFromCv<&'a tch::Tensor>
38 for &'a [[[$elem; N3]; N2]; N1]
39 {
40 type Error = Error;
41
42 fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
43 ensure!(from.device() == tch::Device::Cpu);
44 ensure!(from.kind() == <$elem as tch::kind::Element>::KIND);
45 ensure!(from.size() == &[N1 as i64, N2 as i64, N3 as i64]);
46 let slice: &[$elem] =
47 unsafe { slice::from_raw_parts(from.data_ptr() as *mut $elem, N1 * N2 * N3) };
48 Ok(slice.nest().nest().as_array())
49 }
50 }
51
52 impl<'a, const N1: usize, const N2: usize, const N3: usize, const N4: usize>
53 TryFromCv<&'a tch::Tensor> for &'a [[[[$elem; N4]; N3]; N2]; N1]
54 {
55 type Error = Error;
56
57 fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
58 ensure!(from.device() == tch::Device::Cpu);
59 ensure!(from.kind() == <$elem as tch::kind::Element>::KIND);
60 ensure!(from.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64]);
61 let slice: &[$elem] = unsafe {
62 slice::from_raw_parts(from.data_ptr() as *mut $elem, N1 * N2 * N3 * N4)
63 };
64 Ok(slice.nest().nest().nest().as_array())
65 }
66 }
67
68 impl<
69 'a,
70 const N1: usize,
71 const N2: usize,
72 const N3: usize,
73 const N4: usize,
74 const N5: usize,
75 > TryFromCv<&'a tch::Tensor> for &'a [[[[[$elem; N5]; N4]; N3]; N2]; N1]
76 {
77 type Error = Error;
78
79 fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
80 ensure!(from.device() == tch::Device::Cpu);
81 ensure!(from.kind() == <$elem as tch::kind::Element>::KIND);
82 ensure!(from.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64]);
83 let slice: &[$elem] = unsafe {
84 slice::from_raw_parts(from.data_ptr() as *mut $elem, N1 * N2 * N3 * N4 * N5)
85 };
86 Ok(slice.nest().nest().nest().nest().as_array())
87 }
88 }
89
90 impl<
91 'a,
92 const N1: usize,
93 const N2: usize,
94 const N3: usize,
95 const N4: usize,
96 const N5: usize,
97 const N6: usize,
98 > TryFromCv<&'a tch::Tensor> for &'a [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]
99 {
100 type Error = Error;
101
102 fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
103 ensure!(from.device() == tch::Device::Cpu);
104 ensure!(from.kind() == <$elem as tch::kind::Element>::KIND);
105 ensure!(
106 from.size()
107 == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64]
108 );
109 let slice: &[$elem] = unsafe {
110 slice::from_raw_parts(
111 from.data_ptr() as *mut $elem,
112 N1 * N2 * N3 * N4 * N5 * N6,
113 )
114 };
115 Ok(slice.nest().nest().nest().nest().nest().as_array())
116 }
117 }
118
119 impl<const N: usize> TryFromCv<&tch::Tensor> for [$elem; N] {
122 type Error = Error;
123
124 fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
125 ensure!(from.size() == &[N as i64]);
126 let mut array = [Default::default(); N];
127 from.f_copy_data(array.as_mut(), N)?;
128 Ok(array)
129 }
130 }
131
132 impl<const N1: usize, const N2: usize> TryFromCv<&tch::Tensor> for [[$elem; N2]; N1] {
133 type Error = Error;
134
135 fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
136 ensure!(from.size() == &[N1 as i64, N2 as i64]);
137 let mut array = [[Default::default(); N2]; N1];
138 from.f_copy_data(array.flat_mut(), N1 * N2)?;
139 Ok(array)
140 }
141 }
142
143 impl<const N1: usize, const N2: usize, const N3: usize> TryFromCv<&tch::Tensor>
144 for [[[$elem; N3]; N2]; N1]
145 {
146 type Error = Error;
147
148 fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
149 ensure!(from.size() == &[N1 as i64, N2 as i64, N3 as i64]);
150 let mut array = [[[Default::default(); N3]; N2]; N1];
151 from.f_copy_data(array.flat_mut().flat_mut(), N1 * N2 * N3)?;
152 Ok(array)
153 }
154 }
155
156 impl<const N1: usize, const N2: usize, const N3: usize, const N4: usize>
157 TryFromCv<&tch::Tensor> for [[[[$elem; N4]; N3]; N2]; N1]
158 {
159 type Error = Error;
160
161 fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
162 ensure!(from.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64]);
163 let mut array = [[[[Default::default(); N4]; N3]; N2]; N1];
164 from.f_copy_data(array.flat_mut().flat_mut().flat_mut(), N1 * N2 * N3 * N4)?;
165 Ok(array)
166 }
167 }
168
169 impl<
170 const N1: usize,
171 const N2: usize,
172 const N3: usize,
173 const N4: usize,
174 const N5: usize,
175 > TryFromCv<&tch::Tensor> for [[[[[$elem; N5]; N4]; N3]; N2]; N1]
176 {
177 type Error = Error;
178
179 fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
180 ensure!(from.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64]);
181 let mut array = [[[[[Default::default(); N5]; N4]; N3]; N2]; N1];
182 from.f_copy_data(
183 array.flat_mut().flat_mut().flat_mut().flat_mut(),
184 N1 * N2 * N3 * N4 * N5,
185 )?;
186 Ok(array)
187 }
188 }
189
190 impl<
191 const N1: usize,
192 const N2: usize,
193 const N3: usize,
194 const N4: usize,
195 const N5: usize,
196 const N6: usize,
197 > TryFromCv<&tch::Tensor> for [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]
198 {
199 type Error = Error;
200
201 fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
202 ensure!(
203 from.size()
204 == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64]
205 );
206 let mut array = [[[[[[Default::default(); N6]; N5]; N4]; N3]; N2]; N1];
207 from.f_copy_data(
208 array.flat_mut().flat_mut().flat_mut().flat_mut().flat_mut(),
209 N1 * N2 * N3 * N4 * N5 * N6,
210 )?;
211 Ok(array)
212 }
213 }
214
215 impl<const N: usize> TryFromCv<tch::Tensor> for [$elem; N] {
218 type Error = Error;
219
220 fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
221 Self::try_from_cv(&from)
222 }
223 }
224
225 impl<const N1: usize, const N2: usize> TryFromCv<tch::Tensor> for [[$elem; N2]; N1] {
226 type Error = Error;
227
228 fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
229 Self::try_from_cv(&from)
230 }
231 }
232
233 impl<const N1: usize, const N2: usize, const N3: usize> TryFromCv<tch::Tensor>
234 for [[[$elem; N3]; N2]; N1]
235 {
236 type Error = Error;
237
238 fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
239 Self::try_from_cv(&from)
240 }
241 }
242
243 impl<const N1: usize, const N2: usize, const N3: usize, const N4: usize>
244 TryFromCv<tch::Tensor> for [[[[$elem; N4]; N3]; N2]; N1]
245 {
246 type Error = Error;
247
248 fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
249 Self::try_from_cv(&from)
250 }
251 }
252
253 impl<
254 const N1: usize,
255 const N2: usize,
256 const N3: usize,
257 const N4: usize,
258 const N5: usize,
259 > TryFromCv<tch::Tensor> for [[[[[$elem; N5]; N4]; N3]; N2]; N1]
260 {
261 type Error = Error;
262
263 fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
264 Self::try_from_cv(&from)
265 }
266 }
267
268 impl<
269 const N1: usize,
270 const N2: usize,
271 const N3: usize,
272 const N4: usize,
273 const N5: usize,
274 const N6: usize,
275 > TryFromCv<tch::Tensor> for [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]
276 {
277 type Error = Error;
278
279 fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
280 Self::try_from_cv(&from)
281 }
282 }
283
284 impl<const N: usize> FromCv<&[$elem; N]> for tch::Tensor {
287 fn from_cv(from: &[$elem; N]) -> Self {
288 Self::from_slice(from.as_ref())
289 }
290 }
291
292 impl<const N1: usize, const N2: usize> FromCv<&[[$elem; N2]; N1]> for tch::Tensor {
293 fn from_cv(from: &[[$elem; N2]; N1]) -> Self {
294 Self::from_slice(from.flat()).view([N1 as i64, N2 as i64])
295 }
296 }
297
298 impl<const N1: usize, const N2: usize, const N3: usize> FromCv<&[[[$elem; N3]; N2]; N1]>
299 for tch::Tensor
300 {
301 fn from_cv(from: &[[[$elem; N3]; N2]; N1]) -> Self {
302 Self::from_slice(from.flat().flat()).view([N1 as i64, N2 as i64, N3 as i64])
303 }
304 }
305
306 impl<const N1: usize, const N2: usize, const N3: usize, const N4: usize>
307 FromCv<&[[[[$elem; N4]; N3]; N2]; N1]> for tch::Tensor
308 {
309 fn from_cv(from: &[[[[$elem; N4]; N3]; N2]; N1]) -> Self {
310 Self::from_slice(from.flat().flat().flat())
311 .view([N1 as i64, N2 as i64, N3 as i64, N4 as i64])
312 }
313 }
314
315 impl<
316 const N1: usize,
317 const N2: usize,
318 const N3: usize,
319 const N4: usize,
320 const N5: usize,
321 > FromCv<&[[[[[$elem; N5]; N4]; N3]; N2]; N1]> for tch::Tensor
322 {
323 fn from_cv(from: &[[[[[$elem; N5]; N4]; N3]; N2]; N1]) -> Self {
324 Self::from_slice(from.flat().flat().flat().flat())
325 .view([N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64])
326 }
327 }
328
329 impl<
330 const N1: usize,
331 const N2: usize,
332 const N3: usize,
333 const N4: usize,
334 const N5: usize,
335 const N6: usize,
336 > FromCv<&[[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]> for tch::Tensor
337 {
338 fn from_cv(from: &[[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]) -> Self {
339 Self::from_slice(from.flat().flat().flat().flat().flat()).view([
340 N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64,
341 ])
342 }
343 }
344
345 impl<const N: usize> FromCv<[$elem; N]> for tch::Tensor {
348 fn from_cv(from: [$elem; N]) -> Self {
349 Self::from_cv(&from)
350 }
351 }
352
353 impl<const N1: usize, const N2: usize> FromCv<[[$elem; N2]; N1]> for tch::Tensor {
354 fn from_cv(from: [[$elem; N2]; N1]) -> Self {
355 Self::from_cv(&from)
356 }
357 }
358
359 impl<const N1: usize, const N2: usize, const N3: usize> FromCv<[[[$elem; N3]; N2]; N1]>
360 for tch::Tensor
361 {
362 fn from_cv(from: [[[$elem; N3]; N2]; N1]) -> Self {
363 Self::from_cv(&from)
364 }
365 }
366
367 impl<const N1: usize, const N2: usize, const N3: usize, const N4: usize>
368 FromCv<[[[[$elem; N4]; N3]; N2]; N1]> for tch::Tensor
369 {
370 fn from_cv(from: [[[[$elem; N4]; N3]; N2]; N1]) -> Self {
371 Self::from_cv(&from)
372 }
373 }
374
375 impl<
376 const N1: usize,
377 const N2: usize,
378 const N3: usize,
379 const N4: usize,
380 const N5: usize,
381 > FromCv<[[[[[$elem; N5]; N4]; N3]; N2]; N1]> for tch::Tensor
382 {
383 fn from_cv(from: [[[[[$elem; N5]; N4]; N3]; N2]; N1]) -> Self {
384 Self::from_cv(&from)
385 }
386 }
387
388 impl<
389 const N1: usize,
390 const N2: usize,
391 const N3: usize,
392 const N4: usize,
393 const N5: usize,
394 const N6: usize,
395 > FromCv<[[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]> for tch::Tensor
396 {
397 fn from_cv(from: [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]) -> Self {
398 Self::from_cv(&from)
399 }
400 }
401 };
402}
403
404impl_from_array!(u8);
405impl_from_array!(i8);
406impl_from_array!(i16);
407impl_from_array!(i32);
408impl_from_array!(i64);
409impl_from_array!(half::f16);
410impl_from_array!(f32);
411impl_from_array!(f64);
412impl_from_array!(bool);
413
414pub use tensor_as_image::*;
415mod tensor_as_image {
416 use super::*;
417
418 #[derive(Debug)]
420 pub struct TchTensorAsImage
421 {
422 pub(crate) tensor: tch::Tensor,
423 pub(crate) kind: TchTensorImageShape,
424 }
425
426 #[derive(Debug, Clone, Copy)]
428 pub enum TchTensorImageShape {
429 Whc,
430 Hwc,
431 Chw,
432 Cwh,
433 }
434
435 impl TchTensorAsImage
436 {
437 pub fn new(tensor: tch::Tensor, kind: TchTensorImageShape) -> Result<Self> {
438 let ndim = tensor.dim();
439 ensure!(
440 ndim == 3,
441 "the tensor must have 3 dimensions, but get {}",
442 ndim
443 );
444 Ok(Self { tensor, kind })
445 }
446
447 pub fn into_inner(self) -> tch::Tensor {
448 self.tensor
449 }
450
451 pub fn kind(&self) -> TchTensorImageShape {
452 self.kind
453 }
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460 use crate::tch;
461 use crate::TryIntoCv;
462 use rand::prelude::*;
463
464 #[test]
465 fn tensor_to_array_ref() {
466 let mut rng = rand::thread_rng();
467
468 {
470 type T = [f32; 3];
471
472 let input: T = rng.gen();
473 let tensor = tch::Tensor::from_cv(&input);
474
475 let array: T = (&tensor).try_into_cv().unwrap();
476 assert!(array == input);
477
478 let array_ref: &T = (&tensor).try_into_cv().unwrap();
479 assert!(array_ref == input.as_ref());
480 }
481
482 {
484 type T = [[f32; 3]; 2];
485
486 let input: T = rng.gen();
487 let tensor = tch::Tensor::from_cv(&input);
488
489 let array: T = (&tensor).try_into_cv().unwrap();
490 assert!(array == input);
491
492 let array_ref: &T = (&tensor).try_into_cv().unwrap();
493 assert!(array_ref == input.as_ref());
494 }
495
496 {
498 type T = [[[f32; 4]; 3]; 2];
499
500 let input: T = rng.gen();
501 let tensor = tch::Tensor::from_cv(&input);
502
503 let array: T = (&tensor).try_into_cv().unwrap();
504 assert!(array == input);
505
506 let array_ref: &T = (&tensor).try_into_cv().unwrap();
507 assert!(array_ref == input.as_ref());
508 }
509
510 {
512 type T = [[[[f32; 2]; 4]; 3]; 2];
513
514 let input: T = rng.gen();
515 let tensor = tch::Tensor::from_cv(&input);
516
517 let array: T = (&tensor).try_into_cv().unwrap();
518 assert!(array == input);
519
520 let array_ref: &T = (&tensor).try_into_cv().unwrap();
521 assert!(array_ref == input.as_ref());
522 }
523
524 {
526 type T = [[[[[f32; 3]; 2]; 4]; 3]; 2];
527
528 let input: T = rng.gen();
529 let tensor = tch::Tensor::from_cv(&input);
530
531 let array: T = (&tensor).try_into_cv().unwrap();
532 assert!(array == input);
533
534 let array_ref: &T = (&tensor).try_into_cv().unwrap();
535 assert!(array_ref == input.as_ref());
536 }
537
538 {
540 type T = [[[[[[f32; 2]; 3]; 2]; 4]; 3]; 2];
541
542 let input: T = rng.gen();
543 let tensor = tch::Tensor::from_cv(&input);
544
545 let array: T = (&tensor).try_into_cv().unwrap();
546 assert!(array == input);
547
548 let array_ref: &T = (&tensor).try_into_cv().unwrap();
549 assert!(array_ref == input.as_ref());
550 }
551
552 {
554 type T = [[[[[[f32; 2]; 3]; 2]; 4]; 3]; 2];
555
556 let input: T = rng.gen();
557 let tensor = tch::Tensor::from_cv(&input);
558
559 let array: T = (&tensor).try_into_cv().unwrap();
560 assert!(array == input);
561
562 let array_ref: &T = (&tensor).try_into_cv().unwrap();
563 assert!(array_ref == input.as_ref());
564 }
565 }
566}