avro_schema/schema/
de.rs

1use std::{collections::HashMap, fmt};
2
3use serde::{
4    de::{MapAccess, SeqAccess, Visitor},
5    Deserialize, Deserializer,
6};
7use serde_json::Value;
8
9use super::*;
10
11fn to_primitive(v: &str) -> Option<Schema> {
12    use Schema::*;
13    Some(match v {
14        "null" => Null,
15        "boolean" => Boolean,
16        "bytes" => Bytes(None),
17        "string" => String(None),
18        "int" => Int(None),
19        "long" => Long(None),
20        "float" => Float,
21        "double" => Double,
22        _ => return None,
23    })
24}
25
26fn get_type<E: serde::de::Error>(map: &mut HashMap<String, Value>) -> Result<String, E> {
27    if let Some(v) = map.remove("type") {
28        if let Value::String(v) = v {
29            Ok(v)
30        } else if let Value::Null = v {
31            Ok("null".to_string())
32        } else {
33            Err(serde::de::Error::custom("type must be a string"))
34        }
35    } else {
36        Err(serde::de::Error::missing_field("type"))
37    }
38}
39
40fn as_string<E: serde::de::Error>(v: Value, helper: &str) -> Result<String, E> {
41    if let Value::String(v) = v {
42        Ok(v)
43    } else {
44        Err(serde::de::Error::custom(format!(
45            "{} must be a string",
46            helper
47        )))
48    }
49}
50
51fn remove_string<E: serde::de::Error>(
52    data: &mut HashMap<String, Value>,
53    key: &str,
54) -> Result<Option<String>, E> {
55    match data.remove(key) {
56        Some(s) => as_string(s, key).map(Some),
57        None => Ok(None),
58    }
59}
60
61fn remove_usize<E: serde::de::Error>(
62    data: &mut HashMap<String, Value>,
63    key: &str,
64) -> Result<Option<usize>, E> {
65    data.remove(key)
66        .map(|x| serde_json::from_value::<usize>(x).map_err(serde::de::Error::custom))
67        .transpose()
68}
69
70fn remove_vec_string<E: serde::de::Error>(
71    data: &mut HashMap<String, Value>,
72    key: &str,
73) -> Result<Vec<String>, E> {
74    match data.remove(key) {
75        Some(s) => {
76            if let Value::Array(x) = s {
77                x.into_iter().map(|x| as_string(x, key)).collect()
78            } else {
79                Err(serde::de::Error::custom(format!(
80                    "{} must be a string",
81                    key
82                )))
83            }
84        }
85        None => Ok(vec![]),
86    }
87}
88
89fn to_enum<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
90    Ok(Schema::Enum(Enum {
91        name: remove_string(data, "name")?
92            .ok_or_else(|| serde::de::Error::custom("name is required in enum"))?,
93        namespace: remove_string(data, "namespace")?,
94        aliases: remove_vec_string(data, "aliases")?,
95        doc: remove_string(data, "doc")?,
96        symbols: remove_vec_string(data, "symbols")?,
97        default: remove_string(data, "default")?,
98    }))
99}
100
101fn to_map<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
102    let item = data
103        .remove("values")
104        .ok_or_else(|| serde::de::Error::custom("values is required in a map"))?;
105    let schema: Schema = serde_json::from_value(item).map_err(serde::de::Error::custom)?;
106    Ok(Schema::Map(Box::new(schema)))
107}
108
109fn to_schema<E: serde::de::Error>(
110    data: &mut HashMap<String, Value>,
111    key: &str,
112) -> Result<Option<Schema>, E> {
113    let schema = data.remove(key);
114    schema
115        .map(|schema| serde_json::from_value(schema).map_err(serde::de::Error::custom))
116        .transpose()
117}
118
119fn to_array<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
120    let schema =
121        to_schema(data, "items")?.ok_or_else(|| E::custom("items is required in an array"))?;
122    Ok(Schema::Array(Box::new(schema)))
123}
124
125fn to_field<E: serde::de::Error>(data: Value) -> Result<Field, E> {
126    serde_json::from_value(data).map_err(E::custom)
127}
128
129fn to_vec_fields<E: serde::de::Error>(
130    data: &mut HashMap<String, Value>,
131    key: &str,
132) -> Result<Vec<Field>, E> {
133    match data.remove(key) {
134        Some(s) => {
135            if let Value::Array(x) = s {
136                x.into_iter().map(to_field).collect()
137            } else {
138                Err(E::custom(format!("{} must be a string", key)))
139            }
140        }
141        None => Ok(vec![]),
142    }
143}
144
145fn to_record<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
146    Ok(Schema::Record(Record {
147        name: remove_string(data, "name")?
148            .ok_or_else(|| serde::de::Error::custom("name is required in enum"))?,
149        namespace: remove_string(data, "namespace")?,
150        aliases: remove_vec_string(data, "aliases")?,
151        doc: remove_string(data, "doc")?,
152        fields: to_vec_fields(data, "fields")?,
153    }))
154}
155
156fn to_fixed<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
157    let size = remove_usize(data, "size")?
158        .ok_or_else(|| serde::de::Error::custom("size is required in fixed"))?;
159
160    let logical = remove_string(data, "logicalType")?.unwrap_or_default();
161    let logical = match logical.as_ref() {
162        "decimal" => {
163            let precision = remove_usize(data, "precision")?;
164            let scale = remove_usize(data, "scale")?.unwrap_or_default();
165            precision.map(|p| FixedLogical::Decimal(p, scale))
166        }
167        "duration" => Some(FixedLogical::Duration),
168        _ => None,
169    };
170
171    Ok(Schema::Fixed(Fixed {
172        name: remove_string(data, "name")?
173            .ok_or_else(|| serde::de::Error::custom("name is required in fixed"))?,
174        namespace: remove_string(data, "namespace")?,
175        aliases: remove_vec_string(data, "aliases")?,
176        doc: remove_string(data, "doc")?,
177        size,
178        logical,
179    }))
180}
181
182fn to_order<E: serde::de::Error>(
183    data: &mut HashMap<String, Value>,
184    key: &str,
185) -> Result<Option<Order>, E> {
186    remove_string(data, key)?
187        .map(|x| {
188            Ok(match x.as_ref() {
189                "ascending" => Order::Ascending,
190                "descending" => Order::Descending,
191                "ignore" => Order::Ignore,
192                _ => {
193                    return Err(serde::de::Error::custom(
194                        "order can only be one of {ascending, descending, ignore}",
195                    ))
196                }
197            })
198        })
199        .transpose()
200}
201
202struct SchemaVisitor {}
203
204impl<'de> Visitor<'de> for SchemaVisitor {
205    type Value = Schema;
206
207    // Format a message stating what data this Visitor expects to receive.
208    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
209        formatter.write_str("a null, string, array or map describing an Avro schema")
210    }
211
212    fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
213    where
214        D: Deserializer<'de>,
215    {
216        deserializer.deserialize_any(SchemaVisitor {})
217    }
218
219    fn visit_none<E>(self) -> Result<Self::Value, E>
220    where
221        E: serde::de::Error,
222    {
223        Ok(Schema::Null)
224    }
225
226    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
227    where
228        E: serde::de::Error,
229    {
230        to_primitive(v)
231            .ok_or_else(|| serde::de::Error::custom("string must be a valid primitive Schema"))
232    }
233
234    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
235    where
236        A: SeqAccess<'de>,
237    {
238        let mut vec = Vec::with_capacity(seq.size_hint().unwrap_or(0));
239        while let Some(item) = seq.next_element::<Value>()? {
240            let schema: Schema = serde_json::from_value(item).map_err(serde::de::Error::custom)?;
241            vec.push(schema)
242        }
243        Ok(Schema::Union(vec))
244    }
245
246    fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
247    where
248        M: MapAccess<'de>,
249    {
250        let mut map = HashMap::<String, Value>::with_capacity(access.size_hint().unwrap_or(0));
251
252        // While there are entries remaining in the input, add them
253        // into our map.
254        while let Some((key, value)) = access.next_entry()? {
255            map.insert(key, value);
256        }
257
258        let (schema, type_) = get_type(&mut map).map(|x| (to_primitive(&x), x))?;
259
260        if let Some(schema) = schema {
261            Ok(match type_.as_ref() {
262                "string" => {
263                    let logical = remove_string(&mut map, "logicalType")?.unwrap_or_default();
264                    match logical.as_ref() {
265                        "uuid" => Schema::String(Some(StringLogical::Uuid)),
266                        _ => schema,
267                    }
268                }
269                "int" => {
270                    let logical = remove_string(&mut map, "logicalType")?.unwrap_or_default();
271                    match logical.as_ref() {
272                        "date" => Schema::Int(Some(IntLogical::Date)),
273                        "time-millis" => Schema::Int(Some(IntLogical::Time)),
274                        _ => schema,
275                    }
276                }
277                "long" => {
278                    let logical = remove_string(&mut map, "logicalType")?.unwrap_or_default();
279                    match logical.as_ref() {
280                        "time-micros" => Schema::Long(Some(LongLogical::Time)),
281                        "timestamp-millis" => Schema::Long(Some(LongLogical::TimestampMillis)),
282                        "timestamp-micros" => Schema::Long(Some(LongLogical::TimestampMicros)),
283                        "local-timestamp-millis" => {
284                            Schema::Long(Some(LongLogical::LocalTimestampMillis))
285                        }
286                        "local-timestamp-micros" => {
287                            Schema::Long(Some(LongLogical::LocalTimestampMicros))
288                        }
289                        _ => schema,
290                    }
291                }
292                "bytes" => {
293                    let logical = remove_string(&mut map, "logicalType")?.unwrap_or_default();
294                    match logical.as_ref() {
295                        "decimal" => {
296                            let precision = remove_usize(&mut map, "precision")?;
297                            let scale = remove_usize(&mut map, "scale")?.unwrap_or_default();
298                            Schema::Bytes(precision.map(|p| BytesLogical::Decimal(p, scale)))
299                        }
300                        _ => schema,
301                    }
302                }
303                _ => schema,
304            })
305        } else {
306            match type_.as_ref() {
307                "enum" => to_enum(&mut map),
308                "map" => to_map(&mut map),
309                "array" => to_array(&mut map),
310                "record" => to_record(&mut map),
311                "fixed" => to_fixed(&mut map),
312                other => todo!("{}", other),
313            }
314        }
315    }
316}
317
318impl<'de> Deserialize<'de> for Schema {
319    fn deserialize<D>(deserializer: D) -> Result<Schema, D::Error>
320    where
321        D: Deserializer<'de>,
322    {
323        deserializer.deserialize_option(SchemaVisitor {})
324    }
325}
326
327struct FieldVisitor {}
328
329impl<'de> Visitor<'de> for FieldVisitor {
330    type Value = Field;
331
332    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
333        formatter.write_str("a map describing an Avro field")
334    }
335
336    fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
337    where
338        M: MapAccess<'de>,
339    {
340        let mut map = HashMap::<String, Value>::with_capacity(access.size_hint().unwrap_or(0));
341
342        // While there are entries remaining in the input, add them
343        // into our map.
344        while let Some((key, value)) = access.next_entry()? {
345            map.insert(key, value);
346        }
347
348        Ok(Field {
349            name: remove_string(&mut map, "name")?
350                .ok_or_else(|| serde::de::Error::custom("name is required in enum"))?,
351            doc: remove_string(&mut map, "doc")?,
352            schema: to_schema(&mut map, "type")?
353                .ok_or_else(|| serde::de::Error::custom("type is required in Field"))?,
354            default: to_schema(&mut map, "default")?,
355            order: to_order(&mut map, "order")?,
356            aliases: remove_vec_string(&mut map, "aliases")?,
357        })
358    }
359}
360
361impl<'de> Deserialize<'de> for Field {
362    fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
363    where
364        D: Deserializer<'de>,
365    {
366        deserializer.deserialize_map(FieldVisitor {})
367    }
368}