use std::fs;
use std::io;
use std::mem;
use std::io::{Seek, Write};
use std::mem::MaybeUninit;
use std::path;
use super::{Error, Result, Sample, SampleFormat, WavSpec, WavSpecEx};
use ::read;
pub trait WriteExt: io::Write {
fn write_u8(&mut self, x: u8) -> io::Result<()>;
fn write_le_i16(&mut self, x: i16) -> io::Result<()>;
fn write_le_u16(&mut self, x: u16) -> io::Result<()>;
fn write_le_i24(&mut self, x: i32) -> io::Result<()>;
fn write_le_i24_4(&mut self, x: i32) -> io::Result<()>;
fn write_le_u24(&mut self, x: u32) -> io::Result<()>;
fn write_le_i32(&mut self, x: i32) -> io::Result<()>;
fn write_le_u32(&mut self, x: u32) -> io::Result<()>;
fn write_le_f32(&mut self, x: f32) -> io::Result<()>;
}
impl<W> WriteExt for W
where W: io::Write
{
#[inline(always)]
fn write_u8(&mut self, x: u8) -> io::Result<()> {
let buf = [x];
self.write_all(&buf)
}
#[inline(always)]
fn write_le_i16(&mut self, x: i16) -> io::Result<()> {
self.write_le_u16(x as u16)
}
#[inline(always)]
fn write_le_u16(&mut self, x: u16) -> io::Result<()> {
let mut buf = [0u8; 2];
buf[0] = (x & 0xff) as u8;
buf[1] = (x >> 8) as u8;
self.write_all(&buf)
}
#[inline(always)]
fn write_le_i24(&mut self, x: i32) -> io::Result<()> {
self.write_le_u24(x as u32)
}
#[inline(always)]
fn write_le_i24_4(&mut self, x: i32) -> io::Result<()> {
self.write_le_u32((x as u32) & 0x00_ff_ff_ff)
}
#[inline(always)]
fn write_le_u24(&mut self, x: u32) -> io::Result<()> {
let mut buf = [0u8; 3];
buf[0] = ((x >> 00) & 0xff) as u8;
buf[1] = ((x >> 08) & 0xff) as u8;
buf[2] = ((x >> 16) & 0xff) as u8;
self.write_all(&buf)
}
#[inline(always)]
fn write_le_i32(&mut self, x: i32) -> io::Result<()> {
self.write_le_u32(x as u32)
}
#[inline(always)]
fn write_le_u32(&mut self, x: u32) -> io::Result<()> {
let mut buf = [0u8; 4];
buf[0] = ((x >> 00) & 0xff) as u8;
buf[1] = ((x >> 08) & 0xff) as u8;
buf[2] = ((x >> 16) & 0xff) as u8;
buf[3] = ((x >> 24) & 0xff) as u8;
self.write_all(&buf)
}
#[inline(always)]
fn write_le_f32(&mut self, x: f32) -> io::Result<()> {
let u = unsafe { mem::transmute::<f32, u32>(x) };
self.write_le_u32(u)
}
}
fn channel_mask(channels: u16) -> u32 {
let channels = if channels > 18 { 18 } else { channels };
(0..channels as u32).map(|c| 1 << c).fold(0, |a, c| a | c)
}
#[test]
fn verify_channel_mask() {
assert_eq!(channel_mask(0), 0);
assert_eq!(channel_mask(1), 1);
assert_eq!(channel_mask(2), 3);
assert_eq!(channel_mask(3), 7);
assert_eq!(channel_mask(4), 0xF);
assert_eq!(channel_mask(8), 0xFF);
assert_eq!(channel_mask(16), 0xFFFF);
assert_eq!(channel_mask(18), 0x3FFFF);
assert_eq!(channel_mask(32), 0x3FFFF);
assert_eq!(channel_mask(64), 0x3FFFF);
assert_eq!(channel_mask(129), 0x3FFFF);
}
pub struct WavWriter<W>
where W: io::Write + io::Seek
{
spec: WavSpec,
bytes_per_sample: u16,
writer: W,
data_bytes_written: u32,
finalized: bool,
sample_writer_buffer: Vec<MaybeUninit<u8>>,
data_len_offset: u32,
}
enum FmtKind {
PcmWaveFormat,
WaveFormatExtensible,
}
impl<W> WavWriter<W>
where W: io::Write + io::Seek
{
pub fn new(writer: W, spec: WavSpec) -> Result<WavWriter<W>> {
let spec_ex = WavSpecEx {
spec: spec,
bytes_per_sample: (spec.bits_per_sample + 7) / 8,
};
WavWriter::new_with_spec_ex(writer, spec_ex)
}
pub fn new_with_spec_ex(writer: W, spec_ex: WavSpecEx) -> Result<WavWriter<W>> {
let spec = spec_ex.spec;
let fmt_kind = if spec.channels > 2 || spec.bits_per_sample > 16 {
FmtKind::WaveFormatExtensible
} else {
FmtKind::PcmWaveFormat
};
let mut writer = WavWriter {
spec: spec,
bytes_per_sample: spec_ex.bytes_per_sample,
writer: writer,
data_bytes_written: 0,
sample_writer_buffer: Vec::new(),
finalized: false,
data_len_offset: match fmt_kind {
FmtKind::WaveFormatExtensible => 64,
FmtKind::PcmWaveFormat => 40,
},
};
let supported = match spec.bits_per_sample {
8 => true,
16 => true,
24 => true,
32 => true,
_ => false,
};
if !supported {
return Err(Error::Unsupported)
}
try!(writer.write_headers(fmt_kind));
Ok(writer)
}
fn write_headers(&mut self, fmt_kind: FmtKind) -> io::Result<()> {
let mut header = [0u8; 68];
{
let mut buffer = io::Cursor::new(&mut header[..]);
try!(buffer.write_all("RIFF".as_bytes()));
try!(buffer.write_le_u32(0));
try!(buffer.write_all("WAVE".as_bytes()));
try!(buffer.write_all("fmt ".as_bytes()));
match fmt_kind {
FmtKind::PcmWaveFormat => {
try!(self.write_pcmwaveformat(&mut buffer));
}
FmtKind::WaveFormatExtensible => {
try!(self.write_waveformatextensible(&mut buffer));
}
}
try!(buffer.write_all("data".as_bytes()));
try!(buffer.write_le_u32(0));
}
let header_len = self.data_len_offset as usize + 4;
self.writer.write_all(&header[..header_len])
}
fn write_waveformat(&self, buffer: &mut io::Cursor<&mut [u8]>) -> io::Result<()> {
let spec = &self.spec;
try!(buffer.write_le_u16(spec.channels));
try!(buffer.write_le_u32(spec.sample_rate));
let bytes_per_sec = spec.sample_rate
* self.bytes_per_sample as u32
* spec.channels as u32;
try!(buffer.write_le_u32(bytes_per_sec));
try!(buffer.write_le_u16((bytes_per_sec / spec.sample_rate) as u16));
Ok(())
}
fn write_pcmwaveformat(&mut self, buffer: &mut io::Cursor<&mut [u8]>) -> io::Result<()> {
try!(buffer.write_le_u32(16));
match self.spec.sample_format {
SampleFormat::Int => {
try!(buffer.write_le_u16(1));
},
SampleFormat::Float => {
if self.spec.bits_per_sample == 32 {
try!(buffer.write_le_u16(3));
} else {
panic!("Invalid number of bits per sample. \
When writing SampleFormat::Float, \
bits_per_sample must be 32.");
}
},
};
try!(self.write_waveformat(buffer));
try!(buffer.write_le_u16(self.spec.bits_per_sample));
Ok(())
}
fn write_waveformatextensible(&mut self, buffer: &mut io::Cursor<&mut [u8]>) -> io::Result<()> {
try!(buffer.write_le_u32(40));
try!(buffer.write_le_u16(0xfffe));
try!(self.write_waveformat(buffer));
try!(buffer.write_le_u16(self.bytes_per_sample as u16 * 8));
try!(buffer.write_le_u16(22));
try!(buffer.write_le_u16(self.spec.bits_per_sample));
try!(buffer.write_le_u32(channel_mask(self.spec.channels)));
let subformat_guid = match self.spec.sample_format {
SampleFormat::Int => super::KSDATAFORMAT_SUBTYPE_PCM,
SampleFormat::Float => {
if self.spec.bits_per_sample == 32 {
super::KSDATAFORMAT_SUBTYPE_IEEE_FLOAT
} else {
panic!("Invalid number of bits per sample. \
When writing SampleFormat::Float, \
bits_per_sample must be 32.");
}
}
};
try!(buffer.write_all(&subformat_guid));
Ok(())
}
#[inline]
pub fn write_sample<S: Sample>(&mut self, sample: S) -> Result<()> {
try!(sample.write_padded(
&mut self.writer,
self.spec.bits_per_sample,
self.bytes_per_sample,
));
self.data_bytes_written += self.bytes_per_sample as u32;
Ok(())
}
pub fn get_i16_writer<'s>(&'s mut self,
num_samples: u32)
-> SampleWriter16<'s, W> {
if self.spec.sample_format != SampleFormat::Int {
panic!("When calling get_i16_writer, the sample format must be int.");
}
if self.spec.bits_per_sample != 16 {
panic!("When calling get_i16_writer, the number of bits per sample must be 16.");
}
let num_bytes = num_samples as usize * 2;
if self.sample_writer_buffer.len() < num_bytes {
let mut new_buffer = Vec::<MaybeUninit<u8>>::with_capacity(num_bytes);
unsafe { new_buffer.set_len(num_bytes); }
self.sample_writer_buffer = new_buffer;
}
SampleWriter16 {
writer: &mut self.writer,
buffer: &mut self.sample_writer_buffer[..num_bytes],
data_bytes_written: &mut self.data_bytes_written,
index: 0,
}
}
fn update_header(&mut self) -> Result<()> {
let header_size = self.data_len_offset + 4 - 8;
let file_size = self.data_bytes_written + header_size;
try!(self.writer.seek(io::SeekFrom::Start(4)));
try!(self.writer.write_le_u32(file_size));
try!(self.writer.seek(io::SeekFrom::Start(self.data_len_offset as u64)));
try!(self.writer.write_le_u32(self.data_bytes_written));
if (self.data_bytes_written / self.bytes_per_sample as u32)
% self.spec.channels as u32 != 0 {
Err(Error::UnfinishedSample)
} else {
Ok(())
}
}
pub fn flush(&mut self) -> Result<()> {
let current_pos = try!(self.writer.seek(io::SeekFrom::Current(0)));
try!(self.update_header());
try!(self.writer.flush());
try!(self.writer.seek(io::SeekFrom::Start(current_pos)));
Ok(())
}
pub fn finalize(mut self) -> Result<()> {
self.finalized = true;
try!(self.update_header());
try!(self.writer.flush());
Ok(())
}
pub fn spec(&self) -> WavSpec {
self.spec
}
pub fn duration(&self) -> u32 {
self.data_bytes_written / (self.bytes_per_sample as u32 * self.spec.channels as u32)
}
pub fn len(&self) -> u32 {
self.data_bytes_written / self.bytes_per_sample as u32
}
}
impl<W> Drop for WavWriter<W>
where W: io::Write + io::Seek
{
fn drop(&mut self) {
if !self.finalized {
let _r = self.update_header();
}
}
}
fn read_append<W: io::Read + io::Seek>(mut reader: &mut W) -> Result<(WavSpecEx, u32, u32)> {
let (spec_ex, data_len) = {
try!(read::read_wave_header(&mut reader));
try!(read::read_until_data(&mut reader))
};
let data_len_offset = try!(reader.seek(io::SeekFrom::Current(0))) as u32 - 4;
let spec = spec_ex.spec;
let num_samples = data_len / spec_ex.bytes_per_sample as u32;
if num_samples * spec_ex.bytes_per_sample as u32 != data_len {
let msg = "data chunk length is not a multiple of sample size";
return Err(Error::FormatError(msg));
}
let supported = match (spec_ex.bytes_per_sample, spec.bits_per_sample) {
(1, 8) => true,
(2, 16) => true,
(3, 24) => true,
(4, 32) => true,
_ => false,
};
if !supported {
return Err(Error::Unsupported);
}
if num_samples % spec_ex.spec.channels as u32 != 0 {
return Err(Error::FormatError("invalid data chunk length"));
}
Ok((spec_ex, data_len, data_len_offset))
}
impl WavWriter<io::BufWriter<fs::File>> {
pub fn create<P: AsRef<path::Path>>(filename: P,
spec: WavSpec)
-> Result<WavWriter<io::BufWriter<fs::File>>> {
let file = try!(fs::File::create(filename));
let buf_writer = io::BufWriter::new(file);
WavWriter::new(buf_writer, spec)
}
pub fn append<P: AsRef<path::Path>>(filename: P) -> Result<WavWriter<io::BufWriter<fs::File>>> {
let mut file = try!(fs::OpenOptions::new().read(true).write(true).open(filename));
try!(file.seek(io::SeekFrom::Start(0)));
let mut buf_reader = io::BufReader::new(file);
let (spec_ex, data_len, data_len_offset) = try!(read_append(&mut buf_reader));
let mut file = buf_reader.into_inner();
try!(file.seek(io::SeekFrom::Current(data_len as i64)));
let buf_writer = io::BufWriter::new(file);
let writer = WavWriter {
spec: spec_ex.spec,
bytes_per_sample: spec_ex.bytes_per_sample,
writer: buf_writer,
data_bytes_written: data_len,
sample_writer_buffer: Vec::new(),
finalized: false,
data_len_offset: data_len_offset,
};
Ok(writer)
}
}
impl<W> WavWriter<W> where W: io::Read + io::Write + io::Seek {
pub fn new_append(mut writer: W) -> Result<WavWriter<W>> {
let (spec_ex, data_len, data_len_offset) = try!(read_append(&mut writer));
try!(writer.seek(io::SeekFrom::Current(data_len as i64)));
let writer = WavWriter {
spec: spec_ex.spec,
bytes_per_sample: spec_ex.bytes_per_sample,
writer: writer,
data_bytes_written: data_len,
sample_writer_buffer: Vec::new(),
finalized: false,
data_len_offset: data_len_offset,
};
Ok(writer)
}
}
pub struct SampleWriter16<'parent, W> where W: io::Write + io::Seek + 'parent {
writer: &'parent mut W,
buffer: &'parent mut [MaybeUninit<u8>],
data_bytes_written: &'parent mut u32,
index: u32,
}
impl<'parent, W: io::Write + io::Seek> SampleWriter16<'parent, W> {
#[inline(always)]
pub fn write_sample<S: Sample>(&mut self, sample: S) {
assert!((self.index as usize) + 2 <= self.buffer.len(),
"Trying to write more samples than reserved for the sample writer.");
unsafe { self.write_sample_unchecked(sample) };
}
unsafe fn write_u16_le_unchecked(&mut self, value: u16) {
*self.buffer.get_unchecked_mut(self.index as usize) = MaybeUninit::new(value as u8);
self.buffer.get_unchecked_mut(self.index as usize).assume_init();
*self.buffer.get_unchecked_mut(self.index as usize + 1) = MaybeUninit::new((value >> 8) as u8);
self.buffer.get_unchecked_mut(self.index as usize + 1).assume_init();
}
#[inline(always)]
pub unsafe fn write_sample_unchecked<S: Sample>(&mut self, sample: S) {
self.write_u16_le_unchecked(sample.as_i16() as u16);
self.index += 2;
}
pub fn flush(self) -> Result<()> {
if self.index as usize != self.buffer.len() {
panic!("Insufficient samples written to the sample writer.");
}
let slice = unsafe { &*(self.buffer as *const [MaybeUninit<u8>] as *const [u8]) };
try!(self.writer.write_all(slice));
*self.data_bytes_written += self.buffer.len() as u32;
Ok(())
}
}
#[test]
fn short_write_should_signal_error() {
use SampleFormat;
let mut buffer = io::Cursor::new(Vec::new());
let write_spec = WavSpec {
channels: 17,
sample_rate: 48000,
bits_per_sample: 8,
sample_format: SampleFormat::Int,
};
let mut writer = WavWriter::new(&mut buffer, write_spec).unwrap();
for s in 0..17 * 5 - 1 {
writer.write_sample(s as i16).unwrap();
}
let error = writer.finalize().err().unwrap();
match error {
Error::UnfinishedSample => {}
_ => panic!("UnfinishedSample error should have been returned."),
}
}
#[test]
fn wide_write_should_signal_error() {
let mut buffer = io::Cursor::new(Vec::new());
let spec8 = WavSpec {
channels: 1,
sample_rate: 44100,
bits_per_sample: 8,
sample_format: SampleFormat::Int,
};
{
let mut writer = WavWriter::new(&mut buffer, spec8).unwrap();
assert!(writer.write_sample(127_i8).is_ok());
assert!(writer.write_sample(127_i16).is_ok());
assert!(writer.write_sample(127_i32).is_ok());
assert!(writer.write_sample(128_i16).is_err());
assert!(writer.write_sample(128_i32).is_err());
}
let spec16 = WavSpec { bits_per_sample: 16, ..spec8 };
{
let mut writer = WavWriter::new(&mut buffer, spec16).unwrap();
assert!(writer.write_sample(32767_i16).is_ok());
assert!(writer.write_sample(32767_i32).is_ok());
assert!(writer.write_sample(32768_i32).is_err());
}
let spec24 = WavSpec { bits_per_sample: 24, ..spec8 };
{
let mut writer = WavWriter::new(&mut buffer, spec24).unwrap();
assert!(writer.write_sample(8_388_607_i32).is_ok());
assert!(writer.write_sample(8_388_608_i32).is_err());
}
}
#[test]
fn s24_wav_write() {
use std::fs::File;
use std::io::Read;
let mut buffer = io::Cursor::new(Vec::new());
let spec = WavSpecEx {
spec: WavSpec {
channels: 2,
sample_rate: 48000,
bits_per_sample: 24,
sample_format: SampleFormat::Int,
},
bytes_per_sample: 4,
};
{
let mut writer = WavWriter::new_with_spec_ex(&mut buffer, spec).unwrap();
assert!(writer.write_sample(-96_i32).is_ok());
assert!(writer.write_sample(23_052_i32).is_ok());
assert!(writer.write_sample(8_388_607_i32).is_ok());
assert!(writer.write_sample(-8_360_672_i32).is_ok());
}
let mut expected = Vec::new();
File::open("testsamples/waveformatextensible-24bit-4byte-48kHz-stereo.wav")
.unwrap()
.read_to_end(&mut expected)
.unwrap();
assert_eq!(buffer.into_inner(), expected);
}