aha 0.2.5

aha model inference library, now supports Qwen(2.5VL/3/3VL/3.5/ASR/3Embedding/3Reranker), MiniCPM4, VoxCPM/1.5, DeepSeek-OCR/2, Hunyuan-OCR, PaddleOCR-VL/1.5, RMBG2.0, GLM(ASR-Nano-2512/OCR), Fun-ASR-Nano-2512, LFM(2/2.5/2VL/2.5VL)
Documentation
use crate::{
    models::{
        common::{
            InferenceModel,
            modules::{NaiveAttnTwoLinearMLPBlock, get_layer_norm},
        },
        lfm2::model::Lfm2Decoder,
        lfm2vl::config::{Lfm2VLConfig, Lfm2VLVisionConfig},
    },
    utils::{
        interpolate::interpolate_bilinear,
        tensor_utils::{get_equal_mask, masked_scatter_dim0, prepare_mask},
    },
};
use anyhow::{Result, anyhow};
use candle_core::{D, IndexOp, Tensor};
use candle_nn::{Activation, LayerNorm, Linear, Module, VarBuilder, embedding, linear_b};
use num::integer::Roots;

pub struct Siglip2VisionEmbeddings {
    patch_embedding: Linear,
    postitional_embeddings: Tensor,
}

impl Siglip2VisionEmbeddings {
    pub fn new(vb: VarBuilder, cfg: &Lfm2VLVisionConfig) -> Result<Self> {
        let embed_dim = cfg.hidden_size;
        let patch_size = cfg.patch_size;
        let patch_embedding = linear_b(
            cfg.num_channels * patch_size * patch_size,
            embed_dim,
            true,
            vb.pp("patch_embedding"),
        )?;
        let position_embedding_size = cfg.num_patches.sqrt();
        let position_embedding =
            embedding(cfg.num_patches, embed_dim, vb.pp("position_embedding"))?;
        // (num_patches, embed_dim) -> (h, w, embed_dim)
        //  ->(embed_dim, h, w) -> (1, embe_dim, h, w)
        let postitional_embeddings = position_embedding
            .embeddings()
            .reshape((position_embedding_size, position_embedding_size, ()))?
            .permute((2, 0, 1))?
            .unsqueeze(0)?;
        Ok(Self {
            patch_embedding,
            postitional_embeddings,
        })
    }

    fn resize_positional_embeddings(
        &self,
        spatial_shapes: &Tensor,
        max_length: usize,
    ) -> Result<Tensor> {
        let mut result_pos_embeddings = vec![];
        let bs = spatial_shapes.dim(0)?;
        for i in 0..bs {
            let shape_i = spatial_shapes.i(i)?.to_vec1::<u32>()?;
            let height = *shape_i.first().unwrap_or(&32) as usize;
            let width = *shape_i.get(1).unwrap_or(&32) as usize;

            if height == 0 || width == 0 || height * width > max_length {
                return Err(anyhow!("img height or width illigal"));
            }
            // (1, dim, height, width) -> (height*width, dim)
            let resize_embeddings = interpolate_bilinear(
                &self.postitional_embeddings,
                (height, width),
                Some(false),
                Some(true),
            )?
            .reshape(((), height * width))?
            .transpose(0, 1)?;
            let resize_embeddings = if height * width < max_length {
                let pad = max_length - height * width;
                let pad_embedding = resize_embeddings.i(0)?.unsqueeze(0)?.repeat((pad, 1))?;
                Tensor::cat(&[&resize_embeddings, &pad_embedding], 0)?
            } else {
                resize_embeddings
            };
            result_pos_embeddings.push(resize_embeddings);
        }

        let result_pos_embeddings = Tensor::stack(&result_pos_embeddings, 0)?;
        Ok(result_pos_embeddings)
    }

    pub fn forward(&self, pixel_values: &Tensor, spatial_shapes: &Tensor) -> Result<Tensor> {
        let patch_embeds = self.patch_embedding.forward(pixel_values)?;
        let max_length = pixel_values.dim(1)?;
        let resize_pos_embeddings =
            self.resize_positional_embeddings(spatial_shapes, max_length)?;
        let embedding = patch_embeds.add(&resize_pos_embeddings)?;
        Ok(embedding)
    }
}

pub struct Siglip2Encoder {
    layers: Vec<NaiveAttnTwoLinearMLPBlock>,
}

