any-tts 0.1.1

A Rust TTS library with Candle backends and runtime adapters for modern open TTS models
Documentation
use std::io::{Read, Seek};

use nanomp3::{Channels as Mp3Channels, Decoder as Mp3Decoder, MAX_SAMPLES_PER_FRAME};

use super::AudioSamples;
use crate::TtsError;

pub(super) fn decode_wav_bytes(bytes: &[u8]) -> Result<AudioSamples, TtsError> {
    let (format, data) = parse_wav_chunks(bytes)?;
    let decoded = decode_wav_data(format, data)?;
    Ok(AudioSamples::new(
        downmix_to_mono(decoded, format.channels),
        format.sample_rate,
    ))
}

pub(super) fn decode_audio_stream<R>(stream: R) -> Result<AudioSamples, TtsError>
where
    R: Read + Seek + Send + Sync + 'static,
{
    let mut bytes = Vec::new();
    let mut stream = stream;
    stream
        .read_to_end(&mut bytes)
        .map_err(|err| TtsError::AudioError(format!("Failed to read audio stream: {err}")))?;

    decode_audio_bytes(&bytes)
}

pub(super) fn decode_audio_bytes(bytes: &[u8]) -> Result<AudioSamples, TtsError> {
    if is_wav_bytes(bytes) {
        return decode_wav_bytes(bytes);
    }

    decode_mp3_bytes(bytes)
}

fn is_wav_bytes(bytes: &[u8]) -> bool {
    bytes.len() >= 12 && &bytes[0..4] == b"RIFF" && &bytes[8..12] == b"WAVE"
}

fn parse_wav_chunks(bytes: &[u8]) -> Result<(WavFormat, &[u8]), TtsError> {
    if !is_wav_bytes(bytes) {
        return Err(TtsError::AudioError("Invalid WAV header".into()));
    }

    let mut offset = 12usize;
    let mut format = None;
    let mut data = None;

    while offset + 8 <= bytes.len() {
        let chunk_id = &bytes[offset..offset + 4];
        let chunk_size = u32::from_le_bytes([
            bytes[offset + 4],
            bytes[offset + 5],
            bytes[offset + 6],
            bytes[offset + 7],
        ]) as usize;
        let chunk_start = offset + 8;
        let chunk_end = chunk_start.saturating_add(chunk_size);

        if chunk_end > bytes.len() {
            return Err(TtsError::AudioError("Malformed WAV chunk size".into()));
        }

        match chunk_id {
            b"fmt " => format = Some(parse_wav_format(bytes, chunk_start, chunk_size)?),
            b"data" => data = Some(&bytes[chunk_start..chunk_end]),
            _ => {}
        }

        offset = chunk_end;
        if chunk_size % 2 == 1 {
            offset = offset.saturating_add(1);
        }
    }

    let format = format.ok_or_else(|| TtsError::AudioError("Missing WAV fmt chunk".into()))?;
    let data = data.ok_or_else(|| TtsError::AudioError("Missing WAV data chunk".into()))?;
    Ok((format, data))
}

fn parse_wav_format(
    bytes: &[u8],
    chunk_start: usize,
    chunk_size: usize,
) -> Result<WavFormat, TtsError> {
    if chunk_size < 16 {
        return Err(TtsError::AudioError("WAV fmt chunk is too small".into()));
    }

    let format = WavFormat {
        audio_format: u16::from_le_bytes([bytes[chunk_start], bytes[chunk_start + 1]]),
        channels: u16::from_le_bytes([bytes[chunk_start + 2], bytes[chunk_start + 3]]),
        sample_rate: u32::from_le_bytes([
            bytes[chunk_start + 4],
            bytes[chunk_start + 5],
            bytes[chunk_start + 6],
            bytes[chunk_start + 7],
        ]),
        bits_per_sample: u16::from_le_bytes([bytes[chunk_start + 14], bytes[chunk_start + 15]]),
    };
    if format.channels == 0 {
        return Err(TtsError::AudioError(
            "WAV file declares zero channels".into(),
        ));
    }

    Ok(format)
}

