stream-wave-parser 0.1.2

The `stream-wave-parser` is a crate that treats a stream from WAVE file.
Documentation
//! Types to convert multiple channels data into one channel.

use crate::{Error, Result, WaveSpec};
use futures_util::stream::BoxStream;
use futures_util::Stream;

const INVALID_PCM_FORMAT: &str = "only support PCM integer (`pcm_format` is 1)";
const INVALID_BITS_PER_SAMPLE: &str = "only supports a `bits_per_sample` divisible by 8";
const INVALID_BYTES_PER_SAMPLE: &str = "only supports a `bits_per_sample` less than or equal 64";

/// The type that converts WAVE multi channels data into a single channel.
pub struct WaveChannelMixer<'a> {
    stream: BoxStream<'a, Result<Vec<u8>>>,

    channels: u16,
    bits_per_sample: u16,
    options: WaveChannelMixerOptions,

    rest: Vec<u8>,
}

/// The type representing the options to mix.
pub enum WaveChannelMixerOptions {
    /// Picks one channel from multiple channels.
    Pick(u16),

    /// Makes the mean of multiple channels.
    Mean,
}

impl<'a> WaveChannelMixer<'a> {
    /// The constructor.
    pub fn new(
        spec: &WaveSpec,
        options: WaveChannelMixerOptions,
        stream: impl Stream<Item = Result<Vec<u8>>> + Send + 'a,
    ) -> Result<Self> {
        if spec.pcm_format != 1 {
            return Err(Error::MixerConstruction(INVALID_PCM_FORMAT));
        }

        if spec.bits_per_sample % 8 != 0 {
            return Err(Error::MixerConstruction(INVALID_BITS_PER_SAMPLE));
        }
        if spec.bits_per_sample > 64 {
            return Err(Error::MixerConstruction(INVALID_BYTES_PER_SAMPLE));
        }

        let options = if let WaveChannelMixerOptions::Pick(idx) = options {
            WaveChannelMixerOptions::Pick(idx % spec.channels)
        } else {
            options
        };

        let ret = Self {
            stream: Box::pin(stream),

            channels: spec.channels,
            bits_per_sample: spec.bits_per_sample,
            options,

            rest: vec![],
        };

        Ok(ret)
    }

    /// Converts multiple channels into ([single channel], [rest data]).
    fn convert(&self, input: &[u8]) -> (Vec<u8>, Vec<u8>) {
        match &self.options {
            WaveChannelMixerOptions::Pick(idx) => self.pick(*idx, input),
            WaveChannelMixerOptions::Mean => self.mean(input),
        }
    }

    /// Picks single channel in multiple channels.
    fn pick(&self, pick_idx: u16, input: &[u8]) -> (Vec<u8>, Vec<u8>) {
        let mut idx = 0u16;
        let mut converted = vec![];

        println!("input = {input:?}");

        let bytes_per_sample = (self.bits_per_sample / 8) as usize;
        let mut head = 0usize;
        let end = input.len() - input.len() % (bytes_per_sample * self.channels as usize);
        while head < end {
            if idx == pick_idx {
                let sample = self.take_sample(input, head);
                converted.extend(self.sample_into_vec(sample));
            }
            head += bytes_per_sample;
            idx = (idx + 1) % self.channels;
        }

        let rest = input.iter().skip(head).cloned().collect();

        println!("rest = {rest:?}");

        (converted, rest)
    }

    /// Makes the mean of multiple channels into a single channel.
    fn mean(&self, input: &[u8]) -> (Vec<u8>, Vec<u8>) {
        let mut idx = 0u16;
        let mut sum = 0f32;
        let mut converted = vec![];

        let bytes_per_sample = (self.bits_per_sample / 8) as usize;
        let mut head = 0usize;
        let end = input.len() - input.len() % (bytes_per_sample * self.channels as usize);
        while head < end {
            let sample = self.take_sample(input, head);
            sum += sample;

            head += bytes_per_sample;
            idx = (idx + 1) % self.channels;

            if idx == 0 {
                converted.extend(self.sample_into_vec(sum / self.channels as f32));
                sum = 0.;
            }
        }

        let rest = input.iter().skip(head).cloned().collect();

        (converted, rest)
    }

    fn take_sample(&self, input: &[u8], head: usize) -> f32 {
        let bytes_per_sample = (self.bits_per_sample / 8) as usize;
        match bytes_per_sample {
            1..=1 => from_bytes::take_i8_sample(input, head) as f32,
            2..=2 => from_bytes::take_i16_sample(input, head) as f32,
            3..=4 => from_bytes::take_i32_sample(input, head, bytes_per_sample) as f32,
            5..=8 => from_bytes::take_i64_sample(input, head, bytes_per_sample) as f32,
            _ => unreachable!(),
        }
    }

    fn sample_into_vec(&self, sample: f32) -> Vec<u8> {
        let bytes_per_sample = (self.bits_per_sample / 8) as usize;
        match bytes_per_sample {
            1..=1 => (sample as i8).to_le_bytes().to_vec(),
            2..=2 => (sample as i16).to_le_bytes().to_vec(),
            3..=4 => (sample as i32).to_le_bytes()[..bytes_per_sample].to_vec(),
            5..=8 => (sample as i64).to_le_bytes()[..bytes_per_sample].to_vec(),
            _ => unreachable!(),
        }
    }
}

