mod ast;
mod numpy_str;
use std::collections::HashMap;
use crate::nd::dtype::numpy_str::{parse_numpy_dtype_str, DtypeParseError};
use crate::util::{f16, Complex};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Dtype {
kind: DtypeKind,
shape: Vec<usize>,
itemsize: usize,
alignment: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DtypeKind {
Scalar {
kind: DtypeScalarKind,
endianness: Endianness,
},
Struct {
fields: HashMap<String, (usize, Dtype)>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DtypeScalarKind {
I8,
I16,
I32,
I64,
U8,
U16,
U32,
U64,
F16,
F32,
F64,
ComplexF32,
ComplexF64,
Bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Endianness {
Little,
Big,
}
impl Dtype {
pub fn of_scalar(kind: DtypeScalarKind) -> Self {
Self {
kind: DtypeKind::Scalar {
kind,
endianness: Endianness::native(),
},
shape: Vec::new(),
itemsize: kind.itemsize(),
alignment: kind.alignment(),
}
}
pub fn of_struct(fields: HashMap<String, (usize, Dtype)>) -> Result<Self, DtypeError> {
fn determine_itemsize_and_alignment(
fields: &HashMap<String, (usize, Dtype)>,
) -> Result<(usize, usize), DtypeError> {
let mut fields_vec = fields
.iter()
.map(|(_name, (offset, dtype))| (*offset, dtype.itemsize, dtype.alignment))
.collect::<Vec<_>>();
fields_vec.sort_unstable_by_key(|(offset, _itemsize, _alignment)| *offset);
let mut expected_offset = 0;
let is_packed = fields_vec.iter().all({
|(offset, itemsize, _alignment)| {
let packed = *offset == expected_offset;
expected_offset += itemsize;
packed
}
});
if is_packed {
let itemsize = expected_offset;
return Ok((itemsize, 1));
}
let mut expected_offset = 0;
let is_aligned = fields_vec.iter().all({
|(offset, itemsize, alignment)| {
expected_offset = ceil_to_multiple(expected_offset, *alignment);
let aligned = *offset == expected_offset;
expected_offset += itemsize;
aligned
}
});
if is_aligned {
let max_alignment = fields_vec
.iter()
.map(|(_offset, _itemsize, alignment)| *alignment)
.max()
.unwrap_or(1);
let itemsize = ceil_to_multiple(expected_offset, max_alignment);
return Ok((itemsize, max_alignment));
}
Err(DtypeError::InvalidOffsets)
}
let (itemsize, alignment) = determine_itemsize_and_alignment(&fields)?;
Ok(Self {
kind: DtypeKind::Struct { fields },
shape: Vec::new(),
itemsize,
alignment,
})
}
pub fn new(
kind: DtypeKind,
shape: Vec<usize>,
itemsize: usize,
alignment: usize,
) -> Result<Self, DtypeError> {
let shape_prod = shape.iter().product::<usize>();
if shape_prod == 0 {
return Err(DtypeError::InvalidShape);
}
if itemsize % shape_prod != 0 {
return Err(DtypeError::InvalidItemsize);
}
let element_itemsize = itemsize / shape_prod;
match &kind {
DtypeKind::Scalar {
kind,
endianness: _,
} => {
if kind.alignment() != alignment {
return Err(DtypeError::InvalidAlignment);
}
if kind.itemsize() != element_itemsize {
return Err(DtypeError::InvalidItemsize);
}
}
DtypeKind::Struct { fields } => {
let mut fields_vec = fields
.iter()
.map(|(_name, (offset, dtype))| (*offset, dtype.itemsize, dtype.alignment))
.collect::<Vec<_>>();
fields_vec.sort_unstable_by_key(|(offset, _itemsize, _alignment)| *offset);
if alignment == 1 {
let mut expected_offset = 0;
let is_packed = fields_vec.iter().all({
|(offset, itemsize, _alignment)| {
let packed = *offset == expected_offset;
expected_offset += itemsize;
packed
}
});
if !is_packed {
return Err(DtypeError::InvalidOffsets);
}
let expected_itemsize = expected_offset;
if expected_itemsize != element_itemsize {
return Err(DtypeError::InvalidItemsize);
}
} else {
let max_alignment = fields
.values()
.map(|(_offset, dtype)| dtype.alignment)
.max()
.unwrap_or(1);
if alignment != max_alignment {
return Err(DtypeError::InvalidAlignment);
}
let mut expected_offset = 0;
let is_aligned = fields_vec.iter().all({
|(offset, itemsize, alignment)| {
expected_offset = ceil_to_multiple(expected_offset, *alignment);
let aligned = *offset == expected_offset;
expected_offset += itemsize;
aligned
}
});
if !is_aligned {
return Err(DtypeError::InvalidOffsets);
}
let expected_itemsize = ceil_to_multiple(expected_offset, max_alignment);
if expected_itemsize != element_itemsize {
return Err(DtypeError::InvalidItemsize);
}
}
}
}
Ok(Self {
kind,
shape,
itemsize,
alignment,
})
}
pub fn kind(&self) -> &DtypeKind {
&self.kind
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn itemsize(&self) -> usize {
self.itemsize
}
pub fn alignment(&self) -> usize {
self.alignment
}
pub fn with_shape(self, shape: Vec<usize>) -> Result<Self, DtypeError> {
let current_shape_prod = self.shape.iter().product::<usize>();
debug_assert!(current_shape_prod > 0);
debug_assert_eq!(self.itemsize % current_shape_prod, 0);
let base_itemsize = self.itemsize / current_shape_prod;
let shape_prod = shape.iter().product::<usize>();
if shape_prod == 0 {
return Err(DtypeError::InvalidShape);
}
let itemsize = base_itemsize * shape_prod;
Ok(Self {
kind: self.kind,
shape,
itemsize,
alignment: self.alignment,
})
}
pub fn from_numpy_str(s: &str) -> Result<Self, DtypeParseError> {
parse_numpy_dtype_str(s)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum DtypeError {
InvalidOffsets,
InvalidItemsize,
InvalidAlignment,
InvalidShape,
}
impl std::fmt::Display for DtypeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidOffsets => write!(f, "Invalid field offsets"),
Self::InvalidItemsize => write!(f, "Invalid itemsize"),
Self::InvalidAlignment => write!(f, "Invalid alignment"),
Self::InvalidShape => write!(f, "Invalid shape"),
}
}
}
impl DtypeScalarKind {
pub fn itemsize(&self) -> usize {
match self {
Self::I8 => 1,
Self::I16 => 2,
Self::I32 => 4,
Self::I64 => 8,
Self::U8 => 1,
Self::U16 => 2,
Self::U32 => 4,
Self::U64 => 8,
Self::F16 => 2,
Self::F32 => 4,
Self::F64 => 8,
Self::ComplexF32 => 8,
Self::ComplexF64 => 16,
Self::Bool => 1,
}
}
pub fn alignment(&self) -> usize {
match self {
Self::I8 => 1,
Self::I16 => 2,
Self::I32 => 4,
Self::I64 => 8,
Self::U8 => 1,
Self::U16 => 2,
Self::U32 => 4,
Self::U64 => 8,
Self::F16 => 2,
Self::F32 => 4,
Self::F64 => 8,
Self::ComplexF32 => 4,
Self::ComplexF64 => 8,
Self::Bool => 1,
}
}
}
impl Endianness {
pub fn native() -> Self {
if cfg!(target_endian = "little") {
Endianness::Little
} else {
Endianness::Big
}
}
}
pub unsafe trait Dtyped: Copy + 'static {
fn dtype() -> Dtype;
}
macro_rules! impl_dtyped_scalar {
($ty:ty, $kind:ident) => {
unsafe impl Dtyped for $ty {
fn dtype() -> Dtype {
Dtype::of_scalar(DtypeScalarKind::$kind)
}
}
};
}
impl_dtyped_scalar!(i8, I8);
impl_dtyped_scalar!(i16, I16);
impl_dtyped_scalar!(i32, I32);
impl_dtyped_scalar!(i64, I64);
impl_dtyped_scalar!(u8, U8);
impl_dtyped_scalar!(u16, U16);
impl_dtyped_scalar!(u32, U32);
impl_dtyped_scalar!(u64, U64);
impl_dtyped_scalar!(f16, F16);
impl_dtyped_scalar!(f32, F32);
impl_dtyped_scalar!(f64, F64);
impl_dtyped_scalar!(Complex<f32>, ComplexF32);
impl_dtyped_scalar!(Complex<f64>, ComplexF64);
impl_dtyped_scalar!(bool, Bool);
unsafe impl<T: Dtyped, const N: usize> Dtyped for [T; N] {
fn dtype() -> Dtype {
T::dtype().with_shape(vec![N]).unwrap()
}
}
fn ceil_to_multiple(x: usize, m: usize) -> usize {
assert!(m > 0);
x.div_ceil(m) * m
}