aha 0.2.6

aha model inference library, now supports Qwen(2.5VL/3/3VL/3.5/ASR/3Embedding/3Reranker), MiniCPM(4/5), VoxCPM(0.5B/1.5/2), DeepSeek-OCR/2, Hunyuan-OCR, PaddleOCR-VL/1.5, RMBG2.0, GLM(ASR-Nano-2512/OCR), Fun-ASR-Nano-2512, LFM(2/2.5/2VL/2.5VL)
Documentation
use std::time::Instant;

use crate::{
    models::common::{
        MultiModalData,
        generate::{GenerationContext, generate_generic_text},
        modules::{AsrResult, VadFrameResult},
        sample::get_logit_processor,
    },
    params::chat::{ChatCompletionChunkResponse, ChatCompletionParameters, ChatCompletionResponse},
    utils::response_utils::{build_chunk_response_with_usage, build_completion_response_with_time},
};
use anyhow::{Result, anyhow};
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use rocket::async_stream::stream;
use rocket::futures::Stream;

use crate::{
    chat_template::ChatTemplate,
    models::{
        GenerateModel,
        feature_extractor::config::FeatureExtractor,
        qwen3_asr::{
            config::{Qwen3ASRConfig, Qwen3ASRGenerationConfig},
            model::Qwen3ASRModel,
            processor::Qwen3AsrProcessor,
        },
    },
    tokenizer::TokenizerModel,
    utils::{
        find_type_files, get_device, get_dtype, response_utils::build_completion_chunk_response,
    },
};

pub struct Qwen3AsrGenerateModel<'a> {
    chat_template: ChatTemplate<'a>,
    tokenizer: TokenizerModel,
    processor: Qwen3AsrProcessor,
    model: Qwen3ASRModel,
    device: Device,
    dtype: DType,
    eos_token_id1: u32,
    eos_token_id2: u32,
    generation_config: Qwen3ASRGenerationConfig,
    model_name: String,
    default_template: String,
}

impl<'a> Qwen3AsrGenerateModel<'a> {
    pub fn init(path: &str, device: Option<&Device>, dtype: Option<DType>) -> Result<Self> {
        let chat_template = ChatTemplate::init(path)?;
        let tokenizer = TokenizerModel::init(path)?;
        let generation_config_path = path.to_string() + "/generation_config.json";
        let generation_config: Qwen3ASRGenerationConfig =
            serde_json::from_slice(&std::fs::read(generation_config_path)?)?;
        let device = get_device(device);
        let preprocess_config_path = path.to_string() + "/preprocessor_config.json";
        let preprocess_config: FeatureExtractor =
            serde_json::from_slice(&std::fs::read(preprocess_config_path)?)?;
        let processor = Qwen3AsrProcessor::new(&device, &preprocess_config)?;
        let config_path = path.to_string() + "/config.json";
        let cfg: Qwen3ASRConfig = serde_json::from_slice(&std::fs::read(config_path)?)?;
        let cfg_dtype = cfg.thinker_config.dtype.as_str();
        let dtype = get_dtype(dtype, cfg_dtype);
        let model_list = find_type_files(path, "safetensors")?;
        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_list, dtype, &device)? };
        let model = Qwen3ASRModel::new(vb, &cfg, generation_config.eos_token_id.clone())?;
        let model_name = std::path::Path::new(path)
            .file_name()
            .and_then(|s| s.to_str())
            .unwrap_or("qwen3-asr")
            .to_string();
        Ok(Self {
            chat_template,
            tokenizer,
            processor,
            model,
            device,
            dtype,
            eos_token_id1: generation_config.eos_token_id[0] as u32,
            eos_token_id2: generation_config.eos_token_id[1] as u32,
            generation_config,
            model_name,
            default_template: "<|im_start|>system\n<|im_end|>\n<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n<|im_start|>assistant\n".to_string(),
        })
    }

    pub fn asr_vad_res(&mut self, vad_res: VadFrameResult) -> Result<AsrResult> {
        if !vad_res.is_speech || vad_res.orig_audio.is_none() {
            return Ok(AsrResult::init_empty());
        }
        if let Some(audio) = vad_res.orig_audio {
            self.asr_audio(&audio)
        } else {
            Ok(AsrResult::init_empty())
        }
    }

    pub fn asr_audio(&mut self, audio: &Tensor) -> Result<AsrResult> {
        let audio_data =
            self.processor
                .process_audio_tensor(&self.default_template, audio, &self.tokenizer)?;
        let input_ids = audio_data.input_ids.clone();
        let input_features = Some(audio_data.input_features.clone().to_dtype(self.dtype)?);
        let mut ctx = GenerationContext::new(
            None,
            None,
            None,
            None,
            None,
            32432,
            input_ids.dim(1)?,
            512,
            self.device.clone(),
        );
        let data_vec = vec![input_features];
        let data = MultiModalData::new(data_vec);
        let mut text =
            generate_generic_text(&mut self.model, &self.tokenizer, input_ids, data, &mut ctx)?;
        if text.contains("<asr_text>") {
            let mut split: Vec<&str> = text.split("<asr_text>").collect();
            text = split.pop().unwrap_or(&text).to_string();
        }
        Ok(AsrResult::init(text))
    }
}

