use bzimage::{BzImageHeader, MAGIC, VERSION};
use flate2::write::GzEncoder;
use flate2::Compression;
use sha2::{Digest, Sha256};
use std::io::{Cursor, Seek, SeekFrom, Write, Read};
#[test]
fn round_trip_write_read_validate_decompress() {
let payload = b"hello daemonizer world".to_vec();
let uncompressed_size = payload.len() as u64;
let mut enc = GzEncoder::new(Vec::new(), Compression::best());
enc.write_all(&payload).unwrap();
let compressed = enc.finish().unwrap();
let compressed_size = compressed.len() as u64;
let mut hasher = Sha256::new();
hasher.update(&compressed);
let checksum: [u8; 32] = hasher.finalize().into();
let header = BzImageHeader {
magic: *MAGIC,
version: VERSION.into(),
reserved1: 0u32.into(),
uncompressed_size: (uncompressed_size).into(),
compressed_size: (compressed_size).into(),
checksum,
reserved2: 0u32.into(),
};
let mut cur = Cursor::new(Vec::new());
header.write_to(&mut cur).unwrap();
cur.write_all(&compressed).unwrap();
cur.seek(SeekFrom::Start(0)).unwrap();
let read_header = BzImageHeader::read_from(&mut cur).unwrap();
assert_eq!(&read_header.magic_copy(), MAGIC);
let uncompressed_field: simple_endian::u64le = unsafe {
std::ptr::read_unaligned(std::ptr::addr_of!(read_header.uncompressed_size))
};
let compressed_field: simple_endian::u64le = unsafe {
std::ptr::read_unaligned(std::ptr::addr_of!(read_header.compressed_size))
};
let un: u64 = uncompressed_field.into();
let comp: u64 = compressed_field.into();
assert_eq!(un, uncompressed_size);
assert_eq!(comp, compressed_size);
let mut compressed_read = Vec::new();
cur.read_to_end(&mut compressed_read).unwrap();
assert_eq!(compressed_read.len(), compressed_size as usize);
assert!(read_header.validate_checksum(&compressed_read));
let decompressed = BzImageHeader::decompress_data(&compressed_read).unwrap();
assert_eq!(decompressed, payload);
}
#[test]
fn invalid_magic_fails() {
use std::io::Write;
let mut cur = Cursor::new(Vec::new());
cur.write_all(b"BAD!").unwrap();
cur.write_all(&vec![0u8; 60]).unwrap();
cur.seek(SeekFrom::Start(0)).unwrap();
let res = BzImageHeader::read_from(&mut cur);
assert!(res.is_err());
}
#[test]
fn checksum_mismatch_detected() {
let payload = b"somedata".to_vec();
let mut enc = GzEncoder::new(Vec::new(), Compression::best());
enc.write_all(&payload).unwrap();
let compressed = enc.finish().unwrap();
let mut hasher = Sha256::new();
hasher.update(&compressed);
let checksum: [u8; 32] = hasher.finalize().into();
let header = BzImageHeader {
magic: *MAGIC,
version: VERSION.into(),
reserved1: 0u32.into(),
uncompressed_size: (payload.len() as u64).into(),
compressed_size: (compressed.len() as u64).into(),
checksum,
reserved2: 0u32.into(),
};
let mut cur = Cursor::new(Vec::new());
header.write_to(&mut cur).unwrap();
let mut corrupted = compressed.clone();
if !corrupted.is_empty() { corrupted[0] ^= 0xff; }
cur.write_all(&corrupted).unwrap();
cur.seek(SeekFrom::Start(0)).unwrap();
let read_header = BzImageHeader::read_from(&mut cur).unwrap();
let mut compressed_read = Vec::new();
cur.read_to_end(&mut compressed_read).unwrap();
assert!(!read_header.validate_checksum(&compressed_read));
}
#[test]
fn header_write_size() {
let header = BzImageHeader {
magic: *MAGIC,
version: VERSION.into(),
reserved1: 0u32.into(),
uncompressed_size: 0u64.into(),
compressed_size: 0u64.into(),
checksum: [0u8; 32],
reserved2: 0u32.into(),
};
let mut cur = Cursor::new(Vec::new());
header.write_to(&mut cur).unwrap();
assert_eq!(cur.get_ref().len(), bzimage::HEADER_SIZE);
}
#[test]
fn truncated_header_fails() {
let mut buf = Vec::new();
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&[0u8; 10]);
let mut cur = Cursor::new(buf);
let r = BzImageHeader::read_from(&mut cur);
assert!(r.is_err());
}