1use crate::opencv::{core as cv, prelude::*};
2use crate::tch;
3use crate::{common::*, TchTensorAsImage, TchTensorImageShape, TryFromCv, TryIntoCv};
4use std::borrow::Cow;
5
6use utils::*;
7mod utils {
8 use super::*;
9
10 pub struct TchImageMeta {
11 pub kind: tch::Kind,
12 pub width: i64,
13 pub height: i64,
14 pub channels: i64,
15 }
16
17 pub struct TchTensorMeta {
18 pub kind: tch::Kind,
19 pub shape: Vec<i64>,
20 }
21
22 pub fn tch_kind_to_opencv_depth(kind: tch::Kind) -> Result<i32> {
23 use tch::Kind as K;
24
25 let typ = match kind {
26 K::Uint8 => cv::CV_8U,
27 K::Int8 => cv::CV_8S,
28 K::Int16 => cv::CV_16S,
29 K::Half => cv::CV_16F,
30 K::Int => cv::CV_32S,
31 K::Float => cv::CV_32F,
32 K::Double => cv::CV_64F,
33 kind => bail!("unsupported tensor kind {:?}", kind),
34 };
35
36 Ok(typ)
37 }
38
39 pub fn opencv_depth_to_tch_kind(depth: i32) -> Result<tch::Kind> {
40 use tch::Kind as K;
41
42 let kind = match depth {
43 cv::CV_8U => K::Uint8,
44 cv::CV_8S => K::Int8,
45 cv::CV_16S => K::Int16,
46 cv::CV_32S => K::Int,
47 cv::CV_16F => K::Half,
48 cv::CV_32F => K::Float,
49 cv::CV_64F => K::Double,
50 _ => bail!("unsupported OpenCV Mat depth {}", depth),
51 };
52 Ok(kind)
53 }
54
55 pub fn opencv_mat_to_tch_meta_2d(mat: &cv::Mat) -> Result<TchImageMeta> {
56 let cv::Size { height, width } = mat.size()?;
57 let kind = opencv_depth_to_tch_kind(mat.depth())?;
58 let channels = mat.channels();
59 Ok(TchImageMeta {
60 kind,
61 width: width as i64,
62 height: height as i64,
63 channels: channels as i64,
64 })
65 }
66
67 pub fn opencv_mat_to_tch_meta_nd(mat: &cv::Mat) -> Result<TchTensorMeta> {
68 let shape: Vec<_> = mat
69 .mat_size()
70 .iter()
71 .map(|&dim| dim as i64)
72 .chain([mat.channels() as i64])
73 .collect();
74 let kind = opencv_depth_to_tch_kind(mat.depth())?;
75 Ok(TchTensorMeta { shape, kind })
76 }
77}
78
79pub use tensor_from_mat::*;
80mod tensor_from_mat {
81 use super::*;
82
83 #[derive(Debug)]
85 pub struct OpenCvMatAsTchTensor<'a> {
86 pub(super) tensor: ManuallyDrop<tch::Tensor>,
87 pub(super) _mat: &'a cv::Mat,
88 }
89
90 impl<'a> Drop for OpenCvMatAsTchTensor<'a> {
91 fn drop(&mut self) {
92 unsafe {
93 ManuallyDrop::drop(&mut self.tensor);
94 }
95 }
96 }
97
98 impl<'a> AsRef<tch::Tensor> for OpenCvMatAsTchTensor<'a> {
99 fn as_ref(&self) -> &tch::Tensor {
100 self.tensor.deref()
101 }
102 }
103
104 impl<'a> Deref for OpenCvMatAsTchTensor<'a> {
105 type Target = tch::Tensor;
106
107 fn deref(&self) -> &Self::Target {
108 self.tensor.deref()
109 }
110 }
111
112 impl<'a> DerefMut for OpenCvMatAsTchTensor<'a> {
113 fn deref_mut(&mut self) -> &mut Self::Target {
114 self.tensor.deref_mut()
115 }
116 }
117}
118
119impl<'a> TryFromCv<&'a cv::Mat> for OpenCvMatAsTchTensor<'a> {
120 type Error = Error;
121
122 fn try_from_cv(from: &'a cv::Mat) -> Result<Self, Self::Error> {
123 ensure!(from.is_continuous(), "non-continuous Mat is not supported");
124
125 let TchTensorMeta { kind, shape } = opencv_mat_to_tch_meta_nd(&from)?;
126 let strides = {
127 let mut strides: Vec<_> = shape
128 .iter()
129 .rev()
130 .cloned()
131 .scan(1, |prev, dim| {
132 let stride = *prev;
133 *prev *= dim;
134 Some(stride)
135 })
136 .collect();
137 strides.reverse();
138 strides
139 };
140
141 let tensor = unsafe {
142 let ptr = from.ptr(0)? as *const u8;
143 tch::Tensor::f_from_blob(ptr, &shape, &strides, kind, tch::Device::Cpu)?
144 };
145
146 Ok(Self {
147 tensor: ManuallyDrop::new(tensor),
148 _mat: from,
149 })
150 }
151}
152
153impl TryFromCv<&cv::Mat> for TchTensorAsImage {
154 type Error = Error;
155
156 fn try_from_cv(mat: &cv::Mat) -> Result<Self, Self::Error> {
157 let from = if mat.is_continuous() {
158 Cow::Borrowed(mat)
159 } else {
160 Cow::Owned(mat.try_clone()?)
162 };
163
164 let TchImageMeta {
165 kind,
166 width,
167 height,
168 channels,
169 } = opencv_mat_to_tch_meta_2d(&*from)?;
170
171 let tensor = unsafe {
172 let ptr = from.ptr(0)? as *const u8;
173 let slice_size = (height * width * channels) as usize * kind.elt_size_in_bytes();
174 let slice = slice::from_raw_parts(ptr, slice_size);
175 tch::Tensor::f_from_data_size(slice, &[height, width, channels], kind)?
176 };
177
178 Ok(TchTensorAsImage {
179 tensor,
180 kind: TchTensorImageShape::Hwc,
181 })
182 }
183}
184
185impl TryFromCv<cv::Mat> for TchTensorAsImage {
186 type Error = Error;
187
188 fn try_from_cv(from: cv::Mat) -> Result<Self, Self::Error> {
189 (&from).try_into_cv()
190 }
191}
192
193impl TryFromCv<&cv::Mat> for tch::Tensor {
194 type Error = Error;
195
196 fn try_from_cv(mat: &cv::Mat) -> Result<Self, Self::Error> {
197 let from = if mat.is_continuous() {
198 Cow::Borrowed(mat)
199 } else {
200 Cow::Owned(mat.try_clone()?)
202 };
203
204 let TchTensorMeta { kind, shape } = opencv_mat_to_tch_meta_nd(&*from)?;
205
206 let tensor = unsafe {
207 let ptr = from.ptr(0)? as *const u8;
208 let slice_size =
209 shape.iter().cloned().product::<i64>() as usize * kind.elt_size_in_bytes();
210 let slice = slice::from_raw_parts(ptr, slice_size);
211 tch::Tensor::f_from_data_size(slice, shape.as_ref(), kind)?
212 };
213
214 Ok(tensor)
215 }
216}
217
218impl TryFromCv<cv::Mat> for tch::Tensor {
219 type Error = Error;
220
221 fn try_from_cv(from: cv::Mat) -> Result<Self, Self::Error> {
222 (&from).try_into_cv()
223 }
224}
225
226impl TryFromCv<&TchTensorAsImage> for cv::Mat {
227 type Error = Error;
228
229 fn try_from_cv(from: &TchTensorAsImage) -> Result<Self, Self::Error> {
230 let TchTensorAsImage {
231 ref tensor,
232 kind: convention,
233 } = *from;
234
235 use TchTensorImageShape as S;
236 let (tensor, [channels, rows, cols]) = match (tensor.size3()?, convention) {
237 ((w, h, c), S::Whc) => (tensor.f_permute(&[1, 0, 2])?, [c, h, w]),
238 ((h, w, c), S::Hwc) => (tensor.shallow_clone(), [c, h, w]),
239 ((c, w, h), S::Cwh) => (tensor.f_permute(&[2, 1, 0])?, [c, h, w]),
240 ((c, h, w), S::Chw) => (tensor.f_permute(&[1, 2, 0])?, [c, h, w]),
241 };
242 let tensor = tensor.f_contiguous()?.f_to_device(tch::Device::Cpu)?;
243 let depth = tch_kind_to_opencv_depth(tensor.f_kind()?)?;
244 let typ = cv::CV_MAKE_TYPE(depth, channels as i32);
245
246 let mat = unsafe {
247 cv::Mat::new_rows_cols_with_data(
248 rows as i32,
249 cols as i32,
250 typ,
251 tensor.data_ptr(),
252 cv::Mat_AUTO_STEP,
254 )?
255 .try_clone()?
256 };
257
258 Ok(mat)
259 }
260}
261
262impl TryFromCv<TchTensorAsImage> for cv::Mat {
263 type Error = Error;
264
265 fn try_from_cv(from: TchTensorAsImage) -> Result<Self, Self::Error> {
266 (&from).try_into_cv()
267 }
268}
269
270impl TryFromCv<&tch::Tensor> for cv::Mat {
271 type Error = Error;
272
273 fn try_from_cv(from: &tch::Tensor) -> Result<Self, Self::Error> {
274 let tensor = from.f_contiguous()?.f_to_device(tch::Device::Cpu)?;
275 let size: Vec<_> = tensor.size().into_iter().map(|dim| dim as i32).collect();
276 let depth = tch_kind_to_opencv_depth(tensor.f_kind()?)?;
277 let typ = cv::CV_MAKETYPE(depth, 1);
278 let mat = unsafe { cv::Mat::new_nd_with_data(&size, typ, tensor.data_ptr(), None)? };
279 Ok(mat)
280 }
281}
282
283impl TryFromCv<tch::Tensor> for cv::Mat {
284 type Error = Error;
285
286 fn try_from_cv(from: tch::Tensor) -> Result<Self, Self::Error> {
287 (&from).try_into_cv()
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use crate::tch::{self, IndexOp, Tensor};
295
296 const ROUNDS: usize = 1000;
298
299 #[test]
300 fn tensor_mat_conv() -> Result<()> {
301 let size = [2, 3, 4, 5];
302
303 for _ in 0..ROUNDS {
304 let before = Tensor::randn(size.as_ref(), tch::kind::FLOAT_CPU);
305 let mat = cv::Mat::try_from_cv(&before)?;
306 let after = Tensor::try_from_cv(&mat)?.f_view(size)?;
307
308 {
310 fn enumerate_reversed_index(dims: &[i64]) -> Vec<Vec<i64>> {
311 match dims {
312 [] => vec![vec![]],
313 [dim, remaining @ ..] => {
314 let dim = *dim;
315 let indexes: Vec<_> = (0..dim)
316 .flat_map(move |val| {
317 enumerate_reversed_index(remaining).into_iter().map(
318 move |mut tail| {
319 tail.push(val);
320 tail
321 },
322 )
323 })
324 .collect();
325 indexes
326 }
327 }
328 }
329
330 enumerate_reversed_index(&before.size())
331 .into_iter()
332 .map(|mut index| {
333 index.reverse();
334 index
335 })
336 .try_for_each(|tch_index| -> Result<_> {
337 let cv_index: Vec<_> =
338 tch_index.iter().cloned().map(|val| val as i32).collect();
339 let tch_index: Vec<_> = tch_index
340 .iter()
341 .cloned()
342 .map(|val| Some(Tensor::from(val)))
343 .collect();
344 let tch_val: f32 = before.f_index(&tch_index)?.try_into().unwrap();
345 let mat_val: f32 = *mat.at_nd(&cv_index)?;
346 ensure!(tch_val == mat_val, "value mismatch");
347 Ok(())
348 })?;
349 }
350
351 ensure!(before == after, "value mismatch",);
353 }
354
355 Ok(())
356 }
357
358 #[test]
359 fn tensor_as_image_and_mat_conv() -> Result<()> {
360 for _ in 0..ROUNDS {
361 let channels = 3;
362 let height = 16;
363 let width = 8;
364
365 let before = Tensor::randn(&[channels, height, width], tch::kind::FLOAT_CPU);
366 let mat: cv::Mat =
367 TchTensorAsImage::new(before.shallow_clone(), TchTensorImageShape::Chw)?
368 .try_into_cv()?;
369 let after = Tensor::try_from_cv(&mat)?.f_permute(&[2, 0, 1])?; for row in 0..height {
373 for col in 0..width {
374 let pixel: &cv::Vec3f = mat.at_2d(row as i32, col as i32)?;
375 let [red, green, blue] = **pixel;
376 ensure!(f32::try_from(before.i((0, row, col))).unwrap() == red, "value mismatch");
377 ensure!(
378 f32::try_from(before.i((1, row, col))).unwrap() == green,
379 "value mismatch"
380 );
381 ensure!(f32::try_from(before.i((2, row, col))).unwrap() == blue, "value mismatch");
382 }
383 }
384
385 {
387 let before_size = before.size();
388 let after_size = after.size();
389 ensure!(
390 before_size == after_size,
391 "size mismatch: {:?} vs. {:?}",
392 before_size,
393 after_size
394 );
395 ensure!(before == after, "value mismatch");
396 }
397 }
398 Ok(())
399 }
400
401 #[test]
402 fn tensor_from_mat_conv() -> Result<()> {
403 for _ in 0..ROUNDS {
404 let channel = 3;
405 let height = 16;
406 let width = 8;
407
408 let before = Tensor::randn(&[channel, height, width], tch::kind::FLOAT_CPU);
409 let mat: cv::Mat =
410 TchTensorAsImage::new(before.shallow_clone(), TchTensorImageShape::Chw)?
411 .try_into_cv()?;
412 let after = OpenCvMatAsTchTensor::try_from_cv(&mat)?; {
416 ensure!(after.size() == [height, width, channel], "size mismatch",);
417 ensure!(&before.f_permute(&[1, 2, 0])? == &*after, "value mismatch");
418 }
419 }
420 Ok(())
421 }
422}