use anyhow::Result;
use image::RgbImage;
use std::path::Path;
pub fn load_rgbimage<P>(path: P) -> Result<RgbImage>
where
P: AsRef<Path>,
{
let path = path.as_ref();
let img = image::open(path)?;
Ok(img.to_rgb8())
}
#[derive(Debug, Clone)]
pub struct RgbImageBatch {
pub data: Vec<u8>,
pub shape: Vec<usize>,
}
impl RgbImageBatch {
pub fn new(shape: &[usize]) -> Self {
let shape = shape.to_vec();
let size = shape.iter().product();
let data = Vec::with_capacity(size);
assert_eq!(shape.len(), 4);
assert_eq!(shape[3], 3);
Self { shape, data }
}
pub(crate) fn push_rgb_pixels(
&mut self,
img: &RgbImage,
) {
for rgb in img.pixels() {
self.data.push(rgb[0]);
self.data.push(rgb[1]);
self.data.push(rgb[2]);
}
}
pub fn batch_size(&self) -> usize {
self.shape[0]
}
pub fn height(&self) -> usize {
self.shape[1]
}
pub fn width(&self) -> usize {
self.shape[2]
}
pub fn channels(&self) -> usize {
self.shape[3]
}
pub fn size(&self) -> usize {
self.data.capacity()
}
}
pub fn load_batch<T, P>(
paths: &[P],
on_dims: fn(&[usize; 4]) -> Result<T>,
on_img: fn(&mut T, idx: usize, img: &RgbImage) -> Result<()>,
) -> Result<T>
where
P: AsRef<Path>,
{
let batch_size = paths.len();
let path = paths.first().unwrap().as_ref();
let img = load_rgbimage(path)?;
let (width, height) = img.dimensions();
let shape = [batch_size, height as usize, width as usize, 3];
let mut batch = on_dims(&shape)?;
on_img(&mut batch, 0, &img)?;
for i in 1..batch_size {
let path = paths.get(i).unwrap();
let img = load_rgbimage(path)?;
assert_eq!(
img.dimensions(),
(width, height),
"Image dimensions do not match"
);
on_img(&mut batch, i, &img)?;
}
Ok(batch)
}
pub fn load_bhwc_rgbimagebatch<P>(paths: &[P]) -> Result<RgbImageBatch>
where
P: AsRef<Path>,
{
load_batch::<RgbImageBatch, _>(
paths,
|shape| Ok(RgbImageBatch::new(shape)),
|batch, _idx, img| {
batch.push_rgb_pixels(img);
Ok(())
},
)
}