use candle_nn::Activation;
use crate::models::qwen3::config::Qwen3Config;
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Qwen3ASRConfig {
pub model_type: String,
pub support_languages: Vec<String>,
pub thinker_config: ThinkerConfig,
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct ThinkerConfig {
pub model_type: String,
pub audio_config: Qwen3ASRAudioConfig,
pub audio_end_token_id: u32,
pub audio_start_token_id: u32,
pub audio_token_id: u32,
pub dtype: String,
pub initializer_range: f64,
pub text_config: Qwen3ASRTextConfig,
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Qwen3ASRAudioConfig {
pub activation_dropout: f32,
pub activation_function: Activation,
pub add_cross_attention: bool,
pub attention_dropout: f32,
pub bad_words_ids: Option<Vec<u32>>,
pub begin_suppress_tokens: Option<Vec<u32>>,
pub bos_token_id: Option<u32>,
pub chunk_size_feed_forward: usize,
pub conv_chunksize: usize,
pub cross_attention_hidden_size: Option<usize>,
pub d_model: usize,
pub decoder_start_token_id: Option<u32>,
pub diversity_penalty: f64,
pub do_sample: bool,
pub downsample_hidden_size: usize,
pub dropout: f32,
pub early_stopping: bool,
pub encoder_attention_heads: usize,
pub encoder_ffn_dim: usize,
pub encoder_layers: usize,
pub encoder_no_repeat_ngram_size: usize,
pub eos_token_id: Option<u32>,
pub exponential_decay_length_penalty: Option<Vec<f32>>,
pub forced_bos_token_id: Option<u32>,
pub forced_eos_token_id: Option<u32>,
pub id2label: std::collections::HashMap<String, String>,
pub initializer_range: f64,
pub is_decoder: bool,
pub is_encoder_decoder: bool,
pub label2id: std::collections::HashMap<String, usize>,
pub length_penalty: f64,
pub max_length: usize,
pub max_source_positions: usize,
pub min_length: usize,
pub model_type: String,
pub n_window: usize,
pub n_window_infer: usize,
pub no_repeat_ngram_size: usize,
pub num_beam_groups: usize,
pub num_beams: usize,
pub num_hidden_layers: usize,
pub num_mel_bins: usize,
pub num_return_sequences: usize,
pub output_attentions: bool,
pub output_dim: usize,
pub output_hidden_states: bool,
pub output_scores: bool,
pub pad_token_id: Option<u32>,
pub prefix: Option<String>,
pub problem_type: Option<String>,
pub pruned_heads: std::collections::HashMap<String, Vec<i32>>,
pub remove_invalid_values: bool,
pub repetition_penalty: f64,
pub return_dict: bool,
pub return_dict_in_generate: bool,
pub scale_embedding: bool,
pub sep_token_id: Option<u32>,
pub suppress_tokens: Option<Vec<u32>>,
pub task_specific_params: Option<std::collections::HashMap<String, serde_json::Value>>,
pub temperature: f64,
pub tf_legacy_loss: bool,
pub tie_encoder_decoder: bool,
pub tie_word_embeddings: bool,
pub tokenizer_class: Option<String>,
pub top_k: usize,
pub top_p: f64,
pub torchscript: bool,
pub typical_p: f64,
pub use_bfloat16: bool,
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Qwen3ASRTextConfig {
pub add_cross_attention: bool,
pub attention_bias: bool,
pub attention_dropout: f64,
pub bad_words_ids: Option<Vec<u32>>,
pub begin_suppress_tokens: Option<Vec<u32>>,
pub bos_token_id: Option<u32>,
pub chunk_size_feed_forward: usize,
pub cross_attention_hidden_size: Option<usize>,
pub decoder_start_token_id: Option<u32>,
pub diversity_penalty: f64,
pub do_sample: bool,
pub dtype: Option<String>,
pub early_stopping: bool,
pub encoder_no_repeat_ngram_size: usize,
pub eos_token_id: Option<u32>,
pub exponential_decay_length_penalty: Option<Vec<f32>>,
pub forced_bos_token_id: Option<u32>,
pub forced_eos_token_id: Option<u32>,
pub head_dim: usize,
pub hidden_act: Activation,
pub hidden_size: usize,
pub id2label: std::collections::HashMap<String, String>,
pub initializer_range: f64,
pub intermediate_size: usize,
pub is_decoder: bool,
pub is_encoder_decoder: bool,
pub label2id: std::collections::HashMap<String, u32>,
pub length_penalty: f64,
pub max_length: usize,
pub max_position_embeddings: usize,
pub min_length: usize,
pub model_type: String,
pub no_repeat_ngram_size: usize,
pub num_attention_heads: usize,
pub num_beam_groups: usize,
pub num_beams: usize,
pub num_hidden_layers: usize,
pub num_key_value_heads: usize,
pub num_return_sequences: usize,
pub output_attentions: bool,
pub output_hidden_states: bool,
pub output_scores: bool,
pub pad_token_id: Option<u32>,
pub prefix: Option<String>,
pub problem_type: Option<String>,
pub remove_invalid_values: bool,
pub repetition_penalty: f64,
pub return_dict: bool,
pub return_dict_in_generate: bool,
pub rms_norm_eps: f64,
pub rope_scaling: Qwen3ASRRopeScaling,
pub rope_theta: f32,
pub sep_token_id: Option<u32>,
pub suppress_tokens: Option<Vec<u32>>,
pub temperature: f64,
pub tf_legacy_loss: bool,
pub tie_encoder_decoder: bool,
pub tie_word_embeddings: bool,
pub tokenizer_class: Option<String>,
pub top_k: usize,
pub top_p: f64,
pub torchscript: bool,
pub typical_p: f64,
pub use_bfloat16: bool,
pub use_cache: bool,
pub vocab_size: usize,
}
pub fn qwen3asr_text_config2qwen3_config(cfg: &Qwen3ASRTextConfig) -> Qwen3Config {
Qwen3Config {
attention_bias: cfg.attention_bias,
attention_dropout: cfg.attention_dropout,
bos_token_id: cfg.bos_token_id.unwrap_or(151643),
eos_token_id: cfg.eos_token_id.unwrap_or(151645),
head_dim: cfg.head_dim,
hidden_act: cfg.hidden_act,
hidden_size: cfg.hidden_size,
initializer_range: cfg.initializer_range,
intermediate_size: cfg.intermediate_size,
max_position_embeddings: cfg.max_position_embeddings,
max_window_layers: 0,
num_attention_heads: cfg.num_attention_heads,
num_hidden_layers: cfg.num_hidden_layers,
num_key_value_heads: cfg.num_key_value_heads,
rms_norm_eps: cfg.rms_norm_eps,
rope_theta: cfg.rope_theta,
tie_word_embeddings: true,
torch_dtype: "bfloat16".to_string(),
use_cache: cfg.use_cache,
use_sliding_window: false,
vocab_size: cfg.vocab_size,
}
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Qwen3ASRRopeScaling {
pub interleaved: bool,
pub mrope_interleaved: bool,
pub mrope_section: Vec<usize>,
pub rope_type: String,
pub r#type: String,
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Qwen3ASRGenerationConfig {
pub do_sample: bool,
pub eos_token_id: Vec<u32>,
pub pad_token_id: usize,
pub temperature: f32,
}