use std::collections::HashMap;
use std::fs::File;
use std::io::BufWriter;
use std::path::Path;
use bytes::Bytes;
use hound::{WavReader, WavSpec, WavWriter};
use image::{DynamicImage, ImageFormat};
use crate::domain::errors::MediaError;
use crate::domain::ports::MediaLoader;
use crate::domain::types::{CoverMedia, CoverMediaKind};
const KEY_WIDTH: &str = "width";
const KEY_HEIGHT: &str = "height";
const KEY_FORMAT: &str = "format";
const KEY_SAMPLE_RATE: &str = "sample_rate";
const KEY_CHANNELS: &str = "channels";
const KEY_BITS_PER_SAMPLE: &str = "bits_per_sample";
#[expect(dead_code, reason = "will be used in T13 for palette stego")]
const KEY_PALETTE: &str = "palette";
#[expect(dead_code, reason = "will be used in T16 for adaptive embedding")]
const KEY_QUANT_TABLES: &str = "quant_tables";
#[derive(Debug, Default)]
pub struct ImageMediaLoader;
impl MediaLoader for ImageMediaLoader {
fn load(&self, path: &Path) -> Result<CoverMedia, MediaError> {
let extension = path.extension().and_then(|s| s.to_str()).ok_or_else(|| {
MediaError::UnsupportedFormat {
extension: "none".to_string(),
}
})?;
let format = match extension.to_lowercase().as_str() {
"png" => ImageFormat::Png,
"bmp" => ImageFormat::Bmp,
"jpg" | "jpeg" => ImageFormat::Jpeg,
"gif" => ImageFormat::Gif,
ext => {
return Err(MediaError::UnsupportedFormat {
extension: ext.to_string(),
});
}
};
let img = image::open(path).map_err(|e| MediaError::DecodeFailed {
reason: e.to_string(),
})?;
let kind = match format {
ImageFormat::Png => CoverMediaKind::PngImage,
ImageFormat::Bmp => CoverMediaKind::BmpImage,
ImageFormat::Jpeg => CoverMediaKind::JpegImage,
ImageFormat::Gif => CoverMediaKind::GifImage,
_ => unreachable!(),
};
let rgba = img.to_rgba8();
let (width, height) = rgba.dimensions();
let mut metadata = HashMap::new();
metadata.insert(KEY_WIDTH.to_string(), width.to_string());
metadata.insert(KEY_HEIGHT.to_string(), height.to_string());
metadata.insert(KEY_FORMAT.to_string(), format!("{format:?}"));
Ok(CoverMedia {
kind,
data: Bytes::from(rgba.into_raw()),
metadata,
})
}
fn save(&self, media: &CoverMedia, path: &Path) -> Result<(), MediaError> {
let width: u32 = media
.metadata
.get(KEY_WIDTH)
.ok_or_else(|| MediaError::EncodeFailed {
reason: "missing width metadata".to_string(),
})?
.parse()
.map_err(|e: std::num::ParseIntError| MediaError::EncodeFailed {
reason: e.to_string(),
})?;
let height: u32 = media
.metadata
.get(KEY_HEIGHT)
.ok_or_else(|| MediaError::EncodeFailed {
reason: "missing height metadata".to_string(),
})?
.parse()
.map_err(|e: std::num::ParseIntError| MediaError::EncodeFailed {
reason: e.to_string(),
})?;
let img =
image::RgbaImage::from_raw(width, height, media.data.to_vec()).ok_or_else(|| {
MediaError::EncodeFailed {
reason: "invalid image dimensions or data length".to_string(),
}
})?;
let dynamic_img = DynamicImage::ImageRgba8(img);
let format = match media.kind {
CoverMediaKind::PngImage => ImageFormat::Png,
CoverMediaKind::BmpImage => ImageFormat::Bmp,
CoverMediaKind::JpegImage => ImageFormat::Jpeg,
CoverMediaKind::GifImage => ImageFormat::Gif,
_ => {
return Err(MediaError::EncodeFailed {
reason: format!("unsupported media kind: {:?}", media.kind),
});
}
};
dynamic_img
.save_with_format(path, format)
.map_err(|e| MediaError::EncodeFailed {
reason: e.to_string(),
})?;
Ok(())
}
}
#[derive(Debug, Default)]
pub struct AudioMediaLoader;
impl MediaLoader for AudioMediaLoader {
fn load(&self, path: &Path) -> Result<CoverMedia, MediaError> {
let reader = WavReader::open(path).map_err(|e| MediaError::DecodeFailed {
reason: e.to_string(),
})?;
let spec = reader.spec();
let samples: Vec<i16> = reader
.into_samples::<i16>()
.collect::<Result<Vec<_>, _>>()
.map_err(|e| MediaError::DecodeFailed {
reason: e.to_string(),
})?;
let mut data = Vec::with_capacity(samples.len().strict_mul(2));
for sample in samples {
data.extend_from_slice(&sample.to_le_bytes());
}
let mut metadata = HashMap::new();
metadata.insert(KEY_SAMPLE_RATE.to_string(), spec.sample_rate.to_string());
metadata.insert(KEY_CHANNELS.to_string(), spec.channels.to_string());
metadata.insert(
KEY_BITS_PER_SAMPLE.to_string(),
spec.bits_per_sample.to_string(),
);
Ok(CoverMedia {
kind: CoverMediaKind::WavAudio,
data: Bytes::from(data),
metadata,
})
}
fn save(&self, media: &CoverMedia, path: &Path) -> Result<(), MediaError> {
let sample_rate: u32 = media
.metadata
.get(KEY_SAMPLE_RATE)
.ok_or_else(|| MediaError::EncodeFailed {
reason: "missing sample_rate metadata".to_string(),
})?
.parse()
.map_err(|e: std::num::ParseIntError| MediaError::EncodeFailed {
reason: e.to_string(),
})?;
let channels: u16 = media
.metadata
.get(KEY_CHANNELS)
.ok_or_else(|| MediaError::EncodeFailed {
reason: "missing channels metadata".to_string(),
})?
.parse()
.map_err(|e: std::num::ParseIntError| MediaError::EncodeFailed {
reason: e.to_string(),
})?;
let bits_per_sample: u16 = media
.metadata
.get(KEY_BITS_PER_SAMPLE)
.ok_or_else(|| MediaError::EncodeFailed {
reason: "missing bits_per_sample metadata".to_string(),
})?
.parse()
.map_err(|e: std::num::ParseIntError| MediaError::EncodeFailed {
reason: e.to_string(),
})?;
let spec = WavSpec {
channels,
sample_rate,
bits_per_sample,
sample_format: hound::SampleFormat::Int,
};
let file = File::create(path).map_err(|e| MediaError::IoError {
reason: e.to_string(),
})?;
let mut writer =
WavWriter::new(BufWriter::new(file), spec).map_err(|e| MediaError::EncodeFailed {
reason: e.to_string(),
})?;
for chunk in media.data.chunks_exact(2) {
if let Ok(pair) = <[u8; 2]>::try_from(chunk) {
let sample = i16::from_le_bytes(pair);
writer
.write_sample(sample)
.map_err(|e| MediaError::EncodeFailed {
reason: e.to_string(),
})?;
}
}
writer.finalize().map_err(|e| MediaError::EncodeFailed {
reason: e.to_string(),
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
type TestResult = Result<(), Box<dyn std::error::Error>>;
#[test]
fn test_image_loader_png_roundtrip() -> TestResult {
let loader = ImageMediaLoader;
let dir = tempdir()?;
let path = dir.path().join("test.png");
let img = DynamicImage::ImageRgba8(image::RgbaImage::from_pixel(
10,
10,
image::Rgba([255, 255, 255, 255]),
));
img.save(&path)?;
let media = loader.load(&path)?;
assert_eq!(media.kind, CoverMediaKind::PngImage);
assert_eq!(media.metadata.get(KEY_WIDTH), Some(&"10".to_string()));
assert_eq!(media.metadata.get(KEY_HEIGHT), Some(&"10".to_string()));
let out_path = dir.path().join("out.png");
loader.save(&media, &out_path)?;
let reloaded = loader.load(&out_path)?;
assert_eq!(reloaded.data, media.data);
Ok(())
}
#[test]
fn test_audio_loader_wav_roundtrip() -> TestResult {
let loader = AudioMediaLoader;
let dir = tempdir()?;
let path = dir.path().join("test.wav");
let spec = WavSpec {
channels: 1,
sample_rate: 44100,
bits_per_sample: 16,
sample_format: hound::SampleFormat::Int,
};
let mut writer = WavWriter::create(&path, spec)?;
for i in 0..1000_i16 {
writer.write_sample(i)?;
}
writer.finalize()?;
let media = loader.load(&path)?;
assert_eq!(media.kind, CoverMediaKind::WavAudio);
assert_eq!(
media.metadata.get(KEY_SAMPLE_RATE),
Some(&"44100".to_string())
);
assert_eq!(media.metadata.get(KEY_CHANNELS), Some(&"1".to_string()));
let out_path = dir.path().join("out.wav");
loader.save(&media, &out_path)?;
let reloaded = loader.load(&out_path)?;
assert_eq!(reloaded.data, media.data);
Ok(())
}
#[test]
fn test_image_loader_unsupported_format() {
let loader = ImageMediaLoader;
let result = loader.load(Path::new("test.xyz"));
assert!(matches!(result, Err(MediaError::UnsupportedFormat { .. })));
}
#[test]
fn test_image_loader_no_extension() {
let loader = ImageMediaLoader;
let result = loader.load(Path::new("test"));
assert!(matches!(result, Err(MediaError::UnsupportedFormat { .. })));
}
#[test]
fn test_image_loader_bmp_roundtrip() -> TestResult {
let loader = ImageMediaLoader;
let dir = tempdir()?;
let path = dir.path().join("test.bmp");
let img = DynamicImage::ImageRgba8(image::RgbaImage::from_pixel(
5,
5,
image::Rgba([128, 64, 32, 255]),
));
img.save(&path)?;
let media = loader.load(&path)?;
assert_eq!(media.kind, CoverMediaKind::BmpImage);
assert_eq!(media.metadata.get(KEY_WIDTH), Some(&"5".to_string()));
assert_eq!(media.metadata.get(KEY_HEIGHT), Some(&"5".to_string()));
let out_path = dir.path().join("out.bmp");
loader.save(&media, &out_path)?;
let reloaded = loader.load(&out_path)?;
assert_eq!(reloaded.data, media.data);
Ok(())
}
#[test]
fn test_image_loader_jpeg_can_load() -> TestResult {
let loader = ImageMediaLoader;
let dir = tempdir()?;
let path = dir.path().join("test.jpg");
let img = DynamicImage::ImageRgba8(image::RgbaImage::from_pixel(
8,
8,
image::Rgba([200, 100, 50, 255]),
));
img.save(&path)?;
let media = loader.load(&path)?;
assert_eq!(media.kind, CoverMediaKind::JpegImage);
Ok(())
}
#[test]
fn test_image_save_unsupported_kind() {
let loader = ImageMediaLoader;
let media = CoverMedia {
kind: CoverMediaKind::WavAudio,
data: Bytes::from(vec![0u8; 100]),
metadata: {
let mut m = HashMap::new();
m.insert(KEY_WIDTH.to_string(), "10".to_string());
m.insert(KEY_HEIGHT.to_string(), "10".to_string());
m
},
};
let result = loader.save(&media, Path::new("/tmp/test.wav"));
assert!(matches!(result, Err(MediaError::EncodeFailed { .. })));
}
#[test]
fn test_image_save_missing_width() {
let loader = ImageMediaLoader;
let media = CoverMedia {
kind: CoverMediaKind::PngImage,
data: Bytes::from(vec![0u8; 100]),
metadata: {
let mut m = HashMap::new();
m.insert(KEY_HEIGHT.to_string(), "10".to_string());
m
},
};
let result = loader.save(&media, Path::new("/tmp/test.png"));
assert!(matches!(result, Err(MediaError::EncodeFailed { .. })));
}
#[test]
fn test_image_save_missing_height() {
let loader = ImageMediaLoader;
let media = CoverMedia {
kind: CoverMediaKind::PngImage,
data: Bytes::from(vec![0u8; 100]),
metadata: {
let mut m = HashMap::new();
m.insert(KEY_WIDTH.to_string(), "10".to_string());
m
},
};
let result = loader.save(&media, Path::new("/tmp/test.png"));
assert!(matches!(result, Err(MediaError::EncodeFailed { .. })));
}
#[test]
fn test_audio_save_missing_sample_rate() {
let loader = AudioMediaLoader;
let media = CoverMedia {
kind: CoverMediaKind::WavAudio,
data: Bytes::from(vec![0u8; 100]),
metadata: {
let mut m = HashMap::new();
m.insert(KEY_CHANNELS.to_string(), "1".to_string());
m.insert(KEY_BITS_PER_SAMPLE.to_string(), "16".to_string());
m
},
};
let result = loader.save(&media, Path::new("/tmp/test.wav"));
assert!(matches!(result, Err(MediaError::EncodeFailed { .. })));
}
#[test]
fn test_audio_save_missing_channels() {
let loader = AudioMediaLoader;
let media = CoverMedia {
kind: CoverMediaKind::WavAudio,
data: Bytes::from(vec![0u8; 100]),
metadata: {
let mut m = HashMap::new();
m.insert(KEY_SAMPLE_RATE.to_string(), "44100".to_string());
m.insert(KEY_BITS_PER_SAMPLE.to_string(), "16".to_string());
m
},
};
let result = loader.save(&media, Path::new("/tmp/test.wav"));
assert!(matches!(result, Err(MediaError::EncodeFailed { .. })));
}
#[test]
fn test_audio_save_missing_bits_per_sample() {
let loader = AudioMediaLoader;
let media = CoverMedia {
kind: CoverMediaKind::WavAudio,
data: Bytes::from(vec![0u8; 100]),
metadata: {
let mut m = HashMap::new();
m.insert(KEY_SAMPLE_RATE.to_string(), "44100".to_string());
m.insert(KEY_CHANNELS.to_string(), "1".to_string());
m
},
};
let result = loader.save(&media, Path::new("/tmp/test.wav"));
assert!(matches!(result, Err(MediaError::EncodeFailed { .. })));
}
#[test]
fn test_image_load_nonexistent_file() {
let loader = ImageMediaLoader;
let result = loader.load(Path::new("/nonexistent/path/image.png"));
assert!(matches!(result, Err(MediaError::DecodeFailed { .. })));
}
#[test]
fn test_audio_load_nonexistent_file() {
let loader = AudioMediaLoader;
let result = loader.load(Path::new("/nonexistent/path/audio.wav"));
assert!(matches!(result, Err(MediaError::DecodeFailed { .. })));
}
}