1use crate::Error;
4use quick_xml::events::Event;
5use quick_xml::Decoder;
6use tokio::io::AsyncBufRead;
7use tracing::Instrument;
8
9mod impls;
10
11pub use impls::{FromStringVisitor, FromVisitor, OptionalVisitor, TryFromVisitor, XmlFromStr};
12
13pub type XmlReader<R> = quick_xml::Reader<R>;
15
16pub struct PeekingReader<B: AsyncBufRead> {
18 reader: XmlReader<B>,
19 peeked_event: Option<Event<'static>>,
20}
21
22impl<B: AsyncBufRead + Unpin> PeekingReader<B> {
23 pub fn from_buf(reader: B) -> Self {
25 let mut reader = XmlReader::from_reader(reader);
26 Self::set_reader_defaults(&mut reader);
27 Self {
28 reader,
29 peeked_event: None,
30 }
31 }
32
33 pub fn into_inner(self) -> B {
35 self.reader.into_inner()
36 }
37
38 fn set_reader_defaults(reader: &mut XmlReader<B>) {
39 reader.expand_empty_elements(true).trim_text(true);
40 }
41
42 pub async fn peek_event(&mut self) -> quick_xml::Result<&Event<'static>> {
44 if self.peeked_event.is_none() {
45 self.peeked_event = Some(self.next_event_internal().await?);
46 }
47 Ok(self.peeked_event.as_ref().unwrap())
48 }
49
50 pub async fn read_event(&mut self) -> quick_xml::Result<Event<'static>> {
54 if let Some(event) = self.peeked_event.take() {
55 return Ok(event);
56 }
57 self.next_event_internal().await
58 }
59
60 async fn next_event_internal(&mut self) -> quick_xml::Result<Event<'static>> {
61 let mut buf = Vec::new();
62 let event = self
63 .reader
64 .read_event_into_async(&mut buf)
65 .await?
66 .into_owned();
67 tracing::trace!("read XML event: {:?}", event);
68 Ok(event)
69 }
70
71 pub fn decoder(&self) -> Decoder {
73 self.reader.decoder()
74 }
75
76 pub async fn skip_element(&mut self) -> Result<(), Error> {
78 let dec = self.reader.decoder();
79 let start_tag;
80 match self.peek_event().await? {
81 Event::Start(start) => {
82 let name = start.local_name();
84 let name = dec.decode(name.as_ref())?;
85 tracing::debug!("Skipping over element <{}>", name);
86 start_tag = name.to_string();
88 self.read_event().await?;
90 }
91 _ => {
92 return Err(Error::MissingStart);
93 }
94 }
95 let mut depth = 0_usize;
96
97 loop {
98 match self.peek_event().await? {
99 Event::End(end) => {
100 let name = end.local_name();
101 let name = dec.decode(name.as_ref())?.to_string();
102 self.read_event().await?;
104 if name == start_tag && depth == 0 {
106 tracing::trace!("done skipping");
107 return Ok(());
108 }
109 depth -= 1;
110 tracing::trace!("ascending to depth {:?}", depth);
111 }
112 Event::Start(_) => {
113 self.read_event().await?;
114 depth += 1;
115 tracing::trace!("descending to depth {:?}", depth);
116 }
117 _ => {
118 self.read_event().await?;
119 }
120 }
121 }
122 }
123
124 pub async fn deserialize<T>(&mut self) -> Result<T, Error>
126 where
127 T: FromXml<B>,
128 {
129 let mut visitor = T::Visitor::default();
130 let dec = self.reader.decoder();
131
132 let start_tag;
133 let element_span;
134 match self.peek_event().await? {
135 Event::Start(start) => {
136 let name = start.local_name();
138 let name = dec.decode(name.as_ref())?;
139 tracing::debug!("deserializing XML element <{}>", name);
140 if let Some(expected_name) = T::Visitor::start_name() {
141 if name != expected_name {
142 return Err(Error::WrongStart(expected_name.into(), name.into()));
143 }
144 }
145 start_tag = name.to_string();
147 element_span = tracing::debug_span!("deserialize", element = start_tag);
148 let span_guard = element_span.enter();
149 visitor.visit_tag(&name)?;
150 for attr in start.attributes() {
152 let attr = attr?;
153 let attr_name = dec.decode(attr.key.as_ref())?;
154 let attr_value = dec.decode(attr.value.as_ref())?;
155 let attr_value = quick_xml::escape::unescape(&attr_value)?;
156 tracing::trace!("visiting attribute: {:?}", attr_name);
157 visitor.visit_attribute(&attr_name, &attr_value)?;
158 }
159 drop(span_guard);
161 self.read_event()
163 .instrument(element_span.clone().or_current())
164 .await?;
165 }
166 _ => {
167 return Err(Error::MissingStart);
168 }
169 }
170
171 async move {
172 loop {
173 match self.peek_event().await? {
174 Event::End(end) => {
175 let name = end.local_name();
176 let name = dec.decode(name.as_ref())?.to_string();
177 self.read_event().await?;
179 if name != start_tag {
181 return Err(Error::WrongEnd(start_tag, name));
182 }
183 tracing::trace!("finishing deserialization of XML element <{}>", name);
184 return visitor.build();
185 }
186 Event::Text(text) => {
187 let text = dec.decode(&text)?;
188 let text = quick_xml::escape::unescape(&text)?;
189 tracing::trace!("visiting element text");
190 visitor.visit_text(&text)?;
191 self.read_event().await?;
193 }
194 Event::Start(start) => {
195 let name = start.local_name();
197 let name = dec.decode(name.as_ref())?.to_string();
198 tracing::trace!("visiting child: {:?}", name);
199 visitor.visit_child(&name, self).await?;
200 }
201 _ => {
202 self.read_event().await?;
203 }
204 }
205 }
206 }
207 .instrument(element_span.or_current())
208 .await
209 }
210}
211
212impl<'r> PeekingReader<&'r [u8]> {
213 #[allow(clippy::should_implement_trait)]
215 pub fn from_str(str: &'r str) -> Self {
216 Self::from_buf(str.as_bytes())
217 }
218}
219
220pub trait FromXml<B: AsyncBufRead + Unpin> {
222 type Visitor: Visitor<B, Output = Self> + Default;
224}
225
226#[async_trait::async_trait(?Send)]
231pub trait Visitor<B: AsyncBufRead + Unpin> {
232 type Output;
234
235 fn start_name() -> Option<&'static str> {
237 None
238 }
239
240 #[allow(unused_variables)]
244 fn visit_tag(&mut self, name: &str) -> Result<(), Error> {
245 Ok(())
246 }
247
248 #[allow(unused_variables)]
250 fn visit_attribute(&mut self, name: &str, value: &str) -> Result<(), Error> {
251 Err(Error::UnexpectedAttribute(name.into()))
252 }
253
254 #[allow(unused_variables)]
260 async fn visit_child(
261 &mut self,
262 name: &str,
263 reader: &mut PeekingReader<B>,
264 ) -> Result<(), Error> {
265 Err(Error::UnexpectedChild(name.into()))
266 }
267
268 #[allow(unused_variables)]
272 fn visit_text(&mut self, text: &str) -> Result<(), Error> {
273 Err(Error::UnexpectedText)
274 }
275
276 fn build(self) -> Result<Self::Output, Error>;
278}