use crate::Matrix;
use crate::traits::Scalar;
use serde::ser::{Serialize, SerializeTuple, Serializer};
use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
use core::fmt;
use core::marker::PhantomData;
impl<T: Serialize + Copy, const M: usize, const N: usize> Serialize for Matrix<T, M, N> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
if N == 1 {
let mut seq = serializer.serialize_tuple(M)?;
for i in 0..M {
seq.serialize_element(&self.data[0][i])?;
}
seq.end()
} else {
let mut rows = serializer.serialize_tuple(M)?;
for i in 0..M {
rows.serialize_element(&RowRef::<T, N> {
data: core::array::from_fn(|j| self.data[j][i]),
})?;
}
rows.end()
}
}
}
struct RowRef<T, const N: usize> {
data: [T; N],
}
impl<T: Serialize, const N: usize> Serialize for RowRef<T, N> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut seq = serializer.serialize_tuple(N)?;
for j in 0..N {
seq.serialize_element(&self.data[j])?;
}
seq.end()
}
}
impl<'de, T: Scalar + Deserialize<'de>, const M: usize, const N: usize> Deserialize<'de>
for Matrix<T, M, N>
{
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct MatVisitor<T, const M: usize, const N: usize>(PhantomData<T>);
impl<'de, T: Scalar + Deserialize<'de>, const M: usize, const N: usize> Visitor<'de>
for MatVisitor<T, M, N>
{
type Value = Matrix<T, M, N>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
if N == 1 {
write!(f, "an array of {} elements", M)
} else {
write!(f, "a {}x{} matrix as row-major array", M, N)
}
}
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
if N == 1 {
let mut mat = Matrix::<T, M, N>::zeros();
for i in 0..M {
mat.data[0][i] = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(i, &self))?;
}
Ok(mat)
} else {
let mut mat = Matrix::<T, M, N>::zeros();
for i in 0..M {
let row: RowDeserialize<T, N> = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(i, &self))?;
for j in 0..N {
mat.data[j][i] = row.data[j];
}
}
Ok(mat)
}
}
}
deserializer.deserialize_tuple(M, MatVisitor::<T, M, N>(PhantomData))
}
}
struct RowDeserialize<T, const N: usize> {
data: [T; N],
}
impl<'de, T: Scalar + Deserialize<'de>, const N: usize> Deserialize<'de>
for RowDeserialize<T, N>
{
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct RowVisitor<T, const N: usize>(PhantomData<T>);
impl<'de, T: Scalar + Deserialize<'de>, const N: usize> Visitor<'de> for RowVisitor<T, N> {
type Value = RowDeserialize<T, N>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "an array of {} elements", N)
}
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
let mut data = [T::zero(); N];
for j in 0..N {
data[j] = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(j, &self))?;
}
Ok(RowDeserialize { data })
}
}
deserializer.deserialize_tuple(N, RowVisitor::<T, N>(PhantomData))
}
}
#[cfg(feature = "alloc")]
mod dyn_serde {
use crate::dynmatrix::{DynMatrix, DynVector};
use crate::traits::Scalar;
use alloc::format;
use alloc::vec::Vec;
use serde::ser::{Serialize, SerializeSeq, SerializeStruct, Serializer};
use serde::de::{self, Deserialize, Deserializer, MapAccess, Visitor};
use core::fmt;
use core::marker::PhantomData;
impl<T: Serialize + Scalar> Serialize for DynMatrix<T> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut s = serializer.serialize_struct("DynMatrix", 3)?;
s.serialize_field("nrows", &self.nrows())?;
s.serialize_field("ncols", &self.ncols())?;
let row_vecs: Vec<Vec<T>> = (0..self.nrows())
.map(|i| (0..self.ncols()).map(|j| self[(i, j)]).collect())
.collect();
s.serialize_field("data", &row_vecs)?;
s.end()
}
}
impl<'de, T: Scalar + Deserialize<'de>> Deserialize<'de> for DynMatrix<T> {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
#[derive(serde::Deserialize)]
#[serde(field_identifier, rename_all = "lowercase")]
enum Field {
Nrows,
Ncols,
Data,
}
struct DynMatVisitor<T>(PhantomData<T>);
impl<'de, T: Scalar + Deserialize<'de>> Visitor<'de> for DynMatVisitor<T> {
type Value = DynMatrix<T>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "a DynMatrix with nrows, ncols, and row-major data")
}
fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
let mut nrows: Option<usize> = None;
let mut ncols: Option<usize> = None;
let mut data: Option<Vec<Vec<T>>> = None;
while let Some(key) = map.next_key()? {
match key {
Field::Nrows => nrows = Some(map.next_value()?),
Field::Ncols => ncols = Some(map.next_value()?),
Field::Data => data = Some(map.next_value()?),
}
}
let nrows = nrows.ok_or_else(|| de::Error::missing_field("nrows"))?;
let ncols = ncols.ok_or_else(|| de::Error::missing_field("ncols"))?;
let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
if data.len() != nrows {
return Err(de::Error::custom(format!(
"expected {} rows, got {}",
nrows,
data.len()
)));
}
let flat: Vec<T> = data.into_iter().flat_map(|row| row.into_iter()).collect();
if flat.len() != nrows * ncols {
return Err(de::Error::custom("row lengths inconsistent with ncols"));
}
Ok(DynMatrix::from_rows(nrows, ncols, &flat))
}
}
deserializer.deserialize_struct(
"DynMatrix",
&["nrows", "ncols", "data"],
DynMatVisitor::<T>(PhantomData),
)
}
}
impl<T: Serialize + Scalar> Serialize for DynVector<T> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut seq = serializer.serialize_seq(Some(self.len()))?;
for i in 0..self.len() {
seq.serialize_element(&self[i])?;
}
seq.end()
}
}
impl<'de, T: Scalar + Deserialize<'de>> Deserialize<'de> for DynVector<T> {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let data: Vec<T> = Vec::deserialize(deserializer)?;
Ok(DynVector::from_vec(data))
}
}
}
#[cfg(test)]
mod tests {
use crate::Matrix;
use crate::matrix::vector::Vector;
#[test]
fn matrix_roundtrip_json() {
let m = Matrix::new([[1.0_f64, 2.0], [3.0, 4.0]]);
let json = serde_json::to_string(&m).unwrap();
assert_eq!(json, "[[1.0,2.0],[3.0,4.0]]");
let m2: Matrix<f64, 2, 2> = serde_json::from_str(&json).unwrap();
assert_eq!(m, m2);
}
#[test]
fn vector_roundtrip_json() {
let v = Vector::from_array([1.0_f64, 2.0, 3.0]);
let json = serde_json::to_string(&v).unwrap();
assert_eq!(json, "[1.0,2.0,3.0]");
let v2: Vector<f64, 3> = serde_json::from_str(&json).unwrap();
assert_eq!(v, v2);
}
#[test]
fn nonsquare_matrix_roundtrip() {
let m = Matrix::new([[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let json = serde_json::to_string(&m).unwrap();
assert_eq!(json, "[[1.0,2.0,3.0],[4.0,5.0,6.0]]");
let m2: Matrix<f64, 2, 3> = serde_json::from_str(&json).unwrap();
assert_eq!(m, m2);
}
#[test]
fn integer_matrix_roundtrip() {
let m = Matrix::new([[1_i32, 2], [3, 4]]);
let json = serde_json::to_string(&m).unwrap();
let m2: Matrix<i32, 2, 2> = serde_json::from_str(&json).unwrap();
assert_eq!(m, m2);
}
#[test]
fn scalar_matrix_roundtrip() {
let m = Matrix::new([[42.0_f64]]);
let json = serde_json::to_string(&m).unwrap();
assert_eq!(json, "[42.0]");
let m2: Matrix<f64, 1, 1> = serde_json::from_str(&json).unwrap();
assert_eq!(m, m2);
}
#[test]
fn quaternion_roundtrip_json() {
use crate::Quaternion;
let q = Quaternion::new(1.0_f64, 0.0, 0.0, 0.0);
let json = serde_json::to_string(&q).unwrap();
assert!(json.contains("\"w\":1.0"));
let q2: Quaternion<f64> = serde_json::from_str(&json).unwrap();
assert_eq!(q, q2);
}
#[cfg(feature = "alloc")]
mod dyn_tests {
use crate::dynmatrix::{DynMatrix, DynVector};
#[test]
fn dynmatrix_roundtrip_json() {
let m = DynMatrix::from_rows(2, 3, &[1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
let json = serde_json::to_string(&m).unwrap();
assert!(json.contains("\"nrows\":2"));
assert!(json.contains("\"ncols\":3"));
let m2: DynMatrix<f64> = serde_json::from_str(&json).unwrap();
assert_eq!(m, m2);
}
#[test]
fn dynvector_roundtrip_json() {
let v = DynVector::from_slice(&[1.0_f64, 2.0, 3.0]);
let json = serde_json::to_string(&v).unwrap();
assert_eq!(json, "[1.0,2.0,3.0]");
let v2: DynVector<f64> = serde_json::from_str(&json).unwrap();
assert_eq!(v, v2);
}
}
}