aha 0.2.5

aha model inference library, now supports Qwen(2.5VL/3/3VL/3.5/ASR/3Embedding/3Reranker), MiniCPM4, VoxCPM/1.5, DeepSeek-OCR/2, Hunyuan-OCR, PaddleOCR-VL/1.5, RMBG2.0, GLM(ASR-Nano-2512/OCR), Fun-ASR-Nano-2512, LFM(2/2.5/2VL/2.5VL)
Documentation
use std::collections::HashMap;

use crate::{
    models::common::{
        MultiModalData,
        generate::{GenerationContext, generate_generic, generate_stream_generic},
    },
    params::chat::{ChatCompletionChunkResponse, ChatCompletionParameters, ChatCompletionResponse},
};
use anyhow::{Result, anyhow};
use candle_core::{DType, Device, pickle::read_all_with_key};
use candle_nn::VarBuilder;
use rocket::futures::Stream;

use crate::{
    models::{
        GenerateModel,
        fun_asr_nano::{
            config::FunASRNanoConfig, model::FunAsrNanoModel, processor::FunAsrNanoProcessor,
        },
        qwen3::config::{Qwen3Config, Qwen3GenerationConfig},
    },
    tokenizer::TokenizerModel,
    utils::{find_type_files, get_device, get_dtype},
};

pub struct FunAsrNanoGenerateModel {
    tokenizer: TokenizerModel,
    processor: FunAsrNanoProcessor,
    fun_asr_nano: FunAsrNanoModel,
    device: Device,
    dtype: DType,
    generation_config: Qwen3GenerationConfig,
    model_name: String,
}

impl FunAsrNanoGenerateModel {
    pub fn init(path: &str, device: Option<&Device>, dtype: Option<DType>) -> Result<Self> {
        let llm_config_path = path.to_string() + "/Qwen3-0.6B";
        let tokenizer = TokenizerModel::init(&llm_config_path)?;
        let generation_config_path = llm_config_path.clone() + "/generation_config.json";
        let generation_config: Qwen3GenerationConfig =
            serde_json::from_slice(&std::fs::read(generation_config_path)?)?;
        let config_path = llm_config_path + "/config.json";
        let llm_cfg: Qwen3Config = serde_json::from_slice(&std::fs::read(config_path)?)?;
        let device = get_device(device);
        let config_path = path.to_string() + "/config.yaml";
        let cfg: FunASRNanoConfig = serde_yaml::from_slice(&std::fs::read(config_path)?)?;
        let cfg_dtype = cfg.llm_conf.llm_dtype.as_str();
        let dtype = get_dtype(dtype, cfg_dtype);
        let processor = FunAsrNanoProcessor::new(&cfg.frontend_conf, &device)?;
        let model_list = find_type_files(path, "pt")?;
        let mut dict_to_hashmap = HashMap::new();
        for m in model_list {
            let dict = match read_all_with_key(m.clone(), Some("state_dict")) {
                Ok(dict) => dict,
                Err(e) => {
                    println!(
                        "model read_all_with_key {} get state_dict err: {}, use None try again",
                        &m, e
                    );
                    match read_all_with_key(m.clone(), None) {
                        Ok(dict) => dict,
                        Err(e) => {
                            return Err(anyhow!(format!(
                                "model read_all_with_key({}, None): e: {}",
                                &m, e
                            )));
                        }
                    }
                }
            };
            for (k, v) in dict {
                dict_to_hashmap.insert(k, v);
            }
        }
        let vb = VarBuilder::from_tensors(dict_to_hashmap, dtype, &device);
        let fun_asr_nano =
            FunAsrNanoModel::new(vb, &cfg, &llm_cfg, generation_config.eos_token_id.clone())?;
        let model_name = std::path::Path::new(path)
            .file_name()
            .and_then(|s| s.to_str())
            .unwrap_or("fun-asr-nano")
            .to_string();
        Ok(Self {
            tokenizer,
            processor,
            fun_asr_nano,
            device,
            dtype,
            generation_config,
            model_name,
        })
    }
}

impl GenerateModel for FunAsrNanoGenerateModel {
    fn generate(&mut self, mes: ChatCompletionParameters) -> Result<ChatCompletionResponse> {
        let temperature = mes
            .temperature
            .unwrap_or(self.generation_config.temperature);
        let top_p = mes.top_p.unwrap_or(self.generation_config.top_p);
        let top_k = self.generation_config.top_k;
        let seed = mes.seed.unwrap_or(34562) as u64;
        let max_tokens = mes.max_tokens.unwrap_or(1024);
        let (speech, fbank_mask, input_ids) = self.processor.process_info(&mes, &self.tokenizer)?;
        let speech = speech.to_dtype(self.dtype)?;
        let mut ctx = GenerationContext::new(
            temperature.into(),
            top_p.into(),
            top_k.into(),
            mes.repeat_penalty,
            mes.repeat_last_n,
            seed,
            input_ids.dim(1)?,
            max_tokens,
            self.device.clone(),
        );

        let data_vec = vec![speech.into(), fbank_mask.into()];
        let data = MultiModalData::new(data_vec);
        generate_generic(
            &mut self.fun_asr_nano,
            &self.tokenizer,
            input_ids,
            data,
            &mut ctx,
            &self.model_name,
        )
    }

    fn generate_stream(
        &mut self,
        mes: ChatCompletionParameters,
    ) -> Result<
        Box<
            dyn Stream<Item = Result<ChatCompletionChunkResponse, anyhow::Error>>
                + Send
                + Unpin
                + '_,
        >,
    > {
        let temperature = mes
            .temperature
            .unwrap_or(self.generation_config.temperature);
        let top_p = mes.top_p.unwrap_or(self.generation_config.top_p);
        let top_k = self.generation_config.top_k;
        let seed = mes.seed.unwrap_or(34562) as u64;
        let max_tokens = mes.max_tokens.unwrap_or(1024);
        let (speech, fbank_mask, input_ids) = self.processor.process_info(&mes, &self.tokenizer)?;
        let speech = speech.to_dtype(self.dtype)?;
        let data_vec = vec![speech.into(), fbank_mask.into()];
        let data = MultiModalData::new(data_vec);
        let stream = generate_stream_generic(
            &mut self.fun_asr_nano,
            &self.tokenizer,
            input_ids,
            data,
            temperature.into(),
            top_p.into(),
            top_k.into(),
            mes.repeat_penalty,
            mes.repeat_last_n,
            seed,
            max_tokens,
            false,
            &self.device,
            &self.model_name,
        )?;
        Ok(Box::new(Box::pin(stream)))
    }
}