1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
use std::fmt::{Debug, Formatter};
use std::fmt;

use crate::bits;
use crate::errors::MisalignedBitReaderError;

const LEFT_MASKS: [u8; 8] = [
  0xff,
  0x7f,
  0x3f,
  0x1f,
  0x0f,
  0x07,
  0x03,
  0x01,
];
const RIGHT_MASKS: [u8; 8] = [
  0x00,
  0x80,
  0xc0,
  0xe0,
  0xf0,
  0xf8,
  0xfc,
  0xfe,
];

#[derive(Clone)]
pub struct BitReader {
  bytes: Vec<u8>,
  i: usize,
  j: usize,
}

impl Debug for BitReader {
  fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
    let current_info = if self.i < self.bytes.len() {
      format!(
        "current byte {}",
        self.bytes[self.i],
      )
    } else {
      "OOB".to_string()
    };

    write!(
      f,
      "BitReader(\n\tbyte {}/{} bit {}\n\t{}\n\t)",
      self.i,
      self.bytes.len(),
      self.j,
      current_info,
    )
  }
}

impl From<Vec<u8>> for BitReader {
  fn from(bytes: Vec<u8>) -> BitReader {
    BitReader {
      bytes,
      i: 0,
      j: 0,
    }
  }
}

impl BitReader {
  fn refresh_if_needed(&mut self) {
    if self.j == 8 {
      self.i += 1;
      self.j = 0;
    }
  }

  pub fn read_bytes(&mut self, n: usize) -> Result<&[u8], MisalignedBitReaderError> {
    self.refresh_if_needed();

    if self.j == 0 {
      let res = &self.bytes[self.i..self.i + n];
      self.i += n - 1;
      self.j = 8;
      Ok(res)
    } else {
      Err(MisalignedBitReaderError {})
    }
  }

  pub fn read(&mut self, n: usize) -> Vec<bool> {
    let mut res = Vec::with_capacity(n);

    // implementation not well optimized because this is only used in reading header
    let mut byte = self.bytes[self.i];
    for _ in 0..n {
      if self.j == 8 {
        self.i += 1;
        self.j = 0;
        byte = self.bytes[self.i];
      }
      res.push(bits::bit_from_byte(byte, self.j));
      self.j += 1;
    }
    res
  }

  pub fn read_u64(&mut self, n: usize) -> u64 {
    if n == 0 {
      return 0;
    }

    self.refresh_if_needed();

    let n_plus_j = n + self.j;
    if n_plus_j < 8 {
      // it's all in the current byte
      let shift = 8 - n_plus_j;
      let res = ((self.bytes[self.i] & LEFT_MASKS[self.j] & RIGHT_MASKS[n_plus_j]) >> shift) as u64;
      self.j = n_plus_j;
      res
    } else {
      let mut res = 0;
      let mut remaining = n; // how many bits we still need to read
      // let mut s = 0;  // number of bits read into the u64 so far
      remaining -= 8 - self.j;
      res |= ((self.bytes[self.i] & LEFT_MASKS[self.j]) as u64) << remaining;
      while remaining >= 8 {
        self.i += 1;
        remaining -= 8;
        res |= (self.bytes[self.i] as u64) << remaining;
      }
      if remaining > 0 {
        self.i += 1;
        let shift = 8 - remaining;
        res |= ((self.bytes[self.i] & RIGHT_MASKS[remaining]) >> shift) as u64;
        self.j = remaining;
      } else {
        self.j = 8;
      }
      res
    }
  }

  pub fn read_varint(&mut self, jumpstart: usize) -> usize {
    let mut res = self.read_u64(jumpstart) as usize;
    let mut mask = 1 << jumpstart;
    while self.read_one() {
      if self.read_one() {
        res |= mask;
      }
      mask <<= 1;
    }
    res
  }

  pub fn read_one(&mut self) -> bool {
    self.refresh_if_needed();

    let res = bits::bit_from_byte(self.bytes[self.i], self.j);
    self.j += 1;
    res
  }
}