1use super::types::NpyDtype;
6
7pub(super) const NPY_MAGIC: &[u8; 6] = b"\x93NUMPY";
8pub(super) const NPY_MAJOR: u8 = 1;
9pub(super) const NPY_MINOR: u8 = 0;
10pub fn validate_shape(shape: &[usize], data_len: usize) -> Result<(), String> {
12 let expected: usize = shape.iter().product();
13 if expected != data_len {
14 Err(format!(
15 "shape {shape:?} requires {expected} elements but got {data_len}"
16 ))
17 } else {
18 Ok(())
19 }
20}
21pub fn flat_index(indices: &[usize], shape: &[usize]) -> Result<usize, String> {
23 if indices.len() != shape.len() {
24 return Err(format!(
25 "index dimensionality {} != shape dimensionality {}",
26 indices.len(),
27 shape.len()
28 ));
29 }
30 let mut idx = 0usize;
31 let mut stride = 1usize;
32 for i in (0..shape.len()).rev() {
33 if indices[i] >= shape[i] {
34 return Err(format!(
35 "index {} out of range for axis {} with size {}",
36 indices[i], i, shape[i]
37 ));
38 }
39 idx += indices[i] * stride;
40 stride *= shape[i];
41 }
42 Ok(idx)
43}
44pub fn unravel_index(flat: usize, shape: &[usize]) -> Result<Vec<usize>, String> {
46 let total: usize = shape.iter().product();
47 if flat >= total {
48 return Err(format!("flat index {flat} out of range for total {total}"));
49 }
50 let mut indices = vec![0usize; shape.len()];
51 let mut remaining = flat;
52 for i in (0..shape.len()).rev() {
53 indices[i] = remaining % shape[i];
54 remaining /= shape[i];
55 }
56 Ok(indices)
57}
58pub(super) fn build_npy_header(dtype_str: &str, shape: &[usize]) -> Vec<u8> {
60 let shape_str = if shape.is_empty() {
61 "()".to_string()
62 } else if shape.len() == 1 {
63 format!("({},)", shape[0])
64 } else {
65 let inner: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
66 format!("({})", inner.join(", "))
67 };
68 let dict = format!(
69 "{{'descr': '{}', 'fortran_order': False, 'shape': {}, }}",
70 dtype_str, shape_str
71 );
72 let mut header_bytes = dict.into_bytes();
73 header_bytes.push(b'\n');
74 while (10 + header_bytes.len()) % 64 != 0 {
75 let last = header_bytes.len() - 1;
76 header_bytes.insert(last, b' ');
77 }
78 header_bytes
79}
80pub fn write_npy_f64(shape: &[usize], data: &[f64]) -> Vec<u8> {
82 let header_bytes = build_npy_header("<f8", shape);
83 let header_len = header_bytes.len() as u16;
84 let mut out: Vec<u8> = Vec::new();
85 out.extend_from_slice(NPY_MAGIC);
86 out.push(NPY_MAJOR);
87 out.push(NPY_MINOR);
88 out.extend_from_slice(&header_len.to_le_bytes());
89 out.extend_from_slice(&header_bytes);
90 for &v in data {
91 out.extend_from_slice(&v.to_le_bytes());
92 }
93 out
94}
95pub fn write_npy_f32(shape: &[usize], data: &[f32]) -> Vec<u8> {
97 let header_bytes = build_npy_header("<f4", shape);
98 let header_len = header_bytes.len() as u16;
99 let mut out: Vec<u8> = Vec::new();
100 out.extend_from_slice(NPY_MAGIC);
101 out.push(NPY_MAJOR);
102 out.push(NPY_MINOR);
103 out.extend_from_slice(&header_len.to_le_bytes());
104 out.extend_from_slice(&header_bytes);
105 for &v in data {
106 out.extend_from_slice(&v.to_le_bytes());
107 }
108 out
109}
110pub fn write_npy_i32(shape: &[usize], data: &[i32]) -> Vec<u8> {
112 let header_bytes = build_npy_header("<i4", shape);
113 let header_len = header_bytes.len() as u16;
114 let mut out: Vec<u8> = Vec::new();
115 out.extend_from_slice(NPY_MAGIC);
116 out.push(NPY_MAJOR);
117 out.push(NPY_MINOR);
118 out.extend_from_slice(&header_len.to_le_bytes());
119 out.extend_from_slice(&header_bytes);
120 for &v in data {
121 out.extend_from_slice(&v.to_le_bytes());
122 }
123 out
124}
125pub fn write_npy_i64(shape: &[usize], data: &[i64]) -> Vec<u8> {
127 let header_bytes = build_npy_header("<i8", shape);
128 let header_len = header_bytes.len() as u16;
129 let mut out: Vec<u8> = Vec::new();
130 out.extend_from_slice(NPY_MAGIC);
131 out.push(NPY_MAJOR);
132 out.push(NPY_MINOR);
133 out.extend_from_slice(&header_len.to_le_bytes());
134 out.extend_from_slice(&header_bytes);
135 for &v in data {
136 out.extend_from_slice(&v.to_le_bytes());
137 }
138 out
139}
140pub(super) fn parse_npy_header(bytes: &[u8]) -> Result<(String, Vec<usize>, usize), String> {
142 if bytes.len() < 10 {
143 return Err("npy data too short".to_string());
144 }
145 if &bytes[0..6] != NPY_MAGIC {
146 return Err(format!("bad npy magic: {:?}", &bytes[0..6]));
147 }
148 let major = bytes[6];
149 let minor = bytes[7];
150 if major != 1 || minor != 0 {
151 return Err(format!("unsupported npy version: {major}.{minor}"));
152 }
153 let header_len = u16::from_le_bytes([bytes[8], bytes[9]]) as usize;
154 let data_start = 10 + header_len;
155 if bytes.len() < data_start {
156 return Err("npy header truncated".to_string());
157 }
158 let header_str = std::str::from_utf8(&bytes[10..data_start])
159 .map_err(|e| format!("npy header not utf-8: {e}"))?
160 .trim();
161 let dtype_str = extract_dict_value(header_str, "descr")?;
162 let shape_str = extract_dict_value(header_str, "shape")?;
163 let shape = parse_shape_tuple(&shape_str)?;
164 Ok((dtype_str, shape, data_start))
165}
166pub(super) fn extract_dict_value(header: &str, key: &str) -> Result<String, String> {
168 let search = format!("'{key}'");
169 let pos = header
170 .find(&search)
171 .ok_or_else(|| format!("key '{key}' not found in npy header"))?;
172 let rest = &header[pos + search.len()..];
173 let rest = rest.trim_start();
174 let rest = rest
175 .strip_prefix(':')
176 .ok_or("missing ':' after key")?
177 .trim_start();
178 if rest.starts_with('\'') {
179 let inner = rest.strip_prefix('\'').expect("prefix should be present");
180 let end = inner.find('\'').ok_or("unterminated string value")?;
181 Ok(inner[..end].to_string())
182 } else if rest.starts_with('(') {
183 let end = rest.find(')').ok_or("unterminated tuple value")? + 1;
184 Ok(rest[..end].to_string())
185 } else {
186 let end = rest.find([',', '}']).unwrap_or(rest.len());
187 Ok(rest[..end].trim().to_string())
188 }
189}
190pub(super) fn parse_shape_tuple(s: &str) -> Result<Vec<usize>, String> {
192 let inner = s.trim();
193 let inner = inner
194 .strip_prefix('(')
195 .ok_or("shape missing '('")?
196 .strip_suffix(')')
197 .ok_or("shape missing ')'")?;
198 if inner.trim().is_empty() {
199 return Ok(vec![]);
200 }
201 let mut dims = Vec::new();
202 for part in inner.split(',') {
203 let part = part.trim();
204 if part.is_empty() {
205 continue;
206 }
207 let d: usize = part
208 .parse()
209 .map_err(|e| format!("bad shape dimension '{part}': {e}"))?;
210 dims.push(d);
211 }
212 Ok(dims)
213}
214pub fn read_npy_f64(bytes: &[u8]) -> Result<(Vec<usize>, Vec<f64>), String> {
216 let (dtype_str, shape, data_start) = parse_npy_header(bytes)?;
217 if dtype_str != "<f8" {
218 return Err(format!("expected dtype '<f8', got '{dtype_str}'"));
219 }
220 let n_elems: usize = shape.iter().product();
221 let expected_bytes = data_start + n_elems * 8;
222 if bytes.len() < expected_bytes {
223 return Err(format!(
224 "data truncated: expected {expected_bytes} bytes, got {}",
225 bytes.len()
226 ));
227 }
228 let mut data = Vec::with_capacity(n_elems);
229 let mut pos = data_start;
230 for _ in 0..n_elems {
231 let v = f64::from_le_bytes(
232 bytes[pos..pos + 8]
233 .try_into()
234 .expect("slice length must match"),
235 );
236 pos += 8;
237 data.push(v);
238 }
239 Ok((shape, data))
240}
241pub fn read_npy_f32(bytes: &[u8]) -> Result<(Vec<usize>, Vec<f32>), String> {
243 let (dtype_str, shape, data_start) = parse_npy_header(bytes)?;
244 if dtype_str != "<f4" {
245 return Err(format!("expected dtype '<f4', got '{dtype_str}'"));
246 }
247 let n_elems: usize = shape.iter().product();
248 let expected_bytes = data_start + n_elems * 4;
249 if bytes.len() < expected_bytes {
250 return Err(format!(
251 "data truncated: expected {expected_bytes} bytes, got {}",
252 bytes.len()
253 ));
254 }
255 let mut data = Vec::with_capacity(n_elems);
256 let mut pos = data_start;
257 for _ in 0..n_elems {
258 let v = f32::from_le_bytes(
259 bytes[pos..pos + 4]
260 .try_into()
261 .expect("slice length must match"),
262 );
263 pos += 4;
264 data.push(v);
265 }
266 Ok((shape, data))
267}
268pub fn read_npy_i32(bytes: &[u8]) -> Result<(Vec<usize>, Vec<i32>), String> {
270 let (dtype_str, shape, data_start) = parse_npy_header(bytes)?;
271 if dtype_str != "<i4" {
272 return Err(format!("expected dtype '<i4', got '{dtype_str}'"));
273 }
274 let n_elems: usize = shape.iter().product();
275 let expected_bytes = data_start + n_elems * 4;
276 if bytes.len() < expected_bytes {
277 return Err(format!(
278 "data truncated: expected {expected_bytes} bytes, got {}",
279 bytes.len()
280 ));
281 }
282 let mut data = Vec::with_capacity(n_elems);
283 let mut pos = data_start;
284 for _ in 0..n_elems {
285 let v = i32::from_le_bytes(
286 bytes[pos..pos + 4]
287 .try_into()
288 .expect("slice length must match"),
289 );
290 pos += 4;
291 data.push(v);
292 }
293 Ok((shape, data))
294}
295pub fn read_npy_i64(bytes: &[u8]) -> Result<(Vec<usize>, Vec<i64>), String> {
297 let (dtype_str, shape, data_start) = parse_npy_header(bytes)?;
298 if dtype_str != "<i8" {
299 return Err(format!("expected dtype '<i8', got '{dtype_str}'"));
300 }
301 let n_elems: usize = shape.iter().product();
302 let expected_bytes = data_start + n_elems * 8;
303 if bytes.len() < expected_bytes {
304 return Err(format!(
305 "data truncated: expected {expected_bytes} bytes, got {}",
306 bytes.len()
307 ));
308 }
309 let mut data = Vec::with_capacity(n_elems);
310 let mut pos = data_start;
311 for _ in 0..n_elems {
312 let v = i64::from_le_bytes(
313 bytes[pos..pos + 8]
314 .try_into()
315 .expect("slice length must match"),
316 );
317 pos += 8;
318 data.push(v);
319 }
320 Ok((shape, data))
321}
322pub fn detect_npy_dtype(bytes: &[u8]) -> Result<NpyDtype, String> {
324 let (dtype_str, _, _) = parse_npy_header(bytes)?;
325 NpyDtype::from_numpy_str(&dtype_str)
326}
327pub fn read_npy_shape(bytes: &[u8]) -> Result<Vec<usize>, String> {
329 let (_, shape, _) = parse_npy_header(bytes)?;
330 Ok(shape)
331}
332pub(super) fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, String> {
333 if *pos + 4 > data.len() {
334 return Err(format!("unexpected EOF reading u32 at offset {pos}"));
335 }
336 let v = u32::from_le_bytes(
337 data[*pos..*pos + 4]
338 .try_into()
339 .expect("slice length must match"),
340 );
341 *pos += 4;
342 Ok(v)
343}
344#[allow(dead_code)]
348pub fn slice_mean(data: &[f64]) -> Option<f64> {
349 if data.is_empty() {
350 return None;
351 }
352 Some(data.iter().sum::<f64>() / data.len() as f64)
353}
354#[allow(dead_code)]
356pub fn slice_var(data: &[f64]) -> Option<f64> {
357 let mean = slice_mean(data)?;
358 let var = data.iter().map(|&v| (v - mean) * (v - mean)).sum::<f64>() / data.len() as f64;
359 Some(var)
360}
361#[allow(dead_code)]
363pub fn slice_std(data: &[f64]) -> Option<f64> {
364 Some(slice_var(data)?.sqrt())
365}
366#[allow(dead_code)]
368pub fn slice_min_max(data: &[f64]) -> Option<(f64, usize, f64, usize)> {
369 if data.is_empty() {
370 return None;
371 }
372 let mut min_val = data[0];
373 let mut max_val = data[0];
374 let mut min_idx = 0;
375 let mut max_idx = 0;
376 for (i, &v) in data.iter().enumerate() {
377 if v < min_val {
378 min_val = v;
379 min_idx = i;
380 }
381 if v > max_val {
382 max_val = v;
383 max_idx = i;
384 }
385 }
386 Some((min_val, min_idx, max_val, max_idx))
387}
388#[allow(dead_code)]
392pub fn slice_percentile(data: &[f64], p: f64) -> std::result::Result<f64, String> {
393 if data.is_empty() {
394 return Err("slice_percentile: empty slice".to_string());
395 }
396 if !(0.0..=100.0).contains(&p) {
397 return Err(format!("percentile p={p} not in [0,100]"));
398 }
399 let mut sorted = data.to_vec();
400 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
401 let n = sorted.len();
402 let idx = p / 100.0 * (n - 1) as f64;
403 let lo = idx.floor() as usize;
404 let hi = idx.ceil() as usize;
405 if lo == hi {
406 return Ok(sorted[lo]);
407 }
408 let frac = idx - lo as f64;
409 Ok(sorted[lo] * (1.0 - frac) + sorted[hi] * frac)
410}
411#[allow(dead_code)]
413pub fn slice_clip(data: &[f64], lo: f64, hi: f64) -> Vec<f64> {
414 data.iter().map(|&v| v.clamp(lo, hi)).collect()
415}
416#[allow(dead_code)]
418pub fn slice_add(a: &[f64], b: &[f64]) -> std::result::Result<Vec<f64>, String> {
419 if a.len() != b.len() {
420 return Err(format!(
421 "slice_add: length mismatch {} vs {}",
422 a.len(),
423 b.len()
424 ));
425 }
426 Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect())
427}
428#[allow(dead_code)]
430pub fn slice_mul(a: &[f64], b: &[f64]) -> std::result::Result<Vec<f64>, String> {
431 if a.len() != b.len() {
432 return Err(format!(
433 "slice_mul: length mismatch {} vs {}",
434 a.len(),
435 b.len()
436 ));
437 }
438 Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect())
439}
440#[allow(dead_code)]
442pub fn slice_dot(a: &[f64], b: &[f64]) -> std::result::Result<f64, String> {
443 Ok(slice_mul(a, b)?.iter().sum())
444}
445#[allow(dead_code)]
449pub fn linspace(start: f64, stop: f64, n: usize) -> Vec<f64> {
450 if n == 0 {
451 return Vec::new();
452 }
453 if n == 1 {
454 return vec![start];
455 }
456 (0..n)
457 .map(|i| start + (stop - start) * i as f64 / (n - 1) as f64)
458 .collect()
459}
460#[allow(dead_code)]
464pub fn arange(start: f64, stop: f64, step: f64) -> std::result::Result<Vec<f64>, String> {
465 if step == 0.0 {
466 return Err("arange: step cannot be zero".to_string());
467 }
468 if (stop - start) / step < 0.0 {
469 return Ok(Vec::new());
470 }
471 let n = ((stop - start) / step).ceil() as usize;
472 Ok((0..n).map(|i| start + i as f64 * step).collect())
473}
474#[allow(dead_code)]
478pub fn logspace(start: f64, stop: f64, n: usize) -> Vec<f64> {
479 linspace(start, stop, n)
480 .into_iter()
481 .map(|v| 10.0_f64.powf(v))
482 .collect()
483}
484#[allow(dead_code)]
488pub fn transpose_2d(
489 data: &[f64],
490 shape: &[usize],
491) -> std::result::Result<(Vec<f64>, Vec<usize>), String> {
492 if shape.len() != 2 {
493 return Err(format!(
494 "transpose_2d requires 2-D shape, got {}D",
495 shape.len()
496 ));
497 }
498 let nrows = shape[0];
499 let ncols = shape[1];
500 if data.len() != nrows * ncols {
501 return Err(format!(
502 "transpose_2d: data length {} != {}*{}",
503 data.len(),
504 nrows,
505 ncols
506 ));
507 }
508 let mut out = vec![0.0_f64; nrows * ncols];
509 for r in 0..nrows {
510 for c in 0..ncols {
511 out[c * nrows + r] = data[r * ncols + c];
512 }
513 }
514 Ok((out, vec![ncols, nrows]))
515}
516#[allow(dead_code)]
519pub fn matmul(
520 a: &[f64],
521 a_shape: &[usize],
522 b: &[f64],
523 b_shape: &[usize],
524) -> std::result::Result<(Vec<f64>, Vec<usize>), String> {
525 if a_shape.len() != 2 || b_shape.len() != 2 {
526 return Err("matmul: both inputs must be 2-D".to_string());
527 }
528 let (m, k_a) = (a_shape[0], a_shape[1]);
529 let (k_b, n) = (b_shape[0], b_shape[1]);
530 if k_a != k_b {
531 return Err(format!(
532 "matmul: inner dimensions mismatch ({k_a} vs {k_b})"
533 ));
534 }
535 if a.len() != m * k_a || b.len() != k_b * n {
536 return Err("matmul: data length does not match shape".to_string());
537 }
538 let mut c = vec![0.0_f64; m * n];
539 for i in 0..m {
540 for j in 0..n {
541 let mut s = 0.0_f64;
542 for kk in 0..k_a {
543 s += a[i * k_a + kk] * b[kk * n + j];
544 }
545 c[i * n + j] = s;
546 }
547 }
548 Ok((c, vec![m, n]))
549}