aha 0.2.5

aha model inference library, now supports Qwen(2.5VL/3/3VL/3.5/ASR/3Embedding/3Reranker), MiniCPM4, VoxCPM/1.5, DeepSeek-OCR/2, Hunyuan-OCR, PaddleOCR-VL/1.5, RMBG2.0, GLM(ASR-Nano-2512/OCR), Fun-ASR-Nano-2512, LFM(2/2.5/2VL/2.5VL)
Documentation
use crate::params::chat::ChatCompletionParameters;
use anyhow::Result;
use candle_core::{DType, Device, Tensor};

use crate::utils::img_utils::dynamic_preprocess;
use crate::{
    tokenizer::TokenizerModel,
    utils::{
        extract_mes,
        img_utils::{extract_images, img_transform, resize_with_edge_padding},
    },
};

pub struct DeepseekOCRProcessor {
    device: Device,
    dtype: DType,
    image_token: String,
    image_token_id: u32,
    patch_size: u32,
    downsample_ratio: u32,
    version: usize,
}

impl DeepseekOCRProcessor {
    pub fn new(device: &Device, dtype: DType, version: usize) -> Result<Self> {
        Ok(Self {
            device: device.clone(),
            dtype,
            image_token: "<image>".to_string(),
            image_token_id: 128815,
            patch_size: 16,
            downsample_ratio: 4,
            version,
        })
    }

    fn get_prompt(&self, mes_vec: Vec<(String, String)>) -> Result<String> {
        let sep = "\n";
        let sep2 = "";
        let mut ret = "".to_string();
        for (i, (_, message)) in mes_vec.iter().enumerate() {
            if message.chars().count() > 0 {
                if i % 2 == 0 {
                    ret = ret + message + sep;
                } else {
                    ret = ret + message + sep2;
                }
            }
        }
        ret = ret.trim().to_string();
        Ok(ret)
    }

