aha 0.2.5

aha model inference library, now supports Qwen(2.5VL/3/3VL/3.5/ASR/3Embedding/3Reranker), MiniCPM4, VoxCPM/1.5, DeepSeek-OCR/2, Hunyuan-OCR, PaddleOCR-VL/1.5, RMBG2.0, GLM(ASR-Nano-2512/OCR), Fun-ASR-Nano-2512, LFM(2/2.5/2VL/2.5VL)
Documentation
use anyhow::Result;
use candle_core::{IndexOp, Tensor};
use candle_nn::{Conv1d, LayerNorm, Linear, Module, VarBuilder, linear, linear_no_bias};

use crate::{
    models::{
        common::{
            InferenceModel,
            modules::{
                LlamaForCausalLM, TwoLinearMLP, eager_attention_forward, get_conv1d, get_layer_norm,
            },
        },
        glm_asr_nano::config::{GlmAsrAudioConfig, GlmAsrNanoConfig},
    },
    position_embed::rope::{RoPE, glm_asr_apply_rotary_pos_emb},
    utils::tensor_utils::{get_equal_mask, masked_scatter_dim0},
};

#[derive(Debug, Clone)]
// pub struct AttentionNobias {
pub struct GlmAsrAttention {
    q_proj: Linear,
    k_proj: Linear,
    v_proj: Linear,
    o_proj: Linear,
    num_heads: usize,
    num_kv_heads: usize,
    num_kv_groups: usize,
    head_dim: usize,
    middle_size: usize,
}

impl GlmAsrAttention {
    pub fn new(
        vb: VarBuilder,
        hidden_size: usize,
        num_attention_heads: usize,
        num_key_value_heads: usize,
        head_dim: Option<usize>,
    ) -> Result<Self> {
        let num_kv_groups = num_attention_heads / num_key_value_heads;
        let head_dim = match head_dim {
            None => hidden_size / num_attention_heads,
            Some(dim) => dim,
        };
        let q_proj = linear(hidden_size, num_attention_heads * head_dim, vb.pp("q_proj"))?;
        let k_proj = linear_no_bias(hidden_size, num_key_value_heads * head_dim, vb.pp("k_proj"))?;
        let v_proj = linear(hidden_size, num_key_value_heads * head_dim, vb.pp("v_proj"))?;
        let o_proj = linear(num_attention_heads * head_dim, hidden_size, vb.pp("o_proj"))?;

        Ok(Self {
            q_proj,
            k_proj,
            v_proj,
            o_proj,
            num_heads: num_attention_heads,
            num_kv_heads: num_key_value_heads,
            num_kv_groups,
            head_dim,
            middle_size: num_attention_heads * head_dim,
        })
    }

    pub fn forward(
        &self,
        xs: &Tensor,
        cos: Option<&Tensor>,
        sin: Option<&Tensor>,
        attention_mask: Option<&Tensor>,
        tof32: bool,
    ) -> Result<Tensor> {
        let (b_sz, q_len, _) = xs.dims3()?;
        let query_states = self.q_proj.forward(xs)?;
        let key_states = self.k_proj.forward(xs)?;
        let value_states = self.v_proj.forward(xs)?;
        let query_states = query_states
            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
            .transpose(1, 2)?;
        let key_states = key_states
            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
            .transpose(1, 2)?;
        let value_states = value_states
            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
            .transpose(1, 2)?;
        let (query_states, key_states) = if let Some(cos) = cos
            && let Some(sin) = sin
        {
            glm_asr_apply_rotary_pos_emb(&query_states, &key_states, cos, sin, tof32)?
        } else {
            (query_states, key_states)
        };

        let scale = 1f64 / f64::sqrt(self.head_dim as f64);
        let attn_output = eager_attention_forward(
            &query_states,
            &key_states,
            &value_states,
            Some(self.num_kv_groups),
            attention_mask,
            scale,
        )?;
        let attn_output = attn_output.reshape((b_sz, q_len, self.middle_size))?;
        let attn_output = attn_output.apply(&self.o_proj)?;
        Ok(attn_output)
    }
}

pub struct GlmAsrEncoderLayer {
    self_attn: GlmAsrAttention,
    mlp: TwoLinearMLP,
    input_layernorm: LayerNorm,
    post_attention_layernorm: LayerNorm,
}

