use crate::core::{device::Device, dtype::DataType, error::BellandeError, tensor::Tensor};
use std::collections::HashMap;
use std::io::Read;
use std::path::{Path, PathBuf};
#[derive(Debug, PartialEq)]
enum ImageFormat {
JPEG,
PNG,
Unknown,
}
#[derive(Clone, Copy, Debug)]
struct RGB {
r: u8,
g: u8,
b: u8,
}
pub struct ImageDecoder {
width: usize,
height: usize,
channels: usize,
data: Vec<u8>,
}
impl ImageDecoder {
pub fn new(bytes: &[u8]) -> Result<Self, BellandeError> {
let format = Self::detect_format(bytes)?;
match format {
ImageFormat::JPEG => Self::decode_jpeg(bytes),
ImageFormat::PNG => Self::decode_png(bytes),
ImageFormat::Unknown => Err(BellandeError::ImageError(
"Unsupported image format".to_string(),
)),
}
}
fn detect_format(bytes: &[u8]) -> Result<ImageFormat, BellandeError> {
if bytes.len() < 4 {
return Err(BellandeError::ImageError("Invalid image data".to_string()));
}
match &bytes[0..4] {
[0xFF, 0xD8, 0xFF, _] => Ok(ImageFormat::JPEG),
[0x89, 0x50, 0x4E, 0x47] => Ok(ImageFormat::PNG),
_ => Ok(ImageFormat::Unknown),
}
}
fn decode_jpeg(bytes: &[u8]) -> Result<Self, BellandeError> {
let mut reader = std::io::Cursor::new(bytes);
let mut marker = [0u8; 2];
loop {
reader.read_exact(&mut marker).map_err(|e| {
BellandeError::ImageError(format!("Failed to read JPEG marker: {}", e))
})?;
if marker[0] != 0xFF {
return Err(BellandeError::ImageError("Invalid JPEG marker".to_string()));
}
match marker[1] {
0xC0 => break, 0xD9 => return Err(BellandeError::ImageError("Reached end of JPEG".to_string())),
_ => {
let mut length = [0u8; 2];
reader.read_exact(&mut length).map_err(|e| {
BellandeError::ImageError(format!("Failed to read length: {}", e))
})?;
let length = u16::from_be_bytes(length) as u64 - 2;
reader.set_position(reader.position() + length);
}
}
}
let mut header = [0u8; 5];
reader
.read_exact(&mut header)
.map_err(|e| BellandeError::ImageError(format!("Failed to read SOF0 header: {}", e)))?;
let height = u16::from_be_bytes([header[1], header[2]]) as usize;
let width = u16::from_be_bytes([header[3], header[4]]) as usize;
let channels = 3;
let data = vec![0u8; width * height * channels];
Ok(Self {
width,
height,
channels,
data,
})
}
fn decode_png(bytes: &[u8]) -> Result<Self, BellandeError> {
let mut reader = std::io::Cursor::new(bytes);
let mut header = [0u8; 8];
reader
.read_exact(&mut header)
.map_err(|e| BellandeError::ImageError(format!("Failed to read PNG header: {}", e)))?;
let mut length = [0u8; 4];
reader.read_exact(&mut length).map_err(|e| {
BellandeError::ImageError(format!("Failed to read chunk length: {}", e))
})?;
let mut ihdr = [0u8; 8];
reader
.read_exact(&mut ihdr)
.map_err(|e| BellandeError::ImageError(format!("Failed to read IHDR: {}", e)))?;
let width = u32::from_be_bytes([ihdr[0], ihdr[1], ihdr[2], ihdr[3]]) as usize;
let height = u32::from_be_bytes([ihdr[4], ihdr[5], ihdr[6], ihdr[7]]) as usize;
let channels = 3;
let data = vec![0u8; width * height * channels];
Ok(Self {
width,
height,
channels,
data,
})
}
pub fn to_tensor(&self) -> Result<Tensor, BellandeError> {
let mut tensor_data = Vec::with_capacity(self.width * self.height * self.channels);
for &byte in &self.data {
tensor_data.push(f32::from(byte) / 255.0);
}
Ok(Tensor::new(
tensor_data,
vec![1, self.channels, self.height, self.width],
false,
Device::CPU,
DataType::Float32,
))
}
pub fn resize(&mut self, new_width: usize, new_height: usize) -> Result<(), BellandeError> {
if new_width == self.width && new_height == self.height {
return Ok(());
}
let mut new_data = vec![0u8; new_width * new_height * self.channels];
for y in 0..new_height {
for x in 0..new_width {
let src_x = (x as f32 * self.width as f32 / new_width as f32).floor() as usize;
let src_y = (y as f32 * self.height as f32 / new_height as f32).floor() as usize;
for c in 0..self.channels {
let src_idx = (src_y * self.width + src_x) * self.channels + c;
let dst_idx = (y * new_width + x) * self.channels + c;
new_data[dst_idx] = self.data[src_idx];
}
}
}
self.width = new_width;
self.height = new_height;
self.data = new_data;
Ok(())
}
}
pub struct ImageFolder {
path: PathBuf,
cache: HashMap<PathBuf, Tensor>,
supported_extensions: Vec<String>,
}
impl ImageFolder {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, BellandeError> {
let path = path.as_ref().to_path_buf();
if !path.exists() {
return Err(BellandeError::ImageError(format!(
"Image folder does not exist: {}",
path.display()
)));
}
if !path.is_dir() {
return Err(BellandeError::ImageError(format!(
"Path is not a directory: {}",
path.display()
)));
}
Ok(Self {
path,
cache: HashMap::new(),
supported_extensions: vec!["jpg".to_string(), "jpeg".to_string(), "png".to_string()],
})
}
fn decode_image(bytes: &[u8]) -> Result<Tensor, BellandeError> {
let mut decoder = ImageDecoder::new(bytes)?;
if decoder.width != 224 || decoder.height != 224 {
decoder.resize(224, 224)?;
}
decoder.to_tensor()
}
pub fn load_image<P: AsRef<Path>>(&mut self, image_path: P) -> Result<Tensor, BellandeError> {
let path = image_path.as_ref().to_path_buf();
if let Some(tensor) = self.cache.get(&path) {
return Ok(tensor.clone());
}
if !path.exists() {
return Err(BellandeError::ImageError(format!(
"Image file does not exist: {}",
path.display()
)));
}
if let Some(ext) = path.extension() {
if !self
.supported_extensions
.iter()
.any(|e| e == &ext.to_string_lossy())
{
return Err(BellandeError::ImageError(format!(
"Unsupported image format: {}",
path.display()
)));
}
}
let bytes = std::fs::read(&path).map_err(|e| {
BellandeError::ImageError(format!(
"Failed to read image file {}: {}",
path.display(),
e
))
})?;
let tensor = Self::decode_image(&bytes)?;
self.cache.insert(path, tensor.clone());
Ok(tensor)
}
pub fn list_images(&self) -> Result<Vec<PathBuf>, BellandeError> {
let mut images = Vec::new();
for entry in std::fs::read_dir(&self.path).map_err(|e| {
BellandeError::ImageError(format!(
"Failed to read directory {}: {}",
self.path.display(),
e
))
})? {
let entry = entry.map_err(|e| {
BellandeError::ImageError(format!("Failed to read directory entry: {}", e))
})?;
let path = entry.path();
if let Some(ext) = path.extension() {
if self
.supported_extensions
.iter()
.any(|e| e == &ext.to_string_lossy())
{
images.push(path);
}
}
}
Ok(images)
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
pub fn path(&self) -> &Path {
&self.path
}
}