use crate::deflate;
use crate::pbm::Bitmap;
const PNG_SIGNATURE: [u8; 8] = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
pub fn encode_grayscale(bitmap: &Bitmap) -> Vec<u8> {
let w = bitmap.width as u32;
let h = bitmap.height as u32;
let row_bytes = bitmap.width + 1;
let raw_len = row_bytes * bitmap.height;
let n_blocks = raw_len.div_ceil(65535).max(1);
let zlib_size = 2 + 5 * n_blocks + raw_len + 4;
let total = 8 + 25 + 12 + zlib_size + 12;
let mut out = Vec::with_capacity(total);
out.extend_from_slice(&PNG_SIGNATURE);
let mut ihdr = [0u8; 13];
ihdr[0..4].copy_from_slice(&w.to_be_bytes());
ihdr[4..8].copy_from_slice(&h.to_be_bytes());
ihdr[8] = 8; write_chunk(&mut out, b"IHDR", &ihdr);
let idat_len_pos = out.len();
out.extend_from_slice(&[0u8; 4]);
let idat_start = out.len();
out.extend_from_slice(b"IDAT");
let idat_body_start = out.len();
out.push(0x78); out.push(0x01);
let mut raw = Vec::with_capacity(raw_len);
let mut adler_a: u32 = 1;
let mut adler_b: u32 = 0;
let mut adler_since_mod = 0usize;
const ADLER_NMAX: usize = 5552;
for y in 0..bitmap.height {
raw.push(0u8);
adler_b = adler_b.wrapping_add(adler_a);
adler_since_mod += 1;
let row_slice = &bitmap.pixels[y * bitmap.width..(y + 1) * bitmap.width];
for &px in row_slice {
let v: u8 = if px { 0 } else { 255 };
raw.push(v);
adler_a = adler_a.wrapping_add(v as u32);
adler_b = adler_b.wrapping_add(adler_a);
adler_since_mod += 1;
}
if adler_since_mod >= ADLER_NMAX {
adler_a %= 65521;
adler_b %= 65521;
adler_since_mod = 0;
}
}
debug_assert_eq!(raw.len(), raw_len);
adler_a %= 65521;
adler_b %= 65521;
let adler = (adler_b << 16) | adler_a;
let mut i = 0usize;
while i < raw.len() {
let this_block = (raw.len() - i).min(65535);
let bfinal = i + this_block == raw.len();
out.push(if bfinal { 0x01 } else { 0x00 });
out.extend_from_slice(&(this_block as u16).to_le_bytes());
out.extend_from_slice(&(!(this_block as u16)).to_le_bytes());
out.extend_from_slice(&raw[i..i + this_block]);
i += this_block;
}
out.extend_from_slice(&adler.to_be_bytes());
let idat_body_len = (out.len() - idat_body_start) as u32;
out[idat_len_pos..idat_len_pos + 4].copy_from_slice(&idat_body_len.to_be_bytes());
let crc = crc32(&out[idat_start..]);
out.extend_from_slice(&crc.to_be_bytes());
write_chunk(&mut out, b"IEND", &[]);
out
}
fn write_chunk(out: &mut Vec<u8>, kind: &[u8; 4], data: &[u8]) {
out.extend_from_slice(&(data.len() as u32).to_be_bytes());
let crc_start = out.len();
out.extend_from_slice(kind);
out.extend_from_slice(data);
let crc = crc32(&out[crc_start..]);
out.extend_from_slice(&crc.to_be_bytes());
}
const CRC32_POLY: u32 = 0xEDB88320;
const CRC32_TABLE: [u32; 256] = {
let mut table = [0u32; 256];
let mut i = 0u32;
while i < 256 {
let mut c = i;
let mut j = 0;
while j < 8 {
c = if c & 1 == 1 { (c >> 1) ^ CRC32_POLY } else { c >> 1 };
j += 1;
}
table[i as usize] = c;
i += 1;
}
table
};
fn crc32(data: &[u8]) -> u32 {
let mut crc: u32 = 0xFFFF_FFFF;
for &b in data {
crc = CRC32_TABLE[((crc ^ b as u32) & 0xFF) as usize] ^ (crc >> 8);
}
!crc
}
const ADLER_MOD: u32 = 65521;
fn adler32(data: &[u8]) -> u32 {
let mut a: u32 = 1;
let mut b: u32 = 0;
const NMAX: usize = 5552;
let mut i = 0;
while i < data.len() {
let end = (i + NMAX).min(data.len());
for &x in &data[i..end] {
a += x as u32;
b += a;
}
a %= ADLER_MOD;
b %= ADLER_MOD;
i = end;
}
(b << 16) | a
}
pub fn decode(data: &[u8]) -> Result<Bitmap, &'static str> {
if data.len() < 8 || data[..8] != PNG_SIGNATURE {
return Err("PNG: bad signature");
}
let chunks = parse_chunks(&data[8..])?;
let ihdr = chunks
.iter()
.find(|c| c.kind == *b"IHDR")
.ok_or("PNG: missing IHDR")?;
let header = parse_ihdr(ihdr.data)?;
let mut zlib_stream = Vec::new();
for c in &chunks {
if c.kind == *b"IDAT" {
zlib_stream.extend_from_slice(c.data);
}
}
if zlib_stream.is_empty() {
return Err("PNG: no IDAT");
}
let palette: Option<Vec<[u8; 3]>> = chunks
.iter()
.find(|c| c.kind == *b"PLTE")
.map(|c| {
if c.data.len() % 3 != 0 {
return Err("PNG: PLTE not multiple of 3");
}
Ok(c.data.chunks(3).map(|t| [t[0], t[1], t[2]]).collect())
})
.transpose()?;
let raw = zlib_unwrap(&zlib_stream)?;
let unfiltered = unfilter(&raw, &header)?;
pixels_to_bitmap(&unfiltered, &header, palette.as_deref())
}
#[derive(Debug)]
struct Chunk<'a> {
kind: [u8; 4],
data: &'a [u8],
}
fn parse_chunks(mut data: &[u8]) -> Result<Vec<Chunk<'_>>, &'static str> {
let mut out = Vec::new();
while !data.is_empty() {
if data.len() < 12 {
return Err("PNG: chunk truncated");
}
let length = u32::from_be_bytes(data[..4].try_into().unwrap()) as usize;
if data.len() < 12 + length {
return Err("PNG: chunk length exceeds file");
}
let mut kind = [0u8; 4];
kind.copy_from_slice(&data[4..8]);
let chunk_data = &data[8..8 + length];
let crc_expected = u32::from_be_bytes(data[8 + length..12 + length].try_into().unwrap());
let crc_actual = crc32(&data[4..8 + length]);
if crc_actual != crc_expected {
return Err("PNG: chunk CRC mismatch");
}
out.push(Chunk { kind, data: chunk_data });
if &kind == b"IEND" {
return Ok(out);
}
data = &data[12 + length..];
}
Err("PNG: missing IEND")
}
#[derive(Debug, Clone, Copy)]
struct IhdrInfo {
width: u32,
height: u32,
bit_depth: u8,
color_type: u8,
}
impl IhdrInfo {
fn channels(&self) -> usize {
match self.color_type {
0 => 1, 2 => 3, 3 => 1, 4 => 2, 6 => 4, _ => 0,
}
}
fn row_bytes(&self) -> usize {
let bits_per_pixel = self.bit_depth as usize * self.channels();
(self.width as usize * bits_per_pixel + 7) / 8
}
}
fn parse_ihdr(data: &[u8]) -> Result<IhdrInfo, &'static str> {
if data.len() != 13 {
return Err("PNG: IHDR not 13 bytes");
}
let info = IhdrInfo {
width: u32::from_be_bytes(data[0..4].try_into().unwrap()),
height: u32::from_be_bytes(data[4..8].try_into().unwrap()),
bit_depth: data[8],
color_type: data[9],
};
if data[10] != 0 {
return Err("PNG: unsupported compression method");
}
if data[11] != 0 {
return Err("PNG: unsupported filter method");
}
if data[12] != 0 {
return Err("PNG: interlaced PNG not supported");
}
if info.channels() == 0 {
return Err("PNG: bad color_type");
}
Ok(info)
}
fn zlib_unwrap(stream: &[u8]) -> Result<Vec<u8>, &'static str> {
if stream.len() < 6 {
return Err("zlib: stream too short");
}
let cmf = stream[0];
let flg = stream[1];
if (cmf & 0x0F) != 8 {
return Err("zlib: not DEFLATE");
}
if ((cmf as u32) * 256 + flg as u32) % 31 != 0 {
return Err("zlib: bad header checksum");
}
if flg & 0x20 != 0 {
return Err("zlib: FDICT not supported");
}
let deflated = &stream[2..stream.len() - 4];
let inflated = deflate::inflate(deflated)?;
let adler_expected = u32::from_be_bytes(stream[stream.len() - 4..].try_into().unwrap());
let adler_actual = adler32(&inflated);
if adler_actual != adler_expected {
return Err("zlib: Adler-32 mismatch");
}
Ok(inflated)
}
fn unfilter(raw: &[u8], h: &IhdrInfo) -> Result<Vec<u8>, &'static str> {
let row_bytes = h.row_bytes();
let height = h.height as usize;
if raw.len() != (row_bytes + 1) * height {
return Err("PNG: unfilter expected size mismatch");
}
let bpp = (h.bit_depth as usize * h.channels() + 7) / 8; let bpp = bpp.max(1); let mut out = vec![0u8; row_bytes * height];
for y in 0..height {
let in_row = &raw[y * (row_bytes + 1)..(y + 1) * (row_bytes + 1)];
let filter_type = in_row[0];
let in_row_data = &in_row[1..];
let prev_row_start = if y == 0 { None } else { Some((y - 1) * row_bytes) };
for x in 0..row_bytes {
let cur = in_row_data[x];
let left = if x >= bpp { out[y * row_bytes + x - bpp] } else { 0 };
let up = match prev_row_start {
Some(s) => out[s + x],
None => 0,
};
let up_left = match prev_row_start {
Some(s) if x >= bpp => out[s + x - bpp],
_ => 0,
};
let value = match filter_type {
0 => cur,
1 => cur.wrapping_add(left),
2 => cur.wrapping_add(up),
3 => cur.wrapping_add(((left as u16 + up as u16) / 2) as u8),
4 => cur.wrapping_add(paeth_predictor(left, up, up_left)),
_ => return Err("PNG: bad filter type"),
};
out[y * row_bytes + x] = value;
}
}
Ok(out)
}
fn paeth_predictor(a: u8, b: u8, c: u8) -> u8 {
let p = a as i32 + b as i32 - c as i32;
let pa = (p - a as i32).abs();
let pb = (p - b as i32).abs();
let pc = (p - c as i32).abs();
if pa <= pb && pa <= pc {
a
} else if pb <= pc {
b
} else {
c
}
}
fn pixels_to_bitmap(
pixels: &[u8],
h: &IhdrInfo,
palette: Option<&[[u8; 3]]>,
) -> Result<Bitmap, &'static str> {
let w = h.width as usize;
let height = h.height as usize;
let mut bm = Bitmap::new(w, height);
let row_bytes = h.row_bytes();
for y in 0..height {
let row = &pixels[y * row_bytes..(y + 1) * row_bytes];
for x in 0..w {
let brightness = sample_pixel(row, x, h, palette)?;
bm.set(x, y, brightness < 128);
}
}
Ok(bm)
}
fn sample_pixel(
row: &[u8],
x: usize,
h: &IhdrInfo,
palette: Option<&[[u8; 3]]>,
) -> Result<u8, &'static str> {
let bd = h.bit_depth as usize;
match h.color_type {
0 => {
let v = read_sub_byte_sample(row, x, bd);
Ok(scale_to_u8(v, bd))
}
2 => {
if bd != 8 {
return Err("PNG: RGB only supported at 8-bit depth");
}
let r = row[x * 3];
let g = row[x * 3 + 1];
let b = row[x * 3 + 2];
Ok(((r as u16 + g as u16 + b as u16) / 3) as u8)
}
3 => {
let idx = read_sub_byte_sample(row, x, bd) as usize;
let pal = palette.ok_or("PNG: palette image missing PLTE")?;
if idx >= pal.len() {
return Err("PNG: palette index out of range");
}
let [r, g, b] = pal[idx];
Ok(((r as u16 + g as u16 + b as u16) / 3) as u8)
}
4 => {
if bd != 8 {
return Err("PNG: gray+alpha only at 8-bit");
}
Ok(row[x * 2]) }
6 => {
if bd != 8 {
return Err("PNG: RGBA only at 8-bit");
}
let r = row[x * 4];
let g = row[x * 4 + 1];
let b = row[x * 4 + 2];
Ok(((r as u16 + g as u16 + b as u16) / 3) as u8)
}
_ => Err("PNG: unknown color_type"),
}
}
fn read_sub_byte_sample(row: &[u8], x: usize, bit_depth: usize) -> u8 {
match bit_depth {
1 => (row[x / 8] >> (7 - x % 8)) & 0x01,
2 => (row[x / 4] >> (6 - 2 * (x % 4))) & 0x03,
4 => (row[x / 2] >> (4 - 4 * (x % 2))) & 0x0F,
8 => row[x],
_ => 0,
}
}
fn scale_to_u8(v: u8, bit_depth: usize) -> u8 {
match bit_depth {
1 => v * 255,
2 => v * 85,
4 => v * 17,
8 => v,
_ => v,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn crc32_known_value() {
assert_eq!(crc32(b"123456789"), 0xCBF43926);
assert_eq!(crc32(b""), 0x00000000);
}
#[test]
fn adler32_known_value() {
assert_eq!(adler32(b"Wikipedia"), 0x11E60398);
assert_eq!(adler32(b""), 0x00000001); }
#[test]
fn png_signature_and_chunks() {
let bm = Bitmap::new(2, 2);
let png = encode_grayscale(&bm);
assert_eq!(&png[..8], &PNG_SIGNATURE);
assert_eq!(&png[8..12], &[0, 0, 0, 13]);
assert_eq!(&png[12..16], b"IHDR");
assert_eq!(&png[16..20], &2u32.to_be_bytes());
assert_eq!(&png[20..24], &2u32.to_be_bytes());
}
#[test]
fn round_trip_encode_decode_grayscale() {
let mut bm = Bitmap::new(7, 5);
for y in 0..5 {
for x in 0..7 {
bm.set(x, y, (x * 3 + y) % 2 == 0);
}
}
let png = encode_grayscale(&bm);
let decoded = decode(&png).unwrap();
assert_eq!(decoded.width, 7);
assert_eq!(decoded.height, 5);
assert_eq!(decoded, bm);
}
#[test]
fn png_round_trip_via_file_signature() {
let mut bm = Bitmap::new(5, 5);
for y in 0..5 {
for x in 0..5 {
bm.set(x, y, (x + y) % 2 == 0);
}
}
let png = encode_grayscale(&bm);
assert_eq!(&png[..8], &PNG_SIGNATURE);
let iend_marker = [0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82];
assert!(
png.windows(8).any(|w| w == iend_marker),
"PNG 末尾应包含 IEND chunk"
);
}
}