1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
use std::path::PathBuf;
use crate::model::{config::WhisperConfig, decoder::WhisperDecoder, encoder::WhisperEncoder};
use burn::{
module::Module,
tensor::{Int, Tensor, backend::Backend},
};
use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};
/// Top-level Whisper model.
///
/// Wires the audio encoder and text decoder together:
///
/// mel spectrogram → encoder → hidden states
/// token ids + hidden states → decoder → logits
#[derive(Module, Debug)]
pub struct WhisperModel<B: Backend> {
pub encoder: WhisperEncoder<B>,
pub decoder: WhisperDecoder<B>,
}
impl<B: Backend> WhisperModel<B> {
/// Construct a WhisperModel from a config.
pub fn new(config: &WhisperConfig, device: &B::Device) -> Self {
Self {
encoder: WhisperEncoder::new(
config.num_mel_bins,
config.d_model,
config.num_heads,
config.encoder_layers,
device,
),
decoder: WhisperDecoder::new(
config.vocab_size,
config.d_model,
config.num_heads,
config.decoder_layers,
config.max_target_positions,
device,
),
}
}
/// Full forward pass: mel spectrogram → token logits.
///
/// # Arguments
/// * `mel` - Log-mel spectrogram [batch, n_mels, time_steps]
/// * `tokens` - Decoder input token IDs [batch, tgt_len]
///
/// # Returns
/// Logits over vocabulary [batch, tgt_len, vocab_size]
pub fn forward(&self, mel: Tensor<B, 3>, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let encoder_out = self.encoder.forward(mel);
self.decoder.forward(tokens, encoder_out)
}
/// Encode audio only (useful to cache encoder output during inference).
pub fn encode(&self, mel: Tensor<B, 3>) -> Tensor<B, 3> {
self.encoder.forward(mel)
}
/// Decode one step given cached encoder output.
pub fn decode(&self, tokens: Tensor<B, 2, Int>, encoder_out: Tensor<B, 3>) -> Tensor<B, 3> {
self.decoder.forward(tokens, encoder_out)
}
/// Load pre-trained weights from a HuggingFace safetensors file.
///
/// Remaps PyTorch/HF key names to Burn module paths.
pub fn from_safetensors(
config: &WhisperConfig,
path: impl Into<PathBuf>,
device: &B::Device,
) -> anyhow::Result<Self> {
let mut model = Self::new(config, device);
let mut store = SafetensorsStore::from_file(path.into())
.with_from_adapter(PyTorchToBurnAdapter)
// Remove "model." prefix
.with_key_remapping("^model\\.", "")
// encoder.layers.N -> encoder.blocks.N
.with_key_remapping("encoder\\.layers\\.(\\d+)\\.", "encoder.blocks.$1.")
// decoder.layers.N -> decoder.blocks.N
.with_key_remapping("decoder\\.layers\\.(\\d+)\\.", "decoder.blocks.$1.")
// encoder.layer_norm -> encoder.norm
.with_key_remapping("encoder\\.layer_norm\\.", "encoder.norm.")
// decoder.layer_norm -> decoder.norm
.with_key_remapping("decoder\\.layer_norm\\.", "decoder.norm.")
// self_attn_layer_norm -> norm1
.with_key_remapping("\\.self_attn_layer_norm\\.", ".norm1.")
// encoder_attn_layer_norm -> norm2
.with_key_remapping("\\.encoder_attn_layer_norm\\.", ".norm2.")
// encoder blocks: final_layer_norm -> norm2
.with_key_remapping("(encoder\\.blocks\\.\\d+)\\.final_layer_norm", "$1.norm2")
// decoder blocks: final_layer_norm -> norm3
.with_key_remapping("(decoder\\.blocks\\.\\d+)\\.final_layer_norm", "$1.norm3")
// encoder_attn -> cross_attn
.with_key_remapping("\\.encoder_attn\\.", ".cross_attn.")
// fc1/fc2 -> ffn.fc1/ffn.fc2
.with_key_remapping("\\.fc1\\.", ".ffn.fc1.")
.with_key_remapping("\\.fc2\\.", ".ffn.fc2.")
// embed_tokens -> token_embedding
.with_key_remapping("decoder\\.embed_tokens\\.", "decoder.token_embedding.")
// embed_positions -> positional_embedding
.with_key_remapping(
"decoder\\.embed_positions\\.",
"decoder.positional_embedding.",
);
model
.load_from(&mut store)
.expect("Failed to load safetensors file");
Ok(model)
}
}