use std::time::Instant;
use crate::{
models::common::generate::get_logit_processor,
params::chat::{ChatCompletionChunkResponse, ChatCompletionParameters, ChatCompletionResponse},
utils::response_utils::{build_chunk_response_with_usage, build_completion_response_with_time},
};
use anyhow::{Result, anyhow};
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use rocket::async_stream::stream;
use rocket::futures::Stream;
use crate::{
chat_template::ChatTemplate,
models::{
GenerateModel,
feature_extractor::config::FeatureExtractor,
qwen3_asr::{
config::{Qwen3ASRConfig, Qwen3ASRGenerationConfig},
model::Qwen3ASRModel,
processor::Qwen3AsrProcessor,
},
},
tokenizer::TokenizerModel,
utils::{
find_type_files, get_device, get_dtype, response_utils::build_completion_chunk_response,
},
};
pub struct Qwen3AsrGenerateModel<'a> {
chat_template: ChatTemplate<'a>,
tokenizer: TokenizerModel,
processor: Qwen3AsrProcessor,
qwen3_asr: Qwen3ASRModel,
device: Device,
dtype: DType,
eos_token_id1: u32,
eos_token_id2: u32,
generation_config: Qwen3ASRGenerationConfig,
model_name: String,
}
impl<'a> Qwen3AsrGenerateModel<'a> {
pub fn init(path: &str, device: Option<&Device>, dtype: Option<DType>) -> Result<Self> {
let chat_template = ChatTemplate::init(path)?;
let tokenizer = TokenizerModel::init(path)?;
let generation_config_path = path.to_string() + "/generation_config.json";
let generation_config: Qwen3ASRGenerationConfig =
serde_json::from_slice(&std::fs::read(generation_config_path)?)?;
let device = get_device(device);
let preprocess_config_path = path.to_string() + "/preprocessor_config.json";
let preprocess_config: FeatureExtractor =
serde_json::from_slice(&std::fs::read(preprocess_config_path)?)?;
let processor = Qwen3AsrProcessor::new(&device, &preprocess_config)?;
let config_path = path.to_string() + "/config.json";
let cfg: Qwen3ASRConfig = serde_json::from_slice(&std::fs::read(config_path)?)?;
let cfg_dtype = cfg.thinker_config.dtype.as_str();
let dtype = get_dtype(dtype, cfg_dtype);
let model_list = find_type_files(path, "safetensors")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_list, dtype, &device)? };
let qwen3_asr = Qwen3ASRModel::new(vb, &cfg)?;
let model_name = std::path::Path::new(path)
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("qwen3-asr")
.to_string();
Ok(Self {
chat_template,
tokenizer,
processor,
qwen3_asr,
device,
dtype,
eos_token_id1: generation_config.eos_token_id[0] as u32,
eos_token_id2: generation_config.eos_token_id[1] as u32,
generation_config,
model_name,
})
}
}
impl<'a> GenerateModel for Qwen3AsrGenerateModel<'a> {
fn generate(&mut self, mes: ChatCompletionParameters) -> Result<ChatCompletionResponse> {
let temperature = mes
.temperature
.unwrap_or(self.generation_config.temperature);
let seed = mes.seed.unwrap_or(34562) as u64;
let mut logit_processor = get_logit_processor(Some(temperature), mes.top_p, None, seed);
let render_text = self.chat_template.apply_chat_template(&mes)?;
let audio_datas = self
.processor
.process_info(&mes, &render_text, &self.tokenizer)?;
let sample_len = mes.max_tokens.unwrap_or(1024);
let mut generate = Vec::new();
let mut prompt_tokens = 0u32;
let mut prompt_secs = 0.0f64;
let mut completion_secs = 0.0f64;
for data in audio_datas.iter() {
let mut input_ids = data.input_ids.clone();
let mut input_features = Some(data.input_features.clone().to_dtype(self.dtype)?);
let mut seq_len = input_ids.dim(1)?;
prompt_tokens += seq_len as u32;
let mut seqlen_offset = 0;
for _ in 0..sample_len {
let i_start = Instant::now();
let logits =
self.qwen3_asr
.forward(&input_ids, seqlen_offset, input_features.as_ref())?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let next_token = logit_processor.sample(&logits)?;
let i_duration = i_start.elapsed();
if seqlen_offset == 0 {
prompt_secs += i_duration.as_secs_f64();
} else {
completion_secs += i_duration.as_secs_f64();
};
generate.push(next_token);
if next_token == self.eos_token_id1 || next_token == self.eos_token_id2 {
break;
}
seqlen_offset += seq_len;
seq_len = 1;
input_ids = Tensor::from_vec(vec![next_token], (1, 1), &self.device)?;
input_features = None;
}
self.qwen3_asr.clear_kv_cache();
}
let num_token = generate.len() as u32;
let res = self.tokenizer.token_decode(generate)?;
let response = build_completion_response_with_time(
res,
&self.model_name,
num_token.into(),
completion_secs.into(),
prompt_tokens.into(),
prompt_secs.into(),
);
Ok(response)
}
fn generate_stream(
&mut self,
mes: ChatCompletionParameters,
) -> Result<
Box<
dyn Stream<Item = Result<ChatCompletionChunkResponse, anyhow::Error>>
+ Send
+ Unpin
+ '_,
>,
> {
let temperature = mes
.temperature
.unwrap_or(self.generation_config.temperature);
let seed = mes.seed.unwrap_or(34562) as u64;
let mut logit_processor = get_logit_processor(Some(temperature), mes.top_p, None, seed);
let render_text = self.chat_template.apply_chat_template(&mes)?;
let audio_datas = self
.processor
.process_info(&mes, &render_text, &self.tokenizer)?;
let sample_len = mes.max_tokens.unwrap_or(1024);
let stream = stream! {
let mut error_tokens = Vec::new();
let mut prompt_tokens = 0u32;
let mut completion_tokens = 0u32;
let mut prompt_secs = 0.0f64;
let mut completion_secs = 0.0f64;
for data in audio_datas.iter() {
let mut input_ids = data.input_ids.clone();
let mut input_features = Some(data.input_features.clone().to_dtype(self.dtype)?);
let mut seq_len = input_ids.dim(1)?;
prompt_tokens += seq_len as u32;
let mut seqlen_offset = 0;
for _ in 0..sample_len {
let i_start = Instant::now();
let logits =
self.qwen3_asr
.forward(&input_ids, seqlen_offset, input_features.as_ref())?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let next_token = logit_processor.sample(&logits)?;
completion_tokens += 1;
let i_duration = i_start.elapsed();
if seqlen_offset == 0 {
prompt_secs += i_duration.as_secs_f64();
} else {
completion_secs += i_duration.as_secs_f64();
};
let mut decode_ids = Vec::new();
if !error_tokens.is_empty() {
decode_ids.extend_from_slice(&error_tokens);
}
decode_ids.push(next_token);
let decoded_token = self.tokenizer.token_decode(decode_ids).map_err(|e| anyhow!(format!("stream decode error{e}")))?;
if decoded_token.contains("�") {
error_tokens.push(next_token);
if error_tokens.len() > 3 {
error_tokens.clear();
}
seqlen_offset += seq_len;
seq_len = 1;
input_ids = Tensor::from_vec(vec![next_token], (1, 1), &self.device)?;
input_features = None;
continue;
}
error_tokens.clear();
let chunk = build_completion_chunk_response(decoded_token, &self.model_name, None, None);
yield Ok(chunk);
if next_token == self.eos_token_id1 || next_token == self.eos_token_id2 {
yield Ok(build_chunk_response_with_usage(&self.model_name, completion_tokens.into(), completion_secs.into(), prompt_tokens.into(), prompt_secs.into()));
break;
}
seqlen_offset += seq_len;
seq_len = 1;
input_ids = Tensor::from_vec(vec![next_token], (1, 1), &self.device)?;
input_features = None;
}
self.qwen3_asr.clear_kv_cache();
}
};
Ok(Box::new(Box::pin(stream)))
}
}