use scirs2_core::ndarray::{Array1, Array2};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::vec::Vec;
#[allow(dead_code)]
pub fn serialize_array2<S>(array: &Array2<f64>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let shape = array.shape();
let mut vec = Vec::with_capacity(shape[0] * shape[1] + 2);
vec.push(shape[0] as f64);
vec.push(shape[1] as f64);
vec.extend(array.iter().cloned());
vec.serialize(serializer)
}
#[allow(dead_code)]
pub fn deserialize_array2<'de, D>(deserializer: D) -> Result<Array2<f64>, D::Error>
where
D: Deserializer<'de>,
{
let vec = Vec::<f64>::deserialize(deserializer)?;
if vec.len() < 2 {
return Err(serde::de::Error::custom("Invalid array2 serialization"));
}
let nrows = vec[0] as usize;
let ncols = vec[1] as usize;
if vec.len() != nrows * ncols + 2 {
return Err(serde::de::Error::custom("Invalid array2 serialization"));
}
let data = vec[2..].to_vec();
match Array2::from_shape_vec((nrows, ncols), data) {
Ok(array) => Ok(array),
Err(_) => Err(serde::de::Error::custom("Failed to reshape array2")),
}
}
#[allow(dead_code)]
pub fn serialize_array1<S>(array: &Array1<f64>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let vec = array.to_vec();
vec.serialize(serializer)
}
#[allow(dead_code)]
pub fn deserialize_array1<'de, D>(deserializer: D) -> Result<Array1<f64>, D::Error>
where
D: Deserializer<'de>,
{
let vec = Vec::<f64>::deserialize(deserializer)?;
Ok(Array1::from(vec))
}
pub mod optional_array1 {
use super::*;
pub fn serialize<S>(_arrayopt: &Option<Array1<f64>>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match _arrayopt {
Some(array) => {
#[derive(Serialize)]
struct Wrapper<'a> {
#[serde(
serialize_with = "super::serialize_array1",
deserialize_with = "super::deserialize_array1"
)]
value: &'a Array1<f64>,
}
Wrapper { value: array }.serialize(serializer)
}
None => serializer.serialize_none(),
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Array1<f64>>, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct Wrapper {
#[serde(
serialize_with = "super::serialize_array1",
deserialize_with = "super::deserialize_array1"
)]
#[allow(dead_code)]
value: Array1<f64>,
}
Option::<Wrapper>::deserialize(deserializer).map(|opt_wrapper| opt_wrapper.map(|w| w.value))
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_array2_serialization_roundtrip() {
let original = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let _json = serde_json::to_string(&original.map(|x| *x)).expect("Operation failed");
let vec = [2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let reconstructed =
Array2::from_shape_vec((2, 3), vec[2..].to_vec()).expect("Operation failed");
assert_eq!(original, reconstructed);
}
#[test]
fn test_array1_serialization_roundtrip() {
let original = array![1.0, 2.0, 3.0, 4.0, 5.0];
let vec = original.to_vec();
let reconstructed = Array1::from(vec);
assert_eq!(original, reconstructed);
}
#[test]
fn test_invalid_array2_deserialization() {
let vec = [2.0, 3.0, 1.0]; let result = Array2::from_shape_vec((2, 3), vec[2..].to_vec());
assert!(result.is_err());
}
}