use core::mem;
use core::slice;
use core::convert::{TryFrom, TryInto};
use core::num::NonZeroU32;
use std::io::{ErrorKind, Error, Write, Seek, SeekFrom, Result};
use super::pulse::PulseDecodeWriter;
use super::{Header, checksum};
const LEN_PREFIX_SIZE: u64 = mem::size_of::<u16>() as u64;
pub struct TapChunkWriter<W> {
chunk_head: u64,
mpwr: PulseDecodeWriter<W>
}
pub struct TapChunkWriteTran<'a, W: Write + Seek> {
pub checksum: u8,
nchunks: usize,
uncommitted: u16,
writer: &'a mut TapChunkWriter<W>
}
impl<W> TapChunkWriter<W> {
pub fn into_inner(self) -> PulseDecodeWriter<W> {
self.mpwr
}
pub fn get_mut(&mut self) -> &mut PulseDecodeWriter<W> {
&mut self.mpwr
}
pub fn get_ref(&self) -> &PulseDecodeWriter<W> {
&self.mpwr
}
}
impl<'a, W: Write + Seek> Drop for TapChunkWriteTran<'a, W> {
fn drop(&mut self) {
if self.uncommitted != 0 {
let chunk_head = self.writer.chunk_head.checked_add(LEN_PREFIX_SIZE).unwrap();
let _ = self.writer.mpwr.get_mut().seek(SeekFrom::Start(chunk_head));
}
}
}
impl<'a, W: Write + Seek> Write for TapChunkWriteTran<'a, W> {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
let _: u16 = (self.uncommitted as usize).checked_add(buf.len()).unwrap()
.try_into().map_err(|e| Error::new(ErrorKind::WriteZero, e))?;
let written = self.writer.mpwr.get_mut().write(buf)?;
self.checksum ^= checksum(&buf[..written]);
self.uncommitted += written as u16;
Ok(written)
}
fn flush(&mut self) -> Result<()> {
self.writer.flush()
}
}
impl<'a, W> TapChunkWriteTran<'a, W>
where W: Write + Seek
{
pub fn commit(mut self, with_checksum: bool) -> Result<usize> {
let mut nchunks = self.nchunks;
let mut uncommitted = self.uncommitted;
if with_checksum {
if let Some(size) = uncommitted.checked_add(1) {
uncommitted = size;
let checksum = self.checksum;
self.write_all(slice::from_ref(&checksum))?;
}
else {
return Err(Error::new(ErrorKind::WriteZero, "chunk is larger than the maximum allowed size"))
}
}
if let Some(size) = NonZeroU32::new(uncommitted.into()) {
self.writer.commit_chunk(size)?;
nchunks += 1;
}
Ok(nchunks)
}
}
impl<W> TapChunkWriter<W>
where W: Write + Seek
{
pub fn try_new(wr: W) -> Result<Self> {
let mut mpwr = PulseDecodeWriter::new(wr);
let chunk_start = mpwr.get_mut().seek(SeekFrom::Current(LEN_PREFIX_SIZE as i64))?;
let chunk_head = chunk_start.checked_sub(LEN_PREFIX_SIZE).unwrap();
Ok(TapChunkWriter { chunk_head, mpwr })
}
pub fn flush(&mut self) -> Result<()> {
self.mpwr.get_mut().flush()
}
pub fn end_pulse_chunk(&mut self) -> Result<usize> {
if let Some(size) = self.mpwr.end()? {
self.commit_chunk(size)?;
Ok(1)
}
else {
Ok(0)
}
}
pub fn write_header(&mut self, header: &Header) -> Result<usize> {
self.write_chunk(header.to_tap_chunk())
}
pub fn write_chunk<D: AsRef<[u8]>>(&mut self, chunk: D) -> Result<usize> {
let data = chunk.as_ref();
let size = u16::try_from(data.len()).map_err(|_|
Error::new(ErrorKind::InvalidData, "TAP chunk too large."))?;
let nchunks = self.end_pulse_chunk()?;
let wr = self.mpwr.get_mut();
let chunk_head = wr.seek(SeekFrom::Start(self.chunk_head))?;
debug_assert_eq!(chunk_head, self.chunk_head);
wr.write_all(&size.to_le_bytes())?;
wr.write_all(data)?;
let chunk_start = wr.seek(SeekFrom::Current(LEN_PREFIX_SIZE as i64))?;
self.chunk_head = chunk_start.checked_sub(LEN_PREFIX_SIZE).unwrap();
Ok(nchunks + 1)
}
pub fn begin(&mut self) -> Result<TapChunkWriteTran<'_, W>> {
let nchunks = self.end_pulse_chunk()?;
Ok(TapChunkWriteTran { checksum: 0, nchunks, uncommitted: 0, writer: self })
}
pub fn write_pulses_as_tap_chunks<I>(&mut self, mut iter: I) -> Result<usize>
where I: Iterator<Item=NonZeroU32>
{
let mut chunks = 0;
loop {
match self.mpwr.write_decoded_pulses(iter.by_ref())? {
None => return Ok(chunks),
Some(size) => {
chunks += 1;
self.commit_chunk(size)?;
}
}
}
}
fn commit_chunk(&mut self, size: NonZeroU32) -> Result<()> {
let size = u16::try_from(size.get()).map_err(|_|
Error::new(ErrorKind::InvalidData, "TAP chunk too large."))?;
let wr = self.mpwr.get_mut();
let chunk_head = wr.seek(SeekFrom::Start(self.chunk_head))?;
debug_assert_eq!(chunk_head, self.chunk_head);
wr.write_all(&size.to_le_bytes())?;
self.chunk_head = chunk_head.checked_add(LEN_PREFIX_SIZE + size as u64).unwrap();
let pos_cur = self.chunk_head.checked_add(LEN_PREFIX_SIZE).unwrap();
let pos_next = wr.seek(SeekFrom::Start(pos_cur))?;
debug_assert_eq!(pos_next, pos_cur);
Ok(())
}
}