aha 0.2.6

aha model inference library, now supports Qwen(2.5VL/3/3VL/3.5/ASR/3Embedding/3Reranker), MiniCPM(4/5), VoxCPM(0.5B/1.5/2), DeepSeek-OCR/2, Hunyuan-OCR, PaddleOCR-VL/1.5, RMBG2.0, GLM(ASR-Nano-2512/OCR), Fun-ASR-Nano-2512, LFM(2/2.5/2VL/2.5VL)
Documentation
use std::collections::HashMap;

use crate::{
    models::voxcpm::{audio_vae::AudioVAE, tokenizer::SingleChineseTokenizer},
    utils::audio_utils::load_audio_with_resample,
};
use anyhow::Result;
use candle_core::{D, Device, IndexOp, Tensor};

pub struct VoxCPMProcessor {
    sample_rate: usize,
    chunk_size: usize,
    patch_size: usize,
    audio_start_token: u32,
    ref_audio_start_token: u32,
    ref_audio_end_token: u32,
    device: Device,
}

impl VoxCPMProcessor {
    pub fn new(sample_rate: usize, chunk_size: usize, patch_size: usize, device: Device) -> Self {
        Self {
            sample_rate,
            chunk_size,
            patch_size,
            audio_start_token: 101,
            ref_audio_start_token: 103,
            ref_audio_end_token: 104,
            device,
        }
    }

    pub fn build_prompt_cache(
        &mut self,
        prompt_text: String,
        prompt_wav_path: String,
        tokenizer: &SingleChineseTokenizer,
        audio_vae: &AudioVAE,
    ) -> Result<HashMap<String, Tensor>> {
        let (text_token, _) = tokenizer.encode_tensor(prompt_text, &self.device)?;
        let mut audio = load_audio_with_resample(
            &prompt_wav_path,
            &self.device,
            Some(self.sample_rate),
            Some(1),
        )?;
        let patch_len = self.patch_size * self.chunk_size;
        if audio.dim(1)? % patch_len != 0 {
            audio = audio.pad_with_zeros(D::Minus1, 0, patch_len - audio.dim(1)? % patch_len)?;
        }
        let audio_feat = audio_vae.encode(&audio, Some(self.sample_rate))?;
        let audio_feat = audio_feat
            .reshape((audio_vae.latent_dim, (), self.patch_size))?
            .permute((1, 2, 0))?;
        let dim0 = audio_feat.dim(0)? - 1;
        let audio_feat = audio_feat.i(..dim0)?;
        let mut hashmap = HashMap::new();
        hashmap.insert("text_token".to_string(), text_token);
        hashmap.insert("audio_feat".to_string(), audio_feat);
        Ok(hashmap)
    }

    pub fn processor(
        &self,
        target_text: String,
        prompt_text: Option<String>,
        prompt_wav_path: Option<String>,
        tokenizer: &SingleChineseTokenizer,
        audio_vae: &AudioVAE,
    ) -> Result<(Tensor, Option<Tensor>, Option<Tensor>)> {
        let text = if let Some(prompt_text) = &prompt_text {
            prompt_text.clone() + &target_text
        } else {
            target_text
        };
        let (text_token, _) = tokenizer.encode_tensor(text, &self.device)?;
        let audio_start = Tensor::new(vec![self.audio_start_token], &self.device)?;
        let mut text_token = Tensor::cat(&[text_token, audio_start], D::Minus1)?;

        let (audio_feat, audio_mask) = if let Some(path) = prompt_wav_path {
            let mut audio =
                load_audio_with_resample(&path, &self.device, Some(self.sample_rate), Some(1))?;
            let patch_len = self.patch_size * self.chunk_size;
            if audio.dim(1)? % patch_len != 0 {
                audio =
                    audio.pad_with_zeros(D::Minus1, patch_len - audio.dim(1)? % patch_len, 0)?;
            }
            let audio_feat = audio_vae.encode(&audio, Some(self.sample_rate))?;
            let audio_feat = audio_feat
                .reshape((audio_vae.latent_dim, (), self.patch_size))?
                .permute((1, 2, 0))?;
            let text_length = text_token.dim(0)?;
            let audio_length = audio_feat.dim(0)?;
            let audio_mask = if prompt_text.is_some() {
                let text_pad_token =
                    Tensor::zeros(audio_length, candle_core::DType::U32, &self.device)?;
                text_token = Tensor::cat(&[text_token, text_pad_token], D::Minus1)?;
                let mask = Tensor::cat(
                    &[
                        Tensor::zeros(text_length, candle_core::DType::U32, &self.device)?,
                        Tensor::ones(audio_length, candle_core::DType::U32, &self.device)?,
                    ],
                    D::Minus1,
                )?
                .unsqueeze(0)?;
                Some(mask)
            } else {
                let ref_start = Tensor::new(vec![self.ref_audio_start_token], &self.device)?;
                let ref_end = Tensor::new(vec![self.ref_audio_end_token], &self.device)?;
                let ref_token = Tensor::zeros(audio_length, candle_core::DType::U32, &self.device)?;
                text_token = Tensor::cat(&[&ref_start, &ref_token, &ref_end, &text_token], 0)?;
                let mask = Tensor::cat(
                    &[
                        Tensor::new(vec![0u32], &self.device)?,
                        Tensor::ones(audio_length, candle_core::DType::U32, &self.device)?,
                        Tensor::new(vec![0u32], &self.device)?,
                        Tensor::zeros(text_length, candle_core::DType::U32, &self.device)?,
                    ],
                    D::Minus1,
                )?
                .unsqueeze(0)?;
                Some(mask)
            };
            let audio_feat = audio_feat.unsqueeze(0)?;
            (Some(audio_feat), audio_mask)
        } else {
            (None, None)
        };
        let text_token = text_token.unsqueeze(0)?;
        Ok((text_token, audio_feat, audio_mask))
    }

    pub fn processor_use_cache(
        &self,
        target_text: String,
        prompt_cache: &HashMap<String, Tensor>,
        tokenizer: &SingleChineseTokenizer,
    ) -> Result<(Tensor, Option<Tensor>, Option<Tensor>)> {
        let (target_text_token, _) = tokenizer.encode_tensor(target_text, &self.device)?;
        let text_token = match prompt_cache.get("text_token") {
            Some(token) => Tensor::cat(&[token, &target_text_token], 0)?,
            None => target_text_token,
        };
        let audio_start = Tensor::new(vec![self.audio_start_token], &self.device)?;
        let mut text_token = Tensor::cat(&[text_token, audio_start], D::Minus1)?;
        let text_length = text_token.dim(0)?;
        let (audio_length, audio_feat) = match prompt_cache.get("audio_feat") {
            Some(feat) => (feat.dim(0)?, Some(feat.clone().unsqueeze(0)?)),
            None => (0, None),
        };
        let audio_mask = if audio_length > 0 {
            let text_pad_token =
                Tensor::zeros(audio_length, candle_core::DType::U32, &self.device)?;
            text_token = Tensor::cat(&[text_token, text_pad_token], D::Minus1)?;
            let mask = Tensor::cat(
                &[
                    Tensor::zeros(text_length, candle_core::DType::U32, &self.device)?,
                    Tensor::ones(audio_length, candle_core::DType::U32, &self.device)?,
                ],
                D::Minus1,
            )?
            .unsqueeze(0)?;
            Some(mask)
        } else {
            None
        };
        let text_token = text_token.unsqueeze(0)?;
        Ok((text_token, audio_feat, audio_mask))
    }
}