use ::serde::{Deserialize, Serialize};
use scirs2_core::ndarray::{Array, Array2, ArrayBase, IxDyn};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter, Write};
use std::path::Path;
use crate::error::{IoError, Result};
use oxicode::{config, serde as oxicode_serde};
fn oxicode_cfg() -> impl oxicode::config::Config {
config::standard()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SerializationFormat {
Binary,
JSON,
MessagePack,
}
#[allow(dead_code)]
pub fn serialize_array<P, A, S>(
path: P,
array: &ArrayBase<S, IxDyn>,
format: SerializationFormat,
) -> Result<()>
where
P: AsRef<Path>,
A: Serialize + Clone,
S: scirs2_core::ndarray::Data<Elem = A>,
{
let shape = array.shape().to_vec();
let data: Vec<A> = array.iter().cloned().collect();
let serializable = SerializedArray {
metadata: ArrayMetadata {
shape,
dtype: std::any::type_name::<A>().to_string(),
order: 'C',
metadata: std::collections::HashMap::new(),
},
data,
};
let file = File::create(path).map_err(|e| IoError::FileError(e.to_string()))?;
let mut writer = BufWriter::new(file);
match format {
SerializationFormat::Binary => {
let cfg = oxicode_cfg();
let bytes = oxicode_serde::encode_to_vec(&serializable, cfg)
.map_err(|e| IoError::SerializationError(e.to_string()))?;
let len = bytes.len() as u64;
writer
.write_all(&len.to_le_bytes())
.map_err(|e| IoError::FileError(e.to_string()))?;
writer
.write_all(&bytes)
.map_err(|e| IoError::FileError(e.to_string()))?;
}
SerializationFormat::JSON => {
serde_json::to_writer_pretty(&mut writer, &serializable)
.map_err(|e| IoError::SerializationError(e.to_string()))?;
}
SerializationFormat::MessagePack => {
rmp_serde::encode::write(&mut writer, &serializable)
.map_err(|e| IoError::SerializationError(e.to_string()))?;
}
}
writer
.flush()
.map_err(|e| IoError::FileError(e.to_string()))?;
Ok(())
}
#[allow(dead_code)]
pub fn deserialize_array<P, A>(path: P, format: SerializationFormat) -> Result<Array<A, IxDyn>>
where
P: AsRef<Path>,
A: for<'de> Deserialize<'de> + Clone,
{
let file = File::open(path).map_err(|e| IoError::FileError(e.to_string()))?;
let mut reader = BufReader::new(file);
let serialized: SerializedArray<A> = match format {
SerializationFormat::Binary => {
use std::io::Read;
let mut buf = Vec::new();
reader
.read_to_end(&mut buf)
.map_err(|e| IoError::FileError(e.to_string()))?;
if buf.len() >= 8 {
let mut len_bytes = [0u8; 8];
len_bytes.copy_from_slice(&buf[0..8]);
let declared = u64::from_le_bytes(len_bytes) as usize;
if declared <= buf.len() - 8 {
let data_slice = &buf[8..8 + declared];
let cfg = oxicode_cfg();
if let Ok((val, _consumed)) = oxicode_serde::decode_owned_from_slice::<
SerializedArray<A>,
_,
>(data_slice, cfg)
{
val
} else {
let cfg = oxicode_cfg();
let (val, _len): (SerializedArray<A>, usize) =
oxicode_serde::decode_owned_from_slice(&buf, cfg)
.map_err(|e| IoError::DeserializationError(e.to_string()))?;
val
}
} else {
let cfg = oxicode_cfg();
let (val, _len): (SerializedArray<A>, usize) =
oxicode_serde::decode_owned_from_slice(&buf, cfg)
.map_err(|e| IoError::DeserializationError(e.to_string()))?;
val
}
} else {
let cfg = oxicode_cfg();
let (val, _len): (SerializedArray<A>, usize) =
oxicode_serde::decode_owned_from_slice(&buf, cfg)
.map_err(|e| IoError::DeserializationError(e.to_string()))?;
val
}
}
SerializationFormat::JSON => serde_json::from_reader(reader)
.map_err(|e| IoError::DeserializationError(e.to_string()))?,
SerializationFormat::MessagePack => rmp_serde::from_read(reader)
.map_err(|e| IoError::DeserializationError(e.to_string()))?,
};
let array = Array::from_shape_vec(IxDyn(&serialized.metadata.shape), serialized.data)
.map_err(|e| IoError::FormatError(format!("Failed to reconstruct array: {}", e)))?;
Ok(array)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArrayMetadata {
pub shape: Vec<usize>,
pub dtype: String,
pub order: char,
pub metadata: std::collections::HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedArray<A> {
pub metadata: ArrayMetadata,
pub data: Vec<A>,
}
#[allow(dead_code)]
pub fn serialize_array_with_metadata<P, A, S>(
path: P,
array: &ArrayBase<S, IxDyn>,
metadata: std::collections::HashMap<String, String>,
format: SerializationFormat,
) -> Result<()>
where
P: AsRef<Path>,
A: Serialize + Clone,
S: scirs2_core::ndarray::Data<Elem = A>,
{
let file = File::create(path).map_err(|e| IoError::FileError(e.to_string()))?;
let mut writer = BufWriter::new(file);
let shape = array.shape().to_vec();
let dtype = std::any::type_name::<A>().to_string();
let array_metadata = ArrayMetadata {
shape,
dtype,
order: 'C', metadata,
};
let serialized = SerializedArray {
metadata: array_metadata,
data: array.iter().cloned().collect(),
};
match format {
SerializationFormat::Binary => {
let cfg = oxicode_cfg();
let bytes = oxicode_serde::encode_to_vec(&serialized, cfg)
.map_err(|e| IoError::SerializationError(e.to_string()))?;
writer
.write_all(&bytes)
.map_err(|e| IoError::FileError(e.to_string()))?;
}
SerializationFormat::JSON => {
serde_json::to_writer_pretty(&mut writer, &serialized)
.map_err(|e| IoError::SerializationError(e.to_string()))?;
}
SerializationFormat::MessagePack => {
rmp_serde::encode::write(&mut writer, &serialized)
.map_err(|e| IoError::SerializationError(e.to_string()))?;
}
}
writer
.flush()
.map_err(|e| IoError::FileError(e.to_string()))?;
Ok(())
}
#[allow(dead_code)]
pub fn deserialize_array_with_metadata<P, A>(
path: P,
format: SerializationFormat,
) -> Result<(Array<A, IxDyn>, std::collections::HashMap<String, String>)>
where
P: AsRef<Path>,
A: for<'de> Deserialize<'de> + Clone,
{
let file = File::open(path).map_err(|e| IoError::FileError(e.to_string()))?;
let mut reader = BufReader::new(file);
let serialized: SerializedArray<A> = match format {
SerializationFormat::Binary => {
let cfg = oxicode_cfg();
let (val, _len): (SerializedArray<A>, usize) =
oxicode_serde::decode_from_std_read(&mut reader, cfg)
.map_err(|e| IoError::DeserializationError(e.to_string()))?;
val
}
SerializationFormat::JSON => serde_json::from_reader(reader)
.map_err(|e| IoError::DeserializationError(e.to_string()))?,
SerializationFormat::MessagePack => rmp_serde::from_read(reader)
.map_err(|e| IoError::DeserializationError(e.to_string()))?,
};
let shape = serialized.metadata.shape;
let data = serialized.data;
let array = Array::from_shape_vec(IxDyn(&shape), data)
.map_err(|e| IoError::FormatError(format!("Invalid shape: {:?}", e)))?;
Ok((array, serialized.metadata.metadata))
}
#[allow(dead_code)]
pub fn serialize_struct<P, T>(path: P, data: &T, format: SerializationFormat) -> Result<()>
where
P: AsRef<Path>,
T: Serialize,
{
let file = File::create(path).map_err(|e| IoError::FileError(e.to_string()))?;
let mut writer = BufWriter::new(file);
match format {
SerializationFormat::Binary => {
let cfg = oxicode_cfg();
let bytes = oxicode_serde::encode_to_vec(data, cfg)
.map_err(|e| IoError::SerializationError(e.to_string()))?;
writer
.write_all(&bytes)
.map_err(|e| IoError::FileError(e.to_string()))?;
}
SerializationFormat::JSON => {
serde_json::to_writer_pretty(&mut writer, data)
.map_err(|e| IoError::SerializationError(e.to_string()))?;
}
SerializationFormat::MessagePack => {
rmp_serde::encode::write(&mut writer, data)
.map_err(|e| IoError::SerializationError(e.to_string()))?;
}
}
writer
.flush()
.map_err(|e| IoError::FileError(e.to_string()))?;
Ok(())
}
#[allow(dead_code)]
pub fn deserialize_struct<P, T>(path: P, format: SerializationFormat) -> Result<T>
where
P: AsRef<Path>,
T: for<'de> Deserialize<'de>,
{
let file = File::open(path).map_err(|e| IoError::FileError(e.to_string()))?;
let mut reader = BufReader::new(file);
match format {
SerializationFormat::Binary => {
let cfg = oxicode_cfg();
let (data, _len): (T, usize) = oxicode_serde::decode_from_std_read(&mut reader, cfg)
.map_err(|e| IoError::DeserializationError(e.to_string()))?;
Ok(data)
}
SerializationFormat::JSON => {
let data = serde_json::from_reader(reader)
.map_err(|e| IoError::DeserializationError(e.to_string()))?;
Ok(data)
}
SerializationFormat::MessagePack => {
let data = rmp_serde::from_read(reader)
.map_err(|e| IoError::DeserializationError(e.to_string()))?;
Ok(data)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SparseMatrixCOO<A> {
pub rows: usize,
pub cols: usize,
pub row_indices: Vec<usize>,
pub col_indices: Vec<usize>,
pub values: Vec<A>,
pub metadata: std::collections::HashMap<String, String>,
}
impl<A> SparseMatrixCOO<A> {
pub fn new(rows: usize, cols: usize) -> Self {
Self {
rows,
cols,
row_indices: Vec::new(),
col_indices: Vec::new(),
values: Vec::new(),
metadata: std::collections::HashMap::new(),
}
}
pub fn push(&mut self, row: usize, col: usize, value: A) {
if row < self.rows && col < self.cols {
self.row_indices.push(row);
self.col_indices.push(col);
self.values.push(value);
}
}
pub fn nnz(&self) -> usize {
self.values.len()
}
}
#[allow(dead_code)]
pub fn serialize_sparse_matrix<P, A>(
path: P,
matrix: &SparseMatrixCOO<A>,
format: SerializationFormat,
) -> Result<()>
where
P: AsRef<Path>,
A: Serialize,
{
serialize_struct(path, matrix, format)
}
#[allow(dead_code)]
pub fn deserialize_sparse_matrix<P, A>(
path: P,
format: SerializationFormat,
) -> Result<SparseMatrixCOO<A>>
where
P: AsRef<Path>,
A: for<'de> Deserialize<'de>,
{
deserialize_struct(path, format)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SparseFormat {
COO,
CSR,
CSC,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SparseMatrix<A> {
pub shape: (usize, usize),
pub format: SparseFormat,
pub coo_data: SparseMatrixCOO<A>,
#[serde(skip)]
pub csr_data: Option<SparseMatrixCSR<A>>,
#[serde(skip)]
pub csc_data: Option<SparseMatrixCSC<A>>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SparseMatrixCSR<A> {
pub rows: usize,
pub cols: usize,
pub row_ptrs: Vec<usize>,
pub col_indices: Vec<usize>,
pub values: Vec<A>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SparseMatrixCSC<A> {
pub rows: usize,
pub cols: usize,
pub col_ptrs: Vec<usize>,
pub row_indices: Vec<usize>,
pub values: Vec<A>,
pub metadata: HashMap<String, String>,
}
impl<A: Clone> SparseMatrix<A> {
pub fn from_coo(coo: SparseMatrixCOO<A>) -> Self {
let shape = (coo.rows, coo.cols);
Self {
shape,
format: SparseFormat::COO,
coo_data: coo,
csr_data: None,
csc_data: None,
metadata: HashMap::new(),
}
}
pub fn new(rows: usize, cols: usize) -> Self {
Self {
shape: (rows, cols),
format: SparseFormat::COO,
coo_data: SparseMatrixCOO::new(rows, cols),
csr_data: None,
csc_data: None,
metadata: HashMap::new(),
}
}
pub fn insert(&mut self, row: usize, col: usize, value: A) {
self.coo_data.push(row, col, value);
self.csr_data = None;
self.csc_data = None;
}
pub fn nnz(&self) -> usize {
self.coo_data.nnz()
}
pub fn shape(&self) -> (usize, usize) {
self.shape
}
pub fn to_csr(&mut self) -> Result<&SparseMatrixCSR<A>>
where
A: Clone + Default + PartialEq,
{
if self.csr_data.is_none() {
self.csr_data = Some(self.convert_to_csr()?);
}
Ok(self.csr_data.as_ref().expect("Operation failed"))
}
pub fn to_csc(&mut self) -> Result<&SparseMatrixCSC<A>>
where
A: Clone + Default + PartialEq,
{
if self.csc_data.is_none() {
self.csc_data = Some(self.convert_to_csc()?);
}
Ok(self.csc_data.as_ref().expect("Operation failed"))
}
fn convert_to_csr(&self) -> Result<SparseMatrixCSR<A>>
where
A: Clone + Default,
{
let nnz = self.coo_data.nnz();
let rows = self.shape.0;
if nnz == 0 {
return Ok(SparseMatrixCSR {
rows,
cols: self.shape.1,
row_ptrs: vec![0; rows + 1],
col_indices: Vec::new(),
values: Vec::new(),
metadata: self.metadata.clone(),
});
}
let mut triplets: Vec<(usize, usize, A)> = self
.coo_data
.row_indices
.iter()
.zip(self.coo_data.col_indices.iter())
.zip(self.coo_data.values.iter())
.map(|((&r, &c), v)| (r, c, v.clone()))
.collect();
triplets.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
let mut row_ptrs = vec![0; rows + 1];
let mut col_indices = Vec::with_capacity(nnz);
let mut values = Vec::with_capacity(nnz);
let mut current_row = 0;
for (i, (row, col, val)) in triplets.iter().enumerate() {
while current_row < *row {
current_row += 1;
row_ptrs[current_row] = i;
}
col_indices.push(*col);
values.push(val.clone());
}
while current_row < rows {
current_row += 1;
row_ptrs[current_row] = nnz;
}
Ok(SparseMatrixCSR {
rows,
cols: self.shape.1,
row_ptrs,
col_indices,
values,
metadata: self.metadata.clone(),
})
}
fn convert_to_csc(&self) -> Result<SparseMatrixCSC<A>>
where
A: Clone + Default,
{
let nnz = self.coo_data.nnz();
let cols = self.shape.1;
if nnz == 0 {
return Ok(SparseMatrixCSC {
rows: self.shape.0,
cols,
col_ptrs: vec![0; cols + 1],
row_indices: Vec::new(),
values: Vec::new(),
metadata: self.metadata.clone(),
});
}
let mut triplets: Vec<(usize, usize, A)> = self
.coo_data
.row_indices
.iter()
.zip(self.coo_data.col_indices.iter())
.zip(self.coo_data.values.iter())
.map(|((&r, &c), v)| (r, c, v.clone()))
.collect();
triplets.sort_by(|a, b| a.1.cmp(&b.1).then(a.0.cmp(&b.0)));
let mut col_ptrs = vec![0; cols + 1];
let mut row_indices = Vec::with_capacity(nnz);
let mut values = Vec::with_capacity(nnz);
let mut current_col = 0;
for (i, (row, col, val)) in triplets.iter().enumerate() {
while current_col < *col {
current_col += 1;
col_ptrs[current_col] = i;
}
row_indices.push(*row);
values.push(val.clone());
}
while current_col < cols {
current_col += 1;
col_ptrs[current_col] = nnz;
}
Ok(SparseMatrixCSC {
rows: self.shape.0,
cols,
col_ptrs,
row_indices,
values,
metadata: self.metadata.clone(),
})
}
pub fn to_dense(&self) -> Array2<A>
where
A: Clone + Default,
{
let mut dense = Array2::default(self.shape);
for ((row, col), value) in self
.coo_data
.row_indices
.iter()
.zip(self.coo_data.col_indices.iter())
.zip(self.coo_data.values.iter())
{
dense[[*row, *col]] = value.clone();
}
dense
}
pub fn sparsity(&self) -> f64 {
let total_elements = self.shape.0 * self.shape.1;
if total_elements == 0 {
0.0
} else {
1.0 - (self.nnz() as f64 / total_elements as f64)
}
}
pub fn memory_usage(&self) -> usize {
let coo_size = self.coo_data.values.len()
* (std::mem::size_of::<A>() + 2 * std::mem::size_of::<usize>());
let csr_size = if let Some(ref csr) = self.csr_data {
csr.values.len() * std::mem::size_of::<A>()
+ csr.col_indices.len() * std::mem::size_of::<usize>()
+ csr.row_ptrs.len() * std::mem::size_of::<usize>()
} else {
0
};
let csc_size = if let Some(ref csc) = self.csc_data {
csc.values.len() * std::mem::size_of::<A>()
+ csc.row_indices.len() * std::mem::size_of::<usize>()
+ csc.col_ptrs.len() * std::mem::size_of::<usize>()
} else {
0
};
coo_size + csr_size + csc_size
}
}
impl<A: Clone> SparseMatrixCSR<A> {
pub fn new(rows: usize, cols: usize) -> Self {
Self {
rows,
cols,
row_ptrs: vec![0; rows + 1],
col_indices: Vec::new(),
values: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn nnz(&self) -> usize {
self.values.len()
}
pub fn shape(&self) -> (usize, usize) {
(self.rows, self.cols)
}
pub fn row(&self, row: usize) -> Option<(&[usize], &[A])> {
if row >= self.rows {
return None;
}
let start = self.row_ptrs[row];
let end = self.row_ptrs[row + 1];
Some((&self.col_indices[start..end], &self.values[start..end]))
}
}
impl<A: Clone> SparseMatrixCSC<A> {
pub fn new(rows: usize, cols: usize) -> Self {
Self {
rows,
cols,
col_ptrs: vec![0; cols + 1],
row_indices: Vec::new(),
values: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn nnz(&self) -> usize {
self.values.len()
}
pub fn shape(&self) -> (usize, usize) {
(self.rows, self.cols)
}
pub fn column(&self, col: usize) -> Option<(&[usize], &[A])> {
if col >= self.cols {
return None;
}
let start = self.col_ptrs[col];
let end = self.col_ptrs[col + 1];
Some((&self.row_indices[start..end], &self.values[start..end]))
}
}
#[allow(dead_code)]
pub fn serialize_enhanced_sparse_matrix<P, A>(
path: P,
matrix: &SparseMatrix<A>,
format: SerializationFormat,
) -> Result<()>
where
P: AsRef<Path>,
A: Serialize,
{
serialize_struct(path, matrix, format)
}
#[allow(dead_code)]
pub fn deserialize_enhanced_sparse_matrix<P, A>(
path: P,
format: SerializationFormat,
) -> Result<SparseMatrix<A>>
where
P: AsRef<Path>,
A: for<'de> Deserialize<'de> + Default,
{
deserialize_struct(path, format)
}
#[allow(dead_code)]
pub fn from_matrix_market<A>(mm_matrix: &crate::matrix_market::MMSparseMatrix<A>) -> SparseMatrix<A>
where
A: Clone,
{
let mut coo = SparseMatrixCOO::new(mm_matrix.rows, mm_matrix.cols);
for entry in &mm_matrix.entries {
coo.push(entry.row, entry.col, entry.value.clone());
}
let mut sparse = SparseMatrix::from_coo(coo);
sparse
.metadata
.insert("source".to_string(), "Matrix Market".to_string());
sparse.metadata.insert(
"format".to_string(),
format!("{:?}", mm_matrix.header.format),
);
sparse.metadata.insert(
"data_type".to_string(),
format!("{:?}", mm_matrix.header.data_type),
);
sparse.metadata.insert(
"symmetry".to_string(),
format!("{:?}", mm_matrix.header.symmetry),
);
sparse
}
#[allow(dead_code)]
pub fn to_matrix_market<A>(sparse: &SparseMatrix<A>) -> crate::matrix_market::MMSparseMatrix<A>
where
A: Clone,
{
let header = crate::matrix_market::MMHeader {
object: "matrix".to_string(),
format: crate::matrix_market::MMFormat::Coordinate,
data_type: crate::matrix_market::MMDataType::Real, symmetry: crate::matrix_market::MMSymmetry::General, comments: vec!["Converted from enhanced _sparse matrix".to_string()],
};
let entries = sparse
.coo_data
.row_indices
.iter()
.zip(sparse.coo_data.col_indices.iter())
.zip(sparse.coo_data.values.iter())
.map(|((&row, &col), value)| crate::matrix_market::SparseEntry {
row,
col,
value: value.clone(),
})
.collect();
crate::matrix_market::MMSparseMatrix {
header,
rows: sparse.shape.0,
cols: sparse.shape.1,
nnz: sparse.nnz(),
entries,
}
}
pub mod sparse_ops {
use super::*;
pub fn add_coo<A>(a: &SparseMatrixCOO<A>, b: &SparseMatrixCOO<A>) -> Result<SparseMatrixCOO<A>>
where
A: Clone + std::ops::Add<Output = A> + Default + PartialEq,
{
if a.rows != b.rows || a.cols != b.cols {
return Err(IoError::ValidationError(
"Matrix dimensions must match".to_string(),
));
}
let mut result = SparseMatrixCOO::new(a.rows, a.cols);
let mut indices_map: HashMap<(usize, usize), A> = HashMap::new();
for ((row, col), value) in a
.row_indices
.iter()
.zip(a.col_indices.iter())
.zip(a.values.iter())
{
indices_map.insert((*row, *col), value.clone());
}
for ((row, col), value) in b
.row_indices
.iter()
.zip(b.col_indices.iter())
.zip(b.values.iter())
{
let key = (*row, *col);
if let Some(existing) = indices_map.get(&key) {
indices_map.insert(key, existing.clone() + value.clone());
} else {
indices_map.insert(key, value.clone());
}
}
for ((row, col), value) in indices_map {
if value != A::default() {
result.push(row, col, value);
}
}
Ok(result)
}
pub fn csr_matvec<A>(matrix: &SparseMatrixCSR<A>, vector: &[A]) -> Result<Vec<A>>
where
A: Clone + std::ops::Add<Output = A> + std::ops::Mul<Output = A> + Default,
{
if vector.len() != matrix.cols {
return Err(IoError::ValidationError(
"Vector dimension must match _matrix columns".to_string(),
));
}
let mut result = vec![A::default(); matrix.rows];
for (row, result_elem) in result.iter_mut().enumerate() {
let start = matrix.row_ptrs[row];
let end = matrix.row_ptrs[row + 1];
let mut sum = A::default();
for i in start..end {
let col = matrix.col_indices[i];
let val = matrix.values[i].clone();
sum = sum + (val * vector[col].clone());
}
*result_elem = sum;
}
Ok(result)
}
pub fn transpose_coo<A>(matrix: &SparseMatrixCOO<A>) -> SparseMatrixCOO<A>
where
A: Clone,
{
let mut result = SparseMatrixCOO::new(matrix.cols, matrix.rows);
for ((row, col), value) in matrix
.row_indices
.iter()
.zip(matrix.col_indices.iter())
.zip(matrix.values.iter())
{
result.push(*col, *row, value.clone());
}
result
}
}
#[allow(dead_code)]
pub fn write_array_json<P, A, S>(path: P, array: &ArrayBase<S, IxDyn>) -> Result<()>
where
P: AsRef<Path>,
A: Serialize + Clone,
S: scirs2_core::ndarray::Data<Elem = A>,
{
serialize_array::<P, A, S>(path, array, SerializationFormat::JSON)
}
#[allow(dead_code)]
pub fn read_array_json<P, A>(path: P) -> Result<Array<A, IxDyn>>
where
P: AsRef<Path>,
A: for<'de> Deserialize<'de> + Clone,
{
deserialize_array(path, SerializationFormat::JSON)
}
#[allow(dead_code)]
pub fn write_array_binary<P, A, S>(path: P, array: &ArrayBase<S, IxDyn>) -> Result<()>
where
P: AsRef<Path>,
A: Serialize + Clone,
S: scirs2_core::ndarray::Data<Elem = A>,
{
serialize_array::<P, A, S>(path, array, SerializationFormat::Binary)
}
#[allow(dead_code)]
pub fn read_array_binary<P, A>(path: P) -> Result<Array<A, IxDyn>>
where
P: AsRef<Path>,
A: for<'de> Deserialize<'de> + Clone,
{
deserialize_array(path, SerializationFormat::Binary)
}
#[allow(dead_code)]
pub fn write_array_messagepack<P, A, S>(path: P, array: &ArrayBase<S, IxDyn>) -> Result<()>
where
P: AsRef<Path>,
A: Serialize + Clone,
S: scirs2_core::ndarray::Data<Elem = A>,
{
serialize_array::<P, A, S>(path, array, SerializationFormat::MessagePack)
}
#[allow(dead_code)]
pub fn read_array_messagepack<P, A>(path: P) -> Result<Array<A, IxDyn>>
where
P: AsRef<Path>,
A: for<'de> Deserialize<'de> + Clone,
{
deserialize_array(path, SerializationFormat::MessagePack)
}
#[allow(dead_code)]
pub fn serialize_array_zero_copy<P, A, S>(
path: P,
array: &ArrayBase<S, IxDyn>,
format: SerializationFormat,
) -> Result<()>
where
P: AsRef<Path>,
A: Serialize + bytemuck::Pod,
S: scirs2_core::ndarray::Data<Elem = A>,
{
if !array.is_standard_layout() {
return Err(IoError::FormatError(
"Array must be in standard layout for zero-copy serialization".to_string(),
));
}
let file = File::create(&path).map_err(|e| IoError::FileError(e.to_string()))?;
let mut writer = BufWriter::new(file);
let shape = array.shape().to_vec();
let metadata = ArrayMetadata {
shape: shape.clone(),
dtype: std::any::type_name::<A>().to_string(),
order: 'C',
metadata: HashMap::new(),
};
match format {
SerializationFormat::Binary => {
let cfg = oxicode_cfg();
let bytes = oxicode_serde::encode_to_vec(&metadata, cfg)
.map_err(|e| IoError::SerializationError(e.to_string()))?;
writer
.write_all(&bytes)
.map_err(|e| IoError::FileError(e.to_string()))?;
if let Some(slice) = array.as_slice() {
let bytes = bytemuck::cast_slice(slice);
writer
.write_all(bytes)
.map_err(|e| IoError::FileError(e.to_string()))?;
}
}
_ => {
return serialize_array(path, array, format);
}
}
writer
.flush()
.map_err(|e| IoError::FileError(e.to_string()))?;
Ok(())
}
#[allow(dead_code)]
pub fn deserialize_array_zero_copy<P>(path: P) -> Result<(ArrayMetadata, memmap2::Mmap)>
where
P: AsRef<Path>,
{
use std::io::Read;
let mut file = File::open(path).map_err(|e| IoError::FileError(e.to_string()))?;
let mut size_buf = [0u8; 8];
file.read_exact(&mut size_buf)
.map_err(|e| IoError::FileError(e.to_string()))?;
let metadata_size = u64::from_le_bytes(size_buf) as usize;
let mut metadata_buf = vec![0u8; metadata_size];
file.read_exact(&mut metadata_buf)
.map_err(|e| IoError::FileError(e.to_string()))?;
let cfg = oxicode_cfg();
let (metadata, _len): (ArrayMetadata, usize) =
oxicode_serde::decode_owned_from_slice(&metadata_buf, cfg)
.map_err(|e| IoError::DeserializationError(e.to_string()))?;
let mmap = unsafe {
memmap2::MmapOptions::new()
.offset(8 + metadata_size as u64)
.map(&file)
.map_err(|e| IoError::FileError(e.to_string()))?
};
Ok((metadata, mmap))
}