natural-tts 0.1.4

High-level bindings to a variety of text-to-speech libraries.
Documentation
pub mod bs1770;
pub mod utils;

use std::error::Error;
use std::io::Write;
use std::path::PathBuf;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::encodec;
use candle_transformers::models::metavoice::{adapters, gpt, transformer};
use candle_transformers::models::quantized_metavoice::transformer as qtransformer;
use derive_builder::Builder;
use candle_core::{DType, IndexOp, Tensor, Device};
use candle_nn::VarBuilder;
use hf_hub::api::sync::Api;
use rand::{distributions::Distribution, SeedableRng};

use utils::*;

use crate::utils::{get_path, play_wav_file, read_wav_file};
use crate::TtsError;

use super::{did_save, NaturalModelTrait, SynthesizedAudio};

#[derive(Builder, Clone, Default)]
#[builder(setter(into))]
pub struct MetaModelOptions{
    #[builder(default = "false")]
    cpu : bool,
    #[builder(default = "false")]
    tracing : bool, 
    #[builder(default = "false")]
    quantized : bool, 
    #[builder(default = "None")]
    first_stage_meta : Option<String>, 
    #[builder(default = "None")]
    first_stage_weights: Option<String>, 
    #[builder(default = "None")]
    second_stage_weights : Option<String>,
    #[builder(default = "None")]
    encodec_weights: Option<String>,
    #[builder(default = "None")]
    spk_emb : Option<String>,
    #[builder(default = "1024")]
    encodec_ntokens: u32, 
    #[builder(default = "299792458")]
    seed : u64,
    #[builder(default = "2000")]
    max_tokens : u64,
    #[builder(default = "3.0")]
    guidance_scale : f64,
    #[builder(default = "1.0")]
    temperature : f64,
    #[builder(default = "None")]
    repo : Option<String>,
}

#[derive(Clone)]
pub struct MetaModel {
    pub first_stage_model : Transformer, 
    pub device : Device,
    pub first_stage_meta : serde_json::Value,
    pub dtype : DType, 
    pub second_stage_config : gpt::Config,
    pub encodec_config : encodec::Config,
    pub encodec_weights : PathBuf,
    pub second_stage_weights : PathBuf,
    pub encodec_device : Device, 
    pub repo_path : String,
    pub seed : u64,
    pub guidance_scale : f64,
    pub temperature : f64,
    pub max_tokens : u64,
    pub spk_emb : Option<String>,
    pub encodec_ntokens : u32,
}


impl MetaModel{
    pub fn new(options : MetaModelOptions) -> Result<Self, Box<dyn Error>> {
        use tracing_chrome::ChromeLayerBuilder;
        use tracing_subscriber::prelude::*;

        let _guard = if options.tracing {
            let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
            tracing_subscriber::registry().with(chrome_layer).init();
            Some(guard)
        } else {
            None
        };

        let device = device(options.cpu)?;
        let api = Api::new()?;
        let repo_path = match options.repo{
            None => "lmz/candle-metavoice".to_string(),
            Some(x) => x,
        };

        let repo = api.model(repo_path.clone());
        
        let first_stage_meta = match &options.first_stage_meta {
            Some(w) => std::path::PathBuf::from(w),
            None => repo.get("first_stage.meta.json")?,
        };

        let first_stage_meta: serde_json::Value =
            serde_json::from_reader(&std::fs::File::open(first_stage_meta)?)?;

        let dtype = DType::F16;

        let first_stage_config = transformer::Config::cfg1b_v0_1();
        
        let first_stage_model = if options.quantized {
            let filename = match &options.first_stage_weights {
                Some(w) => std::path::PathBuf::from(w),
                None => repo.get("first_stage_q4k.gguf")?,
            };

            let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
            
            let first_stage_model = qtransformer::Model::new(&first_stage_config, vb)?;
            
            Transformer::Quantized(first_stage_model)
        } else {
            let first_stage_weights = match &options.first_stage_weights {
                Some(w) => std::path::PathBuf::from(w),
                None => repo.get("first_stage.safetensors")?,
            };

            let first_stage_vb = unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
            
            let first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
            
            Transformer::Normal(first_stage_model)
        };

        let encodec_device = if device.is_metal() {
            candle_core::Device::Cpu
        } else {
            device.clone()
        };

        let second_stage_config = gpt::Config::cfg1b_v0_1();
        let encodec_config = encodec::Config::default();

        let second_stage_weights = match &options.second_stage_weights {
            Some(w) => std::path::PathBuf::from(w),
            None => repo.get("second_stage.safetensors")?,
        };

        let encodec_weights = match &options.encodec_weights {
            Some(w) => std::path::PathBuf::from(w.clone()),
            None => Api::new()?
                .model("facebook/encodec_24khz".to_string())
                .get("model.safetensors")?,
        };
        
        return Ok(Self{
            first_stage_model,
            device,
            first_stage_meta,
            dtype,
            encodec_weights,
            second_stage_weights,
            encodec_config,
            second_stage_config,
            encodec_device,
            repo_path,
            seed : options.seed,
            guidance_scale : options.guidance_scale,
            temperature : options.temperature,
            max_tokens : options.max_tokens,
            spk_emb : options.spk_emb,
            encodec_ntokens : options.encodec_ntokens
        });
    }