mod impls {
    //! Implements [`Stream`] for [`WaveChannelMixer`].

    use super::*;

    use std::pin::Pin;
    use std::task::{Context, Poll};

    impl<'a> Stream for WaveChannelMixer<'a> {
        type Item = Result<Vec<u8>>;

        fn poll_next(
            mut self: Pin<&mut Self>,
            context: &mut Context<'_>,
        ) -> Poll<Option<Self::Item>> {
            let polled = self.stream.as_mut().poll_next(context);
            let ready = match polled {
                Poll::Ready(ready) => ready,
                Poll::Pending => return Poll::Pending,
            };

            // the input stream is exhausted
            let Some(chunk) = ready else {
                if self.rest.is_empty() {
                    return Poll::Ready(None);
                } else {
                    return Poll::Ready(Some(Err(Error::DataIsNotEnough)));
                }
            };

            // the input stream has error
            let chunk = match chunk {
                Ok(chunk) => chunk,
                Err(e) => return Poll::Ready(Some(Err(e))),
            };

            // convert
            self.rest.extend(&chunk);
            let (converted, rest) = self.convert(&self.rest);
            self.rest = rest;

            Poll::Ready(Some(Ok(converted)))
        }
    }
}

mod from_bytes {
    //! Utilities of `from_le_bytes()`.
    pub fn take_i8_sample(input: &[u8], head: usize) -> i8 {
        let mut buf = [0u8; 1];
        buf[0] = input[head];
        i8::from_le_bytes(buf)
    }

    pub fn take_i16_sample(input: &[u8], head: usize) -> i16 {
        let mut buf = [0u8; 2];
        buf[..2].copy_from_slice(&input[head..(head + 2)]);
        i16::from_le_bytes(buf)
    }

    pub fn take_i32_sample(input: &[u8], head: usize, width: usize) -> i32 {
        let mut buf = [0u8; 4];
        buf[..width].copy_from_slice(&input[head..(head + width)]);
        i32::from_le_bytes(buf)
    }

    pub fn take_i64_sample(input: &[u8], head: usize, width: usize) -> i64 {
        let mut buf = [0u8; 8];
        buf[..width].copy_from_slice(&input[head..(head + width)]);
        i64::from_le_bytes(buf)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use futures_util::stream::iter;
    use futures_util::StreamExt as _;

    #[tokio::test]
    async fn test_pick() {
        let spec = WaveSpec {
            pcm_format: 1,
            channels: 2,
            sample_rate: 8000,
            bits_per_sample: 16,
        };

        let options = WaveChannelMixerOptions::Pick(0); // channel 0

        let chunks = (0i16..200)
            .map(|x| x.to_le_bytes())
            .flatten()
            .collect::<Vec<_>>()
            .chunks(31)
            .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
            .collect::<Vec<_>>();
        let stream = iter(chunks);

        let mut mixer = WaveChannelMixer::new(&spec, options, stream).unwrap();
        let mut converted = vec![];
        while let Some(chunk) = mixer.next().await {
            let chunk = chunk.unwrap();
            converted.extend(chunk);
        }

        let converted: Vec<_> = converted
            .chunks(2)
            .map(|x| from_bytes::take_i16_sample(x, 0))
            .collect();
        let expected: Vec<_> = (0i16..100).map(|x| x * 2).collect();
        assert_eq!(converted, expected);

        let options = WaveChannelMixerOptions::Pick(1); // channel 1

        let chunks = (0i16..200)
            .map(|x| x.to_le_bytes())
            .flatten()
            .collect::<Vec<_>>()
            .chunks(31)
            .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
            .collect::<Vec<_>>();
        let stream = iter(chunks);

        let mut mixer = WaveChannelMixer::new(&spec, options, stream).unwrap();
        let mut converted = vec![];
        while let Some(chunk) = mixer.next().await {
            let chunk = chunk.unwrap();
            converted.extend(chunk);
        }

        let converted: Vec<_> = converted
            .chunks(2)
            .map(|x| from_bytes::take_i16_sample(x, 0))
            .collect();
        let expected: Vec<_> = (0i16..100).map(|x| x * 2 + 1).collect();
        assert_eq!(converted, expected);
    }

    #[tokio::test]
    async fn test_mean() {
        let spec = WaveSpec {
            pcm_format: 1,
            channels: 2,
            sample_rate: 8000,
            bits_per_sample: 16,
        };

        let options = WaveChannelMixerOptions::Mean;

        let chunks = (0i16..200)
            .map(|x| (x * 2).to_le_bytes())
            .flatten()
            .collect::<Vec<_>>()
            .chunks(31)
            .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
            .collect::<Vec<_>>();
        let stream = iter(chunks);

        let mut mixer = WaveChannelMixer::new(&spec, options, stream).unwrap();
        let mut converted = vec![];
        while let Some(chunk) = mixer.next().await {
            let chunk = chunk.unwrap();
            converted.extend(chunk);
        }

        let converted: Vec<_> = converted
            .chunks(2)
            .map(|x| from_bytes::take_i16_sample(x, 0))
            .collect();
        let expected: Vec<_> = (0i16..100).map(|x| x * 4 + 1).collect();
        assert_eq!(converted, expected);
    }
}