use crate::deflate;
use crate::deflate_encode;
use crate::pbm::Bitmap;
const PNG_SIGNATURE: [u8; 8] = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
pub fn encode_grayscale(bitmap: &Bitmap) -> Vec<u8> {
let w = bitmap.width as u32;
let h = bitmap.height as u32;
let row_bytes = bitmap.width + 1;
let raw_len = row_bytes * bitmap.height;
let raw = build_filtered_raw(bitmap);
debug_assert_eq!(raw.len(), raw_len);
let (deflated, adler) = deflate_encode::deflate_fixed_with_adler32(&raw);
let idat_payload_size = 2 + deflated.len() + 4; let total = 8 + 25 + 12 + idat_payload_size + 12;
let mut out = Vec::with_capacity(total);
out.extend_from_slice(&PNG_SIGNATURE);
let mut ihdr = [0u8; 13];
ihdr[0..4].copy_from_slice(&w.to_be_bytes());
ihdr[4..8].copy_from_slice(&h.to_be_bytes());
ihdr[8] = 8; write_chunk(&mut out, b"IHDR", &ihdr);
let idat_len_pos = out.len();
out.extend_from_slice(&[0u8; 4]); let idat_start = out.len();
out.extend_from_slice(b"IDAT");
out.push(0x78); out.push(0x01); out.extend_from_slice(&deflated);
out.extend_from_slice(&adler.to_be_bytes());
let idat_body_len = (out.len() - idat_start - 4) as u32;
out[idat_len_pos..idat_len_pos + 4].copy_from_slice(&idat_body_len.to_be_bytes());
let crc = crc32(&out[idat_start..]);
out.extend_from_slice(&crc.to_be_bytes());
write_chunk(&mut out, b"IEND", &[]);
out
}
fn build_filtered_raw(bitmap: &Bitmap) -> Vec<u8> {
let w = bitmap.width;
let h = bitmap.height;
let mut out = Vec::with_capacity(h * (w + 1));
let mut prev_row = vec![0u8; w];
let mut cur_row = vec![0u8; w];
let mut up_row = vec![0u8; w];
for y in 0..h {
let pixels = &bitmap.pixels[y * w..(y + 1) * w];
for (i, &px) in pixels.iter().enumerate() {
cur_row[i] = if px { 0 } else { 255 };
}
if y == 0 {
out.push(0u8); out.extend_from_slice(&cur_row);
} else {
let mut cost_up: u64 = 0;
let mut cost_none: u64 = 0;
for i in 0..w {
let cur = cur_row[i];
let v = cur.wrapping_sub(prev_row[i]);
up_row[i] = v;
cost_up += (v as i8).unsigned_abs() as u64;
cost_none += (cur as i8).unsigned_abs() as u64;
}
if cost_up <= cost_none {
out.push(2u8);
out.extend_from_slice(&up_row);
} else {
out.push(0u8);
out.extend_from_slice(&cur_row);
}
}
std::mem::swap(&mut prev_row, &mut cur_row);
}
out
}
fn write_chunk(out: &mut Vec<u8>, kind: &[u8; 4], data: &[u8]) {
out.extend_from_slice(&(data.len() as u32).to_be_bytes());
let crc_start = out.len();
out.extend_from_slice(kind);
out.extend_from_slice(data);
let crc = crc32(&out[crc_start..]);
out.extend_from_slice(&crc.to_be_bytes());
}
const CRC32_POLY: u32 = 0xEDB88320;
const CRC32_TABLE: [u32; 256] = {
let mut table = [0u32; 256];
let mut i = 0u32;
while i < 256 {
let mut c = i;
let mut j = 0;
while j < 8 {
c = if c & 1 == 1 { (c >> 1) ^ CRC32_POLY } else { c >> 1 };
j += 1;
}
table[i as usize] = c;
i += 1;
}
table
};
fn crc32(data: &[u8]) -> u32 {
let mut crc: u32 = 0xFFFF_FFFF;
for &b in data {
crc = CRC32_TABLE[((crc ^ b as u32) & 0xFF) as usize] ^ (crc >> 8);
}
!crc
}
const ADLER_MOD: u32 = 65521;
fn adler32(data: &[u8]) -> u32 {
let mut a: u32 = 1;
let mut b: u32 = 0;
const NMAX: usize = 5552;
let mut i = 0;
while i < data.len() {
let end = (i + NMAX).min(data.len());
for &x in &data[i..end] {
a += x as u32;
b += a;
}
a %= ADLER_MOD;
b %= ADLER_MOD;
i = end;
}
(b << 16) | a
}
pub fn decode(data: &[u8]) -> Result<Bitmap, &'static str> {
if data.len() < 8 || data[..8] != PNG_SIGNATURE {
return Err("PNG: bad signature");
}
let chunks = parse_chunks(&data[8..])?;
let ihdr = chunks
.iter()
.find(|c| c.kind == *b"IHDR")
.ok_or("PNG: missing IHDR")?;
let header = parse_ihdr(ihdr.data)?;
let mut zlib_stream = Vec::new();
for c in &chunks {
if c.kind == *b"IDAT" {
zlib_stream.extend_from_slice(c.data);
}
}
if zlib_stream.is_empty() {
return Err("PNG: no IDAT");
}
let palette: Option<Vec<[u8; 3]>> = chunks
.iter()
.find(|c| c.kind == *b"PLTE")
.map(|c| {
if c.data.len() % 3 != 0 {
return Err("PNG: PLTE not multiple of 3");
}
Ok(c.data.chunks(3).map(|t| [t[0], t[1], t[2]]).collect())
})
.transpose()?;
let raw = zlib_unwrap(&zlib_stream)?;
let unfiltered = unfilter(&raw, &header)?;
pixels_to_bitmap(&unfiltered, &header, palette.as_deref())
}
#[derive(Debug)]
struct Chunk<'a> {
kind: [u8; 4],
data: &'a [u8],
}
fn parse_chunks(mut data: &[u8]) -> Result<Vec<Chunk<'_>>, &'static str> {
let mut out = Vec::new();
while !data.is_empty() {
if data.len() < 12 {
return Err("PNG: chunk truncated");
}
let length = u32::from_be_bytes(data[..4].try_into().unwrap()) as usize;
if data.len() < 12 + length {
return Err("PNG: chunk length exceeds file");
}
let mut kind = [0u8; 4];
kind.copy_from_slice(&data[4..8]);
let chunk_data = &data[8..8 + length];
let crc_expected = u32::from_be_bytes(data[8 + length..12 + length].try_into().unwrap());
let crc_actual = crc32(&data[4..8 + length]);
if crc_actual != crc_expected {
return Err("PNG: chunk CRC mismatch");
}
out.push(Chunk { kind, data: chunk_data });
if &kind == b"IEND" {
return Ok(out);
}
data = &data[12 + length..];
}
Err("PNG: missing IEND")
}
#[derive(Debug, Clone, Copy)]
struct IhdrInfo {
width: u32,
height: u32,
bit_depth: u8,
color_type: u8,
}
impl IhdrInfo {
fn channels(&self) -> usize {
match self.color_type {
0 => 1, 2 => 3, 3 => 1, 4 => 2, 6 => 4, _ => 0,
}
}
fn row_bytes(&self) -> usize {
let bits_per_pixel = self.bit_depth as usize * self.channels();
(self.width as usize * bits_per_pixel + 7) / 8
}
}
fn parse_ihdr(data: &[u8]) -> Result<IhdrInfo, &'static str> {
if data.len() != 13 {
return Err("PNG: IHDR not 13 bytes");
}
let info = IhdrInfo {
width: u32::from_be_bytes(data[0..4].try_into().unwrap()),
height: u32::from_be_bytes(data[4..8].try_into().unwrap()),
bit_depth: data[8],
color_type: data[9],
};
if data[10] != 0 {
return Err("PNG: unsupported compression method");
}
if data[11] != 0 {
return Err("PNG: unsupported filter method");
}
if data[12] != 0 {
return Err("PNG: interlaced PNG not supported");
}
if info.channels() == 0 {
return Err("PNG: bad color_type");
}
Ok(info)
}
fn zlib_unwrap(stream: &[u8]) -> Result<Vec<u8>, &'static str> {
if stream.len() < 6 {
return Err("zlib: stream too short");
}
let cmf = stream[0];
let flg = stream[1];
if (cmf & 0x0F) != 8 {
return Err("zlib: not DEFLATE");
}
if ((cmf as u32) * 256 + flg as u32) % 31 != 0 {
return Err("zlib: bad header checksum");
}
if flg & 0x20 != 0 {
return Err("zlib: FDICT not supported");
}
let deflated = &stream[2..stream.len() - 4];
let inflated = deflate::inflate(deflated)?;
let adler_expected = u32::from_be_bytes(stream[stream.len() - 4..].try_into().unwrap());
let adler_actual = adler32(&inflated);
if adler_actual != adler_expected {
return Err("zlib: Adler-32 mismatch");
}
Ok(inflated)
}
fn unfilter(raw: &[u8], h: &IhdrInfo) -> Result<Vec<u8>, &'static str> {
let row_bytes = h.row_bytes();
let height = h.height as usize;
if raw.len() != (row_bytes + 1) * height {
return Err("PNG: unfilter expected size mismatch");
}
let bpp = (h.bit_depth as usize * h.channels() + 7) / 8; let bpp = bpp.max(1); let mut out = vec![0u8; row_bytes * height];
for y in 0..height {
let in_row = &raw[y * (row_bytes + 1)..(y + 1) * (row_bytes + 1)];
let filter_type = in_row[0];
let in_row_data = &in_row[1..];
let prev_row_start = if y == 0 { None } else { Some((y - 1) * row_bytes) };
for x in 0..row_bytes {
let cur = in_row_data[x];
let left = if x >= bpp { out[y * row_bytes + x - bpp] } else { 0 };
let up = match prev_row_start {
Some(s) => out[s + x],
None => 0,
};
let up_left = match prev_row_start {
Some(s) if x >= bpp => out[s + x - bpp],
_ => 0,
};
let value = match filter_type {
0 => cur,
1 => cur.wrapping_add(left),
2 => cur.wrapping_add(up),
3 => cur.wrapping_add(((left as u16 + up as u16) / 2) as u8),
4 => cur.wrapping_add(paeth_predictor(left, up, up_left)),
_ => return Err("PNG: bad filter type"),
};
out[y * row_bytes + x] = value;
}
}
Ok(out)
}
fn paeth_predictor(a: u8, b: u8, c: u8) -> u8 {
let p = a as i32 + b as i32 - c as i32;
let pa = (p - a as i32).abs();
let pb = (p - b as i32).abs();
let pc = (p - c as i32).abs();
if pa <= pb && pa <= pc {
a
} else if pb <= pc {
b
} else {
c
}
}
fn pixels_to_bitmap(
pixels: &[u8],
h: &IhdrInfo,
palette: Option<&[[u8; 3]]>,
) -> Result<Bitmap, &'static str> {
let w = h.width as usize;
let height = h.height as usize;
let mut bm = Bitmap::new(w, height);
let row_bytes = h.row_bytes();
for y in 0..height {
let row = &pixels[y * row_bytes..(y + 1) * row_bytes];
for x in 0..w {
let brightness = sample_pixel(row, x, h, palette)?;
bm.set(x, y, brightness < 128);
}
}
Ok(bm)
}
fn sample_pixel(
row: &[u8],
x: usize,
h: &IhdrInfo,
palette: Option<&[[u8; 3]]>,
) -> Result<u8, &'static str> {
let bd = h.bit_depth as usize;
match h.color_type {
0 => {
let v = read_sub_byte_sample(row, x, bd);
Ok(scale_to_u8(v, bd))
}
2 => {
if bd != 8 {
return Err("PNG: RGB only supported at 8-bit depth");
}
let r = row[x * 3];
let g = row[x * 3 + 1];
let b = row[x * 3 + 2];
Ok(((r as u16 + g as u16 + b as u16) / 3) as u8)
}
3 => {
let idx = read_sub_byte_sample(row, x, bd) as usize;
let pal = palette.ok_or("PNG: palette image missing PLTE")?;
if idx >= pal.len() {
return Err("PNG: palette index out of range");
}
let [r, g, b] = pal[idx];
Ok(((r as u16 + g as u16 + b as u16) / 3) as u8)
}
4 => {
if bd != 8 {
return Err("PNG: gray+alpha only at 8-bit");
}
Ok(row[x * 2]) }
6 => {
if bd != 8 {
return Err("PNG: RGBA only at 8-bit");
}
let r = row[x * 4];
let g = row[x * 4 + 1];
let b = row[x * 4 + 2];
Ok(((r as u16 + g as u16 + b as u16) / 3) as u8)
}
_ => Err("PNG: unknown color_type"),
}
}
fn read_sub_byte_sample(row: &[u8], x: usize, bit_depth: usize) -> u8 {
match bit_depth {
1 => (row[x / 8] >> (7 - x % 8)) & 0x01,
2 => (row[x / 4] >> (6 - 2 * (x % 4))) & 0x03,
4 => (row[x / 2] >> (4 - 4 * (x % 2))) & 0x0F,
8 => row[x],
_ => 0,
}
}
fn scale_to_u8(v: u8, bit_depth: usize) -> u8 {
match bit_depth {
1 => v * 255,
2 => v * 85,
4 => v * 17,
8 => v,
_ => v,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn crc32_known_value() {
assert_eq!(crc32(b"123456789"), 0xCBF43926);
assert_eq!(crc32(b""), 0x00000000);
}
#[test]
fn adler32_known_value() {
assert_eq!(adler32(b"Wikipedia"), 0x11E60398);
assert_eq!(adler32(b""), 0x00000001); }
#[test]
fn png_signature_and_chunks() {
let bm = Bitmap::new(2, 2);
let png = encode_grayscale(&bm);
assert_eq!(&png[..8], &PNG_SIGNATURE);
assert_eq!(&png[8..12], &[0, 0, 0, 13]);
assert_eq!(&png[12..16], b"IHDR");
assert_eq!(&png[16..20], &2u32.to_be_bytes());
assert_eq!(&png[20..24], &2u32.to_be_bytes());
}
#[test]
fn round_trip_encode_decode_grayscale() {
let mut bm = Bitmap::new(7, 5);
for y in 0..5 {
for x in 0..7 {
bm.set(x, y, (x * 3 + y) % 2 == 0);
}
}
let png = encode_grayscale(&bm);
let decoded = decode(&png).unwrap();
assert_eq!(decoded.width, 7);
assert_eq!(decoded.height, 5);
assert_eq!(decoded, bm);
}
#[test]
fn png_round_trip_via_file_signature() {
let mut bm = Bitmap::new(5, 5);
for y in 0..5 {
for x in 0..5 {
bm.set(x, y, (x + y) % 2 == 0);
}
}
let png = encode_grayscale(&bm);
assert_eq!(&png[..8], &PNG_SIGNATURE);
let iend_marker = [0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82];
assert!(
png.windows(8).any(|w| w == iend_marker),
"PNG 末尾应包含 IEND chunk"
);
}
}