use super::error::{PdfError, Result};
use bytes::Bytes;
use lopdf::Document;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PdfImage {
pub page_number: usize,
pub image_index: usize,
pub width: i64,
pub height: i64,
pub color_space: Option<String>,
pub bits_per_component: Option<i64>,
pub filters: Vec<String>,
pub data: Bytes,
pub decoded_format: String,
}
#[derive(Debug)]
pub struct PdfImageExtractor {
document: Document,
}
#[cfg(feature = "pdf")]
#[allow(clippy::too_many_arguments)]
fn decode_image_data(
raw: &[u8],
filters: &[String],
color_space: Option<&str>,
width: i64,
height: i64,
bits_per_component: Option<i64>,
palette: Option<&[u8]>,
palette_base_channels: u32,
) -> (Bytes, String) {
let primary = filters.first().map(String::as_str).unwrap_or("");
match primary {
"DCTDecode" => {
(Bytes::from(raw.to_vec()), "jpeg".to_string())
}
"FlateDecode" => {
match decode_flate_to_png(
raw,
color_space,
width,
height,
bits_per_component,
palette,
palette_base_channels,
) {
Ok(png_bytes) => (Bytes::from(png_bytes), "png".to_string()),
Err(_) => {
(Bytes::from(raw.to_vec()), "raw".to_string())
}
}
}
"JPXDecode" => {
(Bytes::from(raw.to_vec()), "jpeg2000".to_string())
}
"CCITTFaxDecode" => {
(Bytes::from(raw.to_vec()), "ccitt".to_string())
}
"JBIG2Decode" => {
(Bytes::from(raw.to_vec()), "jbig2".to_string())
}
_ => {
let format = detect_image_format(raw);
(Bytes::from(raw.to_vec()), format)
}
}
}
fn detect_image_format(data: &[u8]) -> String {
if data.starts_with(b"\xff\xd8\xff") {
"jpeg".to_string()
} else if data.starts_with(b"\x89PNG\r\n\x1a\n") {
"png".to_string()
} else if data.starts_with(b"GIF8") {
"gif".to_string()
} else if data.starts_with(b"II") || data.starts_with(b"MM") {
"tiff".to_string()
} else if data.starts_with(b"BM") {
"bmp".to_string()
} else {
"raw".to_string()
}
}
#[cfg(feature = "pdf")]
fn decode_flate_to_png(
raw: &[u8],
color_space: Option<&str>,
width: i64,
height: i64,
bits_per_component: Option<i64>,
palette: Option<&[u8]>,
palette_base_channels: u32,
) -> std::result::Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
use flate2::read::ZlibDecoder;
use image::ImageEncoder;
use std::io::Read;
let mut decoder = ZlibDecoder::new(raw);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed)?;
let is_indexed = color_space.map(|cs| cs.contains("Indexed")).unwrap_or(false);
let index_channels: u32 = if is_indexed {
1
} else {
match color_space {
Some(cs) if cs.contains("RGB") => 3,
Some(cs) if cs.contains("CMYK") => 4,
Some(cs) if cs.contains("Gray") || cs.contains("grey") => 1,
_ => 3, }
};
let bpc = bits_per_component.unwrap_or(8) as u32;
let w = width as u32;
let h = height as u32;
let bytes_per_channel = bpc.div_ceil(8);
let raw_row_stride = w * index_channels * bytes_per_channel; let pred_row_stride = raw_row_stride + 1;
let pixel_data: Vec<u8> = if h > 0 && decompressed.len() as u32 == pred_row_stride * h {
let mut pixels = Vec::with_capacity((raw_row_stride * h) as usize);
for row in 0..h {
let row_start = (row * pred_row_stride) as usize;
pixels.extend_from_slice(&decompressed[row_start + 1..row_start + pred_row_stride as usize]);
}
pixels
} else {
decompressed
};
if is_indexed {
if let Some(palette_data) = palette {
let base_ch = if palette_base_channels > 0 {
palette_base_channels
} else {
3 };
let mut expanded = Vec::with_capacity(pixel_data.len() * base_ch as usize);
for &idx in &pixel_data {
let offset = idx as usize * base_ch as usize;
if offset + base_ch as usize <= palette_data.len() {
expanded.extend_from_slice(&palette_data[offset..offset + base_ch as usize]);
} else {
expanded.extend(std::iter::repeat_n(0u8, base_ch as usize));
}
}
let color_type = match base_ch {
1 => image::ColorType::L8,
3 => image::ColorType::Rgb8,
4 => image::ColorType::Rgba8,
_ => image::ColorType::Rgb8,
};
let expected_len = (w * h * base_ch) as usize;
if expanded.len() != expected_len {
return Err(format!(
"PDF indexed image buffer length mismatch after palette expansion: \
expected {expected_len} bytes ({w}x{h} px, {base_ch} ch) but got {} bytes",
expanded.len()
)
.into());
}
let mut png_bytes: Vec<u8> = Vec::new();
let mut cursor = std::io::Cursor::new(&mut png_bytes);
image::codecs::png::PngEncoder::new(&mut cursor)
.write_image(&expanded, w, h, color_type.into())
.map_err(|e| format!("PNG encoding failed for indexed image: {e}"))?;
return Ok(png_bytes);
}
let expected_len = (w * h) as usize;
if pixel_data.len() != expected_len {
return Err(format!(
"PDF indexed image buffer length mismatch (grayscale fallback): \
expected {expected_len} bytes ({w}x{h} px, 1 ch) but got {} bytes",
pixel_data.len()
)
.into());
}
let mut png_bytes: Vec<u8> = Vec::new();
let mut cursor = std::io::Cursor::new(&mut png_bytes);
image::codecs::png::PngEncoder::new(&mut cursor)
.write_image(&pixel_data, w, h, image::ColorType::L8.into())
.map_err(|e| format!("PNG encoding failed for indexed image (grayscale fallback): {e}"))?;
return Ok(png_bytes);
}
let channels = index_channels;
let color_type = match (channels, bpc) {
(1, 8) => image::ColorType::L8,
(1, 16) => image::ColorType::L16,
(3, 8) => image::ColorType::Rgb8,
(3, 16) => image::ColorType::Rgb16,
(4, 8) => image::ColorType::Rgba8,
_ => image::ColorType::Rgb8,
};
let expected_len = (w * h * channels * bytes_per_channel) as usize;
if pixel_data.len() != expected_len {
return Err(format!(
"PDF image buffer length mismatch: expected {expected_len} bytes \
({w}x{h} px, {channels} ch) but got {} bytes — skipping malformed image",
pixel_data.len()
)
.into());
}
let mut png_bytes: Vec<u8> = Vec::new();
let mut cursor = std::io::Cursor::new(&mut png_bytes);
image::codecs::png::PngEncoder::new(&mut cursor)
.write_image(&pixel_data, w, h, color_type.into())
.map_err(|e| format!("PNG encoding failed: {e}"))?;
Ok(png_bytes)
}
#[cfg(feature = "pdf")]
fn extract_indexed_palette(dict: &lopdf::Dictionary, document: &Document) -> Option<(Vec<u8>, u32)> {
use lopdf::Object;
let cs_obj = dict.get(b"ColorSpace").ok()?;
let array = match cs_obj {
Object::Array(arr) => arr,
_ => return None,
};
if array.len() < 4 {
return None;
}
let name = array[0].as_name().ok()?;
if name != b"Indexed" {
return None;
}
let base_channels = match &array[1] {
Object::Name(name) => {
let name_str = String::from_utf8_lossy(name);
if name_str.contains("RGB") {
3u32
} else if name_str.contains("CMYK") {
4
} else if name_str.contains("Gray") || name_str.contains("grey") {
1
} else {
3 }
}
Object::Array(base_arr) => {
if let Some(first) = base_arr.first() {
let name_str = String::from_utf8_lossy(first.as_name().unwrap_or(b""));
if name_str.contains("ICCBased") {
if let Some(stream_ref) = base_arr.get(1)
&& let Ok(obj_id) = stream_ref.as_reference()
&& let Ok(obj) = document.get_object(obj_id)
&& let Ok(stream) = obj.as_stream()
&& let Ok(n) = stream.dict.get(b"N")
&& let Ok(n_val) = n.as_i64()
{
return extract_palette_data(&array[3], document).map(|data| (data, n_val as u32));
}
3 } else {
3
}
} else {
3
}
}
_ => 3,
};
extract_palette_data(&array[3], document).map(|data| (data, base_channels))
}
#[cfg(feature = "pdf")]
fn extract_palette_data(lookup: &lopdf::Object, document: &Document) -> Option<Vec<u8>> {
use lopdf::Object;
match lookup {
Object::String(bytes, _) => Some(bytes.clone()),
Object::Reference(obj_id) => {
let obj = document.get_object(*obj_id).ok()?;
match obj {
Object::Stream(stream) => {
Some(stream.content.clone())
}
Object::String(bytes, _) => Some(bytes.clone()),
_ => None,
}
}
_ => None,
}
}
impl PdfImageExtractor {
pub fn new(pdf_bytes: &[u8]) -> Result<Self> {
Self::new_with_password(pdf_bytes, None)
}
pub fn new_with_password(pdf_bytes: &[u8], password: Option<&str>) -> Result<Self> {
let mut doc =
Document::load_mem(pdf_bytes).map_err(|e| PdfError::InvalidPdf(format!("Failed to load PDF: {}", e)))?;
if doc.is_encrypted() {
if let Some(pwd) = password {
doc.decrypt(pwd).map_err(|_| PdfError::InvalidPassword)?;
} else {
return Err(PdfError::PasswordRequired);
}
}
Ok(Self { document: doc })
}
pub fn extract_images(&self) -> Result<Vec<PdfImage>> {
let mut all_images = Vec::new();
let pages = self.document.get_pages();
for (page_num, page_id) in pages.iter() {
let images = self
.document
.get_page_images(*page_id)
.map_err(|e| PdfError::MetadataExtractionFailed(format!("Failed to get page images: {}", e)))?;
for (img_index, img) in images.iter().enumerate() {
let filters = img.filters.clone().unwrap_or_default();
#[cfg(feature = "pdf")]
let (palette, palette_base_channels) = extract_indexed_palette(img.origin_dict, &self.document)
.map(|(p, ch)| (Some(p), ch))
.unwrap_or((None, 0));
#[cfg(feature = "pdf")]
let (data, decoded_format) = decode_image_data(
img.content,
&filters,
img.color_space.as_deref(),
img.width,
img.height,
img.bits_per_component,
palette.as_deref(),
palette_base_channels,
);
#[cfg(not(feature = "pdf"))]
let (data, decoded_format) = (Bytes::from(img.content.to_vec()), "raw".to_string());
all_images.push(PdfImage {
page_number: *page_num as usize,
image_index: img_index + 1,
width: img.width,
height: img.height,
color_space: img.color_space.clone(),
bits_per_component: img.bits_per_component,
filters,
data,
decoded_format,
});
}
}
Ok(all_images)
}
pub fn extract_images_from_page(&self, page_number: u32) -> Result<Vec<PdfImage>> {
let pages = self.document.get_pages();
let page_id = pages
.get(&page_number)
.ok_or(PdfError::PageNotFound(page_number as usize))?;
let images = self
.document
.get_page_images(*page_id)
.map_err(|e| PdfError::MetadataExtractionFailed(format!("Failed to get page images: {}", e)))?;
let mut page_images = Vec::new();
for (img_index, img) in images.iter().enumerate() {
let filters = img.filters.clone().unwrap_or_default();
#[cfg(feature = "pdf")]
let (palette, palette_base_channels) = extract_indexed_palette(img.origin_dict, &self.document)
.map(|(p, ch)| (Some(p), ch))
.unwrap_or((None, 0));
#[cfg(feature = "pdf")]
let (data, decoded_format) = decode_image_data(
img.content,
&filters,
img.color_space.as_deref(),
img.width,
img.height,
img.bits_per_component,
palette.as_deref(),
palette_base_channels,
);
#[cfg(not(feature = "pdf"))]
let (data, decoded_format) = (Bytes::from(img.content.to_vec()), "raw".to_string());
page_images.push(PdfImage {
page_number: page_number as usize,
image_index: img_index + 1,
width: img.width,
height: img.height,
color_space: img.color_space.clone(),
bits_per_component: img.bits_per_component,
filters,
data,
decoded_format,
});
}
Ok(page_images)
}
pub fn get_image_count(&self) -> Result<usize> {
let images = self.extract_images()?;
Ok(images.len())
}
}
pub fn extract_images_from_pdf(pdf_bytes: &[u8]) -> Result<Vec<PdfImage>> {
let extractor = PdfImageExtractor::new(pdf_bytes)?;
extractor.extract_images()
}
pub fn extract_images_from_pdf_with_password(pdf_bytes: &[u8], password: &str) -> Result<Vec<PdfImage>> {
let extractor = PdfImageExtractor::new_with_password(pdf_bytes, Some(password))?;
extractor.extract_images()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extractor_creation() {
let result = PdfImageExtractor::new(b"not a pdf");
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), PdfError::InvalidPdf(_)));
}
#[test]
fn test_extract_images_invalid_pdf() {
let result = extract_images_from_pdf(b"not a pdf");
assert!(result.is_err());
}
#[test]
fn test_extract_images_empty_pdf() {
let result = extract_images_from_pdf(b"");
assert!(result.is_err());
}
#[test]
fn test_detect_image_format_jpeg() {
let jpeg_magic = b"\xff\xd8\xff\xe0some_jpeg_data";
assert_eq!(detect_image_format(jpeg_magic), "jpeg");
}
#[test]
fn test_detect_image_format_png() {
let png_magic = b"\x89PNG\r\n\x1a\nsome_png_data";
assert_eq!(detect_image_format(png_magic), "png");
}
#[test]
fn test_detect_image_format_unknown() {
assert_eq!(detect_image_format(b"\x00\x01\x02\x03"), "raw");
}
#[cfg(feature = "pdf")]
#[test]
fn test_decode_dct_passthrough() {
let jpeg_bytes = b"\xff\xd8\xff\xe0fake_jpeg";
let filters = vec!["DCTDecode".to_string()];
let (data, format) = decode_image_data(jpeg_bytes, &filters, Some("DeviceRGB"), 100, 100, Some(8), None, 0);
assert_eq!(format, "jpeg");
assert_eq!(data.as_ref(), jpeg_bytes);
}
#[cfg(feature = "pdf")]
#[test]
fn test_decode_jpx_passthrough() {
let jpx_bytes = b"\x00\x00\x00\x0cjP fake_jpx";
let filters = vec!["JPXDecode".to_string()];
let (data, format) = decode_image_data(jpx_bytes, &filters, Some("DeviceRGB"), 10, 10, Some(8), None, 0);
assert_eq!(format, "jpeg2000");
assert_eq!(data.as_ref(), jpx_bytes);
}
#[cfg(feature = "pdf")]
#[test]
fn test_decode_unknown_filter_passthrough() {
let raw_bytes = b"\x00\x01\x02\x03";
let filters = vec!["RunLengthDecode".to_string()];
let (data, format) = decode_image_data(raw_bytes, &filters, None, 2, 2, Some(8), None, 0);
assert_eq!(format, "raw");
assert_eq!(data.as_ref(), raw_bytes);
}
#[cfg(feature = "pdf")]
#[test]
fn test_decode_no_filter_uses_detection() {
let jpeg_bytes = b"\xff\xd8\xff\xe0fake";
let filters: Vec<String> = vec![];
let (data, format) = decode_image_data(jpeg_bytes, &filters, None, 10, 10, None, None, 0);
assert_eq!(format, "jpeg");
assert_eq!(data.as_ref(), jpeg_bytes);
}
#[cfg(feature = "pdf")]
#[test]
fn test_decode_flate_valid_rgb_image() {
use flate2::Compression;
use flate2::write::ZlibEncoder;
use std::io::Write;
let raw_pixels: Vec<u8> = vec![
255, 0, 0, 0, 255, 0, 0, 0, 255, 255, 255, 0, ];
let row_stride = 2 * 3; let mut rows_with_predictor: Vec<u8> = Vec::new();
for row in 0..2usize {
rows_with_predictor.push(0); rows_with_predictor.extend_from_slice(&raw_pixels[row * row_stride..(row + 1) * row_stride]);
}
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&rows_with_predictor).unwrap();
let compressed = encoder.finish().unwrap();
let filters = vec!["FlateDecode".to_string()];
let (data, format) = decode_image_data(&compressed, &filters, Some("DeviceRGB"), 2, 2, Some(8), None, 0);
assert_eq!(format, "png", "FlateDecode images should be re-encoded as PNG");
assert!(
data.starts_with(b"\x89PNG\r\n\x1a\n"),
"Decoded data should be a valid PNG (got {} bytes, first bytes: {:?})",
data.len(),
&data[..data.len().min(8)]
);
}
#[cfg(feature = "pdf")]
#[test]
fn test_decode_flate_indexed_with_palette() {
use flate2::Compression;
use flate2::write::ZlibEncoder;
use std::io::Write;
let indices: Vec<u8> = vec![0, 1, 2, 0];
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&indices).unwrap();
let compressed = encoder.finish().unwrap();
let palette: Vec<u8> = vec![
255, 0, 0, 0, 255, 0, 0, 0, 255, ];
let filters = vec!["FlateDecode".to_string()];
let (data, format) =
decode_image_data(&compressed, &filters, Some("Indexed"), 2, 2, Some(8), Some(&palette), 3);
assert_eq!(format, "png", "Indexed FlateDecode should produce PNG");
assert!(
data.starts_with(b"\x89PNG\r\n\x1a\n"),
"Decoded data should be a valid PNG"
);
}
#[cfg(feature = "pdf")]
#[test]
fn test_decode_flate_indexed_without_palette_grayscale_fallback() {
use flate2::Compression;
use flate2::write::ZlibEncoder;
use std::io::Write;
let indices: Vec<u8> = vec![10, 50, 100, 200];
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&indices).unwrap();
let compressed = encoder.finish().unwrap();
let filters = vec!["FlateDecode".to_string()];
let (data, format) = decode_image_data(&compressed, &filters, Some("Indexed"), 2, 2, Some(8), None, 0);
assert_eq!(
format, "png",
"Indexed without palette should still produce PNG (grayscale)"
);
assert!(
data.starts_with(b"\x89PNG\r\n\x1a\n"),
"Decoded data should be a valid PNG"
);
}
}