use common_traits::*;
use crate::codes::params::{DefaultReadParams, ReadParams};
use crate::traits::*;
use core::convert::Infallible;
use core::{mem, ptr};
use std::error::Error;
type BB<WR> = <<WR as WordRead>::Word as DoubleType>::DoubleType;
#[derive(Debug)]
pub struct BufBitReader<E: Endianness, WR: WordRead, RP: ReadParams = DefaultReadParams>
where
    WR::Word: DoubleType,
{
    backend: WR,
    buffer: BB<WR>,
    bits_in_buffer: usize,
    _marker: core::marker::PhantomData<(E, RP)>,
}
impl<E: Endianness, WR: WordRead + Clone, RP: ReadParams> core::clone::Clone
    for BufBitReader<E, WR, RP>
where
    WR::Word: DoubleType,
{
    fn clone(&self) -> Self {
        Self {
            backend: self.backend.clone(),
            buffer: self.buffer,
            bits_in_buffer: self.bits_in_buffer,
            _marker: core::marker::PhantomData,
        }
    }
}
impl<E: Endianness, WR: WordRead, RP: ReadParams> BufBitReader<E, WR, RP>
where
    WR::Word: DoubleType,
{
    #[must_use]
    pub fn new(backend: WR) -> Self {
        check_tables(WR::Word::BITS + 1);
        Self {
            backend,
            buffer: BB::<WR>::ZERO,
            bits_in_buffer: 0,
            _marker: core::marker::PhantomData,
        }
    }
    pub fn into_inner(self) -> Result<WR, Infallible> {
        let backend = unsafe { ptr::read(&self.backend) };
        mem::forget(self);
        Ok(backend)
    }
}
impl<WR: WordRead, RP: ReadParams> BufBitReader<BE, WR, RP>
where
    WR::Word: DoubleType,
{
    #[inline(always)]
    fn refill(&mut self) -> Result<(), <WR as WordRead>::Error> {
        debug_assert!(BB::<WR>::BITS - self.bits_in_buffer >= WR::Word::BITS);
        let new_word: BB<WR> = self.backend.read_word()?.to_be().upcast();
        self.bits_in_buffer += WR::Word::BITS;
        self.buffer |= new_word << (BB::<WR>::BITS - self.bits_in_buffer);
        Ok(())
    }
}
impl<WR: WordRead, RP: ReadParams> BitRead<BE> for BufBitReader<BE, WR, RP>
where
    WR::Word: DoubleType + UpcastableInto<u64>,
    BB<WR>: CastableInto<u64>,
{
    type Error = <WR as WordRead>::Error;
    type PeekWord = BB<WR>;
    #[inline(always)]
    fn peek_bits(&mut self, n_bits: usize) -> Result<Self::PeekWord, Self::Error> {
        debug_assert!(n_bits > 0);
        debug_assert!(n_bits <= Self::PeekWord::BITS);
        if n_bits > self.bits_in_buffer {
            self.refill()?;
        }
        debug_assert!(n_bits <= self.bits_in_buffer);
        Ok(self.buffer >> (BB::<WR>::BITS - n_bits))
    }
    #[inline(always)]
    fn skip_bits_after_table_lookup(&mut self, n_bits: usize) {
        self.bits_in_buffer -= n_bits;
        self.buffer <<= n_bits;
    }
    #[inline]
    fn read_bits(&mut self, mut n_bits: usize) -> Result<u64, Self::Error> {
        debug_assert!(n_bits <= 64);
        debug_assert!(self.bits_in_buffer < BB::<WR>::BITS);
        if n_bits <= self.bits_in_buffer {
            let result: u64 = (self.buffer >> (BB::<WR>::BITS - n_bits - 1) >> 1_u32).cast();
            self.bits_in_buffer -= n_bits;
            self.buffer <<= n_bits;
            return Ok(result);
        }
        let mut result: u64 =
            (self.buffer >> (BB::<WR>::BITS - 1 - self.bits_in_buffer) >> 1_u8).cast();
        n_bits -= self.bits_in_buffer;
        while n_bits > WR::Word::BITS {
            let new_word: u64 = self.backend.read_word()?.to_be().upcast();
            result = (result << WR::Word::BITS) | new_word;
            n_bits -= WR::Word::BITS;
        }
        debug_assert!(n_bits > 0);
        debug_assert!(n_bits <= WR::Word::BITS);
        let new_word = self.backend.read_word()?.to_be();
        self.bits_in_buffer = WR::Word::BITS - n_bits;
        let upcasted: u64 = new_word.upcast();
        let final_bits: u64 = (upcasted >> self.bits_in_buffer).downcast();
        result = (result << (n_bits - 1) << 1) | final_bits;
        self.buffer = (UpcastableInto::<BB<WR>>::upcast(new_word)
            << (BB::<WR>::BITS - self.bits_in_buffer - 1))
            << 1;
        Ok(result)
    }
    #[inline]
    fn read_unary(&mut self) -> Result<u64, Self::Error> {
        debug_assert!(self.bits_in_buffer < BB::<WR>::BITS);
        let zeros: usize = self.buffer.leading_zeros() as _;
        if zeros < self.bits_in_buffer {
            self.buffer = self.buffer << zeros << 1;
            self.bits_in_buffer -= zeros + 1;
            return Ok(zeros as u64);
        }
        let mut result: u64 = self.bits_in_buffer as _;
        loop {
            let new_word = self.backend.read_word()?.to_be();
            if new_word != WR::Word::ZERO {
                let zeros: usize = new_word.leading_zeros() as _;
                self.buffer =
                    UpcastableInto::<BB<WR>>::upcast(new_word) << (WR::Word::BITS + zeros) << 1;
                self.bits_in_buffer = WR::Word::BITS - zeros - 1;
                return Ok(result + zeros as u64);
            }
            result += WR::Word::BITS as u64;
        }
    }
    #[inline]
    fn skip_bits(&mut self, mut n_bits: usize) -> Result<(), Self::Error> {
        debug_assert!(self.bits_in_buffer < BB::<WR>::BITS);
        if n_bits <= self.bits_in_buffer {
            self.bits_in_buffer -= n_bits;
            self.buffer <<= n_bits;
            return Ok(());
        }
        n_bits -= self.bits_in_buffer;
        while n_bits > WR::Word::BITS {
            let _ = self.backend.read_word()?;
            n_bits -= WR::Word::BITS;
        }
        let new_word = self.backend.read_word()?.to_be();
        self.bits_in_buffer = WR::Word::BITS - n_bits;
        self.buffer = UpcastableInto::<BB<WR>>::upcast(new_word)
            << (BB::<WR>::BITS - 1 - self.bits_in_buffer)
            << 1;
        Ok(())
    }
    #[cfg(not(feature = "no_copy_impls"))]
    fn copy_to<F: Endianness, W: BitWrite<F>>(
        &mut self,
        bit_write: &mut W,
        mut n: u64,
    ) -> Result<(), CopyError<Self::Error, W::Error>> {
        let from_buffer = Ord::min(n, self.bits_in_buffer as _);
        self.buffer = self.buffer.rotate_left(from_buffer as _);
        bit_write
            .write_bits(self.buffer.cast(), from_buffer as usize)
            .map_err(CopyError::WriteError)?;
        n -= from_buffer;
        if n == 0 {
            self.bits_in_buffer -= from_buffer as usize;
            return Ok(());
        }
        while n > WR::Word::BITS as u64 {
            bit_write
                .write_bits(
                    self.backend
                        .read_word()
                        .map_err(CopyError::ReadError)?
                        .to_be()
                        .upcast(),
                    WR::Word::BITS,
                )
                .map_err(CopyError::WriteError)?;
            n -= WR::Word::BITS as u64;
        }
        assert!(n > 0);
        let new_word = self
            .backend
            .read_word()
            .map_err(CopyError::ReadError)?
            .to_be();
        self.bits_in_buffer = WR::Word::BITS - n as usize;
        bit_write
            .write_bits((new_word >> self.bits_in_buffer).upcast(), n as usize)
            .map_err(CopyError::WriteError)?;
        self.buffer = UpcastableInto::<BB<WR>>::upcast(new_word)
            .rotate_right(WR::Word::BITS as u32 - n as u32);
        Ok(())
    }
}
impl<
        E: Error + Send + Sync + 'static,
        WR: WordRead<Error = E> + WordSeek<Error = E>,
        RP: ReadParams,
    > BitSeek for BufBitReader<BE, WR, RP>
