use std::{
cell::Cell,
collections::HashMap,
ffi::{CStr, CString},
fs::File,
io::{Seek, SeekFrom, Write},
os::raw::{c_char, c_int, c_void},
path::Path,
};
use crate::{
array::Array,
error::{Error, FileIoPayload, FileOp, InteriorNulPayload, Result, check},
};
thread_local! {
static CPU_STREAM: Cell<Option<mlxrs_sys::mlx_stream>> = const { Cell::new(None) };
}
fn io_cpu_stream() -> mlxrs_sys::mlx_stream {
crate::error::ensure_handler_installed();
crate::stream::assert_streams_not_cleared();
CPU_STREAM.with(|cell| {
if let Some(s) = cell.get() {
return s;
}
let s = unsafe { mlxrs_sys::mlx_default_cpu_stream_new() };
if s.ctx.is_null() {
panic!(
"mlxrs::io: mlx_default_cpu_stream_new returned NULL ctx — \
CPU stream initialization failed. Aborting."
);
}
cell.set(Some(s));
s
})
}
fn path_cstring(path: &Path) -> Result<CString> {
CString::new(path.as_os_str().as_encoded_bytes()).map_err(|_| {
let _ = path;
Error::InteriorNul(InteriorNulPayload::new("io::path_cstring", "path"))
})
}
struct ArrayMapGuard(mlxrs_sys::mlx_map_string_to_array);
impl Drop for ArrayMapGuard {
fn drop(&mut self) {
unsafe {
let _ = mlxrs_sys::mlx_map_string_to_array_free(self.0);
}
}
}
struct StringMapGuard(mlxrs_sys::mlx_map_string_to_string);
impl Drop for StringMapGuard {
fn drop(&mut self) {
unsafe {
let _ = mlxrs_sys::mlx_map_string_to_string_free(self.0);
}
}
}
#[cfg(feature = "gguf")]
struct GgufGuard(mlxrs_sys::mlx_io_gguf);
#[cfg(feature = "gguf")]
impl Drop for GgufGuard {
fn drop(&mut self) {
unsafe {
let _ = mlxrs_sys::mlx_io_gguf_free(self.0);
}
}
}
#[cfg(feature = "gguf")]
struct VectorStringGuard(mlxrs_sys::mlx_vector_string);
#[cfg(feature = "gguf")]
impl Drop for VectorStringGuard {
fn drop(&mut self) {
unsafe {
let _ = mlxrs_sys::mlx_vector_string_free(self.0);
}
}
}
#[cfg(feature = "gguf")]
struct StringGuard(mlxrs_sys::mlx_string);
#[cfg(feature = "gguf")]
impl Drop for StringGuard {
fn drop(&mut self) {
unsafe {
let _ = mlxrs_sys::mlx_string_free(self.0);
}
}
}
#[cfg(feature = "gguf")]
fn gguf_has_meta(rc: std::os::raw::c_int, flag: bool) -> Result<bool> {
match rc {
0 => Ok(flag),
2 => Ok(false),
_ => {
check(rc)?;
Ok(false) }
}
}
fn drain_array_map(map: mlxrs_sys::mlx_map_string_to_array) -> HashMap<String, Array> {
let it = unsafe { mlxrs_sys::mlx_map_string_to_array_iterator_new(map) };
let mut out = HashMap::new();
loop {
let mut key: *const std::os::raw::c_char = std::ptr::null();
let mut value = Array(unsafe { mlxrs_sys::mlx_array_new() });
let rc =
unsafe { mlxrs_sys::mlx_map_string_to_array_iterator_next(&mut key, &mut value.0, it) };
if rc != 0 {
break;
}
let k = unsafe { CStr::from_ptr(key) }
.to_string_lossy()
.into_owned();
out.insert(k, value);
}
unsafe {
let _ = mlxrs_sys::mlx_map_string_to_array_iterator_free(it);
}
out
}
fn drain_string_map(map: mlxrs_sys::mlx_map_string_to_string) -> HashMap<String, String> {
let it = unsafe { mlxrs_sys::mlx_map_string_to_string_iterator_new(map) };
let mut out = HashMap::new();
loop {
let mut key: *const std::os::raw::c_char = std::ptr::null();
let mut value: *const std::os::raw::c_char = std::ptr::null();
let rc = unsafe { mlxrs_sys::mlx_map_string_to_string_iterator_next(&mut key, &mut value, it) };
if rc != 0 {
break;
}
let k = unsafe { CStr::from_ptr(key) }
.to_string_lossy()
.into_owned();
let v = unsafe { CStr::from_ptr(value) }
.to_string_lossy()
.into_owned();
out.insert(k, v);
}
unsafe {
let _ = mlxrs_sys::mlx_map_string_to_string_iterator_free(it);
}
out
}
fn build_array_map<'a, I>(arrays: I) -> Result<mlxrs_sys::mlx_map_string_to_array>
where
I: IntoIterator<Item = (&'a str, &'a Array)>,
{
crate::error::ensure_handler_installed();
let guard = ArrayMapGuard(unsafe { mlxrs_sys::mlx_map_string_to_array_new() });
if guard.0.ctx.is_null() {
let last = crate::error::take_last();
return Err(last.unwrap_or(Error::FfiNullHandle(
crate::error::FfiNullHandlePayload::new("mlx_map_string_to_array_new"),
)));
}
for (k, v) in arrays {
let ck = CString::new(k).map_err(|_| {
let _ = k;
Error::InteriorNul(InteriorNulPayload::new(
"io::map_arrays insert",
"array key",
))
})?;
check(unsafe { mlxrs_sys::mlx_map_string_to_array_insert(guard.0, ck.as_ptr(), v.0) })?;
}
let raw = guard.0;
std::mem::forget(guard);
Ok(raw)
}
fn build_string_map(meta: &HashMap<String, String>) -> Result<mlxrs_sys::mlx_map_string_to_string> {
crate::error::ensure_handler_installed();
let guard = StringMapGuard(unsafe { mlxrs_sys::mlx_map_string_to_string_new() });
if guard.0.ctx.is_null() {
let last = crate::error::take_last();
return Err(last.unwrap_or(Error::FfiNullHandle(
crate::error::FfiNullHandlePayload::new("mlx_map_string_to_string_new"),
)));
}
for (k, v) in meta {
let ck = CString::new(k.as_str()).map_err(|_| {
let _ = k;
Error::InteriorNul(InteriorNulPayload::new(
"io::map_meta insert",
"metadata key",
))
})?;
let cv = CString::new(v.as_str()).map_err(|_| {
let _ = v;
Error::InteriorNul(InteriorNulPayload::new(
"io::map_meta insert",
"metadata value",
))
})?;
check(unsafe {
mlxrs_sys::mlx_map_string_to_string_insert(guard.0, ck.as_ptr(), cv.as_ptr())
})?;
}
let raw = guard.0;
std::mem::forget(guard);
Ok(raw)
}
pub fn load_safetensors(path: &Path) -> Result<HashMap<String, Array>> {
Ok(load_safetensors_with_metadata(path)?.0)
}
pub fn load_safetensors_with_metadata(
path: &Path,
) -> Result<(HashMap<String, Array>, HashMap<String, String>)> {
let cpath = path_cstring(path)?;
let mut arrays = unsafe { mlxrs_sys::mlx_map_string_to_array_new() };
let mut meta = unsafe { mlxrs_sys::mlx_map_string_to_string_new() };
let arrays_guard = ArrayMapGuard(arrays);
let meta_guard = StringMapGuard(meta);
check(unsafe {
mlxrs_sys::mlx_load_safetensors(&mut arrays, &mut meta, cpath.as_ptr(), io_cpu_stream())
})?;
let a = drain_array_map(arrays);
let m = drain_string_map(meta);
drop(arrays_guard);
drop(meta_guard);
Ok((a, m))
}
pub fn save_safetensors(path: &Path, arrays: &HashMap<String, Array>) -> Result<()> {
save_safetensors_with_metadata(path, arrays, &HashMap::new())
}
pub fn save_safetensors_with_metadata(
path: &Path,
arrays: &HashMap<String, Array>,
metadata: &HashMap<String, String>,
) -> Result<()> {
save_safetensors_view(path, arrays.iter().map(|(k, v)| (k.as_str(), v)), metadata)
}
pub fn save_safetensors_view<'a, I>(
path: &Path,
arrays: I,
metadata: &HashMap<String, String>,
) -> Result<()>
where
I: IntoIterator<Item = (&'a str, &'a Array)>,
{
let cpath = path_cstring(path)?;
let amap = build_array_map(arrays)?;
let amap_guard = ArrayMapGuard(amap);
let mmap = build_string_map(metadata)?;
let mmap_guard = StringMapGuard(mmap);
check(unsafe { mlxrs_sys::mlx_save_safetensors(cpath.as_ptr(), amap, mmap) })?;
drop(amap_guard);
drop(mmap_guard);
Ok(())
}
pub fn save_safetensors_to_file<'a, I>(
file: &mut File,
arrays: I,
metadata: &HashMap<String, String>,
) -> Result<()>
where
I: IntoIterator<Item = (&'a str, &'a Array)>,
{
let amap = build_array_map(arrays)?;
let amap_guard = ArrayMapGuard(amap);
let mmap = build_string_map(metadata)?;
let mmap_guard = StringMapGuard(mmap);
let state = WriterState::new(file);
crate::error::ensure_handler_installed();
let writer = unsafe { mlxrs_sys::mlx_io_writer_new(state.as_desc(), make_writer_vtable()) };
let writer_guard = WriterGuard(writer);
if writer_guard.0.ctx.is_null() {
let last = crate::error::take_last();
return Err(last.unwrap_or(Error::FfiNullHandle(
crate::error::FfiNullHandlePayload::new("mlx_io_writer_new"),
)));
}
file.seek(SeekFrom::Start(0)).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_safetensors_to_file: seek to byte 0",
FileOp::Other("seek"),
std::path::PathBuf::new(),
e,
))
})?;
file.set_len(0).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_safetensors_to_file: truncate to 0",
FileOp::Other("set_len"),
std::path::PathBuf::new(),
e,
))
})?;
let rc = unsafe { mlxrs_sys::mlx_save_safetensors_writer(writer, amap, mmap) };
drop(writer_guard);
drop(amap_guard);
drop(mmap_guard);
if let Some(e) = state.into_err() {
return Err(Error::FileIo(FileIoPayload::new(
"save_safetensors_to_file: write callback",
FileOp::Write,
std::path::PathBuf::new(),
e,
)));
}
check(rc)?;
Ok(())
}
struct WriterState {
file: *mut File,
err: std::cell::Cell<Option<std::io::Error>>,
label: &'static CStr,
}
impl WriterState {
fn new(file: &mut File) -> Self {
Self {
file: file as *mut File,
err: std::cell::Cell::new(None),
label: c"mlxrs::io::save_safetensors_to_file(&mut File)",
}
}
fn as_desc(&self) -> *mut c_void {
(self as *const Self as *mut Self).cast::<c_void>()
}
fn into_err(self) -> Option<std::io::Error> {
self.err.into_inner()
}
fn set_err(&self, e: std::io::Error) {
let prev = self.err.take();
self.err.set(prev.or(Some(e)));
}
}
struct WriterGuard(mlxrs_sys::mlx_io_writer);
impl Drop for WriterGuard {
fn drop(&mut self) {
unsafe {
let _ = mlxrs_sys::mlx_io_writer_free(self.0);
}
}
}
fn with_state<R>(desc: *mut c_void, f: impl FnOnce(&WriterState, &mut File) -> R) -> Option<R> {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let state = unsafe { &*(desc as *const WriterState) };
let file = unsafe { &mut *state.file };
f(state, file)
}));
match result {
Ok(r) => Some(r),
Err(_) => {
let state = unsafe { &*(desc as *const WriterState) };
state.set_err(std::io::Error::other(
"mlxrs::io::save_safetensors_to_file callback panicked",
));
None
}
}
}
unsafe extern "C" fn cb_is_open(_desc: *mut c_void) -> bool {
true
}
unsafe extern "C" fn cb_good(desc: *mut c_void) -> bool {
with_state(desc, |state, _file| {
let prev = state.err.take();
let is_good = prev.is_none();
state.err.set(prev);
is_good
})
.unwrap_or(false)
}
unsafe extern "C" fn cb_tell(desc: *mut c_void) -> usize {
with_state(desc, |state, file| match file.stream_position() {
Ok(p) => p as usize,
Err(e) => {
state.set_err(e);
0
}
})
.unwrap_or(0)
}
unsafe extern "C" fn cb_seek(desc: *mut c_void, off: i64, whence: c_int) {
with_state(desc, |state, file| {
let pos = match whence {
x if x == libc::SEEK_SET => SeekFrom::Start(off as u64),
x if x == libc::SEEK_CUR => SeekFrom::Current(off),
x if x == libc::SEEK_END => SeekFrom::End(off),
_ => {
state.set_err(std::io::Error::other(format!(
"save_safetensors_to_file: unknown seek whence {whence}"
)));
return;
}
};
if let Err(e) = file.seek(pos) {
state.set_err(e);
}
});
}
unsafe extern "C" fn cb_read(desc: *mut c_void, _data: *mut c_char, _n: usize) {
with_state(desc, |state, _file| {
state.set_err(std::io::Error::other(
"save_safetensors_to_file: writer.read called (writer-only sink)",
));
});
}
unsafe extern "C" fn cb_read_at_offset(
desc: *mut c_void,
_data: *mut c_char,
_n: usize,
_off: usize,
) {
with_state(desc, |state, _file| {
state.set_err(std::io::Error::other(
"save_safetensors_to_file: writer.read_at_offset called (writer-only sink)",
));
});
}
unsafe extern "C" fn cb_write(desc: *mut c_void, data: *const c_char, n: usize) {
with_state(desc, |state, file| {
if n == 0 {
return;
}
if data.is_null() {
state.set_err(std::io::Error::other(
"save_safetensors_to_file: write callback received NULL data",
));
return;
}
let bytes = unsafe { std::slice::from_raw_parts(data as *const u8, n) };
if let Err(e) = file.write_all(bytes) {
state.set_err(e);
}
});
}
unsafe extern "C" fn cb_label(desc: *mut c_void) -> *const c_char {
with_state(desc, |state, _file| state.label.as_ptr()).unwrap_or(c"<panic>".as_ptr())
}
unsafe extern "C" fn cb_free(_desc: *mut c_void) {
}
fn make_writer_vtable() -> mlxrs_sys::mlx_io_vtable {
mlxrs_sys::mlx_io_vtable {
is_open: Some(cb_is_open),
good: Some(cb_good),
tell: Some(cb_tell),
seek: Some(cb_seek),
read: Some(cb_read),
read_at_offset: Some(cb_read_at_offset),
write: Some(cb_write),
label: Some(cb_label),
free: Some(cb_free),
}
}
#[cfg(feature = "gguf")]
#[cfg_attr(docsrs, doc(cfg(feature = "gguf")))]
#[non_exhaustive]
#[derive(
Debug, derive_more::Display, derive_more::IsVariant, derive_more::Unwrap, derive_more::TryUnwrap,
)]
#[display("{}", self.as_str())]
#[unwrap(ref, ref_mut)]
#[try_unwrap(ref, ref_mut)]
pub enum GgufMetadata {
Array(Array),
String(String),
StringList(Vec<String>),
}
#[cfg(feature = "gguf")]
impl GgufMetadata {
pub fn as_str(&self) -> &'static str {
match self {
Self::Array(_) => "array",
Self::String(_) => "string",
Self::StringList(_) => "string_list",
}
}
}
#[cfg(feature = "gguf")]
#[cfg_attr(docsrs, doc(cfg(feature = "gguf")))]
pub fn load_gguf(path: &Path) -> Result<(HashMap<String, Array>, HashMap<String, GgufMetadata>)> {
let cpath = path_cstring(path)?;
let mut guard = GgufGuard(mlxrs_sys::mlx_io_gguf {
ctx: std::ptr::null_mut(),
});
check(unsafe { mlxrs_sys::mlx_load_gguf(&mut guard.0, cpath.as_ptr(), io_cpu_stream()) })?;
let gguf = guard.0;
let mut keys = unsafe { mlxrs_sys::mlx_vector_string_new() };
let keys_guard = VectorStringGuard(keys);
check(unsafe { mlxrs_sys::mlx_io_gguf_get_keys(&mut keys, gguf) })?;
let n = unsafe { mlxrs_sys::mlx_vector_string_size(keys) };
let mut weights = HashMap::new();
let mut metadata = HashMap::new();
for i in 0..n {
let mut raw: *mut std::os::raw::c_char = std::ptr::null_mut();
check(unsafe { mlxrs_sys::mlx_vector_string_get(&mut raw, keys, i) })?;
let key = unsafe { CStr::from_ptr(raw) }
.to_string_lossy()
.into_owned();
let ckey = CString::new(key.as_str()).map_err(|_| {
let _ = &key;
Error::InteriorNul(InteriorNulPayload::new(
"io::gguf_load: key lookup",
"gguf key",
))
})?;
let mut f_arr = false;
let rc_arr =
unsafe { mlxrs_sys::mlx_io_gguf_has_metadata_array(&mut f_arr, gguf, ckey.as_ptr()) };
let is_meta_arr = gguf_has_meta(rc_arr, f_arr)?;
let mut f_str = false;
let rc_str =
unsafe { mlxrs_sys::mlx_io_gguf_has_metadata_string(&mut f_str, gguf, ckey.as_ptr()) };
let is_meta_str = gguf_has_meta(rc_str, f_str)?;
let mut f_vstr = false;
let rc_vstr = unsafe {
mlxrs_sys::mlx_io_gguf_has_metadata_vector_string(&mut f_vstr, gguf, ckey.as_ptr())
};
let is_meta_vstr = gguf_has_meta(rc_vstr, f_vstr)?;
if is_meta_arr {
let mut arr = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_io_gguf_get_metadata_array(&mut arr.0, gguf, ckey.as_ptr()) })?;
metadata.insert(key, GgufMetadata::Array(arr));
} else if is_meta_str {
let mut s = StringGuard(unsafe { mlxrs_sys::mlx_string_new() });
check(unsafe { mlxrs_sys::mlx_io_gguf_get_metadata_string(&mut s.0, gguf, ckey.as_ptr()) })?;
let v = unsafe { CStr::from_ptr(mlxrs_sys::mlx_string_data(s.0)) }
.to_string_lossy()
.into_owned();
metadata.insert(key, GgufMetadata::String(v));
} else if is_meta_vstr {
let mut vstr = unsafe { mlxrs_sys::mlx_vector_string_new() };
let vstr_guard = VectorStringGuard(vstr);
check(unsafe {
mlxrs_sys::mlx_io_gguf_get_metadata_vector_string(&mut vstr, gguf, ckey.as_ptr())
})?;
let m = unsafe { mlxrs_sys::mlx_vector_string_size(vstr) };
let mut list = Vec::with_capacity(m);
for j in 0..m {
let mut sp: *mut std::os::raw::c_char = std::ptr::null_mut();
check(unsafe { mlxrs_sys::mlx_vector_string_get(&mut sp, vstr, j) })?;
list.push(unsafe { CStr::from_ptr(sp) }.to_string_lossy().into_owned());
}
drop(vstr_guard);
metadata.insert(key, GgufMetadata::StringList(list));
} else {
let mut arr = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_io_gguf_get_array(&mut arr.0, gguf, ckey.as_ptr()) })?;
weights.insert(key, arr);
}
}
drop(keys_guard);
drop(guard);
Ok((weights, metadata))
}
#[cfg(feature = "gguf")]
#[cfg_attr(docsrs, doc(cfg(feature = "gguf")))]
pub fn save_gguf(
path: &Path,
weights: &HashMap<String, Array>,
metadata: &HashMap<String, GgufMetadata>,
) -> Result<()> {
let cpath = path_cstring(path)?;
let gguf = unsafe { mlxrs_sys::mlx_io_gguf_new() };
let guard = GgufGuard(gguf);
for (k, v) in weights {
let ck = CString::new(k.as_str()).map_err(|_| {
let _ = &k;
Error::InteriorNul(InteriorNulPayload::new(
"io::gguf_save: weights insert",
"gguf weight key",
))
})?;
check(unsafe { mlxrs_sys::mlx_io_gguf_set_array(gguf, ck.as_ptr(), v.0) })?;
}
for (k, v) in metadata {
let ck = CString::new(k.as_str()).map_err(|_| {
let _ = &k;
Error::InteriorNul(InteriorNulPayload::new(
"io::gguf_save: metadata insert",
"gguf metadata key",
))
})?;
match v {
GgufMetadata::Array(arr) => {
check(unsafe { mlxrs_sys::mlx_io_gguf_set_metadata_array(gguf, ck.as_ptr(), arr.0) })?;
}
GgufMetadata::String(s) => {
let cs = CString::new(s.as_str()).map_err(|_| {
let _ = &s;
Error::InteriorNul(InteriorNulPayload::new(
"io::gguf_save: metadata string insert",
"gguf metadata string value",
))
})?;
check(unsafe {
mlxrs_sys::mlx_io_gguf_set_metadata_string(gguf, ck.as_ptr(), cs.as_ptr())
})?;
}
GgufMetadata::StringList(list) => {
let vstr = unsafe { mlxrs_sys::mlx_vector_string_new() };
let vstr_guard = VectorStringGuard(vstr);
for s in list {
let cs = CString::new(s.as_str()).map_err(|_| {
let _ = &s;
Error::InteriorNul(InteriorNulPayload::new(
"io::gguf_save: metadata list-entry append",
"gguf metadata list entry",
))
})?;
check(unsafe { mlxrs_sys::mlx_vector_string_append_value(vstr, cs.as_ptr()) })?;
}
check(unsafe {
mlxrs_sys::mlx_io_gguf_set_metadata_vector_string(gguf, ck.as_ptr(), vstr)
})?;
drop(vstr_guard);
}
}
}
check(unsafe { mlxrs_sys::mlx_save_gguf(cpath.as_ptr(), gguf) })?;
drop(guard);
Ok(())
}
#[cfg(test)]
mod tests;