Skip to main content

rlx_voxtral_tts/
load.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Mmap-backed weight access for `consolidated.safetensors`.
17
18use anyhow::{Result, bail};
19use rlx_core::safetensors_checkpoint::SafetensorsCheckpoint;
20use rlx_core::weight_map::WeightMap;
21use std::collections::{HashMap, HashSet};
22use std::path::{Path, PathBuf};
23use std::sync::Arc;
24
25pub type WeightSnapshot = HashMap<String, (Vec<f32>, Vec<usize>)>;
26
27pub const PREFIX_CODEC: &str = "audio_tokenizer.";
28pub const PREFIX_ACOUSTIC: &str = "acoustic_transformer.";
29pub const PREFIX_BACKBONE: &str = "layers.";
30pub const PREFIX_MM_EMBED: &str = "mm_audio_embeddings.";
31
32#[derive(Clone)]
33pub struct VoxtralTtsWeightStore {
34    dir: PathBuf,
35    checkpoint: Arc<SafetensorsCheckpoint>,
36    keys: Arc<HashSet<String>>,
37}
38
39impl VoxtralTtsWeightStore {
40    pub fn open(model_dir: &Path) -> Result<Self> {
41        let dir = crate::config::resolve_model_dir(model_dir)?;
42        let consolidated = dir.join(crate::config::CONSOLIDATED_WEIGHTS);
43        if !consolidated.is_file() {
44            bail!(
45                "missing {} — run `just fetch-voxtral-tts`",
46                consolidated.display()
47            );
48        }
49        let checkpoint = Arc::new(SafetensorsCheckpoint::open(&dir)?);
50        let keys = Arc::new(checkpoint.keys().map(str::to_string).collect());
51        Ok(Self {
52            dir,
53            checkpoint,
54            keys,
55        })
56    }
57
58    pub fn model_dir(&self) -> &Path {
59        &self.dir
60    }
61
62    pub fn keys(&self) -> &std::collections::HashSet<String> {
63        &self.keys
64    }
65
66    pub fn load_prefix(&self, prefix: &str) -> Result<WeightMap> {
67        let want: HashSet<String> = self
68            .keys
69            .iter()
70            .filter(|k| k.starts_with(prefix))
71            .cloned()
72            .collect();
73        if want.is_empty() {
74            bail!("no tensors with prefix {prefix:?} under {:?}", self.dir);
75        }
76        self.checkpoint.load_selected(&want)
77    }
78
79    pub fn load_codec(&self) -> Result<WeightMap> {
80        self.load_prefix(PREFIX_CODEC)
81    }
82
83    pub fn load_acoustic(&self) -> Result<WeightMap> {
84        self.load_prefix(PREFIX_ACOUSTIC)
85    }
86
87    pub fn load_backbone(&self) -> Result<WeightMap> {
88        let mut want: HashSet<String> = self
89            .keys
90            .iter()
91            .filter(|k| k.starts_with(PREFIX_BACKBONE) || *k == "norm.weight")
92            .cloned()
93            .collect();
94        for k in self.keys.iter() {
95            if k.starts_with(PREFIX_MM_EMBED) || k.starts_with("tok_embeddings.") {
96                want.insert(k.clone());
97            }
98        }
99        if want.is_empty() {
100            bail!("no backbone tensors found under {:?}", self.dir);
101        }
102        self.checkpoint.load_selected(&want)
103    }
104
105    pub fn tensor_snapshot_for_embed(&self) -> Result<WeightSnapshot> {
106        let want: HashSet<String> = self
107            .keys
108            .iter()
109            .filter(|k| {
110                k.starts_with(PREFIX_MM_EMBED)
111                    || k.starts_with("tok_embeddings.")
112                    || k.as_str() == "norm.weight"
113            })
114            .cloned()
115            .collect();
116        if want.is_empty() {
117            bail!("no embedding tensors found under {:?}", self.dir);
118        }
119        let mut wm = self.checkpoint.load_selected(&want)?;
120        let keys: Vec<String> = wm.keys().map(str::to_string).collect();
121        let mut out = HashMap::with_capacity(keys.len());
122        for key in keys {
123            out.insert(key.clone(), wm.take(&key)?);
124        }
125        Ok(out)
126    }
127
128    pub fn tensor_snapshot_for_backbone(&self) -> Result<WeightSnapshot> {
129        let mut wm = self.load_backbone()?;
130        let keys: Vec<String> = wm.keys().map(str::to_string).collect();
131        let mut out = HashMap::with_capacity(keys.len());
132        for key in keys {
133            out.insert(key.clone(), wm.take(&key)?);
134        }
135        Ok(out)
136    }
137
138    pub fn tensor_snapshot(&self, prefix: &str) -> Result<WeightSnapshot> {
139        let mut wm = self.load_prefix(prefix)?;
140        let keys: Vec<String> = wm.keys().map(str::to_string).collect();
141        let mut out = HashMap::with_capacity(keys.len());
142        for key in keys {
143            out.insert(key.clone(), wm.take(&key)?);
144        }
145        Ok(out)
146    }
147}