stream-wave-parser 0.1.2

The `stream-wave-parser` is a crate that treats a stream from WAVE file.
Documentation
//! Types to parse stream as WAVE header and `data` stream.

use crate::{Error, Result};
use futures_util::stream::{iter, BoxStream};
use futures_util::{Stream, StreamExt as _};
use std::collections::VecDeque;

/// The structure that handles WAVE files as a stream.
pub struct WaveStream<'a> {
    stream: BoxStream<'a, Result<Vec<u8>>>,
    current: VecDeque<u8>,
    riff_size: Option<u32>,
    spec: Option<WaveSpec>,
    data_size: Option<u32>,
}

/// The structure representing the metadata of the WAVE file.
#[derive(Clone, Debug)]
pub struct WaveSpec {
    /// Audio format. (1: PCM integer, ...).
    pub pcm_format: u16,

    /// Number of channels.
    pub channels: u16,

    /// Sample rate.
    pub sample_rate: u32,

    /// Number of bits per sample.
    pub bits_per_sample: u16,
}

/// The stream wraps a stream that returns rest.
struct DataChunk<'a> {
    stream: BoxStream<'a, Result<Vec<u8>>>,
    data_size: u32,
    consumed: u32,
}

impl<'a> WaveStream<'a> {
    /// The constructor.
    pub fn new(stream: impl Stream<Item = Result<Vec<u8>>> + Send + 'a) -> Self {
        Self {
            stream: Box::pin(stream),
            current: VecDeque::new(),
            riff_size: None,
            spec: None,
            data_size: None,
        }
    }

    /// Parses a WAVE header and returns [`WaveSpec`].
    pub async fn spec(&mut self) -> Result<WaveSpec> {
        self.take_riff().await?;
        self.skip_to_data_chunk().await?;

        let spec = self.spec.as_ref().ok_or(Error::FmtChunkIsNotFound)?;
        Ok(spec.clone())
    }

    /// Returns the stream that returns data chunks.
    pub async fn into_data(mut self) -> BoxStream<'a, Result<Vec<u8>>> {
        if let Err(e) = self.take_riff().await {
            return Box::pin(iter(vec![Err(e)]));
        }

        if let Err(e) = self.skip_to_data_chunk().await {
            return Box::pin(iter(vec![Err(e)]));
        }

        let data_size = self.data_size.unwrap(); // If `skip_to_data_chunk()` is `Ok(_)`, it is `Some(_)` and can be `unwrap()`.

        if data_size <= self.current.len() as u32 {
            return Box::pin(iter(vec![Ok(self
                .current
                .into_iter()
                .take(data_size as usize)
                .collect())]));
        }

        let consumed = self.current.len() as u32;
        let data_chunk = DataChunk {
            stream: self.stream,
            data_size,
            consumed,
        };

        Box::pin(iter(vec![Ok(self.current.into())]).chain(data_chunk))
    }

    async fn take_riff(&mut self) -> Result<()> {
        if self.riff_size.is_some() {
            return Ok(());
        }

        let four = self.take::<4>().await?;
        if b"RIFF" != &four {
            return Err(Error::RiffChunkHeaderIsNotFound);
        }

        self.riff_size = Some(self.take_u32().await?);

        let four = self.take::<4>().await?;
        if b"WAVE" != &four {
            return Err(Error::WaveChunkHeaderIsNotFound);
        }

        Ok(())
    }

    async fn skip_to_data_chunk(&mut self) -> Result<()> {
        if self.data_size.is_some() {
            return Ok(());
        }

        loop {
            let four = self.take::<4>().await?;
            let size = self.take_u32().await?;
            match &four {
                b"data" => {
                    self.data_size = Some(size);
                    return Ok(());
                }

                b"fmt " => {
                    let spec = self.parse_fmt(size).await?;
                    self.spec = Some(spec);
                }

                // skip other chunk
                _ => {
                    for _ in 0..size {
                        self.next().await?;
                    }
                }
            }
        }
    }

    async fn take_u16(&mut self) -> Result<u16> {
        let four = self.take::<2>().await?;
        Ok(u16::from_le_bytes(four))
    }

    async fn take_u32(&mut self) -> Result<u32> {
        let four = self.take::<4>().await?;
        Ok(u32::from_le_bytes(four))
    }

    async fn parse_fmt(&mut self, size: u32) -> Result<WaveSpec> {
        let pcm_format = self.take_u16().await?;

        let channels = self.take_u16().await?;
        let sample_rate = self.take_u32().await?;
        let _bit_rate = self.take_u32().await?;
        let _block_size = self.take_u16().await?;
        let bits_per_sample = self.take_u16().await?;

        // skip extension
        if size > 16 {
            for _ in 0..(size - 16) {
                self.next().await?;
            }
        }

        let spec = WaveSpec {
            pcm_format,
            channels,
            sample_rate,
            bits_per_sample,
        };
        Ok(spec)
    }

    async fn take<const N: usize>(&mut self) -> Result<[u8; N]> {
        let mut bytes = [0; N];
        for item in bytes.iter_mut() {
            *item = self.next().await?;
        }
        Ok(bytes)
    }

    async fn next(&mut self) -> Result<u8> {
        while self.current.is_empty() {
            self.current = self
                .stream
                .next()
                .await
                .ok_or(Error::DataIsNotEnough)??
                .into();
        }

        Ok(self.current.pop_front().unwrap())
    }
}

