fhirbolt_serde/element/internal/
de.rs

1use std::{iter, mem, vec};
2
3use serde::{
4    de::{self, DeserializeSeed, Error, MapAccess, SeqAccess, Unexpected, Visitor},
5    forward_to_deserialize_any,
6};
7
8use fhirbolt_element::{Element, Primitive, Value};
9use fhirbolt_shared::{path::ElementPath, FhirRelease};
10
11use crate::{
12    context::{
13        de::{CurrentElement, DeserializationContext},
14        Format,
15    },
16    element::{self, Deserializer},
17    DeserializationMode,
18};
19
20const SERDE_JSON_NUMBER_TOKEN: &str = "$serde_json::private::Number";
21const PRIMITIVE_CHILDREN: &[&str] = &["id", "extension", "value"];
22
23#[derive(Default, Debug)]
24pub struct InternalElement<const R: FhirRelease>(pub Element<R>);
25
26impl<const R: FhirRelease> InternalElement<R> {
27    pub fn into_element<'a, D>(
28        self,
29        deserialization_mode: DeserializationMode,
30        current_path: &mut ElementPath,
31    ) -> Result<Element<R>, D::Error>
32    where
33        D: de::Deserializer<'a>,
34    {
35        let mut element = self.0;
36
37        tri!(resolve_element_types::<D, R>(
38            &mut element,
39            deserialization_mode,
40            current_path
41        ));
42
43        Ok(element)
44    }
45}
46
47pub fn resolve_element_types<'a, D, const R: FhirRelease>(
48    element: &mut Element<R>,
49    deserialization_mode: DeserializationMode,
50    current_path: &mut ElementPath,
51) -> Result<(), D::Error>
52where
53    D: de::Deserializer<'a>,
54{
55    let mut is_resource = false;
56
57    for (key, value) in element.iter_mut() {
58        // in case of top-level resource: current_path is empty
59        if (current_path.current_element_is_resource() || current_path.is_empty())
60            && key == "resourceType"
61        {
62            if let Value::Primitive(Primitive::String(s)) = value {
63                current_path.push(s);
64                is_resource = true;
65            }
66        } else {
67            // check if field is valid at current path
68            if deserialization_mode == DeserializationMode::Strict {
69                tri!(validate_field_is_valid::<D>(current_path, key));
70            }
71
72            current_path.push(key);
73
74            tri!(resolve_value_types::<D, R>(
75                value,
76                deserialization_mode,
77                current_path
78            ));
79
80            if current_path.current_element_is_sequence() {
81                if let Value::Element(e) = value {
82                    *value = Value::Sequence(vec![mem::take(e)]);
83                }
84            }
85
86            current_path.pop();
87        }
88    }
89
90    if is_resource {
91        current_path.pop();
92    }
93
94    Ok(())
95}
96
97fn validate_field_is_valid<'a, D>(current_path: &ElementPath, field: &str) -> Result<(), D::Error>
98where
99    D: de::Deserializer<'a>,
100{
101    if current_path.current_element_is_primitive() {
102        if !PRIMITIVE_CHILDREN.contains(&field) {
103            return Err(Error::custom(format_args!(
104                "unknown field `{}`, expected one of primitive children {:?}",
105                field, PRIMITIVE_CHILDREN
106            )));
107        }
108    } else {
109        let fields = current_path.children();
110
111        if !fields.map(|s| s.contains(field)).unwrap_or(false) {
112            if let Some(expected_fields) = fields {
113                return Err(Error::custom(format_args!(
114                    "unknown field `{}`, expected one of {:?}",
115                    field,
116                    &expected_fields.iter().collect::<Vec<_>>()
117                )));
118            } else {
119                return Err(Error::custom(format_args!(
120                    "unknown field `{}`, there are no fields",
121                    field
122                )));
123            }
124        }
125    }
126
127    Ok(())
128}
129
130fn resolve_value_types<'a, D, const R: FhirRelease>(
131    value: &mut Value<R>,
132    deserialization_mode: DeserializationMode,
133    current_path: &mut ElementPath,
134) -> Result<(), D::Error>
135where
136    D: de::Deserializer<'a>,
137{
138    match value {
139        Value::Element(e) => {
140            tri!(resolve_element_types::<D, R>(
141                e,
142                deserialization_mode,
143                current_path
144            ));
145        }
146        Value::Sequence(s) => {
147            for element in s {
148                tri!(resolve_element_types::<D, R>(
149                    element,
150                    deserialization_mode,
151                    current_path
152                ));
153            }
154        }
155        Value::Primitive(p) => {
156            if current_path.parent_element_is_boolean() {
157                *p = tri!(map_bool::<D>(p))
158            } else if current_path.parent_element_is_integer()
159                || current_path.parent_element_is_positive_integer()
160                || current_path.parent_element_is_unsigned_integer()
161            {
162                *p = tri!(map_integer::<D>(p))
163            } else if current_path.parent_element_is_integer64() {
164                *p = tri!(map_integer64::<D>(p))
165            } else if current_path.parent_element_is_decimal() {
166                *p = tri!(map_decimal::<D>(p))
167            } else {
168                *p = tri!(map_string::<D>(p))
169            }
170        }
171    }
172
173    Ok(())
174}
175
176fn map_bool<'a, D>(primitive: &Primitive) -> Result<Primitive, D::Error>
177where
178    D: de::Deserializer<'a>,
179{
180    let expected = "a boolean";
181    match primitive {
182        Primitive::Bool(b) => Ok(Primitive::Bool(*b)),
183        Primitive::Integer(i) => Err(Error::invalid_type(
184            Unexpected::Signed((*i).into()),
185            &expected,
186        )),
187        Primitive::Integer64(i) => Err(Error::invalid_type(
188            Unexpected::Other(&format!("integer `{}`", i)),
189            &expected,
190        )),
191        Primitive::Decimal(s) => {
192            return Err(Error::invalid_type(
193                Unexpected::Other(&format!("decimal `{}`", s)),
194                &expected,
195            ))
196        }
197        Primitive::String(s) => Ok(Primitive::Bool(tri!(s
198            .parse()
199            .map_err(|_| { Error::invalid_value(Unexpected::Other(s), &expected) })))),
200    }
201}
202
203fn map_integer<'a, D>(primitive: &Primitive) -> Result<Primitive, D::Error>
204where
205    D: de::Deserializer<'a>,
206{
207    let expected = "an integer";
208    match primitive {
209        Primitive::Bool(b) => Err(Error::invalid_type(Unexpected::Bool(*b), &expected)),
210        Primitive::Integer(i) => Ok(Primitive::Integer(*i)),
211        Primitive::Integer64(i) => Ok(Primitive::Integer(*i as i32)),
212        Primitive::Decimal(s) => {
213            return Err(Error::invalid_type(
214                Unexpected::Other(&format!("decimal `{}`", s)),
215                &expected,
216            ))
217        }
218        Primitive::String(s) => Ok(Primitive::Integer(tri!(s
219            .parse()
220            .map_err(|_| { Error::invalid_value(Unexpected::Other(s), &expected) })))),
221    }
222}
223
224fn map_integer64<'a, D>(primitive: &Primitive) -> Result<Primitive, D::Error>
225where
226    D: de::Deserializer<'a>,
227{
228    let expected = "an integer64";
229    match primitive {
230        Primitive::Bool(b) => Err(Error::invalid_type(Unexpected::Bool(*b), &expected)),
231        Primitive::Integer(i) => Ok(Primitive::Integer64((*i).into())),
232        Primitive::Integer64(i) => Ok(Primitive::Integer64(*i)),
233        Primitive::Decimal(s) => {
234            return Err(Error::invalid_type(
235                Unexpected::Other(&format!("decimal `{}`", s)),
236                &expected,
237            ))
238        }
239        Primitive::String(s) => Ok(Primitive::Integer64(tri!(s
240            .parse()
241            .map_err(|_| { Error::invalid_value(Unexpected::Other(s), &expected) })))),
242    }
243}
244
245fn map_decimal<'a, D>(primitive: &mut Primitive) -> Result<Primitive, D::Error>
246where
247    D: de::Deserializer<'a>,
248{
249    let expected = "a decimal";
250    match primitive {
251        Primitive::Bool(b) => Err(Error::invalid_type(Unexpected::Bool(*b), &expected)),
252        Primitive::Integer(i) => Ok(Primitive::Decimal(i.to_string())),
253        Primitive::Integer64(i) => Ok(Primitive::Decimal(i.to_string())),
254        Primitive::Decimal(s) | Primitive::String(s) => Ok(Primitive::Decimal(mem::take(s))),
255    }
256}
257
258fn map_string<'a, D>(primitive: &mut Primitive) -> Result<Primitive, D::Error>
259where
260    D: de::Deserializer<'a>,
261{
262    let expected = "a string";
263    match primitive {
264        Primitive::Bool(b) => Err(Error::invalid_type(Unexpected::Bool(*b), &expected)),
265        Primitive::Integer(i) => Err(Error::invalid_type(
266            Unexpected::Signed((*i).into()),
267            &expected,
268        )),
269        Primitive::Integer64(i) => Err(Error::invalid_type(
270            Unexpected::Other(&format!("integer `{}`", i)),
271            &expected,
272        )),
273        Primitive::Decimal(s) => {
274            return Err(Error::invalid_type(
275                Unexpected::Other(&format!("decimal `{}`", s)),
276                &expected,
277            ))
278        }
279        Primitive::String(s) => Ok(Primitive::String(mem::take(s))),
280    }
281}
282
283impl<'de, const R: FhirRelease> DeserializeSeed<'de>
284    for &mut DeserializationContext<InternalElement<R>>
285{
286    type Value = InternalElement<R>;
287
288    #[inline]
289    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
290    where
291        D: de::Deserializer<'de>,
292    {
293        match tri!(deserializer.deserialize_any(ValueVisitor(self.transmute()))) {
294            Value::Element(e) => Ok(InternalElement(e)),
295            Value::Sequence(_) => Err(Error::invalid_type(Unexpected::Seq, &"an element")),
296            Value::Primitive(_) => Err(Error::invalid_type(
297                Unexpected::Other("primitive"),
298                &"an element",
299            )),
300        }
301    }
302}
303
304impl<'de, const R: FhirRelease> DeserializeSeed<'de> for &mut DeserializationContext<Value<R>> {
305    type Value = Value<R>;
306
307    #[inline]
308    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
309    where
310        D: de::Deserializer<'de>,
311    {
312        deserializer.deserialize_any(ValueVisitor(self))
313    }
314}
315
316fn merge_sequences<const R: FhirRelease>(
317    left: Vec<Element<R>>,
318    right: Vec<Element<R>>,
319) -> Vec<Element<R>> {
320    let left_iter = left.into_iter().map(Some).chain(iter::repeat(None));
321    let right_iter = right.into_iter().map(Some).chain(iter::repeat(None));
322
323    left_iter
324        .zip(right_iter)
325        .take_while(|(e, n)| e.is_some() || n.is_some())
326        .flat_map(|(e, n)| match (e, n) {
327            (Some(mut e), Some(n)) => {
328                e.extend(n);
329                Some(e)
330            }
331            (Some(e), None) => Some(e),
332            (None, Some(n)) => Some(n),
333            _ => None,
334        })
335        .collect()
336}
337
338struct ValueVisitor<'a, const R: FhirRelease>(&'a mut DeserializationContext<Value<R>>);
339
340impl<'a, 'de, const R: FhirRelease> Visitor<'de> for ValueVisitor<'a, R> {
341    type Value = Value<R>;
342
343    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
344        formatter.write_str("a map")
345    }
346
347    #[inline]
348    fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
349    where
350        E: Error,
351    {
352        if self.0.from == Format::Json {
353            Ok(Value::Element(Element! {
354                "value" => Value::Primitive(Primitive::Bool(v)),
355            }))
356        } else {
357            Ok(Value::Primitive(Primitive::Bool(v)))
358        }
359    }
360
361    #[inline]
362    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
363    where
364        E: Error,
365    {
366        if self.0.from == Format::Json {
367            Ok(Value::Element(Element! {
368                "value" =>  Value::Primitive(Primitive::Integer64(v as i64)),
369            }))
370        } else {
371            Ok(Value::Primitive(Primitive::Integer64(v as i64)))
372        }
373    }
374
375    #[inline]
376    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
377    where
378        E: Error,
379    {
380        if self.0.from == Format::Json {
381            Ok(Value::Element(Element! {
382                "value" =>  Value::Primitive(Primitive::Integer64(v)),
383            }))
384        } else {
385            Ok(Value::Primitive(Primitive::Integer64(v)))
386        }
387    }
388
389    #[inline]
390    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
391    where
392        E: Error,
393    {
394        if self.0.from == Format::Json {
395            let number =
396                tri!(serde_json::Number::from_f64(v)
397                    .ok_or_else(|| Error::custom("not a JSON number")));
398
399            Ok(Value::Element(Element! {
400                "value" =>  Value::Primitive(Primitive::Decimal(number.to_string())),
401            }))
402        } else {
403            Ok(Value::Primitive(Primitive::Decimal(v.to_string())))
404        }
405    }
406
407    #[inline]
408    fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
409    where
410        E: Error,
411    {
412        if self.0.from == Format::Json {
413            if self.0.current_element() == CurrentElement::Id
414                || self.0.current_element() == CurrentElement::ExtensionUrl
415            {
416                Ok(Value::Primitive(Primitive::String(v)))
417            } else {
418                Ok(Value::Element(Element! {
419                    "value" =>  Value::Primitive(Primitive::String(v)),
420                }))
421            }
422        } else {
423            Ok(Value::Primitive(Primitive::String(v)))
424        }
425    }
426
427    #[inline]
428    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
429    where
430        E: Error,
431    {
432        self.visit_string(v.to_string())
433    }
434
435    fn visit_map<V>(self, mut map_access: V) -> Result<Self::Value, V::Error>
436    where
437        V: MapAccess<'de>,
438    {
439        let mut element = Element::default();
440
441        while let Some(key) = tri!(map_access.next_key::<String>()) {
442            if key == SERDE_JSON_NUMBER_TOKEN {
443                return Ok(Value::Element(Element! {
444                    "value" =>  Value::Primitive(Primitive::Decimal(tri!(map_access.next_value()))),
445                }));
446            }
447
448            let key = if let Some(stripped) = key.strip_prefix('_') {
449                stripped.into()
450            } else {
451                key
452            };
453
454            if key == "resourceType"
455                && self.0.current_element() != CurrentElement::ExampleScenarioInstance
456                && self.0.current_element() != CurrentElement::ConsentProvision
457                && self.0.current_element() != CurrentElement::SubscriptionFilterBy
458            {
459                let value: String = tri!(map_access.next_value());
460
461                element.insert(key, Value::Primitive(Primitive::String(value)));
462            } else {
463                self.0.push_current_element(match key.as_str() {
464                    "id" => CurrentElement::Id,
465                    "instance" => CurrentElement::ExampleScenarioInstance,
466                    "provision" => CurrentElement::ConsentProvision,
467                    "filterBy" => CurrentElement::SubscriptionFilterBy,
468                    "extension" | "modifierExtension" => CurrentElement::Extension,
469                    "url" => {
470                        if self.0.current_element() == CurrentElement::Extension {
471                            CurrentElement::ExtensionUrl
472                        } else {
473                            CurrentElement::Other
474                        }
475                    }
476                    _ => CurrentElement::Other,
477                });
478
479                let value = tri!(map_access.next_value_seed(self.0.transmute::<Value<R>>()));
480                let existing = element.remove(&key);
481
482                let matched_value = match (existing, value) {
483                    (Some(Value::Element(mut e)), Value::Element(n)) => {
484                        if self.0.from == Format::Json {
485                            e.extend(n);
486                            Value::Element(e)
487                        } else {
488                            Value::Sequence(vec![e, n])
489                        }
490                    }
491                    (Some(Value::Sequence(ev)), Value::Sequence(nv)) => {
492                        Value::Sequence(merge_sequences(ev, nv))
493                    }
494                    (Some(Value::Sequence(mut es)), Value::Element(n)) => {
495                        es.push(n);
496                        Value::Sequence(es)
497                    }
498                    (_e, v) => v,
499                };
500
501                element.insert(key, matched_value);
502
503                self.0.pop_current_element();
504            }
505        }
506
507        fn embed_string_in_element<const R: FhirRelease>(value: &mut Value<R>) {
508            if let Value::Primitive(Primitive::String(s)) = value {
509                *value = Value::Element(Element! {
510                    "value" =>  Value::Primitive(Primitive::String(mem::take(s))),
511                });
512            }
513        }
514
515        if element.contains_key("resourceType") {
516            if let Some(id) = element.get_mut("id") {
517                embed_string_in_element(id)
518            }
519        }
520
521        Ok(Value::Element(element))
522    }
523
524    fn visit_seq<V>(self, mut seq_access: V) -> Result<Self::Value, V::Error>
525    where
526        V: SeqAccess<'de>,
527    {
528        let mut elements = Vec::new();
529
530        while let Some(value) =
531            tri!(seq_access.next_element_seed(self.0.transmute::<Option<Value<R>>>()))
532        {
533            match value {
534                Some(Value::Element(e)) => elements.push(e),
535                Some(Value::Sequence(_)) => {
536                    return Err(Error::invalid_type(Unexpected::Seq, &"a sequence element"))
537                }
538                Some(Value::Primitive(_)) => {
539                    return Err(Error::invalid_type(Unexpected::Seq, &"a sequence element"))
540                }
541                None => elements.push(Default::default()),
542            }
543        }
544
545        Ok(Value::Sequence(elements))
546    }
547}
548
549impl<'de, const R: FhirRelease> DeserializeSeed<'de>
550    for &mut DeserializationContext<Option<Value<R>>>
551{
552    type Value = Option<Value<R>>;
553
554    #[inline]
555    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
556    where
557        D: de::Deserializer<'de>,
558    {
559        deserializer.deserialize_option(SeqElementVisitor(self))
560    }
561}
562
563struct SeqElementVisitor<'a, const R: FhirRelease>(
564    &'a mut DeserializationContext<Option<Value<R>>>,
565);
566
567impl<'a, 'de, const R: FhirRelease> Visitor<'de> for SeqElementVisitor<'a, R> {
568    type Value = Option<Value<R>>;
569
570    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
571        formatter.write_str("a sequence element")
572    }
573
574    #[inline]
575    fn visit_map<V>(self, map_access: V) -> Result<Self::Value, V::Error>
576    where
577        V: MapAccess<'de>,
578    {
579        ValueVisitor(self.0.transmute())
580            .visit_map(map_access)
581            .map(Some)
582    }
583
584    #[inline]
585    fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
586    where
587        D: de::Deserializer<'de>,
588    {
589        self.0
590            .transmute::<Value<R>>()
591            .deserialize(deserializer)
592            .map(Some)
593    }
594
595    #[inline]
596    fn visit_none<E>(self) -> Result<Self::Value, E>
597    where
598        E: Error,
599    {
600        Ok(None)
601    }
602}
603
604impl<'de, const R: FhirRelease> de::Deserializer<'de> for Deserializer<InternalElement<R>> {
605    type Error = element::error::Error;
606
607    #[inline]
608    fn deserialize_any<V>(self, visitor: V) -> element::error::Result<V::Value>
609    where
610        V: Visitor<'de>,
611    {
612        Deserializer(self.0 .0).deserialize_any(visitor)
613    }
614
615    forward_to_deserialize_any! {
616        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
617        bytes byte_buf option unit unit_struct newtype_struct tuple
618        tuple_struct map struct enum seq identifier ignored_any
619    }
620}