cv_convert_fork/
with_tch_ndarray.rs

1use crate::ndarray as nd;
2use crate::tch;
3use crate::{common::*, FromCv, TryFromCv};
4
5use to_ndarray_shape::*;
6
7mod to_ndarray_shape {
8    use super::*;
9
10    pub trait ToNdArrayShape<D>
11    where
12        Self::Output: Sized + Into<nd::StrideShape<D>>,
13    {
14        type Output;
15        type Error;
16
17        fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error>;
18    }
19
20    impl ToNdArrayShape<nd::IxDyn> for Vec<i64> {
21        type Output = Vec<usize>;
22        type Error = Error;
23
24        fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
25            let size: Vec<_> = self.iter().map(|&dim| dim as usize).collect();
26            Ok(size)
27        }
28    }
29
30    impl ToNdArrayShape<nd::Ix0> for Vec<i64> {
31        type Output = [usize; 0];
32        type Error = Error;
33
34        fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
35            ensure!(
36                self.is_empty(),
37                "empty empty tensor dimension, but get {:?}",
38                self
39            );
40            Ok([])
41        }
42    }
43
44    impl ToNdArrayShape<nd::Ix1> for Vec<i64> {
45        type Output = [usize; 1];
46        type Error = Error;
47
48        fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
49            let shape = match self.as_slice() {
50                &[s0] => [s0 as usize],
51                other => bail!("expect one dimension, but get {:?}", other),
52            };
53            Ok(shape)
54        }
55    }
56
57    impl ToNdArrayShape<nd::Ix2> for Vec<i64> {
58        type Output = [usize; 2];
59        type Error = Error;
60
61        fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
62            let shape = match self.as_slice() {
63                &[s0, s1] => [s0 as usize, s1 as usize],
64                other => bail!("expect one dimension, but get {:?}", other),
65            };
66            Ok(shape)
67        }
68    }
69
70    impl ToNdArrayShape<nd::Ix3> for Vec<i64> {
71        type Output = [usize; 3];
72        type Error = Error;
73
74        fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
75            let shape = match self.as_slice() {
76                &[s0, s1, s2] => [s0 as usize, s1 as usize, s2 as usize],
77                other => bail!("expect one dimension, but get {:?}", other),
78            };
79            Ok(shape)
80        }
81    }
82
83    impl ToNdArrayShape<nd::Ix4> for Vec<i64> {
84        type Output = [usize; 4];
85        type Error = Error;
86
87        fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
88            let shape = match self.as_slice() {
89                &[s0, s1, s2, s3] => [s0 as usize, s1 as usize, s2 as usize, s3 as usize],
90                other => bail!("expect one dimension, but get {:?}", other),
91            };
92            Ok(shape)
93        }
94    }
95
96    impl ToNdArrayShape<nd::Ix5> for Vec<i64> {
97        type Output = [usize; 5];
98        type Error = Error;
99
100        fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
101            let shape = match self.as_slice() {
102                &[s0, s1, s2, s3, s4] => [
103                    s0 as usize,
104                    s1 as usize,
105                    s2 as usize,
106                    s3 as usize,
107                    s4 as usize,
108                ],
109                other => bail!("expect one dimension, but get {:?}", other),
110            };
111            Ok(shape)
112        }
113    }
114
115    impl ToNdArrayShape<nd::Ix6> for Vec<i64> {
116        type Output = [usize; 6];
117        type Error = Error;
118
119        fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
120            let shape = match self.as_slice() {
121                &[s0, s1, s2, s3, s4, s5] => [
122                    s0 as usize,
123                    s1 as usize,
124                    s2 as usize,
125                    s3 as usize,
126                    s4 as usize,
127                    s5 as usize,
128                ],
129                other => bail!("expect one dimension, but get {:?}", other),
130            };
131            Ok(shape)
132        }
133    }
134}
135
136impl<A, D> TryFromCv<tch::Tensor> for nd::Array<A, D>
137where
138    D: nd::Dimension,
139    A: tch::kind::Element,
140    Vec<A>: TryFrom<tch::Tensor, Error = tch::TchError>,
141    Vec<i64>: ToNdArrayShape<D, Error = Error>,
142{
143    type Error = Error;
144
145    fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
146        // check element type consistency
147        ensure!(
148            from.kind() == A::KIND,
149            "tensor with kind {:?} cannot converted to array with type {:?}",
150            from.kind(),
151            A::KIND
152        );
153
154        let shape = from.size();
155        let elems = Vec::<A>::try_from(from.flatten(0, -1))?;
156        let array_shape = shape.to_ndarray_shape()?;
157        let array = Self::from_shape_vec(array_shape, elems)?;
158        Ok(array)
159    }
160}
161
162impl<A, D> TryFromCv<&tch::Tensor> for nd::Array<A, D>
163where
164    D: nd::Dimension,
165    A: tch::kind::Element,
166    Vec<A>: TryFrom<tch::Tensor, Error = tch::TchError>,
167    Vec<i64>: ToNdArrayShape<D, Error = Error>,
168{
169    type Error = Error;
170
171    fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
172        Self::try_from_cv(from.shallow_clone())
173    }
174}
175
176impl<A, S, D> FromCv<&nd::ArrayBase<S, D>> for tch::Tensor
177where
178    A: tch::kind::Element + Clone,
179    S: nd::RawData<Elem = A> + nd::Data,
180    D: nd::Dimension,
181{
182    fn from_cv(from: &nd::ArrayBase<S, D>) -> Self {
183        let shape: Vec<_> = from.shape().iter().map(|&s| s as i64).collect();
184
185        match from.as_slice() {
186            Some(slice) => tch::Tensor::from_slice(slice).view(shape.as_slice()),
187            None => {
188                let elems: Vec<_> = from.iter().cloned().collect();
189                tch::Tensor::from_slice(&elems).view(shape.as_slice())
190            }
191        }
192    }
193}
194
195impl<A, S, D> FromCv<nd::ArrayBase<S, D>> for tch::Tensor
196where
197    A: tch::kind::Element + Clone,
198    S: nd::RawData<Elem = A> + nd::Data,
199    D: nd::Dimension,
200{
201    fn from_cv(from: nd::ArrayBase<S, D>) -> Self {
202        Self::from_cv(&from)
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::tch::{self, IndexOp};
210    use crate::TryIntoCv;
211    use itertools::{iproduct, izip};
212    use rand::prelude::*;
213
214    #[test]
215    fn tensor_to_ndarray_conversion() -> Result<()> {
216        // ArrayD
217        {
218            let s0 = 3;
219            let s1 = 4;
220            let s2 = 5;
221
222            let tensor = tch::Tensor::randn([s0, s1, s2], tch::kind::FLOAT_CPU);
223            let array: nd::ArrayD<f32> = (&tensor).try_into_cv()?;
224
225            let is_correct = itertools::iproduct!(0..s0, 0..s1, 0..s2).all(|(i0, i1, i2)| {
226                let lhs: f32 = tensor.i((i0, i1, i2)).try_into().unwrap();
227                let rhs = array[[i0 as usize, i1 as usize, i2 as usize]];
228                lhs == rhs
229            });
230
231            ensure!(is_correct, "value mismatch");
232        }
233
234        // Array0
235        {
236            let tensor = tch::Tensor::randn([], tch::kind::FLOAT_CPU);
237            let array: nd::Array0<f32> = (&tensor).try_into_cv()?;
238            let lhs: f32 = tensor.try_into().unwrap();
239            let rhs = array[()];
240            ensure!(lhs == rhs, "value mismatch");
241        }
242
243        // Array1
244        {
245            let s0 = 10;
246            let tensor = tch::Tensor::randn([s0], tch::kind::FLOAT_CPU);
247            let array: nd::Array1<f32> = (&tensor).try_into_cv()?;
248
249            let is_correct = (0..s0).all(|ind| {
250                let lhs: f32 = tensor.i((ind,)).try_into().unwrap();
251                let rhs = array[ind as usize];
252                lhs == rhs
253            });
254
255            ensure!(is_correct, "value mismatch");
256        }
257
258        // Array2
259        {
260            let s0 = 3;
261            let s1 = 5;
262
263            let tensor = tch::Tensor::randn([s0, s1], tch::kind::FLOAT_CPU);
264            let array: nd::Array2<f32> = (&tensor).try_into_cv()?;
265
266            let is_correct = itertools::iproduct!(0..s0, 0..s1).all(|(i0, i1)| {
267                let lhs: f32 = tensor.i((i0, i1)).try_into().unwrap();
268                let rhs = array[[i0 as usize, i1 as usize]];
269                lhs == rhs
270            });
271
272            ensure!(is_correct, "value mismatch");
273        }
274
275        // Array3
276        {
277            let s0 = 3;
278            let s1 = 5;
279            let s2 = 7;
280
281            let tensor = tch::Tensor::randn([s0, s1, s2], tch::kind::FLOAT_CPU);
282            let array: nd::Array3<f32> = (&tensor).try_into_cv()?;
283
284            let is_correct = itertools::iproduct!(0..s0, 0..s1, 0..s2).all(|(i0, i1, i2)| {
285                let lhs: f32 = tensor.i((i0, i1, i2)).try_into().unwrap();
286                let rhs = array[[i0 as usize, i1 as usize, i2 as usize]];
287                lhs == rhs
288            });
289
290            ensure!(is_correct, "value mismatch");
291        }
292
293        // Array4
294        {
295            let s0 = 3;
296            let s1 = 5;
297            let s2 = 7;
298            let s3 = 11;
299
300            let tensor = tch::Tensor::randn([s0, s1, s2, s3], tch::kind::FLOAT_CPU);
301            let array: nd::Array4<f32> = (&tensor).try_into_cv()?;
302
303            let is_correct =
304                itertools::iproduct!(0..s0, 0..s1, 0..s2, 0..s3).all(|(i0, i1, i2, i3)| {
305                    let lhs: f32 = tensor.i((i0, i1, i2, i3)).try_into().unwrap();
306                    let rhs = array[[i0 as usize, i1 as usize, i2 as usize, i3 as usize]];
307                    lhs == rhs
308                });
309
310            ensure!(is_correct, "value mismatch");
311        }
312
313        // Array5
314        {
315            let s0 = 3;
316            let s1 = 5;
317            let s2 = 7;
318            let s3 = 11;
319            let s4 = 13;
320
321            let tensor = tch::Tensor::randn([s0, s1, s2, s3, s4], tch::kind::FLOAT_CPU);
322            let array: nd::Array5<f32> = (&tensor).try_into_cv()?;
323
324            let is_correct = itertools::iproduct!(0..s0, 0..s1, 0..s2, 0..s3, 0..s4).all(
325                |(i0, i1, i2, i3, i4)| {
326                    let lhs: f32 = tensor.i((i0, i1, i2, i3, i4)).try_into().unwrap();
327                    let rhs = array[[
328                        i0 as usize,
329                        i1 as usize,
330                        i2 as usize,
331                        i3 as usize,
332                        i4 as usize,
333                    ]];
334                    lhs == rhs
335                },
336            );
337
338            ensure!(is_correct, "value mismatch");
339        }
340
341        // Array6
342        {
343            let s0 = 3;
344            let s1 = 5;
345            let s2 = 7;
346            let s3 = 11;
347            let s4 = 13;
348            let s5 = 17;
349
350            let tensor = tch::Tensor::randn([s0, s1, s2, s3, s4, s5], tch::kind::FLOAT_CPU);
351            let array: nd::Array6<f32> = (&tensor).try_into_cv()?;
352
353            let is_correct = itertools::iproduct!(0..s0, 0..s1, 0..s2, 0..s3, 0..s4, 0..s5).all(
354                |(i0, i1, i2, i3, i4, i5)| {
355                    let lhs: f32 = tensor.i((i0, i1, i2, i3, i4, i5)).try_into().unwrap();
356                    let rhs = array[[
357                        i0 as usize,
358                        i1 as usize,
359                        i2 as usize,
360                        i3 as usize,
361                        i4 as usize,
362                        i5 as usize,
363                    ]];
364                    lhs == rhs
365                },
366            );
367
368            ensure!(is_correct, "value mismatch");
369        }
370
371        Ok(())
372    }
373
374    #[test]
375    fn ndarray_to_tensor_conversion() -> Result<()> {
376        let mut rng = rand::thread_rng();
377
378        let s0 = 2;
379        let s1 = 3;
380        let s2 = 4;
381
382        let array = nd::Array3::<f32>::from_shape_simple_fn([s0, s1, s2], || rng.gen());
383        let array = array.reversed_axes();
384
385        let tensor = tch::Tensor::from_cv(&array);
386
387        let is_shape_correct = array.shape().len() == tensor.size().len()
388            && izip!(array.shape().iter().cloned(), tensor.size().iter().cloned())
389                .all(|(lhs, rhs)| lhs == rhs as usize);
390
391        ensure!(is_shape_correct, "shape mismatch");
392
393        let is_value_correct = iproduct!(0..s0, 0..s1, 0..s2).all(|(i0, i1, i2)| {
394            let lhs = array[(i2, i1, i0)];
395            let rhs: f32 = tensor
396                .i((i2 as i64, i1 as i64, i0 as i64))
397                .try_into()
398                .unwrap();
399            lhs == rhs
400        });
401        ensure!(is_value_correct, "value mismatch");
402
403        Ok(())
404    }
405}