use super::error::{PdfError, Result};
use bytes::Bytes;
use lopdf::Document;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PdfImage {
pub page_number: usize,
pub image_index: usize,
pub width: i64,
pub height: i64,
pub color_space: Option<String>,
pub bits_per_component: Option<i64>,
pub filters: Vec<String>,
pub data: Bytes,
pub decoded_format: String,
}
#[derive(Debug)]
pub struct PdfImageExtractor {
document: Document,
}
#[cfg(feature = "pdf")]
#[allow(clippy::too_many_arguments)]
fn decode_image_data(
raw: &[u8],
filters: &[String],
color_space: Option<&str>,
width: i64,
height: i64,
bits_per_component: Option<i64>,
palette: Option<&[u8]>,
palette_base_channels: u32,
) -> (Bytes, String) {
let primary = filters.first().map(String::as_str).unwrap_or("");
match primary {
"DCTDecode" => {
(Bytes::from(raw.to_vec()), "jpeg".to_string())
}
"FlateDecode" => {
match decode_flate_to_png(
raw,
color_space,
width,
height,
bits_per_component,
palette,
palette_base_channels,
) {
Ok(png_bytes) => (Bytes::from(png_bytes), "png".to_string()),
Err(_) => {
(Bytes::from(raw.to_vec()), "raw".to_string())
}
}
}
"JPXDecode" => {
(Bytes::from(raw.to_vec()), "jpeg2000".to_string())
}
"CCITTFaxDecode" => {
(Bytes::from(raw.to_vec()), "ccitt".to_string())
}
"JBIG2Decode" => {
(Bytes::from(raw.to_vec()), "jbig2".to_string())
}
_ => {
let format = detect_image_format(raw);
(Bytes::from(raw.to_vec()), format)
}
}
}
fn detect_image_format(data: &[u8]) -> String {
if data.starts_with(b"\xff\xd8\xff") {
"jpeg".to_string()
} else if data.starts_with(b"\x89PNG\r\n\x1a\n") {
"png".to_string()
} else if data.starts_with(b"GIF8") {
"gif".to_string()
} else if data.starts_with(b"II") || data.starts_with(b"MM") {
"tiff".to_string()
} else if data.starts_with(b"BM") {
"bmp".to_string()
} else {
"raw".to_string()
}
}
#[cfg(feature = "pdf")]
fn decode_flate_to_png(
raw: &[u8],
color_space: Option<&str>,
width: i64,
height: i64,
bits_per_component: Option<i64>,
palette: Option<&[u8]>,
palette_base_channels: u32,
) -> std::result::Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
use flate2::read::ZlibDecoder;
use image::ImageEncoder;
use std::io::Read;
let mut decoder = ZlibDecoder::new(raw);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed)?;
let is_indexed = color_space.map(|cs| cs.contains("Indexed")).unwrap_or(false);
let index_channels: u32 = if is_indexed {
1
} else {
match color_space {
Some(cs) if cs.contains("RGB") => 3,
Some(cs) if cs.contains("CMYK") => 4,
Some(cs) if cs.contains("Gray") || cs.contains("grey") => 1,
_ => 3, }
};
let bpc = bits_per_component.unwrap_or(8) as u32;
let w = width as u32;
let h = height as u32;
let bytes_per_channel = bpc.div_ceil(8);
let raw_row_stride = w * index_channels * bytes_per_channel; let pred_row_stride = raw_row_stride + 1;
let pixel_data: Vec<u8> = if h > 0 && decompressed.len() as u32 == pred_row_stride * h {
let mut pixels = Vec::with_capacity((raw_row_stride * h) as usize);
for row in 0..h {
let row_start = (row * pred_row_stride) as usize;
pixels.extend_from_slice(&decompressed[row_start + 1..row_start + pred_row_stride as usize]);
}
pixels
} else {
decompressed
};
if is_indexed {
if let Some(palette_data) = palette {
let base_ch = if palette_base_channels > 0 {
palette_base_channels
} else {
3 };
let mut expanded = Vec::with_capacity(pixel_data.len() * base_ch as usize);
for &idx in &pixel_data {
let offset = idx as usize * base_ch as usize;
if offset + base_ch as usize <= palette_data.len() {
expanded.extend_from_slice(&palette_data[offset..offset + base_ch as usize]);
} else {
expanded.extend(std::iter::repeat_n(0u8, base_ch as usize));
}
}
let color_type = match base_ch {
1 => image::ColorType::L8,
3 => image::ColorType::Rgb8,
4 => image::ColorType::Rgba8,
_ => image::ColorType::Rgb8,
};
let expected_len = (w * h * base_ch) as usize;
if expanded.len() != expected_len {
return Err(format!(
"PDF indexed image buffer length mismatch after palette expansion: \
expected {expected_len} bytes ({w}x{h} px, {base_ch} ch) but got {} bytes",
expanded.len()
)
.into());
}
let mut png_bytes: Vec<u8> = Vec::new();
let mut cursor = std::io::Cursor::new(&mut png_bytes);
image::codecs::png::PngEncoder::new(&mut cursor)
.write_image(&expanded, w, h, color_type.into())
.map_err(|e| format!("PNG encoding failed for indexed image: {e}"))?;
return Ok(png_bytes);
}
let expected_len = (w * h) as usize;
if pixel_data.len() != expected_len {
return Err(format!(
"PDF indexed image buffer length mismatch (grayscale fallback): \
expected {expected_len} bytes ({w}x{h} px, 1 ch) but got {} bytes",
pixel_data.len()
)
.into());
}
let mut png_bytes: Vec<u8> = Vec::new();
let mut cursor = std::io::Cursor::new(&mut png_bytes);
image::codecs::png::PngEncoder::new(&mut cursor)
.write_image(&pixel_data, w, h, image::ColorType::L8.into())
.map_err(|e| format!("PNG encoding failed for indexed image (grayscale fallback): {e}"))?;
return Ok(png_bytes);
}
let channels = index_channels;
let color_type = match (channels, bpc) {
(1, 8) => image::ColorType::L8,
(1, 16) => image::ColorType::L16,
(3, 8) => image::ColorType::Rgb8,
(3, 16) => image::ColorType::Rgb16,
(4, 8) => image::ColorType::Rgba8,
_ => image::ColorType::Rgb8,
};
let expected_len = (w * h * channels * bytes_per_channel) as usize;
if pixel_data.len() != expected_len {
return Err(format!(
"PDF image buffer length mismatch: expected {expected_len} bytes \
({w}x{h} px, {channels} ch) but got {} bytes — skipping malformed image",
pixel_data.len()
)
.into());
}
let mut png_bytes: Vec<u8> = Vec::new();
let mut cursor = std::io::Cursor::new(&mut png_bytes);
image::codecs::png::PngEncoder::new(&mut cursor)
.write_image(&pixel_data, w, h, color_type.into())
.map_err(|e| format!("PNG encoding failed: {e}"))?;
Ok(png_bytes)
}
#[cfg(feature = "pdf")]
fn extract_indexed_palette(dict: &lopdf::Dictionary, document: &Document) -> Option<(Vec<u8>, u32)> {
use lopdf::Object;
let cs_obj = dict.get(b"ColorSpace").ok()?;
let array = match cs_obj {
Object::Array(arr) => arr,
_ => return None,
};
if array.len() < 4 {
return None;
}
let name = array[0].as_name().ok()?;
if name != b"Indexed" {
return None;
}
let base_channels = match &array[1] {
Object::Name(name) => {
let name_str = String::from_utf8_lossy(name);
if name_str.contains("RGB") {
3u32
} else if name_str.contains("CMYK") {
4
} else if name_str.contains("Gray") || name_str.contains("grey") {
1
} else {
3 }
}
Object::Array(base_arr) => {
if let Some(first) = base_arr.first() {
let name_str = String::from_utf8_lossy(first.as_name().unwrap_or(b""));
if name_str.contains("ICCBased") {
if let Some(stream_ref) = base_arr.get(1)
&& let Ok(obj_id) = stream_ref.as_reference()
&& let Ok(obj) = document.get_object(obj_id)
&& let Ok(stream) = obj.as_stream()
&& let Ok(n) = stream.dict.get(b"N")
&& let Ok(n_val) = n.as_i64()
{
return extract_palette_data(&array[3], document).map(|data| (data, n_val as u32));
}
3 } else {
3
}
} else {
3
}
}
_ => 3,
};
extract_palette_data(&array[3], document).map(|data| (data, base_channels))
}
#[cfg(feature = "pdf")]
fn extract_palette_data(lookup: &lopdf::Object, document: &Document) -> Option<Vec<u8>> {
use lopdf::Object;
match lookup {
Object::String(bytes, _) => Some(bytes.clone()),
Object::Reference(obj_id) => {
let obj = document.get_object(*obj_id).ok()?;
match obj {
Object::Stream(stream) => {
Some(stream.content.clone())
}
Object::String(bytes, _) => Some(bytes.clone()),
_ => None,
}
}
_ => None,
}
}
impl PdfImageExtractor {
pub fn new(pdf_bytes: &[u8]) -> Result<Self> {
Self::new_with_password(pdf_bytes, None)
}
pub fn new_with_password(pdf_bytes: &[u8], password: Option<&str>) -> Result<Self> {
let mut doc =
Document::load_mem(pdf_bytes).map_err(|e| PdfError::InvalidPdf(format!("Failed to load PDF: {}", e)))?;
if doc.is_encrypted() {
if let Some(pwd) = password {
doc.decrypt(pwd).map_err(|_| PdfError::InvalidPassword)?;
} else {
return Err(PdfError::PasswordRequired);
}
}
Ok(Self { document: doc })
}
pub fn extract_images(&self, max_images_per_page: Option<u32>) -> Result<Vec<PdfImage>> {
let mut all_images = Vec::new();
let pages = self.document.get_pages();
for (page_num, page_id) in pages.iter() {
let images = self
.document
.get_page_images(*page_id)
.map_err(|e| PdfError::MetadataExtractionFailed(format!("Failed to get page images: {}", e)))?;
if let Some(cap) = max_images_per_page
&& images.len() > cap as usize
{
tracing::warn!(
page_number = *page_num,
image_count = images.len(),
cap,
"PDF page exceeds max_images_per_page; skipping image extraction for this page"
);
continue;
}
for (img_index, img) in images.iter().enumerate() {
let filters = img.filters.clone().unwrap_or_default();
#[cfg(feature = "pdf")]
let (palette, palette_base_channels) = extract_indexed_palette(img.origin_dict, &self.document)
.map(|(p, ch)| (Some(p), ch))
.unwrap_or((None, 0));
#[cfg(feature = "pdf")]
let (data, decoded_format) = decode_image_data(
img.content,
&filters,
img.color_space.as_deref(),
img.width,
img.height,
img.bits_per_component,
palette.as_deref(),
palette_base_channels,
);
#[cfg(not(feature = "pdf"))]
let (data, decoded_format) = (Bytes::from(img.content.to_vec()), "raw".to_string());
all_images.push(PdfImage {
page_number: *page_num as usize,
image_index: img_index + 1,
width: img.width,
height: img.height,
color_space: img.color_space.clone(),
bits_per_component: img.bits_per_component,
filters,
data,
decoded_format,
});
}
}
Ok(all_images)
}
pub fn extract_images_from_page(&self, page_number: u32) -> Result<Vec<PdfImage>> {
let pages = self.document.get_pages();
let page_id = pages
.get(&page_number)
.ok_or(PdfError::PageNotFound(page_number as usize))?;
let images = self
.document
.get_page_images(*page_id)
.map_err(|e| PdfError::MetadataExtractionFailed(format!("Failed to get page images: {}", e)))?;
let mut page_images = Vec::new();
for (img_index, img) in images.iter().enumerate() {
let filters = img.filters.clone().unwrap_or_default();
#[cfg(feature = "pdf")]
let (palette, palette_base_channels) = extract_indexed_palette(img.origin_dict, &self.document)
.map(|(p, ch)| (Some(p), ch))
.unwrap_or((None, 0));
#[cfg(feature = "pdf")]
let (data, decoded_format) = decode_image_data(
img.content,
&filters,
img.color_space.as_deref(),
img.width,
img.height,
img.bits_per_component,
palette.as_deref(),
palette_base_channels,
);
#[cfg(not(feature = "pdf"))]
let (data, decoded_format) = (Bytes::from(img.content.to_vec()), "raw".to_string());
page_images.push(PdfImage {
page_number: page_number as usize,
image_index: img_index + 1,
width: img.width,
height: img.height,
color_space: img.color_space.clone(),
bits_per_component: img.bits_per_component,
filters,
data,
decoded_format,
});
}
Ok(page_images)
}
pub fn get_image_count(&self) -> Result<usize> {
let images = self.extract_images(None)?;
Ok(images.len())
}
}
pub fn extract_images_from_pdf(pdf_bytes: &[u8], max_images_per_page: Option<u32>) -> Result<Vec<PdfImage>> {
let extractor = PdfImageExtractor::new(pdf_bytes)?;
extractor.extract_images(max_images_per_page)
}
pub fn extract_images_from_pdf_with_password(
pdf_bytes: &[u8],
password: &str,
max_images_per_page: Option<u32>,
) -> Result<Vec<PdfImage>> {
let extractor = PdfImageExtractor::new_with_password(pdf_bytes, Some(password))?;
extractor.extract_images(max_images_per_page)
}
#[cfg(feature = "pdf")]
pub fn reextract_raw_images_via_pdfium(pdf_bytes: &[u8], images: &mut [PdfImage]) -> Result<u32> {
use image::ImageEncoder;
use pdfium_render::prelude::*;
let needs_fallback = images
.iter()
.any(|img| matches!(img.decoded_format.as_str(), "raw" | "ccitt" | "jbig2"));
if !needs_fallback {
return Ok(0);
}
let pdfium = super::bindings::bind_pdfium(PdfError::RenderingFailed, "image fallback rendering", None)?;
let document = pdfium
.load_pdf_from_byte_slice(pdf_bytes, None)
.map_err(|e| PdfError::InvalidPdf(super::error::format_pdfium_error(e)))?;
let mut reextracted = 0u32;
for img in images.iter_mut() {
if !matches!(img.decoded_format.as_str(), "raw" | "ccitt" | "jbig2") {
continue;
}
let page_idx: i32 = img.page_number.saturating_sub(1) as i32;
let Ok(page) = document.pages().get(page_idx) else {
continue;
};
let target_index = img.image_index;
let mut current_image = 0usize;
for obj in page.objects().iter() {
if let Some(image_obj) = obj.as_image_object() {
current_image += 1;
if current_image == target_index {
if let Ok(dynamic_image) = image_obj.get_processed_image(&document) {
let w = dynamic_image.width();
let h = dynamic_image.height();
let rgba = dynamic_image.to_rgba8();
let mut png_buf: Vec<u8> = Vec::new();
if image::codecs::png::PngEncoder::new(&mut png_buf)
.write_image(rgba.as_raw(), w, h, image::ExtendedColorType::Rgba8)
.is_ok()
{
img.data = Bytes::from(png_buf);
img.decoded_format = "png".to_string();
img.width = w as i64;
img.height = h as i64;
reextracted += 1;
}
}
break;
}
}
}
}
Ok(reextracted)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extractor_creation() {
let result = PdfImageExtractor::new(b"not a pdf");
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), PdfError::InvalidPdf(_)));
}
#[test]
fn test_extract_images_invalid_pdf() {
let result = extract_images_from_pdf(b"not a pdf", None);
assert!(result.is_err());
}
#[test]
fn test_extract_images_empty_pdf() {
let result = extract_images_from_pdf(b"", None);
assert!(result.is_err());
}
#[test]
fn test_detect_image_format_jpeg() {
let jpeg_magic = b"\xff\xd8\xff\xe0some_jpeg_data";
assert_eq!(detect_image_format(jpeg_magic), "jpeg");
}
#[test]
fn test_detect_image_format_png() {
let png_magic = b"\x89PNG\r\n\x1a\nsome_png_data";
assert_eq!(detect_image_format(png_magic), "png");
}
#[test]
fn test_detect_image_format_unknown() {
assert_eq!(detect_image_format(b"\x00\x01\x02\x03"), "raw");
}
#[cfg(feature = "pdf")]
#[test]
fn test_decode_dct_passthrough() {
let jpeg_bytes = b"\xff\xd8\xff\xe0fake_jpeg";
let filters = vec!["DCTDecode".to_string()];
let (data, format) = decode_image_data(jpeg_bytes, &filters, Some("DeviceRGB"), 100, 100, Some(8), None, 0);
assert_eq!(format, "jpeg");
assert_eq!(data.as_ref(), jpeg_bytes);
}
#[cfg(feature = "pdf")]
#[test]
fn test_decode_jpx_passthrough() {
let jpx_bytes = b"\x00\x00\x00\x0cjP fake_jpx";
let filters = vec!["JPXDecode".to_string()];
let (data, format) = decode_image_data(jpx_bytes, &filters, Some("DeviceRGB"), 10, 10, Some(8), None, 0);
assert_eq!(format, "jpeg2000");
assert_eq!(data.as_ref(), jpx_bytes);
}
#[cfg(feature = "pdf")]
#[test]
fn test_decode_unknown_filter_passthrough() {
let raw_bytes = b"\x00\x01\x02\x03";
let filters = vec!["RunLengthDecode".to_string()];
let (data, format) = decode_image_data(raw_bytes, &filters, None, 2, 2, Some(8), None, 0);
assert_eq!(format, "raw");
assert_eq!(data.as_ref(), raw_bytes);
}
#[cfg(feature = "pdf")]
#[test]
fn test_decode_no_filter_uses_detection() {
let jpeg_bytes = b"\xff\xd8\xff\xe0fake";
let filters: Vec<String> = vec![];
let (data, format) = decode_image_data(jpeg_bytes, &filters, None, 10, 10, None, None, 0);
assert_eq!(format, "jpeg");
assert_eq!(data.as_ref(), jpeg_bytes);
}
#[cfg(feature = "pdf")]
#[test]
fn test_decode_flate_valid_rgb_image() {
use flate2::Compression;
use flate2::write::ZlibEncoder;
use std::io::Write;
let raw_pixels: Vec<u8> = vec![
255, 0, 0, 0, 255, 0, 0, 0, 255, 255, 255, 0, ];
let row_stride = 2 * 3; let mut rows_with_predictor: Vec<u8> = Vec::new();
for row in 0..2usize {
rows_with_predictor.push(0); rows_with_predictor.extend_from_slice(&raw_pixels[row * row_stride..(row + 1) * row_stride]);
}
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&rows_with_predictor).unwrap();
let compressed = encoder.finish().unwrap();
let filters = vec!["FlateDecode".to_string()];
let (data, format) = decode_image_data(&compressed, &filters, Some("DeviceRGB"), 2, 2, Some(8), None, 0);
assert_eq!(format, "png", "FlateDecode images should be re-encoded as PNG");
assert!(
data.starts_with(b"\x89PNG\r\n\x1a\n"),
"Decoded data should be a valid PNG (got {} bytes, first bytes: {:?})",
data.len(),
&data[..data.len().min(8)]
);
}
#[cfg(feature = "pdf")]
#[test]
fn test_decode_flate_indexed_with_palette() {
use flate2::Compression;
use flate2::write::ZlibEncoder;
use std::io::Write;
let indices: Vec<u8> = vec![0, 1, 2, 0];
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&indices).unwrap();
let compressed = encoder.finish().unwrap();
let palette: Vec<u8> = vec![
255, 0, 0, 0, 255, 0, 0, 0, 255, ];
let filters = vec!["FlateDecode".to_string()];
let (data, format) =
decode_image_data(&compressed, &filters, Some("Indexed"), 2, 2, Some(8), Some(&palette), 3);
assert_eq!(format, "png", "Indexed FlateDecode should produce PNG");
assert!(
data.starts_with(b"\x89PNG\r\n\x1a\n"),
"Decoded data should be a valid PNG"
);
}
#[cfg(feature = "pdf")]
#[test]
fn test_max_images_per_page_cap_skips_dense_page() {
use flate2::Compression;
use flate2::write::ZlibEncoder;
use lopdf::{Document, Object, Stream, dictionary};
use std::io::Write;
let make_compressed_pixel = || {
let raw = vec![255u8, 0u8, 0u8]; let mut enc = ZlibEncoder::new(Vec::new(), Compression::default());
enc.write_all(&raw).unwrap();
enc.finish().unwrap()
};
let mut doc = Document::with_version("1.4");
let add_image = |doc: &mut Document| {
let stream = Stream::new(
dictionary! {
"Type" => Object::Name(b"XObject".to_vec()),
"Subtype" => Object::Name(b"Image".to_vec()),
"Width" => 1i64,
"Height" => 1i64,
"ColorSpace" => Object::Name(b"DeviceRGB".to_vec()),
"BitsPerComponent" => 8i64,
"Filter" => Object::Name(b"FlateDecode".to_vec())
},
make_compressed_pixel(),
);
doc.add_object(stream)
};
let img1a = add_image(&mut doc);
let img1b = add_image(&mut doc);
let img1c = add_image(&mut doc);
let img2a = add_image(&mut doc);
let pages_id = doc.new_object_id();
let page1_id = doc.add_object(dictionary! {
"Type" => Object::Name(b"Page".to_vec()),
"Parent" => Object::Reference(pages_id),
"MediaBox" => Object::Array(vec![0i64.into(), 0i64.into(), 100i64.into(), 100i64.into()]),
"Resources" => dictionary! {
"XObject" => dictionary! {
"Im0" => Object::Reference(img1a),
"Im1" => Object::Reference(img1b),
"Im2" => Object::Reference(img1c)
}
}
});
let page2_id = doc.add_object(dictionary! {
"Type" => Object::Name(b"Page".to_vec()),
"Parent" => Object::Reference(pages_id),
"MediaBox" => Object::Array(vec![0i64.into(), 0i64.into(), 100i64.into(), 100i64.into()]),
"Resources" => dictionary! {
"XObject" => dictionary! {
"Im0" => Object::Reference(img2a)
}
}
});
doc.set_object(
pages_id,
dictionary! {
"Type" => Object::Name(b"Pages".to_vec()),
"Kids" => Object::Array(vec![Object::Reference(page1_id), Object::Reference(page2_id)]),
"Count" => 2i64
},
);
let catalog_id = doc.add_object(dictionary! {
"Type" => Object::Name(b"Catalog".to_vec()),
"Pages" => Object::Reference(pages_id)
});
doc.trailer.set("Root", Object::Reference(catalog_id));
let mut pdf_bytes = Vec::new();
doc.save_to(&mut pdf_bytes).unwrap();
let all = extract_images_from_pdf(&pdf_bytes, None).expect("should parse");
assert_eq!(all.len(), 4, "without cap: all images extracted");
let capped = extract_images_from_pdf(&pdf_bytes, Some(2)).expect("should parse");
assert_eq!(capped.len(), 1, "with cap=2: only page-2 image extracted");
assert_eq!(capped[0].width, 1);
let zero_capped = extract_images_from_pdf(&pdf_bytes, Some(0)).expect("should parse");
assert_eq!(zero_capped.len(), 0, "cap=0: all pages skipped");
let exact_capped = extract_images_from_pdf(&pdf_bytes, Some(3)).expect("should parse");
assert_eq!(exact_capped.len(), 4, "cap=exactly-page-count: all images extracted");
}
#[cfg(feature = "pdf")]
#[test]
fn test_decode_flate_indexed_without_palette_grayscale_fallback() {
use flate2::Compression;
use flate2::write::ZlibEncoder;
use std::io::Write;
let indices: Vec<u8> = vec![10, 50, 100, 200];
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&indices).unwrap();
let compressed = encoder.finish().unwrap();
let filters = vec!["FlateDecode".to_string()];
let (data, format) = decode_image_data(&compressed, &filters, Some("Indexed"), 2, 2, Some(8), None, 0);
assert_eq!(
format, "png",
"Indexed without palette should still produce PNG (grayscale)"
);
assert!(
data.starts_with(b"\x89PNG\r\n\x1a\n"),
"Decoded data should be a valid PNG"
);
}
}