Skip to main content

rlx_whisper/
weights.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Checkpoint key prefixes — HuggingFace (`model.encoder`) vs OpenAI (`.pt`).
17
18use rlx_core::weight_map::WeightMap;
19
20#[derive(Debug, Clone)]
21pub struct WhisperWeightPrefix {
22    pub encoder: String,
23    pub decoder: String,
24    pub hf_embed_names: bool,
25}
26
27impl WhisperWeightPrefix {
28    pub fn detect(weights: &WeightMap) -> Self {
29        let (encoder, decoder) = if weights.has("model.encoder.conv1.weight") {
30            ("model.encoder".into(), "model.decoder".into())
31        } else if weights.has("encoder.conv1.weight") {
32            ("encoder".into(), "decoder".into())
33        } else {
34            ("model.encoder".into(), "model.decoder".into())
35        };
36        let hf_embed_names = weights.has(&format!("{decoder}.embed_tokens.weight"));
37        Self {
38            encoder,
39            decoder,
40            hf_embed_names,
41        }
42    }
43
44    pub fn enc_layer(&self, i: usize, suffix: &str) -> String {
45        format!("{}.layers.{i}.{suffix}", self.encoder)
46    }
47
48    pub fn dec_layer(&self, i: usize, suffix: &str) -> String {
49        format!("{}.layers.{i}.{suffix}", self.decoder)
50    }
51
52    pub fn enc_conv1_w(&self) -> String {
53        format!("{}.conv1.weight", self.encoder)
54    }
55
56    pub fn enc_conv1_b(&self) -> String {
57        format!("{}.conv1.bias", self.encoder)
58    }
59
60    pub fn enc_conv2_w(&self) -> String {
61        format!("{}.conv2.weight", self.encoder)
62    }
63
64    pub fn enc_conv2_b(&self) -> String {
65        format!("{}.conv2.bias", self.encoder)
66    }
67
68    pub fn enc_ln_post_w(&self) -> String {
69        format!("{}.layer_norm.weight", self.encoder)
70    }
71
72    pub fn enc_ln_post_b(&self) -> String {
73        format!("{}.layer_norm.bias", self.encoder)
74    }
75
76    pub fn dec_embed_tokens(&self) -> String {
77        if self.hf_embed_names {
78            format!("{}.embed_tokens.weight", self.decoder)
79        } else {
80            format!("{}.token_embedding.weight", self.decoder)
81        }
82    }
83
84    pub fn dec_embed_positions(&self) -> String {
85        if self.hf_embed_names {
86            format!("{}.embed_positions.weight", self.decoder)
87        } else {
88            format!("{}.positional_embedding", self.decoder)
89        }
90    }
91
92    pub fn dec_ln_w(&self) -> String {
93        format!("{}.layer_norm.weight", self.decoder)
94    }
95
96    pub fn dec_ln_b(&self) -> String {
97        format!("{}.layer_norm.bias", self.decoder)
98    }
99}