use std::collections::HashMap;
use std::io::{Seek, SeekFrom, Write};
use super::error::WzResult;
use super::header::WzHeader;
use super::keys::WzKey;
use crate::crypto::constants::WZ_OFFSET_CONSTANT;
pub struct WzBinaryWriter<W: Write + Seek> {
pub(crate) writer: W,
pub wz_key: WzKey,
pub hash: u32,
pub header: WzHeader,
pub string_cache: HashMap<String, u32>,
}
macro_rules! impl_write_le {
($($name:ident($ty:ty)),+ $(,)?) => { $(
pub fn $name(&mut self, val: $ty) -> WzResult<()> {
self.writer.write_all(&val.to_le_bytes())?;
Ok(())
}
)+ };
}
impl<W: Write + Seek> WzBinaryWriter<W> {
pub fn new(writer: W, iv: [u8; 4], header: WzHeader) -> Self {
WzBinaryWriter {
writer,
wz_key: WzKey::new(iv),
hash: 0,
header,
string_cache: HashMap::new(),
}
}
pub fn position(&mut self) -> WzResult<u64> {
Ok(self.writer.stream_position()?)
}
pub fn seek(&mut self, pos: u64) -> WzResult<()> {
self.writer.seek(SeekFrom::Start(pos))?;
Ok(())
}
impl_write_le! {
write_u8(u8),
write_i16(i16),
write_u16(u16),
write_i32(i32),
write_u32(u32),
write_i64(i64),
write_f32(f32),
write_f64(f64),
}
pub fn write_bytes(&mut self, data: &[u8]) -> WzResult<()> {
self.writer.write_all(data)?;
Ok(())
}
pub fn write_compressed_int(&mut self, val: i32) -> WzResult<()> {
if (-127..=127).contains(&val) && val != -128 {
self.write_u8(val as u8)
} else {
self.write_u8(0x80)?; self.write_i32(val)
}
}
pub fn write_compressed_long(&mut self, val: i64) -> WzResult<()> {
if (-127..=127).contains(&val) && val != -128 {
self.write_u8(val as u8)
} else {
self.write_u8(0x80)?;
self.write_i64(val)
}
}
pub fn write_wz_string(&mut self, s: &str) -> WzResult<()> {
if s.is_ascii() {
self.write_wz_ascii_string(s)
} else {
self.write_wz_unicode_string(s)
}
}
fn write_wz_unicode_string(&mut self, s: &str) -> WzResult<()> {
let chars: Vec<u16> = s.encode_utf16().collect();
let length = chars.len();
if length >= 127 {
self.write_u8(127)?; self.write_i32(length as i32)?;
} else {
self.write_u8(length as u8)?;
}
self.wz_key.ensure_size(length * 2);
let mut mask: u16 = super::WZ_UNICODE_MASK_INIT;
for (i, &ch) in chars.iter().enumerate() {
let key_lo = self.wz_key[i * 2] as u16;
let key_hi = self.wz_key[i * 2 + 1] as u16;
let key_word = key_lo | (key_hi << 8);
let encrypted = ch ^ key_word ^ mask;
mask = mask.wrapping_add(1);
self.write_u16(encrypted)?;
}
Ok(())
}
fn write_wz_ascii_string(&mut self, s: &str) -> WzResult<()> {
let bytes = s.as_bytes();
let length = bytes.len();
if length > 127 {
self.write_u8(0x80)?; self.write_i32(length as i32)?;
} else {
self.write_u8((-(length as i32)) as u8)?;
}
self.wz_key.ensure_size(length);
let mut mask: u8 = super::WZ_ASCII_MASK_INIT;
for (i, &byte) in bytes.iter().enumerate() {
let encrypted = byte ^ self.wz_key[i] ^ mask;
mask = mask.wrapping_add(1);
self.write_u8(encrypted)?;
}
Ok(())
}
pub fn write_string_value(
&mut self,
s: &str,
without_offset: u8,
with_offset: u8,
) -> WzResult<()> {
if s.len() > 4 {
if let Some(&cached_offset) = self.string_cache.get(s) {
self.write_u8(with_offset)?;
return self.write_i32(cached_offset as i32);
}
}
self.write_u8(without_offset)?;
let str_offset = self.position()? as u32;
self.write_wz_string(s)?;
if !self.string_cache.contains_key(s) {
self.string_cache.insert(s.to_string(), str_offset);
}
Ok(())
}
pub fn write_wz_object_value(&mut self, name: &str, entry_type: u8) -> WzResult<()> {
let cache_key = format!("{}_{}", entry_type, name);
if let Some(&cached_offset) = self.string_cache.get(&cache_key) {
self.write_u8(0x02)?; return self.write_i32(cached_offset as i32);
}
let str_offset = (self.position()? as u32).wrapping_sub(self.header.data_start);
self.write_u8(entry_type)?;
self.write_wz_string(name)?;
self.string_cache.insert(cache_key, str_offset);
Ok(())
}
pub fn write_null_terminated_string(&mut self, s: &str) -> WzResult<()> {
self.write_bytes(s.as_bytes())?;
self.write_u8(0)
}
pub fn write_wz_offset(&mut self, value: u32) -> WzResult<()> {
let cur_pos = self.position()? as u32;
let fstart = self.header.data_start;
let mut enc = (cur_pos.wrapping_sub(fstart)) ^ 0xFFFF_FFFF;
enc = enc.wrapping_mul(self.hash);
enc = enc.wrapping_sub(WZ_OFFSET_CONSTANT);
enc = enc.rotate_left(enc & 0x1F);
let write_val = enc ^ (value.wrapping_sub(fstart.wrapping_mul(2)));
self.write_u32(write_val)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wz::test_utils::{dummy_header, make_reader, make_reader_with_header};
use crate::wz::header::WzHeader;
use std::io::Cursor;
fn make_writer() -> WzBinaryWriter<Cursor<Vec<u8>>> {
WzBinaryWriter::new(Cursor::new(Vec::new()), [0; 4], dummy_header(0))
}
fn finish_writer(writer: WzBinaryWriter<Cursor<Vec<u8>>>) -> Vec<u8> {
writer.writer.into_inner()
}
#[test]
fn test_write_string_value_inline() {
let mut writer = make_writer();
writer.write_string_value("Property", 0x73, 0x1B).unwrap();
let data = finish_writer(writer);
let mut reader = make_reader(data);
let s = reader.read_string_block(0).unwrap();
assert_eq!(s, "Property");
}
#[test]
fn test_write_string_value_cached() {
let mut writer = make_writer();
writer.write_string_value("Property", 0x73, 0x1B).unwrap();
let first_end = writer.position().unwrap();
writer.write_string_value("Property", 0x73, 0x1B).unwrap();
let data = finish_writer(writer);
let mut reader = make_reader(data);
let s1 = reader.read_string_block(0).unwrap();
assert_eq!(s1, "Property");
reader.seek(first_end).unwrap();
let s2 = reader.read_string_block(0).unwrap();
assert_eq!(s2, "Property");
}
#[test]
fn test_write_string_value_short_string_no_cache() {
let mut writer = make_writer();
writer.write_string_value("ab", 0x00, 0x01).unwrap();
writer.write_string_value("ab", 0x00, 0x01).unwrap();
let data = finish_writer(writer);
let mut reader = make_reader(data);
let s1 = reader.read_string_block(0).unwrap();
assert_eq!(s1, "ab");
let s2 = reader.read_string_block(0).unwrap();
assert_eq!(s2, "ab");
}
#[test]
fn test_write_wz_object_value_inline() {
let mut writer = make_writer();
writer.write_wz_object_value("test.img", 4).unwrap();
let data = finish_writer(writer);
assert_eq!(data[0], 4);
}
#[test]
fn test_write_wz_object_value_cached() {
let mut writer = make_writer();
writer.write_wz_object_value("test.img", 4).unwrap();
let pos_after_first = writer.position().unwrap();
writer.write_wz_object_value("test.img", 4).unwrap();
let data = finish_writer(writer);
assert_eq!(data[pos_after_first as usize], 0x02);
}
#[test]
fn test_write_null_terminated_string() {
let mut writer = make_writer();
writer.write_null_terminated_string("hello").unwrap();
let data = finish_writer(writer);
assert_eq!(&data, &[b'h', b'e', b'l', b'l', b'o', 0]);
}
#[test]
fn test_write_null_terminated_string_empty() {
let mut writer = make_writer();
writer.write_null_terminated_string("").unwrap();
let data = finish_writer(writer);
assert_eq!(&data, &[0]);
}
#[test]
fn test_write_wz_offset_roundtrip() {
let mut writer = make_writer();
let desired: u32 = 1000;
writer.write_wz_offset(desired).unwrap();
let data = finish_writer(writer);
let mut reader = make_reader(data);
let result = reader.read_wz_offset().unwrap();
assert_eq!(result, desired as u64);
}
#[test]
fn test_write_wz_offset_roundtrip_with_hash() {
let data_start: u32 = 60;
let hash: u32 = 713421;
let desired: u32 = 200;
let header = WzHeader {
ident: String::new(),
file_size: 256,
data_start,
copyright: String::new(),
};
let mut writer = WzBinaryWriter::new(Cursor::new(vec![0u8; 256]), [0; 4], header);
writer.hash = hash;
writer.seek(data_start as u64).unwrap();
writer.write_wz_offset(desired).unwrap();
let data = writer.writer.into_inner();
let mut reader = make_reader_with_header(data, data_start, 256);
reader.hash = hash;
reader.seek(data_start as u64).unwrap();
let result = reader.read_wz_offset().unwrap();
assert_eq!(result, desired as u64);
}
#[test]
fn test_write_wz_offset_position_dependent() {
let hash: u32 = 12345;
let desired: u32 = 500;
let mut w1 = make_writer();
w1.hash = hash;
w1.write_wz_offset(desired).unwrap();
let data1 = finish_writer(w1);
let header = dummy_header(256);
let mut w2 = WzBinaryWriter::new(Cursor::new(vec![0u8; 256]), [0; 4], header);
w2.hash = hash;
w2.seek(8).unwrap();
w2.write_wz_offset(desired).unwrap();
let data2 = w2.writer.into_inner();
assert_ne!(&data1[0..4], &data2[8..12]);
let mut r1 = make_reader(data1);
r1.hash = hash;
assert_eq!(r1.read_wz_offset().unwrap(), desired as u64);
let mut r2 = make_reader(data2);
r2.hash = hash;
r2.seek(8).unwrap();
assert_eq!(r2.read_wz_offset().unwrap(), desired as u64);
}
#[test]
fn test_write_compressed_int_roundtrip() {
for &val in &[0, 1, -1, 127, -127, 128, -128, i32::MAX, i32::MIN] {
let mut writer = make_writer();
writer.write_compressed_int(val).unwrap();
let data = finish_writer(writer);
let mut reader = make_reader(data);
assert_eq!(reader.read_compressed_int().unwrap(), val, "Failed for {}", val);
}
}
#[test]
fn test_write_compressed_long_roundtrip() {
for &val in &[0i64, 1, -1, 127, -127, 128, -128, i64::MAX, i64::MIN] {
let mut writer = make_writer();
writer.write_compressed_long(val).unwrap();
let data = finish_writer(writer);
let mut reader = make_reader(data);
assert_eq!(reader.read_compressed_long().unwrap(), val, "Failed for {}", val);
}
}
#[test]
fn test_write_wz_string_ascii_roundtrip() {
let mut writer = make_writer();
writer.write_wz_string("Hello").unwrap();
let data = finish_writer(writer);
let mut reader = make_reader(data);
assert_eq!(reader.read_wz_string().unwrap(), "Hello");
let long_str: String = "B".repeat(200);
let mut writer = make_writer();
writer.write_wz_string(&long_str).unwrap();
let data = finish_writer(writer);
let mut reader = make_reader(data);
assert_eq!(reader.read_wz_string().unwrap(), long_str);
}
#[test]
fn test_write_wz_string_unicode_roundtrip() {
let unicode_str = "\u{AC00}\u{B098}\u{B2E4}"; let mut writer = make_writer();
writer.write_wz_string(unicode_str).unwrap();
let data = finish_writer(writer);
let mut reader = make_reader(data);
assert_eq!(reader.read_wz_string().unwrap(), unicode_str);
}
}