use crate::error::{Error, Result, check};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, derive_more::Display, derive_more::IsVariant)]
#[display("{}", self.as_str())]
pub enum Dtype {
Bool,
U8,
U16,
U32,
U64,
I8,
I16,
I32,
I64,
F16,
F32,
F64,
BF16,
Complex64,
}
impl Dtype {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Bool => "bool",
Self::U8 => "uint8",
Self::U16 => "uint16",
Self::U32 => "uint32",
Self::U64 => "uint64",
Self::I8 => "int8",
Self::I16 => "int16",
Self::I32 => "int32",
Self::I64 => "int64",
Self::F16 => "float16",
Self::F32 => "float32",
Self::F64 => "float64",
Self::BF16 => "bfloat16",
Self::Complex64 => "complex64",
}
}
}
impl std::str::FromStr for Dtype {
type Err = crate::error::Error;
fn from_str(s: &str) -> Result<Self> {
match s {
"bool" => Ok(Self::Bool),
"uint8" => Ok(Self::U8),
"uint16" => Ok(Self::U16),
"uint32" => Ok(Self::U32),
"uint64" => Ok(Self::U64),
"int8" => Ok(Self::I8),
"int16" => Ok(Self::I16),
"int32" => Ok(Self::I32),
"int64" => Ok(Self::I64),
"float16" => Ok(Self::F16),
"float32" => Ok(Self::F32),
"float64" => Ok(Self::F64),
"bfloat16" => Ok(Self::BF16),
"complex64" => Ok(Self::Complex64),
_ => Err(Error::UnknownEnumValue(
crate::error::UnknownEnumValuePayload::new(
"Dtype",
s,
&[
"bool",
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"float16",
"float32",
"float64",
"bfloat16",
"complex64",
],
),
)),
}
}
}
impl TryFrom<mlxrs_sys::mlx_dtype> for Dtype {
type Error = Error;
fn try_from(raw: mlxrs_sys::mlx_dtype) -> Result<Self> {
match raw {
mlxrs_sys::mlx_dtype__MLX_BOOL => Ok(Self::Bool),
mlxrs_sys::mlx_dtype__MLX_UINT8 => Ok(Self::U8),
mlxrs_sys::mlx_dtype__MLX_UINT16 => Ok(Self::U16),
mlxrs_sys::mlx_dtype__MLX_UINT32 => Ok(Self::U32),
mlxrs_sys::mlx_dtype__MLX_UINT64 => Ok(Self::U64),
mlxrs_sys::mlx_dtype__MLX_INT8 => Ok(Self::I8),
mlxrs_sys::mlx_dtype__MLX_INT16 => Ok(Self::I16),
mlxrs_sys::mlx_dtype__MLX_INT32 => Ok(Self::I32),
mlxrs_sys::mlx_dtype__MLX_INT64 => Ok(Self::I64),
mlxrs_sys::mlx_dtype__MLX_FLOAT16 => Ok(Self::F16),
mlxrs_sys::mlx_dtype__MLX_FLOAT32 => Ok(Self::F32),
mlxrs_sys::mlx_dtype__MLX_FLOAT64 => Ok(Self::F64),
mlxrs_sys::mlx_dtype__MLX_BFLOAT16 => Ok(Self::BF16),
mlxrs_sys::mlx_dtype__MLX_COMPLEX64 => Ok(Self::Complex64),
other => Err(Error::UnknownDtype(other)),
}
}
}
impl From<Dtype> for mlxrs_sys::mlx_dtype {
fn from(d: Dtype) -> Self {
match d {
Dtype::Bool => mlxrs_sys::mlx_dtype__MLX_BOOL,
Dtype::U8 => mlxrs_sys::mlx_dtype__MLX_UINT8,
Dtype::U16 => mlxrs_sys::mlx_dtype__MLX_UINT16,
Dtype::U32 => mlxrs_sys::mlx_dtype__MLX_UINT32,
Dtype::U64 => mlxrs_sys::mlx_dtype__MLX_UINT64,
Dtype::I8 => mlxrs_sys::mlx_dtype__MLX_INT8,
Dtype::I16 => mlxrs_sys::mlx_dtype__MLX_INT16,
Dtype::I32 => mlxrs_sys::mlx_dtype__MLX_INT32,
Dtype::I64 => mlxrs_sys::mlx_dtype__MLX_INT64,
Dtype::F16 => mlxrs_sys::mlx_dtype__MLX_FLOAT16,
Dtype::F32 => mlxrs_sys::mlx_dtype__MLX_FLOAT32,
Dtype::F64 => mlxrs_sys::mlx_dtype__MLX_FLOAT64,
Dtype::BF16 => mlxrs_sys::mlx_dtype__MLX_BFLOAT16,
Dtype::Complex64 => mlxrs_sys::mlx_dtype__MLX_COMPLEX64,
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub struct Complex64 {
re: f32,
im: f32,
}
impl Complex64 {
#[inline(always)]
pub const fn new(re: f32, im: f32) -> Self {
Self { re, im }
}
#[inline(always)]
pub const fn re(&self) -> f32 {
self.re
}
#[inline(always)]
pub const fn im(&self) -> f32 {
self.im
}
#[inline(always)]
pub const fn as_parts(&self) -> (f32, f32) {
(self.re, self.im)
}
}
impl From<(f32, f32)> for Complex64 {
#[inline(always)]
fn from((re, im): (f32, f32)) -> Self {
Self::new(re, im)
}
}
impl From<Complex64> for (f32, f32) {
#[inline(always)]
fn from(c: Complex64) -> Self {
(c.re, c.im)
}
}
pub trait Element: sealed::Sealed + Copy + 'static {
const DTYPE: Dtype;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self>;
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize);
fn sentinel_ptr() -> *const Self;
}
mod sealed {
pub trait Sealed {}
}
impl sealed::Sealed for bool {}
impl sealed::Sealed for u8 {}
impl sealed::Sealed for u16 {}
impl sealed::Sealed for u32 {}
impl sealed::Sealed for u64 {}
impl sealed::Sealed for i8 {}
impl sealed::Sealed for i16 {}
impl sealed::Sealed for i32 {}
impl sealed::Sealed for i64 {}
impl sealed::Sealed for f32 {}
impl sealed::Sealed for f64 {}
impl sealed::Sealed for half::f16 {}
impl sealed::Sealed for half::bf16 {}
impl sealed::Sealed for Complex64 {}
impl Element for bool {
const DTYPE: Dtype = Dtype::Bool;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self> {
let mut out: bool = false;
check(unsafe { mlxrs_sys::mlx_array_item_bool(&mut out, arr) })?;
Ok(out)
}
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize) {
unsafe {
(
mlxrs_sys::mlx_array_data_bool(arr),
mlxrs_sys::mlx_array_size(arr),
)
}
}
fn sentinel_ptr() -> *const Self {
static V: bool = false;
&V
}
}
impl Element for i32 {
const DTYPE: Dtype = Dtype::I32;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self> {
let mut out: i32 = 0;
check(unsafe { mlxrs_sys::mlx_array_item_int32(&mut out, arr) })?;
Ok(out)
}
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize) {
unsafe {
(
mlxrs_sys::mlx_array_data_int32(arr),
mlxrs_sys::mlx_array_size(arr),
)
}
}
fn sentinel_ptr() -> *const Self {
static V: i32 = 0;
&V
}
}
impl Element for u32 {
const DTYPE: Dtype = Dtype::U32;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self> {
let mut out: u32 = 0;
check(unsafe { mlxrs_sys::mlx_array_item_uint32(&mut out, arr) })?;
Ok(out)
}
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize) {
unsafe {
(
mlxrs_sys::mlx_array_data_uint32(arr),
mlxrs_sys::mlx_array_size(arr),
)
}
}
fn sentinel_ptr() -> *const Self {
static V: u32 = 0;
&V
}
}
impl Element for f32 {
const DTYPE: Dtype = Dtype::F32;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self> {
let mut out: f32 = 0.0;
check(unsafe { mlxrs_sys::mlx_array_item_float32(&mut out, arr) })?;
Ok(out)
}
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize) {
unsafe {
(
mlxrs_sys::mlx_array_data_float32(arr),
mlxrs_sys::mlx_array_size(arr),
)
}
}
fn sentinel_ptr() -> *const Self {
static V: f32 = 0.0;
&V
}
}
impl Element for half::f16 {
const DTYPE: Dtype = Dtype::F16;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self> {
let mut raw: mlxrs_sys::float16_t = unsafe { std::mem::zeroed() };
check(unsafe { mlxrs_sys::mlx_array_item_float16(&mut raw, arr) })?;
Ok(unsafe { std::mem::transmute_copy(&raw) })
}
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize) {
unsafe {
(
mlxrs_sys::mlx_array_data_float16(arr) as *const half::f16,
mlxrs_sys::mlx_array_size(arr),
)
}
}
fn sentinel_ptr() -> *const Self {
static V: half::f16 = half::f16::ZERO;
&V
}
}
impl Element for u8 {
const DTYPE: Dtype = Dtype::U8;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self> {
let mut out: u8 = 0;
check(unsafe { mlxrs_sys::mlx_array_item_uint8(&mut out, arr) })?;
Ok(out)
}
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize) {
unsafe {
(
mlxrs_sys::mlx_array_data_uint8(arr),
mlxrs_sys::mlx_array_size(arr),
)
}
}
fn sentinel_ptr() -> *const Self {
static V: u8 = 0;
&V
}
}
impl Element for u16 {
const DTYPE: Dtype = Dtype::U16;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self> {
let mut out: u16 = 0;
check(unsafe { mlxrs_sys::mlx_array_item_uint16(&mut out, arr) })?;
Ok(out)
}
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize) {
unsafe {
(
mlxrs_sys::mlx_array_data_uint16(arr),
mlxrs_sys::mlx_array_size(arr),
)
}
}
fn sentinel_ptr() -> *const Self {
static V: u16 = 0;
&V
}
}
impl Element for u64 {
const DTYPE: Dtype = Dtype::U64;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self> {
let mut out: u64 = 0;
check(unsafe { mlxrs_sys::mlx_array_item_uint64(&mut out, arr) })?;
Ok(out)
}
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize) {
unsafe {
(
mlxrs_sys::mlx_array_data_uint64(arr),
mlxrs_sys::mlx_array_size(arr),
)
}
}
fn sentinel_ptr() -> *const Self {
static V: u64 = 0;
&V
}
}
impl Element for i8 {
const DTYPE: Dtype = Dtype::I8;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self> {
let mut out: i8 = 0;
check(unsafe { mlxrs_sys::mlx_array_item_int8(&mut out, arr) })?;
Ok(out)
}
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize) {
unsafe {
(
mlxrs_sys::mlx_array_data_int8(arr),
mlxrs_sys::mlx_array_size(arr),
)
}
}
fn sentinel_ptr() -> *const Self {
static V: i8 = 0;
&V
}
}
impl Element for i16 {
const DTYPE: Dtype = Dtype::I16;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self> {
let mut out: i16 = 0;
check(unsafe { mlxrs_sys::mlx_array_item_int16(&mut out, arr) })?;
Ok(out)
}
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize) {
unsafe {
(
mlxrs_sys::mlx_array_data_int16(arr),
mlxrs_sys::mlx_array_size(arr),
)
}
}
fn sentinel_ptr() -> *const Self {
static V: i16 = 0;
&V
}
}
impl Element for i64 {
const DTYPE: Dtype = Dtype::I64;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self> {
let mut out: i64 = 0;
check(unsafe { mlxrs_sys::mlx_array_item_int64(&mut out, arr) })?;
Ok(out)
}
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize) {
unsafe {
(
mlxrs_sys::mlx_array_data_int64(arr),
mlxrs_sys::mlx_array_size(arr),
)
}
}
fn sentinel_ptr() -> *const Self {
static V: i64 = 0;
&V
}
}
impl Element for f64 {
const DTYPE: Dtype = Dtype::F64;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self> {
let mut out: f64 = 0.0;
check(unsafe { mlxrs_sys::mlx_array_item_float64(&mut out, arr) })?;
Ok(out)
}
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize) {
unsafe {
(
mlxrs_sys::mlx_array_data_float64(arr),
mlxrs_sys::mlx_array_size(arr),
)
}
}
fn sentinel_ptr() -> *const Self {
static V: f64 = 0.0;
&V
}
}
impl Element for half::bf16 {
const DTYPE: Dtype = Dtype::BF16;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self> {
let mut raw: mlxrs_sys::bfloat16_t = unsafe { std::mem::zeroed() };
check(unsafe { mlxrs_sys::mlx_array_item_bfloat16(&mut raw, arr) })?;
Ok(unsafe { std::mem::transmute_copy(&raw) })
}
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize) {
unsafe {
(
mlxrs_sys::mlx_array_data_bfloat16(arr) as *const half::bf16,
mlxrs_sys::mlx_array_size(arr),
)
}
}
fn sentinel_ptr() -> *const Self {
static V: half::bf16 = half::bf16::ZERO;
&V
}
}
impl Element for Complex64 {
const DTYPE: Dtype = Dtype::Complex64;
unsafe fn item(arr: mlxrs_sys::mlx_array) -> Result<Self> {
let mut raw = mlxrs_sys::mlx_complex64_t { re: 0.0, im: 0.0 };
check(unsafe { mlxrs_sys::mlx_array_item_complex64(&mut raw, arr) })?;
Ok(Self::new(raw.re, raw.im))
}
unsafe fn data(arr: mlxrs_sys::mlx_array) -> (*const Self, usize) {
unsafe {
(
mlxrs_sys::mlx_array_data_complex64(arr) as *const Complex64,
mlxrs_sys::mlx_array_size(arr),
)
}
}
fn sentinel_ptr() -> *const Self {
static V: Complex64 = Complex64::new(0.0, 0.0);
&V
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn complex64_parts_round_trip() {
let c = Complex64::new(1.5, -2.25);
assert_eq!(c.re(), 1.5);
assert_eq!(c.im(), -2.25);
assert_eq!(c.as_parts(), (1.5, -2.25));
}
#[test]
fn complex64_tuple_conversions() {
let c: Complex64 = (3.0_f32, 4.0_f32).into();
assert_eq!(c, Complex64::new(3.0, 4.0));
let parts: (f32, f32) = c.into();
assert_eq!(parts, (3.0, 4.0));
}
#[test]
fn complex64_default_is_zero() {
assert_eq!(Complex64::default(), Complex64::new(0.0, 0.0));
}
#[test]
fn complex64_layout_matches_mlx_complex64_t() {
use std::mem::{align_of, size_of};
assert_eq!(
size_of::<Complex64>(),
size_of::<mlxrs_sys::mlx_complex64_t>(),
"Complex64 size must match mlx_complex64_t"
);
assert_eq!(
align_of::<Complex64>(),
align_of::<mlxrs_sys::mlx_complex64_t>(),
"Complex64 align must match mlx_complex64_t"
);
let raw = mlxrs_sys::mlx_complex64_t { re: 7.0, im: 9.0 };
let as_c: Complex64 = unsafe { std::mem::transmute(raw) };
assert_eq!(as_c, Complex64::new(7.0, 9.0));
}
#[test]
fn complex64_dtype_is_complex64() {
assert_eq!(<Complex64 as Element>::DTYPE, Dtype::Complex64);
}
#[test]
fn dtype_hash_consistent_with_eq() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(Dtype::F32);
set.insert(Dtype::Complex64);
set.insert(Dtype::F32); assert_eq!(set.len(), 2);
assert!(set.contains(&Dtype::F32));
assert!(set.contains(&Dtype::Complex64));
assert!(!set.contains(&Dtype::I32));
}
#[test]
fn dtype_from_str_round_trips_every_variant() {
use std::str::FromStr;
const ALL: &[Dtype] = &[
Dtype::Bool,
Dtype::U8,
Dtype::U16,
Dtype::U32,
Dtype::U64,
Dtype::I8,
Dtype::I16,
Dtype::I32,
Dtype::I64,
Dtype::F16,
Dtype::F32,
Dtype::F64,
Dtype::BF16,
Dtype::Complex64,
];
for &d in ALL {
assert_eq!(
Dtype::from_str(d.as_str()).expect("as_str output must parse"),
d,
"round-trip failed for {d:?}"
);
}
let err = Dtype::from_str("not_a_dtype").unwrap_err();
assert!(
matches!(err, Error::UnknownEnumValue(_)),
"unknown dtype name must yield UnknownEnumValue, got {err:?}"
);
}
}