Skip to main content

oxiphysics_io/numpy/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::types::NpyDtype;
6
7pub(super) const NPY_MAGIC: &[u8; 6] = b"\x93NUMPY";
8pub(super) const NPY_MAJOR: u8 = 1;
9pub(super) const NPY_MINOR: u8 = 0;
10/// Validate that a shape and data length are consistent.
11pub fn validate_shape(shape: &[usize], data_len: usize) -> Result<(), String> {
12    let expected: usize = shape.iter().product();
13    if expected != data_len {
14        Err(format!(
15            "shape {shape:?} requires {expected} elements but got {data_len}"
16        ))
17    } else {
18        Ok(())
19    }
20}
21/// Compute the flat index from multi-dimensional indices (row-major).
22pub fn flat_index(indices: &[usize], shape: &[usize]) -> Result<usize, String> {
23    if indices.len() != shape.len() {
24        return Err(format!(
25            "index dimensionality {} != shape dimensionality {}",
26            indices.len(),
27            shape.len()
28        ));
29    }
30    let mut idx = 0usize;
31    let mut stride = 1usize;
32    for i in (0..shape.len()).rev() {
33        if indices[i] >= shape[i] {
34            return Err(format!(
35                "index {} out of range for axis {} with size {}",
36                indices[i], i, shape[i]
37            ));
38        }
39        idx += indices[i] * stride;
40        stride *= shape[i];
41    }
42    Ok(idx)
43}
44/// Compute multi-dimensional indices from a flat index (row-major).
45pub fn unravel_index(flat: usize, shape: &[usize]) -> Result<Vec<usize>, String> {
46    let total: usize = shape.iter().product();
47    if flat >= total {
48        return Err(format!("flat index {flat} out of range for total {total}"));
49    }
50    let mut indices = vec![0usize; shape.len()];
51    let mut remaining = flat;
52    for i in (0..shape.len()).rev() {
53        indices[i] = remaining % shape[i];
54        remaining /= shape[i];
55    }
56    Ok(indices)
57}
58/// Build the NPY v1.0 header string for the given dtype and shape.
59pub(super) fn build_npy_header(dtype_str: &str, shape: &[usize]) -> Vec<u8> {
60    let shape_str = if shape.is_empty() {
61        "()".to_string()
62    } else if shape.len() == 1 {
63        format!("({},)", shape[0])
64    } else {
65        let inner: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
66        format!("({})", inner.join(", "))
67    };
68    let dict = format!(
69        "{{'descr': '{}', 'fortran_order': False, 'shape': {}, }}",
70        dtype_str, shape_str
71    );
72    let mut header_bytes = dict.into_bytes();
73    header_bytes.push(b'\n');
74    while (10 + header_bytes.len()) % 64 != 0 {
75        let last = header_bytes.len() - 1;
76        header_bytes.insert(last, b' ');
77    }
78    header_bytes
79}
80/// Assemble a complete `.npy` v1.0 byte sequence for a `f64` array.
81pub fn write_npy_f64(shape: &[usize], data: &[f64]) -> Vec<u8> {
82    let header_bytes = build_npy_header("<f8", shape);
83    let header_len = header_bytes.len() as u16;
84    let mut out: Vec<u8> = Vec::new();
85    out.extend_from_slice(NPY_MAGIC);
86    out.push(NPY_MAJOR);
87    out.push(NPY_MINOR);
88    out.extend_from_slice(&header_len.to_le_bytes());
89    out.extend_from_slice(&header_bytes);
90    for &v in data {
91        out.extend_from_slice(&v.to_le_bytes());
92    }
93    out
94}
95/// Assemble a complete `.npy` v1.0 byte sequence for an `f32` array.
96pub fn write_npy_f32(shape: &[usize], data: &[f32]) -> Vec<u8> {
97    let header_bytes = build_npy_header("<f4", shape);
98    let header_len = header_bytes.len() as u16;
99    let mut out: Vec<u8> = Vec::new();
100    out.extend_from_slice(NPY_MAGIC);
101    out.push(NPY_MAJOR);
102    out.push(NPY_MINOR);
103    out.extend_from_slice(&header_len.to_le_bytes());
104    out.extend_from_slice(&header_bytes);
105    for &v in data {
106        out.extend_from_slice(&v.to_le_bytes());
107    }
108    out
109}
110/// Assemble a complete `.npy` v1.0 byte sequence for an `i32` array.
111pub fn write_npy_i32(shape: &[usize], data: &[i32]) -> Vec<u8> {
112    let header_bytes = build_npy_header("<i4", shape);
113    let header_len = header_bytes.len() as u16;
114    let mut out: Vec<u8> = Vec::new();
115    out.extend_from_slice(NPY_MAGIC);
116    out.push(NPY_MAJOR);
117    out.push(NPY_MINOR);
118    out.extend_from_slice(&header_len.to_le_bytes());
119    out.extend_from_slice(&header_bytes);
120    for &v in data {
121        out.extend_from_slice(&v.to_le_bytes());
122    }
123    out
124}
125/// Assemble a complete `.npy` v1.0 byte sequence for an `i64` array.
126pub fn write_npy_i64(shape: &[usize], data: &[i64]) -> Vec<u8> {
127    let header_bytes = build_npy_header("<i8", shape);
128    let header_len = header_bytes.len() as u16;
129    let mut out: Vec<u8> = Vec::new();
130    out.extend_from_slice(NPY_MAGIC);
131    out.push(NPY_MAJOR);
132    out.push(NPY_MINOR);
133    out.extend_from_slice(&header_len.to_le_bytes());
134    out.extend_from_slice(&header_bytes);
135    for &v in data {
136        out.extend_from_slice(&v.to_le_bytes());
137    }
138    out
139}
140/// Parse the NPY v1.0 header and return `(dtype_str, shape, data_offset)`.
141pub(super) fn parse_npy_header(bytes: &[u8]) -> Result<(String, Vec<usize>, usize), String> {
142    if bytes.len() < 10 {
143        return Err("npy data too short".to_string());
144    }
145    if &bytes[0..6] != NPY_MAGIC {
146        return Err(format!("bad npy magic: {:?}", &bytes[0..6]));
147    }
148    let major = bytes[6];
149    let minor = bytes[7];
150    if major != 1 || minor != 0 {
151        return Err(format!("unsupported npy version: {major}.{minor}"));
152    }
153    let header_len = u16::from_le_bytes([bytes[8], bytes[9]]) as usize;
154    let data_start = 10 + header_len;
155    if bytes.len() < data_start {
156        return Err("npy header truncated".to_string());
157    }
158    let header_str = std::str::from_utf8(&bytes[10..data_start])
159        .map_err(|e| format!("npy header not utf-8: {e}"))?
160        .trim();
161    let dtype_str = extract_dict_value(header_str, "descr")?;
162    let shape_str = extract_dict_value(header_str, "shape")?;
163    let shape = parse_shape_tuple(&shape_str)?;
164    Ok((dtype_str, shape, data_start))
165}
166/// Extract a value from a Python-style dict string for a given key.
167pub(super) fn extract_dict_value(header: &str, key: &str) -> Result<String, String> {
168    let search = format!("'{key}'");
169    let pos = header
170        .find(&search)
171        .ok_or_else(|| format!("key '{key}' not found in npy header"))?;
172    let rest = &header[pos + search.len()..];
173    let rest = rest.trim_start();
174    let rest = rest
175        .strip_prefix(':')
176        .ok_or("missing ':' after key")?
177        .trim_start();
178    if rest.starts_with('\'') {
179        let inner = rest.strip_prefix('\'').expect("prefix should be present");
180        let end = inner.find('\'').ok_or("unterminated string value")?;
181        Ok(inner[..end].to_string())
182    } else if rest.starts_with('(') {
183        let end = rest.find(')').ok_or("unterminated tuple value")? + 1;
184        Ok(rest[..end].to_string())
185    } else {
186        let end = rest.find([',', '}']).unwrap_or(rest.len());
187        Ok(rest[..end].trim().to_string())
188    }
189}
190/// Parse a Python tuple string like `"(3, 2)"` or `"(6,)"` or `"()"` into `Vec`usize`.
191pub(super) fn parse_shape_tuple(s: &str) -> Result<Vec<usize>, String> {
192    let inner = s.trim();
193    let inner = inner
194        .strip_prefix('(')
195        .ok_or("shape missing '('")?
196        .strip_suffix(')')
197        .ok_or("shape missing ')'")?;
198    if inner.trim().is_empty() {
199        return Ok(vec![]);
200    }
201    let mut dims = Vec::new();
202    for part in inner.split(',') {
203        let part = part.trim();
204        if part.is_empty() {
205            continue;
206        }
207        let d: usize = part
208            .parse()
209            .map_err(|e| format!("bad shape dimension '{part}': {e}"))?;
210        dims.push(d);
211    }
212    Ok(dims)
213}
214/// Parse a `.npy` byte buffer and return `(shape, f64 data)`.
215pub fn read_npy_f64(bytes: &[u8]) -> Result<(Vec<usize>, Vec<f64>), String> {
216    let (dtype_str, shape, data_start) = parse_npy_header(bytes)?;
217    if dtype_str != "<f8" {
218        return Err(format!("expected dtype '<f8', got '{dtype_str}'"));
219    }
220    let n_elems: usize = shape.iter().product();
221    let expected_bytes = data_start + n_elems * 8;
222    if bytes.len() < expected_bytes {
223        return Err(format!(
224            "data truncated: expected {expected_bytes} bytes, got {}",
225            bytes.len()
226        ));
227    }
228    let mut data = Vec::with_capacity(n_elems);
229    let mut pos = data_start;
230    for _ in 0..n_elems {
231        let v = f64::from_le_bytes(
232            bytes[pos..pos + 8]
233                .try_into()
234                .expect("slice length must match"),
235        );
236        pos += 8;
237        data.push(v);
238    }
239    Ok((shape, data))
240}
241/// Parse a `.npy` byte buffer and return `(shape, f32 data)`.
242pub fn read_npy_f32(bytes: &[u8]) -> Result<(Vec<usize>, Vec<f32>), String> {
243    let (dtype_str, shape, data_start) = parse_npy_header(bytes)?;
244    if dtype_str != "<f4" {
245        return Err(format!("expected dtype '<f4', got '{dtype_str}'"));
246    }
247    let n_elems: usize = shape.iter().product();
248    let expected_bytes = data_start + n_elems * 4;
249    if bytes.len() < expected_bytes {
250        return Err(format!(
251            "data truncated: expected {expected_bytes} bytes, got {}",
252            bytes.len()
253        ));
254    }
255    let mut data = Vec::with_capacity(n_elems);
256    let mut pos = data_start;
257    for _ in 0..n_elems {
258        let v = f32::from_le_bytes(
259            bytes[pos..pos + 4]
260                .try_into()
261                .expect("slice length must match"),
262        );
263        pos += 4;
264        data.push(v);
265    }
266    Ok((shape, data))
267}
268/// Parse a `.npy` byte buffer and return `(shape, i32 data)`.
269pub fn read_npy_i32(bytes: &[u8]) -> Result<(Vec<usize>, Vec<i32>), String> {
270    let (dtype_str, shape, data_start) = parse_npy_header(bytes)?;
271    if dtype_str != "<i4" {
272        return Err(format!("expected dtype '<i4', got '{dtype_str}'"));
273    }
274    let n_elems: usize = shape.iter().product();
275    let expected_bytes = data_start + n_elems * 4;
276    if bytes.len() < expected_bytes {
277        return Err(format!(
278            "data truncated: expected {expected_bytes} bytes, got {}",
279            bytes.len()
280        ));
281    }
282    let mut data = Vec::with_capacity(n_elems);
283    let mut pos = data_start;
284    for _ in 0..n_elems {
285        let v = i32::from_le_bytes(
286            bytes[pos..pos + 4]
287                .try_into()
288                .expect("slice length must match"),
289        );
290        pos += 4;
291        data.push(v);
292    }
293    Ok((shape, data))
294}
295/// Parse a `.npy` byte buffer and return `(shape, i64 data)`.
296pub fn read_npy_i64(bytes: &[u8]) -> Result<(Vec<usize>, Vec<i64>), String> {
297    let (dtype_str, shape, data_start) = parse_npy_header(bytes)?;
298    if dtype_str != "<i8" {
299        return Err(format!("expected dtype '<i8', got '{dtype_str}'"));
300    }
301    let n_elems: usize = shape.iter().product();
302    let expected_bytes = data_start + n_elems * 8;
303    if bytes.len() < expected_bytes {
304        return Err(format!(
305            "data truncated: expected {expected_bytes} bytes, got {}",
306            bytes.len()
307        ));
308    }
309    let mut data = Vec::with_capacity(n_elems);
310    let mut pos = data_start;
311    for _ in 0..n_elems {
312        let v = i64::from_le_bytes(
313            bytes[pos..pos + 8]
314                .try_into()
315                .expect("slice length must match"),
316        );
317        pos += 8;
318        data.push(v);
319    }
320    Ok((shape, data))
321}
322/// Auto-detect dtype from header and return the NpyDtype.
323pub fn detect_npy_dtype(bytes: &[u8]) -> Result<NpyDtype, String> {
324    let (dtype_str, _, _) = parse_npy_header(bytes)?;
325    NpyDtype::from_numpy_str(&dtype_str)
326}
327/// Auto-detect and return the shape from a .npy byte buffer.
328pub fn read_npy_shape(bytes: &[u8]) -> Result<Vec<usize>, String> {
329    let (_, shape, _) = parse_npy_header(bytes)?;
330    Ok(shape)
331}
332pub(super) fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, String> {
333    if *pos + 4 > data.len() {
334        return Err(format!("unexpected EOF reading u32 at offset {pos}"));
335    }
336    let v = u32::from_le_bytes(
337        data[*pos..*pos + 4]
338            .try_into()
339            .expect("slice length must match"),
340    );
341    *pos += 4;
342    Ok(v)
343}
344/// Compute the arithmetic mean of a slice.
345///
346/// Returns `None` if the slice is empty.
347#[allow(dead_code)]
348pub fn slice_mean(data: &[f64]) -> Option<f64> {
349    if data.is_empty() {
350        return None;
351    }
352    Some(data.iter().sum::<f64>() / data.len() as f64)
353}
354/// Compute the variance of a slice (population variance, ddof=0).
355#[allow(dead_code)]
356pub fn slice_var(data: &[f64]) -> Option<f64> {
357    let mean = slice_mean(data)?;
358    let var = data.iter().map(|&v| (v - mean) * (v - mean)).sum::<f64>() / data.len() as f64;
359    Some(var)
360}
361/// Compute the standard deviation (population, ddof=0).
362#[allow(dead_code)]
363pub fn slice_std(data: &[f64]) -> Option<f64> {
364    Some(slice_var(data)?.sqrt())
365}
366/// Compute min, max, and their flat indices.
367#[allow(dead_code)]
368pub fn slice_min_max(data: &[f64]) -> Option<(f64, usize, f64, usize)> {
369    if data.is_empty() {
370        return None;
371    }
372    let mut min_val = data[0];
373    let mut max_val = data[0];
374    let mut min_idx = 0;
375    let mut max_idx = 0;
376    for (i, &v) in data.iter().enumerate() {
377        if v < min_val {
378            min_val = v;
379            min_idx = i;
380        }
381        if v > max_val {
382            max_val = v;
383            max_idx = i;
384        }
385    }
386    Some((min_val, min_idx, max_val, max_idx))
387}
388/// Compute the p-th percentile of a slice using linear interpolation.
389///
390/// `p` must be in `\[0, 100\]`.
391#[allow(dead_code)]
392pub fn slice_percentile(data: &[f64], p: f64) -> std::result::Result<f64, String> {
393    if data.is_empty() {
394        return Err("slice_percentile: empty slice".to_string());
395    }
396    if !(0.0..=100.0).contains(&p) {
397        return Err(format!("percentile p={p} not in [0,100]"));
398    }
399    let mut sorted = data.to_vec();
400    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
401    let n = sorted.len();
402    let idx = p / 100.0 * (n - 1) as f64;
403    let lo = idx.floor() as usize;
404    let hi = idx.ceil() as usize;
405    if lo == hi {
406        return Ok(sorted[lo]);
407    }
408    let frac = idx - lo as f64;
409    Ok(sorted[lo] * (1.0 - frac) + sorted[hi] * frac)
410}
411/// Clip values to `\[lo, hi\]`.
412#[allow(dead_code)]
413pub fn slice_clip(data: &[f64], lo: f64, hi: f64) -> Vec<f64> {
414    data.iter().map(|&v| v.clamp(lo, hi)).collect()
415}
416/// Element-wise sum of two equal-length slices.
417#[allow(dead_code)]
418pub fn slice_add(a: &[f64], b: &[f64]) -> std::result::Result<Vec<f64>, String> {
419    if a.len() != b.len() {
420        return Err(format!(
421            "slice_add: length mismatch {} vs {}",
422            a.len(),
423            b.len()
424        ));
425    }
426    Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect())
427}
428/// Element-wise product of two equal-length slices.
429#[allow(dead_code)]
430pub fn slice_mul(a: &[f64], b: &[f64]) -> std::result::Result<Vec<f64>, String> {
431    if a.len() != b.len() {
432        return Err(format!(
433            "slice_mul: length mismatch {} vs {}",
434            a.len(),
435            b.len()
436        ));
437    }
438    Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect())
439}
440/// Dot product of two equal-length slices.
441#[allow(dead_code)]
442pub fn slice_dot(a: &[f64], b: &[f64]) -> std::result::Result<f64, String> {
443    Ok(slice_mul(a, b)?.iter().sum())
444}
445/// Generate `n` equally-spaced values from `start` to `stop` (inclusive).
446///
447/// Equivalent to `numpy.linspace(start, stop, num=n)`.
448#[allow(dead_code)]
449pub fn linspace(start: f64, stop: f64, n: usize) -> Vec<f64> {
450    if n == 0 {
451        return Vec::new();
452    }
453    if n == 1 {
454        return vec![start];
455    }
456    (0..n)
457        .map(|i| start + (stop - start) * i as f64 / (n - 1) as f64)
458        .collect()
459}
460/// Generate values from `start` to `stop` (exclusive) with step `step`.
461///
462/// Equivalent to `numpy.arange(start, stop, step)`.
463#[allow(dead_code)]
464pub fn arange(start: f64, stop: f64, step: f64) -> std::result::Result<Vec<f64>, String> {
465    if step == 0.0 {
466        return Err("arange: step cannot be zero".to_string());
467    }
468    if (stop - start) / step < 0.0 {
469        return Ok(Vec::new());
470    }
471    let n = ((stop - start) / step).ceil() as usize;
472    Ok((0..n).map(|i| start + i as f64 * step).collect())
473}
474/// Generate `n` log-spaced values from `10^start` to `10^stop`.
475///
476/// Equivalent to `numpy.logspace(start, stop, num=n)`.
477#[allow(dead_code)]
478pub fn logspace(start: f64, stop: f64, n: usize) -> Vec<f64> {
479    linspace(start, stop, n)
480        .into_iter()
481        .map(|v| 10.0_f64.powf(v))
482        .collect()
483}
484/// Transpose a 2-D row-major matrix stored as a flat `Vec`f64`.
485///
486/// `shape` must be `[nrows, ncols]`. Returns `(transposed_data, [ncols, nrows])`.
487#[allow(dead_code)]
488pub fn transpose_2d(
489    data: &[f64],
490    shape: &[usize],
491) -> std::result::Result<(Vec<f64>, Vec<usize>), String> {
492    if shape.len() != 2 {
493        return Err(format!(
494            "transpose_2d requires 2-D shape, got {}D",
495            shape.len()
496        ));
497    }
498    let nrows = shape[0];
499    let ncols = shape[1];
500    if data.len() != nrows * ncols {
501        return Err(format!(
502            "transpose_2d: data length {} != {}*{}",
503            data.len(),
504            nrows,
505            ncols
506        ));
507    }
508    let mut out = vec![0.0_f64; nrows * ncols];
509    for r in 0..nrows {
510        for c in 0..ncols {
511            out[c * nrows + r] = data[r * ncols + c];
512        }
513    }
514    Ok((out, vec![ncols, nrows]))
515}
516/// Compute the matrix product C = A * B where A is (m×k) and B is (k×n),
517/// both stored as flat row-major `f64` slices.
518#[allow(dead_code)]
519pub fn matmul(
520    a: &[f64],
521    a_shape: &[usize],
522    b: &[f64],
523    b_shape: &[usize],
524) -> std::result::Result<(Vec<f64>, Vec<usize>), String> {
525    if a_shape.len() != 2 || b_shape.len() != 2 {
526        return Err("matmul: both inputs must be 2-D".to_string());
527    }
528    let (m, k_a) = (a_shape[0], a_shape[1]);
529    let (k_b, n) = (b_shape[0], b_shape[1]);
530    if k_a != k_b {
531        return Err(format!(
532            "matmul: inner dimensions mismatch ({k_a} vs {k_b})"
533        ));
534    }
535    if a.len() != m * k_a || b.len() != k_b * n {
536        return Err("matmul: data length does not match shape".to_string());
537    }
538    let mut c = vec![0.0_f64; m * n];
539    for i in 0..m {
540        for j in 0..n {
541            let mut s = 0.0_f64;
542            for kk in 0..k_a {
543                s += a[i * k_a + kk] * b[kk * n + j];
544            }
545            c[i * n + j] = s;
546        }
547    }
548    Ok((c, vec![m, n]))
549}