use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use std::path::Path;
pub fn write_mat<T, P>(array: &Array<T>, path: P, var_name: &str) -> Result<()>
where
T: Clone + MatWritable,
P: AsRef<Path>,
{
T::write_to_mat(array, path.as_ref(), var_name)
}
pub fn read_mat<T, P>(path: P, var_name: &str) -> Result<Array<T>>
where
T: Clone + MatReadable,
P: AsRef<Path>,
{
T::read_from_mat(path.as_ref(), var_name)
}
pub trait MatWritable: Clone {
fn write_to_mat(array: &Array<Self>, path: &Path, var_name: &str) -> Result<()>;
}
pub trait MatReadable: Clone {
fn read_from_mat(path: &Path, var_name: &str) -> Result<Array<Self>>;
}
macro_rules! impl_mat_io {
($type:ty, $type_name:expr) => {
impl MatWritable for $type {
fn write_to_mat(array: &Array<Self>, path: &Path, var_name: &str) -> Result<()> {
let shape = array.shape();
let data = array.to_vec();
let metadata = serde_json::json!({
"variable_name": var_name,
"shape": shape,
"dtype": $type_name,
"format": "NumRS2-MAT",
});
let meta_path = path.with_extension("mat.meta");
std::fs::write(&meta_path, metadata.to_string())
.map_err(|e| NumRs2Error::IOError(format!("Failed to write metadata: {}", e)))?;
let data_bytes: Vec<u8> = unsafe {
std::slice::from_raw_parts(
data.as_ptr() as *const u8,
data.len() * std::mem::size_of::<$type>(),
)
.to_vec()
};
std::fs::write(path, &data_bytes)
.map_err(|e| NumRs2Error::IOError(format!("Failed to write data: {}", e)))?;
Ok(())
}
}
impl MatReadable for $type {
fn read_from_mat(path: &Path, var_name: &str) -> Result<Array<Self>> {
let meta_path = path.with_extension("mat.meta");
let metadata_str = std::fs::read_to_string(&meta_path)
.map_err(|e| NumRs2Error::IOError(format!("Failed to read metadata: {}", e)))?;
let metadata: serde_json::Value = serde_json::from_str(&metadata_str)
.map_err(|e| NumRs2Error::DeserializationError(format!("Invalid metadata: {}", e)))?;
let stored_var = metadata["variable_name"]
.as_str()
.ok_or_else(|| NumRs2Error::DeserializationError("Missing variable name".to_string()))?;
if stored_var != var_name {
return Err(NumRs2Error::DeserializationError(format!(
"Variable name mismatch: expected {}, found {}",
var_name, stored_var
)));
}
let shape: Vec<usize> = metadata["shape"]
.as_array()
.ok_or_else(|| NumRs2Error::DeserializationError("Missing shape".to_string()))?
.iter()
.map(|v| {
v.as_u64()
.ok_or_else(|| NumRs2Error::DeserializationError("Invalid shape value".to_string()))
.map(|x| x as usize)
})
.collect::<Result<Vec<_>>>()?;
let data_bytes = std::fs::read(path)
.map_err(|e| NumRs2Error::IOError(format!("Failed to read data: {}", e)))?;
let num_elements = shape.iter().product();
if data_bytes.len() != num_elements * std::mem::size_of::<$type>() {
return Err(NumRs2Error::DeserializationError(format!(
"Data size mismatch: expected {} bytes, got {}",
num_elements * std::mem::size_of::<$type>(),
data_bytes.len()
)));
}
let data: Vec<$type> = unsafe {
std::slice::from_raw_parts(
data_bytes.as_ptr() as *const $type,
num_elements,
)
.to_vec()
};
Ok(Array::from_vec(data).reshape(&shape))
}
}
};
}
impl_mat_io!(f64, "f64");
impl_mat_io!(f32, "f32");
impl_mat_io!(i32, "i32");
impl_mat_io!(i64, "i64");
impl_mat_io!(u32, "u32");
impl_mat_io!(u64, "u64");
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_mat_roundtrip_f64() {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let path = temp_dir.path().join("test.mat");
let array = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
write_mat(&array, &path, "test_var").expect("Failed to write .mat file");
let loaded: Array<f64> = read_mat(&path, "test_var").expect("Failed to read .mat file");
assert_eq!(array.shape(), loaded.shape());
assert_eq!(array.to_vec(), loaded.to_vec());
}
#[test]
fn test_mat_roundtrip_i32() {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let path = temp_dir.path().join("test_int.mat");
let array = Array::from_vec(vec![1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
write_mat(&array, &path, "integers").expect("Failed to write .mat file");
let loaded: Array<i32> = read_mat(&path, "integers").expect("Failed to read .mat file");
assert_eq!(array.shape(), loaded.shape());
assert_eq!(array.to_vec(), loaded.to_vec());
}
#[test]
fn test_mat_multidimensional() {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let path = temp_dir.path().join("test_3d.mat");
let array = Array::from_vec(vec![1.0; 24]).reshape(&[2, 3, 4]);
write_mat(&array, &path, "data3d").expect("Failed to write .mat file");
let loaded: Array<f64> = read_mat(&path, "data3d").expect("Failed to read .mat file");
assert_eq!(array.shape(), loaded.shape());
assert_eq!(array.to_vec(), loaded.to_vec());
}
#[test]
fn test_mat_1d_array() {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let path = temp_dir.path().join("test_1d.mat");
let array = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
write_mat(&array, &path, "vector").expect("Failed to write .mat file");
let loaded: Array<f64> = read_mat(&path, "vector").expect("Failed to read .mat file");
assert_eq!(array.shape(), loaded.shape());
assert_eq!(array.to_vec(), loaded.to_vec());
}
}