#![forbid(unsafe_code)]
use std::fmt::Display;
use std::fs::File;
use std::io::Cursor;
use std::path::Path;
use std::time::Duration;
use std::{fmt, io};
pub const QOA_SLICE_LEN: usize = 20;
pub const QOA_LMS_LEN: usize = 4;
pub const QOA_HEADER_SIZE: usize = 8;
pub const QOA_MAGIC: u32 = u32::from_be_bytes(*b"qoaf");
pub const MAX_SLICES_PER_CHANNEL_PER_FRAME: usize = 256;
#[derive(Debug, Clone, PartialEq)]
pub enum ProcessingMode {
FixedSamples {
channels: u8,
sample_rate: u32,
samples: u32,
},
Streaming,
}
#[derive(Debug)]
pub struct QoaDecoder<R> {
mode: ProcessingMode,
lms: Vec<QoaLms>,
reader: R,
current_frame: CurrentFrame,
pending_samples: Box<[i16]>,
next_pending_sample_idx: usize,
returned_first_frame_header: bool,
}
#[derive(Debug, Clone, Default)]
struct QoaLms {
history: [i32; QOA_LMS_LEN],
weights: [i32; QOA_LMS_LEN],
}
impl<R> QoaDecoder<R>
where
R: io::Read,
{
pub fn new(mut reader: R) -> Result<Self, DecodeError> {
let magic = read_u32_be(&mut reader)?;
if magic != QOA_MAGIC {
return Err(DecodeError::NotQoaFile);
}
let samples = read_u32_be(&mut reader)?;
let mode = if samples == 0 {
ProcessingMode::Streaming
} else {
ProcessingMode::FixedSamples {
channels: 0,
sample_rate: 0,
samples,
}
};
let current_frame: CurrentFrame = Default::default();
let mut to_return = Self {
mode,
lms: Vec::new(),
reader,
current_frame,
pending_samples: Box::new([]),
next_pending_sample_idx: 0,
returned_first_frame_header: false,
};
if to_return.mode != ProcessingMode::Streaming {
let found_frame = to_return.decode_frame_header_and_lms(true)?;
if !found_frame {
return Err(DecodeError::NoSamples);
}
}
Ok(to_return)
}
pub fn into_inner(self) -> R {
self.reader
}
pub fn mode(&self) -> &ProcessingMode {
&self.mode
}
pub fn current_frame_header(&self) -> &FrameHeader {
&self.current_frame.header
}
pub fn total_duration(&self) -> Option<Duration> {
match &self.mode {
ProcessingMode::FixedSamples {
channels: _channels,
sample_rate,
samples,
} => Some(Duration::from_secs_f64(
(*samples as f64) / (*sample_rate as f64),
)),
ProcessingMode::Streaming => None,
}
}
fn decode_frame_header_and_lms(&mut self, first: bool) -> Result<bool, DecodeError> {
let frame_header = match read_u64_be(&mut self.reader) {
Ok(h) => h,
Err(e) => {
return if e.kind() == io::ErrorKind::UnexpectedEof {
Ok(false)
} else {
Err(e.into())
};
}
};
let num_channels = ((frame_header >> 56) & 0x0000ff) as u8;
let sample_rate = ((frame_header >> 32) & 0xffffff) as u32;
let num_samples_per_channel = ((frame_header >> 16) & 0x00ffff) as u16;
let frame_size = (frame_header & 0x00ffff) as u16;
let frame_header = FrameHeader {
num_channels,
sample_rate,
num_samples_per_channel,
};
if num_channels == 0 || sample_rate == 0 {
return Err(DecodeError::InvalidFrameHeader);
}
const LMS_SIZE: usize = 4;
let non_sample_data_size = QOA_HEADER_SIZE + QOA_LMS_LEN * LMS_SIZE * num_channels as usize;
if frame_size as usize <= non_sample_data_size {
return Err(DecodeError::InvalidFrameHeader);
}
let data_size = frame_size as usize - non_sample_data_size;
let num_slices = data_size / 8;
if num_slices % num_channels as usize != 0 {
return Err(DecodeError::InvalidFrameHeader);
}
if num_slices / num_channels as usize > MAX_SLICES_PER_CHANNEL_PER_FRAME {
return Err(DecodeError::InvalidFrameHeader);
}
if let ProcessingMode::FixedSamples {
channels: decoded_channels,
sample_rate: decoded_sample_rate,
..
} = &mut self.mode
{
if first {
*decoded_channels = num_channels;
*decoded_sample_rate = sample_rate;
} else if num_channels != *decoded_channels || sample_rate != *decoded_sample_rate {
return Err(DecodeError::IncompatibleFrame);
}
}
if self.lms.len() != num_channels as usize {
assert!(matches!(self.mode, ProcessingMode::Streaming) || first);
self.lms
.resize_with(num_channels as usize, Default::default);
}
for c in 0..num_channels as usize {
let mut history = read_u64_be(&mut self.reader)?;
let mut weights = read_u64_be(&mut self.reader)?;
for i in 0..QOA_LMS_LEN {
self.lms[c].history[i] = ((history >> 48) as i16) as i32;
history <<= 16;
self.lms[c].weights[i] = ((weights >> 48) as i16) as i32;
weights <<= 16;
}
}
self.current_frame = CurrentFrame {
header: frame_header,
num_samples_per_channel_remaining: num_samples_per_channel,
};
Ok(true)
}
fn decode_one_slice_per_channel(&mut self) -> Result<(), DecodeError> {
assert!(self.next_pending_sample_idx >= self.pending_samples.len());
let channels = self.current_frame.header.num_channels as usize;
let full_slices_num_samples = QOA_SLICE_LEN * channels;
if self.pending_samples.len() != full_slices_num_samples {
self.pending_samples = vec![0_i16; full_slices_num_samples].into_boxed_slice();
}
self.next_pending_sample_idx = 0;
for channel_idx in 0..channels {
let mut slice = read_u64_be(&mut self.reader)?;
let scale_factor = ((slice >> 60) & 0xf) as usize;
for sample_in_channel_slice_idx in 0..QOA_SLICE_LEN {
let prediction = self.lms[channel_idx].predict();
let quantized = ((slice >> 57) & 0x7) as usize;
let dequantized = QOA_DEQUANT_TAB[scale_factor][quantized];
let reconstructed = (prediction + dequantized).clamp(-32768, 32767) as i16;
let data_idx = sample_in_channel_slice_idx * channels + channel_idx;
self.pending_samples[data_idx] = reconstructed;
slice <<= 3;
self.lms[channel_idx].update(reconstructed, dequantized);
}
}
let num_samples_per_channel = self.current_frame.num_samples_per_channel_remaining;
if (num_samples_per_channel as usize) < QOA_SLICE_LEN {
let total_num_samples = num_samples_per_channel as usize * channels;
self.pending_samples = self.pending_samples[0..total_num_samples]
.to_vec()
.into_boxed_slice();
self.current_frame.num_samples_per_channel_remaining -= num_samples_per_channel;
} else {
self.current_frame.num_samples_per_channel_remaining -= QOA_SLICE_LEN as u16;
}
Ok(())
}
}
impl QoaDecoder<io::BufReader<File>> {
pub fn open<P: AsRef<Path>>(path: P) -> Result<QoaDecoder<io::BufReader<File>>, DecodeError> {
let file = File::open(path)?;
QoaDecoder::new(io::BufReader::new(file))
}
}
impl QoaDecoder<Cursor<Vec<u8>>> {
pub fn new_streaming() -> Result<Self, DecodeError> {
let streaming_header: Vec<u8> = [QOA_MAGIC, 0]
.iter()
.flat_map(|&x| x.to_be_bytes())
.collect();
QoaDecoder::new(Cursor::new(streaming_header))
}
pub fn decode_frame(&mut self, frame_data: &[u8]) -> Result<Vec<i16>, DecodeError> {
self.reader.get_mut().extend_from_slice(frame_data);
let mut to_return = Vec::new();
for item in self {
match item? {
QoaItem::Sample(s) => to_return.push(s),
QoaItem::FrameHeader(_) => (),
}
}
Ok(to_return)
}
}
#[derive(Debug)]
pub enum QoaItem {
Sample(i16),
FrameHeader(FrameHeader),
}
impl<R: io::Read> Iterator for QoaDecoder<R> {
type Item = Result<QoaItem, DecodeError>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(sample) = self.pending_samples.get(self.next_pending_sample_idx) {
self.next_pending_sample_idx += 1;
return Some(Ok(QoaItem::Sample(*sample)));
}
if !self.returned_first_frame_header {
self.returned_first_frame_header = true;
return Some(Ok(QoaItem::FrameHeader(self.current_frame.header)));
}
if self.current_frame.num_samples_per_channel_remaining > 0 {
if let Err(e) = self.decode_one_slice_per_channel() {
return Some(Err(e));
}
} else {
return match self.decode_frame_header_and_lms(false) {
Ok(true) => Some(Ok(QoaItem::FrameHeader(self.current_frame.header))),
Ok(false) => None,
Err(e) => Some(Err(e)),
};
}
debug_assert!(!self.pending_samples.is_empty());
self.next()
}
}
pub struct DecodedQoa {
pub num_channels: u8,
pub sample_rate: u32,
pub samples: Vec<i16>,
}
pub fn decode_all<R: io::Read>(reader: R) -> Result<DecodedQoa, DecodeError> {
let mut decoder = QoaDecoder::new(reader)?;
let mut samples = Vec::new();
if let &ProcessingMode::FixedSamples {
samples: samples_per_channel,
channels,
..
} = decoder.mode()
{
samples.reserve_exact(samples_per_channel as usize * channels as usize);
}
let QoaItem::FrameHeader(FrameHeader {
num_channels,
sample_rate,
..
}) = decoder.next().unwrap()?
else {
unreachable!();
};
for item in decoder {
match item? {
QoaItem::Sample(s) => samples.push(s),
QoaItem::FrameHeader(header) => {
if num_channels != header.num_channels || sample_rate != header.sample_rate {
return Err(DecodeError::IncompatibleFrame);
}
}
}
}
Ok(DecodedQoa {
num_channels,
sample_rate,
samples,
})
}
pub fn open_and_decode_all<P: AsRef<Path>>(path: P) -> Result<DecodedQoa, DecodeError> {
let file = File::open(path.as_ref())?;
let reader = io::BufReader::new(file);
decode_all(reader)
}
#[derive(Debug, Default)]
struct CurrentFrame {
header: FrameHeader,
num_samples_per_channel_remaining: u16,
}
#[derive(Debug, Copy, Clone, Default)]
pub struct FrameHeader {
pub num_channels: u8,
pub sample_rate: u32,
pub num_samples_per_channel: u16,
}
fn read_u32_be<R: io::Read>(mut reader: R) -> io::Result<u32> {
Ok(u32::from_be_bytes(read_array(&mut reader)?))
}
fn read_u64_be<R: io::Read>(mut reader: R) -> io::Result<u64> {
Ok(u64::from_be_bytes(read_array(&mut reader)?))
}
fn read_array<R: io::Read, const LEN: usize>(mut reader: R) -> io::Result<[u8; LEN]> {
let mut bytes = [0_u8; LEN];
reader.read_exact(&mut bytes)?;
Ok(bytes)
}
impl QoaLms {
#[inline(always)]
fn predict(&self) -> i32 {
let mut prediction: i32 = 0;
for i in 0..QOA_LMS_LEN {
prediction = prediction.wrapping_add(self.weights[i].wrapping_mul(self.history[i]));
}
prediction >> 13
}
#[inline(always)]
fn update(&mut self, sample: i16, residual: i32) {
let delta = residual >> 4;
for i in 0..QOA_LMS_LEN {
self.weights[i] += if self.history[i] < 0 { -delta } else { delta };
}
for i in 0..QOA_LMS_LEN - 1 {
self.history[i] = self.history[i + 1];
}
self.history[QOA_LMS_LEN - 1] = sample as i32;
}
}
const QOA_DEQUANT_TAB: [[i32; 8]; 16] = [
[1, -1, 3, -3, 5, -5, 7, -7],
[5, -5, 18, -18, 32, -32, 49, -49],
[16, -16, 53, -53, 95, -95, 147, -147],
[34, -34, 113, -113, 203, -203, 315, -315],
[63, -63, 210, -210, 378, -378, 588, -588],
[104, -104, 345, -345, 621, -621, 966, -966],
[158, -158, 528, -528, 950, -950, 1477, -1477],
[228, -228, 760, -760, 1368, -1368, 2128, -2128],
[316, -316, 1053, -1053, 1895, -1895, 2947, -2947],
[422, -422, 1405, -1405, 2529, -2529, 3934, -3934],
[548, -548, 1828, -1828, 3290, -3290, 5117, -5117],
[696, -696, 2320, -2320, 4176, -4176, 6496, -6496],
[868, -868, 2893, -2893, 5207, -5207, 8099, -8099],
[1064, -1064, 3548, -3548, 6386, -6386, 9933, -9933],
[1286, -1286, 4288, -4288, 7718, -7718, 12005, -12005],
[1536, -1536, 5120, -5120, 9216, -9216, 14336, -14336],
];
#[derive(Debug)]
pub enum DecodeError {
NotQoaFile,
NoSamples,
InvalidFrameHeader,
IncompatibleFrame,
IoError(io::Error),
}
impl std::error::Error for DecodeError {}
impl Display for DecodeError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
DecodeError::NotQoaFile => write!(f, "File is not a qoa file"),
DecodeError::NoSamples => write!(f, "File has no samples"),
DecodeError::InvalidFrameHeader => write!(f, "File has invalid frame header"),
DecodeError::IncompatibleFrame => write!(f, "Incompatible frame header"),
DecodeError::IoError(e) => write!(f, "IO error: {}", e),
}
}
}
impl From<io::Error> for DecodeError {
fn from(inner: io::Error) -> Self {
DecodeError::IoError(inner)
}
}
#[cfg(feature = "rodio")]
pub struct QoaRodioSource<R: io::Read> {
decoder: QoaDecoder<R>,
}
#[cfg(feature = "rodio")]
mod rodio_integration {
use super::*;
impl<R: io::Read> QoaRodioSource<R> {
pub fn new(decoder: QoaDecoder<R>) -> QoaRodioSource<R> {
Self { decoder }
}
}
impl<R: io::Read> Iterator for QoaRodioSource<R> {
type Item = i16;
fn next(&mut self) -> Option<Self::Item> {
loop {
return match self.decoder.next() {
Some(Ok(QoaItem::Sample(s))) => {
if self.decoder.next_pending_sample_idx
>= self.decoder.pending_samples.len()
&& self.decoder.current_frame.num_samples_per_channel_remaining == 0
{
match self.decoder.next() {
Some(Ok(QoaItem::FrameHeader(_))) => (),
Some(Ok(QoaItem::Sample(_))) => unreachable!(),
Some(Err(_)) => return None,
None => (), }
}
Some(s)
}
Some(Ok(QoaItem::FrameHeader(_))) => continue,
Some(Err(_)) => None,
None => None,
};
}
}
}
impl<R: io::Read> rodio::Source for QoaRodioSource<R> {
fn current_frame_len(&self) -> Option<usize> {
if matches!(self.decoder.mode, ProcessingMode::Streaming) {
let num_samples = self.decoder.current_frame.num_samples_per_channel_remaining
as usize
* self.decoder.current_frame.header.num_channels as usize;
Some(num_samples)
} else {
None
}
}
fn channels(&self) -> u16 {
self.decoder.current_frame.header.num_channels.into()
}
fn sample_rate(&self) -> u32 {
self.decoder.current_frame.header.sample_rate
}
fn total_duration(&self) -> Option<Duration> {
self.decoder.total_duration()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
static QOA_BYTES: &[u8] = include_bytes!("../fixtures/julien_baker_sprained_ankle.qoa");
#[test]
fn test_iterating_through_whole_file() {
let qoa = QoaDecoder::new(Cursor::new(QOA_BYTES)).unwrap();
assert!(matches!(
qoa.mode(),
ProcessingMode::FixedSamples {
channels: 2,
sample_rate: 44100,
samples: 2394122,
..
}
));
let mut frame_headers_seen = 0;
let mut samples_seen = 0;
for item in qoa {
let item = item.expect("No io errors should happen");
match item {
QoaItem::Sample(_) => samples_seen += 1,
QoaItem::FrameHeader(header) => {
assert_eq!(header.num_channels, 2);
assert_eq!(header.sample_rate, 44100);
frame_headers_seen += 1;
if frame_headers_seen < 468 {
assert_eq!(header.num_samples_per_channel, 5120);
} else {
assert_eq!(header.num_samples_per_channel, 3082);
}
}
}
}
assert_eq!(frame_headers_seen, 468);
assert_eq!(samples_seen, 2394122 * 2);
}
#[test]
fn test_decode_streaming_frames() {
let mut qoa = QoaDecoder::new_streaming().unwrap();
assert!(matches!(qoa.mode(), ProcessingMode::Streaming));
let frame_header =
read_u64_be(Cursor::new(QOA_BYTES[QOA_HEADER_SIZE..16].to_vec())).unwrap();
let frame_size = (frame_header & 0x00ffff) as u16;
let first_frame_end = 8 + frame_size as usize;
let samples = qoa
.decode_frame(&QOA_BYTES[QOA_HEADER_SIZE..first_frame_end])
.unwrap();
assert_eq!(samples.len(), 5120 * 2);
let frame_header = read_u64_be(Cursor::new(
QOA_BYTES[first_frame_end..first_frame_end + QOA_HEADER_SIZE].to_vec(),
))
.unwrap();
let frame_size = (frame_header & 0x00ffff) as u16;
let second_frame_end = first_frame_end + frame_size as usize;
let samples = qoa
.decode_frame(&QOA_BYTES[first_frame_end..second_frame_end])
.unwrap();
assert_eq!(samples.len(), 5120 * 2);
}
#[test]
fn test_decode_all() {
let decoded = decode_all(Cursor::new(QOA_BYTES)).unwrap();
assert_eq!(decoded.sample_rate, 44100);
assert_eq!(decoded.num_channels, 2);
assert_eq!(decoded.samples.len(), 2394122 * 2);
}
}