aha 0.2.5

aha model inference library, now supports Qwen(2.5VL/3/3VL/3.5/ASR/3Embedding/3Reranker), MiniCPM4, VoxCPM/1.5, DeepSeek-OCR/2, Hunyuan-OCR, PaddleOCR-VL/1.5, RMBG2.0, GLM(ASR-Nano-2512/OCR), Fun-ASR-Nano-2512, LFM(2/2.5/2VL/2.5VL)
Documentation
use candle_nn::Activation;

use crate::models::qwen3::config::Qwen3Config;

#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Qwen3ASRConfig {
    pub model_type: String,
    pub support_languages: Vec<String>,
    pub thinker_config: ThinkerConfig,
}

#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct ThinkerConfig {
    pub model_type: String,
    pub audio_config: Qwen3ASRAudioConfig,
    pub audio_end_token_id: u32,
    pub audio_start_token_id: u32,
    pub audio_token_id: u32,
    pub dtype: String,
    pub initializer_range: f64,
    pub text_config: Qwen3ASRTextConfig,
}

#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Qwen3ASRAudioConfig {
    pub activation_dropout: f32,
    pub activation_function: Activation,
    pub add_cross_attention: bool,
    pub attention_dropout: f32,
    pub bad_words_ids: Option<Vec<u32>>,
    pub begin_suppress_tokens: Option<Vec<u32>>,
    pub bos_token_id: Option<u32>,
    pub chunk_size_feed_forward: usize,
    pub conv_chunksize: usize,
    pub cross_attention_hidden_size: Option<usize>,
    pub d_model: usize,
    pub decoder_start_token_id: Option<u32>,
    pub diversity_penalty: f64,
    pub do_sample: bool,
    pub downsample_hidden_size: usize,
    pub dropout: f32,
    pub early_stopping: bool,
    pub encoder_attention_heads: usize,
    pub encoder_ffn_dim: usize,
    pub encoder_layers: usize,
    pub encoder_no_repeat_ngram_size: usize,
    pub eos_token_id: Option<u32>,
    pub exponential_decay_length_penalty: Option<Vec<f32>>,
    pub forced_bos_token_id: Option<u32>,
    pub forced_eos_token_id: Option<u32>,
    pub id2label: std::collections::HashMap<String, String>,
    pub initializer_range: f64,
    pub is_decoder: bool,
    pub is_encoder_decoder: bool,
    pub label2id: std::collections::HashMap<String, usize>,
    pub length_penalty: f64,
    pub max_length: usize,
    pub max_source_positions: usize,
    pub min_length: usize,
    pub model_type: String,
    pub n_window: usize,
    pub n_window_infer: usize,
    pub no_repeat_ngram_size: usize,
    pub num_beam_groups: usize,
    pub num_beams: usize,
    pub num_hidden_layers: usize,
    pub num_mel_bins: usize,
    pub num_return_sequences: usize,
    pub output_attentions: bool,
    pub output_dim: usize,
    pub output_hidden_states: bool,
    pub output_scores: bool,
    pub pad_token_id: Option<u32>,
    pub prefix: Option<String>,
    pub problem_type: Option<String>,
    pub pruned_heads: std::collections::HashMap<String, Vec<i32>>,
    pub remove_invalid_values: bool,
    pub repetition_penalty: f64,
    pub return_dict: bool,
    pub return_dict_in_generate: bool,
    pub scale_embedding: bool,
    pub sep_token_id: Option<u32>,
    pub suppress_tokens: Option<Vec<u32>>,
    pub task_specific_params: Option<std::collections::HashMap<String, serde_json::Value>>,
    pub temperature: f64,
    pub tf_legacy_loss: bool,
    pub tie_encoder_decoder: bool,
    pub tie_word_embeddings: bool,
    pub tokenizer_class: Option<String>,
    pub top_k: usize,
    pub top_p: f64,
    pub torchscript: bool,
    pub typical_p: f64,
    pub use_bfloat16: bool,
}

