import os
import subprocess
from math import ceil, floor, log2
ROOT = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "src", "codes"
)
def get_best_fitting_type(n_bits, signed=False):
if signed:
if n_bits <= 7:
return "i8"
if n_bits <= 15:
return "i16"
if n_bits <= 31:
return "i32"
if n_bits <= 63:
return "i64"
if n_bits <= 127:
return "i128"
raise ValueError(n_bits)
else:
if n_bits <= 8:
return "u8"
if n_bits <= 16:
return "u16"
if n_bits <= 32:
return "u32"
if n_bits <= 64:
return "u64"
if n_bits <= 128:
return "u128"
raise ValueError(n_bits)
read_func_merged_table = """
/// Reads a value using a decoding table.
///
/// If the result is `Some` the decoding was successful, and
/// the decoded value and the length of the code are returned.
#[inline(always)]
pub fn read_table_%(bo)s<B: BitRead<%(BO)s>>(backend: &mut B) -> Option<(u64, usize)> {
if let Ok(idx) = backend.peek_bits(READ_BITS) {
let idx: usize = idx.as_to();
let (value, len) = READ_%(BO)s[idx];
if len != MISSING_VALUE_LEN_%(BO)s {
backend.skip_bits_after_peek(len as usize);
return Some((value as u64, len as usize));
}
}
None
}
"""
read_func_two_table = """
/// Reads a value using a decoding table.
///
/// If the result is `Some` the decoding was successful, and
/// the decoded value and the length of the code are returned.
#[inline(always)]
pub fn read_table_%(bo)s<B: BitRead<%(BO)s>>(backend: &mut B) -> Option<(u64, usize)> {
if let Ok(idx) = backend.peek_bits(READ_BITS) {
let idx: usize = idx.as_to();
let len = READ_LEN_%(BO)s[idx];
if len != MISSING_VALUE_LEN_%(BO)s {
backend.skip_bits_after_peek(len as usize);
return Some((READ_%(BO)s[idx] as u64, len as usize));
}
}
None
}
"""
write_func_merged_table = """
/// Writes a value using an encoding table.
///
/// If the result is `Some` the encoding was successful, and
/// length of the code is returned.
#[inline(always)]
#[allow(clippy::unnecessary_cast)] // rationale: "*bits as u64" is flaky redundant
pub fn write_table_%(bo)s<B: BitWrite<%(BO)s>>(backend: &mut B, n: u64) -> Result<Option<usize>, B::Error> {
// We cannot use .get() here because n is a u64
if n >= WRITE_%(BO)s.len() as u64 {
return Ok(None);
}
let n = n as usize;
let (bits, len) = WRITE_%(BO)s[n];
backend.write_bits(bits as u64, len as usize)?;
Ok(Some(len as usize))
}
"""
write_func_two_table = """
/// Writes a value using an encoding table.
///
/// If the result is `Some` the encoding was successful, and
/// length of the code is returned.
#[inline(always)]
pub fn write_table_%(bo)s<B: BitWrite<%(BO)s>>(backend: &mut B, n: u64) -> Result<Option<usize>, B::Error> {
// We cannot use .get() here because n is a u64
if n >= WRITE_%(BO)s.len() as u64 {
return Ok(None);
}
let n = n as usize;
let len = WRITE_LEN_%(BO)s[n] as usize;
backend.write_bits(WRITE_%(BO)s[n] as u64, len)?;
Ok(Some(len))
}
"""
def gen_table(
read_bits,
write_max_val,
len_max_val,
code_name,
len_func,
read_func,
write_func,
merged_table,
):
with open(os.path.join(ROOT, "{}_tables.rs".format(code_name)), "w") as f:
f.write(
"#![doc(hidden)]\n// THIS FILE HAS BEEN GENERATED BY THE SCRIPT {}\n".format(
os.path.basename(__file__)
)
)
f.write("// ~~~~~~~~~~~~~~~~~~~ DO NOT MODIFY ~~~~~~~~~~~~~~~~~~~~~~\n")
f.write(
"// Methods for reading and writing values using precomputed tables for {} codes\n".format( code_name
)
)
f.write("use crate::traits::{BE, BitRead, BitWrite, LE};\n")
f.write("use num_primitive::PrimitiveNumber;\n")
f.write("/// How many bits are needed to read the tables\n")
f.write("pub const READ_BITS: usize = {};\n".format(read_bits))
f.write("/// Maximum value writable using the table(s)\n")
f.write("pub const WRITE_MAX: u64 = {};\n".format(write_max_val))
if merged_table:
read_func_template = read_func_merged_table
write_func_template = write_func_merged_table
else:
read_func_template = read_func_two_table
write_func_template = write_func_two_table
for bo in ["le", "be"]:
f.write(read_func_template % {"bo": bo, "BO": bo.upper()})
f.write(write_func_template % {"bo": bo, "BO": bo.upper()})
for BO in ["BE", "LE"]:
codes = []
for value in range(0, 2**read_bits):
bits = ("{:0%sb}" % read_bits).format(value)
try:
value, bits_left = read_func(bits, BO == "BE")
codes.append((value, read_bits - len(bits_left)))
except ValueError:
codes.append((None, None))
read_max_val = max(x[0] or 0 for x in codes)
read_max_len = max(x[1] or 0 for x in codes)
len_ty = "u8"
f.write(
"/// The len we assign to a code that cannot be decoded through the table\n"
)
f.write(
"pub const MISSING_VALUE_LEN_{}: {} = {};\n".format(
BO, len_ty, read_max_len + 1
)
)
if merged_table:
f.write(
"/// Precomputed table for reading {} codes\n".format(code_name)
)
f.write(
"pub const READ_%s: &[(%s, %s)] = &["
% (
BO,
get_best_fitting_type(log2(read_max_val + 1)),
get_best_fitting_type(log2(read_max_len + 2)),
)
)
for value, l in codes:
f.write("({}, {}), ".format(value or 0, l or (read_max_len + 1)))
f.write("];\n")
else:
f.write(
"/// Precomputed table for reading {} codes\n".format( code_name
)
)
f.write(
"pub const READ_%s: &[%s] = &["
% (
BO,
get_best_fitting_type(log2(read_max_val + 1)),
)
)
for value, l in codes:
f.write("{}, ".format(value or 0))
f.write("];\n")
f.write(
"/// Precomputed lengths table for reading {} codes\n".format( code_name
)
)
f.write(
"pub const READ_LEN_%s: &[%s] = &["
% (
BO,
get_best_fitting_type(log2(read_max_len + 2)),
)
)
for value, l in codes:
f.write("{}, ".format(l or (read_max_len + 1)))
f.write("];\n")
for bo in ["BE", "LE"]:
if merged_table:
f.write(
"/// Precomputed lengths table for writing {} codes\n".format(
code_name
)
)
f.write(
"pub const WRITE_%s: &[(%s, u8)] = &["
% (bo, get_best_fitting_type(len_func(write_max_val)))
)
for n in range(write_max_val + 1):
bits = write_func(n, "", bo == "BE")
f.write("({}, {}),".format(int(bits, 2), len(bits)))
f.write("];\n")
else:
f.write(
"/// Table used to speed up the writing of {} codes\n".format(
code_name
)
)
f.write(
"pub const WRITE_%s: &[%s] = &["
% (bo, get_best_fitting_type(len_func(write_max_val)))
)
len_bits = []
for n in range(write_max_val + 1):
bits = write_func(n, "", bo == "BE")
len_bits.append(len(bits))
f.write("{},".format(int(bits, 2)))
f.write("];\n")
f.write(
"/// Table used to speed up the writing of {} codes\n".format(
code_name
)
)
f.write(
"pub const WRITE_LEN_%s: &[%s] = &["
% (bo, get_best_fitting_type(len_func(write_max_val)))
)
for l in len_bits:
f.write("{}, ".format(l))
f.write("];\n")
f.write(
"/// Table used to speed up the skipping of {} codes\n".format(code_name)
)
f.write(
"pub const LEN: &[%s] = &["
% (get_best_fitting_type(ceil(log2(len_func(len_max_val)))))
)
for n in range(write_max_val + 1):
f.write("{}, ".format(len_func(n)))
f.write("];\n")
f.write(
"/// Asserts at compile time that `peek_bits` is large enough for these tables.\n"
)
f.write("pub const fn check_read_table(peek_bits: usize) {\n")
f.write(
' assert!(peek_bits >= READ_BITS, "BitRead peek word too small for %s code read tables (%d bits required)");\n'
% (code_name, read_bits)
)
f.write("}\n")
def read_fixed(n_bits, bitstream, be):
if len(bitstream) < n_bits:
raise ValueError()
if be:
return int(bitstream[:n_bits], 2), bitstream[n_bits:]
else:
return int(bitstream[-n_bits:], 2), bitstream[:-n_bits]
def write_fixed(value, n_bits, bitstream, be):
if be:
return bitstream + ("{:0%sb}" % n_bits).format(value)
else:
return ("{:0%sb}" % n_bits).format(value) + bitstream
def read_unary(bitstream, be):
if be:
l = len(bitstream) - len(bitstream.lstrip("0")) if l == len(bitstream):
raise ValueError()
return l, bitstream[l + 1 :]
else:
l = len(bitstream) - len(bitstream.rstrip("0")) if l == len(bitstream):
raise ValueError()
return l, bitstream[: -l - 1]
def write_unary(n, bitstream, be):
if be:
return bitstream + "0" * n + "1"
else:
return "1" + "0" * n + bitstream
def len_unary(n):
return n + 1
assert write_unary(0, "", True) == "1"
assert write_unary(0, "", False) == "1"
assert write_unary(1, "", True) == "01"
assert write_unary(1, "", False) == "10"
assert write_unary(2, "", True) == "001"
assert write_unary(2, "", False) == "100"
assert write_unary(3, "", True) == "0001"
assert write_unary(3, "", False) == "1000"
for i in range(256):
wbe = write_unary(i, "", True)
rbe = read_unary(wbe, True)[0]
wle = write_unary(i, "", False)
rle = read_unary(wle, False)[0]
l = len_unary(i) assert i == rbe
assert i == rle
assert len(wbe) == l
assert len(wle) == l
def gen_unary(read_bits, write_max_val, len_max_val=None, merged_table=False):
len_max_val = len_max_val or write_max_val
return gen_table(
read_bits,
min(write_max_val, 63),
len_max_val,
"unary",
len_unary,
read_unary,
write_unary,
merged_table,
)
def read_gamma(bitstream, be):
l, bitstream = read_unary(bitstream, be)
if l == 0:
return 0, bitstream
f, bitstream = read_fixed(l, bitstream, be)
v = f + (1 << l) - 1
return v, bitstream
def write_gamma(n, bitstream, be):
n += 1
l = floor(log2(n)) s = n - (1 << l)
bitstream = write_unary(l, bitstream, be)
if l != 0:
bitstream = write_fixed(s, l, bitstream, be)
return bitstream
def len_gamma(n):
n += 1
l = floor(log2(n)) return 2 * l + 1
assert write_gamma(0, "", True) == "1"
assert write_gamma(0, "", False) == "1"
assert write_gamma(1, "", True) == "010"
assert write_gamma(1, "", False) == "010"
assert write_gamma(2, "", True) == "011"
assert write_gamma(2, "", False) == "110"
assert write_gamma(3, "", True) == "00100"
assert write_gamma(3, "", False) == "00100"
assert write_gamma(4, "", True) == "00101"
assert write_gamma(4, "", False) == "01100"
assert write_gamma(5, "", True) == "00110"
assert write_gamma(5, "", False) == "10100"
for i in range(256):
wbe = write_gamma(i, "", True)
rbe = read_gamma(wbe, True)[0]
wle = write_gamma(i, "", False)
rle = read_gamma(wle, False)[0]
l = len_gamma(i) assert i == rbe
assert i == rle
assert len(wbe) == l
assert len(wle) == l
def gen_gamma(read_bits, write_max_val, len_max_val=None, merged_table=False):
assert read_bits > 0
len_max_val = len_max_val or write_max_val
return gen_table(
read_bits,
write_max_val,
len_max_val,
"gamma",
len_gamma,
read_gamma,
write_gamma,
merged_table,
)
def read_delta(bitstream, be):
l, bitstream = read_gamma(bitstream, be)
if l == 0:
return 0, bitstream
f, bitstream = read_fixed(l, bitstream, be)
v = f + (1 << l) - 1
return v, bitstream
def read_delta_partial(bitstream, be):
try:
gamma_len, bitstream_after_gamma = read_gamma(bitstream, be)
bits_consumed = len(bitstream) - len(bitstream_after_gamma)
if gamma_len == 0:
return 0, bits_consumed
if len(bitstream_after_gamma) >= gamma_len:
f, bitstream_after_fixed = read_fixed(gamma_len, bitstream_after_gamma, be)
v = f + (1 << gamma_len) - 1
total_bits = len(bitstream) - len(bitstream_after_fixed)
return v, total_bits
else:
return gamma_len, bits_consumed | 0x80
except ValueError:
raise ValueError()
def write_delta(n, bitstream, be):
n += 1
l = floor(log2(n)) s = n - (1 << l)
bitstream = write_gamma(l, bitstream, be)
if l != 0:
bitstream = write_fixed(s, l, bitstream, be)
return bitstream
def len_delta(n):
n += 1
l = floor(log2(n)) return l + len_gamma(l)
assert write_delta(0, "", True) == "1"
assert write_delta(0, "", False) == "1"
assert write_delta(1, "", True) == "0100"
assert write_delta(1, "", False) == "0010"
assert write_delta(2, "", True) == "0101"
assert write_delta(2, "", False) == "1010"
assert write_delta(3, "", True) == "01100"
assert write_delta(3, "", False) == "00110"
assert write_delta(4, "", True) == "01101"
assert write_delta(4, "", False) == "01110"
assert write_delta(5, "", True) == "01110"
assert write_delta(5, "", False) == "10110"
for i in range(256):
wbe = write_delta(i, "", True)
rbe = read_delta(wbe, True)[0]
wle = write_delta(i, "", False)
rle = read_delta(wle, False)[0]
l = len_delta(i) assert i == rbe
assert i == rle
assert len(wbe) == l
assert len(wle) == l
def gen_delta(read_bits, write_max_val, len_max_val=None, merged_table=False):
assert read_bits > 0
len_max_val = len_max_val or write_max_val
code_name = "delta"
with open(os.path.join(ROOT, "{}_tables.rs".format(code_name)), "w") as f:
f.write(
"#![doc(hidden)]\n// THIS FILE HAS BEEN GENERATED BY THE SCRIPT {}\n".format(
os.path.basename(__file__)
)
)
f.write("// ~~~~~~~~~~~~~~~~~~~ DO NOT MODIFY ~~~~~~~~~~~~~~~~~~~~~~\n")
f.write(
"// Methods for reading and writing values using precomputed tables for {} codes\n".format(
code_name
)
)
f.write("use crate::traits::{BitRead, BitWrite, BE, LE};\n")
f.write("use num_primitive::PrimitiveNumber;\n")
f.write("/// How many bits are needed to read the tables\n")
f.write("pub const READ_BITS: usize = {};\n".format(read_bits))
f.write("/// Maximum value writable using the table(s)\n")
f.write("pub const WRITE_MAX: u64 = {};\n".format(write_max_val))
for bo in ["le", "be"]:
BO = bo.upper()
f.write(
"""
/// Reads from the decoding table.
///
/// Returns `(len_with_flag, value_or_gamma)` where:
/// - If len_with_flag >= 0: complete code, value_or_gamma is decoded value, len_with_flag is code length
/// - If len_with_flag < 0: partial code (gamma decoded), value_or_gamma is gamma_len, (len_with_flag & 0x7F) is gamma code length
/// - If len_with_flag = 0: no valid decoding (gamma not decoded)
///
/// The backend position is always advanced by (len_with_flag & 0x7F) bits.
/// Using signed i8 allows testing with `< 0` instead of masking, which is more efficient.
#[inline(always)]
pub fn read_table_%(bo)s<B: BitRead<%(BO)s>>(backend: &mut B) -> (i8, u64) {
if let Ok(idx) = backend.peek_bits(READ_BITS) {
let idx: usize = idx.as_to();
let len_with_flag = READ_LEN_%(BO)s[idx];
let value_or_gamma = READ_%(BO)s[idx] as u64;
backend.skip_bits_after_peek((len_with_flag & 0x7F) as usize);
(len_with_flag, value_or_gamma)
} else {
// Not enough bits available
(0, 0)
}
}
"""
% {"bo": bo, "BO": BO}
)
f.write(
"""
/// Writes a value using an encoding table.
///
/// If the result is `Some` the encoding was successful, and
/// length of the code is returned.
#[inline(always)]
pub fn write_table_le<B: BitWrite<LE>>(backend: &mut B, n: u64) -> Result<Option<usize>, B::Error> {
// We cannot use .get() here because n is a u64
if n >= WRITE_LE.len() as u64 {
return Ok(None);
}
let n = n as usize;
let len = WRITE_LEN_LE[n] as usize;
backend.write_bits(WRITE_LE[n] as u64, len)?;
Ok(Some(len))
}
/// Writes a value using an encoding table.
///
/// If the result is `Some` the encoding was successful, and
/// length of the code is returned.
#[inline(always)]
pub fn write_table_be<B: BitWrite<BE>>(backend: &mut B, n: u64) -> Result<Option<usize>, B::Error> {
// We cannot use .get() here because n is a u64
if n >= WRITE_BE.len() as u64 {
return Ok(None);
}
let n = n as usize;
let len = WRITE_LEN_BE[n] as usize;
backend.write_bits(WRITE_BE[n] as u64, len)?;
Ok(Some(len))
}
"""
)
for BO in ["BE", "LE"]:
codes = []
for value in range(0, 2**read_bits):
bits = ("{:0%sb}" % read_bits).format(value)
try:
value_or_gamma, len_with_flag = read_delta_partial(bits, BO == "BE")
codes.append((value_or_gamma, len_with_flag))
except ValueError:
codes.append((0, 0))
read_max_val = max(x[0] for x in codes)
read_max_len = max(x[1] & 0x7F for x in codes)
f.write("/// Precomputed table for reading {} codes\n".format(code_name))
f.write("/// For complete codes: stores the decoded value\n")
f.write(
"/// For partial codes: stores the gamma_len (length of the following fixed part)\n"
)
f.write(
"pub const READ_%s: &[%s] = &["
% (BO, get_best_fitting_type(log2(read_max_val + 1)))
)
for value_or_gamma, _ in codes:
f.write("{}, ".format(value_or_gamma))
f.write("];\n")
f.write(
"/// Precomputed lengths table for reading {} codes\n".format(code_name)
)
f.write("/// Positive (< 0x80): complete code length\n")
f.write(
"/// Negative (>= 0x80 when viewed as u8): (value & 0x7F) is gamma code length (bits consumed)\n"
)
f.write("/// Zero: no valid decoding (gamma not decoded)\n")
f.write(
"pub const READ_LEN_%s: &[%s] = &["
% (
BO,
get_best_fitting_type(log2(read_max_len + 1), signed=True),
)
)
for _, len_with_flag in codes:
f.write(
"{}, ".format(
len_with_flag if len_with_flag < 128 else len_with_flag - 256
)
)
f.write("];\n")
for bo in ["BE", "LE"]:
f.write(
"/// Table used to speed up the writing of {} codes\n".format(code_name)
)
f.write(
"pub const WRITE_%s: &[%s] = &["
% (bo, get_best_fitting_type(len_delta(write_max_val)))
)
len_bits = []
for n in range(write_max_val + 1):
bits = write_delta(n, "", bo == "BE")
len_bits.append(len(bits))
f.write("{},".format(int(bits, 2)))
f.write("];\n")
f.write(
"/// Table used to speed up the writing of {} codes\n".format(code_name)
)
f.write(
"pub const WRITE_LEN_%s: &[%s] = &["
% (bo, get_best_fitting_type(len_delta(write_max_val)))
)
for l in len_bits:
f.write("{}, ".format(l))
f.write("];\n")
f.write(
"/// Table used to speed up the skipping of {} codes\n".format(code_name)
)
f.write(
"pub const LEN: &[%s] = &["
% (get_best_fitting_type(ceil(log2(len_delta(len_max_val)))))
)
for n in range(write_max_val + 1):
f.write("{}, ".format(len_delta(n)))
f.write("];\n")
f.write(
"/// Asserts at compile time that `peek_bits` is large enough for these tables.\n"
)
f.write("pub const fn check_read_table(peek_bits: usize) {\n")
f.write(
' assert!(peek_bits >= READ_BITS, "BitRead peek word too small for %s code read tables (%d bits required)");\n'
% (code_name, read_bits)
)
f.write("}\n")
def read_minimal_binary(max, bitstream, be):
l = int(floor(log2(max))) v, bitstream = read_fixed(l, bitstream, be)
limit = (1 << (l + 1)) - max
if v < limit:
return v, bitstream
else:
b, bitstream = read_fixed(1, bitstream, be)
v = (v << 1) | b
return v - limit, bitstream
def write_minimal_binary(n, max, bitstream, be):
l = int(floor(log2(max))) limit = (1 << (l + 1)) - max
if n < limit:
return write_fixed(n, l, bitstream, be)
else:
to_write = n + limit
bitstream = write_fixed(to_write >> 1, l, bitstream, be)
return write_fixed(to_write & 1, 1, bitstream, be)
def len_minimal_binary(n, max):
l = int(floor(log2(max))) limit = (1 << (l + 1)) - max
if n >= limit:
return l + 1
else:
return l
assert write_minimal_binary(0, 10, "", True) == "000"
assert write_minimal_binary(0, 10, "", False) == "000"
assert write_minimal_binary(1, 10, "", True) == "001"
assert write_minimal_binary(1, 10, "", False) == "001"
assert write_minimal_binary(2, 10, "", True) == "010"
assert write_minimal_binary(2, 10, "", False) == "010"
assert write_minimal_binary(3, 10, "", True) == "011"
assert write_minimal_binary(3, 10, "", False) == "011"
assert write_minimal_binary(4, 10, "", True) == "100"
assert write_minimal_binary(4, 10, "", False) == "100"
assert write_minimal_binary(5, 10, "", True) == "101"
assert write_minimal_binary(5, 10, "", False) == "101"
assert write_minimal_binary(6, 10, "", True) == "1100"
assert write_minimal_binary(6, 10, "", False) == "0110"
assert write_minimal_binary(7, 10, "", True) == "1101"
assert write_minimal_binary(7, 10, "", False) == "1110"
assert write_minimal_binary(8, 10, "", True) == "1110"
assert write_minimal_binary(8, 10, "", False) == "0111"
assert write_minimal_binary(9, 10, "", True) == "1111"
assert write_minimal_binary(9, 10, "", False) == "1111"
_max = 200
for i in range(_max):
wbe = write_minimal_binary(i, _max, "", True)
rbe = read_minimal_binary(_max, wbe, True)[0]
wle = write_minimal_binary(i, _max, "", False)
rle = read_minimal_binary(_max, wle, False)[0]
l = len_minimal_binary(i, _max) assert i == rbe
assert i == rle
assert len(wbe) == l
assert len(wle) == l
def read_zeta(bitstream, k, be):
h, bitstream = read_unary(bitstream, be)
u = 2 ** ((h + 1) * k)
l = 2 ** (h * k) r, bitstream = read_minimal_binary(u - l, bitstream, be)
return l + r - 1, bitstream
def write_zeta(n, k, bitstream, be):
n += 1
h = int(floor(log2(n)) / k)
u = 2 ** ((h + 1) * k)
l = 2 ** (h * k)
bitstream = write_unary(h, bitstream, be)
bitstream = write_minimal_binary(n - l, u - l, bitstream, be)
return bitstream
def len_zeta(n, k):
n += 1
h = int(floor(log2(n)) / k)
u = 2 ** ((h + 1) * k)
l = 2 ** (h * k) return len_unary(h) + len_minimal_binary(n - l, u - l)
assert write_zeta(0, 3, "", True) == "100"
assert write_zeta(1, 3, "", True) == "1010"
assert write_zeta(2, 3, "", True) == "1011"
assert write_zeta(3, 3, "", True) == "1100"
assert write_zeta(4, 3, "", True) == "1101"
assert write_zeta(5, 3, "", True) == "1110"
assert write_zeta(6, 3, "", True) == "1111"
assert write_zeta(7, 3, "", True) == "0100000"
assert write_zeta(8, 3, "", True) == "0100001"
assert write_zeta(0, 3, "", False) == "001"
assert write_zeta(1, 3, "", False) == "0011"
assert write_zeta(2, 3, "", False) == "1011"
assert write_zeta(3, 3, "", False) == "0101"
assert write_zeta(4, 3, "", False) == "1101"
assert write_zeta(5, 3, "", False) == "0111"
assert write_zeta(6, 3, "", False) == "1111"
assert write_zeta(7, 3, "", False) == "0000010"
assert write_zeta(8, 3, "", False) == "0000110"
for i in range(256):
l = len_zeta(i, 3)
wbe = write_zeta(i, 3, "", True)
rbe = read_zeta(wbe, 3, True)[0]
assert i == rbe, "%s %s %s" % (i, rbe, wbe)
assert len(wbe) == l
wle = write_zeta(i, 3, "", False)
rle = read_zeta(wle, 3, False)[0]
assert i == rle, "%s %s %s" % (i, rle, wle)
assert len(wle) == l
def gen_zeta(read_bits, write_max_val, len_max_val=None, k=3, merged_table=False):
assert read_bits > 0
len_max_val = len_max_val or write_max_val
gen_table(
read_bits,
write_max_val,
len_max_val,
"zeta",
lambda n: len_zeta(n, k),
lambda bitstream, be: read_zeta(bitstream, k, be),
lambda n, bitstream, be: write_zeta(n, k, bitstream, be),
merged_table,
)
with open(os.path.join(ROOT, "zeta_tables.rs"), "a") as f:
f.write("/// The K of the zeta codes for these tables\n")
f.write("pub const K: usize = {};".format(k))
def read_omega_partial(bitstream, be):
n = 1
bits_consumed = 0
while True:
if be:
if not bitstream:
return n, (bits_consumed | 0x80) if bits_consumed > 0 else 0
if bitstream[0] == "0":
return n - 1, bits_consumed + 1
l = n
if len(bitstream) < l + 1:
return n, (bits_consumed | 0x80) if bits_consumed > 0 else 0
block = bitstream[0 : l + 1]
bitstream = bitstream[l + 1 :]
bits_consumed += l + 1
n = int(block, 2)
else: if not bitstream:
return n, (bits_consumed | 0x80) if bits_consumed > 0 else 0
if bitstream[-1] == "0":
return n - 1, bits_consumed + 1
l = n
if len(bitstream) < l + 1:
return n, (bits_consumed | 0x80) if bits_consumed > 0 else 0
block = bitstream[-(l + 1) :]
bitstream = bitstream[: -(l + 1)]
bits_consumed += l + 1
k = int(block, 2)
n = (k >> 1) | (1 << l)
def read_omega(bitstream, be):
value_or_n, len_with_flag = read_omega_partial(bitstream, be)
if len_with_flag == 0 or (len_with_flag & 0x80):
raise ValueError()
return value_or_n, bitstream[len_with_flag:] if be else bitstream[:-len_with_flag]
def _recursive_write_omega(n, bitstream, be):
if n <= 1:
return bitstream
l = floor(log2(n)) bitstream = _recursive_write_omega(l, bitstream, be)
if be:
return bitstream + ("{:0%sb}" % (l + 1)).format(n)
else:
n = (n << 1) | 1
mask = (1 << (l + 1)) - 1
n &= mask
return ("{:0%sb}" % (l + 1)).format(n) + bitstream
def write_omega(n, bitstream, be):
n += 1
bitstream = _recursive_write_omega(n, bitstream, be)
if be:
return bitstream + "0"
else:
return "0" + bitstream
def _recursive_len_omega(n):
if n <= 1:
return 1
l = floor(log2(n)) return _recursive_len_omega(l) + l + 1
def len_omega(n):
return _recursive_len_omega(n + 1)
assert write_omega(0, "", True) == "0"
assert write_omega(1, "", True) == "100"
assert write_omega(2, "", True) == "110"
assert write_omega(3, "", True) == "101000"
assert write_omega(4, "", True) == "101010"
assert write_omega(5, "", True) == "101100"
assert write_omega(6, "", True) == "101110"
assert write_omega(7, "", True) == "1110000"
assert write_omega(15, "", True) == "10100100000"
assert write_omega(99, "", True) == "1011011001000"
assert write_omega(999, "", True) == "11100111111010000"
assert write_omega(999999, "", True) == "1010010011111101000010010000000"
assert write_omega(0, "", False) == "0"
assert write_omega(1, "", False) == "001"
assert write_omega(2, "", False) == "011"
assert write_omega(3, "", False) == "000101"
assert write_omega(4, "", False) == "001101"
assert write_omega(5, "", False) == "010101"
assert write_omega(6, "", False) == "011101"
assert write_omega(7, "", False) == "0000111"
assert write_omega(15, "", False) == "00000100101"
assert write_omega(99, "", False) == "0100100110101"
assert write_omega(999, "", False) == "01111010001001111"
assert write_omega(999999, "", False) == "0111010000100100000010011100101"
for i in range(256):
wbe = write_omega(i, "", True)
rbe, rem_be = read_omega(wbe, True)
assert rem_be == ""
wle = write_omega(i, "", False)
rle, rem_le = read_omega(wle, False)
assert rem_le == ""
l = len_omega(i) assert i == rbe
assert i == rle
assert len(wbe) == l
assert len(wle) == l
def gen_omega(read_bits, write_max_val, len_max_val=None, merged_table=False):
assert read_bits > 0
len_max_val = len_max_val or write_max_val
code_name = "omega"
with open(os.path.join(ROOT, "{}_tables.rs".format(code_name)), "w") as f:
f.write(
"#![doc(hidden)]\n// THIS FILE HAS BEEN GENERATED BY THE SCRIPT {}\n".format(
os.path.basename(__file__)
)
)
f.write("// ~~~~~~~~~~~~~~~~~~~ DO NOT MODIFY ~~~~~~~~~~~~~~~~~~~~~~\n")
f.write(
"// Methods for reading and writing values using precomputed tables for {} codes\n".format(
code_name
)
)
f.write("use crate::traits::{BitRead, BitWrite, BE, LE};\n")
f.write("use num_primitive::PrimitiveNumber;\n")
f.write("/// How many bits are needed to read the tables\n")
f.write("pub const READ_BITS: usize = {};\n".format(read_bits))
f.write(
'const _: () = assert!(READ_BITS >= 2, "Read tables for Elias ω must use at least 2 bits");\n'
)
f.write("/// Maximum value writable using the table(s)\n")
f.write(
'const _: () = assert!(WRITE_MAX >= 62, "Write tables for Elias ω must represent 62");\n'
)
f.write("pub const WRITE_MAX: u64 = {};\n".format(write_max_val))
for bo in ["le", "be"]:
BO = bo.upper()
f.write(
"""
/// Reads from the decoding table.
///
/// Returns `(len_with_flag, value)` where:
/// - If len_with_flag >= 0: complete code, value is decoded value, len_with_flag is code length
/// - If len_with_flag < 0: partial code, value is partial_n, (len_with_flag & 0x7F) is partial_len
/// - If len_with_flag = 0: no valid decoding (cannot occur with >= 2 bit tables)
///
/// The backend position is always advanced by (len_with_flag & 0x7F) bits.
/// Using signed i8 allows testing with `< 0` instead of masking, which is more efficient.
#[inline(always)]
pub fn read_table_%(bo)s<B: BitRead<%(BO)s>>(backend: &mut B) -> (i8, u64) {
if let Ok(idx) = backend.peek_bits(READ_BITS) {
let idx: usize = idx.as_to();
let len_with_flag = READ_LEN_%(BO)s[idx];
let value = READ_%(BO)s[idx] as u64;
backend.skip_bits_after_peek((len_with_flag & 0x7F) as usize);
(len_with_flag, value)
} else {
// Not enough bits available - return initial state
(0, 1)
}
}
"""
% {"bo": bo, "BO": BO}
)
f.write(
"""
/// Writes a value using an encoding table.
///
/// If the result is `Some` the encoding was successful, and
/// length of the code is returned.
#[inline(always)]
pub fn write_table_le<B: BitWrite<LE>>(backend: &mut B, mut n: u64) -> Result<Option<usize>, B::Error> {
// We cannot use .get() here because n is a u64
if n < WRITE_LE.len() as u64 {
let n = n as usize;
let len = WRITE_LEN_LE[n] as usize;
backend.write_bits(WRITE_LE[n] as u64, len)?;
return Ok(Some(len));
}
n += 1;
let λ = n.ilog2() as usize;
let bits = WRITE_LE[λ - 1];
let len = WRITE_LEN_LE[λ - 1] as usize;
backend.write_bits(bits as u64, len - 1)?;
#[cfg(feature = "checks")]
{
// Clean up after the lowest λ bits in case checks are enabled
n &= u64::MAX >> (u64::BITS - (λ as u32));
}
// Little-endian case: rotate left the lower λ + 1 bits (the bit in
// position λ is a one) so that the lowest bit can be peeked to find the
// block.
backend.write_bits(n << 1 | 1, λ + 1)?;
backend.write_bits(0, 1)?;
Ok(Some(λ + len + 1))
}
/// Writes a value using an encoding table.
///
/// If the result is `Some` the encoding was successful, and
/// length of the code is returned.
#[inline(always)]
pub fn write_table_be<B: BitWrite<BE>>(backend: &mut B, mut n: u64) -> Result<Option<usize>, B::Error> {
// We cannot use .get() here because n is a u64
if n < WRITE_BE.len() as u64 {
let n = n as usize;
let len = WRITE_LEN_BE[n] as usize;
backend.write_bits(WRITE_BE[n] as u64, len)?;
return Ok(Some(len));
}
n += 1;
let λ = n.ilog2() as usize;
let bits = WRITE_BE[λ - 1];
let len = WRITE_LEN_BE[λ - 1] as usize;
backend.write_bits(bits as u64 >> 1, len - 1)?;
backend.write_bits(n, λ + 1)?;
backend.write_bits(0, 1)?;
Ok(Some(λ + len + 1))
}
"""
)
for BO in ["BE", "LE"]:
codes = []
for value in range(0, 2**read_bits):
bits = ("{:0%sb}" % read_bits).format(value)
value_or_n, len_with_flag = read_omega_partial(bits, BO == "BE")
codes.append((value_or_n, len_with_flag))
read_max_val = max(x[0] for x in codes)
read_max_len = max(x[1] & 0x7F for x in codes)
f.write("/// Precomputed table for reading {} codes\n".format(code_name))
f.write("/// For complete codes: stores the decoded value\n")
f.write("/// For partial codes: stores the partial n state\n")
f.write(
"pub const READ_%s: &[%s] = &["
% (BO, get_best_fitting_type(log2(read_max_val + 1)))
)
for value_or_n, _ in codes:
f.write("{}, ".format(value_or_n))
f.write("];\n")
f.write(
"/// Precomputed lengths table for reading {} codes\n".format(code_name)
)
f.write("/// Positive (< 0x80): complete code length\n")
f.write(
"/// Negative (>= 0x80 when viewed as u8): (value & 0x7F) is partial_len (bits consumed by complete blocks)\n"
)
f.write("/// Zero: no valid decoding (cannot occur with >= 2 bit tables)\n")
f.write(
"pub const READ_LEN_%s: &[%s] = &["
% (
BO,
get_best_fitting_type(log2(read_max_len + 1), signed=True),
)
)
for _, len_with_flag in codes:
f.write(
"{}, ".format(
len_with_flag if len_with_flag < 128 else len_with_flag - 256
)
)
f.write("];\n")
for bo in ["BE", "LE"]:
f.write(
"/// Table used to speed up the writing of {} codes\n".format(code_name)
)
f.write(
"pub const WRITE_%s: &[%s] = &["
% (bo, get_best_fitting_type(len_omega(write_max_val)))
)
len_bits = []
for n in range(write_max_val + 1):
bits = write_omega(n, "", bo == "BE")
len_bits.append(len(bits))
f.write("{},".format(int(bits, 2)))
f.write("];\n")
f.write(
"/// Table used to speed up the writing of {} codes\n".format(code_name)
)
f.write(
"pub const WRITE_LEN_%s: &[%s] = &["
% (bo, get_best_fitting_type(len_omega(write_max_val)))
)
for l in len_bits:
f.write("{}, ".format(l))
f.write("];\n")
f.write(
"/// Table used to speed up the skipping of {} codes\n".format(code_name)
)
f.write(
"pub const LEN: &[%s] = &["
% (get_best_fitting_type(ceil(log2(len_omega(len_max_val)))))
)
for n in range(write_max_val + 1):
f.write("{}, ".format(len_omega(n)))
f.write("];\n")
f.write(
"/// Asserts at compile time that `peek_bits` is large enough for these tables.\n"
)
f.write("pub const fn check_read_table(peek_bits: usize) {\n")
f.write(
' assert!(peek_bits >= READ_BITS, "BitRead peek word too small for %s code read tables (%d bits required)");\n'
% (code_name, read_bits)
)
f.write("}\n")
def read_rice(bitstream, k, be):
q, bitstream = read_unary(bitstream, be)
if k == 0:
return q, bitstream
r, bitstream = read_fixed(k, bitstream, be)
return (q << k) | r, bitstream
def write_rice(n, k, bitstream, be):
q = n >> k
r = n & ((1 << k) - 1)
bitstream = write_unary(q, bitstream, be)
if k != 0:
bitstream = write_fixed(r, k, bitstream, be)
return bitstream
def len_rice(n, k):
return (n >> k) + 1 + k
assert write_rice(0, 2, "", True) == "100"
assert write_rice(0, 2, "", False) == "001"
assert write_rice(1, 2, "", True) == "101"
assert write_rice(1, 2, "", False) == "011"
assert write_rice(2, 2, "", True) == "110"
assert write_rice(2, 2, "", False) == "101"
assert write_rice(3, 2, "", True) == "111"
assert write_rice(3, 2, "", False) == "111"
assert write_rice(4, 2, "", True) == "0100"
assert write_rice(4, 2, "", False) == "0010"
for k in range(4):
for i in range(256):
wbe = write_rice(i, k, "", True)
rbe = read_rice(wbe, k, True)[0]
wle = write_rice(i, k, "", False)
rle = read_rice(wle, k, False)[0]
l = len_rice(i, k) assert i == rbe
assert i == rle
assert len(wbe) == l
assert len(wle) == l
def read_pi(bitstream, k, be):
lam, bitstream = read_rice(bitstream, k, be)
if lam == 0:
return 0, bitstream
f, bitstream = read_fixed(lam, bitstream, be)
return (1 << lam) + f - 1, bitstream
def read_pi_partial(bitstream, k, be):
try:
lam, bitstream_after_rice = read_rice(bitstream, k, be)
bits_consumed = len(bitstream) - len(bitstream_after_rice)
if lam == 0:
return 0, bits_consumed
if len(bitstream_after_rice) >= lam:
f, bitstream_after_fixed = read_fixed(lam, bitstream_after_rice, be)
v = (1 << lam) + f - 1
total_bits = len(bitstream) - len(bitstream_after_fixed)
return v, total_bits
else:
return lam, bits_consumed | 0x80
except ValueError:
raise ValueError()
def write_pi(n, k, bitstream, be):
n += 1
lam = floor(log2(n))
s = n - (1 << lam)
bitstream = write_rice(lam, k, bitstream, be)
if lam != 0:
bitstream = write_fixed(s, lam, bitstream, be)
return bitstream
def len_pi(n, k):
n += 1
lam = floor(log2(n))
return len_rice(lam, k) + lam
assert write_pi(0, 2, "", True) == "100"
assert write_pi(0, 2, "", False) == "001"
assert write_pi(1, 2, "", True) == "1010"
assert write_pi(1, 2, "", False) == "0011"
assert write_pi(2, 2, "", True) == "1011"
assert write_pi(2, 2, "", False) == "1011"
assert write_pi(3, 2, "", True) == "11000"
assert write_pi(3, 2, "", False) == "00101"
assert write_pi(4, 2, "", True) == "11001"
assert write_pi(4, 2, "", False) == "01101"
assert write_pi(5, 2, "", True) == "11010"
assert write_pi(5, 2, "", False) == "10101"
assert write_pi(6, 2, "", True) == "11011"
assert write_pi(6, 2, "", False) == "11101"
assert write_pi(7, 2, "", True) == "111000"
assert write_pi(7, 2, "", False) == "000111"
for k in range(4):
for i in range(256):
wbe = write_pi(i, k, "", True)
rbe = read_pi(wbe, k, True)[0]
wle = write_pi(i, k, "", False)
rle = read_pi(wle, k, False)[0]
l = len_pi(i, k) assert i == rbe, "%s %s %s" % (i, rbe, wbe)
assert i == rle, "%s %s %s" % (i, rle, wle)
assert len(wbe) == l
assert len(wle) == l
def gen_pi(read_bits, write_max_val, len_max_val=None, k=2, merged_table=False):
assert read_bits > 0
len_max_val = len_max_val or write_max_val
code_name = "pi"
with open(os.path.join(ROOT, "{}_tables.rs".format(code_name)), "w") as f:
f.write(
"#![doc(hidden)]\n// THIS FILE HAS BEEN GENERATED BY THE SCRIPT {}\n".format(
os.path.basename(__file__)
)
)
f.write("// ~~~~~~~~~~~~~~~~~~~ DO NOT MODIFY ~~~~~~~~~~~~~~~~~~~~~~\n")
f.write(
"// Methods for reading and writing values using precomputed tables for {} codes\n".format(
code_name
)
)
f.write("use crate::traits::{BitRead, BitWrite, BE, LE};\n")
f.write("use num_primitive::PrimitiveNumber;\n")
f.write("/// How many bits are needed to read the tables\n")
f.write("pub const READ_BITS: usize = {};\n".format(read_bits))
f.write("/// Maximum value writable using the table(s)\n")
f.write("pub const WRITE_MAX: u64 = {};\n".format(write_max_val))
f.write("/// The K of the pi codes for these tables\n")
f.write("pub const K: usize = {};\n".format(k))
for bo in ["le", "be"]:
BO = bo.upper()
f.write(
"""
/// Reads from the decoding table.
///
/// Returns `(len_with_flag, value_or_lambda)` where:
/// - If len_with_flag >= 0: complete code, value_or_lambda is decoded value, len_with_flag is code length
/// - If len_with_flag < 0: partial code (rice decoded), value_or_lambda is lambda, (len_with_flag & 0x7F) is rice code length
/// - If len_with_flag = 0: no valid decoding (rice not decoded)
///
/// The backend position is always advanced by (len_with_flag & 0x7F) bits.
/// Using signed i8 allows testing with `< 0` instead of masking, which is more efficient.
#[inline(always)]
pub fn read_table_%(bo)s<B: BitRead<%(BO)s>>(backend: &mut B) -> (i8, u64) {
if let Ok(idx) = backend.peek_bits(READ_BITS) {
let idx: usize = idx.as_to();
let len_with_flag = READ_LEN_%(BO)s[idx];
let value_or_lambda = READ_%(BO)s[idx] as u64;
backend.skip_bits_after_peek((len_with_flag & 0x7F) as usize);
(len_with_flag, value_or_lambda)
} else {
// Not enough bits available
(0, 0)
}
}
"""
% {"bo": bo, "BO": BO}
)
f.write(
"""
/// Writes a value using an encoding table.
///
/// If the result is `Some` the encoding was successful, and
/// length of the code is returned.
#[inline(always)]
pub fn write_table_le<B: BitWrite<LE>>(backend: &mut B, n: u64) -> Result<Option<usize>, B::Error> {
// We cannot use .get() here because n is a u64
if n >= WRITE_LE.len() as u64 {
return Ok(None);
}
let n = n as usize;
let len = WRITE_LEN_LE[n] as usize;
backend.write_bits(WRITE_LE[n] as u64, len)?;
Ok(Some(len))
}
/// Writes a value using an encoding table.
///
/// If the result is `Some` the encoding was successful, and
/// length of the code is returned.
#[inline(always)]
pub fn write_table_be<B: BitWrite<BE>>(backend: &mut B, n: u64) -> Result<Option<usize>, B::Error> {
// We cannot use .get() here because n is a u64
if n >= WRITE_BE.len() as u64 {
return Ok(None);
}
let n = n as usize;
let len = WRITE_LEN_BE[n] as usize;
backend.write_bits(WRITE_BE[n] as u64, len)?;
Ok(Some(len))
}
"""
)
for BO in ["BE", "LE"]:
codes = []
for value in range(0, 2**read_bits):
bits = ("{:0%sb}" % read_bits).format(value)
try:
value_or_lam, len_with_flag = read_pi_partial(bits, k, BO == "BE")
codes.append((value_or_lam, len_with_flag))
except ValueError:
codes.append((0, 0))
read_max_val = max(x[0] for x in codes)
read_max_len = max(x[1] & 0x7F for x in codes)
f.write("/// Precomputed table for reading {} codes\n".format(code_name))
f.write("/// For complete codes: stores the decoded value\n")
f.write(
"/// For partial codes: stores the lambda (length of the following fixed part)\n"
)
f.write(
"pub const READ_%s: &[%s] = &["
% (BO, get_best_fitting_type(log2(read_max_val + 1)))
)
for value_or_lam, _ in codes:
f.write("{}, ".format(value_or_lam))
f.write("];\n")
f.write(
"/// Precomputed lengths table for reading {} codes\n".format(code_name)
)
f.write("/// Positive (< 0x80): complete code length\n")
f.write(
"/// Negative (>= 0x80 when viewed as u8): (value & 0x7F) is rice code length (bits consumed)\n"
)
f.write("/// Zero: no valid decoding (rice not decoded)\n")
f.write(
"pub const READ_LEN_%s: &[%s] = &["
% (
BO,
get_best_fitting_type(log2(read_max_len + 1), signed=True),
)
)
for _, len_with_flag in codes:
f.write(
"{}, ".format(
len_with_flag if len_with_flag < 128 else len_with_flag - 256
)
)
f.write("];\n")
for bo in ["BE", "LE"]:
f.write(
"/// Table used to speed up the writing of {} codes\n".format(code_name)
)
f.write(
"pub const WRITE_%s: &[%s] = &["
% (bo, get_best_fitting_type(len_pi(write_max_val, k)))
)
len_bits = []
for n in range(write_max_val + 1):
bits = write_pi(n, k, "", bo == "BE")
len_bits.append(len(bits))
f.write("{},".format(int(bits, 2)))
f.write("];\n")
f.write(
"/// Table used to speed up the writing of {} codes\n".format(code_name)
)
f.write(
"pub const WRITE_LEN_%s: &[%s] = &["
% (bo, get_best_fitting_type(len_pi(write_max_val, k)))
)
for l in len_bits:
f.write("{}, ".format(l))
f.write("];\n")
f.write(
"/// Table used to speed up the skipping of {} codes\n".format(code_name)
)
f.write(
"pub const LEN: &[%s] = &["
% (get_best_fitting_type(ceil(log2(len_pi(len_max_val, k)))))
)
for n in range(write_max_val + 1):
f.write("{}, ".format(len_pi(n, k)))
f.write("];\n")
f.write(
"/// Asserts at compile time that `peek_bits` is large enough for these tables.\n"
)
f.write("pub const fn check_read_table(peek_bits: usize) {\n")
f.write(
' assert!(peek_bits >= READ_BITS, "BitRead peek word too small for %s code read tables (%d bits required)");\n'
% (code_name, read_bits)
)
f.write("}\n")
def generate_default_tables():
gen_gamma(
read_bits=9, write_max_val=63, merged_table=False, )
gen_delta(
read_bits=10, write_max_val=255, merged_table=False,
)
gen_zeta(
read_bits=12, write_max_val=1023, k=3,
merged_table=False, )
gen_omega(
read_bits=10, write_max_val=63, merged_table=False,
)
gen_pi(
read_bits=10, write_max_val=1023, k=2,
merged_table=False,
)
subprocess.check_call(
"cargo fmt",
shell=True,
)
if __name__ == "__main__":
generate_default_tables()