aha 0.2.5

aha model inference library, now supports Qwen(2.5VL/3/3VL/3.5/ASR/3Embedding/3Reranker), MiniCPM4, VoxCPM/1.5, DeepSeek-OCR/2, Hunyuan-OCR, PaddleOCR-VL/1.5, RMBG2.0, GLM(ASR-Nano-2512/OCR), Fun-ASR-Nano-2512, LFM(2/2.5/2VL/2.5VL)
Documentation
use anyhow::Result;
use candle_core::Tensor;
pub mod embedding;
pub mod generate;
pub mod gguf;
pub mod model_mapping;
pub mod modules;
pub mod reranker;

/// 多模态模型的特征数据
/// 每个模型数据不一样
/// 需按顺序存放与取用
#[derive(Clone, Debug)]
pub struct MultiModalData {
    pub data_vec: Vec<Option<Tensor>>,
}
impl MultiModalData {
    pub fn new(data_vec: Vec<Option<Tensor>>) -> Self {
        Self { data_vec }
    }
}

#[allow(unused)]
pub trait InferenceModel {
    /// 初始前向传播(考虑多模态输入)
    /// 默认实现无特殊数据
    fn forward_initial(
        &mut self,
        input_ids: &Tensor,
        seqlen_offset: usize,
        data: MultiModalData,
    ) -> Result<Tensor> {
        Self::forward_step(self, input_ids, seqlen_offset)
    }

    /// 后续前向传播(自回归步骤)
    fn forward_step(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor>;

    /// 清理 KV cache
    fn clear_cache(&mut self);

    /// 获取结束 token IDs
    fn stop_token_ids(&self) -> Vec<u32>;
}