use std::{ffi::CStr, sync::Mutex};
use static_assertions::assert_impl_all;
use smol_str::format_smolstr;
use crate::error::{
Error, FfiNullHandlePayload, OutOfRangePayload, Result, check, ensure_handler_installed,
};
static DEFAULT_DEVICE_LOCK: Mutex<()> = Mutex::new(());
struct StringGuard(mlxrs_sys::mlx_string);
impl Drop for StringGuard {
fn drop(&mut self) {
unsafe {
let _ = mlxrs_sys::mlx_string_free(self.0);
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, derive_more::Display, derive_more::IsVariant)]
#[display("{}", self.as_str())]
#[non_exhaustive]
pub enum DeviceKind {
Cpu,
Gpu,
}
impl DeviceKind {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Cpu => "cpu",
Self::Gpu => "gpu",
}
}
#[inline]
fn to_raw(self) -> mlxrs_sys::mlx_device_type {
match self {
DeviceKind::Cpu => mlxrs_sys::mlx_device_type__MLX_CPU,
DeviceKind::Gpu => mlxrs_sys::mlx_device_type__MLX_GPU,
}
}
#[inline]
fn from_raw(raw: mlxrs_sys::mlx_device_type) -> Result<Self> {
match raw {
mlxrs_sys::mlx_device_type__MLX_CPU => Ok(DeviceKind::Cpu),
mlxrs_sys::mlx_device_type__MLX_GPU => Ok(DeviceKind::Gpu),
other => Err(Error::OutOfRange(OutOfRangePayload::new(
"DeviceKind::from_raw: mlx_device_type",
"must be MLX_CPU or MLX_GPU",
format_smolstr!("{other}"),
))),
}
}
pub fn count(self) -> Result<usize> {
ensure_handler_installed();
let mut n: i32 = 0;
check(unsafe { mlxrs_sys::mlx_device_count(&mut n, self.to_raw()) })?;
Ok(n.max(0) as usize)
}
}
#[repr(transparent)]
pub struct Device(pub(crate) mlxrs_sys::mlx_device);
unsafe impl Send for Device {}
unsafe impl Sync for Device {}
assert_impl_all!(Device: Send, Sync, std::hash::Hash, std::fmt::Display, std::fmt::Debug);
impl Drop for Device {
fn drop(&mut self) {
unsafe {
let _ = mlxrs_sys::mlx_device_free(self.0);
}
}
}
impl Device {
pub fn cpu() -> Result<Self> {
Self::with_index(DeviceKind::Cpu, 0)
}
pub fn gpu() -> Result<Self> {
Self::with_index(DeviceKind::Gpu, 0)
}
pub fn with_index(kind: DeviceKind, index: i32) -> Result<Self> {
ensure_handler_installed();
let raw = unsafe { mlxrs_sys::mlx_device_new_type(kind.to_raw(), index) };
if raw.ctx.is_null() {
return Err(Error::FfiNullHandle(FfiNullHandlePayload::new(
"mlx_device_new_type",
)));
}
Ok(Self(raw))
}
pub fn try_clone(&self) -> Result<Self> {
ensure_handler_installed();
let mut out = Self(unsafe { mlxrs_sys::mlx_device_new() });
check(unsafe { mlxrs_sys::mlx_device_set(&mut out.0, self.0) })?;
Ok(out)
}
pub fn current() -> Result<Self> {
ensure_handler_installed();
let _g = DEFAULT_DEVICE_LOCK
.lock()
.unwrap_or_else(|p| p.into_inner());
let mut out = Self(unsafe { mlxrs_sys::mlx_device_new() });
check(unsafe { mlxrs_sys::mlx_get_default_device(&mut out.0) })?;
Ok(out)
}
#[inline(always)]
pub fn get_default() -> Result<Self> {
Self::current()
}
pub fn set_default(&self) -> Result<()> {
ensure_handler_installed();
let _g = DEFAULT_DEVICE_LOCK
.lock()
.unwrap_or_else(|p| p.into_inner());
check(unsafe { mlxrs_sys::mlx_set_default_device(self.0) })
}
pub fn kind(&self) -> Result<DeviceKind> {
ensure_handler_installed();
let mut raw: mlxrs_sys::mlx_device_type = 0;
check(unsafe { mlxrs_sys::mlx_device_get_type(&mut raw, self.0) })?;
DeviceKind::from_raw(raw)
}
pub fn index(&self) -> Result<i32> {
ensure_handler_installed();
let mut idx: i32 = 0;
check(unsafe { mlxrs_sys::mlx_device_get_index(&mut idx, self.0) })?;
Ok(idx)
}
pub fn is_available(&self) -> Result<bool> {
ensure_handler_installed();
let mut avail = false;
check(unsafe { mlxrs_sys::mlx_device_is_available(&mut avail, self.0) })?;
Ok(avail)
}
#[inline(always)]
pub fn equal(&self, other: &Device) -> bool {
unsafe { mlxrs_sys::mlx_device_equal(self.0, other.0) }
}
#[inline]
pub unsafe fn as_raw(&self) -> mlxrs_sys::mlx_device {
self.0
}
}
impl PartialEq for Device {
fn eq(&self, other: &Self) -> bool {
self.equal(other)
}
}
impl Eq for Device {}
impl std::hash::Hash for Device {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match (self.kind(), self.index()) {
(Ok(kind), Ok(index)) => {
kind.hash(state);
index.hash(state);
}
_ => i32::MIN.hash(state),
}
}
}
fn fmt_device(
dev: &Device,
f: &mut std::fmt::Formatter<'_>,
wrap: Option<&str>,
) -> std::fmt::Result {
crate::error::ensure_handler_installed();
let mut guard = StringGuard(unsafe { mlxrs_sys::mlx_string_new() });
let rc = unsafe { mlxrs_sys::mlx_device_tostring(&mut guard.0, dev.0) };
let text = if rc == 0 {
let p = unsafe { mlxrs_sys::mlx_string_data(guard.0) };
if p.is_null() {
None
} else {
Some(unsafe { CStr::from_ptr(p) }.to_string_lossy())
}
} else {
None
};
let result = match (wrap, &text) {
(Some(prefix), Some(t)) => write!(f, "{prefix}({t})"),
(Some(prefix), None) => write!(f, "{prefix}(<unprintable>)"),
(None, Some(t)) => f.write_str(t.as_ref()),
(None, None) => f.write_str("<unprintable>"),
};
result
}
impl std::fmt::Debug for Device {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt_device(self, f, Some("Device"))
}
}
impl std::fmt::Display for Device {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt_device(self, f, None)
}
}