mod impls {
    //! Implements [`Stream`] for [`DataChunk`].

    use super::*;

    use std::pin::Pin;
    use std::task::{Context, Poll};

    impl<'a> Stream for DataChunk<'a> {
        type Item = Result<Vec<u8>>;

        fn poll_next(
            mut self: Pin<&mut Self>,
            context: &mut Context<'_>,
        ) -> Poll<Option<<Self as Stream>::Item>> {
            let polled = self.stream.as_mut().poll_next(context);
            let ready = match polled {
                Poll::Ready(ready) => ready,
                Poll::Pending => return Poll::Pending,
            };

            let Some(chunk) = ready else {
                return Poll::Ready(None);
            };

            let chunk = match chunk {
                Ok(chunk) => chunk,
                Err(e) => return Poll::Ready(Some(Err(e))),
            };

            let rest_size = (self.data_size - self.consumed) as usize;
            if chunk.len() < rest_size {
                self.consumed += chunk.len() as u32;
                Poll::Ready(Some(Ok(chunk)))
            } else {
                let chunk = chunk.into_iter().take(rest_size).collect();
                Poll::Ready(Some(Ok(chunk)))
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use tokio::fs::read;

    /// The file path of WAVE file. The file from `魔王魂®︎`. <https://maou.audio/se_system49/>
    const FILE: &str = "./assets/test/maou_se_system49.wav";

    #[tokio::test]
    async fn test_one_chunk() {
        let read = read(FILE).await.unwrap();
        let mut stream = WaveStream::new(iter(vec![Ok(read)]));

        let spec = stream.spec().await.unwrap();
        assert_eq!(spec.pcm_format, 1);
        assert_eq!(spec.channels, 2);
        assert_eq!(spec.sample_rate, 44100);
        assert_eq!(spec.bits_per_sample, 24);

        let data_size = stream.data_size.unwrap();
        let mut data = stream.into_data().await;
        let mut size = 0;
        while let Some(chunk) = data.next().await {
            let chunk = chunk.unwrap();
            size += chunk.len();
        }
        assert_eq!(data_size, size as u32);
    }

    #[tokio::test]
    async fn test_chunks() {
        let read = read(FILE).await.unwrap();
        let chunks = read
            .chunks(65536)
            .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
            .collect::<Vec<_>>();
        let mut stream = WaveStream::new(iter(chunks));

        let spec = stream.spec().await.unwrap();
        assert_eq!(spec.pcm_format, 1);
        assert_eq!(spec.channels, 2);
        assert_eq!(spec.sample_rate, 44100);
        assert_eq!(spec.bits_per_sample, 24);

        let data_size = stream.data_size.unwrap();
        let mut data = stream.into_data().await;
        let mut size = 0;
        while let Some(chunk) = data.next().await {
            let chunk = chunk.unwrap();
            size += chunk.len();
        }
        assert_eq!(data_size, size as u32);
    }

    #[tokio::test]
    async fn test_generate() {
        // create sine wave (440 Hz 1 seconds)
        use std::f32::consts::PI;

        let data_chunk = (0..)
            .enumerate()
            .map(|(_, idx)| {
                let t = idx as f32 / 8000.0;
                let sample = (t * 440. * 2. * PI).sin();
                ((sample * i16::MAX as f32) as i16).to_le_bytes()
            })
            .take(8000)
            .flatten()
            .collect::<Vec<u8>>();

        let mut wave = b"RIFF".to_vec();
        let riff_length = ((data_chunk.len() + 36) as u32).to_le_bytes();
        wave.extend(riff_length);
        wave.extend(b"WAVE");
        wave.extend(b"fmt ");
        wave.extend(16u32.to_le_bytes()); // `fmt ` chunk size
        wave.extend(1u16.to_le_bytes()); // PCM format
        wave.extend(1u16.to_le_bytes()); // channels
        wave.extend(8000u32.to_le_bytes()); // sample rate
        wave.extend(16000u32.to_le_bytes()); // bit rate
        wave.extend(2u16.to_le_bytes()); // block size
        wave.extend(16u16.to_le_bytes()); // bits per sample

        wave.extend(b"data");
        wave.extend((data_chunk.len() as u32).to_le_bytes());
        wave.extend(&data_chunk);

        // create stream
        let chunks = wave
            .chunks(65536)
            .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
            .collect::<Vec<_>>();
        let mut stream = WaveStream::new(iter(chunks));

        // read stream
        let spec = stream.spec().await.unwrap();
        assert_eq!(spec.pcm_format, 1);
        assert_eq!(spec.channels, 1);
        assert_eq!(spec.sample_rate, 8000);
        assert_eq!(spec.bits_per_sample, 16);

        let data_size = stream.data_size.unwrap();
        let mut data = stream.into_data().await;
        let mut size = 0;
        while let Some(chunk) = data.next().await {
            let chunk = chunk.unwrap();
            size += chunk.len();
        }
        assert_eq!(data_size, size as u32);
        assert_eq!(data_size, data_chunk.len() as u32);
    }
}