impl Siglip2Encoder {
    pub fn new(vb: VarBuilder, cfg: &Lfm2VLVisionConfig) -> Result<Self> {
        let vb_layers = vb.pp("layers");
        let mut layers = vec![];
        for i in 0..cfg.num_hidden_layers {
            let layer = NaiveAttnTwoLinearMLPBlock::new(
                vb_layers.pp(i),
                cfg.hidden_size,
                cfg.num_attention_heads,
                None,
                None,
                true,
                "self_attn",
                Some("out_proj"),
                cfg.intermediate_size,
                cfg.hidden_act,
                true,
                "mlp",
                "fc1",
                "fc2",
                cfg.layer_norm_eps,
                "layer_norm1",
                "layer_norm2",
            )?;
            layers.push(layer);
        }
        Ok(Self { layers })
    }

    pub fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
        let mut xs = xs.clone();
        for layer in &self.layers {
            xs = layer.forward(&xs, None, None, attention_mask, false)?;
        }
        Ok(xs)
    }
}

pub struct Siglip2VisionModel {
    embeddings: Siglip2VisionEmbeddings,
    encoder: Siglip2Encoder,
    post_layernorm: LayerNorm,
}

impl Siglip2VisionModel {
    pub fn new(vb: VarBuilder, cfg: &Lfm2VLVisionConfig) -> Result<Self> {
        let vb = vb.pp("vision_model");
        let embeddings = Siglip2VisionEmbeddings::new(vb.pp("embeddings"), cfg)?;
        let encoder = Siglip2Encoder::new(vb.pp("encoder"), cfg)?;
        let post_layernorm = get_layer_norm(
            vb.pp("post_layernorm"),
            cfg.layer_norm_eps,
            cfg.hidden_size,
            true,
        )?;
        Ok(Self {
            embeddings,
            encoder,
            post_layernorm,
        })
    }

    pub fn forward(
        &self,
        pixel_values: &Tensor,
        attention_mask: &Tensor,
        spatial_shapes: &Tensor,
    ) -> Result<Tensor> {
        let xs = self.embeddings.forward(pixel_values, spatial_shapes)?;
        let mask = prepare_mask(attention_mask)?.to_dtype(xs.dtype())?;
        let xs = self.encoder.forward(&xs, Some(&mask))?;
        let xs = self.post_layernorm.forward(&xs)?;
        Ok(xs)
    }
}

pub struct Lfm2VlMultiModalProjector {
    factor: usize,
    layer_norm: Option<LayerNorm>,
    linear_1: Linear,
    act: Activation,
    linear_2: Linear,
}

impl Lfm2VlMultiModalProjector {
    pub fn new(vb: VarBuilder, cfg: &Lfm2VLConfig) -> Result<Self> {
        let in_channels = cfg.vision_config.hidden_size * (cfg.downsample_factor).pow(2);
        let factor = cfg.downsample_factor;
        let layer_norm = if let Some(flag) = cfg.projector_use_layernorm
            && !flag
        {
            None
        } else {
            let layer_norm = get_layer_norm(
                vb.pp("layer_norm"),
                cfg.vision_config.layer_norm_eps,
                in_channels,
                true,
            )?;
            Some(layer_norm)
        };
        let linear_1 = linear_b(
            in_channels,
            cfg.projector_hidden_size,
            cfg.projector_bias,
            vb.pp("linear_1"),
        )?;
        let act = cfg.projector_hidden_act;
        let linear_2 = linear_b(
            cfg.projector_hidden_size,
            cfg.text_config.hidden_size,
            cfg.projector_bias,
            vb.pp("linear_2"),
        )?;
        Ok(Self {
            factor,
            layer_norm,
            linear_1,
            act,
            linear_2,
        })
    }

    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let (bs, w, h, c) = xs.dims4()?;
        let xs = xs.reshape((bs, w, h / self.factor, c * self.factor))?;
        let xs = xs.permute((0, 2, 1, 3))?;
        let xs = xs.reshape((
            bs,
            h / self.factor,
            w / self.factor,
            c * self.factor * self.factor,
        ))?;
        let mut xs = xs.permute((0, 2, 1, 3))?.contiguous()?;
        if let Some(norm) = &self.layer_norm {
            xs = norm.forward(&xs)?;
        }
        xs = self.linear_1.forward(&xs)?.apply(&self.act)?;
        xs = self.linear_2.forward(&xs)?;
        Ok(xs)
    }
}

