ferray-io 0.3.1

NumPy-compatible file I/O (.npy, .npz, memory-mapped) for ferray
Documentation
// ferray-io: NumPy dtype string parsing
//
// Parses dtype descriptor strings like "<f8", "|b1", ">i4" into
// (DType, Endianness) pairs.

use ferray_core::dtype::DType;
use ferray_core::error::{FerrayError, FerrayResult};

/// Byte order / endianness extracted from a `NumPy` dtype string.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Endianness {
    /// Little-endian (`<` prefix).
    Little,
    /// Big-endian (`>` prefix).
    Big,
    /// Not applicable / single byte (`|` prefix, or `=` for native).
    Native,
}

impl Endianness {
    /// Whether byte-swapping is needed on the current platform.
    #[inline]
    #[must_use]
    pub const fn needs_swap(self) -> bool {
        match self {
            Self::Little => cfg!(target_endian = "big"),
            Self::Big => cfg!(target_endian = "little"),
            Self::Native => false,
        }
    }
}

/// Parse a `NumPy` dtype descriptor string into a `(DType, Endianness)` pair.
///
/// Examples:
/// - `"<f8"` -> `(DType::F64, Endianness::Little)`
/// - `"|b1"` -> `(DType::Bool, Endianness::Native)`
/// - `">i4"` -> `(DType::I32, Endianness::Big)`
pub fn parse_dtype_str(s: &str) -> FerrayResult<(DType, Endianness)> {
    if s.len() < 2 {
        return Err(FerrayError::invalid_dtype(format!(
            "dtype string too short: '{s}'"
        )));
    }

    let (endian, type_str) = match s.as_bytes()[0] {
        b'<' => (Endianness::Little, &s[1..]),
        b'>' => (Endianness::Big, &s[1..]),
        b'|' => (Endianness::Native, &s[1..]),
        b'=' => (Endianness::Native, &s[1..]),
        // If no endian prefix, assume native
        _ => (Endianness::Native, s),
    };

    let dtype = match type_str {
        "b1" => DType::Bool,
        "u1" => DType::U8,
        "u2" => DType::U16,
        "u4" => DType::U32,
        "u8" => DType::U64,
        "u16" => DType::U128,
        "i1" => DType::I8,
        "i2" => DType::I16,
        "i4" => DType::I32,
        "i8" => DType::I64,
        "i16" => DType::I128,
        // IEEE 754 binary16 — standard NumPy `np.float16` descriptor.
        #[cfg(feature = "f16")]
        "f2" => DType::F16,
        "f4" => DType::F32,
        "f8" => DType::F64,
        // bfloat16 — non-standard NumPy descriptor. We use the ferray-specific
        // tag `"bf16"` so ferray round-trip is lossless; files saved this way
        // are NOT loadable by vanilla NumPy (which has no native bf16 dtype).
        #[cfg(feature = "bf16")]
        "bf16" => DType::BF16,
        "c8" => DType::Complex32,
        "c16" => DType::Complex64,
        // datetime64 / timedelta64 with explicit unit:
        //   "M8[ns]" / "m8[ns]" etc. The leading 'M'/'m' is the type tag
        //   (M = datetime, m = timedelta), '8' is the byte size, and the
        //   bracketed suffix is the TimeUnit.
        other if other.starts_with("M8[") && other.ends_with(']') => {
            let unit_str = &other[3..other.len() - 1];
            let unit =
                ferray_core::dtype::TimeUnit::from_descr_suffix(unit_str).ok_or_else(|| {
                    FerrayError::invalid_dtype(format!("unknown datetime64 unit: '{unit_str}'"))
                })?;
            DType::DateTime64(unit)
        }
        other if other.starts_with("m8[") && other.ends_with(']') => {
            let unit_str = &other[3..other.len() - 1];
            let unit =
                ferray_core::dtype::TimeUnit::from_descr_suffix(unit_str).ok_or_else(|| {
                    FerrayError::invalid_dtype(format!("unknown timedelta64 unit: '{unit_str}'"))
                })?;
            DType::Timedelta64(unit)
        }
        _ => {
            return Err(FerrayError::invalid_dtype(format!(
                "unsupported dtype descriptor: '{s}'"
            )));
        }
    };

    Ok((dtype, endian))
}

