use std::ffi::CString;
use crate::errors::{check_status, check_ptr};
use crate::{TensorMap, Error};
use super::{realloc_vec, create_ndarray};
pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorMap, Error> {
let path = path.as_ref().as_os_str().to_str().expect("this path is not valid UTF8");
let path = CString::new(path).expect("this path contains a NULL byte");
let ptr = unsafe {
crate::c_api::mts_tensormap_load(
path.as_ptr(),
Some(create_ndarray)
)
};
check_ptr(ptr)?;
return Ok(unsafe { TensorMap::from_raw(ptr) });
}
pub fn load_buffer(buffer: &[u8]) -> Result<TensorMap, Error> {
let ptr = unsafe {
crate::c_api::mts_tensormap_load_buffer(
buffer.as_ptr(),
buffer.len(),
Some(create_ndarray)
)
};
check_ptr(ptr)?;
return Ok(unsafe { TensorMap::from_raw(ptr) });
}
pub fn save(path: impl AsRef<std::path::Path>, tensor: &TensorMap) -> Result<(), Error> {
let path = path.as_ref().as_os_str().to_str().expect("this path is not valid UTF8");
let path = CString::new(path).expect("this path contains a NULL byte");
unsafe {
check_status(crate::c_api::mts_tensormap_save(path.as_ptr(), tensor.ptr))
}
}
pub fn save_buffer(tensor: &TensorMap, buffer: &mut Vec<u8>) -> Result<(), Error> {
let mut buffer_ptr = buffer.as_mut_ptr();
let mut buffer_count = buffer.len();
unsafe {
check_status(crate::c_api::mts_tensormap_save_buffer(
&mut buffer_ptr,
&mut buffer_count,
(buffer as *mut Vec<u8>).cast(),
Some(realloc_vec),
tensor.ptr,
))?;
}
buffer.resize(buffer_count, 0);
Ok(())
}