async_xml/
reader.rs

1//! Deserialization implementations
2
3use crate::Error;
4use quick_xml::events::Event;
5use quick_xml::Decoder;
6use tokio::io::AsyncBufRead;
7use tracing::Instrument;
8
9mod impls;
10
11pub use impls::{FromStringVisitor, FromVisitor, OptionalVisitor, TryFromVisitor, XmlFromStr};
12
13/// Type alias for the underlying reader
14pub type XmlReader<R> = quick_xml::Reader<R>;
15
16/// A wrapper around a [`XmlReader`] that supports peeking XML events without consuming them
17pub struct PeekingReader<B: AsyncBufRead> {
18    reader: XmlReader<B>,
19    peeked_event: Option<Event<'static>>,
20}
21
22impl<B: AsyncBufRead + Unpin> PeekingReader<B> {
23    /// Create a new [`PeekingReader`] from a buffered reader
24    pub fn from_buf(reader: B) -> Self {
25        let mut reader = XmlReader::from_reader(reader);
26        Self::set_reader_defaults(&mut reader);
27        Self {
28            reader,
29            peeked_event: None,
30        }
31    }
32
33    /// Consume this [`PeekingReader`] and returns the underlying buffered reader
34    pub fn into_inner(self) -> B {
35        self.reader.into_inner()
36    }
37
38    fn set_reader_defaults(reader: &mut XmlReader<B>) {
39        reader.expand_empty_elements(true).trim_text(true);
40    }
41
42    /// Peek a single event without consuming it
43    pub async fn peek_event(&mut self) -> quick_xml::Result<&Event<'static>> {
44        if self.peeked_event.is_none() {
45            self.peeked_event = Some(self.next_event_internal().await?);
46        }
47        Ok(self.peeked_event.as_ref().unwrap())
48    }
49
50    /// Read an event, consuming it
51    ///
52    /// If an event has been peeked but not yet consumed, the previously peeked event will be returned.
53    pub async fn read_event(&mut self) -> quick_xml::Result<Event<'static>> {
54        if let Some(event) = self.peeked_event.take() {
55            return Ok(event);
56        }
57        self.next_event_internal().await
58    }
59
60    async fn next_event_internal(&mut self) -> quick_xml::Result<Event<'static>> {
61        let mut buf = Vec::new();
62        let event = self
63            .reader
64            .read_event_into_async(&mut buf)
65            .await?
66            .into_owned();
67        tracing::trace!("read XML event: {:?}", event);
68        Ok(event)
69    }
70
71    /// Get the underlying XML decoder
72    pub fn decoder(&self) -> Decoder {
73        self.reader.decoder()
74    }
75
76    /// Consume and discard the next element including all of its child elements
77    pub async fn skip_element(&mut self) -> Result<(), Error> {
78        let dec = self.reader.decoder();
79        let start_tag;
80        match self.peek_event().await? {
81            Event::Start(start) => {
82                // check for start element name
83                let name = start.local_name();
84                let name = dec.decode(name.as_ref())?;
85                tracing::debug!("Skipping over element <{}>", name);
86                // store name to match expected end element
87                start_tag = name.to_string();
88                // remove peeked start event
89                self.read_event().await?;
90            }
91            _ => {
92                return Err(Error::MissingStart);
93            }
94        }
95        let mut depth = 0_usize;
96
97        loop {
98            match self.peek_event().await? {
99                Event::End(end) => {
100                    let name = end.local_name();
101                    let name = dec.decode(name.as_ref())?.to_string();
102                    // remove peeked end event
103                    self.read_event().await?;
104                    // check for name
105                    if name == start_tag && depth == 0 {
106                        tracing::trace!("done skipping");
107                        return Ok(());
108                    }
109                    depth -= 1;
110                    tracing::trace!("ascending to depth {:?}", depth);
111                }
112                Event::Start(_) => {
113                    self.read_event().await?;
114                    depth += 1;
115                    tracing::trace!("descending to depth {:?}", depth);
116                }
117                _ => {
118                    self.read_event().await?;
119                }
120            }
121        }
122    }
123
124    /// Read a single element from the XML input and deserialize it into a `T`
125    pub async fn deserialize<T>(&mut self) -> Result<T, Error>
126    where
127        T: FromXml<B>,
128    {
129        let mut visitor = T::Visitor::default();
130        let dec = self.reader.decoder();
131
132        let start_tag;
133        let element_span;
134        match self.peek_event().await? {
135            Event::Start(start) => {
136                // check for start element name
137                let name = start.local_name();
138                let name = dec.decode(name.as_ref())?;
139                tracing::debug!("deserializing XML element <{}>", name);
140                if let Some(expected_name) = T::Visitor::start_name() {
141                    if name != expected_name {
142                        return Err(Error::WrongStart(expected_name.into(), name.into()));
143                    }
144                }
145                // store name to match expected end element
146                start_tag = name.to_string();
147                element_span = tracing::debug_span!("deserialize", element = start_tag);
148                let span_guard = element_span.enter();
149                visitor.visit_tag(&name)?;
150                // read attributes
151                for attr in start.attributes() {
152                    let attr = attr?;
153                    let attr_name = dec.decode(attr.key.as_ref())?;
154                    let attr_value = dec.decode(attr.value.as_ref())?;
155                    let attr_value = quick_xml::escape::unescape(&attr_value)?;
156                    tracing::trace!("visiting attribute: {:?}", attr_name);
157                    visitor.visit_attribute(&attr_name, &attr_value)?;
158                }
159                // drop here because async
160                drop(span_guard);
161                // remove peeked start event
162                self.read_event()
163                    .instrument(element_span.clone().or_current())
164                    .await?;
165            }
166            _ => {
167                return Err(Error::MissingStart);
168            }
169        }
170
171        async move {
172            loop {
173                match self.peek_event().await? {
174                    Event::End(end) => {
175                        let name = end.local_name();
176                        let name = dec.decode(name.as_ref())?.to_string();
177                        // remove peeked end event
178                        self.read_event().await?;
179                        // check for name
180                        if name != start_tag {
181                            return Err(Error::WrongEnd(start_tag, name));
182                        }
183                        tracing::trace!("finishing deserialization of XML element <{}>", name);
184                        return visitor.build();
185                    }
186                    Event::Text(text) => {
187                        let text = dec.decode(&text)?;
188                        let text = quick_xml::escape::unescape(&text)?;
189                        tracing::trace!("visiting element text");
190                        visitor.visit_text(&text)?;
191                        // remove peeked event
192                        self.read_event().await?;
193                    }
194                    Event::Start(start) => {
195                        // peeked child start element -> find name and call into sub-element
196                        let name = start.local_name();
197                        let name = dec.decode(name.as_ref())?.to_string();
198                        tracing::trace!("visiting child: {:?}", name);
199                        visitor.visit_child(&name, self).await?;
200                    }
201                    _ => {
202                        self.read_event().await?;
203                    }
204                }
205            }
206        }
207        .instrument(element_span.or_current())
208        .await
209    }
210}
211
212impl<'r> PeekingReader<&'r [u8]> {
213    /// Create a new [`PeekingReader`] reading XML event from a [`str`].
214    #[allow(clippy::should_implement_trait)]
215    pub fn from_str(str: &'r str) -> Self {
216        Self::from_buf(str.as_bytes())
217    }
218}
219
220/// Marks a type as being deserializable from XML
221pub trait FromXml<B: AsyncBufRead + Unpin> {
222    /// The visitor to use to deserialize this type
223    type Visitor: Visitor<B, Output = Self> + Default;
224}
225
226/// A trait for building up instances of types during deserialization
227///
228/// As [`XmlReader::read_event_into_async()`](quick_xml::Reader::read_event_into_async) does not return a `Send`
229/// future, this entire trait must be `?Send`.
230#[async_trait::async_trait(?Send)]
231pub trait Visitor<B: AsyncBufRead + Unpin> {
232    /// Output type this [`Visitor`] returns
233    type Output;
234
235    /// Should return the expected starting tag name, if any
236    fn start_name() -> Option<&'static str> {
237        None
238    }
239
240    /// Visit the starting tag with the given name
241    ///
242    /// This is called exactly once during deserialization and will be called before any other `visit_*` methods.
243    #[allow(unused_variables)]
244    fn visit_tag(&mut self, name: &str) -> Result<(), Error> {
245        Ok(())
246    }
247
248    /// Visit an attribute with the given name and value
249    #[allow(unused_variables)]
250    fn visit_attribute(&mut self, name: &str, value: &str) -> Result<(), Error> {
251        Err(Error::UnexpectedAttribute(name.into()))
252    }
253
254    /// Visit a child element with the given tag name
255    ///
256    /// Implementations must make sure the child element is read in some way. Most likely this will be either a
257    /// [`reader.skip_element()`](PeekingReader::skip_element) or [`reader.deserialize()`](PeekingReader::deserialize)
258    /// call.
259    #[allow(unused_variables)]
260    async fn visit_child(
261        &mut self,
262        name: &str,
263        reader: &mut PeekingReader<B>,
264    ) -> Result<(), Error> {
265        Err(Error::UnexpectedChild(name.into()))
266    }
267
268    /// Visit any plain text contained in the element
269    ///
270    /// May be called multiple times.
271    #[allow(unused_variables)]
272    fn visit_text(&mut self, text: &str) -> Result<(), Error> {
273        Err(Error::UnexpectedText)
274    }
275
276    /// Validate and build the output type
277    fn build(self) -> Result<Self::Output, Error>;
278}