1use crate::ndarray as nd;
2use crate::tch;
3use crate::{common::*, FromCv, TryFromCv};
4
5use to_ndarray_shape::*;
6
7mod to_ndarray_shape {
8 use super::*;
9
10 pub trait ToNdArrayShape<D>
11 where
12 Self::Output: Sized + Into<nd::StrideShape<D>>,
13 {
14 type Output;
15 type Error;
16
17 fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error>;
18 }
19
20 impl ToNdArrayShape<nd::IxDyn> for Vec<i64> {
21 type Output = Vec<usize>;
22 type Error = Error;
23
24 fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
25 let size: Vec<_> = self.iter().map(|&dim| dim as usize).collect();
26 Ok(size)
27 }
28 }
29
30 impl ToNdArrayShape<nd::Ix0> for Vec<i64> {
31 type Output = [usize; 0];
32 type Error = Error;
33
34 fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
35 ensure!(
36 self.is_empty(),
37 "empty empty tensor dimension, but get {:?}",
38 self
39 );
40 Ok([])
41 }
42 }
43
44 impl ToNdArrayShape<nd::Ix1> for Vec<i64> {
45 type Output = [usize; 1];
46 type Error = Error;
47
48 fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
49 let shape = match self.as_slice() {
50 &[s0] => [s0 as usize],
51 other => bail!("expect one dimension, but get {:?}", other),
52 };
53 Ok(shape)
54 }
55 }
56
57 impl ToNdArrayShape<nd::Ix2> for Vec<i64> {
58 type Output = [usize; 2];
59 type Error = Error;
60
61 fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
62 let shape = match self.as_slice() {
63 &[s0, s1] => [s0 as usize, s1 as usize],
64 other => bail!("expect one dimension, but get {:?}", other),
65 };
66 Ok(shape)
67 }
68 }
69
70 impl ToNdArrayShape<nd::Ix3> for Vec<i64> {
71 type Output = [usize; 3];
72 type Error = Error;
73
74 fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
75 let shape = match self.as_slice() {
76 &[s0, s1, s2] => [s0 as usize, s1 as usize, s2 as usize],
77 other => bail!("expect one dimension, but get {:?}", other),
78 };
79 Ok(shape)
80 }
81 }
82
83 impl ToNdArrayShape<nd::Ix4> for Vec<i64> {
84 type Output = [usize; 4];
85 type Error = Error;
86
87 fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
88 let shape = match self.as_slice() {
89 &[s0, s1, s2, s3] => [s0 as usize, s1 as usize, s2 as usize, s3 as usize],
90 other => bail!("expect one dimension, but get {:?}", other),
91 };
92 Ok(shape)
93 }
94 }
95
96 impl ToNdArrayShape<nd::Ix5> for Vec<i64> {
97 type Output = [usize; 5];
98 type Error = Error;
99
100 fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
101 let shape = match self.as_slice() {
102 &[s0, s1, s2, s3, s4] => [
103 s0 as usize,
104 s1 as usize,
105 s2 as usize,
106 s3 as usize,
107 s4 as usize,
108 ],
109 other => bail!("expect one dimension, but get {:?}", other),
110 };
111 Ok(shape)
112 }
113 }
114
115 impl ToNdArrayShape<nd::Ix6> for Vec<i64> {
116 type Output = [usize; 6];
117 type Error = Error;
118
119 fn to_ndarray_shape(&self) -> Result<Self::Output, Self::Error> {
120 let shape = match self.as_slice() {
121 &[s0, s1, s2, s3, s4, s5] => [
122 s0 as usize,
123 s1 as usize,
124 s2 as usize,
125 s3 as usize,
126 s4 as usize,
127 s5 as usize,
128 ],
129 other => bail!("expect one dimension, but get {:?}", other),
130 };
131 Ok(shape)
132 }
133 }
134}
135
136impl<A, D> TryFromCv<tch::Tensor> for nd::Array<A, D>
137where
138 D: nd::Dimension,
139 A: tch::kind::Element,
140 Vec<A>: TryFrom<tch::Tensor, Error = tch::TchError>,
141 Vec<i64>: ToNdArrayShape<D, Error = Error>,
142{
143 type Error = Error;
144
145 fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
146 ensure!(
148 from.kind() == A::KIND,
149 "tensor with kind {:?} cannot converted to array with type {:?}",
150 from.kind(),
151 A::KIND
152 );
153
154 let shape = from.size();
155 let elems = Vec::<A>::try_from(from.flatten(0, -1))?;
156 let array_shape = shape.to_ndarray_shape()?;
157 let array = Self::from_shape_vec(array_shape, elems)?;
158 Ok(array)
159 }
160}
161
162impl<A, D> TryFromCv<&tch::Tensor> for nd::Array<A, D>
163where
164 D: nd::Dimension,
165 A: tch::kind::Element,
166 Vec<A>: TryFrom<tch::Tensor, Error = tch::TchError>,
167 Vec<i64>: ToNdArrayShape<D, Error = Error>,
168{
169 type Error = Error;
170
171 fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
172 Self::try_from_cv(from.shallow_clone())
173 }
174}
175
176impl<A, S, D> FromCv<&nd::ArrayBase<S, D>> for tch::Tensor
177where
178 A: tch::kind::Element + Clone,
179 S: nd::RawData<Elem = A> + nd::Data,
180 D: nd::Dimension,
181{
182 fn from_cv(from: &nd::ArrayBase<S, D>) -> Self {
183 let shape: Vec<_> = from.shape().iter().map(|&s| s as i64).collect();
184
185 match from.as_slice() {
186 Some(slice) => tch::Tensor::from_slice(slice).view(shape.as_slice()),
187 None => {
188 let elems: Vec<_> = from.iter().cloned().collect();
189 tch::Tensor::from_slice(&elems).view(shape.as_slice())
190 }
191 }
192 }
193}
194
195impl<A, S, D> FromCv<nd::ArrayBase<S, D>> for tch::Tensor
196where
197 A: tch::kind::Element + Clone,
198 S: nd::RawData<Elem = A> + nd::Data,
199 D: nd::Dimension,
200{
201 fn from_cv(from: nd::ArrayBase<S, D>) -> Self {
202 Self::from_cv(&from)
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use crate::tch::{self, IndexOp};
210 use crate::TryIntoCv;
211 use itertools::{iproduct, izip};
212 use rand::prelude::*;
213
214 #[test]
215 fn tensor_to_ndarray_conversion() -> Result<()> {
216 {
218 let s0 = 3;
219 let s1 = 4;
220 let s2 = 5;
221
222 let tensor = tch::Tensor::randn([s0, s1, s2], tch::kind::FLOAT_CPU);
223 let array: nd::ArrayD<f32> = (&tensor).try_into_cv()?;
224
225 let is_correct = itertools::iproduct!(0..s0, 0..s1, 0..s2).all(|(i0, i1, i2)| {
226 let lhs: f32 = tensor.i((i0, i1, i2)).try_into().unwrap();
227 let rhs = array[[i0 as usize, i1 as usize, i2 as usize]];
228 lhs == rhs
229 });
230
231 ensure!(is_correct, "value mismatch");
232 }
233
234 {
236 let tensor = tch::Tensor::randn([], tch::kind::FLOAT_CPU);
237 let array: nd::Array0<f32> = (&tensor).try_into_cv()?;
238 let lhs: f32 = tensor.try_into().unwrap();
239 let rhs = array[()];
240 ensure!(lhs == rhs, "value mismatch");
241 }
242
243 {
245 let s0 = 10;
246 let tensor = tch::Tensor::randn([s0], tch::kind::FLOAT_CPU);
247 let array: nd::Array1<f32> = (&tensor).try_into_cv()?;
248
249 let is_correct = (0..s0).all(|ind| {
250 let lhs: f32 = tensor.i((ind,)).try_into().unwrap();
251 let rhs = array[ind as usize];
252 lhs == rhs
253 });
254
255 ensure!(is_correct, "value mismatch");
256 }
257
258 {
260 let s0 = 3;
261 let s1 = 5;
262
263 let tensor = tch::Tensor::randn([s0, s1], tch::kind::FLOAT_CPU);
264 let array: nd::Array2<f32> = (&tensor).try_into_cv()?;
265
266 let is_correct = itertools::iproduct!(0..s0, 0..s1).all(|(i0, i1)| {
267 let lhs: f32 = tensor.i((i0, i1)).try_into().unwrap();
268 let rhs = array[[i0 as usize, i1 as usize]];
269 lhs == rhs
270 });
271
272 ensure!(is_correct, "value mismatch");
273 }
274
275 {
277 let s0 = 3;
278 let s1 = 5;
279 let s2 = 7;
280
281 let tensor = tch::Tensor::randn([s0, s1, s2], tch::kind::FLOAT_CPU);
282 let array: nd::Array3<f32> = (&tensor).try_into_cv()?;
283
284 let is_correct = itertools::iproduct!(0..s0, 0..s1, 0..s2).all(|(i0, i1, i2)| {
285 let lhs: f32 = tensor.i((i0, i1, i2)).try_into().unwrap();
286 let rhs = array[[i0 as usize, i1 as usize, i2 as usize]];
287 lhs == rhs
288 });
289
290 ensure!(is_correct, "value mismatch");
291 }
292
293 {
295 let s0 = 3;
296 let s1 = 5;
297 let s2 = 7;
298 let s3 = 11;
299
300 let tensor = tch::Tensor::randn([s0, s1, s2, s3], tch::kind::FLOAT_CPU);
301 let array: nd::Array4<f32> = (&tensor).try_into_cv()?;
302
303 let is_correct =
304 itertools::iproduct!(0..s0, 0..s1, 0..s2, 0..s3).all(|(i0, i1, i2, i3)| {
305 let lhs: f32 = tensor.i((i0, i1, i2, i3)).try_into().unwrap();
306 let rhs = array[[i0 as usize, i1 as usize, i2 as usize, i3 as usize]];
307 lhs == rhs
308 });
309
310 ensure!(is_correct, "value mismatch");
311 }
312
313 {
315 let s0 = 3;
316 let s1 = 5;
317 let s2 = 7;
318 let s3 = 11;
319 let s4 = 13;
320
321 let tensor = tch::Tensor::randn([s0, s1, s2, s3, s4], tch::kind::FLOAT_CPU);
322 let array: nd::Array5<f32> = (&tensor).try_into_cv()?;
323
324 let is_correct = itertools::iproduct!(0..s0, 0..s1, 0..s2, 0..s3, 0..s4).all(
325 |(i0, i1, i2, i3, i4)| {
326 let lhs: f32 = tensor.i((i0, i1, i2, i3, i4)).try_into().unwrap();
327 let rhs = array[[
328 i0 as usize,
329 i1 as usize,
330 i2 as usize,
331 i3 as usize,
332 i4 as usize,
333 ]];
334 lhs == rhs
335 },
336 );
337
338 ensure!(is_correct, "value mismatch");
339 }
340
341 {
343 let s0 = 3;
344 let s1 = 5;
345 let s2 = 7;
346 let s3 = 11;
347 let s4 = 13;
348 let s5 = 17;
349
350 let tensor = tch::Tensor::randn([s0, s1, s2, s3, s4, s5], tch::kind::FLOAT_CPU);
351 let array: nd::Array6<f32> = (&tensor).try_into_cv()?;
352
353 let is_correct = itertools::iproduct!(0..s0, 0..s1, 0..s2, 0..s3, 0..s4, 0..s5).all(
354 |(i0, i1, i2, i3, i4, i5)| {
355 let lhs: f32 = tensor.i((i0, i1, i2, i3, i4, i5)).try_into().unwrap();
356 let rhs = array[[
357 i0 as usize,
358 i1 as usize,
359 i2 as usize,
360 i3 as usize,
361 i4 as usize,
362 i5 as usize,
363 ]];
364 lhs == rhs
365 },
366 );
367
368 ensure!(is_correct, "value mismatch");
369 }
370
371 Ok(())
372 }
373
374 #[test]
375 fn ndarray_to_tensor_conversion() -> Result<()> {
376 let mut rng = rand::thread_rng();
377
378 let s0 = 2;
379 let s1 = 3;
380 let s2 = 4;
381
382 let array = nd::Array3::<f32>::from_shape_simple_fn([s0, s1, s2], || rng.gen());
383 let array = array.reversed_axes();
384
385 let tensor = tch::Tensor::from_cv(&array);
386
387 let is_shape_correct = array.shape().len() == tensor.size().len()
388 && izip!(array.shape().iter().cloned(), tensor.size().iter().cloned())
389 .all(|(lhs, rhs)| lhs == rhs as usize);
390
391 ensure!(is_shape_correct, "shape mismatch");
392
393 let is_value_correct = iproduct!(0..s0, 0..s1, 0..s2).all(|(i0, i1, i2)| {
394 let lhs = array[(i2, i1, i0)];
395 let rhs: f32 = tensor
396 .i((i2 as i64, i1 as i64, i0 as i64))
397 .try_into()
398 .unwrap();
399 lhs == rhs
400 });
401 ensure!(is_value_correct, "value mismatch");
402
403 Ok(())
404 }
405}