Skip to main content

rlx_kittentts/
assets.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Model directory discovery and path layout.
17
18use std::path::{Path, PathBuf};
19
20use anyhow::{Context, Result, bail};
21
22use crate::config::{DEFAULT_HF_REPO, ModelConfig};
23
24/// Default local checkout from `just fetch-kittentts`.
25pub const DEFAULT_LOCAL_DIR: &str = ".cache/kittentts-mini-0.8";
26
27/// Resolved ONNX + voices (+ optional native weights) for one checkpoint directory.
28#[derive(Debug, Clone)]
29pub struct ModelLayout {
30    pub dir: PathBuf,
31    pub config: ModelConfig,
32    pub onnx: PathBuf,
33    pub voices: PathBuf,
34    pub native_weights: Option<PathBuf>,
35}
36
37impl ModelLayout {
38    pub fn resolve(model_dir: &Path) -> Result<Self> {
39        let dir = model_dir
40            .canonicalize()
41            .unwrap_or_else(|_| model_dir.to_path_buf());
42        let config = ModelConfig::load_from_dir(&dir)?;
43        let onnx = dir.join(&config.model_file);
44        let voices = dir.join(&config.voices);
45        if !onnx.is_file() {
46            bail!(
47                "ONNX model missing: {}\n\
48                 Fetch weights: `just fetch-kittentts` or set RLX_KITTENTTS_DIR",
49                onnx.display()
50            );
51        }
52        if !voices.is_file() {
53            bail!("voices NPZ missing: {}", voices.display());
54        }
55        Ok(Self {
56            native_weights: find_native_weights(&dir),
57            dir,
58            config,
59            onnx,
60            voices,
61        })
62    }
63
64    /// Voice keys from the NPZ plus friendly alias names from config.
65    pub fn voice_names(&self) -> Result<Vec<String>> {
66        let raw = crate::load_npz(&self.voices)
67            .with_context(|| format!("load voices {}", self.voices.display()))?;
68        let mut names: Vec<String> = raw.into_keys().collect();
69        for alias in self.config.voice_aliases.keys() {
70            if !names.iter().any(|n| n == alias) {
71                names.push(alias.clone());
72            }
73        }
74        names.sort();
75        Ok(names)
76    }
77}
78
79/// Best-effort model directory for CLI and tests.
80pub fn default_model_dir() -> Result<PathBuf> {
81    for key in ["RLX_KITTENTTS_DIR", "KITTENTTS_MODEL_DIR"] {
82        if let Ok(raw) = std::env::var(key) {
83            let p = PathBuf::from(&raw);
84            if layout_exists(&p) {
85                return Ok(p);
86            }
87        }
88    }
89
90    let local = PathBuf::from(DEFAULT_LOCAL_DIR);
91    if layout_exists(&local) {
92        return Ok(local);
93    }
94
95    #[cfg(feature = "hf-download")]
96    if let Ok(p) = hf_snapshot_dir(DEFAULT_HF_REPO) {
97        return Ok(p);
98    }
99
100    if let Some(p) = model_dir_from_bundle_manifest() {
101        return Ok(p);
102    }
103
104    if let Some(p) = hf_hub_cache_snapshot(DEFAULT_HF_REPO) {
105        return Ok(p);
106    }
107
108    bail!(
109        "KittenTTS weights not found.\n\
110         Quick start:\n\
111           just fetch-kittentts\n\
112           just kittentts-demo\n\
113         Or set RLX_KITTENTTS_DIR to a directory containing config.json, \
114         the ONNX file, and voices.npz."
115    )
116}
117
118pub fn layout_exists(dir: &Path) -> bool {
119    dir.join("config.json").is_file()
120}
121
122fn home_dir() -> PathBuf {
123    std::env::var("HOME")
124        .or_else(|_| std::env::var("USERPROFILE"))
125        .map(PathBuf::from)
126        .unwrap_or_else(|_| PathBuf::from("."))
127}
128
129pub fn hf_hub_root() -> PathBuf {
130    if let Ok(h) = std::env::var("HF_HOME") {
131        return PathBuf::from(h).join("hub");
132    }
133    if let Ok(h) = std::env::var("HUGGINGFACE_HUB_CACHE") {
134        return PathBuf::from(h);
135    }
136    home_dir().join(".cache").join("huggingface").join("hub")
137}
138
139/// Locate a cached HF snapshot without downloading (no `hf-hub` dependency).
140fn hf_hub_cache_snapshot(repo_id: &str) -> Option<PathBuf> {
141    let cache_name = format!("models--{}", repo_id.replace('/', "--"));
142    let snapshots = hf_hub_root().join(cache_name).join("snapshots");
143    let mut candidates: Vec<PathBuf> = std::fs::read_dir(&snapshots)
144        .ok()
145        .into_iter()
146        .flatten()
147        .flatten()
148        .map(|e| e.path())
149        .filter(|snap| layout_exists(snap))
150        .collect();
151    candidates.sort();
152    candidates.into_iter().last()
153}
154
155fn model_dir_from_bundle_manifest() -> Option<PathBuf> {
156    let manifest = Path::new(env!("CARGO_MANIFEST_DIR"))
157        .join("../kitten_tts_mini_rlx/weights/rlx_bundle/manifest.json");
158    let data = std::fs::read_to_string(&manifest).ok()?;
159    let v: serde_json::Value = serde_json::from_str(&data).ok()?;
160    let onnx = v.get("source_onnx")?.as_str()?;
161    let dir = PathBuf::from(onnx).parent()?.to_path_buf();
162    layout_exists(&dir).then_some(dir)
163}
164
165/// RLX ONNX bundle (`graph.json` + weights) under a native weights directory.
166pub fn find_rlx_bundle(weights_dir: &Path) -> Option<PathBuf> {
167    if let Ok(raw) =
168        std::env::var("RLX_ONNX_BUNDLE").or_else(|_| std::env::var("KITTEN_RLX_BUNDLE"))
169    {
170        let p = PathBuf::from(raw);
171        if p.join("graph.json").is_file() {
172            return Some(p);
173        }
174    }
175    let in_dir = weights_dir.join("rlx_bundle");
176    if in_dir.join("graph.json").is_file() {
177        return Some(in_dir);
178    }
179    None
180}
181
182/// Workspace decomposed weights (`crates/kitten_tts_mini_rlx/weights`).
183pub fn default_native_weights_dir() -> Option<PathBuf> {
184    if let Ok(raw) = std::env::var("KITTEN_RLX_WEIGHTS") {
185        let p = PathBuf::from(raw);
186        if p.join("model.safetensors").is_file() {
187            return Some(p);
188        }
189    }
190    let sibling = Path::new(env!("CARGO_MANIFEST_DIR")).join("../kitten_tts_mini_rlx/weights");
191    if sibling.join("model.safetensors").is_file() && find_rlx_bundle(&sibling).is_some() {
192        return Some(sibling);
193    }
194    None
195}
196
197/// Decomposed RLX weights (`model.safetensors`), if present.
198pub fn find_native_weights(model_dir: &Path) -> Option<PathBuf> {
199    if let Ok(raw) = std::env::var("KITTEN_RLX_WEIGHTS") {
200        let p = PathBuf::from(raw);
201        if p.join("model.safetensors").is_file() {
202            return Some(p);
203        }
204    }
205    if model_dir.join("model.safetensors").is_file() {
206        return Some(model_dir.to_path_buf());
207    }
208    default_native_weights_dir()
209}
210
211#[cfg(feature = "hf-download")]
212pub fn hf_snapshot_dir(repo_id: &str) -> Result<PathBuf> {
213    let api = hf_hub::api::sync::ApiBuilder::new()
214        .with_cache_dir(hf_hub_root())
215        .build()
216        .context("hf_hub ApiBuilder")?;
217    let repo = api.model(normalize_repo_id(repo_id));
218    let config = repo.get("config.json").with_context(|| {
219        format!(
220            "locate {repo_id} in Hugging Face cache under {}\n\
221             Download once: `just fetch-kittentts`",
222            hf_hub_root().display()
223        )
224    })?;
225    config
226        .parent()
227        .map(Path::to_path_buf)
228        .context("config.json has no parent (snapshot dir)")
229}
230
231#[cfg(not(feature = "hf-download"))]
232pub fn hf_snapshot_dir(_repo_id: &str) -> Result<PathBuf> {
233    bail!("rebuild with `--features hf-download` on rlx-kittentts")
234}
235
236pub fn normalize_repo_id(repo_id: &str) -> String {
237    if repo_id.contains('/') {
238        repo_id.to_string()
239    } else {
240        format!("KittenML/{repo_id}")
241    }
242}