/// Convert a `DType` to its `NumPy` dtype descriptor string with the given endianness.
///
/// # Errors
/// Returns `FerrayError::InvalidDtype` for unsupported dtype variants.
pub fn dtype_to_descr(dtype: DType, endian: Endianness) -> FerrayResult<String> {
    let prefix = match endian {
        Endianness::Little => '<',
        Endianness::Big => '>',
        Endianness::Native => '|',
    };

    // datetime64 / timedelta64 carry a unit and so don't fit the simple
    // `<type_str>` prefix path; format them directly here.
    if let DType::DateTime64(u) = dtype {
        let actual_prefix = match endian {
            Endianness::Native => '|',
            _ => prefix,
        };
        return Ok(format!("{actual_prefix}M8[{}]", u.descr_suffix()));
    }
    if let DType::Timedelta64(u) = dtype {
        let actual_prefix = match endian {
            Endianness::Native => '|',
            _ => prefix,
        };
        return Ok(format!("{actual_prefix}m8[{}]", u.descr_suffix()));
    }

    let type_str = match dtype {
        DType::Bool => "b1",
        DType::U8 => "u1",
        DType::U16 => "u2",
        DType::U32 => "u4",
        DType::U64 => "u8",
        DType::U128 => "u16",
        DType::I8 => "i1",
        DType::I16 => "i2",
        DType::I32 => "i4",
        DType::I64 => "i8",
        DType::I128 => "i16",
        #[cfg(feature = "f16")]
        DType::F16 => "f2",
        DType::F32 => "f4",
        DType::F64 => "f8",
        #[cfg(feature = "bf16")]
        DType::BF16 => "bf16",
        DType::Complex32 => "c8",
        DType::Complex64 => "c16",
        _ => {
            return Err(FerrayError::invalid_dtype(format!(
                "unsupported dtype for descriptor: {dtype:?}"
            )));
        }
    };

    // Bool and single-byte types use '|' regardless
    let actual_prefix = match dtype {
        DType::Bool | DType::U8 | DType::I8 => '|',
        _ => prefix,
    };

    Ok(format!("{actual_prefix}{type_str}"))
}

