use crate::core::shared::{BitWriter, ByteWriter};
use super::{
Buffer, MsbFirst, OrderConfig, RawBuffer
};
#[derive(Debug)]
#[repr(C)]
pub struct Writer<Order: OrderConfig = MsbFirst, const BIT_CODER_ENABLED: bool = false> {
ptr: *mut u8,
num_bits: usize,
pos_in_curr_byte: u8,
buffer: RawBuffer,
_phantom: std::marker::PhantomData<Order>,
}
impl<Order: OrderConfig, const BIT_CODER_ENABLED: bool> Into<Buffer<Order>> for Writer<Order, BIT_CODER_ENABLED> {
fn into(self) -> Buffer<Order> {
Buffer::<Order> {
data: self.buffer,
len: self.num_bits,
_phantom: std::marker::PhantomData,
}
}
}
impl<Order: OrderConfig> BitWriter for Writer<Order, true> {
type ByteWriter = Writer<Order, false>;
fn write_bits(&mut self, (size, value): (u8, u64)) {
assert!(size <= 64 && size > 0, "Invalid size: {}", size);
while size as usize + self.num_bits >= self.buffer.cap << 3 {
self.buffer.double();
}
unsafe{
self.ptr = self.buffer.as_ptr().add(self.num_bits >> 3);
}
debug_assert!(
size==64 || value >> size==0,
"Invalid Data: 'value' has more than 'size' bits of data: {:?}",
(size, value)
);
unsafe{ self.write_bits_unchecked((size, value)) }
}
fn into_byte_writer(&mut self) -> &mut Self::ByteWriter {
if self.pos_in_curr_byte != 0 {
if self.num_bits + 8 >= self.buffer.cap << 3 {
self.buffer.double();
}
self.ptr = unsafe{ self.ptr.add(1) };
self.num_bits = self.num_bits + (8 - self.pos_in_curr_byte as usize);
self.pos_in_curr_byte = 0;
}
unsafe {
&mut *(self as *mut Self as *mut Writer<Order, false>)
}
}
}
impl<Order: OrderConfig> ByteWriter for Writer<Order, false> {
type BitWriter = Writer<Order, true>;
fn write_byte(&mut self, data: u8) {
unsafe{
self.ptr.write(data);
}
while self.num_bits + 8 >= self.buffer.cap << 3 {
self.buffer.double();
}
unsafe {
self.ptr = self.ptr.add(1);
}
self.num_bits += 8;
}
fn into_bit_writer(&mut self) -> &mut Self::BitWriter {
let ptr = self as *mut Self;
unsafe {
&mut *(ptr as *mut Writer<Order, true>)
}
}
}
impl<Order: OrderConfig> Writer<Order, true> {
pub unsafe fn write_bits_unchecked(&mut self, (size, value): (u8, u64)) {
self.num_bits = self.num_bits.unchecked_add(size as usize);
let mut offset = if Order::IS_MSB_FIRST{ size } else { 0 };
if self.pos_in_curr_byte != 0 {
let num_remaining_in_curr_byte = 8_u8.unchecked_sub(self.pos_in_curr_byte);
if size <= num_remaining_in_curr_byte {
unsafe {
*self.ptr |= if Order::IS_MSB_FIRST {
(value & ((1<<num_remaining_in_curr_byte)-1)) << num_remaining_in_curr_byte.unchecked_sub(size)
} else {
value << self.pos_in_curr_byte
} as u8;
}
self.pos_in_curr_byte = if size == num_remaining_in_curr_byte {
self.ptr = unsafe{ self.ptr.add(1) };
0
} else {
self.pos_in_curr_byte.unchecked_add(size)
};
return;
}
unsafe {
*self.ptr |= if Order::IS_MSB_FIRST {
(value >> size.unchecked_sub(num_remaining_in_curr_byte)) & ((1<<num_remaining_in_curr_byte)-1)
} else {
value << self.pos_in_curr_byte
} as u8;
}
self.ptr = self.ptr.add(1);
offset = if Order::IS_MSB_FIRST {
size.unchecked_sub(num_remaining_in_curr_byte)
} else {
num_remaining_in_curr_byte
}
}
for _ in 0.. if Order::IS_MSB_FIRST{ offset } else { size.unchecked_sub(offset) } >> 3 {
if Order::IS_MSB_FIRST{ offset = offset.unchecked_sub(8) };
unsafe {
self.ptr.write((value >> offset) as u8);
}
if !Order::IS_MSB_FIRST{ offset = offset.unchecked_add(8) };
self.ptr = self.ptr.add(1);
}
unsafe {
self.ptr.write(if Order::IS_MSB_FIRST {(value & ((1<<offset)-1))<<(8_u8.unchecked_sub(offset))} else {value >> offset} as u8);
}
self.pos_in_curr_byte = if Order::IS_MSB_FIRST {offset} else {size.unchecked_sub(offset) & 7};
}
}
impl<Order: OrderConfig, const BIT_CODER_ENABLED: bool> Writer<Order, BIT_CODER_ENABLED> {
pub fn new() -> Self {
let buffer= RawBuffer::with_capacity(1);
Self {
ptr: buffer.as_ptr(),
num_bits: 0,
pos_in_curr_byte: 0,
buffer,
_phantom: std::marker::PhantomData,
}
}
pub fn with_cap(len: usize) -> Self {
let cap = (len + 7) >> 3;
let buffer = RawBuffer::with_capacity(cap);
Self {
ptr: buffer.data.as_ptr(),
num_bits: 0,
pos_in_curr_byte: 0,
buffer,
_phantom: std::marker::PhantomData,
}
}
}