use std::io;
pub struct BitWriter<W: io::Write> {
w: W,
cache: u8,
cache_len: usize,
total_written: usize,
}
impl<W: io::Write> From<W> for BitWriter<W> {
fn from(w: W) -> Self {
BitWriter {
w,
cache: 0,
cache_len: 0,
total_written: 0,
}
}
}
impl<W: io::Write> io::Write for BitWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
for b in buf {
for i in 0..8 {
self.write_bit((b & (1 << (7 - i))) != 0)?;
}
}
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.w.flush()
}
}
impl<W: io::Write> BitWriter<W> {
pub fn new(w: W) -> BitWriter<W> {
BitWriter::from(w)
}
pub fn write_bit(&mut self, b: bool) -> io::Result<()> {
if self.cache_len < 8 {
self.cache_len += 1;
self.total_written += 1;
if b {
self.cache |= 1 << (8 - self.cache_len);
}
Ok(())
} else {
self.w.write_all(&[self.cache])?;
self.cache_len = 0;
self.cache = 0;
self.write_bit(b)
}
}
pub fn flush_all(&mut self) -> io::Result<()> {
if self.cache_len > 0 {
self.w.write_all(&[self.cache])?;
self.cache_len = 0;
self.cache = 0;
}
io::Write::flush(&mut self.w)
}
pub fn n_total_written(&self) -> usize {
self.total_written
}
pub fn write_bits_be(&mut self, n: u64, len: usize) -> io::Result<usize> {
for i in 0..len {
self.write_bit(n & (1 << (len - i - 1)) != 0)?;
}
Ok(len)
}
}
pub fn write_to_vec<F>(f: F) -> Vec<u8>
where
F: FnOnce(&mut BitWriter<&mut Vec<u8>>) -> io::Result<usize>,
{
let mut bytes = Vec::new();
let mut bits = BitWriter::new(&mut bytes);
f(&mut bits).expect("I/O to vector never fails");
bits.flush_all().expect("I/O to vector never fails");
bytes
}
#[cfg(test)]
mod tests {
use super::*;
use crate::jet::Core;
use crate::node::CoreConstructible;
use crate::types;
use crate::ConstructNode;
use std::sync::Arc;
#[test]
fn vec() {
types::Context::with_context(|ctx| {
let program = Arc::<ConstructNode<Core>>::unit(&ctx);
let _ = write_to_vec(|w| program.encode_without_witness(w));
})
}
#[test]
fn empty_vec() {
let vec = write_to_vec(|_| Ok(0));
assert!(vec.is_empty());
}
}