#![cfg_attr(docsrs, doc(cfg(feature = "bcj2")))]
extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
const NUM_MODEL_BITS: u32 = 11;
const BIT_MODEL_TOTAL: u32 = 1 << NUM_MODEL_BITS;
const TOP_VALUE: u32 = 1 << 24;
const NUM_MOVE_BITS: u32 = 5;
const PROB_INIT: u16 = (BIT_MODEL_TOTAL / 2) as u16;
const NUM_PROBS: usize = 2 + 256;
#[inline]
fn is_e8_e9(b: u8) -> bool {
b == 0xE8 || b == 0xE9
}
#[inline]
fn is_jcc(prev: u8, b: u8) -> bool {
prev == 0x0F && (b & 0xF0) == 0x80
}
struct RangeDec<'a> {
rc: &'a [u8],
pos: usize,
range: u32,
code: u32,
}
impl<'a> RangeDec<'a> {
fn new(rc: &'a [u8]) -> Result<Self, Error> {
if rc.len() < 5 {
return Err(Error::UnexpectedEnd);
}
if rc[0] != 0 {
return Err(Error::Corrupt);
}
let code = ((rc[1] as u32) << 24)
| ((rc[2] as u32) << 16)
| ((rc[3] as u32) << 8)
| (rc[4] as u32);
Ok(Self {
rc,
pos: 5,
range: 0xFFFF_FFFF,
code,
})
}
#[inline]
fn normalize(&mut self) -> Result<(), Error> {
if self.range < TOP_VALUE {
if self.pos >= self.rc.len() {
return Err(Error::UnexpectedEnd);
}
self.range <<= 8;
self.code = (self.code << 8) | self.rc[self.pos] as u32;
self.pos += 1;
}
Ok(())
}
#[inline]
fn decode_bit(&mut self, prob: &mut u16) -> Result<u32, Error> {
self.normalize()?;
let ttt = *prob as u32;
let bound = (self.range >> NUM_MODEL_BITS) * ttt;
if self.code < bound {
self.range = bound;
*prob = (ttt + ((BIT_MODEL_TOTAL - ttt) >> NUM_MOVE_BITS)) as u16;
Ok(0)
} else {
self.range -= bound;
self.code -= bound;
*prob = (ttt - (ttt >> NUM_MOVE_BITS)) as u16;
Ok(1)
}
}
}
#[inline]
fn prob_index(b: u8, prev: u8) -> usize {
if b == 0xE8 {
2 + prev as usize
} else if b == 0xE9 {
1
} else {
0
}
}
pub fn decode(
main: &[u8],
call: &[u8],
jump: &[u8],
rc: &[u8],
out_len: usize,
) -> Result<Vec<u8>, Error> {
let mut out = vec![0u8; out_len];
let mut probs = [PROB_INIT; NUM_PROBS];
let mut rd = RangeDec::new(rc)?;
let mut mp = 0usize; let mut cp = 0usize; let mut jp = 0usize; let mut op = 0usize; let mut prev: u8 = 0;
while op < out_len {
if mp >= main.len() {
return Err(Error::UnexpectedEnd);
}
let b = main[mp];
mp += 1;
out[op] = b;
op += 1;
let candidate = is_e8_e9(b) || is_jcc(prev, b);
let prev_before = prev;
prev = b;
if !candidate {
continue;
}
let pidx = prob_index(b, prev_before);
let bit = rd.decode_bit(&mut probs[pidx])?;
if bit == 0 {
continue;
}
if out_len - op < 4 {
return Err(Error::Corrupt);
}
let (src, sp) = if b == 0xE8 {
(call, &mut cp)
} else {
(jump, &mut jp)
};
if *sp + 4 > src.len() {
return Err(Error::UnexpectedEnd);
}
let abs = ((src[*sp] as u32) << 24)
| ((src[*sp + 1] as u32) << 16)
| ((src[*sp + 2] as u32) << 8)
| (src[*sp + 3] as u32);
*sp += 4;
let ip4 = (op as u32).wrapping_add(4);
let dest = abs.wrapping_sub(ip4);
out[op] = dest as u8;
out[op + 1] = (dest >> 8) as u8;
out[op + 2] = (dest >> 16) as u8;
out[op + 3] = (dest >> 24) as u8;
op += 4;
prev = (dest >> 24) as u8;
}
Ok(out)
}
pub fn encode(input: &[u8]) -> (Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>) {
let mut main = Vec::with_capacity(input.len());
let mut call = Vec::new();
let mut jump = Vec::new();
let mut probs = [PROB_INIT; NUM_PROBS];
let mut rc = RangeEnc::new();
let mut i = 0usize;
let mut prev: u8 = 0;
while i < input.len() {
let b = input[i];
main.push(b);
let candidate = is_e8_e9(b) || is_jcc(prev, b);
let pidx = prob_index(b, prev);
prev = b;
i += 1;
if !candidate {
continue;
}
if i + 4 > input.len() {
rc.encode_bit(&mut probs[pidx], 0);
continue;
}
let rel = (input[i] as u32)
| ((input[i + 1] as u32) << 8)
| ((input[i + 2] as u32) << 16)
| ((input[i + 3] as u32) << 24);
let ip4 = (i as u32).wrapping_add(4);
let abs = rel.wrapping_add(ip4);
rc.encode_bit(&mut probs[pidx], 1);
let stream = if b == 0xE8 { &mut call } else { &mut jump };
stream.push((abs >> 24) as u8);
stream.push((abs >> 16) as u8);
stream.push((abs >> 8) as u8);
stream.push(abs as u8);
i += 4;
prev = (rel >> 24) as u8;
}
let rc = rc.finish();
(main, call, jump, rc)
}
struct RangeEnc {
low: u64,
range: u32,
cache: u8,
cache_size: u64,
out: Vec<u8>,
}
impl RangeEnc {
fn new() -> Self {
Self {
low: 0,
range: 0xFFFF_FFFF,
cache: 0,
cache_size: 1,
out: Vec::new(),
}
}
fn shift_low(&mut self) {
if self.low < 0xFF00_0000 || self.low > 0xFFFF_FFFF {
let mut temp = self.cache;
loop {
self.out
.push((temp as u64).wrapping_add(self.low >> 32) as u8);
temp = 0xFF;
self.cache_size -= 1;
if self.cache_size == 0 {
break;
}
}
self.cache = (self.low >> 24) as u8;
}
self.cache_size += 1;
self.low = (self.low << 8) & 0xFFFF_FFFF;
}
fn encode_bit(&mut self, prob: &mut u16, bit: u32) {
let ttt = *prob as u32;
let bound = (self.range >> NUM_MODEL_BITS) * ttt;
if bit == 0 {
self.range = bound;
*prob = (ttt + ((BIT_MODEL_TOTAL - ttt) >> NUM_MOVE_BITS)) as u16;
} else {
self.low += bound as u64;
self.range -= bound;
*prob = (ttt - (ttt >> NUM_MOVE_BITS)) as u16;
}
while self.range < TOP_VALUE {
self.range <<= 8;
self.shift_low();
}
}
fn finish(mut self) -> Vec<u8> {
for _ in 0..5 {
self.shift_low();
}
self.out
}
}
#[cfg(test)]
mod tests {
use super::*;
fn roundtrip(input: &[u8]) {
let (main, call, jump, rc) = encode(input);
let got = decode(&main, &call, &jump, &rc, input.len()).expect("decode");
assert_eq!(got, input, "BCJ2 round-trip mismatch");
}
#[test]
fn empty() {
roundtrip(&[]);
}
#[test]
fn no_branches() {
roundtrip(b"the quick brown fox jumps over the lazy dog");
roundtrip(&[0u8; 64]);
let ramp: Vec<u8> = (0..200u32)
.map(|x| x as u8)
.filter(|&b| b != 0xE8 && b != 0xE9)
.collect();
roundtrip(&ramp);
}
#[test]
fn single_call() {
let mut v = vec![0x90u8, 0x90, 0xE8, 0x10, 0x20, 0x30, 0x00, 0xCC, 0xCC];
v.extend_from_slice(&[0u8; 8]);
roundtrip(&v);
}
#[test]
fn single_jmp() {
let v = vec![0xE9u8, 0xFF, 0xFF, 0xFF, 0xFF, 0x90, 0x90, 0x90, 0x90, 0x90];
roundtrip(&v);
}
#[test]
fn conditional_jump() {
let v = vec![0x0Fu8, 0x84, 0x01, 0x02, 0x03, 0x04, 0x55, 0x55, 0x55, 0x55];
roundtrip(&v);
}
#[test]
fn mixed_branches() {
let mut v = Vec::new();
for k in 0..50u32 {
v.push(0x55);
v.push(0xE8);
v.extend_from_slice(&(k.wrapping_mul(7)).to_le_bytes());
v.push(0xE9);
v.extend_from_slice(&(0x1000u32.wrapping_sub(k)).to_le_bytes());
v.push(0x0F);
v.push(0x8C);
v.extend_from_slice(&k.to_le_bytes());
}
v.extend_from_slice(&[0u8; 8]); roundtrip(&v);
}
#[test]
fn branch_opcode_at_tail_no_room() {
roundtrip(&[0x90, 0x90, 0xE8, 0x01, 0x02]); roundtrip(&[0xE9]); roundtrip(&[0x0F, 0x80]); }
#[test]
fn e8_prev_byte_models() {
let mut v = Vec::new();
for p in 0..256u32 {
v.push(p as u8);
v.push(0xE8);
v.extend_from_slice(&p.to_le_bytes());
}
v.extend_from_slice(&[0u8; 8]);
roundtrip(&v);
}
#[test]
fn truncated_rc_errors() {
assert_eq!(
decode(&[0x90], &[], &[], &[0, 0], 1),
Err(Error::UnexpectedEnd)
);
}
#[test]
fn bad_rc_first_byte() {
assert_eq!(
decode(&[0x90], &[], &[], &[1, 0, 0, 0, 0], 1),
Err(Error::Corrupt)
);
}
#[test]
fn truncated_main_errors() {
let (main, call, jump, rc) = encode(b"abc");
assert_eq!(
decode(&main, &call, &jump, &rc, 100),
Err(Error::UnexpectedEnd)
);
}
}