Skip to main content

rlx_kittentts/
assets.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Model directory discovery and path layout.
17
18use std::path::{Path, PathBuf};
19
20use anyhow::{Context, Result, bail};
21
22use crate::config::{DEFAULT_HF_REPO, ModelConfig};
23
24/// Default local checkout from `just fetch-kittentts`.
25pub const DEFAULT_LOCAL_DIR: &str = ".cache/kittentts-mini-0.8";
26
27/// Resolved ONNX + voices (+ optional native weights) for one checkpoint directory.
28#[derive(Debug, Clone)]
29pub struct ModelLayout {
30    pub dir: PathBuf,
31    pub config: ModelConfig,
32    pub onnx: PathBuf,
33    pub voices: PathBuf,
34    pub native_weights: Option<PathBuf>,
35}
36
37impl ModelLayout {
38    pub fn resolve(model_dir: &Path) -> Result<Self> {
39        let dir = model_dir
40            .canonicalize()
41            .unwrap_or_else(|_| model_dir.to_path_buf());
42        let config = ModelConfig::load_from_dir(&dir)?;
43        let onnx = dir.join(&config.model_file);
44        let voices = dir.join(&config.voices);
45        let native_weights = find_native_weights(&dir);
46        if !onnx.is_file() && native_weights.is_none() {
47            bail!(
48                "ONNX model missing: {}\n\
49                 Fetch weights: `just fetch-kittentts` or set RLX_KITTENTTS_DIR",
50                onnx.display()
51            );
52        }
53        if !voices.is_file() {
54            bail!("voices NPZ missing: {}", voices.display());
55        }
56        Ok(Self {
57            native_weights,
58            dir,
59            config,
60            onnx,
61            voices,
62        })
63    }
64
65    /// Voice keys from the NPZ plus friendly alias names from config.
66    pub fn voice_names(&self) -> Result<Vec<String>> {
67        let raw = crate::load_npz(&self.voices)
68            .with_context(|| format!("load voices {}", self.voices.display()))?;
69        let mut names: Vec<String> = raw.into_keys().collect();
70        for alias in self.config.voice_aliases.keys() {
71            if !names.iter().any(|n| n == alias) {
72                names.push(alias.clone());
73            }
74        }
75        names.sort();
76        Ok(names)
77    }
78}
79
80/// Best-effort model directory for CLI and tests.
81pub fn default_model_dir() -> Result<PathBuf> {
82    for key in ["RLX_KITTENTTS_DIR", "KITTENTTS_MODEL_DIR"] {
83        if let Ok(raw) = std::env::var(key) {
84            let p = PathBuf::from(&raw);
85            if layout_exists(&p) {
86                return Ok(p);
87            }
88        }
89    }
90
91    let local = PathBuf::from(DEFAULT_LOCAL_DIR);
92    if layout_exists(&local) {
93        return Ok(local);
94    }
95
96    #[cfg(feature = "hf-download")]
97    if let Ok(p) = hf_snapshot_dir(DEFAULT_HF_REPO) {
98        return Ok(p);
99    }
100
101    if let Some(p) = model_dir_from_bundle_manifest() {
102        return Ok(p);
103    }
104
105    if let Some(p) = hf_hub_cache_snapshot(DEFAULT_HF_REPO) {
106        return Ok(p);
107    }
108
109    bail!(
110        "KittenTTS weights not found.\n\
111         Quick start:\n\
112           just fetch-kittentts\n\
113           just kittentts-demo\n\
114         Or set RLX_KITTENTTS_DIR to a directory containing config.json, \
115         the ONNX file, and voices.npz."
116    )
117}
118
119pub fn layout_exists(dir: &Path) -> bool {
120    dir.join("config.json").is_file()
121}
122
123fn home_dir() -> PathBuf {
124    std::env::var("HOME")
125        .or_else(|_| std::env::var("USERPROFILE"))
126        .map(PathBuf::from)
127        .unwrap_or_else(|_| PathBuf::from("."))
128}
129
130pub fn hf_hub_root() -> PathBuf {
131    if let Ok(h) = std::env::var("HF_HOME") {
132        return PathBuf::from(h).join("hub");
133    }
134    if let Ok(h) = std::env::var("HUGGINGFACE_HUB_CACHE") {
135        return PathBuf::from(h);
136    }
137    home_dir().join(".cache").join("huggingface").join("hub")
138}
139
140/// Locate a cached HF snapshot without downloading (no `hf-hub` dependency).
141fn hf_hub_cache_snapshot(repo_id: &str) -> Option<PathBuf> {
142    let cache_name = format!("models--{}", repo_id.replace('/', "--"));
143    let snapshots = hf_hub_root().join(cache_name).join("snapshots");
144    let mut candidates: Vec<PathBuf> = std::fs::read_dir(&snapshots)
145        .ok()
146        .into_iter()
147        .flatten()
148        .flatten()
149        .map(|e| e.path())
150        .filter(|snap| layout_exists(snap))
151        .collect();
152    candidates.sort();
153    candidates.into_iter().last()
154}
155
156fn model_dir_from_bundle_manifest() -> Option<PathBuf> {
157    let manifest = Path::new(env!("CARGO_MANIFEST_DIR"))
158        .join("../kitten_tts_mini_rlx/weights/rlx_bundle/manifest.json");
159    let data = std::fs::read_to_string(&manifest).ok()?;
160    let v: serde_json::Value = serde_json::from_str(&data).ok()?;
161    let onnx = v.get("source_onnx")?.as_str()?;
162    let dir = PathBuf::from(onnx).parent()?.to_path_buf();
163    layout_exists(&dir).then_some(dir)
164}
165
166/// RLX ONNX bundle (`graph.json` + weights) under a native weights directory.
167pub fn find_rlx_bundle(weights_dir: &Path) -> Option<PathBuf> {
168    if let Ok(raw) =
169        std::env::var("RLX_ONNX_BUNDLE").or_else(|_| std::env::var("KITTEN_RLX_BUNDLE"))
170    {
171        let p = PathBuf::from(raw);
172        if p.join("graph.json").is_file() {
173            return Some(p);
174        }
175    }
176    let in_dir = weights_dir.join("rlx_bundle");
177    if in_dir.join("graph.json").is_file() {
178        return Some(in_dir);
179    }
180    None
181}
182
183/// Workspace decomposed weights (`crates/kitten_tts_mini_rlx/weights`).
184pub fn default_native_weights_dir() -> Option<PathBuf> {
185    if let Ok(raw) = std::env::var("KITTEN_RLX_WEIGHTS") {
186        let p = PathBuf::from(raw);
187        if p.join("model.safetensors").is_file() || p.join("rlx_bundle/graph.json").is_file() {
188            return Some(p);
189        }
190    }
191    let sibling = Path::new(env!("CARGO_MANIFEST_DIR")).join("../kitten_tts_mini_rlx/weights");
192    if sibling.join("model.safetensors").is_file() && find_rlx_bundle(&sibling).is_some() {
193        return Some(sibling);
194    }
195    if find_rlx_bundle(&sibling).is_some() {
196        return Some(sibling);
197    }
198    None
199}
200
201/// Decomposed RLX weights (`model.safetensors`), if present.
202pub fn find_native_weights(model_dir: &Path) -> Option<PathBuf> {
203    if let Ok(raw) = std::env::var("KITTEN_RLX_WEIGHTS") {
204        let p = PathBuf::from(raw);
205        if p.join("model.safetensors").is_file() || p.join("rlx_bundle/graph.json").is_file() {
206            return Some(p);
207        }
208    }
209    if model_dir.join("model.safetensors").is_file() {
210        return Some(model_dir.to_path_buf());
211    }
212    if model_dir.join("rlx_bundle/graph.json").is_file() {
213        return Some(model_dir.to_path_buf());
214    }
215    default_native_weights_dir()
216}
217
218#[cfg(feature = "hf-download")]
219pub fn hf_snapshot_dir(repo_id: &str) -> Result<PathBuf> {
220    let api = hf_hub::api::sync::ApiBuilder::new()
221        .with_cache_dir(hf_hub_root())
222        .build()
223        .context("hf_hub ApiBuilder")?;
224    let repo = api.model(normalize_repo_id(repo_id));
225    let config = repo.get("config.json").with_context(|| {
226        format!(
227            "locate {repo_id} in Hugging Face cache under {}\n\
228             Download once: `just fetch-kittentts`",
229            hf_hub_root().display()
230        )
231    })?;
232    config
233        .parent()
234        .map(Path::to_path_buf)
235        .context("config.json has no parent (snapshot dir)")
236}
237
238#[cfg(not(feature = "hf-download"))]
239pub fn hf_snapshot_dir(_repo_id: &str) -> Result<PathBuf> {
240    bail!("rebuild with `--features hf-download` on rlx-kittentts")
241}
242
243pub fn normalize_repo_id(repo_id: &str) -> String {
244    if repo_id.contains('/') {
245        repo_id.to_string()
246    } else {
247        format!("KittenML/{repo_id}")
248    }
249}