fn decode_wav_data(format: WavFormat, data: &[u8]) -> Result<Vec<f32>, TtsError> {
    match (format.audio_format, format.bits_per_sample) {
        (1, 8) => Ok(data
            .iter()
            .map(|&sample| (sample as f32 - 128.0) / 127.0)
            .collect()),
        (1, 16) => Ok(data
            .chunks_exact(2)
            .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32)
            .collect()),
        (1, 24) => Ok(data
            .chunks_exact(3)
            .map(|chunk| {
                let value =
                    ((chunk[2] as i32) << 24 >> 8) | ((chunk[1] as i32) << 8) | chunk[0] as i32;
                value as f32 / 8_388_607.0
            })
            .collect()),
        (1, 32) => Ok(data
            .chunks_exact(4)
            .map(|chunk| {
                i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as f32
                    / i32::MAX as f32
            })
            .collect()),
        (3, 32) => Ok(data
            .chunks_exact(4)
            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
            .collect()),
        _ => Err(TtsError::AudioError(format!(
            "Unsupported WAV format: format={}, bits={}",
            format.audio_format, format.bits_per_sample
        ))),
    }
}

fn downmix_to_mono(samples: Vec<f32>, channels: u16) -> Vec<f32> {
    if channels == 1 {
        return samples;
    }

    samples
        .chunks(channels as usize)
        .map(|frame| frame.iter().copied().sum::<f32>() / frame.len() as f32)
        .collect()
}

fn decode_mp3_bytes(bytes: &[u8]) -> Result<AudioSamples, TtsError> {
    let mut decoder = Mp3Decoder::new();
    let mut pcm = [0.0f32; MAX_SAMPLES_PER_FRAME];
    let mut sample_rate = None;
    let mut samples = Vec::new();
    let mut offset = 0usize;

    while offset < bytes.len() {
        let (consumed, frame) = decoder.decode(&bytes[offset..], &mut pcm);

        if consumed == 0 && frame.is_none() {
            break;
        }

        offset = offset.saturating_add(consumed);

        let Some(frame) = frame else {
            continue;
        };

        if frame.sample_rate == 0 {
            return Err(TtsError::AudioError(
                "MP3 stream is missing a sample rate".into(),
            ));
        }

        match sample_rate {
            Some(existing) if existing != frame.sample_rate => {
                return Err(TtsError::AudioError(format!(
                    "MP3 stream changed sample rate from {existing} Hz to {} Hz",
                    frame.sample_rate
                )));
            }
            None => sample_rate = Some(frame.sample_rate),
            _ => {}
        }

        append_decoded_mp3_frame(&mut samples, &pcm, frame.samples_produced, frame.channels)?;
    }

    let sample_rate = sample_rate.ok_or_else(|| {
        TtsError::AudioError("Unsupported audio stream: expected WAV or MP3 data".into())
    })?;

    if samples.is_empty() {
        return Err(TtsError::AudioError(
            "Unsupported audio stream: expected WAV or MP3 data".into(),
        ));
    }

    Ok(AudioSamples::new(samples, sample_rate))
}

fn append_decoded_mp3_frame(
    samples: &mut Vec<f32>,
    pcm: &[f32],
    samples_produced: usize,
    channels: Mp3Channels,
) -> Result<(), TtsError> {
    let channel_count = channels.num() as usize;
    let total_samples = samples_produced.checked_mul(channel_count).ok_or_else(|| {
        TtsError::AudioError("MP3 frame sample count overflowed while decoding".into())
    })?;
    let frame = pcm.get(..total_samples).ok_or_else(|| {
        TtsError::AudioError("MP3 decoder returned an invalid frame length".into())
    })?;

    if channel_count == 1 {
        samples.extend_from_slice(frame);
        return Ok(());
    }

    for frame in frame.chunks_exact(channel_count) {
        samples.push(frame.iter().copied().sum::<f32>() / channel_count as f32);
    }

    Ok(())
}

#[derive(Debug, Clone, Copy)]
struct WavFormat {
    audio_format: u16,
    channels: u16,
    sample_rate: u32,
    bits_per_sample: u16,
}