    pub fn get_secondary_models(&self) -> Result<(gpt::Model, encodec::Model), Box<dyn Error>>{
        let second_stage_vb = unsafe { VarBuilder::from_mmaped_safetensors(&[self.second_stage_weights.clone()], self.dtype, &self.device)? };
        let second_stage_model = gpt::Model::new(self.second_stage_config.clone(), second_stage_vb)?;
        
        let encodec_vb = unsafe { VarBuilder::from_mmaped_safetensors(&[self.encodec_weights.clone()], self.dtype, &self.encodec_device)? };
        let encodec_model = encodec::Model::new(&self.encodec_config, encodec_vb)?;

        return Ok((second_stage_model, encodec_model));
    }

    pub fn run (&self, prompt : String, filename : String,) -> Result<(), Box<dyn Error>>{
        let (second_stage_model, encodec_model) = self.get_secondary_models()?;

        let fs_tokenizer = get_fs_tokenizer(self.first_stage_meta.clone())?;

        let prompt_tokens = fs_tokenizer.encode(&prompt)?;
        let mut tokens = prompt_tokens.clone();
        
        let api = Api::new()?;
        let repo = api.model(self.repo_path.clone());
        
        let spk_emb_file = match &self.spk_emb {
            Some(w) => std::path::PathBuf::from(w),
            None => repo.get("spk_emb.safetensors")?,
        };

        let spk_emb = candle_core::safetensors::load(&spk_emb_file, &candle_core::Device::Cpu)?;
        
        let spk_emb = match spk_emb.get("spk_emb") {
            None => return Err(TtsError::Tensor.into()),
            Some(spk_emb) => spk_emb.to_dtype(self.dtype)?,
        };
        
        let spk_emb = spk_emb.to_device(&self.device)?;
        let mut logits_processor = LogitsProcessor::new(self.seed, Some(self.temperature), Some(0.95));

        for index in 0..self.max_tokens {
            let context_size = if index > 0 { 1 } else { tokens.len() };
            let start_pos = tokens.len().saturating_sub(context_size);
            let ctxt = &tokens[start_pos..];
            let input = Tensor::new(ctxt, &self.device)?;
            let input = Tensor::stack(&[&input, &input], 0)?;
            let logits = match &self.first_stage_model {
                Transformer::Normal(m) => m.clone().forward(&input, &spk_emb, tokens.len() - context_size)?,
                Transformer::Quantized(m) => {
                    m.clone().forward(&input, &spk_emb, tokens.len() - context_size)?
                }
            };
            let logits0 = logits.i((0, 0))?;
            let logits1 = logits.i((1, 0))?;
            let logits = ((logits0 * self.guidance_scale)? + logits1 * (1. - self.guidance_scale))?;
            let logits = logits.to_dtype(DType::F32)?;
            let next_token = logits_processor.sample(&logits)?;
            tokens.push(next_token);
            print!(".");
            std::io::stdout().flush()?;
            if next_token == 2048 {
                break;
            }
        }
        let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(self.encodec_ntokens);
        let (_, ids1, ids2) = fie2c.decode(&tokens);
        let mut rng = rand::rngs::StdRng::seed_from_u64(self.seed + 1337);
        let encoded_text: Vec<_> = prompt_tokens.iter().map(|v| v - 1024).collect();
        let mut hierarchies_in1 =
            [encoded_text.as_slice(), ids1.as_slice(), &[self.encodec_ntokens]].concat();
        let mut hierarchies_in2 = [
            vec![self.encodec_ntokens; encoded_text.len()].as_slice(),
            ids2.as_slice(),
            &[self.encodec_ntokens],
        ]
        .concat();
        hierarchies_in1.resize(self.second_stage_config.block_size, self.encodec_ntokens);
        hierarchies_in2.resize(self.second_stage_config.block_size, self.encodec_ntokens);
        let in_x1 = Tensor::new(hierarchies_in1, &self.device)?;
        let in_x2 = Tensor::new(hierarchies_in2, &self.device)?;
        let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
        let logits = second_stage_model.forward(&in_x)?;
        let mut codes = vec![];
        for logits in logits.iter() {
            let logits = logits.squeeze(0)?;
            let (seq_len, _) = logits.dims2()?;
            let mut codes_ = Vec::with_capacity(seq_len);
            for step in 0..seq_len {
                let logits = logits.i(step)?.to_dtype(DType::F32)?;
                let logits = &(&logits / 1.0)?;
                let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
                let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?;
                let sample = distr.sample(&mut rng) as u32;
                codes_.push(sample)
            }
            codes.push(codes_)
        }

        let codes = Tensor::new(codes, &self.device)?.unsqueeze(0)?;
        let codes = Tensor::cat(&[in_x, codes], 1)?;
        let tilted_encodec = adapters::TiltedEncodec::new(self.encodec_ntokens);
        let codes = codes.i(0)?.to_vec2::<u32>()?;
        let (_, audio_ids) = tilted_encodec.decode(&codes);
        let audio_ids = Tensor::new(audio_ids, &self.encodec_device)?.unsqueeze(0)?;
        let pcm = encodec_model.decode(&audio_ids)?;
        let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
        let pcm = normalize_loudness(&pcm, 24_000, true)?;
        let pcm = pcm.to_vec1::<f32>()?;
        let mut output = std::fs::File::create(&filename)?;
        write_pcm_as_wav(&mut output, &pcm, 24_000)?;
        Ok(())
    }
}

impl Default for MetaModel{
    fn default() -> Self {
        return Self::new(MetaModelOptions::default()).unwrap();
    }
}

impl NaturalModelTrait for MetaModel{
    type SynthesizeType = f32;

    fn save(&mut self, message: String, path : String) -> Result<(), Box<dyn Error>>{
        let _ = self.run(message, path.clone())?;
        did_save(path.as_str())
    }

    fn say(&mut self, message : String)  -> Result<(), Box<dyn Error>>{
        let path = get_path("temp.wav".to_string());
        self.save(message, path.clone())?;
        play_wav_file(&path)?;
        std::fs::remove_file(path)?;
        Ok(())
    }

    fn synthesize(&mut self, message : String) -> Result<SynthesizedAudio<Self::SynthesizeType>, Box<dyn Error>> {
        let path = get_path("temp.wav".to_string());
        self.save(message, path.clone())?;
        let d = read_wav_file(&path)?;
        std::fs::remove_file(path)?;
        Ok(d)
    }
}