aha 0.2.6

aha model inference library, now supports Qwen(2.5VL/3/3VL/3.5/ASR/3Embedding/3Reranker), MiniCPM(4/5), VoxCPM(0.5B/1.5/2), DeepSeek-OCR/2, Hunyuan-OCR, PaddleOCR-VL/1.5, RMBG2.0, GLM(ASR-Nano-2512/OCR), Fun-ASR-Nano-2512, LFM(2/2.5/2VL/2.5VL)
Documentation
use anyhow::Result;
use candle_core::Tensor;
use candle_nn::{
    Activation, Embedding, Init, LayerNorm, Linear, Module, VarBuilder, embedding, linear_b,
};

use crate::{
    models::common::modules::{TwoLinearMLP, eager_attention_forward, get_layer_norm},
    position_embed::rope::{RoPE, apply_rotary_pos_emb_interleave},
    utils::tensor_utils::prepare_causal_attention_mask,
};
pub mod config;

pub struct GPT2Attention {
    num_heads: usize,
    head_dim: usize,
    c_attn: Linear,
    c_proj: Linear,
    kv_cache: Option<(Tensor, Tensor)>,
}

impl GPT2Attention {
    pub fn new(vb: VarBuilder, hidden_size: usize, num_heads: usize) -> Result<Self> {
        let c_attn = linear_b(hidden_size, 3 * hidden_size, true, vb.pp("c_attn"))?;
        let c_proj = linear_b(hidden_size, hidden_size, true, vb.pp("c_proj"))?;
        let head_dim = hidden_size / num_heads;
        Ok(Self {
            num_heads,
            head_dim,
            c_attn,
            c_proj,
            kv_cache: None,
        })
    }

    pub fn forward(
        &mut self,
        xs: &Tensor,
        cos: Option<&Tensor>,
        sin: Option<&Tensor>,
        attention_mask: Option<&Tensor>,
    ) -> Result<Tensor> {
        let (b, seq_len, _) = xs.dims3()?;
        let xs = self.c_attn.forward(xs)?;
        let xs_splits = xs.chunk(3, 2)?;
        let query_states = xs_splits[0]
            .as_ref()
            .reshape((b, seq_len, self.num_heads, self.head_dim))?
            .transpose(1, 2)?;
        let key_states = xs_splits[1]
            .as_ref()
            .reshape((b, seq_len, self.num_heads, self.head_dim))?
            .transpose(1, 2)?;
        let value_states = xs_splits[2]
            .as_ref()
            .reshape((b, seq_len, self.num_heads, self.head_dim))?
            .transpose(1, 2)?;
        let (query_states, key_states) = if let Some(cos) = cos
            && let Some(sin) = sin
        {
            apply_rotary_pos_emb_interleave(&query_states, &key_states, cos, sin, false)?
        } else {
            (query_states, key_states)
        };
        let (key_states, value_states) = match &self.kv_cache {
            None => (key_states, value_states),
            Some((prev_k, prev_v)) => {
                let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
                let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
                (key_states, value_states)
            }
        };
        self.kv_cache = Some((key_states.clone(), value_states.clone()));
        let scale: f64 = 1f64 / f64::sqrt(self.head_dim as f64);
        let attn_output = eager_attention_forward(
            &query_states,
            &key_states,
            &value_states,
            None,
            attention_mask,
            scale,
        )?;
        let attn_output = attn_output.reshape((b, seq_len, self.num_heads * self.head_dim))?;
        let attn_output = attn_output.apply(&self.c_proj)?;
        Ok(attn_output)
    }

    pub fn clear_kv_cache(&mut self) {
        self.kv_cache = None
    }
}

pub struct GPT2MLP {
    linear1: Linear,
    linear2: Linear,
    act: Activation,
}

impl GPT2MLP {
    pub fn new(
        vb: VarBuilder,
        in_dim: usize,
        middle_dim: usize,
        out_dim: usize,
        act: Activation,
    ) -> Result<Self> {
        let c_fc_weight = vb
            .get_with_hints((in_dim, middle_dim), "c_fc.weight", Init::Const(1.0))?
            .t()?;
        let c_fc_bias = vb.get_with_hints(middle_dim, "c_fc.bias", Init::Const(0.0))?;
        let c_fc = Linear::new(c_fc_weight, Some(c_fc_bias));

        let c_proj_weight = vb
            .get_with_hints((middle_dim, out_dim), "c_proj.weight", Init::Const(1.0))?
            .t()?;
        let c_proj_bias = vb.get_with_hints(out_dim, "c_proj.bias", Init::Const(0.0))?;
        let c_proj = Linear::new(c_proj_weight, Some(c_proj_bias));

        Ok(Self {
            linear1: c_fc,
            linear2: c_proj,
            act,
        })
    }
    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let xs = xs
            .apply(&self.linear1)?
            .apply(&self.act)?
            .apply(&self.linear2)?;
        Ok(xs)
    }
}

pub struct GPT2Block {
    ln_1: LayerNorm,
    attn: GPT2Attention,
    ln_2: LayerNorm,
    mlp: TwoLinearMLP,
}

