use crate::{Error, Result, WaveSpec};
use futures_util::stream::BoxStream;
use futures_util::Stream;
const INVALID_PCM_FORMAT: &str = "only support PCM integer (`pcm_format` is 1)";
const INVALID_BITS_PER_SAMPLE: &str = "only supports a `bits_per_sample` divisible by 8";
const INVALID_BYTES_PER_SAMPLE: &str = "only supports a `bits_per_sample` less than or equal 64";
pub struct WaveChannelMixer<'a> {
stream: BoxStream<'a, Result<Vec<u8>>>,
channels: u16,
bits_per_sample: u16,
options: WaveChannelMixerOptions,
rest: Vec<u8>,
}
pub enum WaveChannelMixerOptions {
Pick(u16),
Mean,
}
impl<'a> WaveChannelMixer<'a> {
pub fn new(
spec: &WaveSpec,
options: WaveChannelMixerOptions,
stream: impl Stream<Item = Result<Vec<u8>>> + Send + 'a,
) -> Result<Self> {
if spec.pcm_format != 1 {
return Err(Error::MixerConstruction(INVALID_PCM_FORMAT));
}
if spec.bits_per_sample % 8 != 0 {
return Err(Error::MixerConstruction(INVALID_BITS_PER_SAMPLE));
}
if spec.bits_per_sample > 64 {
return Err(Error::MixerConstruction(INVALID_BYTES_PER_SAMPLE));
}
let options = if let WaveChannelMixerOptions::Pick(idx) = options {
WaveChannelMixerOptions::Pick(idx % spec.channels)
} else {
options
};
let ret = Self {
stream: Box::pin(stream),
channels: spec.channels,
bits_per_sample: spec.bits_per_sample,
options,
rest: vec![],
};
Ok(ret)
}
fn convert(&self, input: &[u8]) -> (Vec<u8>, Vec<u8>) {
match &self.options {
WaveChannelMixerOptions::Pick(idx) => self.pick(*idx, input),
WaveChannelMixerOptions::Mean => self.mean(input),
}
}
fn pick(&self, pick_idx: u16, input: &[u8]) -> (Vec<u8>, Vec<u8>) {
let mut idx = 0u16;
let mut converted = vec![];
println!("input = {input:?}");
let bytes_per_sample = (self.bits_per_sample / 8) as usize;
let mut head = 0usize;
let end = input.len() - input.len() % (bytes_per_sample * self.channels as usize);
while head < end {
if idx == pick_idx {
let sample = self.take_sample(input, head);
converted.extend(self.sample_into_vec(sample));
}
head += bytes_per_sample;
idx = (idx + 1) % self.channels;
}
let rest = input.iter().skip(head).cloned().collect();
println!("rest = {rest:?}");
(converted, rest)
}
fn mean(&self, input: &[u8]) -> (Vec<u8>, Vec<u8>) {
let mut idx = 0u16;
let mut sum = 0f32;
let mut converted = vec![];
let bytes_per_sample = (self.bits_per_sample / 8) as usize;
let mut head = 0usize;
let end = input.len() - input.len() % (bytes_per_sample * self.channels as usize);
while head < end {
let sample = self.take_sample(input, head);
sum += sample;
head += bytes_per_sample;
idx = (idx + 1) % self.channels;
if idx == 0 {
converted.extend(self.sample_into_vec(sum / self.channels as f32));
sum = 0.;
}
}
let rest = input.iter().skip(head).cloned().collect();
(converted, rest)
}
fn take_sample(&self, input: &[u8], head: usize) -> f32 {
let bytes_per_sample = (self.bits_per_sample / 8) as usize;
match bytes_per_sample {
1..=1 => from_bytes::take_i8_sample(input, head) as f32,
2..=2 => from_bytes::take_i16_sample(input, head) as f32,
3..=4 => from_bytes::take_i32_sample(input, head, bytes_per_sample) as f32,
5..=8 => from_bytes::take_i64_sample(input, head, bytes_per_sample) as f32,
_ => unreachable!(),
}
}
fn sample_into_vec(&self, sample: f32) -> Vec<u8> {
let bytes_per_sample = (self.bits_per_sample / 8) as usize;
match bytes_per_sample {
1..=1 => (sample as i8).to_le_bytes().to_vec(),
2..=2 => (sample as i16).to_le_bytes().to_vec(),
3..=4 => (sample as i32).to_le_bytes()[..bytes_per_sample].to_vec(),
5..=8 => (sample as i64).to_le_bytes()[..bytes_per_sample].to_vec(),
_ => unreachable!(),
}
}
}
mod impls {
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
impl<'a> Stream for WaveChannelMixer<'a> {
type Item = Result<Vec<u8>>;
fn poll_next(
mut self: Pin<&mut Self>,
context: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let polled = self.stream.as_mut().poll_next(context);
let ready = match polled {
Poll::Ready(ready) => ready,
Poll::Pending => return Poll::Pending,
};
let Some(chunk) = ready else {
if self.rest.is_empty() {
return Poll::Ready(None);
} else {
return Poll::Ready(Some(Err(Error::DataIsNotEnough)));
}
};
let chunk = match chunk {
Ok(chunk) => chunk,
Err(e) => return Poll::Ready(Some(Err(e))),
};
self.rest.extend(&chunk);
let (converted, rest) = self.convert(&self.rest);
self.rest = rest;
Poll::Ready(Some(Ok(converted)))
}
}
}
mod from_bytes {
pub fn take_i8_sample(input: &[u8], head: usize) -> i8 {
let mut buf = [0u8; 1];
buf[0] = input[head];
i8::from_le_bytes(buf)
}
pub fn take_i16_sample(input: &[u8], head: usize) -> i16 {
let mut buf = [0u8; 2];
buf[..2].copy_from_slice(&input[head..(head + 2)]);
i16::from_le_bytes(buf)
}
pub fn take_i32_sample(input: &[u8], head: usize, width: usize) -> i32 {
let mut buf = [0u8; 4];
buf[..width].copy_from_slice(&input[head..(head + width)]);
i32::from_le_bytes(buf)
}
pub fn take_i64_sample(input: &[u8], head: usize, width: usize) -> i64 {
let mut buf = [0u8; 8];
buf[..width].copy_from_slice(&input[head..(head + width)]);
i64::from_le_bytes(buf)
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::stream::iter;
use futures_util::StreamExt as _;
#[tokio::test]
async fn test_pick() {
let spec = WaveSpec {
pcm_format: 1,
channels: 2,
sample_rate: 8000,
bits_per_sample: 16,
};
let options = WaveChannelMixerOptions::Pick(0);
let chunks = (0i16..200)
.map(|x| x.to_le_bytes())
.flatten()
.collect::<Vec<_>>()
.chunks(31)
.map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
.collect::<Vec<_>>();
let stream = iter(chunks);
let mut mixer = WaveChannelMixer::new(&spec, options, stream).unwrap();
let mut converted = vec![];
while let Some(chunk) = mixer.next().await {
let chunk = chunk.unwrap();
converted.extend(chunk);
}
let converted: Vec<_> = converted
.chunks(2)
.map(|x| from_bytes::take_i16_sample(x, 0))
.collect();
let expected: Vec<_> = (0i16..100).map(|x| x * 2).collect();
assert_eq!(converted, expected);
let options = WaveChannelMixerOptions::Pick(1);
let chunks = (0i16..200)
.map(|x| x.to_le_bytes())
.flatten()
.collect::<Vec<_>>()
.chunks(31)
.map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
.collect::<Vec<_>>();
let stream = iter(chunks);
let mut mixer = WaveChannelMixer::new(&spec, options, stream).unwrap();
let mut converted = vec![];
while let Some(chunk) = mixer.next().await {
let chunk = chunk.unwrap();
converted.extend(chunk);
}
let converted: Vec<_> = converted
.chunks(2)
.map(|x| from_bytes::take_i16_sample(x, 0))
.collect();
let expected: Vec<_> = (0i16..100).map(|x| x * 2 + 1).collect();
assert_eq!(converted, expected);
}
#[tokio::test]
async fn test_mean() {
let spec = WaveSpec {
pcm_format: 1,
channels: 2,
sample_rate: 8000,
bits_per_sample: 16,
};
let options = WaveChannelMixerOptions::Mean;
let chunks = (0i16..200)
.map(|x| (x * 2).to_le_bytes())
.flatten()
.collect::<Vec<_>>()
.chunks(31)
.map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
.collect::<Vec<_>>();
let stream = iter(chunks);
let mut mixer = WaveChannelMixer::new(&spec, options, stream).unwrap();
let mut converted = vec![];
while let Some(chunk) = mixer.next().await {
let chunk = chunk.unwrap();
converted.extend(chunk);
}
let converted: Vec<_> = converted
.chunks(2)
.map(|x| from_bytes::take_i16_sample(x, 0))
.collect();
let expected: Vec<_> = (0i16..100).map(|x| x * 4 + 1).collect();
assert_eq!(converted, expected);
}
}