aha 0.2.6

aha model inference library, now supports Qwen(2.5VL/3/3VL/3.5/ASR/3Embedding/3Reranker), MiniCPM(4/5), VoxCPM(0.5B/1.5/2), DeepSeek-OCR/2, Hunyuan-OCR, PaddleOCR-VL/1.5, RMBG2.0, GLM(ASR-Nano-2512/OCR), Fun-ASR-Nano-2512, LFM(2/2.5/2VL/2.5VL)
Documentation
use crate::{
    models::common::{
        MultiModalData,
        generate::{
            GenerationContext, GenerationDataProvider, PrepareData, generate_generic_text,
            generate_stream_generic_text,
        },
    },
    params::chat::ChatCompletionParameters,
};
use anyhow::{Result, anyhow};
use candle_core::{DType, Device, quantized::gguf_file};
use candle_nn::VarBuilder;
use rocket::futures::Stream;

use crate::{
    chat_template::ChatTemplate,
    models::{
        common::gguf::Gguf,
        qwen3_5::{config::Qwen3_5Config, model::Qwen3_5Model},
        qwen3vl::processor::Qwen3VLProcessor,
    },
    tokenizer::TokenizerModel,
    utils::{find_type_files, get_device, get_dtype},
};

pub struct Qwen3_5GenerateModel<'a> {
    chat_template: ChatTemplate<'a>,
    tokenizer: TokenizerModel,
    pre_processor: Option<Qwen3VLProcessor>,
    model: Qwen3_5Model,
    device: Device,
    model_name: String,
    repeat_penalty: f32,
    repeat_last_n: usize,
}

impl<'a> Qwen3_5GenerateModel<'a> {
    pub fn init(path: &str, device: Option<&Device>, dtype: Option<DType>) -> Result<Self> {
        let model_name = std::path::Path::new(path)
            .file_name()
            .and_then(|s| s.to_str())
            .unwrap_or("qwen3.5");
        let chat_template = ChatTemplate::init(path)?;
        let tokenizer = TokenizerModel::init(path)?;
        let config_path = path.to_string() + "/config.json";
        let cfg: Qwen3_5Config = serde_json::from_slice(&std::fs::read(config_path)?)?;
        let device = get_device(device);
        let cfg_dtype = cfg.text_config.dtype.as_str();
        let dtype = get_dtype(dtype, cfg_dtype);
        let pre_processor = Qwen3VLProcessor::new(path, &device, dtype)?;
        let model_list = find_type_files(path, "safetensors")?;
        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_list, dtype, &device)? };
        let eos_ids = vec![cfg.text_config.eos_token_id];
        let model = Qwen3_5Model::new_from_vb(vb, cfg, eos_ids)?;

        Ok(Self {
            chat_template,
            tokenizer,
            pre_processor: Some(pre_processor),
            model,
            device,
            model_name: model_name.to_string(),
            repeat_penalty: 1.0,
            repeat_last_n: 64,
        })
    }

    pub fn init_without_visual(
        path: &str,
        device: Option<&Device>,
        dtype: Option<DType>,
    ) -> Result<Self> {
        let model_name = std::path::Path::new(path)
            .file_name()
            .and_then(|s| s.to_str())
            .unwrap_or("qwen3.5");
        let chat_template = ChatTemplate::init(path)?;
        let tokenizer = TokenizerModel::init(path)?;
        let config_path = path.to_string() + "/config.json";
        let cfg: Qwen3_5Config = serde_json::from_slice(&std::fs::read(config_path)?)?;
        let device = get_device(device);
        let cfg_dtype = cfg.text_config.dtype.as_str();
        let dtype = get_dtype(dtype, cfg_dtype);
        // let pre_processor = Qwen3VLProcessor::new(path, &device, dtype)?;
        let pre_processor = None;
        let model_list = find_type_files(path, "safetensors")?;
        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_list, dtype, &device)? };
        let eos_ids = vec![cfg.text_config.eos_token_id];
        // let qwen3_5 = Qwen3_5Model::new_from_vb(vb, cfg, eos_ids)?;
        let model = Qwen3_5Model::new_from_vb_without_visual(vb, cfg, eos_ids)?;

        Ok(Self {
            chat_template,
            tokenizer,
            pre_processor,
            model,
            device,
            model_name: model_name.to_string(),
            repeat_penalty: 1.0,
            repeat_last_n: 64,
        })
    }

    pub fn init_from_gguf(
        model_file: &str,
        mmproj_file: Option<&str>,
        device: Option<&Device>,
    ) -> Result<Self> {
        if !model_file.contains("Qwen3.5") || !model_file.ends_with("gguf") {
            return Err(anyhow!("Qwen3.5 gguf model file name illigal {model_file}"));
        }
        if let Some(mmproj) = mmproj_file
            && (!mmproj.contains("mmproj") || !mmproj.ends_with("gguf"))
        {
            return Err(anyhow!("Qwen3.5 mmproj_file name illigal {model_file}"));
        }

        let mut reader = std::fs::File::open(model_file)?;
        let content = gguf_file::Content::read(&mut reader)?;
        let device = get_device(device);
        let mut model_gguf = Gguf::new(content, reader, device.clone());

        let chat_template_str = model_gguf
            .get_matedata("tokenizer.chat_template")?
            .to_string()?
            .clone();
        let chat_template = ChatTemplate::str_init(&chat_template_str)?;
        let tokenizer = model_gguf.build_tokenizer(Some(false), Some(false), Some(false))?;
        let (pre_processor, mut mmproj_gguf) = if let Some(mmproj_f) = mmproj_file {
            let mut reader = std::fs::File::open(mmproj_f)?;
            let content = gguf_file::Content::read(&mut reader)?;
            let mmproj_gguf = Gguf::new(content, reader, device.clone());
            let processor = Qwen3VLProcessor::new_qwen3_5_default(&device, DType::F32)?;
            (Some(processor), Some(mmproj_gguf))
        } else {
            (None, None)
        };

        let eos_token_id = model_gguf
            .get_matedata("tokenizer.ggml.eos_token_id")?
            .to_u32()?;
        let eos_ids = vec![eos_token_id];
        let model =
            Qwen3_5Model::new_from_gguf(&mut model_gguf, mmproj_gguf.as_mut(), &device, eos_ids)?;
        let stem = std::path::Path::new(model_file)
            .file_stem() // 获取文件名主干(不含扩展名)
            .and_then(|s| s.to_str())
            .unwrap_or("qwen3.5");
        Ok(Self {
            chat_template,
            tokenizer,
            pre_processor,
            model,
            device,
            model_name: stem.to_string(),
            repeat_penalty: 1.2,
            repeat_last_n: 64,
        })
    }

    pub fn generate_text(&mut self, mes: ChatCompletionParameters) -> Result<String> {
        let seed = mes.seed.unwrap_or(32768) as u64;
        let temperature = mes.temperature.unwrap_or(0.4);
        let top_p = mes.top_p.unwrap_or(0.95);
        let mes_render = self.chat_template.apply_chat_template(&mes)?;
        let (mes_text, pixel_values, image_grid_thw, pixel_values_video, video_grid_thw) =
            if let Some(processor) = &self.pre_processor {
                let input = processor.process_info(&mes, &mes_render)?;
                (
                    input.replace_text,
                    input.pixel_values,
                    input.image_grid_thw,
                    input.pixel_values_video,
                    input.video_grid_thw,
                )
            } else {
                (mes_render, None, None, None, None)
            };
        let input_ids = self.tokenizer.text_encode(mes_text, &self.device)?;
        let sample_len = mes.max_tokens.unwrap_or(1024);
        let mut ctx = GenerationContext::new(
            temperature.into(),
            top_p.into(),
            Some(20),
            self.repeat_penalty.into(),
            self.repeat_last_n.into(),
            seed,
            input_ids.dim(1)?,
            sample_len,
            self.device.clone(),
        );
        let data_vec = vec![
            pixel_values,
            image_grid_thw,
            pixel_values_video,
            video_grid_thw,
        ];
        let data = MultiModalData::new(data_vec);

        generate_generic_text(&mut self.model, &self.tokenizer, input_ids, data, &mut ctx)
    }

    pub fn generate_stream_text(
        &mut self,
        mes: ChatCompletionParameters,
    ) -> Result<impl Stream<Item = Result<String, anyhow::Error>>> {
        let mes_render = self.chat_template.apply_chat_template(&mes)?;
        let (mes_text, pixel_values, image_grid_thw, pixel_values_video, video_grid_thw) =
            if let Some(processor) = &self.pre_processor {
                let input = processor.process_info(&mes, &mes_render)?;
                (
                    input.replace_text,
                    input.pixel_values,
                    input.image_grid_thw,
                    input.pixel_values_video,
                    input.video_grid_thw,
                )
            } else {
                (mes_render, None, None, None, None)
            };
        let input_ids = self.tokenizer.text_encode(mes_text, &self.device)?;
        let sample_len = mes.max_tokens.unwrap_or(1024);
        let data_vec = vec![
            pixel_values,
            image_grid_thw,
            pixel_values_video,
            video_grid_thw,
        ];
        let data = MultiModalData::new(data_vec);
        let seed = mes.seed.unwrap_or(34562) as u64;
        generate_stream_generic_text(
            &mut self.model,
            &self.tokenizer,
            input_ids,
            data,
            mes.temperature,
            mes.top_p,
            None,
            self.repeat_penalty.into(),
            self.repeat_last_n.into(),
            seed,
            sample_len,
            &self.device,
        )
    }
}