where
    WR::Word: DoubleType,
{
    type Error = <WR as WordSeek>::Error;
    #[inline]
    fn bit_pos(&mut self) -> Result<u64, Self::Error> {
        Ok(self.backend.word_pos()? * WR::Word::BITS as u64 - self.bits_in_buffer as u64)
    }
    #[inline]
    fn set_bit_pos(&mut self, bit_index: u64) -> Result<(), Self::Error> {
        self.backend
            .set_word_pos(bit_index / WR::Word::BITS as u64)?;
        let bit_offset = (bit_index % WR::Word::BITS as u64) as usize;
        self.buffer = BB::<WR>::ZERO;
        self.bits_in_buffer = 0;
        if bit_offset != 0 {
            let new_word: BB<WR> = self.backend.read_word()?.to_be().upcast();
            self.bits_in_buffer = WR::Word::BITS - bit_offset;
            self.buffer = new_word << (BB::<WR>::BITS - self.bits_in_buffer);
        }
        Ok(())
    }
}
impl<WR: WordRead, RP: ReadParams> BufBitReader<LE, WR, RP>
where
    WR::Word: DoubleType,
{
    #[inline(always)]
    fn refill(&mut self) -> Result<(), <WR as WordRead>::Error> {
        debug_assert!(BB::<WR>::BITS - self.bits_in_buffer >= WR::Word::BITS);
        let new_word: BB<WR> = self.backend.read_word()?.to_le().upcast();
        self.buffer |= new_word << self.bits_in_buffer;
        self.bits_in_buffer += WR::Word::BITS;
        Ok(())
    }
}
impl<WR: WordRead, RP: ReadParams> BitRead<LE> for BufBitReader<LE, WR, RP>
where
    WR::Word: DoubleType + UpcastableInto<u64>,
    BB<WR>: CastableInto<u64>,
{
    type Error = <WR as WordRead>::Error;
    type PeekWord = BB<WR>;
    #[inline(always)]
    fn peek_bits(&mut self, n_bits: usize) -> Result<Self::PeekWord, Self::Error> {
        debug_assert!(n_bits > 0);
        debug_assert!(n_bits <= Self::PeekWord::BITS);
        if n_bits > self.bits_in_buffer {
            self.refill()?;
        }
        debug_assert!(n_bits <= self.bits_in_buffer);
        let shamt = BB::<WR>::BITS - n_bits;
        Ok((self.buffer << shamt) >> shamt)
    }
    #[inline(always)]
    fn skip_bits_after_table_lookup(&mut self, n_bits: usize) {
        self.bits_in_buffer -= n_bits;
        self.buffer >>= n_bits;
    }
    #[inline]
    fn read_bits(&mut self, mut n_bits: usize) -> Result<u64, Self::Error> {
        debug_assert!(n_bits <= 64);
        debug_assert!(self.bits_in_buffer < BB::<WR>::BITS);
        if n_bits <= self.bits_in_buffer {
            let result: u64 = (self.buffer & ((BB::<WR>::ONE << n_bits) - BB::<WR>::ONE)).cast();
            self.bits_in_buffer -= n_bits;
            self.buffer >>= n_bits;
            return Ok(result);
        }
        let mut result: u64 = self.buffer.cast();
        let mut bits_in_res = self.bits_in_buffer;
        while n_bits > WR::Word::BITS + bits_in_res {
            let new_word: u64 = self.backend.read_word()?.to_le().upcast();
            result |= new_word << bits_in_res;
            bits_in_res += WR::Word::BITS;
        }
        n_bits -= bits_in_res;
        debug_assert!(n_bits > 0);
        debug_assert!(n_bits <= WR::Word::BITS);
        let new_word = self.backend.read_word()?.to_le();
        self.bits_in_buffer = WR::Word::BITS - n_bits;
        let shamt = 64 - n_bits;
        let upcasted: u64 = new_word.upcast();
        let final_bits: u64 = ((upcasted << shamt) >> shamt).downcast();
        result |= final_bits << bits_in_res;
        self.buffer = UpcastableInto::<BB<WR>>::upcast(new_word) >> n_bits;
        Ok(result)
    }
    #[inline]
    fn read_unary(&mut self) -> Result<u64, Self::Error> {
        debug_assert!(self.bits_in_buffer < BB::<WR>::BITS);
        let zeros: usize = self.buffer.trailing_zeros() as usize;
        if zeros < self.bits_in_buffer {
            self.buffer = self.buffer >> zeros >> 1;
            self.bits_in_buffer -= zeros + 1;
            return Ok(zeros as u64);
        }
        let mut result: u64 = self.bits_in_buffer as _;
        loop {
            let new_word = self.backend.read_word()?.to_le();
            if new_word != WR::Word::ZERO {
                let zeros: usize = new_word.trailing_zeros() as _;
                self.buffer = UpcastableInto::<BB<WR>>::upcast(new_word) >> zeros >> 1;
                self.bits_in_buffer = WR::Word::BITS - zeros - 1;
                return Ok(result + zeros as u64);
            }
            result += WR::Word::BITS as u64;
        }
    }
    #[inline]
    fn skip_bits(&mut self, mut n_bits: usize) -> Result<(), Self::Error> {
        debug_assert!(self.bits_in_buffer < BB::<WR>::BITS);
        if n_bits <= self.bits_in_buffer {
            self.bits_in_buffer -= n_bits;
            self.buffer >>= n_bits;
            return Ok(());
        }
        n_bits -= self.bits_in_buffer;
        while n_bits > WR::Word::BITS {
            let _ = self.backend.read_word()?;
            n_bits -= WR::Word::BITS;
        }
        let new_word = self.backend.read_word()?.to_le();
        self.bits_in_buffer = WR::Word::BITS - n_bits;
        self.buffer = UpcastableInto::<BB<WR>>::upcast(new_word) >> n_bits;
        Ok(())
    }
    #[cfg(not(feature = "no_copy_impls"))]
    fn copy_to<F: Endianness, W: BitWrite<F>>(
        &mut self,
        bit_write: &mut W,
        mut n: u64,
    ) -> Result<(), CopyError<Self::Error, W::Error>> {
        let from_buffer = Ord::min(n, self.bits_in_buffer as _);
        bit_write
            .write_bits(self.buffer.cast(), from_buffer as usize)
            .map_err(CopyError::WriteError)?;
        self.buffer >>= from_buffer;
        n -= from_buffer;
        if n == 0 {
            self.bits_in_buffer -= from_buffer as usize;
            return Ok(());
        }
        while n > WR::Word::BITS as u64 {
            bit_write
                .write_bits(
                    self.backend
                        .read_word()
                        .map_err(CopyError::ReadError)?
                        .to_le()
                        .upcast(),
                    WR::Word::BITS,
                )
                .map_err(CopyError::WriteError)?;
            n -= WR::Word::BITS as u64;
        }
        assert!(n > 0);
        let new_word = self
            .backend
            .read_word()
            .map_err(CopyError::ReadError)?
            .to_le();
        self.bits_in_buffer = WR::Word::BITS - n as usize;
        bit_write
            .write_bits(new_word.upcast(), n as usize)
            .map_err(CopyError::WriteError)?;
        self.buffer = UpcastableInto::<BB<WR>>::upcast(new_word) >> n;
        Ok(())
    }
}
impl<
        E: Error + Send + Sync + 'static,
        WR: WordRead<Error = E> + WordSeek<Error = E>,
        RP: ReadParams,
    > BitSeek for BufBitReader<LE, WR, RP>