pub struct Lfm2VLModel {
    vision_tower: Siglip2VisionModel,
    multi_modal_projector: Lfm2VlMultiModalProjector,
    language_model: Lfm2Decoder,
    lm_head: Linear,
    img_id: u32,
    stop_token_ids: Vec<u32>,
}

impl Lfm2VLModel {
    pub fn new(vb: VarBuilder, cfg: &Lfm2VLConfig, eos_ids: Vec<u32>) -> Result<Self> {
        let vb = vb.pp("model");
        let vision_tower = Siglip2VisionModel::new(vb.pp("vision_tower"), &cfg.vision_config)?;
        let multi_modal_projector =
            Lfm2VlMultiModalProjector::new(vb.pp("multi_modal_projector"), cfg)?;
        let language_model = Lfm2Decoder::new(vb.pp("language_model"), &cfg.text_config)?;
        let lm_head = Linear::new(language_model.embed_tokens.embeddings().clone(), None);
        Ok(Self {
            vision_tower,
            multi_modal_projector,
            language_model,
            lm_head,
            img_id: cfg.image_token_id,
            stop_token_ids: eos_ids,
        })
    }

    pub fn forward(
        &mut self,
        input_ids: &Tensor,
        pixel_values: Option<&Tensor>,
        pixel_attention_mask: Option<&Tensor>,
        spatial_shapes: Option<&Tensor>,
        seqlen_offset: usize,
    ) -> Result<Tensor> {
        let mut inputs_embeds = self.language_model.embed_tokens.forward(input_ids)?;
        if let Some(pixel) = pixel_values
            && let Some(mask) = pixel_attention_mask
            && let Some(shapes) = spatial_shapes
        {
            let image_embeds = self.vision_tower.forward(pixel, mask, shapes)?;
            let bs = image_embeds.dim(0)?;
            let img_featrure_length = mask.sum(1)?.to_vec1::<u32>()?;
            let mut image_features = vec![];
            for img_idx in 0..bs {
                let feature = image_embeds.i(img_idx)?;
                let feature = feature.narrow(0, 0, img_featrure_length[img_idx] as usize)?;
                let shape = shapes.i(img_idx)?.to_vec1::<u32>()?;
                let h = shape[0];
                let w = shape[1];
                let feature = feature
                    .reshape((1, h as usize, w as usize, ()))?
                    .contiguous()?;
                let img_embedding = self.multi_modal_projector.forward(&feature)?;
                let dim = img_embedding.dim(D::Minus1)?;
                let img_embedding = img_embedding.reshape(((), dim))?;
                image_features.push(img_embedding);
            }
            let image_embeds = Tensor::cat(&image_features, 0)?;
            let image_mask = get_equal_mask(input_ids, self.img_id)?;
            inputs_embeds = masked_scatter_dim0(&inputs_embeds, &image_embeds, &image_mask)?;
        }
        let output = self
            .language_model
            .forward(input_ids, Some(&inputs_embeds), seqlen_offset)?;
        let seq_len = output.dim(1)?;
        let last = output.narrow(1, seq_len - 1, 1)?;
        let logits = self.lm_head.forward(&last)?;
        Ok(logits)
    }

    pub fn clear_cache(&mut self) {
        self.language_model.clear_cache();
    }
}

impl InferenceModel for Lfm2VLModel {
    fn forward_initial(
        &mut self,
        input_ids: &Tensor,
        seqlen_offset: usize,
        data: crate::models::common::MultiModalData,
    ) -> Result<Tensor> {
        if data.data_vec.len() != 3 {
            return Err(anyhow::anyhow!(
                "Lfm2VL process data error, must have pixel_values, pixel_attention_mask, spatial_shapes"
            ));
        }
        let pixel_values = &data.data_vec[0];
        let pixel_attention_mask = &data.data_vec[1];
        let spatial_shapes = &data.data_vec[2];
        self.forward(
            input_ids,
            pixel_values.as_ref(),
            pixel_attention_mask.as_ref(),
            spatial_shapes.as_ref(),
            seqlen_offset,
        )
    }

    fn forward_step(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
        self.forward(input_ids, None, None, None, seqlen_offset)
    }

    fn clear_cache(&mut self) {
        self.clear_cache();
    }

    fn stop_token_ids(&self) -> Vec<u32> {
        self.stop_token_ids.clone()
    }
}