1use crate::wrappers::image::{load_hwc, load_hwc_from_mem, resize_hwc, save_hwc};
3use crate::{Device, TchError, Tensor};
4use std::io;
5use std::path::Path;
6
7pub(crate) fn hwc_to_chw(tensor: &Tensor) -> Tensor {
8 tensor.permute([2, 0, 1])
9}
10
11pub(crate) fn chw_to_hwc(tensor: &Tensor) -> Tensor {
12 tensor.permute([1, 2, 0])
13}
14
15pub fn load<T: AsRef<Path>>(path: T) -> Result<Tensor, TchError> {
19 let tensor = load_hwc(path)?;
20 Ok(hwc_to_chw(&tensor))
21}
22
23pub fn load_from_memory(img_data: &[u8]) -> Result<Tensor, TchError> {
27 let tensor = load_hwc_from_mem(img_data)?;
28 Ok(hwc_to_chw(&tensor))
29}
30
31pub fn save<T: AsRef<Path>>(t: &Tensor, path: T) -> Result<(), TchError> {
39 let t = t.to_kind(crate::Kind::Uint8);
40 match t.size().as_slice() {
41 [1, _, _, _] => save_hwc(&chw_to_hwc(&t.squeeze_dim(0)).to_device(Device::Cpu), path),
42 [_, _, _] => save_hwc(&chw_to_hwc(&t).to_device(Device::Cpu), path),
43 sz => Err(TchError::FileFormat(format!("unexpected size for image tensor {sz:?}"))),
44 }
45}
46
47pub fn resize(t: &Tensor, out_w: i64, out_h: i64) -> Result<Tensor, TchError> {
52 Ok(hwc_to_chw(&resize_hwc(&chw_to_hwc(t), out_w, out_h)?))
53}
54
55pub fn resize_preserve_aspect_ratio_hwc(
56 t: &Tensor,
57 out_w: i64,
58 out_h: i64,
59) -> Result<Tensor, TchError> {
60 let tensor_size = t.size();
61 let (w, h) = (tensor_size[0], tensor_size[1]);
62 if w * out_h == h * out_w {
63 Ok(hwc_to_chw(&resize_hwc(t, out_w, out_h)?))
64 } else {
65 let (resize_w, resize_h) = {
66 let ratio_w = out_w as f64 / w as f64;
67 let ratio_h = out_h as f64 / h as f64;
68 let ratio = ratio_w.max(ratio_h);
69 ((ratio * h as f64) as i64, (ratio * w as f64) as i64)
70 };
71 let resize_w = i64::max(resize_w, out_w);
72 let resize_h = i64::max(resize_h, out_h);
73 let t = hwc_to_chw(&resize_hwc(t, resize_w, resize_h)?);
74 let t = if resize_w == out_w { t } else { t.f_narrow(2, (resize_w - out_w) / 2, out_w)? };
75 let t = if resize_h == out_h { t } else { t.f_narrow(1, (resize_h - out_h) / 2, out_h)? };
76 Ok(t)
77 }
78}
79
80pub fn resize_preserve_aspect_ratio(
84 t: &Tensor,
85 out_w: i64,
86 out_h: i64,
87) -> Result<Tensor, TchError> {
88 resize_preserve_aspect_ratio_hwc(&chw_to_hwc(t), out_w, out_h)
89}
90
91pub fn load_and_resize<T: AsRef<Path>>(
93 path: T,
94 out_w: i64,
95 out_h: i64,
96) -> Result<Tensor, TchError> {
97 let tensor = load_hwc(path)?;
98 resize_preserve_aspect_ratio_hwc(&tensor, out_w, out_h)
99}
100
101pub fn load_and_resize_from_memory(
103 img_data: &[u8],
104 out_w: i64,
105 out_h: i64,
106) -> Result<Tensor, TchError> {
107 let tensor = load_hwc_from_mem(img_data)?;
108 resize_preserve_aspect_ratio_hwc(&tensor, out_w, out_h)
109}
110
111fn visit_dirs(dir: &Path, files: &mut Vec<std::fs::DirEntry>) -> Result<(), TchError> {
112 if dir.is_dir() {
113 for entry in std::fs::read_dir(dir)? {
114 let entry = entry?;
115 let path = entry.path();
116 if path.is_dir() {
117 visit_dirs(&path, files)?;
118 } else if entry
119 .file_name()
120 .to_str()
121 .map_or(false, |s| s.ends_with(".png") || s.ends_with(".jpg"))
122 {
123 files.push(entry);
124 }
125 }
126 }
127 Ok(())
128}
129
130pub fn load_dir<T: AsRef<Path>>(path: T, out_w: i64, out_h: i64) -> Result<Tensor, TchError> {
132 let mut files: Vec<std::fs::DirEntry> = vec![];
133 visit_dirs(path.as_ref(), &mut files)?;
134 if files.is_empty() {
135 return Err(TchError::Io(io::Error::new(
136 io::ErrorKind::NotFound,
137 format!("no image found in {:?}", path.as_ref(),),
138 )));
139 }
140 let v: Vec<_> = files
141 .iter()
142 .filter_map(|x| load_and_resize(x.path(), out_w, out_h).ok())
144 .collect();
145 Ok(Tensor::stack(&v, 0))
146}