Skip to main content

ferray_strings/
serde_impl.rs

1//! Serialize/Deserialize for `StringArray<D>` (#279).
2//!
3//! Mirrors `ferray_core::array::serde_impl`: emits `{ "shape": [...],
4//! "data": [...] }` with strings in row-major order. The single
5//! deserialize path validates the rank against `D::NDIM` (for fixed
6//! ranks) and rebuilds the dimension via the same ndarray-based
7//! construction used for the numeric `Array`.
8
9use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
10use serde::ser::{SerializeStruct, Serializer};
11use serde::{Deserialize, Serialize};
12
13use ferray_core::dimension::Dimension;
14
15use crate::string_array::StringArray;
16
17impl<D: Dimension> Serialize for StringArray<D> {
18    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
19        let mut state = serializer.serialize_struct("StringArray", 2)?;
20        state.serialize_field("shape", self.shape())?;
21        state.serialize_field("data", self.as_slice())?;
22        state.end()
23    }
24}
25
26impl<'de, D: Dimension> Deserialize<'de> for StringArray<D> {
27    fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
28        #[derive(Deserialize)]
29        #[serde(field_identifier, rename_all = "lowercase")]
30        enum Field {
31            Shape,
32            Data,
33        }
34
35        struct StringArrayVisitor<D>(std::marker::PhantomData<D>);
36
37        impl<'de, D: Dimension> Visitor<'de> for StringArrayVisitor<D> {
38            type Value = StringArray<D>;
39
40            fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41                write!(f, "a struct with 'shape' and 'data' fields")
42            }
43
44            fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
45                let shape: Vec<usize> = seq
46                    .next_element()?
47                    .ok_or_else(|| de::Error::missing_field("shape"))?;
48                let data: Vec<String> = seq
49                    .next_element()?
50                    .ok_or_else(|| de::Error::missing_field("data"))?;
51                build_string_array(shape, data)
52            }
53
54            fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
55                let mut shape: Option<Vec<usize>> = None;
56                let mut data: Option<Vec<String>> = None;
57                while let Some(key) = map.next_key()? {
58                    match key {
59                        Field::Shape => {
60                            if shape.is_some() {
61                                return Err(de::Error::duplicate_field("shape"));
62                            }
63                            shape = Some(map.next_value()?);
64                        }
65                        Field::Data => {
66                            if data.is_some() {
67                                return Err(de::Error::duplicate_field("data"));
68                            }
69                            data = Some(map.next_value()?);
70                        }
71                    }
72                }
73                let shape = shape.ok_or_else(|| de::Error::missing_field("shape"))?;
74                let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
75                build_string_array(shape, data)
76            }
77        }
78
79        deserializer.deserialize_struct(
80            "StringArray",
81            &["shape", "data"],
82            StringArrayVisitor::<D>(std::marker::PhantomData),
83        )
84    }
85}
86
87fn build_string_array<D, E>(shape: Vec<usize>, data: Vec<String>) -> Result<StringArray<D>, E>
88where
89    D: Dimension,
90    E: de::Error,
91{
92    if let Some(expected) = D::NDIM
93        && shape.len() != expected
94    {
95        return Err(de::Error::custom(format!(
96            "expected {expected}D shape, got {}D ({shape:?})",
97            shape.len()
98        )));
99    }
100    let dim = D::from_dim_slice(&shape).ok_or_else(|| {
101        de::Error::custom(format!(
102            "shape {shape:?} is not valid for the dimension type"
103        ))
104    })?;
105    StringArray::from_vec(dim, data).map_err(|e| de::Error::custom(e.to_string()))
106}
107
108#[cfg(test)]
109mod tests {
110    use ferray_core::dimension::{Ix1, Ix2, IxDyn};
111
112    use crate::string_array::{StringArray, array};
113
114    #[test]
115    fn round_trip_1d() {
116        let a = array(&["foo", "bar", "baz"]).unwrap();
117        let json = serde_json::to_string(&a).unwrap();
118        let restored: StringArray<Ix1> = serde_json::from_str(&json).unwrap();
119        assert_eq!(a.shape(), restored.shape());
120        assert_eq!(a.as_slice(), restored.as_slice());
121    }
122
123    #[test]
124    fn round_trip_2d() {
125        let a = StringArray::<Ix2>::from_vec(
126            Ix2::new([2, 3]),
127            vec!["a", "b", "c", "d", "e", "f"]
128                .into_iter()
129                .map(String::from)
130                .collect(),
131        )
132        .unwrap();
133        let json = serde_json::to_string(&a).unwrap();
134        let restored: StringArray<Ix2> = serde_json::from_str(&json).unwrap();
135        assert_eq!(a.shape(), restored.shape());
136        assert_eq!(a.as_slice(), restored.as_slice());
137    }
138
139    #[test]
140    fn round_trip_dynamic() {
141        let a = StringArray::<IxDyn>::from_vec_dyn(
142            &[2, 2],
143            vec!["one".into(), "two".into(), "three".into(), "four".into()],
144        )
145        .unwrap();
146        let json = serde_json::to_string(&a).unwrap();
147        let restored: StringArray<IxDyn> = serde_json::from_str(&json).unwrap();
148        assert_eq!(a.shape(), restored.shape());
149        assert_eq!(a.as_slice(), restored.as_slice());
150    }
151
152    #[test]
153    fn unicode_round_trip() {
154        // Multi-byte UTF-8 strings must round-trip exactly.
155        let a = array(&["こんにちは", "Здравствуйте", "🎉"]).unwrap();
156        let json = serde_json::to_string(&a).unwrap();
157        let restored: StringArray<Ix1> = serde_json::from_str(&json).unwrap();
158        assert_eq!(a.as_slice(), restored.as_slice());
159    }
160
161    #[test]
162    fn rank_mismatch_error() {
163        let json = r#"{"shape":[2,2,2],"data":["a","b","c","d","e","f","g","h"]}"#;
164        let result = serde_json::from_str::<StringArray<Ix2>>(json);
165        assert!(result.is_err());
166    }
167
168    #[test]
169    fn size_mismatch_error() {
170        let json = r#"{"shape":[2,3],"data":["a","b","c"]}"#;
171        let result = serde_json::from_str::<StringArray<Ix2>>(json);
172        assert!(result.is_err());
173    }
174}