1use std::path::{Path, PathBuf};
19
20use anyhow::{Context, Result, bail};
21
22use crate::config::{DEFAULT_HF_REPO, ModelConfig};
23
24pub const DEFAULT_LOCAL_DIR: &str = ".cache/kittentts-mini-0.8";
26
27#[derive(Debug, Clone)]
29pub struct ModelLayout {
30 pub dir: PathBuf,
31 pub config: ModelConfig,
32 pub onnx: PathBuf,
33 pub voices: PathBuf,
34 pub native_weights: Option<PathBuf>,
35}
36
37impl ModelLayout {
38 pub fn resolve(model_dir: &Path) -> Result<Self> {
39 let dir = model_dir
40 .canonicalize()
41 .unwrap_or_else(|_| model_dir.to_path_buf());
42 let config = ModelConfig::load_from_dir(&dir)?;
43 let onnx = dir.join(&config.model_file);
44 let voices = dir.join(&config.voices);
45 if !onnx.is_file() {
46 bail!(
47 "ONNX model missing: {}\n\
48 Fetch weights: `just fetch-kittentts` or set RLX_KITTENTTS_DIR",
49 onnx.display()
50 );
51 }
52 if !voices.is_file() {
53 bail!("voices NPZ missing: {}", voices.display());
54 }
55 Ok(Self {
56 native_weights: find_native_weights(&dir),
57 dir,
58 config,
59 onnx,
60 voices,
61 })
62 }
63
64 pub fn voice_names(&self) -> Result<Vec<String>> {
66 let raw = crate::load_npz(&self.voices)
67 .with_context(|| format!("load voices {}", self.voices.display()))?;
68 let mut names: Vec<String> = raw.into_keys().collect();
69 for alias in self.config.voice_aliases.keys() {
70 if !names.iter().any(|n| n == alias) {
71 names.push(alias.clone());
72 }
73 }
74 names.sort();
75 Ok(names)
76 }
77}
78
79pub fn default_model_dir() -> Result<PathBuf> {
81 for key in ["RLX_KITTENTTS_DIR", "KITTENTTS_MODEL_DIR"] {
82 if let Ok(raw) = std::env::var(key) {
83 let p = PathBuf::from(&raw);
84 if layout_exists(&p) {
85 return Ok(p);
86 }
87 }
88 }
89
90 let local = PathBuf::from(DEFAULT_LOCAL_DIR);
91 if layout_exists(&local) {
92 return Ok(local);
93 }
94
95 #[cfg(feature = "hf-download")]
96 if let Ok(p) = hf_snapshot_dir(DEFAULT_HF_REPO) {
97 return Ok(p);
98 }
99
100 if let Some(p) = model_dir_from_bundle_manifest() {
101 return Ok(p);
102 }
103
104 if let Some(p) = hf_hub_cache_snapshot(DEFAULT_HF_REPO) {
105 return Ok(p);
106 }
107
108 bail!(
109 "KittenTTS weights not found.\n\
110 Quick start:\n\
111 just fetch-kittentts\n\
112 just kittentts-demo\n\
113 Or set RLX_KITTENTTS_DIR to a directory containing config.json, \
114 the ONNX file, and voices.npz."
115 )
116}
117
118pub fn layout_exists(dir: &Path) -> bool {
119 dir.join("config.json").is_file()
120}
121
122fn home_dir() -> PathBuf {
123 std::env::var("HOME")
124 .or_else(|_| std::env::var("USERPROFILE"))
125 .map(PathBuf::from)
126 .unwrap_or_else(|_| PathBuf::from("."))
127}
128
129pub fn hf_hub_root() -> PathBuf {
130 if let Ok(h) = std::env::var("HF_HOME") {
131 return PathBuf::from(h).join("hub");
132 }
133 if let Ok(h) = std::env::var("HUGGINGFACE_HUB_CACHE") {
134 return PathBuf::from(h);
135 }
136 home_dir().join(".cache").join("huggingface").join("hub")
137}
138
139fn hf_hub_cache_snapshot(repo_id: &str) -> Option<PathBuf> {
141 let cache_name = format!("models--{}", repo_id.replace('/', "--"));
142 let snapshots = hf_hub_root().join(cache_name).join("snapshots");
143 let mut candidates: Vec<PathBuf> = std::fs::read_dir(&snapshots)
144 .ok()
145 .into_iter()
146 .flatten()
147 .flatten()
148 .map(|e| e.path())
149 .filter(|snap| layout_exists(snap))
150 .collect();
151 candidates.sort();
152 candidates.into_iter().last()
153}
154
155fn model_dir_from_bundle_manifest() -> Option<PathBuf> {
156 let manifest = Path::new(env!("CARGO_MANIFEST_DIR"))
157 .join("../kitten_tts_mini_rlx/weights/rlx_bundle/manifest.json");
158 let data = std::fs::read_to_string(&manifest).ok()?;
159 let v: serde_json::Value = serde_json::from_str(&data).ok()?;
160 let onnx = v.get("source_onnx")?.as_str()?;
161 let dir = PathBuf::from(onnx).parent()?.to_path_buf();
162 layout_exists(&dir).then_some(dir)
163}
164
165pub fn find_rlx_bundle(weights_dir: &Path) -> Option<PathBuf> {
167 if let Ok(raw) =
168 std::env::var("RLX_ONNX_BUNDLE").or_else(|_| std::env::var("KITTEN_RLX_BUNDLE"))
169 {
170 let p = PathBuf::from(raw);
171 if p.join("graph.json").is_file() {
172 return Some(p);
173 }
174 }
175 let in_dir = weights_dir.join("rlx_bundle");
176 if in_dir.join("graph.json").is_file() {
177 return Some(in_dir);
178 }
179 None
180}
181
182pub fn default_native_weights_dir() -> Option<PathBuf> {
184 if let Ok(raw) = std::env::var("KITTEN_RLX_WEIGHTS") {
185 let p = PathBuf::from(raw);
186 if p.join("model.safetensors").is_file() {
187 return Some(p);
188 }
189 }
190 let sibling = Path::new(env!("CARGO_MANIFEST_DIR")).join("../kitten_tts_mini_rlx/weights");
191 if sibling.join("model.safetensors").is_file() && find_rlx_bundle(&sibling).is_some() {
192 return Some(sibling);
193 }
194 None
195}
196
197pub fn find_native_weights(model_dir: &Path) -> Option<PathBuf> {
199 if let Ok(raw) = std::env::var("KITTEN_RLX_WEIGHTS") {
200 let p = PathBuf::from(raw);
201 if p.join("model.safetensors").is_file() {
202 return Some(p);
203 }
204 }
205 if model_dir.join("model.safetensors").is_file() {
206 return Some(model_dir.to_path_buf());
207 }
208 default_native_weights_dir()
209}
210
211#[cfg(feature = "hf-download")]
212pub fn hf_snapshot_dir(repo_id: &str) -> Result<PathBuf> {
213 let api = hf_hub::api::sync::ApiBuilder::new()
214 .with_cache_dir(hf_hub_root())
215 .build()
216 .context("hf_hub ApiBuilder")?;
217 let repo = api.model(normalize_repo_id(repo_id));
218 let config = repo.get("config.json").with_context(|| {
219 format!(
220 "locate {repo_id} in Hugging Face cache under {}\n\
221 Download once: `just fetch-kittentts`",
222 hf_hub_root().display()
223 )
224 })?;
225 config
226 .parent()
227 .map(Path::to_path_buf)
228 .context("config.json has no parent (snapshot dir)")
229}
230
231#[cfg(not(feature = "hf-download"))]
232pub fn hf_snapshot_dir(_repo_id: &str) -> Result<PathBuf> {
233 bail!("rebuild with `--features hf-download` on rlx-kittentts")
234}
235
236pub fn normalize_repo_id(repo_id: &str) -> String {
237 if repo_id.contains('/') {
238 repo_id.to_string()
239 } else {
240 format!("KittenML/{repo_id}")
241 }
242}