impl GlmAsrEncoderLayer {
    pub fn new(vb: VarBuilder, audio_cfg: &GlmAsrAudioConfig) -> Result<Self> {
        let self_attn = GlmAsrAttention::new(
            vb.pp("self_attn"),
            audio_cfg.hidden_size,
            audio_cfg.num_attention_heads,
            audio_cfg.num_key_value_heads,
            Some(audio_cfg.head_dim),
        )?;
        let mlp = TwoLinearMLP::new(
            vb.pp("mlp"),
            audio_cfg.hidden_size,
            audio_cfg.intermediate_size,
            audio_cfg.hidden_size,
            audio_cfg.hidden_act,
            true,
            "fc1",
            "fc2",
        )?;
        let input_layernorm =
            get_layer_norm(vb.pp("input_layernorm"), 1e-5, audio_cfg.hidden_size, true)?;
        let post_attention_layernorm = get_layer_norm(
            vb.pp("post_attention_layernorm"),
            1e-5,
            audio_cfg.hidden_size,
            true,
        )?;
        Ok(Self {
            self_attn,
            mlp,
            input_layernorm,
            post_attention_layernorm,
        })
    }

    pub fn forward(
        &self,
        xs: &Tensor,
        cos: Option<&Tensor>,
        sin: Option<&Tensor>,
        attention_mask: Option<&Tensor>,
        tof32: bool,
    ) -> Result<Tensor> {
        let residual = xs.clone();
        let xs = self.input_layernorm.forward(xs)?;
        let xs = self
            .self_attn
            .forward(&xs, cos, sin, attention_mask, tof32)?;
        let residual = residual.add(&xs)?;
        let xs = self.post_attention_layernorm.forward(&residual)?;
        let xs = self.mlp.forward(&xs)?;
        let xs = residual.add(&xs)?;
        Ok(xs)
    }
}

pub struct GlmAsrEncoder {
    conv1: Conv1d,
    conv2: Conv1d,
    layers: Vec<GlmAsrEncoderLayer>,
    norm: LayerNorm,
    rotary_emb: RoPE,
}

impl GlmAsrEncoder {
    pub fn new(vb: VarBuilder, audio_cfg: &GlmAsrAudioConfig) -> Result<Self> {
        let conv1 = get_conv1d(
            vb.pp("conv1"),
            audio_cfg.num_mel_bins,
            audio_cfg.hidden_size,
            3,
            1,
            1,
            1,
            1,
            true,
        )?;
        let conv2 = get_conv1d(
            vb.pp("conv2"),
            audio_cfg.hidden_size,
            audio_cfg.hidden_size,
            3,
            1,
            2,
            1,
            1,
            true,
        )?;
        let mut layers = vec![];
        let vb_layers = vb.pp("layers");
        for i in 0..audio_cfg.num_hidden_layers {
            let layer_i = GlmAsrEncoderLayer::new(vb_layers.pp(i), audio_cfg)?;
            layers.push(layer_i);
        }
        let norm = get_layer_norm(vb.pp("norm"), 1e-5, audio_cfg.hidden_size, true)?;
        let dim = (audio_cfg.head_dim as f64 * audio_cfg.partial_rotary_factor) as usize;
        let rotary_emb = RoPE::new(dim, audio_cfg.rope_parameters.rope_theta, vb.device())?;
        Ok(Self {
            conv1,
            conv2,
            layers,
            norm,
            rotary_emb,
        })
    }
    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let xs = self.conv1.forward(xs)?.gelu()?;
        let xs = self.conv2.forward(&xs)?.gelu()?;
        let mut xs = xs.transpose(1, 2)?;
        let (_, seq_len, _) = xs.dims3()?;
        let (cos, sin) = self.rotary_emb.forward(0, seq_len, xs.device())?;
        for encoder_layer in &self.layers {
            xs = encoder_layer.forward(&xs, Some(&cos), Some(&sin), None, false)?;
        }
        let xs = self.norm.forward(&xs)?;
        Ok(xs)
    }
}

pub struct GlmAsrNanoModel {
    config: GlmAsrNanoConfig,
    audio_tower: GlmAsrEncoder,
    multi_modal_projector: TwoLinearMLP,
    language_model: LlamaForCausalLM,
    stop_token_ids: Vec<u32>,
}

