use ferray_core::dtype::DType;
use ferray_core::error::{FerrayError, FerrayResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Endianness {
Little,
Big,
Native,
}
impl Endianness {
#[inline]
#[must_use]
pub const fn needs_swap(self) -> bool {
match self {
Self::Little => cfg!(target_endian = "big"),
Self::Big => cfg!(target_endian = "little"),
Self::Native => false,
}
}
}
pub fn parse_dtype_str(s: &str) -> FerrayResult<(DType, Endianness)> {
if s.len() < 2 {
return Err(FerrayError::invalid_dtype(format!(
"dtype string too short: '{s}'"
)));
}
let (endian, type_str) = match s.as_bytes()[0] {
b'<' => (Endianness::Little, &s[1..]),
b'>' => (Endianness::Big, &s[1..]),
b'|' => (Endianness::Native, &s[1..]),
b'=' => (Endianness::Native, &s[1..]),
_ => (Endianness::Native, s),
};
let dtype = match type_str {
"b1" => DType::Bool,
"u1" => DType::U8,
"u2" => DType::U16,
"u4" => DType::U32,
"u8" => DType::U64,
"u16" => DType::U128,
"i1" => DType::I8,
"i2" => DType::I16,
"i4" => DType::I32,
"i8" => DType::I64,
"i16" => DType::I128,
#[cfg(feature = "f16")]
"f2" => DType::F16,
"f4" => DType::F32,
"f8" => DType::F64,
#[cfg(feature = "bf16")]
"bf16" => DType::BF16,
"c8" => DType::Complex32,
"c16" => DType::Complex64,
other if other.starts_with("M8[") && other.ends_with(']') => {
let unit_str = &other[3..other.len() - 1];
let unit =
ferray_core::dtype::TimeUnit::from_descr_suffix(unit_str).ok_or_else(|| {
FerrayError::invalid_dtype(format!("unknown datetime64 unit: '{unit_str}'"))
})?;
DType::DateTime64(unit)
}
other if other.starts_with("m8[") && other.ends_with(']') => {
let unit_str = &other[3..other.len() - 1];
let unit =
ferray_core::dtype::TimeUnit::from_descr_suffix(unit_str).ok_or_else(|| {
FerrayError::invalid_dtype(format!("unknown timedelta64 unit: '{unit_str}'"))
})?;
DType::Timedelta64(unit)
}
_ => {
return Err(FerrayError::invalid_dtype(format!(
"unsupported dtype descriptor: '{s}'"
)));
}
};
Ok((dtype, endian))
}
pub fn dtype_to_descr(dtype: DType, endian: Endianness) -> FerrayResult<String> {
let prefix = match endian {
Endianness::Little => '<',
Endianness::Big => '>',
Endianness::Native => '|',
};
if let DType::DateTime64(u) = dtype {
let actual_prefix = match endian {
Endianness::Native => '|',
_ => prefix,
};
return Ok(format!("{actual_prefix}M8[{}]", u.descr_suffix()));
}
if let DType::Timedelta64(u) = dtype {
let actual_prefix = match endian {
Endianness::Native => '|',
_ => prefix,
};
return Ok(format!("{actual_prefix}m8[{}]", u.descr_suffix()));
}
let type_str = match dtype {
DType::Bool => "b1",
DType::U8 => "u1",
DType::U16 => "u2",
DType::U32 => "u4",
DType::U64 => "u8",
DType::U128 => "u16",
DType::I8 => "i1",
DType::I16 => "i2",
DType::I32 => "i4",
DType::I64 => "i8",
DType::I128 => "i16",
#[cfg(feature = "f16")]
DType::F16 => "f2",
DType::F32 => "f4",
DType::F64 => "f8",
#[cfg(feature = "bf16")]
DType::BF16 => "bf16",
DType::Complex32 => "c8",
DType::Complex64 => "c16",
_ => {
return Err(FerrayError::invalid_dtype(format!(
"unsupported dtype for descriptor: {dtype:?}"
)));
}
};
let actual_prefix = match dtype {
DType::Bool | DType::U8 | DType::I8 => '|',
_ => prefix,
};
Ok(format!("{actual_prefix}{type_str}"))
}
pub fn dtype_to_native_descr(dtype: DType) -> FerrayResult<String> {
let endian = if cfg!(target_endian = "little") {
Endianness::Little
} else {
Endianness::Big
};
dtype_to_descr(dtype, endian)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_common_dtypes() {
assert_eq!(
parse_dtype_str("<f8").unwrap(),
(DType::F64, Endianness::Little)
);
assert_eq!(
parse_dtype_str("<f4").unwrap(),
(DType::F32, Endianness::Little)
);
assert_eq!(
parse_dtype_str(">i4").unwrap(),
(DType::I32, Endianness::Big)
);
assert_eq!(
parse_dtype_str("<i8").unwrap(),
(DType::I64, Endianness::Little)
);
assert_eq!(
parse_dtype_str("|b1").unwrap(),
(DType::Bool, Endianness::Native)
);
assert_eq!(
parse_dtype_str("<u1").unwrap(),
(DType::U8, Endianness::Little)
);
assert_eq!(
parse_dtype_str("<c8").unwrap(),
(DType::Complex32, Endianness::Little)
);
assert_eq!(
parse_dtype_str("<c16").unwrap(),
(DType::Complex64, Endianness::Little)
);
}
#[test]
fn parse_unsigned_types() {
assert_eq!(
parse_dtype_str("<u2").unwrap(),
(DType::U16, Endianness::Little)
);
assert_eq!(
parse_dtype_str("<u4").unwrap(),
(DType::U32, Endianness::Little)
);
assert_eq!(
parse_dtype_str("<u8").unwrap(),
(DType::U64, Endianness::Little)
);
}
#[test]
fn parse_128bit_types() {
assert_eq!(
parse_dtype_str("<i16").unwrap(),
(DType::I128, Endianness::Little)
);
assert_eq!(
parse_dtype_str("<u16").unwrap(),
(DType::U128, Endianness::Little)
);
}
#[test]
fn parse_invalid() {
assert!(parse_dtype_str("x").is_err());
assert!(parse_dtype_str("<z4").is_err());
assert!(parse_dtype_str("").is_err());
}
#[test]
fn roundtrip_descr() {
let dtypes = [
DType::Bool,
DType::U8,
DType::U16,
DType::U32,
DType::U64,
DType::I8,
DType::I16,
DType::I32,
DType::I64,
DType::F32,
DType::F64,
DType::Complex32,
DType::Complex64,
];
for dt in dtypes {
let descr = dtype_to_native_descr(dt).unwrap();
let (parsed_dt, _) = parse_dtype_str(&descr).unwrap();
assert_eq!(
parsed_dt, dt,
"roundtrip failed for {dt:?}: descr='{descr}'"
);
}
}
#[cfg(feature = "f16")]
#[test]
fn parse_f16_descriptor() {
assert_eq!(
parse_dtype_str("<f2").unwrap(),
(DType::F16, Endianness::Little)
);
assert_eq!(
parse_dtype_str(">f2").unwrap(),
(DType::F16, Endianness::Big)
);
}
#[cfg(feature = "f16")]
#[test]
fn f16_roundtrip_descr() {
let d = dtype_to_native_descr(DType::F16).unwrap();
let (parsed, _) = parse_dtype_str(&d).unwrap();
assert_eq!(parsed, DType::F16);
}
#[cfg(feature = "bf16")]
#[test]
fn parse_bf16_descriptor() {
assert_eq!(
parse_dtype_str("<bf16").unwrap(),
(DType::BF16, Endianness::Little)
);
}
#[cfg(feature = "bf16")]
#[test]
fn bf16_roundtrip_descr() {
let d = dtype_to_native_descr(DType::BF16).unwrap();
let (parsed, _) = parse_dtype_str(&d).unwrap();
assert_eq!(parsed, DType::BF16);
}
#[test]
fn endianness_swap() {
if cfg!(target_endian = "little") {
assert!(!Endianness::Little.needs_swap());
assert!(Endianness::Big.needs_swap());
} else {
assert!(Endianness::Little.needs_swap());
assert!(!Endianness::Big.needs_swap());
}
assert!(!Endianness::Native.needs_swap());
}
#[test]
fn parse_datetime64_ns() {
use ferray_core::dtype::TimeUnit;
let (dt, e) = parse_dtype_str("<M8[ns]").unwrap();
assert_eq!(dt, DType::DateTime64(TimeUnit::Ns));
assert_eq!(e, Endianness::Little);
}
#[test]
fn parse_datetime64_us() {
use ferray_core::dtype::TimeUnit;
let (dt, _) = parse_dtype_str("<M8[us]").unwrap();
assert_eq!(dt, DType::DateTime64(TimeUnit::Us));
}
#[test]
fn parse_timedelta64_ms() {
use ferray_core::dtype::TimeUnit;
let (dt, _) = parse_dtype_str("<m8[ms]").unwrap();
assert_eq!(dt, DType::Timedelta64(TimeUnit::Ms));
}
#[test]
fn datetime64_roundtrip_descr() {
use ferray_core::dtype::TimeUnit;
for u in [
TimeUnit::Ns,
TimeUnit::Us,
TimeUnit::Ms,
TimeUnit::S,
TimeUnit::D,
] {
let dt = DType::DateTime64(u);
let s = dtype_to_descr(dt, Endianness::Little).unwrap();
let (parsed, _) = parse_dtype_str(&s).unwrap();
assert_eq!(parsed, dt, "roundtrip failed for {dt}");
}
}
#[test]
fn timedelta64_roundtrip_descr() {
use ferray_core::dtype::TimeUnit;
for u in [
TimeUnit::Ns,
TimeUnit::Us,
TimeUnit::Ms,
TimeUnit::S,
TimeUnit::D,
] {
let dt = DType::Timedelta64(u);
let s = dtype_to_descr(dt, Endianness::Little).unwrap();
let (parsed, _) = parse_dtype_str(&s).unwrap();
assert_eq!(parsed, dt, "roundtrip failed for {dt}");
}
}
#[test]
fn unknown_datetime_unit_errors() {
assert!(parse_dtype_str("<M8[foobar]").is_err());
}
}