use crate::whisper::config::WhisperConfig;
use trustformers_core::{
device::Device,
errors::{Result, TrustformersError},
layers::{layernorm::LayerNorm, linear::Linear},
tensor::Tensor,
traits::{Config, Layer},
};
fn make_contiguous(t: Tensor) -> Result<Tensor> {
let shape = t.shape().to_vec();
t.reshape(&shape)
}
pub struct Conv1d {
weight: Vec<f32>, bias: Option<Vec<f32>>,
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
}
impl Conv1d {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
use_bias: bool,
) -> Self {
let weight = vec![0.0f32; out_channels * in_channels * kernel_size];
let bias = if use_bias { Some(vec![0.0f32; out_channels]) } else { None };
Self {
weight,
bias,
in_channels,
out_channels,
kernel_size,
stride,
padding,
}
}
pub fn forward(
&self,
input: &[f32],
batch: usize,
time_in: usize,
) -> Result<(Vec<f32>, usize)> {
let time_out = (time_in + 2 * self.padding - self.kernel_size) / self.stride + 1;
let total = batch * self.out_channels * time_out;
let mut output = vec![0.0f32; total];
for b in 0..batch {
for oc in 0..self.out_channels {
for t_out in 0..time_out {
let t_start = t_out * self.stride;
let mut acc = 0.0f32;
for ic in 0..self.in_channels {
for k in 0..self.kernel_size {
let t_in = t_start + k;
if t_in >= self.padding && t_in < time_in + self.padding {
let actual_t = t_in - self.padding;
let inp_idx =
b * self.in_channels * time_in + ic * time_in + actual_t;
let w_idx = oc * self.in_channels * self.kernel_size
+ ic * self.kernel_size
+ k;
acc += input[inp_idx] * self.weight[w_idx];
}
}
}
if let Some(ref b_vec) = self.bias {
acc += b_vec[oc];
}
acc = gelu_approx(acc);
let out_idx = b * self.out_channels * time_out + oc * time_out + t_out;
output[out_idx] = acc;
}
}
}
Ok((output, time_out))
}
pub fn out_channels(&self) -> usize {
self.out_channels
}
}
fn gelu_approx(x: f32) -> f32 {
let c = 0.044715_f32;
let inner = (2.0_f32 / std::f32::consts::PI).sqrt() * (x + c * x.powi(3));
0.5 * x * (1.0 + inner.tanh())
}
pub struct WhisperEncoderLayer {
self_attn_q: Linear,
self_attn_k: Linear,
self_attn_v: Linear,
self_attn_out: Linear,
self_attn_layer_norm: LayerNorm,
fc1: Linear,
fc2: Linear,
final_layer_norm: LayerNorm,
num_heads: usize,
head_dim: usize,
}
impl WhisperEncoderLayer {
pub fn new(config: &WhisperConfig) -> Result<Self> {
let d = config.d_model;
let heads = config.encoder_attention_heads;
let head_dim = config.encoder_head_dim();
Ok(Self {
self_attn_q: Linear::new(d, d, false),
self_attn_k: Linear::new(d, d, false),
self_attn_v: Linear::new(d, d, false),
self_attn_out: Linear::new(d, d, false),
self_attn_layer_norm: LayerNorm::new(vec![d], 1e-5)?,
fc1: Linear::new(d, config.encoder_ffn_dim, true),
fc2: Linear::new(config.encoder_ffn_dim, d, true),
final_layer_norm: LayerNorm::new(vec![d], 1e-5)?,
num_heads: heads,
head_dim,
})
}
pub fn new_with_device(config: &WhisperConfig, device: Device) -> Result<Self> {
let d = config.d_model;
let heads = config.encoder_attention_heads;
let head_dim = config.encoder_head_dim();
Ok(Self {
self_attn_q: Linear::new_with_device(d, d, false, device),
self_attn_k: Linear::new_with_device(d, d, false, device),
self_attn_v: Linear::new_with_device(d, d, false, device),
self_attn_out: Linear::new_with_device(d, d, false, device),
self_attn_layer_norm: LayerNorm::new_with_device(vec![d], 1e-5, device)?,
fc1: Linear::new_with_device(d, config.encoder_ffn_dim, true, device),
fc2: Linear::new_with_device(config.encoder_ffn_dim, d, true, device),
final_layer_norm: LayerNorm::new_with_device(vec![d], 1e-5, device)?,
num_heads: heads,
head_dim,
})
}
pub fn forward(&self, hidden_states: Tensor) -> Result<Tensor> {
let shape = hidden_states.shape().to_vec();
if shape.len() < 2 {
return Err(TrustformersError::shape_error(
"WhisperEncoderLayer expects at least 2D input".to_string(),
));
}
let batch_size = shape[0];
let seq_len = shape[1];
let d_model = shape[2];
let normed = make_contiguous(self.self_attn_layer_norm.forward(hidden_states.clone())?)?;
let q = self.self_attn_q.forward(normed.clone())?;
let k = self.self_attn_k.forward(normed.clone())?;
let v = self.self_attn_v.forward(normed)?;
let q = q.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim])?;
let k = k.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim])?;
let v = v.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim])?;
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
let k_t = k.transpose(2, 3)?;
let scores = q.matmul(&k_t)?;
let scale = (self.head_dim as f32).sqrt();
let scores = scores.div_scalar(scale)?;
let attn_weights = scores.softmax(-1)?;
let attn_out = attn_weights.matmul(&v)?;
let attn_out = attn_out.transpose(1, 2)?;
let attn_out = attn_out.reshape(&[batch_size, seq_len, d_model])?;
let attn_out = self.self_attn_out.forward(attn_out)?;
let hidden_c = make_contiguous(hidden_states.clone())?;
let attn_c = make_contiguous(attn_out)?;
let residual1 = hidden_c.add(&attn_c)?;
let normed2 = make_contiguous(self.final_layer_norm.forward(residual1.clone())?)?;
let ff = self.fc1.forward(normed2)?;
let ff = ff.gelu()?;
let ff = self.fc2.forward(ff)?;
let residual1_c = make_contiguous(residual1)?;
let ff_c = make_contiguous(ff)?;
residual1_c.add(&ff_c)
}
}
pub struct WhisperAudioEncoder {
conv1: Conv1d,
conv2: Conv1d,
positional_embedding: Vec<f32>,
layers: Vec<WhisperEncoderLayer>,
layer_norm: LayerNorm,
config: WhisperConfig,
}
impl WhisperAudioEncoder {
pub fn new(config: &WhisperConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &WhisperConfig, device: Device) -> Result<Self> {
let conv1 = Conv1d::new(config.num_mel_bins, config.d_model, 3, 1, 1, true);
let conv2 = Conv1d::new(config.d_model, config.d_model, 3, 2, 1, true);
let positional_embedding = vec![0.0f32; config.max_source_positions * config.d_model];
let mut layers = Vec::new();
for _ in 0..config.encoder_layers {
layers.push(WhisperEncoderLayer::new_with_device(config, device)?);
}
let layer_norm = LayerNorm::new_with_device(vec![config.d_model], 1e-5, device)?;
Ok(Self {
conv1,
conv2,
positional_embedding,
layers,
layer_norm,
config: config.clone(),
})
}
pub fn forward(&self, mel: &Tensor) -> Result<Tensor> {
let shape = mel.shape().to_vec();
if shape.len() != 3 {
return Err(TrustformersError::shape_error(
"WhisperAudioEncoder expects input of shape [batch, mel_bins, time]".to_string(),
));
}
let batch = shape[0];
let _mel_bins = shape[1];
let time_in = shape[2];
let mel_data = match mel {
Tensor::F32(arr) => arr.as_slice().ok_or_else(|| {
TrustformersError::shape_error("non-contiguous mel tensor".to_string())
})?,
_ => {
return Err(TrustformersError::invalid_input_simple(
"WhisperAudioEncoder expects F32 tensor".to_string(),
))
},
};
let (conv1_out, time1) = self.conv1.forward(mel_data, batch, time_in)?;
let (conv2_out, time2) = self.conv2.forward(&conv1_out, batch, time1)?;
let d = self.config.d_model;
let mut hidden = vec![0.0f32; batch * time2 * d];
for b in 0..batch {
for t in 0..time2 {
for c in 0..d {
hidden[b * time2 * d + t * d + c] = conv2_out[b * d * time2 + c * time2 + t];
}
}
}
let seq_len = time2.min(self.config.max_source_positions);
for b in 0..batch {
for t in 0..seq_len {
for c in 0..d {
hidden[b * time2 * d + t * d + c] += self.positional_embedding[t * d + c];
}
}
}
let mut hidden_states = Tensor::from_vec(hidden, &[batch, time2, d])?;
for layer in &self.layers {
hidden_states = layer.forward(hidden_states)?;
}
make_contiguous(self.layer_norm.forward(hidden_states)?)
}
}
pub struct WhisperDecoderLayer {
self_attn_q: Linear,
self_attn_k: Linear,
self_attn_v: Linear,
self_attn_out: Linear,
self_attn_layer_norm: LayerNorm,
encoder_attn_q: Linear,
encoder_attn_k: Linear,
encoder_attn_v: Linear,
encoder_attn_out: Linear,
encoder_attn_layer_norm: LayerNorm,
fc1: Linear,
fc2: Linear,
final_layer_norm: LayerNorm,
num_heads: usize,
head_dim: usize,
}
impl WhisperDecoderLayer {
pub fn new(config: &WhisperConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &WhisperConfig, device: Device) -> Result<Self> {
let d = config.d_model;
let heads = config.decoder_attention_heads;
let head_dim = config.decoder_head_dim();
Ok(Self {
self_attn_q: Linear::new_with_device(d, d, false, device),
self_attn_k: Linear::new_with_device(d, d, false, device),
self_attn_v: Linear::new_with_device(d, d, false, device),
self_attn_out: Linear::new_with_device(d, d, false, device),
self_attn_layer_norm: LayerNorm::new_with_device(vec![d], 1e-5, device)?,
encoder_attn_q: Linear::new_with_device(d, d, false, device),
encoder_attn_k: Linear::new_with_device(d, d, false, device),
encoder_attn_v: Linear::new_with_device(d, d, false, device),
encoder_attn_out: Linear::new_with_device(d, d, false, device),
encoder_attn_layer_norm: LayerNorm::new_with_device(vec![d], 1e-5, device)?,
fc1: Linear::new_with_device(d, config.decoder_ffn_dim, true, device),
fc2: Linear::new_with_device(config.decoder_ffn_dim, d, true, device),
final_layer_norm: LayerNorm::new_with_device(vec![d], 1e-5, device)?,
num_heads: heads,
head_dim,
})
}
fn multi_head_attention(
&self,
q_proj: &Linear,
k_proj: &Linear,
v_proj: &Linear,
o_proj: &Linear,
query: Tensor,
key_value: Tensor,
num_heads: usize,
head_dim: usize,
causal: bool,
) -> Result<Tensor> {
let q_shape = query.shape().to_vec();
let kv_shape = key_value.shape().to_vec();
let batch = q_shape[0];
let q_len = q_shape[1];
let kv_len = kv_shape[1];
let d_model = q_shape[2];
let q = q_proj.forward(query.clone())?;
let k = k_proj.forward(key_value.clone())?;
let v = v_proj.forward(key_value)?;
let q = q.reshape(&[batch, q_len, num_heads, head_dim])?.transpose(1, 2)?;
let k = k.reshape(&[batch, kv_len, num_heads, head_dim])?.transpose(1, 2)?;
let v = v.reshape(&[batch, kv_len, num_heads, head_dim])?.transpose(1, 2)?;
let k_t = k.transpose(2, 3)?;
let scores = q.matmul(&k_t)?;
let scale = (head_dim as f32).sqrt();
let scores = scores.div_scalar(scale)?;
let scores = if causal {
let mut mask_data = vec![0.0f32; q_len * kv_len];
for i in 0..q_len {
for j in (i + 1)..kv_len {
mask_data[i * kv_len + j] = f32::NEG_INFINITY;
}
}
let mask =
Tensor::from_vec(mask_data, &[q_len, kv_len])?.reshape(&[1, 1, q_len, kv_len])?;
scores.add(&mask)?
} else {
scores
};
let attn_weights = scores.softmax(-1)?;
let attn_out = attn_weights.matmul(&v)?;
let attn_out = attn_out.transpose(1, 2)?;
let attn_out = attn_out.reshape(&[batch, q_len, d_model])?;
o_proj.forward(attn_out)
}
pub fn forward(&self, hidden_states: Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
let num_heads = self.num_heads;
let head_dim = self.head_dim;
let normed = make_contiguous(self.self_attn_layer_norm.forward(hidden_states.clone())?)?;
let self_attn_out = self.multi_head_attention(
&self.self_attn_q,
&self.self_attn_k,
&self.self_attn_v,
&self.self_attn_out,
normed.clone(),
normed,
num_heads,
head_dim,
true,
)?;
let hidden_c = make_contiguous(hidden_states.clone())?;
let residual1 = hidden_c.add(&make_contiguous(self_attn_out)?)?;
let normed2 = make_contiguous(self.encoder_attn_layer_norm.forward(residual1.clone())?)?;
let enc_c = make_contiguous(encoder_hidden_states.clone())?;
let cross_attn_out = self.multi_head_attention(
&self.encoder_attn_q,
&self.encoder_attn_k,
&self.encoder_attn_v,
&self.encoder_attn_out,
normed2,
enc_c,
num_heads,
head_dim,
false,
)?;
let residual1_c = make_contiguous(residual1)?;
let residual2 = residual1_c.add(&make_contiguous(cross_attn_out)?)?;
let normed3 = make_contiguous(self.final_layer_norm.forward(residual2.clone())?)?;
let ff = self.fc1.forward(normed3)?;
let ff = ff.gelu()?;
let ff = self.fc2.forward(ff)?;
let residual2_c = make_contiguous(residual2)?;
residual2_c.add(&make_contiguous(ff)?)
}
}
pub struct WhisperDecoder {
embed_tokens: trustformers_core::layers::embedding::Embedding,
embed_positions: Vec<f32>,
layers: Vec<WhisperDecoderLayer>,
layer_norm: LayerNorm,
config: WhisperConfig,
}
impl WhisperDecoder {
pub fn new(config: &WhisperConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &WhisperConfig, device: Device) -> Result<Self> {
let embed_tokens = trustformers_core::layers::embedding::Embedding::new_with_device(
config.vocab_size,
config.d_model,
None,
device,
)?;
let embed_positions = vec![0.0f32; config.max_target_positions * config.d_model];
let mut layers = Vec::new();
for _ in 0..config.decoder_layers {
layers.push(WhisperDecoderLayer::new_with_device(config, device)?);
}
let layer_norm = LayerNorm::new_with_device(vec![config.d_model], 1e-5, device)?;
Ok(Self {
embed_tokens,
embed_positions,
layers,
layer_norm,
config: config.clone(),
})
}
pub fn forward(&self, input_ids: &[u32], encoder_hidden_states: &Tensor) -> Result<Tensor> {
let seq_len = input_ids.len();
let d = self.config.d_model;
let batch = 1usize;
let emb_tensor = self.embed_tokens.forward(input_ids.to_vec())?;
let emb_data = match &emb_tensor {
Tensor::F32(arr) => arr.to_owned().into_raw_vec_and_offset().0,
_ => {
return Err(TrustformersError::invalid_input_simple(
"Expected F32 embedding".to_string(),
))
},
};
let actual_seq = seq_len.min(self.config.max_target_positions);
let mut hidden_data = emb_data;
for t in 0..actual_seq {
for c in 0..d {
hidden_data[t * d + c] += self.embed_positions[t * d + c];
}
}
let mut hidden_states = Tensor::from_vec(hidden_data, &[batch, seq_len, d])?;
for layer in &self.layers {
hidden_states = layer.forward(hidden_states, encoder_hidden_states)?;
}
make_contiguous(self.layer_norm.forward(hidden_states)?)
}
}
pub struct WhisperModel {
pub encoder: WhisperAudioEncoder,
pub decoder: WhisperDecoder,
pub config: WhisperConfig,
}
impl WhisperModel {
pub fn new(config: WhisperConfig) -> Result<Self> {
config.validate()?;
let encoder = WhisperAudioEncoder::new(&config)?;
let decoder = WhisperDecoder::new(&config)?;
Ok(Self {
encoder,
decoder,
config,
})
}
pub fn new_with_device(config: WhisperConfig, device: Device) -> Result<Self> {
config.validate()?;
let encoder = WhisperAudioEncoder::new_with_device(&config, device)?;
let decoder = WhisperDecoder::new_with_device(&config, device)?;
Ok(Self {
encoder,
decoder,
config,
})
}
pub fn encode(&self, mel: &Tensor) -> Result<Tensor> {
self.encoder.forward(mel)
}
pub fn forward(&self, mel: &Tensor, decoder_input_ids: &[u32]) -> Result<Tensor> {
let encoder_out = self.encoder.forward(mel)?;
self.decoder.forward(decoder_input_ids, &encoder_out)
}
}
pub struct WhisperForConditionalGeneration {
pub model: WhisperModel,
pub proj_out: Linear,
}
impl WhisperForConditionalGeneration {
pub fn new(config: WhisperConfig) -> Result<Self> {
let d_model = config.d_model;
let vocab_size = config.vocab_size;
let model = WhisperModel::new(config)?;
let proj_out = Linear::new(d_model, vocab_size, false);
Ok(Self { model, proj_out })
}
pub fn new_with_device(config: WhisperConfig, device: Device) -> Result<Self> {
let d_model = config.d_model;
let vocab_size = config.vocab_size;
let model = WhisperModel::new_with_device(config, device)?;
let proj_out = Linear::new_with_device(d_model, vocab_size, false, device);
Ok(Self { model, proj_out })
}
pub fn forward(&self, mel: &Tensor, decoder_input_ids: &[u32]) -> Result<Tensor> {
let hidden = self.model.forward(mel, decoder_input_ids)?;
self.proj_out.forward(hidden)
}
pub fn weight_map() -> Vec<(&'static str, &'static str)> {
vec![
("model.encoder.conv1.weight", "encoder.conv1.weight"),
("model.encoder.conv2.weight", "encoder.conv2.weight"),
(
"model.encoder.embed_positions.weight",
"encoder.positional_embedding",
),
(
"model.decoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
),
(
"model.decoder.embed_positions.weight",
"decoder.embed_positions",
),
("proj_out.weight", "proj_out.weight"),
]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::whisper::config::WhisperConfig;
use trustformers_core::traits::Config;
fn lcg_next(state: &mut u64) -> f32 {
*state = state.wrapping_mul(6364136223846793005u64).wrapping_add(1442695040888963407u64);
(*state as f32) / (u64::MAX as f32)
}
fn lcg_vec(n: usize, seed: u64) -> Vec<f32> {
let mut s = seed;
(0..n).map(|_| lcg_next(&mut s) * 2.0 - 1.0).collect()
}
fn tiny_cfg() -> WhisperConfig {
WhisperConfig {
num_mel_bins: 4,
max_source_positions: 16,
encoder_layers: 1,
encoder_attention_heads: 2,
d_model: 8,
encoder_ffn_dim: 16,
vocab_size: 64,
max_target_positions: 16,
decoder_layers: 1,
decoder_attention_heads: 2,
decoder_ffn_dim: 16,
dropout: 0.0,
attention_dropout: 0.0,
activation_dropout: 0.0,
scale_embedding: false,
model_type: "whisper".to_string(),
}
}
#[test]
fn test_whisper_config_validate_tiny() {
tiny_cfg().validate().expect("tiny config should pass validation");
}
#[test]
fn test_whisper_config_encoder_head_dim() {
let cfg = tiny_cfg();
assert_eq!(
cfg.encoder_head_dim(),
cfg.d_model / cfg.encoder_attention_heads,
"encoder_head_dim should be d_model / encoder_attention_heads"
);
}
#[test]
fn test_whisper_config_decoder_head_dim() {
let cfg = tiny_cfg();
assert_eq!(
cfg.decoder_head_dim(),
cfg.d_model / cfg.decoder_attention_heads,
"decoder_head_dim should be d_model / decoder_attention_heads"
);
}
#[test]
fn test_whisper_config_default_mel_bins() {
let cfg = WhisperConfig::default();
assert_eq!(
cfg.num_mel_bins, 80,
"Whisper default num_mel_bins should be 80"
);
}
#[test]
fn test_whisper_config_validate_fails_zero_mel_bins() {
let mut cfg = tiny_cfg();
cfg.num_mel_bins = 0;
assert!(cfg.validate().is_err(), "should fail with num_mel_bins=0");
}
#[test]
fn test_whisper_config_validate_fails_d_model_not_divisible_by_enc_heads() {
let mut cfg = tiny_cfg();
cfg.encoder_attention_heads = 3; assert!(
cfg.validate().is_err(),
"should fail when d_model not divisible by encoder_attention_heads"
);
}
#[test]
fn test_conv1d_output_time_dimension_stride1() {
let conv = Conv1d::new(4, 8, 3, 1, 1, false);
let batch = 1;
let time_in = 10;
let input = lcg_vec(batch * 4 * time_in, 7);
let (_, time_out) = conv.forward(&input, batch, time_in).expect("Conv1d forward failed");
assert_eq!(
time_out, time_in,
"stride=1, same padding should preserve time dimension"
);
}
#[test]
fn test_conv1d_output_time_dimension_stride2() {
let conv = Conv1d::new(8, 8, 3, 2, 1, false);
let batch = 1;
let time_in = 10;
let input = lcg_vec(batch * 8 * time_in, 13);
let (_, time_out) = conv.forward(&input, batch, time_in).expect("Conv1d forward failed");
assert_eq!(
time_out, 5,
"stride=2 should halve time dimension (10 -> 5)"
);
}
#[test]
fn test_conv1d_out_channels_matches() {
let conv = Conv1d::new(4, 8, 3, 1, 1, false);
assert_eq!(
conv.out_channels(),
8,
"out_channels should match construction arg"
);
}
#[test]
fn test_conv1d_with_bias_forward() {
let conv = Conv1d::new(4, 8, 3, 1, 1, true);
let batch = 1;
let time_in = 4;
let input = lcg_vec(batch * 4 * time_in, 5);
let result = conv.forward(&input, batch, time_in);
assert!(result.is_ok(), "Conv1d with bias should succeed");
}
#[test]
fn test_gelu_approx_zero_input() {
assert!((gelu_approx(0.0) - 0.0).abs() < 1e-6, "gelu(0) should be 0");
}
#[test]
fn test_gelu_approx_positive_large_approaches_input() {
let x = 10.0f32;
assert!(
(gelu_approx(x) - x).abs() < 0.1,
"gelu(10) should be close to 10"
);
}
#[test]
fn test_encoder_forward_output_shape() {
let cfg = tiny_cfg();
let encoder = WhisperAudioEncoder::new(&cfg).expect("should create encoder");
let time_in = 32;
let mel_data = lcg_vec(cfg.num_mel_bins * time_in, 3);
let mel =
Tensor::from_vec(mel_data, &[1, cfg.num_mel_bins, time_in]).expect("build mel tensor");
match encoder.forward(&mel) {
Ok(output) => {
let shape = output.shape();
assert_eq!(shape[0], 1, "batch dim should be 1");
assert_eq!(shape[1], 16, "encoder should halve time via conv2 stride=2");
assert_eq!(
shape[2], cfg.d_model,
"encoder output hidden should be d_model"
);
},
Err(_) => {
},
}
}
#[test]
fn test_encoder_forward_fails_on_2d_input() {
let cfg = tiny_cfg();
let encoder = WhisperAudioEncoder::new(&cfg).expect("should create encoder");
let bad_input = Tensor::from_vec(lcg_vec(32, 1), &[4, 8]).expect("build 2d tensor");
assert!(
encoder.forward(&bad_input).is_err(),
"encoder should reject non-3D input"
);
}
#[test]
fn test_decoder_layer_forward_output_shape() {
let cfg = tiny_cfg();
let layer = WhisperDecoderLayer::new(&cfg).expect("should create WhisperDecoderLayer");
let enc_time = 8usize;
let dec_time = 3usize;
let enc_data = lcg_vec(enc_time * cfg.d_model, 19);
let dec_data = lcg_vec(dec_time * cfg.d_model, 23);
let enc_states =
Tensor::from_vec(enc_data, &[1, enc_time, cfg.d_model]).expect("build enc tensor");
let dec_states =
Tensor::from_vec(dec_data, &[1, dec_time, cfg.d_model]).expect("build dec tensor");
match layer.forward(dec_states, &enc_states) {
Ok(output) => {
let shape = output.shape();
assert_eq!(shape[0], 1, "batch preserved");
assert_eq!(shape[1], dec_time, "decoder seq len preserved");
assert_eq!(shape[2], cfg.d_model, "decoder hidden dim preserved");
},
Err(_) => {
},
}
}
#[test]
fn test_whisper_model_creation() {
let cfg = tiny_cfg();
WhisperModel::new(cfg).expect("should create WhisperModel");
}
#[test]
fn test_whisper_model_encode_output_shape() {
let cfg = tiny_cfg();
let model = WhisperModel::new(cfg.clone()).expect("should create WhisperModel");
let time_in = 32;
let mel_data = lcg_vec(cfg.num_mel_bins * time_in, 11);
let mel =
Tensor::from_vec(mel_data, &[1, cfg.num_mel_bins, time_in]).expect("build mel tensor");
match model.encode(&mel) {
Ok(enc_out) => {
let shape = enc_out.shape();
assert_eq!(
shape[2], cfg.d_model,
"encoder output hidden should be d_model"
);
},
Err(_) => {
},
}
}
#[test]
fn test_whisper_model_forward_decoder_output_shape() {
let cfg = tiny_cfg();
let model = WhisperModel::new(cfg.clone()).expect("should create WhisperModel");
let time_in = 16;
let mel_data = lcg_vec(cfg.num_mel_bins * time_in, 31);
let mel =
Tensor::from_vec(mel_data, &[1, cfg.num_mel_bins, time_in]).expect("build mel tensor");
let decoder_ids: Vec<u32> = vec![1, 2, 3];
match model.forward(&mel, &decoder_ids) {
Ok(output) => {
let shape = output.shape();
assert_eq!(shape[0], 1, "batch preserved");
assert_eq!(shape[1], decoder_ids.len(), "seq len matches decoder input");
assert_eq!(shape[2], cfg.d_model, "hidden dim matches d_model");
},
Err(_) => {
},
}
}
#[test]
fn test_conditional_gen_forward_vocab_logits() {
let cfg = tiny_cfg();
let model = WhisperForConditionalGeneration::new(cfg.clone())
.expect("should create WhisperForConditionalGeneration");
let time_in = 16;
let mel_data = lcg_vec(cfg.num_mel_bins * time_in, 41);
let mel =
Tensor::from_vec(mel_data, &[1, cfg.num_mel_bins, time_in]).expect("build mel tensor");
let decoder_ids: Vec<u32> = vec![1, 2];
match model.forward(&mel, &decoder_ids) {
Ok(logits) => {
let shape = logits.shape();
assert_eq!(
shape[shape.len() - 1],
cfg.vocab_size,
"output last dim should equal vocab_size"
);
},
Err(_) => {
},
}
}
#[test]
fn test_conditional_gen_weight_map_non_empty() {
let wmap = WhisperForConditionalGeneration::weight_map();
assert!(
!wmap.is_empty(),
"weight_map should have at least one entry"
);
}
#[test]
fn test_whisper_tiny_config_layers() {
let cfg = WhisperConfig::whisper_tiny();
assert_eq!(
cfg.encoder_layers, 4,
"whisper_tiny should have 4 encoder layers"
);
assert_eq!(
cfg.decoder_layers, 4,
"whisper_tiny should have 4 decoder layers"
);
}
#[test]
fn test_whisper_base_config_d_model() {
let cfg = WhisperConfig::whisper_base();
assert_eq!(cfg.d_model, 512, "whisper_base d_model should be 512");
}
}