aha 0.2.5

aha model inference library, now supports Qwen(2.5VL/3/3VL/3.5/ASR/3Embedding/3Reranker), MiniCPM4, VoxCPM/1.5, DeepSeek-OCR/2, Hunyuan-OCR, PaddleOCR-VL/1.5, RMBG2.0, GLM(ASR-Nano-2512/OCR), Fun-ASR-Nano-2512, LFM(2/2.5/2VL/2.5VL)
Documentation
use crate::{
    models::common::embedding::{NormalizeType, TextEmbedding},
    tokenizer::TokenizerModel,
    utils::{find_type_files, get_device, get_dtype},
};
use anyhow::{Result, anyhow};
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config as BertConfig};

pub struct AllMiniLML6V2Embedding {
    tokenizer: TokenizerModel,
    model: BertModel,
    device: Device,
    normalize: NormalizeType,
}

impl AllMiniLML6V2Embedding {
    pub fn init(path: &str, device: Option<&Device>, dtype: Option<DType>) -> Result<Self> {
        let tokenizer = TokenizerModel::init(path)?;
        let config_path = path.to_string() + "/config.json";
        let cfg: BertConfig = serde_json::from_slice(&std::fs::read(config_path)?)?;
        let device = get_device(device);
        let dtype = get_dtype(dtype, "float32");
        let model_list = find_type_files(path, "safetensors")?;
        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_list, dtype, &device)? };
        let model = BertModel::load(vb, &cfg)?;
        Ok(Self {
            tokenizer,
            model,
            device,
            normalize: NormalizeType::L2,
        })
    }

    fn prepare_token_ids(&self, text: &str) -> Result<Vec<u32>> {
        let mut token_ids = self.tokenizer.text_encode_vec(text.to_string(), true)?;
        token_ids = token_ids
            .into_iter()
            .filter(|&x| x != 0)
            .collect::<Vec<u32>>();
        if token_ids.is_empty() {
            return Err(anyhow!("embedding tokenized input cannot be empty"));
        }
        Ok(token_ids)
    }
    fn embed_one(&mut self, text: &str) -> Result<Vec<f32>> {
        let token_ids = self.prepare_token_ids(text)?;
        let seq_len = token_ids.len();
        let input_ids = Tensor::from_slice(&token_ids, (1, seq_len), &self.device)?;
        let token_type_ids = Tensor::zeros((1, seq_len), DType::U32, &self.device)?;
        let attention_mask = Tensor::ones((1, seq_len), DType::U32, &self.device)?;
        let hidden = self
            .model
            .forward(&input_ids, &token_type_ids, Some(&attention_mask))?
            .to_dtype(DType::F32)?;
        let hidden = hidden.mean(1)?;
        let embed = self
            .normalize
            .normalize(&hidden, hidden.rank() - 1)?
            .squeeze(0)?;
        let embed = embed.to_vec1::<f32>()?;
        Ok(embed)
    }
}

impl TextEmbedding for AllMiniLML6V2Embedding {
    fn embed_texts(&mut self, input: &[String]) -> Result<Vec<Vec<f32>>> {
        if input.is_empty() {
            return Err(anyhow!("embedding input cannot be empty"));
        }
        let mut out = Vec::with_capacity(input.len());
        for text in input {
            out.push(self.embed_one(text)?);
        }
        Ok(out)
    }
}