aha 0.2.5

aha model inference library, now supports Qwen(2.5VL/3/3VL/3.5/ASR/3Embedding/3Reranker), MiniCPM4, VoxCPM/1.5, DeepSeek-OCR/2, Hunyuan-OCR, PaddleOCR-VL/1.5, RMBG2.0, GLM(ASR-Nano-2512/OCR), Fun-ASR-Nano-2512, LFM(2/2.5/2VL/2.5VL)
Documentation
use crate::{
    models::{
        common::embedding::{NormalizeType, TextEmbedding},
        qwen3::{config::Qwen3Config, model::Qwen3Model},
    },
    tokenizer::TokenizerModel,
    utils::{find_type_files, get_device, get_dtype},
};
use anyhow::{Result, anyhow};
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;

pub struct Qwen3Embedding {
    tokenizer: TokenizerModel,
    model: Qwen3Model,
    device: Device,
    normalize: NormalizeType,
}

impl Qwen3Embedding {
    pub fn init(path: &str, device: Option<&Device>, dtype: Option<DType>) -> Result<Self> {
        let tokenizer = TokenizerModel::init(path)?;
        let config_path = path.to_string() + "/config.json";
        let cfg: Qwen3Config = serde_json::from_slice(&std::fs::read(config_path)?)?;
        let device = get_device(device);
        let dtype = get_dtype(dtype, cfg.torch_dtype.as_str());
        let model_list = find_type_files(path, "safetensors")?;
        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_list, dtype, &device)? };
        let model = Qwen3Model::new(&cfg, vb, vec![])?;
        Ok(Self {
            tokenizer,
            model,
            device,
            normalize: NormalizeType::L2,
        })
    }

    pub fn embed_multi(&mut self, input: &[String]) -> Result<Tensor> {
        if input.is_empty() {
            return Err(anyhow!("embedding input cannot be empty"));
        }
        let mut out = Vec::with_capacity(input.len());
        for text in input {
            out.push(self.embed_one(text)?);
        }
        let out = Tensor::stack(&out, 0)?;
        Ok(out)
    }

    pub fn embed_one(&mut self, text: &str) -> Result<Tensor> {
        let input_ids = self.tokenizer.text_encode(text.to_string(), &self.device)?;
        let hidden = self
            .model
            .forward_hidden(Some(&input_ids), None, 0)?
            .squeeze(0)?
            .to_dtype(DType::F32)?;

        self.model.clear_kv_cache();
        let norm = self
            .normalize
            .normalize(&hidden, hidden.rank() - 1)?
            .squeeze(0)?;
        Ok(norm)
    }
}

impl TextEmbedding for Qwen3Embedding {
    fn embed_texts(&mut self, input: &[String]) -> Result<Vec<Vec<f32>>> {
        let embeds = self.embed_multi(input)?;
        let embeds = embeds.to_vec2::<f32>()?;
        Ok(embeds)
    }
}