use std::fmt;
use std::io::Read;
use std::io::Write;
use std::io::Seek;
use std::io::SeekFrom;
use std::convert::TryInto;
#[derive(PartialEq, Eq, Clone, Copy, Hash)]
pub struct ChunkId {
pub value: [u8; 4]
}
pub static RIFF_ID: ChunkId = ChunkId { value: [0x52, 0x49, 0x46, 0x46] };
pub static LIST_ID: ChunkId = ChunkId { value: [0x4C, 0x49, 0x53, 0x54] };
pub static SEQT_ID: ChunkId = ChunkId { value: [0x73, 0x65, 0x71, 0x74] };
impl ChunkId {
pub fn as_str(&self) -> &str {
std::str::from_utf8(&self.value).unwrap()
}
pub fn new(s: &str) -> Result<ChunkId, &str> {
let bytes = s.as_bytes();
if bytes.len() != 4 {
Err("Invalid length")
} else {
let mut a: [u8; 4] = Default::default();
a.copy_from_slice(&bytes[..]);
Ok(ChunkId { value: a })
}
}
}
impl fmt::Display for ChunkId {
fn fmt(&self, f: &mut std::fmt::Formatter) -> fmt::Result {
write!(f, "'{}'", self.as_str())
}
}
impl fmt::Debug for ChunkId {
fn fmt(&self, f: &mut std::fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
}
}
#[derive(PartialEq, Debug)]
pub enum ChunkContents {
Data(ChunkId, Vec<u8>),
Children(ChunkId, ChunkId, Vec<ChunkContents>),
ChildrenNoType(ChunkId, Vec<ChunkContents>)
}
impl ChunkContents {
pub fn write<T>(&self, writer: &mut T) -> std::io::Result<u64>
where T: Seek + Write {
match &self {
&ChunkContents::Data(id, data) => {
if data.len() as u64 > u32::MAX as u64 {
use std::io::{Error, ErrorKind};
return Err(Error::new(ErrorKind::InvalidData, "Data too big"));
}
let len = data.len() as u32;
writer.write_all(&id.value)?;
writer.write_all(&len.to_le_bytes())?;
writer.write_all(&data)?;
if len % 2 != 0 {
let single_byte: [u8; 1] = [0];
writer.write_all(&single_byte)?;
}
Ok((8 + len + (len % 2)).into())
}
&ChunkContents::Children(id, chunk_type, children) => {
writer.write_all(&id.value)?;
let len_pos = writer.seek(SeekFrom::Current(0))?;
let zeros: [u8; 4] = [0, 0, 0, 0];
writer.write_all(&zeros)?;
writer.write_all(&chunk_type.value)?;
let mut total_len: u64 = 4;
for child in children {
total_len = total_len + child.write(writer)?;
}
if total_len > u32::MAX as u64 {
use std::io::{Error, ErrorKind};
return Err(Error::new(ErrorKind::InvalidData, "Data too big"));
}
let end_pos = writer.seek(SeekFrom::Current(0))?;
writer.seek(SeekFrom::Start(len_pos))?;
writer.write_all(&(total_len as u32).to_le_bytes())?;
writer.seek(SeekFrom::Start(end_pos))?;
Ok((8 + total_len + (total_len % 2)).into())
}
&ChunkContents::ChildrenNoType(id, children) => {
writer.write_all(&id.value)?;
let len_pos = writer.seek(SeekFrom::Current(0))?;
let zeros: [u8; 4] = [0, 0, 0, 0];
writer.write_all(&zeros)?;
let mut total_len: u64 = 0;
for child in children {
total_len = total_len + child.write(writer)?;
}
if total_len > u32::MAX as u64 {
use std::io::{Error, ErrorKind};
return Err(Error::new(ErrorKind::InvalidData, "Data too big"));
}
let end_pos = writer.seek(SeekFrom::Current(0))?;
writer.seek(SeekFrom::Start(len_pos))?;
writer.write_all(&(total_len as u32).to_le_bytes())?;
writer.seek(SeekFrom::Start(end_pos))?;
Ok((8 + total_len + (total_len % 2)).into())
}
}
}
}
#[derive(PartialEq, Eq, Debug)]
pub struct Chunk {
pos: u64,
id: ChunkId,
len: u32,
}
pub struct Iter<'a, T>
where T: Seek + Read {
end: u64,
cur: u64,
stream: &'a mut T
}
impl<'a, T> Iterator for Iter<'a, T>
where T: Seek + Read {
type Item = std::io::Result<Chunk>;
fn next(&mut self) -> Option<Self::Item> {
if self.cur >= self.end {
return None
}
let chunk = match Chunk::read(&mut self.stream, self.cur) {
Ok(chunk) => chunk,
Err(err) => return Some(Err(err)),
};
let len = chunk.len() as u64;
self.cur = self.cur + len + 8 + (len % 2);
Some(Ok(chunk))
}
}
impl Chunk {
pub fn id(&self) -> ChunkId {
self.id.clone()
}
pub fn len(&self) -> u32 {
self.len
}
pub fn offset(&self) -> u64 {
self.pos
}
pub fn read_type<T>(&self, stream: &mut T) -> std::io::Result<ChunkId>
where T: Read + Seek {
stream.seek(SeekFrom::Start(self.pos + 8))?;
let mut fourcc : [u8; 4] = [0; 4];
stream.read_exact(&mut fourcc)?;
Ok(ChunkId { value: fourcc })
}
pub fn read<T>(stream: &mut T, pos: u64) -> std::io::Result<Chunk>
where T: Read + Seek {
stream.seek(SeekFrom::Start(pos))?;
let mut fourcc : [u8; 4] = [0; 4];
stream.read_exact(&mut fourcc)?;
let mut len : [u8; 4] = [0; 4];
stream.read_exact(&mut len)?;
Ok(Chunk {
pos: pos,
id: ChunkId { value: fourcc },
len: u32::from_le_bytes(len)
})
}
pub fn read_contents<T>(&self, stream: &mut T) -> std::io::Result<Vec<u8>>
where T: Read + Seek {
stream.seek(SeekFrom::Start(self.pos + 8))?;
let mut data: Vec<u8> = vec![0; self.len.try_into().unwrap()];
stream.read_exact(&mut data)?;
Ok(data)
}
pub fn iter<'a, T>(&self, stream: &'a mut T) -> Iter<'a, T>
where T: Seek + Read {
Iter {
cur: self.pos + 12,
end: self.pos + 4 + (self.len as u64),
stream: stream
}
}
pub fn iter_no_type<'a, T>(&self, stream: &'a mut T) -> Iter<'a, T>
where T: Seek + Read {
Iter {
cur: self.pos + 8,
end: self.pos + 4 + (self.len as u64),
stream: stream
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chunkid_from_str() {
assert_eq!(ChunkId::new("RIFF").unwrap(), RIFF_ID);
assert_eq!(ChunkId::new("LIST").unwrap(), LIST_ID);
assert_eq!(ChunkId::new("seqt").unwrap(), SEQT_ID);
assert_eq!(ChunkId::new("123 ").unwrap(),
ChunkId { value: [0x31, 0x32, 0x33, 0x20] });
assert_eq!(ChunkId::new("123"), Err("Invalid length"));
assert_eq!(ChunkId::new("12345"), Err("Invalid length"));
}
#[test]
fn chunkid_to_str() {
assert_eq!(RIFF_ID.as_str(), "RIFF");
assert_eq!(LIST_ID.as_str(), "LIST");
assert_eq!(SEQT_ID.as_str(), "seqt");
assert_eq!(ChunkId::new("123 ").unwrap().as_str(), "123 ");
}
#[test]
fn chunkid_format() {
assert_eq!(format!("{}", RIFF_ID), "'RIFF'");
assert_eq!(format!("{}", LIST_ID), "'LIST'");
assert_eq!(format!("{}", SEQT_ID), "'seqt'");
assert_eq!(format!("{:?}", RIFF_ID), "'RIFF'");
assert_eq!(format!("{:?}", LIST_ID), "'LIST'");
assert_eq!(format!("{:?}", SEQT_ID), "'seqt'");
}
}