use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::Path;
#[derive(Serialize, Deserialize)]
struct MessagePackArray<T> {
shape: Vec<usize>,
data: Vec<T>,
dtype: String,
}
pub fn to_messagepack<T, P>(array: &Array<T>, path: P) -> Result<()>
where
T: Clone + Serialize,
P: AsRef<Path>,
{
let file = File::create(path.as_ref())
.map_err(|e| NumRs2Error::IOError(format!("Failed to create file: {}", e)))?;
let mut writer = BufWriter::new(file);
to_messagepack_writer(array, &mut writer)
}
pub fn to_messagepack_writer<T, W>(array: &Array<T>, writer: &mut W) -> Result<()>
where
T: Clone + Serialize,
W: Write,
{
let msgpack_array = MessagePackArray {
shape: array.shape(),
data: array.to_vec(),
dtype: std::any::type_name::<T>().to_string(),
};
rmp_serde::encode::write(writer, &msgpack_array).map_err(|e| {
NumRs2Error::SerializationError(format!("MessagePack serialization error: {}", e))
})?;
Ok(())
}
pub fn to_messagepack_bytes<T>(array: &Array<T>) -> Result<Vec<u8>>
where
T: Clone + Serialize,
{
let msgpack_array = MessagePackArray {
shape: array.shape(),
data: array.to_vec(),
dtype: std::any::type_name::<T>().to_string(),
};
rmp_serde::to_vec(&msgpack_array).map_err(|e| {
NumRs2Error::SerializationError(format!("MessagePack serialization error: {}", e))
})
}
pub fn from_messagepack<T, P>(path: P) -> Result<Array<T>>
where
T: Clone + for<'de> Deserialize<'de>,
P: AsRef<Path>,
{
let file = File::open(path.as_ref())
.map_err(|e| NumRs2Error::IOError(format!("Failed to open file: {}", e)))?;
let mut reader = BufReader::new(file);
from_messagepack_reader(&mut reader)
}
pub fn from_messagepack_reader<T, R>(reader: &mut R) -> Result<Array<T>>
where
T: Clone + for<'de> Deserialize<'de>,
R: Read,
{
let msgpack_array: MessagePackArray<T> = rmp_serde::from_read(reader).map_err(|e| {
NumRs2Error::DeserializationError(format!("MessagePack deserialization error: {}", e))
})?;
let array = Array::from_vec(msgpack_array.data).reshape(&msgpack_array.shape);
Ok(array)
}
pub fn from_messagepack_bytes<T>(bytes: &[u8]) -> Result<Array<T>>
where
T: Clone + for<'de> Deserialize<'de>,
{
let msgpack_array: MessagePackArray<T> = rmp_serde::from_slice(bytes).map_err(|e| {
NumRs2Error::DeserializationError(format!("MessagePack deserialization error: {}", e))
})?;
let array = Array::from_vec(msgpack_array.data).reshape(&msgpack_array.shape);
Ok(array)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn test_messagepack_roundtrip_i32() {
let array = Array::from_vec(vec![1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
let temp_file = NamedTempFile::new().expect("Failed to create temp file");
let path = temp_file.path();
to_messagepack(&array, path).expect("Failed to write MessagePack");
let loaded: Array<i32> = from_messagepack(path).expect("Failed to read MessagePack");
assert_eq!(array.shape(), loaded.shape());
assert_eq!(array.to_vec(), loaded.to_vec());
}
#[test]
fn test_messagepack_roundtrip_f64() {
let array = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let temp_file = NamedTempFile::new().expect("Failed to create temp file");
let path = temp_file.path();
to_messagepack(&array, path).expect("Failed to write MessagePack");
let loaded: Array<f64> = from_messagepack(path).expect("Failed to read MessagePack");
assert_eq!(array.shape(), loaded.shape());
assert_eq!(array.to_vec(), loaded.to_vec());
}
#[test]
fn test_messagepack_bytes() {
let array = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let bytes = to_messagepack_bytes(&array).expect("Failed to serialize to bytes");
let loaded: Array<f64> =
from_messagepack_bytes(&bytes).expect("Failed to deserialize from bytes");
assert_eq!(array.shape(), loaded.shape());
assert_eq!(array.to_vec(), loaded.to_vec());
}
#[test]
fn test_messagepack_1d_array() {
let array = Array::from_vec(vec![10, 20, 30, 40, 50]);
let bytes = to_messagepack_bytes(&array).expect("Failed to serialize");
let loaded: Array<i32> = from_messagepack_bytes(&bytes).expect("Failed to deserialize");
assert_eq!(array.shape(), loaded.shape());
assert_eq!(array.to_vec(), loaded.to_vec());
}
#[test]
fn test_messagepack_3d_array() {
let array = Array::from_vec(vec![1.0; 24]).reshape(&[2, 3, 4]);
let bytes = to_messagepack_bytes(&array).expect("Failed to serialize");
let loaded: Array<f64> = from_messagepack_bytes(&bytes).expect("Failed to deserialize");
assert_eq!(array.shape(), loaded.shape());
assert_eq!(array.to_vec(), loaded.to_vec());
}
}