/// Return the native-endian dtype descriptor for a given `DType`.
///
/// # Errors
/// Returns `FerrayError::InvalidDtype` for unsupported dtype variants.
pub fn dtype_to_native_descr(dtype: DType) -> FerrayResult<String> {
    let endian = if cfg!(target_endian = "little") {
        Endianness::Little
    } else {
        Endianness::Big
    };
    dtype_to_descr(dtype, endian)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parse_common_dtypes() {
        assert_eq!(
            parse_dtype_str("<f8").unwrap(),
            (DType::F64, Endianness::Little)
        );
        assert_eq!(
            parse_dtype_str("<f4").unwrap(),
            (DType::F32, Endianness::Little)
        );
        assert_eq!(
            parse_dtype_str(">i4").unwrap(),
            (DType::I32, Endianness::Big)
        );
        assert_eq!(
            parse_dtype_str("<i8").unwrap(),
            (DType::I64, Endianness::Little)
        );
        assert_eq!(
            parse_dtype_str("|b1").unwrap(),
            (DType::Bool, Endianness::Native)
        );
        assert_eq!(
            parse_dtype_str("<u1").unwrap(),
            (DType::U8, Endianness::Little)
        );
        assert_eq!(
            parse_dtype_str("<c8").unwrap(),
            (DType::Complex32, Endianness::Little)
        );
        assert_eq!(
            parse_dtype_str("<c16").unwrap(),
            (DType::Complex64, Endianness::Little)
        );
    }

    #[test]
    fn parse_unsigned_types() {
        assert_eq!(
            parse_dtype_str("<u2").unwrap(),
            (DType::U16, Endianness::Little)
        );
        assert_eq!(
            parse_dtype_str("<u4").unwrap(),
            (DType::U32, Endianness::Little)
        );
        assert_eq!(
            parse_dtype_str("<u8").unwrap(),
            (DType::U64, Endianness::Little)
        );
    }

    #[test]
    fn parse_128bit_types() {
        assert_eq!(
            parse_dtype_str("<i16").unwrap(),
            (DType::I128, Endianness::Little)
        );
        assert_eq!(
            parse_dtype_str("<u16").unwrap(),
            (DType::U128, Endianness::Little)
        );
    }

    #[test]
    fn parse_invalid() {
        assert!(parse_dtype_str("x").is_err());
        assert!(parse_dtype_str("<z4").is_err());
        assert!(parse_dtype_str("").is_err());
    }

    #[test]
    fn roundtrip_descr() {
        let dtypes = [
            DType::Bool,
            DType::U8,
            DType::U16,
            DType::U32,
            DType::U64,
            DType::I8,
            DType::I16,
            DType::I32,
            DType::I64,
            DType::F32,
            DType::F64,
            DType::Complex32,
            DType::Complex64,
        ];
        for dt in dtypes {
            let descr = dtype_to_native_descr(dt).unwrap();
            let (parsed_dt, _) = parse_dtype_str(&descr).unwrap();
            assert_eq!(
                parsed_dt, dt,
                "roundtrip failed for {dt:?}: descr='{descr}'"
            );
        }
    }

    #[cfg(feature = "f16")]
    #[test]
    fn parse_f16_descriptor() {
        assert_eq!(
            parse_dtype_str("<f2").unwrap(),
            (DType::F16, Endianness::Little)
        );
        assert_eq!(
            parse_dtype_str(">f2").unwrap(),
            (DType::F16, Endianness::Big)
        );
    }

    #[cfg(feature = "f16")]
    #[test]
    fn f16_roundtrip_descr() {
        let d = dtype_to_native_descr(DType::F16).unwrap();
        let (parsed, _) = parse_dtype_str(&d).unwrap();
        assert_eq!(parsed, DType::F16);
    }

    #[cfg(feature = "bf16")]
    #[test]
    fn parse_bf16_descriptor() {
        assert_eq!(
            parse_dtype_str("<bf16").unwrap(),
            (DType::BF16, Endianness::Little)
        );
    }

    #[cfg(feature = "bf16")]
    #[test]
    fn bf16_roundtrip_descr() {
        let d = dtype_to_native_descr(DType::BF16).unwrap();
        let (parsed, _) = parse_dtype_str(&d).unwrap();
        assert_eq!(parsed, DType::BF16);
    }

    #[test]
    fn endianness_swap() {
        if cfg!(target_endian = "little") {
            assert!(!Endianness::Little.needs_swap());
            assert!(Endianness::Big.needs_swap());
        } else {
            assert!(Endianness::Little.needs_swap());
            assert!(!Endianness::Big.needs_swap());
        }
        assert!(!Endianness::Native.needs_swap());
    }

    // -- datetime64 / timedelta64 descriptor parsing --

    #[test]
    fn parse_datetime64_ns() {
        use ferray_core::dtype::TimeUnit;
        let (dt, e) = parse_dtype_str("<M8[ns]").unwrap();
        assert_eq!(dt, DType::DateTime64(TimeUnit::Ns));
        assert_eq!(e, Endianness::Little);
    }

    #[test]
    fn parse_datetime64_us() {
        use ferray_core::dtype::TimeUnit;
        let (dt, _) = parse_dtype_str("<M8[us]").unwrap();
        assert_eq!(dt, DType::DateTime64(TimeUnit::Us));
    }

    #[test]
    fn parse_timedelta64_ms() {
        use ferray_core::dtype::TimeUnit;
        let (dt, _) = parse_dtype_str("<m8[ms]").unwrap();
        assert_eq!(dt, DType::Timedelta64(TimeUnit::Ms));
    }

    #[test]
    fn datetime64_roundtrip_descr() {
        use ferray_core::dtype::TimeUnit;
        for u in [
            TimeUnit::Ns,
            TimeUnit::Us,
            TimeUnit::Ms,
            TimeUnit::S,
            TimeUnit::D,
        ] {
            let dt = DType::DateTime64(u);
            let s = dtype_to_descr(dt, Endianness::Little).unwrap();
            let (parsed, _) = parse_dtype_str(&s).unwrap();
            assert_eq!(parsed, dt, "roundtrip failed for {dt}");
        }
    }

    #[test]
    fn timedelta64_roundtrip_descr() {
        use ferray_core::dtype::TimeUnit;
        for u in [
            TimeUnit::Ns,
            TimeUnit::Us,
            TimeUnit::Ms,
            TimeUnit::S,
            TimeUnit::D,
        ] {
            let dt = DType::Timedelta64(u);
            let s = dtype_to_descr(dt, Endianness::Little).unwrap();
            let (parsed, _) = parse_dtype_str(&s).unwrap();
            assert_eq!(parsed, dt, "roundtrip failed for {dt}");
        }
    }

    #[test]
    fn unknown_datetime_unit_errors() {
        assert!(parse_dtype_str("<M8[foobar]").is_err());
    }
}