use std::cell::RefCell;
use seq_macro::seq;
use super::constant::MAX_RICE_PARAMETER;
use super::constant::MAX_RICE_PARTITIONS;
use super::constant::MAX_RICE_PARTITION_ORDER;
use super::constant::MIN_RICE_PARTITION_SIZE;
#[derive(Clone, Debug, PartialEq, PartialOrd)]
struct PrcBitTable {
p_to_bits: std::simd::u32x16,
mask: std::simd::Mask<<u32 as std::simd::SimdElement>::Mask, 16>,
}
static ZEROS: std::simd::u32x16 = std::simd::u32x16::from_array([0u32; 16]);
static INDEX: std::simd::u32x16 =
std::simd::u32x16::from_array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
static MAXES: std::simd::u32x16 = std::simd::u32x16::from_array([u32::MAX; 16]);
impl PrcBitTable {
pub fn zero(max_p: usize) -> Self {
debug_assert!(max_p <= MAX_RICE_PARAMETER);
Self {
p_to_bits: ZEROS,
mask: INDEX.lanes_le(std::simd::u32x16::splat(max_p as u32)),
}
}
pub fn from_signal(signal: &[i32], max_p: usize, offset: usize) -> Self {
let mut ret = Self::zero(max_p);
ret.init_with_signal(signal, offset);
ret
}
#[allow(unused_assignments, clippy::identity_op)]
fn init_with_signal(&mut self, signal: &[i32], offset: usize) {
let mut p_to_bits: [u32; 16] = [0u32; 16];
seq!(p in 0..15 {
p_to_bits[p] = (offset + signal.len() * (p + 1)) as u32;
});
for v in signal.iter().map(|x| encode_signbit(*x)) {
let mut v = v;
seq!(p in 0..15 {
p_to_bits[p] += v as u32;
v >>= 1;
});
}
self.p_to_bits = std::simd::u32x16::from_array(p_to_bits);
}
#[cfg(test)]
pub fn bits(&self, p: usize) -> usize {
self.p_to_bits[p] as usize
}
#[inline]
pub fn minimizer(&self) -> (usize, usize) {
let ret_bits = self.mask.select(self.p_to_bits, MAXES).reduce_min();
let ret_p = self
.p_to_bits
.lanes_eq(std::simd::u32x16::splat(ret_bits))
.select(INDEX, ZEROS)
.reduce_max();
(ret_p as usize, ret_bits as usize)
}
#[allow(unused_comparisons)]
#[inline]
pub fn merge(&self, other: &Self, offset: usize) -> Self {
let offset = std::simd::u32x16::splat(offset as u32);
let offset = self.mask.select(offset, ZEROS);
Self {
p_to_bits: self.p_to_bits + other.p_to_bits - offset,
mask: self.mask,
}
}
}
#[inline]
fn finest_partition_order(size: usize, min_part_size: usize) -> usize {
assert!(min_part_size >= 1);
let max_splits: u32 = (size / min_part_size) as u32;
let max_order_for_min_part = (32 - max_splits.leading_zeros() - 1) as usize;
std::cmp::min(
MAX_RICE_PARTITION_ORDER,
std::cmp::min(max_order_for_min_part, size.trailing_zeros() as usize),
)
}
#[inline]
pub const fn encode_signbit(v: i32) -> u32 {
let is_negative = if v < 0 { 1 } else { 0 };
v.unsigned_abs() * 2 - is_negative
}
#[inline]
pub const fn decode_signbit(v: u32) -> i32 {
let is_negative = v % 2 == 1;
if is_negative {
-(((v >> 1) + 1) as i32)
} else {
(v >> 1) as i32
}
}
fn eval_partitions(tables: &[PrcBitTable], ps: &mut [usize]) -> usize {
assert!(ps.len() >= tables.len());
let mut sum_bits = 0;
for (dest, t) in ps.iter_mut().zip(tables) {
let (p, bits) = t.minimizer();
sum_bits += bits;
*dest = p;
}
sum_bits
}
fn merge_partitions(tables: &mut [PrcBitTable]) -> usize {
assert!(tables.len() < MAX_RICE_PARTITIONS);
let merged_len = tables.len() / 2;
for part_id in 0..merged_len {
tables[part_id] = tables[part_id * 2].merge(&tables[part_id * 2 + 1], 4);
}
merged_len
}
pub struct PrcParameter {
pub order: usize,
pub ps: Vec<u8>,
pub code_bits: usize,
}
impl PrcParameter {
pub fn new(order: usize, ps: Vec<u8>, code_bits: usize) -> Self {
Self {
order,
ps,
code_bits,
}
}
}
struct PrcParameterFinder {
pub tables: Vec<PrcBitTable>,
pub ps: Vec<usize>,
pub min_ps: Vec<usize>,
}
impl PrcParameterFinder {
pub const fn new() -> Self {
Self {
tables: Vec::new(),
ps: Vec::new(),
min_ps: Vec::new(),
}
}
pub fn find(&mut self, signal: &[i32], warmup_length: usize, max_p: usize) -> PrcParameter {
let mut partition_order = finest_partition_order(
signal.len(),
std::cmp::max(MIN_RICE_PARTITION_SIZE, warmup_length),
);
let mut nparts = 1 << (partition_order as i32);
self.tables.clear();
self.min_ps.resize(nparts, 0);
let part_size = signal.len() / nparts;
for p in 0..nparts {
let start = std::cmp::max(p * part_size, warmup_length);
let end = (p + 1) * part_size;
let table = PrcBitTable::from_signal(&signal[start..end], max_p, 4);
self.tables.push(table);
}
let mut min_bits = eval_partitions(&self.tables, &mut self.min_ps);
let mut min_order = partition_order;
while nparts > 1 {
nparts = merge_partitions(&mut self.tables[0..nparts]);
partition_order -= 1;
self.ps.resize(nparts, 0);
let next_bits = eval_partitions(&self.tables[0..nparts], &mut self.ps);
if next_bits < min_bits {
min_bits = next_bits;
self.min_ps = self.ps.clone();
min_order = partition_order;
}
}
self.min_ps.truncate(1usize << min_order);
PrcParameter::new(
min_order as usize,
self.min_ps.iter().map(|x| *x as u8).collect(),
min_bits,
)
}
}
thread_local! {
static RICE_PARAMETER_FINDER: RefCell<PrcParameterFinder> = RefCell::new(PrcParameterFinder::new());
}
pub fn find_partitioned_rice_parameter(
signal: &[i32],
warmup_length: usize,
max_p: usize,
) -> PrcParameter {
RICE_PARAMETER_FINDER.with(|finder| finder.borrow_mut().find(signal, warmup_length, max_p))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helper;
#[test]
fn bit_table_initialization() {
let table = PrcBitTable::from_signal(&[3, 4, 5, 6], 2, 4);
assert_eq!(table.bits(0), 3 * 2 + 4 * 2 + 5 * 2 + 6 * 2 + 8);
assert_eq!(table.bits(1), 3 + 4 + 5 + 6 + 8 + 4);
}
#[test]
fn prc_parameter_search() {
let signal = test_helper::constant_plus_noise(64, 0, 4096);
let table = PrcBitTable::from_signal(&signal, 14, 4);
let (p, _bits) = table.minimizer();
eprintln!("Table = {:?}", table);
eprintln!("Found p = {}", p);
assert!(p < 13);
assert!(p > 0);
}
#[test]
fn finest_partition_order_search() {
assert_eq!(finest_partition_order(64, 4), 4);
assert_eq!(finest_partition_order(64, 3), 4);
assert_eq!(finest_partition_order(192, 1), 6);
assert_eq!(finest_partition_order(192, 3), 6);
assert_eq!(finest_partition_order(192, 4), 5);
}
#[test]
fn partitioned_rice_parameter_search() {
let signal_left = test_helper::constant_plus_noise(64, 0, 2048);
let signal_right = test_helper::constant_plus_noise(64, 0, 12);
let signal = [signal_left, signal_right].concat();
let (_single_param, single_bits) =
PrcBitTable::from_signal(&signal[4..], 14, 4).minimizer();
let prc_p = super::find_partitioned_rice_parameter(&signal, 4, 14);
assert!(prc_p.code_bits <= single_bits);
assert_eq!(prc_p.order, 1); }
#[test]
fn partition_evaluation() {
let mut part1 = PrcBitTable::zero(4);
part1.p_to_bits[0..5].copy_from_slice(&[17, 19, 15, 11, 19]);
let mut part2 = PrcBitTable::zero(4);
part2.p_to_bits[0..5].copy_from_slice(&[12, 14, 16, 18, 20]);
let mut params = [0, 0];
let min_bits = eval_partitions(&[part1, part2], &mut params);
assert_eq!(min_bits, 23);
assert_eq!(params, [3, 0]);
}
#[test]
fn partition_merging() {
let mut part1 = PrcBitTable::zero(4);
part1.p_to_bits[0..5].copy_from_slice(&[17, 19, 15, 11, 19]);
let mut part2 = PrcBitTable::zero(4);
part2.p_to_bits[0..5].copy_from_slice(&[12, 14, 16, 18, 20]);
let mut table = [part1, part2];
let table_size = merge_partitions(&mut table);
assert_eq!(table_size, 1);
assert_eq!(table[0].p_to_bits[0..5], [25, 29, 27, 25, 35]);
}
#[test]
fn minimizer_search() {
let mut bt = PrcBitTable::zero(4);
bt.p_to_bits[0..8].copy_from_slice(&[6, 7, 4, 5, 9, 0, 0, 0]);
assert_eq!(bt.minimizer(), (2, 4));
let mut bt = PrcBitTable::zero(4);
bt.p_to_bits[0..8].copy_from_slice(&[6, 7, 8, 5, 3, 0, 0, 0]);
assert_eq!(bt.minimizer(), (4, 3));
let mut bt = PrcBitTable::zero(4);
bt.p_to_bits[0..8].copy_from_slice(&[1, 7, 8, 5, 3, 0, 0, 0]);
assert_eq!(bt.minimizer(), (0, 1));
let mut bt = PrcBitTable::zero(4);
bt.p_to_bits[0..8].copy_from_slice(&[1, 7, 1, 1, 3, 0, 0, 0]);
assert_eq!(bt.minimizer(), (3, 1));
}
}