impl GlmAsrNanoModel {
    pub fn new(vb: VarBuilder, config: GlmAsrNanoConfig, eos_ids: Vec<u32>) -> Result<Self> {
        let audio_tower = GlmAsrEncoder::new(vb.pp("audio_tower"), &config.audio_config)?;
        let multi_modal_projector = TwoLinearMLP::new(
            vb.pp("multi_modal_projector"),
            config.audio_config.intermediate_size,
            config.text_config.hidden_size * 2,
            config.text_config.hidden_size,
            config.projector_hidden_act,
            true,
            "linear_1",
            "linear_2",
        )?;
        let language_model = LlamaForCausalLM::new(
            vb.pp("language_model"),
            config.text_config.vocab_size,
            config.text_config.hidden_size,
            config.text_config.num_hidden_layers,
            config.text_config.num_attention_heads,
            Some(config.text_config.num_key_value_heads),
            Some(config.text_config.head_dim),
            config.text_config.attention_bias,
            "self_attn",
            Some("o_proj"),
            config.text_config.intermediate_size,
            config.text_config.hidden_act,
            config.text_config.mlp_bias,
            "mlp",
            config.text_config.rms_norm_eps,
            "input_layernorm",
            "post_attention_layernorm",
            config.text_config.rope_parameters.rope_theta,
        )?;
        Ok(Self {
            config,
            audio_tower,
            multi_modal_projector,
            language_model,
            stop_token_ids: eos_ids,
        })
    }

    pub fn get_audio_features(
        &self,
        input_features: &Tensor,
        audio_token_lengths: &[u32],
    ) -> Result<Tensor> {
        let audio_hidden_states = self.audio_tower.forward(input_features)?;
        let bs = audio_hidden_states.dim(0)?;
        let audio_hidden_states =
            audio_hidden_states.reshape((bs, (), self.config.audio_config.intermediate_size))?;
        let audio_embeds = self.multi_modal_projector.forward(&audio_hidden_states)?;
        let mut valid_audios = vec![];
        for (i, &len) in audio_token_lengths.iter().enumerate() {
            let len = len as usize;
            let audio_i = audio_embeds.i((i, 0..len, ..))?;
            valid_audios.push(audio_i);
        }
        let audio_embeds = Tensor::cat(&valid_audios, 0)?;

        Ok(audio_embeds)
    }

    pub fn forward(
        &mut self,
        input_features: Option<&Tensor>,
        audio_token_lengths: Option<&Tensor>,
        input_ids: &Tensor,
        seqlen_offset: usize,
    ) -> Result<Tensor> {
        let mut inputs_embeds = self.language_model.model.embed_tokens.forward(input_ids)?;
        if let Some(input_features) = input_features
            && let Some(audio_token_len) = audio_token_lengths
        {
            let audio_token_len = audio_token_len.to_vec1::<u32>()?;
            let audio_token_mask = get_equal_mask(input_ids, self.config.audio_token_id)?;
            let audio_embeds = self.get_audio_features(input_features, &audio_token_len)?;
            inputs_embeds = masked_scatter_dim0(&inputs_embeds, &audio_embeds, &audio_token_mask)?;
        }
        let logits = self.language_model.forward(&inputs_embeds, seqlen_offset)?;
        Ok(logits)
    }
    pub fn clear_kv_cache(&mut self) {
        self.language_model.clear_kv_cache();
    }
}

impl InferenceModel for GlmAsrNanoModel {
    fn forward_initial(
        &mut self,
        input_ids: &Tensor,
        seqlen_offset: usize,
        data: crate::models::common::MultiModalData,
    ) -> Result<Tensor> {
        if data.data_vec.len() != 2 {
            return Err(anyhow::anyhow!(
                "GlmAsrNano process data error, must have input_features, audio_token_lengths"
            ));
        }
        let input_features = &data.data_vec[0];
        let audio_token_lengths = &data.data_vec[1];
        self.forward(
            input_features.as_ref(),
            audio_token_lengths.as_ref(),
            input_ids,
            seqlen_offset,
        )
    }

    fn forward_step(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
        self.forward(None, None, input_ids, seqlen_offset)
    }

    fn clear_cache(&mut self) {
        self.clear_kv_cache();
    }

    fn stop_token_ids(&self) -> Vec<u32> {
        self.stop_token_ids.clone()
    }
}