impl<'a> GenerationDataProvider for Qwen3_5GenerateModel<'a> {
    fn get_temperature(&self, req_temp: Option<f32>) -> Option<f32> {
        Some(req_temp.unwrap_or(0.4))
    }

    fn get_top_p(&self, req_top_p: Option<f32>) -> Option<f32> {
        Some(req_top_p.unwrap_or(0.95))
    }

    fn get_top_k(&self, top_k: Option<usize>) -> Option<usize> {
        Some(top_k.unwrap_or(40))
    }

    fn get_data(&self, mes: &ChatCompletionParameters) -> Result<PrepareData> {
        let mes_render = self.chat_template.apply_chat_template(mes)?;
        let in_reasoning = self.is_in_reasoning(&mes_render);
        let (mes_text, pixel_values, image_grid_thw, pixel_values_video, video_grid_thw) =
            if let Some(processor) = &self.pre_processor {
                let input = processor.process_info(mes, &mes_render)?;
                (
                    input.replace_text,
                    input.pixel_values,
                    input.image_grid_thw,
                    input.pixel_values_video,
                    input.video_grid_thw,
                )
            } else {
                (mes_render, None, None, None, None)
            };
        let input_ids = self.tokenizer.text_encode(mes_text, &self.device)?;
        let data_vec = vec![
            pixel_values,
            image_grid_thw,
            pixel_values_video,
            video_grid_thw,
        ];
        let multi_model_data = MultiModalData::new(data_vec);
        Ok(PrepareData {
            in_reasoning,
            input_ids,
            multi_model_data,
        })
    }
}

crate::impl_generate_model!(Qwen3_5GenerateModel<'a>);