1use std::path::{Path, PathBuf};
19
20use anyhow::{Context, Result, bail};
21
22use crate::config::{DEFAULT_HF_REPO, ModelConfig};
23
24pub const DEFAULT_LOCAL_DIR: &str = ".cache/kittentts-mini-0.8";
26
27#[derive(Debug, Clone)]
29pub struct ModelLayout {
30 pub dir: PathBuf,
31 pub config: ModelConfig,
32 pub onnx: PathBuf,
33 pub voices: PathBuf,
34 pub native_weights: Option<PathBuf>,
35}
36
37impl ModelLayout {
38 pub fn resolve(model_dir: &Path) -> Result<Self> {
39 let dir = model_dir
40 .canonicalize()
41 .unwrap_or_else(|_| model_dir.to_path_buf());
42 let config = ModelConfig::load_from_dir(&dir)?;
43 let onnx = dir.join(&config.model_file);
44 let voices = dir.join(&config.voices);
45 let native_weights = find_native_weights(&dir);
46 if !onnx.is_file() && native_weights.is_none() {
47 bail!(
48 "ONNX model missing: {}\n\
49 Fetch weights: `just fetch-kittentts` or set RLX_KITTENTTS_DIR",
50 onnx.display()
51 );
52 }
53 if !voices.is_file() {
54 bail!("voices NPZ missing: {}", voices.display());
55 }
56 Ok(Self {
57 native_weights,
58 dir,
59 config,
60 onnx,
61 voices,
62 })
63 }
64
65 pub fn voice_names(&self) -> Result<Vec<String>> {
67 let raw = crate::load_npz(&self.voices)
68 .with_context(|| format!("load voices {}", self.voices.display()))?;
69 let mut names: Vec<String> = raw.into_keys().collect();
70 for alias in self.config.voice_aliases.keys() {
71 if !names.iter().any(|n| n == alias) {
72 names.push(alias.clone());
73 }
74 }
75 names.sort();
76 Ok(names)
77 }
78}
79
80pub fn default_model_dir() -> Result<PathBuf> {
82 for key in ["RLX_KITTENTTS_DIR", "KITTENTTS_MODEL_DIR"] {
83 if let Ok(raw) = std::env::var(key) {
84 let p = PathBuf::from(&raw);
85 if layout_exists(&p) {
86 return Ok(p);
87 }
88 }
89 }
90
91 let local = PathBuf::from(DEFAULT_LOCAL_DIR);
92 if layout_exists(&local) {
93 return Ok(local);
94 }
95
96 #[cfg(feature = "hf-download")]
97 if let Ok(p) = hf_snapshot_dir(DEFAULT_HF_REPO) {
98 return Ok(p);
99 }
100
101 if let Some(p) = model_dir_from_bundle_manifest() {
102 return Ok(p);
103 }
104
105 if let Some(p) = hf_hub_cache_snapshot(DEFAULT_HF_REPO) {
106 return Ok(p);
107 }
108
109 bail!(
110 "KittenTTS weights not found.\n\
111 Quick start:\n\
112 just fetch-kittentts\n\
113 just kittentts-demo\n\
114 Or set RLX_KITTENTTS_DIR to a directory containing config.json, \
115 the ONNX file, and voices.npz."
116 )
117}
118
119pub fn layout_exists(dir: &Path) -> bool {
120 dir.join("config.json").is_file()
121}
122
123fn home_dir() -> PathBuf {
124 std::env::var("HOME")
125 .or_else(|_| std::env::var("USERPROFILE"))
126 .map(PathBuf::from)
127 .unwrap_or_else(|_| PathBuf::from("."))
128}
129
130pub fn hf_hub_root() -> PathBuf {
131 if let Ok(h) = std::env::var("HF_HOME") {
132 return PathBuf::from(h).join("hub");
133 }
134 if let Ok(h) = std::env::var("HUGGINGFACE_HUB_CACHE") {
135 return PathBuf::from(h);
136 }
137 home_dir().join(".cache").join("huggingface").join("hub")
138}
139
140fn hf_hub_cache_snapshot(repo_id: &str) -> Option<PathBuf> {
142 let cache_name = format!("models--{}", repo_id.replace('/', "--"));
143 let snapshots = hf_hub_root().join(cache_name).join("snapshots");
144 let mut candidates: Vec<PathBuf> = std::fs::read_dir(&snapshots)
145 .ok()
146 .into_iter()
147 .flatten()
148 .flatten()
149 .map(|e| e.path())
150 .filter(|snap| layout_exists(snap))
151 .collect();
152 candidates.sort();
153 candidates.into_iter().last()
154}
155
156fn model_dir_from_bundle_manifest() -> Option<PathBuf> {
157 let manifest = Path::new(env!("CARGO_MANIFEST_DIR"))
158 .join("../kitten_tts_mini_rlx/weights/rlx_bundle/manifest.json");
159 let data = std::fs::read_to_string(&manifest).ok()?;
160 let v: serde_json::Value = serde_json::from_str(&data).ok()?;
161 let onnx = v.get("source_onnx")?.as_str()?;
162 let dir = PathBuf::from(onnx).parent()?.to_path_buf();
163 layout_exists(&dir).then_some(dir)
164}
165
166pub fn find_rlx_bundle(weights_dir: &Path) -> Option<PathBuf> {
168 if let Ok(raw) =
169 std::env::var("RLX_ONNX_BUNDLE").or_else(|_| std::env::var("KITTEN_RLX_BUNDLE"))
170 {
171 let p = PathBuf::from(raw);
172 if p.join("graph.json").is_file() {
173 return Some(p);
174 }
175 }
176 let in_dir = weights_dir.join("rlx_bundle");
177 if in_dir.join("graph.json").is_file() {
178 return Some(in_dir);
179 }
180 None
181}
182
183pub fn default_native_weights_dir() -> Option<PathBuf> {
185 if let Ok(raw) = std::env::var("KITTEN_RLX_WEIGHTS") {
186 let p = PathBuf::from(raw);
187 if p.join("model.safetensors").is_file() || p.join("rlx_bundle/graph.json").is_file() {
188 return Some(p);
189 }
190 }
191 let sibling = Path::new(env!("CARGO_MANIFEST_DIR")).join("../kitten_tts_mini_rlx/weights");
192 if sibling.join("model.safetensors").is_file() && find_rlx_bundle(&sibling).is_some() {
193 return Some(sibling);
194 }
195 if find_rlx_bundle(&sibling).is_some() {
196 return Some(sibling);
197 }
198 None
199}
200
201pub fn find_native_weights(model_dir: &Path) -> Option<PathBuf> {
203 if let Ok(raw) = std::env::var("KITTEN_RLX_WEIGHTS") {
204 let p = PathBuf::from(raw);
205 if p.join("model.safetensors").is_file() || p.join("rlx_bundle/graph.json").is_file() {
206 return Some(p);
207 }
208 }
209 if model_dir.join("model.safetensors").is_file() {
210 return Some(model_dir.to_path_buf());
211 }
212 if model_dir.join("rlx_bundle/graph.json").is_file() {
213 return Some(model_dir.to_path_buf());
214 }
215 default_native_weights_dir()
216}
217
218#[cfg(feature = "hf-download")]
219pub fn hf_snapshot_dir(repo_id: &str) -> Result<PathBuf> {
220 let api = hf_hub::api::sync::ApiBuilder::new()
221 .with_cache_dir(hf_hub_root())
222 .build()
223 .context("hf_hub ApiBuilder")?;
224 let repo = api.model(normalize_repo_id(repo_id));
225 let config = repo.get("config.json").with_context(|| {
226 format!(
227 "locate {repo_id} in Hugging Face cache under {}\n\
228 Download once: `just fetch-kittentts`",
229 hf_hub_root().display()
230 )
231 })?;
232 config
233 .parent()
234 .map(Path::to_path_buf)
235 .context("config.json has no parent (snapshot dir)")
236}
237
238#[cfg(not(feature = "hf-download"))]
239pub fn hf_snapshot_dir(_repo_id: &str) -> Result<PathBuf> {
240 bail!("rebuild with `--features hf-download` on rlx-kittentts")
241}
242
243pub fn normalize_repo_id(repo_id: &str) -> String {
244 if repo_id.contains('/') {
245 repo_id.to_string()
246 } else {
247 format!("KittenML/{repo_id}")
248 }
249}