1use anyhow::{Result, bail};
19use rlx_core::safetensors_checkpoint::SafetensorsCheckpoint;
20use rlx_core::weight_map::WeightMap;
21use std::collections::{HashMap, HashSet};
22use std::path::{Path, PathBuf};
23use std::sync::Arc;
24
25pub type WeightSnapshot = HashMap<String, (Vec<f32>, Vec<usize>)>;
26
27pub const PREFIX_CODEC: &str = "audio_tokenizer.";
28pub const PREFIX_ACOUSTIC: &str = "acoustic_transformer.";
29pub const PREFIX_BACKBONE: &str = "layers.";
30pub const PREFIX_MM_EMBED: &str = "mm_audio_embeddings.";
31
32#[derive(Clone)]
33pub struct VoxtralTtsWeightStore {
34 dir: PathBuf,
35 checkpoint: Arc<SafetensorsCheckpoint>,
36 keys: Arc<HashSet<String>>,
37}
38
39impl VoxtralTtsWeightStore {
40 pub fn open(model_dir: &Path) -> Result<Self> {
41 let dir = crate::config::resolve_model_dir(model_dir)?;
42 let consolidated = dir.join(crate::config::CONSOLIDATED_WEIGHTS);
43 if !consolidated.is_file() {
44 bail!(
45 "missing {} — run `just fetch-voxtral-tts`",
46 consolidated.display()
47 );
48 }
49 let checkpoint = Arc::new(SafetensorsCheckpoint::open(&dir)?);
50 let keys = Arc::new(checkpoint.keys().map(str::to_string).collect());
51 Ok(Self {
52 dir,
53 checkpoint,
54 keys,
55 })
56 }
57
58 pub fn model_dir(&self) -> &Path {
59 &self.dir
60 }
61
62 pub fn keys(&self) -> &std::collections::HashSet<String> {
63 &self.keys
64 }
65
66 pub fn load_prefix(&self, prefix: &str) -> Result<WeightMap> {
67 let want: HashSet<String> = self
68 .keys
69 .iter()
70 .filter(|k| k.starts_with(prefix))
71 .cloned()
72 .collect();
73 if want.is_empty() {
74 bail!("no tensors with prefix {prefix:?} under {:?}", self.dir);
75 }
76 self.checkpoint.load_selected(&want)
77 }
78
79 pub fn load_codec(&self) -> Result<WeightMap> {
80 self.load_prefix(PREFIX_CODEC)
81 }
82
83 pub fn load_acoustic(&self) -> Result<WeightMap> {
84 self.load_prefix(PREFIX_ACOUSTIC)
85 }
86
87 pub fn load_backbone(&self) -> Result<WeightMap> {
88 let mut want: HashSet<String> = self
89 .keys
90 .iter()
91 .filter(|k| k.starts_with(PREFIX_BACKBONE) || *k == "norm.weight")
92 .cloned()
93 .collect();
94 for k in self.keys.iter() {
95 if k.starts_with(PREFIX_MM_EMBED) || k.starts_with("tok_embeddings.") {
96 want.insert(k.clone());
97 }
98 }
99 if want.is_empty() {
100 bail!("no backbone tensors found under {:?}", self.dir);
101 }
102 self.checkpoint.load_selected(&want)
103 }
104
105 pub fn tensor_snapshot_for_embed(&self) -> Result<WeightSnapshot> {
106 let want: HashSet<String> = self
107 .keys
108 .iter()
109 .filter(|k| {
110 k.starts_with(PREFIX_MM_EMBED)
111 || k.starts_with("tok_embeddings.")
112 || k.as_str() == "norm.weight"
113 })
114 .cloned()
115 .collect();
116 if want.is_empty() {
117 bail!("no embedding tensors found under {:?}", self.dir);
118 }
119 let mut wm = self.checkpoint.load_selected(&want)?;
120 let keys: Vec<String> = wm.keys().map(str::to_string).collect();
121 let mut out = HashMap::with_capacity(keys.len());
122 for key in keys {
123 out.insert(key.clone(), wm.take(&key)?);
124 }
125 Ok(out)
126 }
127
128 pub fn tensor_snapshot_for_backbone(&self) -> Result<WeightSnapshot> {
129 let mut wm = self.load_backbone()?;
130 let keys: Vec<String> = wm.keys().map(str::to_string).collect();
131 let mut out = HashMap::with_capacity(keys.len());
132 for key in keys {
133 out.insert(key.clone(), wm.take(&key)?);
134 }
135 Ok(out)
136 }
137
138 pub fn tensor_snapshot(&self, prefix: &str) -> Result<WeightSnapshot> {
139 let mut wm = self.load_prefix(prefix)?;
140 let keys: Vec<String> = wm.keys().map(str::to_string).collect();
141 let mut out = HashMap::with_capacity(keys.len());
142 for key in keys {
143 out.insert(key.clone(), wm.take(&key)?);
144 }
145 Ok(out)
146 }
147}