use std::fs;
use std::io;
use std::mem;
use std::io::Write;
use std::path;
use super::{Error, Result, Sample, SampleFormat, WavSpec};
pub trait WriteExt: io::Write {
fn write_u8(&mut self, x: u8) -> io::Result<()>;
fn write_le_i16(&mut self, x: i16) -> io::Result<()>;
fn write_le_u16(&mut self, x: u16) -> io::Result<()>;
fn write_le_i24(&mut self, x: i32) -> io::Result<()>;
fn write_le_u24(&mut self, x: u32) -> io::Result<()>;
fn write_le_i32(&mut self, x: i32) -> io::Result<()>;
fn write_le_u32(&mut self, x: u32) -> io::Result<()>;
fn write_le_f32(&mut self, x: f32) -> io::Result<()>;
}
impl<W> WriteExt for W
where W: io::Write
{
#[inline(always)]
fn write_u8(&mut self, x: u8) -> io::Result<()> {
let buf = [x];
self.write_all(&buf)
}
#[inline(always)]
fn write_le_i16(&mut self, x: i16) -> io::Result<()> {
self.write_le_u16(x as u16)
}
#[inline(always)]
fn write_le_u16(&mut self, x: u16) -> io::Result<()> {
let mut buf = [0u8; 2];
buf[0] = (x & 0xff) as u8;
buf[1] = (x >> 8) as u8;
self.write_all(&buf)
}
#[inline(always)]
fn write_le_i24(&mut self, x: i32) -> io::Result<()> {
self.write_le_u24(x as u32)
}
#[inline(always)]
fn write_le_u24(&mut self, x: u32) -> io::Result<()> {
let mut buf = [0u8; 3];
buf[0] = ((x >> 00) & 0xff) as u8;
buf[1] = ((x >> 08) & 0xff) as u8;
buf[2] = ((x >> 16) & 0xff) as u8;
self.write_all(&buf)
}
#[inline(always)]
fn write_le_i32(&mut self, x: i32) -> io::Result<()> {
self.write_le_u32(x as u32)
}
#[inline(always)]
fn write_le_u32(&mut self, x: u32) -> io::Result<()> {
let mut buf = [0u8; 4];
buf[0] = ((x >> 00) & 0xff) as u8;
buf[1] = ((x >> 08) & 0xff) as u8;
buf[2] = ((x >> 16) & 0xff) as u8;
buf[3] = ((x >> 24) & 0xff) as u8;
self.write_all(&buf)
}
#[inline(always)]
fn write_le_f32(&mut self, x: f32) -> io::Result<()> {
let u = unsafe { mem::transmute(x) };
self.write_le_u32(u)
}
}
fn channel_mask(channels: u16) -> u32 {
(0..channels).map(|c| 1 << c).fold(0, |a, c| a | c)
}
#[test]
fn verify_channel_mask() {
assert_eq!(channel_mask(0), 0);
assert_eq!(channel_mask(1), 1);
assert_eq!(channel_mask(2), 3);
assert_eq!(channel_mask(3), 7);
assert_eq!(channel_mask(4), 15);
}
pub struct WavWriter<W>
where W: io::Write + io::Seek
{
spec: WavSpec,
bytes_per_sample: u16,
writer: W,
data_bytes_written: u32,
finalized: bool,
sample_writer_buffer: Vec<u8>,
extensible: bool,
}
impl<W> WavWriter<W>
where W: io::Write + io::Seek
{
pub fn new(writer: W, spec: WavSpec) -> Result<WavWriter<W>> {
let mut writer = WavWriter {
spec: spec,
bytes_per_sample: (spec.bits_per_sample as f32 / 8.0).ceil() as u16,
writer: writer,
data_bytes_written: 0,
sample_writer_buffer: Vec::new(),
finalized: false,
extensible: spec.channels > 2 || spec.bits_per_sample > 16,
};
try!(writer.write_headers());
Ok(writer)
}
fn write_headers(&mut self) -> io::Result<()> {
let mut header = [0u8; 68];
{
let mut buffer = io::Cursor::new(&mut header[..]);
try!(buffer.write_all("RIFF".as_bytes()));
try!(buffer.write_le_u32(0));
try!(buffer.write_all("WAVE".as_bytes()));
try!(buffer.write_all("fmt ".as_bytes()));
match self.extensible {
true => try!(self.write_waveformatextensible(&mut buffer)),
false => try!(self.write_waveformatex(&mut buffer)),
}
try!(buffer.write_all("data".as_bytes()));
try!(buffer.write_le_u32(0));
}
let header_len = if self.extensible { 68 } else { 44 };
self.writer.write_all(&header[..header_len])
}
fn write_waveformat(&self, buffer: &mut io::Cursor<&mut [u8]>) -> io::Result<()> {
let spec = &self.spec;
try!(buffer.write_le_u16(spec.channels));
try!(buffer.write_le_u32(spec.sample_rate));
let bytes_per_sec = spec.sample_rate
* self.bytes_per_sample as u32
* spec.channels as u32;
try!(buffer.write_le_u32(bytes_per_sec));
try!(buffer.write_le_u16((bytes_per_sec / spec.sample_rate) as u16));
Ok(())
}
fn write_waveformatex(&mut self, buffer: &mut io::Cursor<&mut [u8]>) -> io::Result<()> {
try!(buffer.write_le_u32(16));
match self.spec.sample_format {
SampleFormat::Int => {
try!(buffer.write_le_u16(1));
},
SampleFormat::Float => {
if self.spec.bits_per_sample == 32 {
try!(buffer.write_le_u16(3));
} else {
panic!("Invalid number of bits per sample. \
When writing SampleFormat::Float, \
bits_per_sample must be 32.");
}
},
};
try!(self.write_waveformat(buffer));
try!(buffer.write_le_u16(self.spec.bits_per_sample));
Ok(())
}
fn write_waveformatextensible(&mut self, buffer: &mut io::Cursor<&mut [u8]>) -> io::Result<()> {
try!(buffer.write_le_u32(40));
try!(buffer.write_le_u16(0xfffe));
try!(self.write_waveformat(buffer));
try!(buffer.write_le_u16(self.bytes_per_sample as u16 * 8));
try!(buffer.write_le_u16(22));
try!(buffer.write_le_u16(self.spec.bits_per_sample));
try!(buffer.write_le_u32(channel_mask(self.spec.channels)));
let subformat_guid = match self.spec.sample_format {
SampleFormat::Int => super::KSDATAFORMAT_SUBTYPE_PCM,
SampleFormat::Float => {
if self.spec.bits_per_sample == 32 {
super::KSDATAFORMAT_SUBTYPE_IEEE_FLOAT
} else {
panic!("Invalid number of bits per sample. \
When writing SampleFormat::Float, \
bits_per_sample must be 32.");
}
}
};
try!(buffer.write_all(&subformat_guid));
Ok(())
}
#[inline]
pub fn write_sample<S: Sample>(&mut self, sample: S) -> Result<()> {
try!(sample.write(&mut self.writer, self.spec.bits_per_sample));
self.data_bytes_written += self.bytes_per_sample as u32;
Ok(())
}
pub fn get_i16_writer<'s>(&'s mut self,
num_samples: u32)
-> SampleWriter16<'s, W> {
if self.spec.sample_format != SampleFormat::Int {
panic!("When calling get_i16_writer, the sample format must be int.");
}
if self.spec.bits_per_sample != 16 {
panic!("When calling get_i16_writer, the number of bits per sample must be 16.");
}
let num_bytes = num_samples as usize * 2;
if self.sample_writer_buffer.len() < num_bytes {
let mut new_buffer = Vec::with_capacity(num_bytes);
unsafe { new_buffer.set_len(num_bytes); }
self.sample_writer_buffer = new_buffer;
}
SampleWriter16 {
writer: &mut self.writer,
buffer: &mut self.sample_writer_buffer[..num_bytes],
data_bytes_written: &mut self.data_bytes_written,
index: 0,
}
}
fn finalize_internal(&mut self) -> Result<()> {
self.finalized = true;
try!(self.writer.flush());
let header_size = if self.extensible { 64 } else { 40 };
let file_size = self.data_bytes_written + (header_size - 4);
try!(self.writer.seek(io::SeekFrom::Start(4)));
try!(self.writer.write_le_u32(file_size));
try!(self.writer.seek(io::SeekFrom::Start(header_size as u64)));
try!(self.writer.write_le_u32(self.data_bytes_written));
if (self.data_bytes_written / self.bytes_per_sample as u32)
% self.spec.channels as u32 != 0 {
return Err(Error::UnfinishedSample);
}
Ok(())
}
pub fn finalize(mut self) -> Result<()> {
self.finalize_internal()
}
}
impl<W> Drop for WavWriter<W>
where W: io::Write + io::Seek
{
fn drop(&mut self) {
if !self.finalized {
let _r = self.finalize_internal();
}
}
}
impl WavWriter<io::BufWriter<fs::File>> {
pub fn create<P: AsRef<path::Path>>(filename: P,
spec: WavSpec)
-> Result<WavWriter<io::BufWriter<fs::File>>> {
let file = try!(fs::File::create(filename));
let buf_writer = io::BufWriter::new(file);
WavWriter::new(buf_writer, spec)
}
}
pub struct SampleWriter16<'parent, W> where W: io::Write + io::Seek + 'parent {
writer: &'parent mut W,
buffer: &'parent mut [u8],
data_bytes_written: &'parent mut u32,
index: u32,
}
impl<'parent, W: io::Write + io::Seek> SampleWriter16<'parent, W> {
#[inline(always)]
pub fn write_sample<S: Sample>(&mut self, sample: S) {
assert!((self.index as usize) <= self.buffer.len() - 2,
"Trying to write more samples than reserved for the sample writer.");
let s = sample.as_i16() as u16;
self.buffer[self.index as usize] = s as u8;
self.buffer[self.index as usize + 1] = (s >> 8) as u8;
self.index += 2;
}
#[cfg(target_arch = "x86_64")]
unsafe fn write_u16_le_unchecked(&mut self, value: u16) {
use std::mem;
let ptr: *mut u16 = mem::transmute(self.buffer.get_unchecked_mut(self.index as usize));
*ptr = value;
}
#[cfg(not(target_arch = "x86_64"))]
unsafe fn write_u16_le_unchecked(&mut self, value: u16) {
let idx = self.index as usize;
*self.buffer.get_unchecked_mut(idx) = value as u8;
*self.buffer.get_unchecked_mut(idx + 1) = (value >> 8) as u8;
}
#[inline(always)]
pub unsafe fn write_sample_unchecked<S: Sample>(&mut self, sample: S) {
self.write_u16_le_unchecked(sample.as_i16() as u16);
self.index += 2;
}
pub fn flush(self) -> Result<()> {
if self.index as usize != self.buffer.len() {
panic!("Insufficient samples written to the sample writer.");
}
try!(self.writer.write_all(&self.buffer));
*self.data_bytes_written += self.buffer.len() as u32;
Ok(())
}
}
#[test]
fn short_write_should_signal_error() {
use SampleFormat;
let mut buffer = io::Cursor::new(Vec::new());
let write_spec = WavSpec {
channels: 17,
sample_rate: 48000,
bits_per_sample: 8,
sample_format: SampleFormat::Int,
};
let mut writer = WavWriter::new(&mut buffer, write_spec).unwrap();
for s in 0..17 * 5 - 1 {
writer.write_sample(s as i16).unwrap();
}
let error = writer.finalize().err().unwrap();
match error {
Error::UnfinishedSample => {}
_ => panic!("UnfinishedSample error should have been returned."),
}
}
#[test]
fn wide_write_should_signal_error() {
let mut buffer = io::Cursor::new(Vec::new());
let spec8 = WavSpec {
channels: 1,
sample_rate: 44100,
bits_per_sample: 8,
sample_format: SampleFormat::Int,
};
{
let mut writer = WavWriter::new(&mut buffer, spec8).unwrap();
assert!(writer.write_sample(127_i8).is_ok());
assert!(writer.write_sample(127_i16).is_ok());
assert!(writer.write_sample(127_i32).is_ok());
assert!(writer.write_sample(128_i16).is_err());
assert!(writer.write_sample(128_i32).is_err());
}
let spec16 = WavSpec { bits_per_sample: 16, ..spec8 };
{
let mut writer = WavWriter::new(&mut buffer, spec16).unwrap();
assert!(writer.write_sample(32767_i16).is_ok());
assert!(writer.write_sample(32767_i32).is_ok());
assert!(writer.write_sample(32768_i32).is_err());
}
let spec24 = WavSpec { bits_per_sample: 24, ..spec8 };
{
let mut writer = WavWriter::new(&mut buffer, spec24).unwrap();
assert!(writer.write_sample(8_388_607_i32).is_ok());
assert!(writer.write_sample(8_388_608_i32).is_err());
}
}