use crate::error::{Error, Result};
use crate::model::audio::whisper::WhisperEncoder;
use crate::model::audio::whisper_decoder::WhisperDecoder;
use crate::model::config::AudioConfig;
use crate::nn::VarBuilder;
use numr::dtype::DType;
use numr::ops::{
ActivationOps, BinaryOps, ConditionalOps, ConvOps, IndexingOps, MatmulOps, NormalizationOps,
ReduceOps, ScalarOps, ShapeOps, TensorOps, UnaryOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct GenerateOptions {
pub max_new_tokens: usize,
pub eos_token_ids: Vec<u32>,
pub suppress_tokens: Vec<u32>,
}
impl Default for GenerateOptions {
fn default() -> Self {
Self {
max_new_tokens: 448,
eos_token_ids: Vec::new(),
suppress_tokens: Vec::new(),
}
}
}
pub struct WhisperModel<R: Runtime> {
pub encoder: WhisperEncoder<R>,
pub decoder: WhisperDecoder<R>,
config: AudioConfig,
}
impl<R: Runtime<DType = DType>> WhisperModel<R> {
pub fn from_varbuilder(vb: &mut VarBuilder<'_, R>, config: &AudioConfig) -> Result<Self> {
let mut model_vb = vb.pp("model");
let mut enc_vb = model_vb.pp("encoder");
let encoder = WhisperEncoder::from_varbuilder(&mut enc_vb, config)?;
drop(enc_vb);
let mut dec_vb = model_vb.pp("decoder");
let decoder = WhisperDecoder::from_varbuilder(&mut dec_vb, config)?;
Ok(Self {
encoder,
decoder,
config: config.clone(),
})
}
pub fn config(&self) -> &AudioConfig {
&self.config
}
pub fn encode<C>(&self, client: &C, mel: &Tensor<R>) -> Result<Tensor<R>>
where
C: RuntimeClient<R>
+ TensorOps<R>
+ ScalarOps<R>
+ MatmulOps<R>
+ BinaryOps<R>
+ ActivationOps<R>
+ NormalizationOps<R>
+ ConvOps<R>
+ ReduceOps<R>
+ ShapeOps<R>,
R::Client: TensorOps<R> + ScalarOps<R> + ConvOps<R> + ReduceOps<R> + BinaryOps<R>,
{
self.encoder.forward_inference(client, mel)
}
pub fn generate<C>(
&self,
client: &C,
encoder_out: &Tensor<R>,
start_tokens: &[u32],
options: &GenerateOptions,
) -> Result<Vec<u32>>
where
C: RuntimeClient<R>
+ TensorOps<R>
+ ScalarOps<R>
+ MatmulOps<R>
+ BinaryOps<R>
+ ActivationOps<R>
+ NormalizationOps<R>
+ ReduceOps<R>
+ ShapeOps<R>
+ UnaryOps<R>
+ ConditionalOps<R>
+ IndexingOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
assert_eq!(
encoder_out.shape()[0],
1,
"WhisperModel::generate currently supports batch=1"
);
if start_tokens.is_empty() {
return Err(Error::ModelError {
reason: "generate requires at least one start token".into(),
});
}
let device = encoder_out.device();
let mut cache = self.decoder.new_cache();
let mut generated: Vec<u32> = Vec::with_capacity(options.max_new_tokens);
let mut position: usize = 0;
let prefix_i64: Vec<i64> = start_tokens.iter().map(|&t| t as i64).collect();
let prefix_tensor = Tensor::<R>::from_slice(&prefix_i64, &[1, prefix_i64.len()], device);
let logits = self.decoder.forward_with_cache(
client,
&prefix_tensor,
encoder_out,
position,
&mut cache,
)?;
position += start_tokens.len();
let mut next_token = greedy_pick_last(
client,
&logits,
self.decoder.vocab_size(),
&options.suppress_tokens,
)?;
for _ in 0..options.max_new_tokens {
if options.eos_token_ids.contains(&next_token) {
break;
}
generated.push(next_token);
let step_ids = Tensor::<R>::from_slice(&[next_token as i64], &[1, 1], device);
let logits = self.decoder.forward_with_cache(
client,
&step_ids,
encoder_out,
position,
&mut cache,
)?;
position += 1;
next_token = greedy_pick_last(
client,
&logits,
self.decoder.vocab_size(),
&options.suppress_tokens,
)?;
}
if !options.eos_token_ids.contains(&next_token) && generated.len() < options.max_new_tokens
{
generated.push(next_token);
}
Ok(generated)
}
}
fn greedy_pick_last<R, C>(
client: &C,
logits: &Tensor<R>,
vocab_size: usize,
suppress: &[u32],
) -> Result<u32>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + BinaryOps<R> + ScalarOps<R>,
R::Client: TensorOps<R>,
{
let shape = logits.shape();
let t = shape[1];
let last = logits.narrow(1, t - 1, 1).map_err(Error::Numr)?;
if suppress.is_empty() {
let data: Vec<f32> = last.to_vec();
return Ok(argmax_f32(&data) as u32);
}
let mut mask = vec![0.0f32; vocab_size];
for &id in suppress {
if (id as usize) < vocab_size {
mask[id as usize] = f32::NEG_INFINITY;
}
}
let device = logits.device();
let mask_t = Tensor::<R>::from_slice(&mask, &[1, 1, vocab_size], device);
let masked = client.add(&last, &mask_t).map_err(Error::Numr)?;
let data: Vec<f32> = masked.to_vec();
Ok(argmax_f32(&data) as u32)
}
fn argmax_f32(xs: &[f32]) -> usize {
let mut best = 0usize;
let mut best_val = f32::NEG_INFINITY;
for (i, &v) in xs.iter().enumerate() {
if v > best_val {
best_val = v;
best = i;
}
}
best
}