ferray_strings/
serde_impl.rs1use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
10use serde::ser::{SerializeStruct, Serializer};
11use serde::{Deserialize, Serialize};
12
13use ferray_core::dimension::Dimension;
14
15use crate::string_array::StringArray;
16
17impl<D: Dimension> Serialize for StringArray<D> {
18 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
19 let mut state = serializer.serialize_struct("StringArray", 2)?;
20 state.serialize_field("shape", self.shape())?;
21 state.serialize_field("data", self.as_slice())?;
22 state.end()
23 }
24}
25
26impl<'de, D: Dimension> Deserialize<'de> for StringArray<D> {
27 fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
28 #[derive(Deserialize)]
29 #[serde(field_identifier, rename_all = "lowercase")]
30 enum Field {
31 Shape,
32 Data,
33 }
34
35 struct StringArrayVisitor<D>(std::marker::PhantomData<D>);
36
37 impl<'de, D: Dimension> Visitor<'de> for StringArrayVisitor<D> {
38 type Value = StringArray<D>;
39
40 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "a struct with 'shape' and 'data' fields")
42 }
43
44 fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
45 let shape: Vec<usize> = seq
46 .next_element()?
47 .ok_or_else(|| de::Error::missing_field("shape"))?;
48 let data: Vec<String> = seq
49 .next_element()?
50 .ok_or_else(|| de::Error::missing_field("data"))?;
51 build_string_array(shape, data)
52 }
53
54 fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
55 let mut shape: Option<Vec<usize>> = None;
56 let mut data: Option<Vec<String>> = None;
57 while let Some(key) = map.next_key()? {
58 match key {
59 Field::Shape => {
60 if shape.is_some() {
61 return Err(de::Error::duplicate_field("shape"));
62 }
63 shape = Some(map.next_value()?);
64 }
65 Field::Data => {
66 if data.is_some() {
67 return Err(de::Error::duplicate_field("data"));
68 }
69 data = Some(map.next_value()?);
70 }
71 }
72 }
73 let shape = shape.ok_or_else(|| de::Error::missing_field("shape"))?;
74 let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
75 build_string_array(shape, data)
76 }
77 }
78
79 deserializer.deserialize_struct(
80 "StringArray",
81 &["shape", "data"],
82 StringArrayVisitor::<D>(std::marker::PhantomData),
83 )
84 }
85}
86
87fn build_string_array<D, E>(shape: Vec<usize>, data: Vec<String>) -> Result<StringArray<D>, E>
88where
89 D: Dimension,
90 E: de::Error,
91{
92 if let Some(expected) = D::NDIM
93 && shape.len() != expected
94 {
95 return Err(de::Error::custom(format!(
96 "expected {expected}D shape, got {}D ({shape:?})",
97 shape.len()
98 )));
99 }
100 let dim = D::from_dim_slice(&shape).ok_or_else(|| {
101 de::Error::custom(format!(
102 "shape {shape:?} is not valid for the dimension type"
103 ))
104 })?;
105 StringArray::from_vec(dim, data).map_err(|e| de::Error::custom(e.to_string()))
106}
107
108#[cfg(test)]
109mod tests {
110 use ferray_core::dimension::{Ix1, Ix2, IxDyn};
111
112 use crate::string_array::{StringArray, array};
113
114 #[test]
115 fn round_trip_1d() {
116 let a = array(&["foo", "bar", "baz"]).unwrap();
117 let json = serde_json::to_string(&a).unwrap();
118 let restored: StringArray<Ix1> = serde_json::from_str(&json).unwrap();
119 assert_eq!(a.shape(), restored.shape());
120 assert_eq!(a.as_slice(), restored.as_slice());
121 }
122
123 #[test]
124 fn round_trip_2d() {
125 let a = StringArray::<Ix2>::from_vec(
126 Ix2::new([2, 3]),
127 vec!["a", "b", "c", "d", "e", "f"]
128 .into_iter()
129 .map(String::from)
130 .collect(),
131 )
132 .unwrap();
133 let json = serde_json::to_string(&a).unwrap();
134 let restored: StringArray<Ix2> = serde_json::from_str(&json).unwrap();
135 assert_eq!(a.shape(), restored.shape());
136 assert_eq!(a.as_slice(), restored.as_slice());
137 }
138
139 #[test]
140 fn round_trip_dynamic() {
141 let a = StringArray::<IxDyn>::from_vec_dyn(
142 &[2, 2],
143 vec!["one".into(), "two".into(), "three".into(), "four".into()],
144 )
145 .unwrap();
146 let json = serde_json::to_string(&a).unwrap();
147 let restored: StringArray<IxDyn> = serde_json::from_str(&json).unwrap();
148 assert_eq!(a.shape(), restored.shape());
149 assert_eq!(a.as_slice(), restored.as_slice());
150 }
151
152 #[test]
153 fn unicode_round_trip() {
154 let a = array(&["こんにちは", "Здравствуйте", "🎉"]).unwrap();
156 let json = serde_json::to_string(&a).unwrap();
157 let restored: StringArray<Ix1> = serde_json::from_str(&json).unwrap();
158 assert_eq!(a.as_slice(), restored.as_slice());
159 }
160
161 #[test]
162 fn rank_mismatch_error() {
163 let json = r#"{"shape":[2,2,2],"data":["a","b","c","d","e","f","g","h"]}"#;
164 let result = serde_json::from_str::<StringArray<Ix2>>(json);
165 assert!(result.is_err());
166 }
167
168 #[test]
169 fn size_mismatch_error() {
170 let json = r#"{"shape":[2,3],"data":["a","b","c"]}"#;
171 let result = serde_json::from_str::<StringArray<Ix2>>(json);
172 assert!(result.is_err());
173 }
174}