impl<'a> GenerateModel for Qwen3AsrGenerateModel<'a> {
    fn generate(&mut self, mes: ChatCompletionParameters) -> Result<ChatCompletionResponse> {
        let temperature = mes
            .temperature
            .unwrap_or(self.generation_config.temperature);
        let seed = mes.seed.unwrap_or(34562) as u64;
        let mut logit_processor = get_logit_processor(Some(temperature), mes.top_p, None, seed);
        let render_text = self.chat_template.apply_chat_template(&mes)?;
        let audio_datas = self
            .processor
            .process_info(&mes, &render_text, &self.tokenizer)?;
        let sample_len = mes.max_tokens.unwrap_or(1024);
        let mut generate = Vec::new();
        let mut prompt_tokens = 0u32;
        let mut prompt_secs = 0.0f64;
        let mut completion_secs = 0.0f64;
        for data in audio_datas.iter() {
            let mut input_ids = data.input_ids.clone();
            let mut input_features = Some(data.input_features.clone().to_dtype(self.dtype)?);
            let mut seq_len = input_ids.dim(1)?;
            prompt_tokens += seq_len as u32;
            let mut seqlen_offset = 0;
            for _ in 0..sample_len {
                let i_start = Instant::now();
                let logits =
                    self.model
                        .forward(&input_ids, seqlen_offset, input_features.as_ref())?;
                let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
                let next_token = logit_processor.sample(&logits)?;
                let i_duration = i_start.elapsed();
                if seqlen_offset == 0 {
                    prompt_secs += i_duration.as_secs_f64();
                } else {
                    completion_secs += i_duration.as_secs_f64();
                };
                generate.push(next_token);
                if next_token == self.eos_token_id1 || next_token == self.eos_token_id2 {
                    break;
                }
                seqlen_offset += seq_len;
                seq_len = 1;
                input_ids = Tensor::from_vec(vec![next_token], (1, 1), &self.device)?;
                input_features = None;
            }
            self.model.clear_kv_cache();
        }
        let num_token = generate.len() as u32;
        let res = self.tokenizer.token_decode(generate)?;
        let response = build_completion_response_with_time(
            res,
            &self.model_name,
            num_token.into(),
            completion_secs.into(),
            prompt_tokens.into(),
            prompt_secs.into(),
        );
        Ok(response)
    }

    fn generate_stream(
        &mut self,
        mes: ChatCompletionParameters,
    ) -> Result<
        Box<
            dyn Stream<Item = Result<ChatCompletionChunkResponse, anyhow::Error>>
                + Send
                + Unpin
                + '_,
        >,
    > {
        let temperature = mes
            .temperature
            .unwrap_or(self.generation_config.temperature);
        let seed = mes.seed.unwrap_or(34562) as u64;
        let mut logit_processor = get_logit_processor(Some(temperature), mes.top_p, None, seed);
        let render_text = self.chat_template.apply_chat_template(&mes)?;
        let audio_datas = self
            .processor
            .process_info(&mes, &render_text, &self.tokenizer)?;
        let sample_len = mes.max_tokens.unwrap_or(1024);
        let stream = stream! {
            let mut error_tokens = Vec::new();
            let mut prompt_tokens = 0u32;
            let mut completion_tokens = 0u32;
            let mut prompt_secs = 0.0f64;
            let mut completion_secs = 0.0f64;
            for data in audio_datas.iter() {
                let mut input_ids = data.input_ids.clone();
                let mut input_features = Some(data.input_features.clone().to_dtype(self.dtype)?);
                let mut seq_len = input_ids.dim(1)?;
                prompt_tokens += seq_len as u32;
                let mut seqlen_offset = 0;
                for _ in 0..sample_len {
                    let i_start = Instant::now();
                    let logits =
                        self.model
                            .forward(&input_ids, seqlen_offset, input_features.as_ref())?;
                    let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
                    let next_token = logit_processor.sample(&logits)?;
                    completion_tokens += 1;
                    let i_duration = i_start.elapsed();
                    if seqlen_offset == 0 {
                        prompt_secs += i_duration.as_secs_f64();
                    } else {
                        completion_secs += i_duration.as_secs_f64();
                    };
                    let mut decode_ids = Vec::new();
                    if !error_tokens.is_empty() {
                        decode_ids.extend_from_slice(&error_tokens);
                    }
                    decode_ids.push(next_token);
                    let decoded_token = self.tokenizer.token_decode(decode_ids).map_err(|e| anyhow!(format!("stream decode error{e}")))?;
                    if decoded_token.contains("") {
                        error_tokens.push(next_token);
                        if error_tokens.len() > 3 {
                            error_tokens.clear();
                        }
                        seqlen_offset += seq_len;
                        seq_len = 1;
                        input_ids = Tensor::from_vec(vec![next_token], (1, 1), &self.device)?;
                        input_features = None;
                        continue;
                    }
                    error_tokens.clear();
                    let chunk = build_completion_chunk_response(decoded_token, &self.model_name, None, None);
                    yield Ok(chunk);
                    if next_token == self.eos_token_id1 || next_token == self.eos_token_id2 {
                        yield Ok(build_chunk_response_with_usage(&self.model_name, completion_tokens.into(), completion_secs.into(), prompt_tokens.into(), prompt_secs.into()));
                        break;
                    }
                    seqlen_offset += seq_len;
                    seq_len = 1;
                    input_ids = Tensor::from_vec(vec![next_token], (1, 1), &self.device)?;
                    input_features = None;
                }
                self.model.clear_kv_cache();
            }
        };
        Ok(Box::new(Box::pin(stream)))
    }
}