impl GPT2Block {
    pub fn new(
        vb: VarBuilder,
        hidden_size: usize,
        num_heads: usize,
        inner_dim: Option<usize>,
    ) -> Result<Self> {
        let inner_dim = inner_dim.unwrap_or(4 * hidden_size);
        let ln_1 = get_layer_norm(vb.pp("ln_1"), 1e-5, hidden_size, true)?;
        let attn = GPT2Attention::new(vb.pp("attn"), hidden_size, num_heads)?;
        let ln_2 = get_layer_norm(vb.pp("ln_2"), 1e-5, hidden_size, true)?;
        let mlp = TwoLinearMLP::new(
            vb.pp("mlp"),
            hidden_size,
            inner_dim,
            hidden_size,
            Activation::NewGelu,
            true,
            "fc_in",
            "fc_out",
        )?;
        Ok(Self {
            ln_1,
            attn,
            ln_2,
            mlp,
        })
    }

    pub fn forward(
        &mut self,
        xs: &Tensor,
        cos: Option<&Tensor>,
        sin: Option<&Tensor>,
        attention_mask: Option<&Tensor>,
    ) -> Result<Tensor> {
        let residual = xs.clone();
        let xs = self.ln_1.forward(xs)?;
        let xs = self.attn.forward(&xs, cos, sin, attention_mask)?;
        let residual = xs.add(&residual)?;
        let xs = self.ln_2.forward(&residual)?;
        let xs = self.mlp.forward(&xs)?;
        let xs = xs.add(&residual)?;
        Ok(xs)
    }
    pub fn clear_kv_cache(&mut self) {
        self.attn.clear_kv_cache()
    }
}

#[allow(unused)]
pub struct GPT2Model {
    pub wte: Option<Embedding>,
    // wpe: Embedding, //rope not need
    h: Vec<GPT2Block>,
    ln_f: LayerNorm,
    rope: Option<RoPE>,
}

#[allow(unused)]
impl GPT2Model {
    pub fn new(
        vb: VarBuilder,
        hidden_size: usize,
        num_heads: usize,
        num_hidden_layers: usize,
        vocab_size: usize,
        // n_positions: usize,
    ) -> Result<Self> {
        let wte = Some(embedding(vocab_size, hidden_size, vb.pp("wte"))?);
        let vb_layers = vb.pp("h");
        let mut h = vec![];
        for i in 0..num_hidden_layers {
            let block = GPT2Block::new(vb_layers.pp(i), hidden_size, num_heads, None)?;
            h.push(block);
        }
        let ln_f = get_layer_norm(vb.pp("ln_f"), 1e-5, hidden_size, true)?;
        let head_dim = hidden_size / num_heads;
        let rope = RoPE::new(head_dim, 10000.0, vb.device())?;
        Ok(Self {
            wte,
            h,
            ln_f,
            rope: Some(rope),
        })
    }

    pub fn new_without_wte(
        vb: VarBuilder,
        hidden_size: usize,
        num_heads: usize,
        num_hidden_layers: usize,
        vocab_size: usize,
        // n_positions: usize,
    ) -> Result<Self> {
        let wte = None;
        let vb_layers = vb.pp("h");
        let mut h = vec![];
        for i in 0..num_hidden_layers {
            let block = GPT2Block::new(vb_layers.pp(i), hidden_size, num_heads, None)?;
            h.push(block);
        }
        let ln_f = get_layer_norm(vb.pp("ln_f"), 1e-5, hidden_size, true)?;
        let head_dim = hidden_size / num_heads;
        let rope = RoPE::new(head_dim, 10000.0, vb.device())?;
        Ok(Self {
            wte,
            h,
            ln_f,
            rope: Some(rope),
        })
    }

    pub fn new_with_wte(
        vb: VarBuilder,
        hidden_size: usize,
        num_heads: usize,
        num_hidden_layers: usize,
        wte_embeddings: &Tensor,
    ) -> Result<Self> {
        let wte = Some(Embedding::new(wte_embeddings.clone(), hidden_size));
        let vb_layers = vb.pp("h");
        let mut h = vec![];
        for i in 0..num_hidden_layers {
            let block = GPT2Block::new(vb_layers.pp(i), hidden_size, num_heads, None)?;
            h.push(block);
        }
        let ln_f = get_layer_norm(vb.pp("ln_f"), 1e-5, hidden_size, true)?;
        Ok(Self {
            wte,
            h,
            ln_f,
            rope: None,
        })
    }

    pub fn forward(&mut self, inputs_embeds: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
        let (b_size, seq_len, _) = inputs_embeds.dims3()?;
        let (cos, sin) = if let Some(rope) = &self.rope {
            let (cos, sin) =
                rope.forward_repeat_interleave(seqlen_offset, seq_len, inputs_embeds.device())?;
            (Some(cos), Some(sin))
        } else {
            (None, None)
        };

        let mut xs = inputs_embeds.clone();
        let attention_mask: Option<Tensor> = {
            if seq_len <= 1 {
                None
            } else {
                Some(prepare_causal_attention_mask(
                    b_size,
                    seq_len,
                    0,
                    xs.device(),
                )?)
            }
        };
        for block in &mut self.h {
            xs = block.forward(&xs, cos.as_ref(), sin.as_ref(), attention_mask.as_ref())?;
        }
        xs = self.ln_f.forward(&xs)?;
        Ok(xs)
    }

    pub fn clear_kv_cache(&mut self) {
        for layer in self.h.iter_mut() {
            layer.clear_kv_cache()
        }
    }
}