cv_convert_fork/
with_opencv_tch.rs

1use crate::opencv::{core as cv, prelude::*};
2use crate::tch;
3use crate::{common::*, TchTensorAsImage, TchTensorImageShape, TryFromCv, TryIntoCv};
4use std::borrow::Cow;
5
6use utils::*;
7mod utils {
8    use super::*;
9
10    pub struct TchImageMeta {
11        pub kind: tch::Kind,
12        pub width: i64,
13        pub height: i64,
14        pub channels: i64,
15    }
16
17    pub struct TchTensorMeta {
18        pub kind: tch::Kind,
19        pub shape: Vec<i64>,
20    }
21
22    pub fn tch_kind_to_opencv_depth(kind: tch::Kind) -> Result<i32> {
23        use tch::Kind as K;
24
25        let typ = match kind {
26            K::Uint8 => cv::CV_8U,
27            K::Int8 => cv::CV_8S,
28            K::Int16 => cv::CV_16S,
29            K::Half => cv::CV_16F,
30            K::Int => cv::CV_32S,
31            K::Float => cv::CV_32F,
32            K::Double => cv::CV_64F,
33            kind => bail!("unsupported tensor kind {:?}", kind),
34        };
35
36        Ok(typ)
37    }
38
39    pub fn opencv_depth_to_tch_kind(depth: i32) -> Result<tch::Kind> {
40        use tch::Kind as K;
41
42        let kind = match depth {
43            cv::CV_8U => K::Uint8,
44            cv::CV_8S => K::Int8,
45            cv::CV_16S => K::Int16,
46            cv::CV_32S => K::Int,
47            cv::CV_16F => K::Half,
48            cv::CV_32F => K::Float,
49            cv::CV_64F => K::Double,
50            _ => bail!("unsupported OpenCV Mat depth {}", depth),
51        };
52        Ok(kind)
53    }
54
55    pub fn opencv_mat_to_tch_meta_2d(mat: &cv::Mat) -> Result<TchImageMeta> {
56        let cv::Size { height, width } = mat.size()?;
57        let kind = opencv_depth_to_tch_kind(mat.depth())?;
58        let channels = mat.channels();
59        Ok(TchImageMeta {
60            kind,
61            width: width as i64,
62            height: height as i64,
63            channels: channels as i64,
64        })
65    }
66
67    pub fn opencv_mat_to_tch_meta_nd(mat: &cv::Mat) -> Result<TchTensorMeta> {
68        let shape: Vec<_> = mat
69            .mat_size()
70            .iter()
71            .map(|&dim| dim as i64)
72            .chain([mat.channels() as i64])
73            .collect();
74        let kind = opencv_depth_to_tch_kind(mat.depth())?;
75        Ok(TchTensorMeta { shape, kind })
76    }
77}
78
79pub use tensor_from_mat::*;
80mod tensor_from_mat {
81    use super::*;
82
83    /// A [Tensor](tch::Tensor) which data reference borrows from a [Mat](cv::Mat). It can be dereferenced to a [Tensor](tch::Tensor).
84    #[derive(Debug)]
85    pub struct OpenCvMatAsTchTensor<'a> {
86        pub(super) tensor: ManuallyDrop<tch::Tensor>,
87        pub(super) _mat: &'a cv::Mat,
88    }
89
90    impl<'a> Drop for OpenCvMatAsTchTensor<'a> {
91        fn drop(&mut self) {
92            unsafe {
93                ManuallyDrop::drop(&mut self.tensor);
94            }
95        }
96    }
97
98    impl<'a> AsRef<tch::Tensor> for OpenCvMatAsTchTensor<'a> {
99        fn as_ref(&self) -> &tch::Tensor {
100            self.tensor.deref()
101        }
102    }
103
104    impl<'a> Deref for OpenCvMatAsTchTensor<'a> {
105        type Target = tch::Tensor;
106
107        fn deref(&self) -> &Self::Target {
108            self.tensor.deref()
109        }
110    }
111
112    impl<'a> DerefMut for OpenCvMatAsTchTensor<'a> {
113        fn deref_mut(&mut self) -> &mut Self::Target {
114            self.tensor.deref_mut()
115        }
116    }
117}
118
119impl<'a> TryFromCv<&'a cv::Mat> for OpenCvMatAsTchTensor<'a> {
120    type Error = Error;
121
122    fn try_from_cv(from: &'a cv::Mat) -> Result<Self, Self::Error> {
123        ensure!(from.is_continuous(), "non-continuous Mat is not supported");
124
125        let TchTensorMeta { kind, shape } = opencv_mat_to_tch_meta_nd(&from)?;
126        let strides = {
127            let mut strides: Vec<_> = shape
128                .iter()
129                .rev()
130                .cloned()
131                .scan(1, |prev, dim| {
132                    let stride = *prev;
133                    *prev *= dim;
134                    Some(stride)
135                })
136                .collect();
137            strides.reverse();
138            strides
139        };
140
141        let tensor = unsafe {
142            let ptr = from.ptr(0)? as *const u8;
143            tch::Tensor::f_from_blob(ptr, &shape, &strides, kind, tch::Device::Cpu)?
144        };
145
146        Ok(Self {
147            tensor: ManuallyDrop::new(tensor),
148            _mat: from,
149        })
150    }
151}
152
153impl TryFromCv<&cv::Mat> for TchTensorAsImage {
154    type Error = Error;
155
156    fn try_from_cv(mat: &cv::Mat) -> Result<Self, Self::Error> {
157        let from = if mat.is_continuous() {
158            Cow::Borrowed(mat)
159        } else {
160            // Mat created from clone() is implicitly continuous
161            Cow::Owned(mat.try_clone()?)
162        };
163
164        let TchImageMeta {
165            kind,
166            width,
167            height,
168            channels,
169        } = opencv_mat_to_tch_meta_2d(&*from)?;
170
171        let tensor = unsafe {
172            let ptr = from.ptr(0)? as *const u8;
173            let slice_size = (height * width * channels) as usize * kind.elt_size_in_bytes();
174            let slice = slice::from_raw_parts(ptr, slice_size);
175            tch::Tensor::f_from_data_size(slice, &[height, width, channels], kind)?
176        };
177
178        Ok(TchTensorAsImage {
179            tensor,
180            kind: TchTensorImageShape::Hwc,
181        })
182    }
183}
184
185impl TryFromCv<cv::Mat> for TchTensorAsImage {
186    type Error = Error;
187
188    fn try_from_cv(from: cv::Mat) -> Result<Self, Self::Error> {
189        (&from).try_into_cv()
190    }
191}
192
193impl TryFromCv<&cv::Mat> for tch::Tensor {
194    type Error = Error;
195
196    fn try_from_cv(mat: &cv::Mat) -> Result<Self, Self::Error> {
197        let from = if mat.is_continuous() {
198            Cow::Borrowed(mat)
199        } else {
200            // Mat created from clone() is implicitly continuous
201            Cow::Owned(mat.try_clone()?)
202        };
203
204        let TchTensorMeta { kind, shape } = opencv_mat_to_tch_meta_nd(&*from)?;
205
206        let tensor = unsafe {
207            let ptr = from.ptr(0)? as *const u8;
208            let slice_size =
209                shape.iter().cloned().product::<i64>() as usize * kind.elt_size_in_bytes();
210            let slice = slice::from_raw_parts(ptr, slice_size);
211            tch::Tensor::f_from_data_size(slice, shape.as_ref(), kind)?
212        };
213
214        Ok(tensor)
215    }
216}
217
218impl TryFromCv<cv::Mat> for tch::Tensor {
219    type Error = Error;
220
221    fn try_from_cv(from: cv::Mat) -> Result<Self, Self::Error> {
222        (&from).try_into_cv()
223    }
224}
225
226impl TryFromCv<&TchTensorAsImage> for cv::Mat {
227    type Error = Error;
228
229    fn try_from_cv(from: &TchTensorAsImage) -> Result<Self, Self::Error> {
230        let TchTensorAsImage {
231            ref tensor,
232            kind: convention,
233        } = *from;
234
235        use TchTensorImageShape as S;
236        let (tensor, [channels, rows, cols]) = match (tensor.size3()?, convention) {
237            ((w, h, c), S::Whc) => (tensor.f_permute(&[1, 0, 2])?, [c, h, w]),
238            ((h, w, c), S::Hwc) => (tensor.shallow_clone(), [c, h, w]),
239            ((c, w, h), S::Cwh) => (tensor.f_permute(&[2, 1, 0])?, [c, h, w]),
240            ((c, h, w), S::Chw) => (tensor.f_permute(&[1, 2, 0])?, [c, h, w]),
241        };
242        let tensor = tensor.f_contiguous()?.f_to_device(tch::Device::Cpu)?;
243        let depth = tch_kind_to_opencv_depth(tensor.f_kind()?)?;
244        let typ = cv::CV_MAKE_TYPE(depth, channels as i32);
245
246        let mat = unsafe {
247            cv::Mat::new_rows_cols_with_data(
248                rows as i32,
249                cols as i32,
250                typ,
251                tensor.data_ptr(),
252                /* step = */
253                cv::Mat_AUTO_STEP,
254            )?
255            .try_clone()?
256        };
257
258        Ok(mat)
259    }
260}
261
262impl TryFromCv<TchTensorAsImage> for cv::Mat {
263    type Error = Error;
264
265    fn try_from_cv(from: TchTensorAsImage) -> Result<Self, Self::Error> {
266        (&from).try_into_cv()
267    }
268}
269
270impl TryFromCv<&tch::Tensor> for cv::Mat {
271    type Error = Error;
272
273    fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
274        let tensor = from.f_contiguous()?.f_to_device(tch::Device::Cpu)?;
275        let size: Vec<_> = tensor.size().into_iter().map(|dim| dim as i32).collect();
276        let depth = tch_kind_to_opencv_depth(tensor.f_kind()?)?;
277        let typ = cv::CV_MAKETYPE(depth, 1);
278        let mat = unsafe { cv::Mat::new_nd_with_data(&size, typ, tensor.data_ptr(), None)? };
279        Ok(mat)
280    }
281}
282
283impl TryFromCv<tch::Tensor> for cv::Mat {
284    type Error = Error;
285
286    fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
287        (&from).try_into_cv()
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use crate::tch::{self, IndexOp, Tensor};
295
296    // const EPSILON: f64 = 1e-8;
297    const ROUNDS: usize = 1000;
298
299    #[test]
300    fn tensor_mat_conv() -> Result<()> {
301        let size = [2, 3, 4, 5];
302
303        for _ in 0..ROUNDS {
304            let before = Tensor::randn(size.as_ref(), tch::kind::FLOAT_CPU);
305            let mat = cv::Mat::try_from_cv(&before)?;
306            let after = Tensor::try_from_cv(&mat)?.f_view(size)?;
307
308            // compare Tensor and Mat values
309            {
310                fn enumerate_reversed_index(dims: &[i64]) -> Vec<Vec<i64>> {
311                    match dims {
312                        [] => vec![vec![]],
313                        [dim, remaining @ ..] => {
314                            let dim = *dim;
315                            let indexes: Vec<_> = (0..dim)
316                                .flat_map(move |val| {
317                                    enumerate_reversed_index(remaining).into_iter().map(
318                                        move |mut tail| {
319                                            tail.push(val);
320                                            tail
321                                        },
322                                    )
323                                })
324                                .collect();
325                            indexes
326                        }
327                    }
328                }
329
330                enumerate_reversed_index(&before.size())
331                    .into_iter()
332                    .map(|mut index| {
333                        index.reverse();
334                        index
335                    })
336                    .try_for_each(|tch_index| -> Result<_> {
337                        let cv_index: Vec<_> =
338                            tch_index.iter().cloned().map(|val| val as i32).collect();
339                        let tch_index: Vec<_> = tch_index
340                            .iter()
341                            .cloned()
342                            .map(|val| Some(Tensor::from(val)))
343                            .collect();
344                        let tch_val: f32 = before.f_index(&tch_index)?.try_into().unwrap();
345                        let mat_val: f32 = *mat.at_nd(&cv_index)?;
346                        ensure!(tch_val == mat_val, "value mismatch");
347                        Ok(())
348                    })?;
349            }
350
351            // compare original and recovered Tensor values
352            ensure!(before == after, "value mismatch",);
353        }
354
355        Ok(())
356    }
357
358    #[test]
359    fn tensor_as_image_and_mat_conv() -> Result<()> {
360        for _ in 0..ROUNDS {
361            let channels = 3;
362            let height = 16;
363            let width = 8;
364
365            let before = Tensor::randn(&[channels, height, width], tch::kind::FLOAT_CPU);
366            let mat: cv::Mat =
367                TchTensorAsImage::new(before.shallow_clone(), TchTensorImageShape::Chw)?
368                    .try_into_cv()?;
369            let after = Tensor::try_from_cv(&mat)?.f_permute(&[2, 0, 1])?; // hwc -> chw
370
371            // compare Tensor and Mat values
372            for row in 0..height {
373                for col in 0..width {
374                    let pixel: &cv::Vec3f = mat.at_2d(row as i32, col as i32)?;
375                    let [red, green, blue] = **pixel;
376                    ensure!(f32::try_from(before.i((0, row, col))).unwrap() == red, "value mismatch");
377                    ensure!(
378                        f32::try_from(before.i((1, row, col))).unwrap() == green,
379                        "value mismatch"
380                    );
381                    ensure!(f32::try_from(before.i((2, row, col))).unwrap() == blue, "value mismatch");
382                }
383            }
384
385            // compare original and recovered Tensor values
386            {
387                let before_size = before.size();
388                let after_size = after.size();
389                ensure!(
390                    before_size == after_size,
391                    "size mismatch: {:?} vs. {:?}",
392                    before_size,
393                    after_size
394                );
395                ensure!(before == after, "value mismatch");
396            }
397        }
398        Ok(())
399    }
400
401    #[test]
402    fn tensor_from_mat_conv() -> Result<()> {
403        for _ in 0..ROUNDS {
404            let channel = 3;
405            let height = 16;
406            let width = 8;
407
408            let before = Tensor::randn(&[channel, height, width], tch::kind::FLOAT_CPU);
409            let mat: cv::Mat =
410                TchTensorAsImage::new(before.shallow_clone(), TchTensorImageShape::Chw)?
411                    .try_into_cv()?;
412            let after = OpenCvMatAsTchTensor::try_from_cv(&mat)?; // in hwc
413
414            // compare original and recovered Tensor values
415            {
416                ensure!(after.size() == [height, width, channel], "size mismatch",);
417                ensure!(&before.f_permute(&[1, 2, 0])? == &*after, "value mismatch");
418            }
419        }
420        Ok(())
421    }
422}