#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Qwen3ASRTextConfig {
    pub add_cross_attention: bool,
    pub attention_bias: bool,
    pub attention_dropout: f64,
    pub bad_words_ids: Option<Vec<u32>>,
    pub begin_suppress_tokens: Option<Vec<u32>>,
    pub bos_token_id: Option<u32>,
    pub chunk_size_feed_forward: usize,
    pub cross_attention_hidden_size: Option<usize>,
    pub decoder_start_token_id: Option<u32>,
    pub diversity_penalty: f64,
    pub do_sample: bool,
    pub dtype: Option<String>,
    pub early_stopping: bool,
    pub encoder_no_repeat_ngram_size: usize,
    pub eos_token_id: Option<u32>,
    pub exponential_decay_length_penalty: Option<Vec<f32>>,
    pub forced_bos_token_id: Option<u32>,
    pub forced_eos_token_id: Option<u32>,
    pub head_dim: usize,
    pub hidden_act: Activation,
    pub hidden_size: usize,
    pub id2label: std::collections::HashMap<String, String>,
    pub initializer_range: f64,
    pub intermediate_size: usize,
    pub is_decoder: bool,
    pub is_encoder_decoder: bool,
    pub label2id: std::collections::HashMap<String, u32>,
    pub length_penalty: f64,
    pub max_length: usize,
    pub max_position_embeddings: usize,
    pub min_length: usize,
    pub model_type: String,
    pub no_repeat_ngram_size: usize,
    pub num_attention_heads: usize,
    pub num_beam_groups: usize,
    pub num_beams: usize,
    pub num_hidden_layers: usize,
    pub num_key_value_heads: usize,
    pub num_return_sequences: usize,
    pub output_attentions: bool,
    pub output_hidden_states: bool,
    pub output_scores: bool,
    pub pad_token_id: Option<u32>,
    pub prefix: Option<String>,
    pub problem_type: Option<String>,
    pub remove_invalid_values: bool,
    pub repetition_penalty: f64,
    pub return_dict: bool,
    pub return_dict_in_generate: bool,
    pub rms_norm_eps: f64,
    pub rope_scaling: Qwen3ASRRopeScaling,
    pub rope_theta: f32,
    pub sep_token_id: Option<u32>,
    pub suppress_tokens: Option<Vec<u32>>,
    pub temperature: f64,
    pub tf_legacy_loss: bool,
    pub tie_encoder_decoder: bool,
    pub tie_word_embeddings: bool,
    pub tokenizer_class: Option<String>,
    pub top_k: usize,
    pub top_p: f64,
    pub torchscript: bool,
    pub typical_p: f64,
    pub use_bfloat16: bool,
    pub use_cache: bool,
    pub vocab_size: usize,
}

pub fn qwen3asr_text_config2qwen3_config(cfg: &Qwen3ASRTextConfig) -> Qwen3Config {
    Qwen3Config {
        attention_bias: cfg.attention_bias,
        attention_dropout: cfg.attention_dropout,
        bos_token_id: cfg.bos_token_id.unwrap_or(151643),
        eos_token_id: cfg.eos_token_id.unwrap_or(151645),
        head_dim: cfg.head_dim,
        hidden_act: cfg.hidden_act,
        hidden_size: cfg.hidden_size,
        initializer_range: cfg.initializer_range,
        intermediate_size: cfg.intermediate_size,
        max_position_embeddings: cfg.max_position_embeddings,
        max_window_layers: 0,
        num_attention_heads: cfg.num_attention_heads,
        num_hidden_layers: cfg.num_hidden_layers,
        num_key_value_heads: cfg.num_key_value_heads,
        rms_norm_eps: cfg.rms_norm_eps,
        rope_theta: cfg.rope_theta,
        tie_word_embeddings: true,
        torch_dtype: "bfloat16".to_string(),
        use_cache: cfg.use_cache,
        use_sliding_window: false,
        vocab_size: cfg.vocab_size,
    }
}

#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Qwen3ASRRopeScaling {
    pub interleaved: bool,
    pub mrope_interleaved: bool,
    pub mrope_section: Vec<usize>,
    pub rope_type: String,
    pub r#type: String,
}

#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Qwen3ASRGenerationConfig {
    pub do_sample: bool,
    pub eos_token_id: Vec<u32>,
    pub pad_token_id: usize,
    pub temperature: f32,
}