use crate::error::{Error, Result};
use tiff_core::{ByteOrder, Compression, Predictor};
use crate::sample::TiffWriteSample;
pub fn compress_block<T: TiffWriteSample>(
samples: &[T],
byte_order: ByteOrder,
compression: Compression,
predictor: Predictor,
samples_per_pixel: u16,
row_width_pixels: usize,
index: usize,
) -> Result<Vec<u8>> {
let mut encoded = T::encode_slice(samples, byte_order);
let row_bytes = row_width_pixels * T::BYTES_PER_SAMPLE * samples_per_pixel as usize;
if row_bytes > 0 {
for row in encoded.chunks_exact_mut(row_bytes) {
apply_forward_predictor(
row,
predictor,
T::BITS_PER_SAMPLE,
samples_per_pixel,
byte_order,
)?;
}
}
compress(&encoded, compression, index)
}
pub fn compress(data: &[u8], compression: Compression, index: usize) -> Result<Vec<u8>> {
match compression {
Compression::None => Ok(data.to_vec()),
Compression::Lzw => compress_lzw(data, index),
Compression::Deflate | Compression::DeflateOld => compress_deflate(data, index),
#[cfg(feature = "zstd")]
Compression::Zstd => compress_zstd(data, index),
other => Err(Error::CompressionFailed {
index,
reason: format!("compression {:?} is not supported for writing", other),
}),
}
}
fn apply_forward_predictor(
row: &mut [u8],
predictor: Predictor,
bits_per_sample: u16,
samples_per_pixel: u16,
byte_order: ByteOrder,
) -> Result<()> {
match predictor {
Predictor::None => Ok(()),
Predictor::Horizontal => {
forward_horizontal_differencing(row, bits_per_sample, samples_per_pixel, byte_order);
Ok(())
}
Predictor::FloatingPoint => {
forward_float_predictor(row, bits_per_sample, samples_per_pixel, byte_order);
Ok(())
}
}
}
fn forward_horizontal_differencing(
buf: &mut [u8],
bit_depth: u16,
samples: u16,
byte_order: ByteOrder,
) {
let bpv = match bit_depth {
0..=8 => 1usize,
9..=16 => 2,
17..=32 => 4,
_ => 8,
};
let n_values = buf.len() / bpv;
let skip = usize::from(samples);
if skip >= n_values {
return;
}
for vi in (skip..n_values).rev() {
let pos = vi * bpv;
let prev = (vi - skip) * bpv;
match bpv {
1 => {
buf[pos] = buf[pos].wrapping_sub(buf[prev]);
}
2 => {
let cur = byte_order.read_u16([buf[pos], buf[pos + 1]]);
let prv = byte_order.read_u16([buf[prev], buf[prev + 1]]);
let d = byte_order.write_u16(cur.wrapping_sub(prv));
buf[pos..pos + 2].copy_from_slice(&d);
}
4 => {
let cur = byte_order.read_u32(buf[pos..pos + 4].try_into().unwrap());
let prv = byte_order.read_u32(buf[prev..prev + 4].try_into().unwrap());
let d = byte_order.write_u32(cur.wrapping_sub(prv));
buf[pos..pos + 4].copy_from_slice(&d);
}
_ => {
let cur = byte_order.read_u64(buf[pos..pos + 8].try_into().unwrap());
let prv = byte_order.read_u64(buf[prev..prev + 8].try_into().unwrap());
let d = byte_order.write_u64(cur.wrapping_sub(prv));
buf[pos..pos + 8].copy_from_slice(&d);
}
}
}
}
fn forward_float_predictor(buf: &mut [u8], bit_depth: u16, samples: u16, byte_order: ByteOrder) {
let bps = match bit_depth {
16 => 2usize,
32 => 4,
64 => 8,
_ => return,
};
let n_values = buf.len() / bps;
if n_values == 0 {
return;
}
let need_swap = matches!(byte_order, ByteOrder::LittleEndian);
let mut tmp = vec![0u8; buf.len()];
for i in 0..n_values {
let base = i * bps;
for b in 0..bps {
let src_b = if need_swap { bps - 1 - b } else { b };
tmp[b * n_values + i] = buf[base + src_b];
}
}
let samples = usize::from(samples);
for i in (samples..tmp.len()).rev() {
tmp[i] = tmp[i].wrapping_sub(tmp[i - samples]);
}
buf.copy_from_slice(&tmp);
}
fn compress_lzw(data: &[u8], index: usize) -> Result<Vec<u8>> {
use weezl::encode::Encoder;
use weezl::BitOrder;
let mut encoder = Encoder::with_tiff_size_switch(BitOrder::Msb, 8);
encoder.encode(data).map_err(|e| Error::CompressionFailed {
index,
reason: format!("LZW: {e}"),
})
}
fn compress_deflate(data: &[u8], index: usize) -> Result<Vec<u8>> {
use flate2::write::ZlibEncoder;
use std::io::Write;
let mut encoder = ZlibEncoder::new(Vec::new(), flate2::Compression::default());
encoder
.write_all(data)
.map_err(|e| Error::CompressionFailed {
index,
reason: format!("deflate write: {e}"),
})?;
encoder.finish().map_err(|e| Error::CompressionFailed {
index,
reason: format!("deflate finish: {e}"),
})
}
#[cfg(feature = "zstd")]
fn compress_zstd(data: &[u8], index: usize) -> Result<Vec<u8>> {
zstd::stream::encode_all(std::io::Cursor::new(data), 3).map_err(|e| Error::CompressionFailed {
index,
reason: format!("ZSTD: {e}"),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_no_compression() {
let data = vec![1u8, 2, 3, 4, 5, 6];
let compressed = compress(&data, Compression::None, 0).unwrap();
assert_eq!(compressed, data);
}
#[test]
fn roundtrip_lzw() {
let data = vec![0u8; 256];
let compressed = compress(&data, Compression::Lzw, 0).unwrap();
assert!(compressed.len() < data.len());
let mut decoder = weezl::decode::Decoder::with_tiff_size_switch(weezl::BitOrder::Msb, 8);
let decompressed = decoder.decode(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn roundtrip_deflate() {
let data: Vec<u8> = (0..256).map(|i| (i % 256) as u8).collect();
let compressed = compress(&data, Compression::Deflate, 0).unwrap();
use flate2::read::ZlibDecoder;
use std::io::Read;
let mut decoder = ZlibDecoder::new(&compressed[..]);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed).unwrap();
assert_eq!(decompressed, data);
}
#[cfg(feature = "zstd")]
#[test]
fn roundtrip_zstd() {
let data: Vec<u8> = (0..256).map(|i| (i % 256) as u8).collect();
let compressed = compress(&data, Compression::Zstd, 0).unwrap();
let decompressed = zstd::stream::decode_all(std::io::Cursor::new(&compressed)).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn forward_horizontal_u8() {
let mut buf = vec![1u8, 2, 4, 7];
forward_horizontal_differencing(&mut buf, 8, 1, ByteOrder::LittleEndian);
assert_eq!(buf, vec![1, 1, 2, 3]);
}
#[test]
fn forward_horizontal_u16_le() {
let bo = ByteOrder::LittleEndian;
let mut buf = Vec::new();
buf.extend_from_slice(&bo.write_u16(1));
buf.extend_from_slice(&bo.write_u16(2));
buf.extend_from_slice(&bo.write_u16(4));
forward_horizontal_differencing(&mut buf, 16, 1, bo);
let v0 = bo.read_u16([buf[0], buf[1]]);
let v1 = bo.read_u16([buf[2], buf[3]]);
let v2 = bo.read_u16([buf[4], buf[5]]);
assert_eq!((v0, v1, v2), (1, 1, 2));
}
}