use crate::errors::DatasetError;
use image::DynamicImage;
use indicatif::{ProgressBar, ProgressStyle};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use tracing::debug;
const SUPPORTED_EXTENSIONS: &[&str] = &["jpg", "jpeg", "png", "bmp", "tiff", "webp"];
#[derive(Debug, Clone)]
pub struct ImageEntry {
pub path: PathBuf,
pub relative_path: PathBuf,
pub stem: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SkippedImage {
pub path: String,
pub reason: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DatasetProcessingSummary {
pub discovered: usize,
pub loaded: usize,
pub skipped: usize,
pub skipped_examples: Vec<SkippedImage>,
}
impl DatasetProcessingSummary {
pub fn new(discovered: usize) -> Self {
Self {
discovered,
loaded: 0,
skipped: 0,
skipped_examples: Vec::new(),
}
}
pub fn record_loaded(&mut self) {
self.loaded += 1;
}
pub fn record_skipped(&mut self, path: impl Into<String>, reason: impl Into<String>) {
self.skipped += 1;
if self.skipped_examples.len() < 5 {
self.skipped_examples.push(SkippedImage {
path: path.into(),
reason: reason.into(),
});
}
}
pub fn has_loaded_images(&self) -> bool {
self.loaded > 0
}
pub fn merge(&mut self, other: Self) {
self.loaded += other.loaded;
self.skipped += other.skipped;
for skipped in other.skipped_examples {
if self.skipped_examples.len() >= 5 {
break;
}
self.skipped_examples.push(skipped);
}
}
}
pub fn scan_images(dir: &Path) -> Result<Vec<ImageEntry>, DatasetError> {
if !dir.exists() || !dir.is_dir() {
return Err(DatasetError::DirectoryNotFound(dir.display().to_string()));
}
let mut entries: Vec<ImageEntry> = Vec::new();
scan_images_recursive(dir, dir, &mut entries)?;
if entries.is_empty() {
return Err(DatasetError::NoImages(dir.display().to_string()));
}
entries.sort_by(|a, b| a.relative_path.cmp(&b.relative_path));
Ok(entries)
}
pub fn load_image(path: &Path) -> Result<DynamicImage, DatasetError> {
debug!("Loading image: {}", path.display());
image::open(path).map_err(|e| DatasetError::ImageLoad {
path: path.display().to_string(),
reason: e.to_string(),
})
}
pub fn for_each_image<F, E>(
dir: &Path,
show_progress: bool,
mut visit: F,
) -> Result<DatasetProcessingSummary, E>
where
F: FnMut(ImageEntry, DynamicImage) -> Result<(), E>,
E: From<DatasetError>,
{
let dataset = DatasetIterator::new(dir, show_progress)?;
let mut summary = DatasetProcessingSummary::new(dataset.len());
for result in dataset {
match result {
Ok((entry, image)) => {
summary.record_loaded();
visit(entry, image)?;
}
Err(DatasetError::ImageLoad { path, reason }) => {
summary.record_skipped(path, reason);
}
Err(error) => return Err(error.into()),
}
}
Ok(summary)
}
pub fn map_images_parallel<W, T, Init, Visit, E>(
dir: &Path,
show_progress: bool,
init_worker: Init,
visit: Visit,
) -> Result<(DatasetProcessingSummary, Vec<T>), E>
where
T: Send,
E: From<DatasetError> + Send,
Init: Fn() -> Result<W, E> + Send + Sync,
Visit: Fn(&mut W, ImageEntry, DynamicImage) -> Result<Option<T>, E> + Send + Sync,
{
let entries = scan_images(dir)?;
let discovered = entries.len();
let indexed_entries = entries
.into_iter()
.enumerate()
.collect::<Vec<(usize, ImageEntry)>>();
let progress = dataset_progress_bar(discovered, show_progress);
let init_worker = &init_worker;
let visit = &visit;
let chunk_results = indexed_entries
.par_chunks(parallel_chunk_size(discovered))
.map(
|chunk| -> Result<(DatasetProcessingSummary, Vec<(usize, T)>), E> {
let mut worker = init_worker()?;
let mut summary = DatasetProcessingSummary::new(0);
let mut items = Vec::with_capacity(chunk.len());
for (index, entry) in chunk.iter().cloned() {
let image = load_entry_image(&entry);
if let Some(progress) = &progress {
progress.inc(1);
}
match image {
Ok(image) => {
summary.record_loaded();
if let Some(item) = visit(&mut worker, entry, image)? {
items.push((index, item));
}
}
Err(DatasetError::ImageLoad { path, reason }) => {
summary.record_skipped(path, reason);
}
Err(error) => return Err(error.into()),
}
}
Ok((summary, items))
},
)
.collect::<Result<Vec<_>, E>>();
if let Some(progress) = &progress {
progress.finish_with_message("Done");
}
let mut summary = DatasetProcessingSummary::new(discovered);
let mut indexed_items = Vec::new();
for (chunk_summary, mut chunk_items) in chunk_results? {
summary.merge(chunk_summary);
indexed_items.append(&mut chunk_items);
}
indexed_items.sort_by_key(|(index, _)| *index);
Ok((
summary,
indexed_items
.into_iter()
.map(|(_, item)| item)
.collect::<Vec<_>>(),
))
}
fn scan_images_recursive(
root: &Path,
current: &Path,
entries: &mut Vec<ImageEntry>,
) -> Result<(), DatasetError> {
let mut directory_entries = std::fs::read_dir(current)?.collect::<Result<Vec<_>, _>>()?;
directory_entries.sort_by_key(|entry| entry.path());
for entry in directory_entries {
let path = entry.path();
let file_type = entry.file_type()?;
if file_type.is_dir() {
scan_images_recursive(root, &path, entries)?;
continue;
}
if file_type.is_file() && is_supported_image_path(&path) {
let relative_path = path
.strip_prefix(root)
.unwrap_or(path.as_path())
.to_path_buf();
entries.push(ImageEntry {
stem: relative_stem(&relative_path),
path,
relative_path,
});
}
}
Ok(())
}
fn is_supported_image_path(path: &Path) -> bool {
path.extension()
.and_then(|ext| ext.to_str())
.map(|ext| SUPPORTED_EXTENSIONS.contains(&ext.to_ascii_lowercase().as_str()))
.unwrap_or(false)
}
fn relative_stem(relative_path: &Path) -> String {
relative_path
.with_extension("")
.to_string_lossy()
.replace('\\', "/")
}
fn load_entry_image(entry: &ImageEntry) -> Result<DynamicImage, DatasetError> {
debug!("Loading image: {}", entry.path.display());
image::open(&entry.path).map_err(|e| DatasetError::ImageLoad {
path: entry.relative_path.display().to_string(),
reason: e.to_string(),
})
}
pub struct DatasetIterator {
entries: Vec<ImageEntry>,
index: usize,
progress: Option<ProgressBar>,
}
impl DatasetIterator {
pub fn new(dir: &Path, show_progress: bool) -> Result<Self, DatasetError> {
let entries = scan_images(dir)?;
let progress = dataset_progress_bar(entries.len(), show_progress);
Ok(Self {
entries,
index: 0,
progress,
})
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
fn dataset_progress_bar(len: usize, show_progress: bool) -> Option<ProgressBar> {
if !show_progress {
return None;
}
let pb = ProgressBar::new(len as u64);
let style =
match ProgressStyle::default_bar().template("{msg} [{bar:40.cyan/blue}] {pos}/{len}") {
Ok(style) => style.progress_chars("=> "),
Err(_) => ProgressStyle::default_bar().progress_chars("=> "),
};
pb.set_style(style);
pb.set_message("Loading images");
Some(pb)
}
fn parallel_chunk_size(len: usize) -> usize {
let workers = std::thread::available_parallelism()
.map(|threads| threads.get())
.unwrap_or(1)
.max(1);
len.div_ceil(workers).max(1)
}
impl Iterator for DatasetIterator {
type Item = Result<(ImageEntry, DynamicImage), DatasetError>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.entries.len() {
if let Some(pb) = &self.progress {
pb.finish_with_message("Done");
}
return None;
}
let entry = self.entries[self.index].clone();
self.index += 1;
if let Some(pb) = &self.progress {
pb.inc(1);
}
let img = match load_entry_image(&entry) {
Ok(img) => img,
Err(e) => return Some(Err(e)),
};
Some(Ok((entry, img)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_scan_empty_dir() {
let dir = tempdir().unwrap();
let result = scan_images(dir.path());
assert!(matches!(result, Err(DatasetError::NoImages(_))));
}
#[test]
fn test_scan_nonexistent_dir() {
let result = scan_images(Path::new("/nonexistent/path/12345"));
assert!(matches!(result, Err(DatasetError::DirectoryNotFound(_))));
}
#[test]
fn test_scan_finds_images() {
let dir = tempdir().unwrap();
let img = image::RgbImage::new(4, 4);
let path = dir.path().join("test.png");
img.save(&path).unwrap();
let entries = scan_images(dir.path()).unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].stem, "test");
}
#[test]
fn test_scan_recurses_and_sorts_by_relative_path() {
let dir = tempdir().unwrap();
let nested_a = dir.path().join("class-b");
let nested_b = dir.path().join("class-a").join("deep");
std::fs::create_dir_all(&nested_a).unwrap();
std::fs::create_dir_all(&nested_b).unwrap();
image::RgbImage::new(4, 4)
.save(dir.path().join("root.png"))
.unwrap();
image::RgbImage::new(4, 4)
.save(nested_a.join("beta.png"))
.unwrap();
image::RgbImage::new(4, 4)
.save(nested_b.join("alpha.png"))
.unwrap();
let entries = scan_images(dir.path()).unwrap();
let relative_paths = entries
.iter()
.map(|entry| entry.relative_path.display().to_string())
.collect::<Vec<_>>();
let stems = entries
.iter()
.map(|entry| entry.stem.as_str())
.collect::<Vec<_>>();
assert_eq!(
relative_paths,
vec![
"class-a/deep/alpha.png".to_string(),
"class-b/beta.png".to_string(),
"root.png".to_string()
]
);
assert_eq!(stems, vec!["class-a/deep/alpha", "class-b/beta", "root"]);
}
#[test]
fn test_for_each_image_skips_corrupt_supported_files() {
let dir = tempdir().unwrap();
let nested = dir.path().join("nested");
std::fs::create_dir_all(&nested).unwrap();
image::RgbImage::new(4, 4)
.save(dir.path().join("good.png"))
.unwrap();
std::fs::write(nested.join("broken.png"), b"not a real image").unwrap();
let mut visited = Vec::new();
let summary = for_each_image(dir.path(), false, |entry, _image| {
visited.push(entry.stem);
Ok::<(), DatasetError>(())
})
.unwrap();
assert_eq!(visited, vec!["good".to_string()]);
assert_eq!(summary.discovered, 2);
assert_eq!(summary.loaded, 1);
assert_eq!(summary.skipped, 1);
assert_eq!(summary.skipped_examples.len(), 1);
assert_eq!(summary.skipped_examples[0].path, "nested/broken.png");
}
#[test]
fn test_map_images_parallel_preserves_order_and_skips_corrupt_files() {
let dir = tempdir().unwrap();
let nested = dir.path().join("class-a");
std::fs::create_dir_all(&nested).unwrap();
image::RgbImage::new(4, 4)
.save(dir.path().join("root.png"))
.unwrap();
image::RgbImage::new(4, 4)
.save(nested.join("leaf.png"))
.unwrap();
std::fs::write(dir.path().join("broken.png"), b"not a real image").unwrap();
let (summary, stems) = map_images_parallel(
dir.path(),
false,
|| Ok::<usize, DatasetError>(0),
|_, entry, _| Ok::<_, DatasetError>(Some(entry.stem)),
)
.unwrap();
assert_eq!(summary.discovered, 3);
assert_eq!(summary.loaded, 2);
assert_eq!(summary.skipped, 1);
assert_eq!(summary.skipped_examples[0].path, "broken.png");
assert_eq!(stems, vec!["class-a/leaf".to_string(), "root".to_string()]);
}
}