where
    WR::Word: DoubleType,
{
    type Error = <WR as WordSeek>::Error;
    #[inline]
    fn bit_pos(&mut self) -> Result<u64, Self::Error> {
        Ok(self.backend.word_pos()? * WR::Word::BITS as u64 - self.bits_in_buffer as u64)
    }
    #[inline]
    fn set_bit_pos(&mut self, bit_index: u64) -> Result<(), Self::Error> {
        self.backend
            .set_word_pos(bit_index / WR::Word::BITS as u64)?;
        let bit_offset = (bit_index % WR::Word::BITS as u64) as usize;
        self.buffer = BB::<WR>::ZERO;
        self.bits_in_buffer = 0;
        if bit_offset != 0 {
            let new_word: BB<WR> = self.backend.read_word()?.to_le().upcast();
            self.bits_in_buffer = WR::Word::BITS - bit_offset;
            self.buffer = new_word >> bit_offset;
        }
        Ok(())
    }
}
macro_rules! test_buf_bit_reader {
    ($f: ident, $word:ty) => {
        #[test]
        fn $f() -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
            use super::MemWordWriterVec;
            use crate::{
                codes::{GammaRead, GammaWrite},
                prelude::{
                    len_delta, len_gamma, BufBitWriter, DeltaRead, DeltaWrite, MemWordReader,
                },
            };
            use rand::Rng;
            use rand::{rngs::SmallRng, SeedableRng};
            let mut buffer_be: Vec<$word> = vec![];
            let mut buffer_le: Vec<$word> = vec![];
            let mut big = super::BufBitWriter::<BE, _>::new(MemWordWriterVec::new(&mut buffer_be));
            let mut little = BufBitWriter::<LE, _>::new(MemWordWriterVec::new(&mut buffer_le));
            let mut r = SmallRng::seed_from_u64(0);
            const ITER: usize = 1_000_000;
            for _ in 0..ITER {
                let value = r.gen_range(0..128);
                assert_eq!(big.write_gamma(value)?, len_gamma(value));
                let value = r.gen_range(0..128);
                assert_eq!(little.write_gamma(value)?, len_gamma(value));
                let value = r.gen_range(0..128);
                assert_eq!(big.write_gamma(value)?, len_gamma(value));
                let value = r.gen_range(0..128);
                assert_eq!(little.write_gamma(value)?, len_gamma(value));
                let value = r.gen_range(0..128);
                assert_eq!(big.write_delta(value)?, len_delta(value));
                let value = r.gen_range(0..128);
                assert_eq!(little.write_delta(value)?, len_delta(value));
                let value = r.gen_range(0..128);
                assert_eq!(big.write_delta(value)?, len_delta(value));
                let value = r.gen_range(0..128);
                assert_eq!(little.write_delta(value)?, len_delta(value));
                let n_bits = r.gen_range(0..=64);
                if n_bits == 0 {
                    big.write_bits(0, 0)?;
                } else {
                    big.write_bits(1, n_bits)?;
                }
                let n_bits = r.gen_range(0..=64);
                if n_bits == 0 {
                    little.write_bits(0, 0)?;
                } else {
                    little.write_bits(1, n_bits)?;
                }
                let value = r.gen_range(0..128);
                assert_eq!(big.write_unary(value)?, value as usize + 1);
                let value = r.gen_range(0..128);
                assert_eq!(little.write_unary(value)?, value as usize + 1);
            }
            drop(big);
            drop(little);
            type ReadWord = $word;
            let be_trans: &[ReadWord] = unsafe {
                core::slice::from_raw_parts(
                    buffer_be.as_ptr() as *const ReadWord,
                    buffer_be.len()
                        * (core::mem::size_of::<$word>() / core::mem::size_of::<ReadWord>()),
                )
            };
            let le_trans: &[ReadWord] = unsafe {
                core::slice::from_raw_parts(
                    buffer_le.as_ptr() as *const ReadWord,
                    buffer_le.len()
                        * (core::mem::size_of::<$word>() / core::mem::size_of::<ReadWord>()),
                )
            };
            let mut big_buff = BufBitReader::<BE, _>::new(MemWordReader::new(be_trans));
            let mut little_buff = BufBitReader::<LE, _>::new(MemWordReader::new(le_trans));
            let mut r = SmallRng::seed_from_u64(0);
            for _ in 0..ITER {
                assert_eq!(big_buff.read_gamma()?, r.gen_range(0..128));
                assert_eq!(little_buff.read_gamma()?, r.gen_range(0..128));
                assert_eq!(big_buff.read_gamma()?, r.gen_range(0..128));
                assert_eq!(little_buff.read_gamma()?, r.gen_range(0..128));
                assert_eq!(big_buff.read_delta()?, r.gen_range(0..128));
                assert_eq!(little_buff.read_delta()?, r.gen_range(0..128));
                assert_eq!(big_buff.read_delta()?, r.gen_range(0..128));
                assert_eq!(little_buff.read_delta()?, r.gen_range(0..128));
                let n_bits = r.gen_range(0..=64);
                if n_bits == 0 {
                    assert_eq!(big_buff.read_bits(0)?, 0);
                } else {
                    assert_eq!(big_buff.read_bits(n_bits)?, 1);
                }
                let n_bits = r.gen_range(0..=64);
                if n_bits == 0 {
                    assert_eq!(little_buff.read_bits(0)?, 0);
                } else {
                    assert_eq!(little_buff.read_bits(n_bits)?, 1);
                }
                assert_eq!(big_buff.read_unary()?, r.gen_range(0..128));
                assert_eq!(little_buff.read_unary()?, r.gen_range(0..128));
            }
            Ok(())
        }
    };
}
test_buf_bit_reader!(test_u64, u64);
test_buf_bit_reader!(test_u32, u32);
test_buf_bit_reader!(test_u16, u16);