1use rlx_core::weight_map::WeightMap;
19
20#[derive(Debug, Clone)]
21pub struct WhisperWeightPrefix {
22 pub encoder: String,
23 pub decoder: String,
24 pub hf_embed_names: bool,
25}
26
27impl WhisperWeightPrefix {
28 pub fn detect(weights: &WeightMap) -> Self {
29 let (encoder, decoder) = if weights.has("model.encoder.conv1.weight") {
30 ("model.encoder".into(), "model.decoder".into())
31 } else if weights.has("encoder.conv1.weight") {
32 ("encoder".into(), "decoder".into())
33 } else {
34 ("model.encoder".into(), "model.decoder".into())
35 };
36 let hf_embed_names = weights.has(&format!("{decoder}.embed_tokens.weight"));
37 Self {
38 encoder,
39 decoder,
40 hf_embed_names,
41 }
42 }
43
44 pub fn enc_layer(&self, i: usize, suffix: &str) -> String {
45 format!("{}.layers.{i}.{suffix}", self.encoder)
46 }
47
48 pub fn dec_layer(&self, i: usize, suffix: &str) -> String {
49 format!("{}.layers.{i}.{suffix}", self.decoder)
50 }
51
52 pub fn enc_conv1_w(&self) -> String {
53 format!("{}.conv1.weight", self.encoder)
54 }
55
56 pub fn enc_conv1_b(&self) -> String {
57 format!("{}.conv1.bias", self.encoder)
58 }
59
60 pub fn enc_conv2_w(&self) -> String {
61 format!("{}.conv2.weight", self.encoder)
62 }
63
64 pub fn enc_conv2_b(&self) -> String {
65 format!("{}.conv2.bias", self.encoder)
66 }
67
68 pub fn enc_ln_post_w(&self) -> String {
69 format!("{}.layer_norm.weight", self.encoder)
70 }
71
72 pub fn enc_ln_post_b(&self) -> String {
73 format!("{}.layer_norm.bias", self.encoder)
74 }
75
76 pub fn dec_embed_tokens(&self) -> String {
77 if self.hf_embed_names {
78 format!("{}.embed_tokens.weight", self.decoder)
79 } else {
80 format!("{}.token_embedding.weight", self.decoder)
81 }
82 }
83
84 pub fn dec_embed_positions(&self) -> String {
85 if self.hf_embed_names {
86 format!("{}.embed_positions.weight", self.decoder)
87 } else {
88 format!("{}.positional_embedding", self.decoder)
89 }
90 }
91
92 pub fn dec_ln_w(&self) -> String {
93 format!("{}.layer_norm.weight", self.decoder)
94 }
95
96 pub fn dec_ln_b(&self) -> String {
97 format!("{}.layer_norm.bias", self.decoder)
98 }
99}