1use std::{collections::HashMap, fmt};
2
3use serde::{
4 de::{MapAccess, SeqAccess, Visitor},
5 Deserialize, Deserializer,
6};
7use serde_json::Value;
8
9use super::*;
10
11fn to_primitive(v: &str) -> Option<Schema> {
12 use Schema::*;
13 Some(match v {
14 "null" => Null,
15 "boolean" => Boolean,
16 "bytes" => Bytes(None),
17 "string" => String(None),
18 "int" => Int(None),
19 "long" => Long(None),
20 "float" => Float,
21 "double" => Double,
22 _ => return None,
23 })
24}
25
26fn get_type<E: serde::de::Error>(map: &mut HashMap<String, Value>) -> Result<String, E> {
27 if let Some(v) = map.remove("type") {
28 if let Value::String(v) = v {
29 Ok(v)
30 } else if let Value::Null = v {
31 Ok("null".to_string())
32 } else {
33 Err(serde::de::Error::custom("type must be a string"))
34 }
35 } else {
36 Err(serde::de::Error::missing_field("type"))
37 }
38}
39
40fn as_string<E: serde::de::Error>(v: Value, helper: &str) -> Result<String, E> {
41 if let Value::String(v) = v {
42 Ok(v)
43 } else {
44 Err(serde::de::Error::custom(format!(
45 "{} must be a string",
46 helper
47 )))
48 }
49}
50
51fn remove_string<E: serde::de::Error>(
52 data: &mut HashMap<String, Value>,
53 key: &str,
54) -> Result<Option<String>, E> {
55 match data.remove(key) {
56 Some(s) => as_string(s, key).map(Some),
57 None => Ok(None),
58 }
59}
60
61fn remove_usize<E: serde::de::Error>(
62 data: &mut HashMap<String, Value>,
63 key: &str,
64) -> Result<Option<usize>, E> {
65 data.remove(key)
66 .map(|x| serde_json::from_value::<usize>(x).map_err(serde::de::Error::custom))
67 .transpose()
68}
69
70fn remove_vec_string<E: serde::de::Error>(
71 data: &mut HashMap<String, Value>,
72 key: &str,
73) -> Result<Vec<String>, E> {
74 match data.remove(key) {
75 Some(s) => {
76 if let Value::Array(x) = s {
77 x.into_iter().map(|x| as_string(x, key)).collect()
78 } else {
79 Err(serde::de::Error::custom(format!(
80 "{} must be a string",
81 key
82 )))
83 }
84 }
85 None => Ok(vec![]),
86 }
87}
88
89fn to_enum<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
90 Ok(Schema::Enum(Enum {
91 name: remove_string(data, "name")?
92 .ok_or_else(|| serde::de::Error::custom("name is required in enum"))?,
93 namespace: remove_string(data, "namespace")?,
94 aliases: remove_vec_string(data, "aliases")?,
95 doc: remove_string(data, "doc")?,
96 symbols: remove_vec_string(data, "symbols")?,
97 default: remove_string(data, "default")?,
98 }))
99}
100
101fn to_map<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
102 let item = data
103 .remove("values")
104 .ok_or_else(|| serde::de::Error::custom("values is required in a map"))?;
105 let schema: Schema = serde_json::from_value(item).map_err(serde::de::Error::custom)?;
106 Ok(Schema::Map(Box::new(schema)))
107}
108
109fn to_schema<E: serde::de::Error>(
110 data: &mut HashMap<String, Value>,
111 key: &str,
112) -> Result<Option<Schema>, E> {
113 let schema = data.remove(key);
114 schema
115 .map(|schema| serde_json::from_value(schema).map_err(serde::de::Error::custom))
116 .transpose()
117}
118
119fn to_array<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
120 let schema =
121 to_schema(data, "items")?.ok_or_else(|| E::custom("items is required in an array"))?;
122 Ok(Schema::Array(Box::new(schema)))
123}
124
125fn to_field<E: serde::de::Error>(data: Value) -> Result<Field, E> {
126 serde_json::from_value(data).map_err(E::custom)
127}
128
129fn to_vec_fields<E: serde::de::Error>(
130 data: &mut HashMap<String, Value>,
131 key: &str,
132) -> Result<Vec<Field>, E> {
133 match data.remove(key) {
134 Some(s) => {
135 if let Value::Array(x) = s {
136 x.into_iter().map(to_field).collect()
137 } else {
138 Err(E::custom(format!("{} must be a string", key)))
139 }
140 }
141 None => Ok(vec![]),
142 }
143}
144
145fn to_record<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
146 Ok(Schema::Record(Record {
147 name: remove_string(data, "name")?
148 .ok_or_else(|| serde::de::Error::custom("name is required in enum"))?,
149 namespace: remove_string(data, "namespace")?,
150 aliases: remove_vec_string(data, "aliases")?,
151 doc: remove_string(data, "doc")?,
152 fields: to_vec_fields(data, "fields")?,
153 }))
154}
155
156fn to_fixed<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
157 let size = remove_usize(data, "size")?
158 .ok_or_else(|| serde::de::Error::custom("size is required in fixed"))?;
159
160 let logical = remove_string(data, "logicalType")?.unwrap_or_default();
161 let logical = match logical.as_ref() {
162 "decimal" => {
163 let precision = remove_usize(data, "precision")?;
164 let scale = remove_usize(data, "scale")?.unwrap_or_default();
165 precision.map(|p| FixedLogical::Decimal(p, scale))
166 }
167 "duration" => Some(FixedLogical::Duration),
168 _ => None,
169 };
170
171 Ok(Schema::Fixed(Fixed {
172 name: remove_string(data, "name")?
173 .ok_or_else(|| serde::de::Error::custom("name is required in fixed"))?,
174 namespace: remove_string(data, "namespace")?,
175 aliases: remove_vec_string(data, "aliases")?,
176 doc: remove_string(data, "doc")?,
177 size,
178 logical,
179 }))
180}
181
182fn to_order<E: serde::de::Error>(
183 data: &mut HashMap<String, Value>,
184 key: &str,
185) -> Result<Option<Order>, E> {
186 remove_string(data, key)?
187 .map(|x| {
188 Ok(match x.as_ref() {
189 "ascending" => Order::Ascending,
190 "descending" => Order::Descending,
191 "ignore" => Order::Ignore,
192 _ => {
193 return Err(serde::de::Error::custom(
194 "order can only be one of {ascending, descending, ignore}",
195 ))
196 }
197 })
198 })
199 .transpose()
200}
201
202struct SchemaVisitor {}
203
204impl<'de> Visitor<'de> for SchemaVisitor {
205 type Value = Schema;
206
207 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
209 formatter.write_str("a null, string, array or map describing an Avro schema")
210 }
211
212 fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
213 where
214 D: Deserializer<'de>,
215 {
216 deserializer.deserialize_any(SchemaVisitor {})
217 }
218
219 fn visit_none<E>(self) -> Result<Self::Value, E>
220 where
221 E: serde::de::Error,
222 {
223 Ok(Schema::Null)
224 }
225
226 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
227 where
228 E: serde::de::Error,
229 {
230 to_primitive(v)
231 .ok_or_else(|| serde::de::Error::custom("string must be a valid primitive Schema"))
232 }
233
234 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
235 where
236 A: SeqAccess<'de>,
237 {
238 let mut vec = Vec::with_capacity(seq.size_hint().unwrap_or(0));
239 while let Some(item) = seq.next_element::<Value>()? {
240 let schema: Schema = serde_json::from_value(item).map_err(serde::de::Error::custom)?;
241 vec.push(schema)
242 }
243 Ok(Schema::Union(vec))
244 }
245
246 fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
247 where
248 M: MapAccess<'de>,
249 {
250 let mut map = HashMap::<String, Value>::with_capacity(access.size_hint().unwrap_or(0));
251
252 while let Some((key, value)) = access.next_entry()? {
255 map.insert(key, value);
256 }
257
258 let (schema, type_) = get_type(&mut map).map(|x| (to_primitive(&x), x))?;
259
260 if let Some(schema) = schema {
261 Ok(match type_.as_ref() {
262 "string" => {
263 let logical = remove_string(&mut map, "logicalType")?.unwrap_or_default();
264 match logical.as_ref() {
265 "uuid" => Schema::String(Some(StringLogical::Uuid)),
266 _ => schema,
267 }
268 }
269 "int" => {
270 let logical = remove_string(&mut map, "logicalType")?.unwrap_or_default();
271 match logical.as_ref() {
272 "date" => Schema::Int(Some(IntLogical::Date)),
273 "time-millis" => Schema::Int(Some(IntLogical::Time)),
274 _ => schema,
275 }
276 }
277 "long" => {
278 let logical = remove_string(&mut map, "logicalType")?.unwrap_or_default();
279 match logical.as_ref() {
280 "time-micros" => Schema::Long(Some(LongLogical::Time)),
281 "timestamp-millis" => Schema::Long(Some(LongLogical::TimestampMillis)),
282 "timestamp-micros" => Schema::Long(Some(LongLogical::TimestampMicros)),
283 "local-timestamp-millis" => {
284 Schema::Long(Some(LongLogical::LocalTimestampMillis))
285 }
286 "local-timestamp-micros" => {
287 Schema::Long(Some(LongLogical::LocalTimestampMicros))
288 }
289 _ => schema,
290 }
291 }
292 "bytes" => {
293 let logical = remove_string(&mut map, "logicalType")?.unwrap_or_default();
294 match logical.as_ref() {
295 "decimal" => {
296 let precision = remove_usize(&mut map, "precision")?;
297 let scale = remove_usize(&mut map, "scale")?.unwrap_or_default();
298 Schema::Bytes(precision.map(|p| BytesLogical::Decimal(p, scale)))
299 }
300 _ => schema,
301 }
302 }
303 _ => schema,
304 })
305 } else {
306 match type_.as_ref() {
307 "enum" => to_enum(&mut map),
308 "map" => to_map(&mut map),
309 "array" => to_array(&mut map),
310 "record" => to_record(&mut map),
311 "fixed" => to_fixed(&mut map),
312 other => todo!("{}", other),
313 }
314 }
315 }
316}
317
318impl<'de> Deserialize<'de> for Schema {
319 fn deserialize<D>(deserializer: D) -> Result<Schema, D::Error>
320 where
321 D: Deserializer<'de>,
322 {
323 deserializer.deserialize_option(SchemaVisitor {})
324 }
325}
326
327struct FieldVisitor {}
328
329impl<'de> Visitor<'de> for FieldVisitor {
330 type Value = Field;
331
332 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
333 formatter.write_str("a map describing an Avro field")
334 }
335
336 fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
337 where
338 M: MapAccess<'de>,
339 {
340 let mut map = HashMap::<String, Value>::with_capacity(access.size_hint().unwrap_or(0));
341
342 while let Some((key, value)) = access.next_entry()? {
345 map.insert(key, value);
346 }
347
348 Ok(Field {
349 name: remove_string(&mut map, "name")?
350 .ok_or_else(|| serde::de::Error::custom("name is required in enum"))?,
351 doc: remove_string(&mut map, "doc")?,
352 schema: to_schema(&mut map, "type")?
353 .ok_or_else(|| serde::de::Error::custom("type is required in Field"))?,
354 default: to_schema(&mut map, "default")?,
355 order: to_order(&mut map, "order")?,
356 aliases: remove_vec_string(&mut map, "aliases")?,
357 })
358 }
359}
360
361impl<'de> Deserialize<'de> for Field {
362 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
363 where
364 D: Deserializer<'de>,
365 {
366 deserializer.deserialize_map(FieldVisitor {})
367 }
368}