1#[cfg(feature = "image-folder")]
31pub use inner::*;
32
33#[cfg(feature = "image-folder")]
34mod inner {
35 use std::path::{Path, PathBuf};
36
37 use image::imageops::FilterType;
38 use image::GenericImageView;
39
40 use crate::dataset::{Dataset, Sample};
41
42 const EXTENSIONS: &[&str] = &["jpg", "jpeg", "png", "bmp", "gif", "tiff", "tif", "webp"];
44
45 fn is_image(path: &Path) -> bool {
46 path.extension()
47 .and_then(|e| e.to_str())
48 .map(|e| EXTENSIONS.contains(&e.to_ascii_lowercase().as_str()))
49 .unwrap_or(false)
50 }
51
52 pub struct ImageFolderBuilder {
56 root: PathBuf,
57 resize: Option<(u32, u32)>,
58 grayscale: bool,
59 }
60
61 impl ImageFolderBuilder {
62 pub fn new<P: AsRef<Path>>(root: P) -> Self {
64 ImageFolderBuilder {
65 root: root.as_ref().to_path_buf(),
66 resize: None,
67 grayscale: false,
68 }
69 }
70
71 pub fn resize(mut self, width: u32, height: u32) -> Self {
73 self.resize = Some((width, height));
74 self
75 }
76
77 pub fn grayscale(mut self, yes: bool) -> Self {
79 self.grayscale = yes;
80 self
81 }
82
83 pub fn build(self) -> Result<ImageFolder, crate::ImageFolderError> {
85 ImageFolder::scan(self.root, self.resize, self.grayscale)
86 }
87 }
88
89 #[derive(Debug)]
93 pub struct ImageFolder {
94 class_names: Vec<String>,
96 entries: Vec<(PathBuf, usize)>,
98 resize: Option<(u32, u32)>,
100 grayscale: bool,
102 channels: usize,
104 width: u32,
106 height: u32,
108 }
109
110 impl ImageFolder {
111 pub fn new<P: AsRef<Path>>(root: P) -> ImageFolderBuilder {
113 ImageFolderBuilder::new(root)
114 }
115
116 fn scan(
118 root: PathBuf,
119 resize: Option<(u32, u32)>,
120 grayscale: bool,
121 ) -> Result<Self, crate::ImageFolderError> {
122 if !root.is_dir() {
123 return Err(crate::ImageFolderError::NotADirectory(
124 root.display().to_string(),
125 ));
126 }
127
128 let mut class_dirs: Vec<(String, PathBuf)> = Vec::new();
130 for entry in std::fs::read_dir(&root).map_err(|e| crate::ImageFolderError::Io(e))? {
131 let entry = entry.map_err(|e| crate::ImageFolderError::Io(e))?;
132 let path = entry.path();
133 if path.is_dir() {
134 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
135 class_dirs.push((name.to_string(), path));
136 }
137 }
138 }
139 class_dirs.sort_by(|a, b| a.0.cmp(&b.0));
140
141 if class_dirs.is_empty() {
142 return Err(crate::ImageFolderError::NoClasses(
143 root.display().to_string(),
144 ));
145 }
146
147 let class_names: Vec<String> = class_dirs.iter().map(|(n, _)| n.clone()).collect();
148
149 let mut entries: Vec<(PathBuf, usize)> = Vec::new();
151 for (class_idx, (_name, dir)) in class_dirs.iter().enumerate() {
152 let mut paths: Vec<PathBuf> = Vec::new();
153 Self::collect_images(dir, &mut paths);
154 paths.sort();
155 for p in paths {
156 entries.push((p, class_idx));
157 }
158 }
159
160 if entries.is_empty() {
161 return Err(crate::ImageFolderError::NoImages(
162 root.display().to_string(),
163 ));
164 }
165
166 let channels = if grayscale { 1 } else { 3 };
167 let (width, height) = resize.unwrap_or((0, 0));
168
169 Ok(ImageFolder {
170 class_names,
171 entries,
172 resize,
173 grayscale,
174 channels,
175 width,
176 height,
177 })
178 }
179
180 fn collect_images(dir: &Path, out: &mut Vec<PathBuf>) {
182 if let Ok(rd) = std::fs::read_dir(dir) {
183 for entry in rd.flatten() {
184 let path = entry.path();
185 if path.is_dir() {
186 Self::collect_images(&path, out);
187 } else if is_image(&path) {
188 out.push(path);
189 }
190 }
191 }
192 }
193
194 pub fn class_names(&self) -> &[String] {
196 &self.class_names
197 }
198
199 pub fn num_classes(&self) -> usize {
201 self.class_names.len()
202 }
203
204 pub fn class_of(&self, index: usize) -> usize {
206 self.entries[index].1
207 }
208
209 pub fn path_of(&self, index: usize) -> &Path {
211 &self.entries[index].0
212 }
213
214 fn load_image(
217 &self,
218 index: usize,
219 ) -> Result<(Vec<f64>, [usize; 3]), crate::ImageFolderError> {
220 let path = &self.entries[index].0;
221 let img = image::open(path).map_err(|e| {
222 crate::ImageFolderError::ImageDecode(path.display().to_string(), e.to_string())
223 })?;
224
225 let img = match self.resize {
227 Some((w, h)) => img.resize_exact(w, h, FilterType::Lanczos3),
228 None => img,
229 };
230
231 let (w, h) = img.dimensions();
233 let (pixels, c) = if self.grayscale {
234 let gray = img.to_luma8();
235 let data: Vec<f64> = gray.as_raw().iter().map(|&v| v as f64 / 255.0).collect();
236 (data, 1usize)
237 } else {
238 let rgb = img.to_rgb8();
239 let raw = rgb.as_raw();
240 let npix = (w * h) as usize;
242 let mut data = vec![0.0f64; 3 * npix];
243 for i in 0..npix {
244 data[i] = raw[i * 3] as f64 / 255.0; data[npix + i] = raw[i * 3 + 1] as f64 / 255.0; data[2 * npix + i] = raw[i * 3 + 2] as f64 / 255.0; }
248 (data, 3usize)
249 };
250
251 Ok((pixels, [c, h as usize, w as usize]))
252 }
253 }
254
255 impl Dataset for ImageFolder {
256 fn len(&self) -> usize {
257 self.entries.len()
258 }
259
260 fn get(&self, index: usize) -> Sample {
261 match self.load_image(index) {
262 Ok((features, shape)) => Sample {
263 features,
264 feature_shape: shape.to_vec(),
265 target: vec![self.entries[index].1 as f64],
266 target_shape: vec![1],
267 },
268 Err(e) => {
269 let c = self.channels;
271 let (w, h) = self.resize.unwrap_or((1, 1));
272 eprintln!(
273 "ImageFolder: failed to load {:?}: {}",
274 self.entries[index].0, e
275 );
276 Sample {
277 features: vec![0.0; c * (h as usize) * (w as usize)],
278 feature_shape: vec![c, h as usize, w as usize],
279 target: vec![self.entries[index].1 as f64],
280 target_shape: vec![1],
281 }
282 }
283 }
284 }
285
286 fn feature_shape(&self) -> &[usize] {
287 &[]
291 }
292
293 fn target_shape(&self) -> &[usize] {
294 &[]
295 }
296
297 fn name(&self) -> &str {
298 "ImageFolder"
299 }
300 }
301
302 unsafe impl Send for ImageFolder {}
304 unsafe impl Sync for ImageFolder {}
305}