Skip to main content

rlx_kittentts/
npz.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Minimal NPZ / NPY loader.
17//!
18//! Supports the subset of the NumPy array format actually used by KittenTTS:
19//!   - NPY format version 1.0 and 2.0
20//!   - `float32` dtype (`<f4`, `=f4`)
21//!   - C-contiguous (row-major) layout
22//!   - Arbitrary number of dimensions (we use 2-D voice matrices)
23//!
24//! NPZ files are simply ZIP archives whose members are `.npy` files.
25//! Each member name without its `.npy` extension is the array name.
26
27use anyhow::{Context, Result, bail};
28use std::{collections::HashMap, io::Read, path::Path};
29use zip::ZipArchive;
30
31// ─────────────────────────────────────────────────────────────────────────────
32// NPY header parser
33// ─────────────────────────────────────────────────────────────────────────────
34
35/// Parse a raw `.npy` byte buffer and return the f32 data as a flat `Vec<f32>`
36/// together with the shape.
37pub fn parse_npy(data: &[u8]) -> Result<(Vec<usize>, Vec<f32>)> {
38    // Magic: 6 bytes "\x93NUMPY"
39    if data.len() < 10 || &data[..6] != b"\x93NUMPY" {
40        bail!("Not a valid NPY file (bad magic)");
41    }
42
43    let major = data[6];
44    let minor = data[7];
45
46    // Header length: 2 bytes (v1) or 4 bytes (v2), little-endian.
47    let (header_len, header_start) = match (major, minor) {
48        (1, _) => {
49            let len = u16::from_le_bytes([data[8], data[9]]) as usize;
50            (len, 10)
51        }
52        (2, _) => {
53            if data.len() < 12 {
54                bail!("NPY v2 file too short");
55            }
56            let len = u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize;
57            (len, 12)
58        }
59        _ => bail!("Unsupported NPY version {}.{}", major, minor),
60    };
61
62    let header_end = header_start + header_len;
63    if data.len() < header_end {
64        bail!("NPY file truncated in header");
65    }
66    let header = std::str::from_utf8(&data[header_start..header_end])
67        .context("NPY header is not valid UTF-8")?;
68
69    // Parse dtype
70    let dtype = extract_header_field(header, "descr").context("NPY header missing 'descr'")?;
71    let dtype = dtype.trim().trim_matches('\'').trim_matches('"');
72
73    // We accept little-endian and native-endian float32 only.
74    let is_f32 = matches!(dtype, "<f4" | "=f4" | "|f4" | ">f4");
75    if !is_f32 {
76        bail!("Unsupported dtype '{}' — only float32 is supported", dtype);
77    }
78    let big_endian = dtype.starts_with('>');
79
80    // Parse fortran_order
81    let fortran = extract_header_field(header, "fortran_order")
82        .unwrap_or("False")
83        .trim()
84        .to_ascii_lowercase();
85    if fortran == "true" {
86        bail!("Fortran-order arrays are not supported");
87    }
88
89    // Parse shape — e.g. "(256, 512, )" or "(100,)"
90    let shape_str = extract_header_field(header, "shape").context("NPY header missing 'shape'")?;
91    let shape = parse_shape(shape_str.trim())?;
92
93    // Total number of elements
94    let n_elements: usize = shape.iter().product();
95
96    // Raw bytes start right after the header
97    let data_bytes = &data[header_end..];
98    if data_bytes.len() < n_elements * 4 {
99        bail!(
100            "NPY data section too short: expected {} bytes, got {}",
101            n_elements * 4,
102            data_bytes.len()
103        );
104    }
105
106    // Read f32 values
107    let values: Vec<f32> = data_bytes[..n_elements * 4]
108        .chunks_exact(4)
109        .map(|b| {
110            let arr = [b[0], b[1], b[2], b[3]];
111            if big_endian {
112                f32::from_be_bytes(arr)
113            } else {
114                f32::from_le_bytes(arr)
115            }
116        })
117        .collect();
118
119    Ok((shape, values))
120}
121
122/// Extract the value of a field from a Python-literal dict header string.
123///
124/// e.g. `extract_header_field("{'descr': '<f4', 'shape': (3,)}", "descr")`
125/// returns `Some("<f4")`.
126fn extract_header_field<'a>(header: &'a str, field: &str) -> Option<&'a str> {
127    // Look for `'field':` or `"field":`.
128    let key_sq = format!("'{}':", field);
129    let key_dq = format!("\"{}\":", field);
130
131    let start = header
132        .find(key_sq.as_str())
133        .map(|p| p + key_sq.len())
134        .or_else(|| header.find(key_dq.as_str()).map(|p| p + key_dq.len()))?;
135
136    let rest = header[start..].trim_start();
137
138    // Value is either a Python string (quoted), tuple (parentheses), or a bare word.
139    if rest.starts_with('(') {
140        // Tuple — find the matching closing paren
141        let end = rest.find(')')?;
142        Some(&rest[..end + 1])
143    } else if rest.starts_with('\'') || rest.starts_with('"') {
144        let quote = rest.chars().next()?;
145        let inner = &rest[1..];
146        let end = inner.find(quote)?;
147        Some(&inner[..end])
148    } else {
149        // Bare value (True, False, or a number) — read until comma or }
150        let end = rest.find([',', '}']).unwrap_or(rest.len());
151        Some(rest[..end].trim())
152    }
153}
154
155/// Parse a Python-style shape tuple like `(256, 512, )` or `(100,)` or `()`.
156fn parse_shape(s: &str) -> Result<Vec<usize>> {
157    let inner = s.trim_start_matches('(').trim_end_matches(')');
158    if inner.trim().is_empty() {
159        return Ok(vec![]);
160    }
161    inner
162        .split(',')
163        .map(|t| t.trim())
164        .filter(|t| !t.is_empty())
165        .map(|t| {
166            t.parse::<usize>()
167                .with_context(|| format!("Bad shape dim: '{}'", t))
168        })
169        .collect()
170}
171
172// ─────────────────────────────────────────────────────────────────────────────
173// NPZ loader — returns flat f32 data per array name
174// ─────────────────────────────────────────────────────────────────────────────
175
176/// A loaded NPZ entry: shape + flat f32 data in row-major (C) order.
177pub struct NpyArray {
178    pub shape: Vec<usize>,
179    pub data: Vec<f32>,
180}
181
182impl NpyArray {
183    /// Number of rows (first dimension).
184    pub fn nrows(&self) -> usize {
185        self.shape.first().copied().unwrap_or(0)
186    }
187
188    /// Number of columns (second dimension).
189    pub fn ncols(&self) -> usize {
190        self.shape.get(1).copied().unwrap_or(1)
191    }
192
193    /// Get row `i` as a slice of f32 values.
194    pub fn row(&self, i: usize) -> &[f32] {
195        let ncols = self.ncols();
196        &self.data[i * ncols..(i + 1) * ncols]
197    }
198}
199
200/// Load an NPZ file and return all arrays indexed by name (`.npy` extension stripped).
201pub fn load_npz(path: &Path) -> Result<HashMap<String, NpyArray>> {
202    let file = std::fs::File::open(path)
203        .with_context(|| format!("Cannot open NPZ file: {}", path.display()))?;
204    let mut archive = ZipArchive::new(file)
205        .with_context(|| format!("Cannot open ZIP archive: {}", path.display()))?;
206
207    let mut arrays = HashMap::new();
208
209    for i in 0..archive.len() {
210        let mut entry = archive.by_index(i).context("Failed to read ZIP entry")?;
211        let name = entry.name().trim_end_matches(".npy").to_string();
212
213        let mut buf = Vec::with_capacity(entry.size() as usize);
214        entry
215            .read_to_end(&mut buf)
216            .context("Failed to read NPY entry")?;
217
218        let (shape, data) =
219            parse_npy(&buf).with_context(|| format!("Failed to parse NPY entry '{}'", name))?;
220
221        arrays.insert(name, NpyArray { shape, data });
222    }
223
224    Ok(arrays)
225}
226
227// ─────────────────────────────────────────────────────────────────────────────
228// Tests
229// ─────────────────────────────────────────────────────────────────────────────
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    /// Build a minimal v1.0 NPY byte buffer for testing.
236    fn make_npy(shape: &[usize], values: &[f32]) -> Vec<u8> {
237        let header_str = format!(
238            "{{'descr': '<f4', 'fortran_order': False, 'shape': ({},), }}",
239            shape
240                .iter()
241                .map(|d| d.to_string())
242                .collect::<Vec<_>>()
243                .join(", ")
244        );
245        // Pad to multiple of 64 bytes (total header block = 10 + header_len, padded)
246        let raw_len = header_str.len() + 1; // +1 for trailing \n
247        let padded_len = raw_len.div_ceil(64) * 64;
248        let _header_len = padded_len - 1; // without the final \n counted separately
249        // Actually NPY spec: padding is spaces, last char is \n, total header_len bytes
250        let pad_needed = padded_len - raw_len;
251        let mut header = header_str;
252        for _ in 0..pad_needed {
253            header.push(' ');
254        }
255        header.push('\n');
256
257        let header_len_u16 = header.len() as u16;
258
259        let mut buf = Vec::new();
260        buf.extend_from_slice(b"\x93NUMPY");
261        buf.push(1); // major
262        buf.push(0); // minor
263        buf.extend_from_slice(&header_len_u16.to_le_bytes());
264        buf.extend_from_slice(header.as_bytes());
265        for &v in values {
266            buf.extend_from_slice(&v.to_le_bytes());
267        }
268        buf
269    }
270
271    #[test]
272    fn test_parse_npy_1d() {
273        let values = vec![1.0f32, 2.0, 3.0];
274        let buf = make_npy(&[3], &values);
275        let (shape, data) = parse_npy(&buf).unwrap();
276        assert_eq!(shape, vec![3]);
277        assert_eq!(data, values);
278    }
279
280    #[test]
281    fn test_parse_npy_2d() {
282        let values: Vec<f32> = (0..6).map(|x| x as f32).collect();
283        let buf = make_npy(&[2, 3], &values);
284        let (shape, data) = parse_npy(&buf).unwrap();
285        assert_eq!(shape, vec![2, 3]);
286        assert_eq!(data, values);
287    }
288
289    #[test]
290    fn test_npy_array_row() {
291        let values: Vec<f32> = (0..6).map(|x| x as f32).collect();
292        let buf = make_npy(&[2, 3], &values);
293        let (shape, data) = parse_npy(&buf).unwrap();
294        let arr = NpyArray { shape, data };
295        assert_eq!(arr.row(0), &[0.0, 1.0, 2.0]);
296        assert_eq!(arr.row(1), &[3.0, 4.0, 5.0]);
297    }
298
299    #[test]
300    fn test_bad_magic() {
301        let result = parse_npy(b"NOTANPY");
302        assert!(result.is_err());
303    }
304}