    pub fn process_info(
        &self,
        mes: &ChatCompletionParameters,
        tokenizer: &TokenizerModel,
        base_size: u32,
        image_size: u32,
        crop_mode: bool,
    ) -> Result<(Tensor, Tensor, Tensor, Tensor, Tensor)> {
        let imgs = extract_images(mes)?;
        let mes_vec = extract_mes(mes)?;
        let prompt = self.get_prompt(mes_vec.clone())?;
        let text_splits: Vec<&str> = prompt.split(&self.image_token).collect();
        let img_mean =
            Tensor::from_slice(&[0.5, 0.5, 0.5], (3, 1, 1), &self.device)?.to_dtype(self.dtype)?;
        let img_std =
            Tensor::from_slice(&[0.5, 0.5, 0.5], (3, 1, 1), &self.device)?.to_dtype(self.dtype)?;
        let mut images_list = Vec::new();
        let mut images_crop_list = Vec::new();
        let mut images_seq_mask = vec![0u32];
        let mut tokenized_id = vec![0u32];
        let mut images_spatial_crop = Vec::new();
        let min_img_size = if self.version == 2 { 768 } else { 640 };
        let max_num = if self.version == 2 { 6 } else { 9 };
        for (text_seq, image) in text_splits.iter().zip(imgs) {
            if !text_seq.is_empty() {
                let token_ids = tokenizer.text_encode_vec(text_seq.to_string(), false)?;
                tokenized_id.extend_from_slice(&token_ids);
                let seq_mask = vec![0u32; token_ids.len()];
                images_seq_mask.extend_from_slice(&seq_mask);
            }
            if crop_mode {
                let mut images_crop_raw = Vec::new();
                let crop_ratio = if image.height() <= min_img_size && image.width() <= min_img_size
                {
                    (1u32, 1u32)
                } else {
                    let (img_crop, ratio) =
                        dynamic_preprocess(&image, 2, max_num, min_img_size, false)?;
                    images_crop_raw = img_crop.clone();
                    ratio
                };

                let gloabal_view =
                    resize_with_edge_padding(&image, base_size, base_size, [127u8; 3]);

                let global_img_trans =
                    img_transform(&gloabal_view, &img_mean, &img_std, &self.device, self.dtype)?;
                images_list.push(global_img_trans);

                images_spatial_crop.push(vec![crop_ratio.0, crop_ratio.1]);

                if crop_ratio.0 > 1 || crop_ratio.1 > 1 {
                    for img in images_crop_raw {
                        let img_t =
                            img_transform(&img, &img_mean, &img_std, &self.device, self.dtype)?;
                        images_crop_list.push(img_t);
                    }
                }

                let num_queries = image_size / self.patch_size / self.downsample_ratio;
                let num_queries_base = base_size / self.patch_size / self.downsample_ratio;
                let mut token_repeat = if self.version == 1 {
                    num_queries_base.pow(2) + num_queries_base + 1
                } else {
                    num_queries_base.pow(2) + 1
                };
                if crop_ratio.0 > 1 || crop_ratio.1 > 1 {
                    let add_num = if self.version == 1 {
                        (num_queries * crop_ratio.0 + 1) * (num_queries * crop_ratio.1)
                    } else {
                        (num_queries * crop_ratio.0) * (num_queries * crop_ratio.1)
                    };
                    token_repeat += add_num;
                }
                let tokenized_image = vec![self.image_token_id; token_repeat as usize];
                tokenized_id.extend_from_slice(&tokenized_image);
                let seq_mask = vec![1u32; tokenized_image.len()];
                images_seq_mask.extend_from_slice(&seq_mask);
            } else {
                let global_view = if image_size <= min_img_size {
                    image.resize_exact(
                        image_size,
                        image_size,
                        image::imageops::FilterType::CatmullRom,
                    )
                } else {
                    resize_with_edge_padding(&image, image_size, image_size, [127u8; 3])
                };
                let global_img_trans =
                    img_transform(&global_view, &img_mean, &img_std, &self.device, self.dtype)?;
                images_list.push(global_img_trans);

                images_spatial_crop.push(vec![1, 1]);
                let num_queries = image_size / self.patch_size / self.downsample_ratio;
                let token_repeat = if self.version == 1 {
                    num_queries.pow(2) + num_queries + 1
                } else {
                    num_queries.pow(2) + 1
                };
                let tokenized_image = vec![self.image_token_id; token_repeat as usize];
                tokenized_id.extend_from_slice(&tokenized_image);
                let seq_mask = vec![1u32; tokenized_image.len()];
                images_seq_mask.extend_from_slice(&seq_mask);
            }
        }
        let token_ids =
            tokenizer.text_encode_vec(text_splits[text_splits.len() - 1].to_string(), false)?;
        tokenized_id.extend_from_slice(&token_ids);
        let seq_mask = vec![0u32; token_ids.len()];
        images_seq_mask.extend_from_slice(&seq_mask);
        let input_ids = Tensor::new(tokenized_id, &self.device)?.unsqueeze(0)?;
        let image_seq_mask = Tensor::new(images_seq_mask, &self.device)?.unsqueeze(0)?;
        let (images_ori, images_spatial_crop_t, image_crop) = if images_list.is_empty() {
            let images_ori = Tensor::zeros(
                (1usize, 3usize, image_size as usize, image_size as usize),
                self.dtype,
                &self.device,
            )?;
            let images_spatial_crop_t = Tensor::zeros((1, 2), DType::F64, &self.device)?;
            let image_crop = Tensor::zeros(
                (1usize, 3usize, base_size as usize, base_size as usize),
                self.dtype,
                &self.device,
            )?;
            (images_ori, images_spatial_crop_t, image_crop)
        } else {
            let images_ori = Tensor::stack(&images_list, 0)?;
            let images_spatial_crop_t = Tensor::new(images_spatial_crop, &self.device)?;
            let image_crop = if !images_crop_list.is_empty() {
                Tensor::stack(&images_crop_list, 0)?
            } else {
                Tensor::zeros(
                    (1usize, 3usize, base_size as usize, base_size as usize),
                    self.dtype,
                    &self.device,
                )?
            };
            (images_ori, images_spatial_crop_t, image_crop)
        };

        Ok((
            input_ids,
            images_ori,
            image_crop,
            image_seq_mask,
            images_spatial_crop_t,
        ))
    }
}