any-tts 0.1.0

A Rust TTS library with Candle backends and runtime adapters for modern open TTS models
Documentation
use std::collections::HashMap;
use std::hash::BuildHasherDefault;

use base64::Engine;
use rustc_hash::FxHasher;
use serde::Deserialize;
use tiktoken_rs::CoreBPE;

use crate::error::TtsError;

use super::config::VoxtralConfig;

const BOS_TOKEN: &str = "<s>";
const EOS_TOKEN: &str = "</s>";
const AUDIO_TOKEN: &str = "[AUDIO]";
const BEGIN_AUDIO_TOKEN: &str = "[BEGIN_AUDIO]";
const NEXT_AUDIO_TEXT_TOKEN: &str = "[NEXT_AUDIO_TEXT]";
const REPEAT_AUDIO_TEXT_TOKEN: &str = "[REPEAT_AUDIO_TEXT]";

type TekkenMap<K, V> = HashMap<K, V, BuildHasherDefault<FxHasher>>;

#[derive(Debug, Deserialize)]
struct TekkenFile {
    config: TekkenConfig,
    vocab: Vec<TekkenVocabEntry>,
    special_tokens: Vec<TekkenSpecialToken>,
}

#[derive(Debug, Deserialize)]
struct TekkenConfig {
    pattern: String,
    default_vocab_size: usize,
    default_num_special_tokens: usize,
}

#[derive(Debug, Deserialize)]
struct TekkenVocabEntry {
    rank: u32,
    token_bytes: String,
}

#[derive(Debug, Deserialize)]
struct TekkenSpecialToken {
    rank: u32,
    token_str: String,
}

pub struct VoxtralTokenizer {
    bpe: CoreBPE,
    bos_token_id: u32,
    audio_token_id: u32,
    begin_audio_token_id: u32,
    next_audio_text_token_id: u32,
    repeat_audio_text_token_id: u32,
}

impl VoxtralTokenizer {
    pub fn from_bytes(bytes: impl AsRef<[u8]>, config: &VoxtralConfig) -> Result<Self, TtsError> {
        let tekken: TekkenFile = serde_json::from_slice(bytes.as_ref())?;

        let num_special_tokens = tekken.config.default_num_special_tokens as u32;
        let inner_vocab_size = tekken
            .config
            .default_vocab_size
            .checked_sub(tekken.config.default_num_special_tokens)
            .ok_or_else(|| {
                TtsError::TokenizerError(
                    "Tekken default vocab size is smaller than the number of special tokens"
                        .to_string(),
                )
            })?;

        let mut encoder = TekkenMap::default();
        for entry in tekken
            .vocab
            .iter()
            .filter(|entry| entry.rank < inner_vocab_size as u32)
        {
            let token_bytes = base64::engine::general_purpose::STANDARD
                .decode(&entry.token_bytes)
                .map_err(|err| {
                    TtsError::TokenizerError(format!(
                        "Failed to decode Tekken token bytes '{}': {}",
                        entry.token_bytes, err
                    ))
                })?;
            encoder.insert(token_bytes, entry.rank + num_special_tokens);
        }
        if encoder.len() != inner_vocab_size {
            return Err(TtsError::TokenizerError(format!(
                "Tekken vocabulary truncation produced {} tokens, expected {}",
                encoder.len(),
                inner_vocab_size
            )));
        }

        let special_tokens = tekken
            .special_tokens
            .iter()
            .map(|entry| (entry.token_str.clone(), entry.rank))
            .collect::<TekkenMap<_, _>>();

        let bpe = CoreBPE::new(encoder, TekkenMap::default(), &tekken.config.pattern)
            .map_err(|err| TtsError::TokenizerError(err.to_string()))?;

        let bos_token_id = *special_tokens
            .get(BOS_TOKEN)
            .ok_or_else(|| missing_token(BOS_TOKEN))?;
        special_tokens
            .get(EOS_TOKEN)
            .ok_or_else(|| missing_token(EOS_TOKEN))?;
        let audio_token_id = *special_tokens
            .get(AUDIO_TOKEN)
            .ok_or_else(|| missing_token(AUDIO_TOKEN))?;
        let begin_audio_token_id = *special_tokens
            .get(BEGIN_AUDIO_TOKEN)
            .ok_or_else(|| missing_token(BEGIN_AUDIO_TOKEN))?;
        let next_audio_text_token_id = *special_tokens
            .get(NEXT_AUDIO_TEXT_TOKEN)
            .ok_or_else(|| missing_token(NEXT_AUDIO_TEXT_TOKEN))?;
        let repeat_audio_text_token_id = *special_tokens
            .get(REPEAT_AUDIO_TEXT_TOKEN)
            .ok_or_else(|| missing_token(REPEAT_AUDIO_TEXT_TOKEN))?;

        if bos_token_id != config.multimodal.bos_token_id {
            return Err(TtsError::TokenizerError(format!(
                "Tekken bos token id {} does not match params.json value {}",
                bos_token_id, config.multimodal.bos_token_id
            )));
        }
        if audio_token_id != config.multimodal.audio_model_args.audio_token_id {
            return Err(TtsError::TokenizerError(format!(
                "Tekken audio token id {} does not match params.json value {}",
                audio_token_id, config.multimodal.audio_model_args.audio_token_id
            )));
        }
        if begin_audio_token_id != config.multimodal.audio_model_args.begin_audio_token_id {
            return Err(TtsError::TokenizerError(format!(
                "Tekken begin-audio token id {} does not match params.json value {}",
                begin_audio_token_id, config.multimodal.audio_model_args.begin_audio_token_id
            )));
        }

        Ok(Self {
            bpe,
            bos_token_id,
            audio_token_id,
            begin_audio_token_id,
            next_audio_text_token_id,
            repeat_audio_text_token_id,
        })
    }

    pub fn encode_text(&self, text: &str) -> Vec<u32> {
        self.bpe.encode_ordinary(text)
    }

    pub fn build_speech_prompt(&self, text: &str, voice_audio_tokens: usize) -> Vec<u32> {
        let text_tokens = self.encode_text(text);
        let mut prompt = Vec::with_capacity(text_tokens.len() + voice_audio_tokens + 5);
        prompt.push(self.bos_token_id);
        prompt.push(self.begin_audio_token_id);
        prompt.extend(std::iter::repeat_n(self.audio_token_id, voice_audio_tokens));
        prompt.push(self.next_audio_text_token_id);
        prompt.extend(text_tokens);
        prompt.push(self.repeat_audio_text_token_id);
        prompt.push(self.begin_audio_token_id);
        prompt
    }
}

fn missing_token(token: &str) -> TtsError {
    TtsError::TokenizerError(format!("Tekken tokenizer is missing special token {token}"))
}