aha 0.2.6

aha model inference library, now supports Qwen(2.5VL/3/3VL/3.5/ASR/3Embedding/3Reranker), MiniCPM(4/5), VoxCPM(0.5B/1.5/2), DeepSeek-OCR/2, Hunyuan-OCR, PaddleOCR-VL/1.5, RMBG2.0, GLM(ASR-Nano-2512/OCR), Fun-ASR-Nano-2512, LFM(2/2.5/2VL/2.5VL)
Documentation
use std::collections::HashMap;

use crate::{
    models::{
        moss_audio_tokenizer_nano::{MossAudioTokenizer, config::MossAudioTokenizerConfig},
        moss_tts_nano::{
            config::MossTTSConfig,
            model::{MossTTSMode, MossTTSModel},
            processor::MossTTSProcessor,
        },
    },
    utils::{find_type_files, get_device, get_dtype},
};
use anyhow::{Result, anyhow};
use candle_core::{DType, Device, pickle::read_all_with_key};
use candle_nn::VarBuilder;
use sentencepiece::SentencePieceProcessor;

pub struct MossTTSGenerate {
    pub audio_tokenizer: MossAudioTokenizer,
    pub text_tokenizer: SentencePieceProcessor,
    pub processor: MossTTSProcessor,
    pub model: MossTTSModel,
    pub device: Device,
}

impl MossTTSGenerate {
    pub fn init(
        tts_path: &str,
        audio_tokenizer_path: &str,
        device: Option<&Device>,
        dtype: Option<DType>,
    ) -> Result<Self> {
        let audio_tokenizer_config_path = audio_tokenizer_path.to_string() + "/config.json";
        let audio_tokenizer_cfg: MossAudioTokenizerConfig =
            serde_json::from_slice(&std::fs::read(audio_tokenizer_config_path)?)?;
        let model_list = find_type_files(audio_tokenizer_path, "safetensors")?;
        let audio_dtype = get_dtype(dtype, &audio_tokenizer_cfg.dtype);
        let device = get_device(device);
        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_list, audio_dtype, &device)? };
        let audio_tokenizer = MossAudioTokenizer::new(vb, &audio_tokenizer_cfg)?;
        let text_tokenizer_path = tts_path.to_string() + "/tokenizer.model";
        let text_tokenizer = SentencePieceProcessor::open(text_tokenizer_path)
            .map_err(|e| anyhow!(format!("load bpe.model file error:{}", e)))?;
        let tts_cfg_path = tts_path.to_string() + "/config.json";
        let tts_cfg: MossTTSConfig = serde_json::from_slice(&std::fs::read(tts_cfg_path)?)?;
        let processor = MossTTSProcessor::new(
            &tts_cfg,
            audio_tokenizer_cfg.sample_rate,
            audio_tokenizer_cfg.number_channels,
            &text_tokenizer,
        )?;
        let model_list = find_type_files(tts_path, "bin")?;
        let mut dict_to_hashmap = HashMap::new();
        let m_dtype = get_dtype(dtype, "bfloat16");
        for m in model_list {
            let dict = read_all_with_key(m, None)?;
            for (k, v) in dict {
                dict_to_hashmap.insert(k, v);
            }
        }
        let vb = VarBuilder::from_tensors(dict_to_hashmap, m_dtype, &device);
        let model = MossTTSModel::new(vb, &tts_cfg)?;

        Ok(Self {
            audio_tokenizer,
            text_tokenizer,
            processor,
            model,
            device,
        })
    }

    pub fn generate(
        &mut self,
        text: &str,
        prompt_audio_path: Option<&str>,
        prompt_text: Option<&str>,
        mode: Option<MossTTSMode>,
    ) -> Result<()> {
        let mode = self.processor.resolved_mode(
            mode,
            prompt_text.is_some(),
            prompt_audio_path.is_some(),
        )?;
        let input_ids = self.processor.build_inference_input_ids(
            text,
            prompt_audio_path,
            prompt_text,
            mode.clone(),
            &self.audio_tokenizer,
            &self.text_tokenizer,
            &self.device,
        )?;
        self.model.generate(&input_ids, &self.audio_tokenizer)?;
        Ok(())
    }
}