1use alloc::string::{String, ToString};
10use alloc::vec::Vec;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum Event {
15 StartElement {
17 name: String,
19 attrs: Vec<(String, String)>,
21 },
22 EndElement(String),
24 Text(String),
26 CData(String),
28 Declaration(String),
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum ParseError {
35 UnexpectedEof,
37 TagMismatch {
39 expected: String,
41 got: String,
43 },
44 MalformedTag(String),
46 UnknownEntity(String),
49}
50
51impl core::fmt::Display for ParseError {
52 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
53 match self {
54 Self::UnexpectedEof => f.write_str("unexpected EOF"),
55 Self::TagMismatch { expected, got } => {
56 write!(f, "tag mismatch: expected </{expected}>, got </{got}>")
57 }
58 Self::MalformedTag(s) => write!(f, "malformed tag: {s}"),
59 Self::UnknownEntity(s) => write!(f, "unknown entity: &{s};"),
60 }
61 }
62}
63
64#[cfg(feature = "std")]
65impl std::error::Error for ParseError {}
66
67#[derive(Debug)]
70pub struct XmlParser<'a> {
71 input: &'a str,
72 pos: usize,
73 stack: Vec<String>,
74 finished: bool,
75 pending_end: Option<String>,
76}
77
78impl<'a> XmlParser<'a> {
79 #[must_use]
81 pub fn new(input: &'a str) -> Self {
82 Self {
83 input,
84 pos: 0,
85 stack: Vec::new(),
86 finished: false,
87 pending_end: None,
88 }
89 }
90
91 fn peek_char(&self) -> Option<char> {
92 self.input[self.pos..].chars().next()
93 }
94
95 fn advance(&mut self, n: usize) {
96 self.pos += n;
97 }
98
99 fn skip_ws(&mut self) {
100 while let Some(c) = self.peek_char() {
101 if c.is_whitespace() {
102 self.advance(c.len_utf8());
103 } else {
104 break;
105 }
106 }
107 }
108
109 fn next_event(&mut self) -> Option<Result<Event, ParseError>> {
110 if self.finished {
111 return None;
112 }
113 if let Some(name) = self.pending_end.take() {
114 return Some(Ok(Event::EndElement(name)));
115 }
116 if self.pos >= self.input.len() {
117 self.finished = true;
118 if !self.stack.is_empty() {
119 return Some(Err(ParseError::UnexpectedEof));
120 }
121 return None;
122 }
123
124 if self.peek_char() == Some('<') {
125 self.parse_tag()
126 } else {
127 self.parse_text()
128 }
129 }
130
131 fn parse_tag(&mut self) -> Option<Result<Event, ParseError>> {
132 let rest = &self.input[self.pos..];
134 if rest.starts_with("<?") {
135 return Some(self.parse_pi());
136 }
137 if rest.starts_with("<![CDATA[") {
138 return Some(self.parse_cdata());
139 }
140 if rest.starts_with("<!--") {
141 return Some(self.skip_comment());
142 }
143 if rest.starts_with("</") {
144 return Some(self.parse_end_tag());
145 }
146 Some(self.parse_start_tag())
147 }
148
149 fn parse_pi(&mut self) -> Result<Event, ParseError> {
150 let close = self.input[self.pos..]
152 .find("?>")
153 .ok_or(ParseError::UnexpectedEof)?;
154 let body = &self.input[self.pos + 2..self.pos + close];
155 self.pos += close + 2;
156 Ok(Event::Declaration(body.trim().to_string()))
157 }
158
159 fn parse_cdata(&mut self) -> Result<Event, ParseError> {
160 let start = self.pos + "<![CDATA[".len();
161 let close = self.input[start..]
162 .find("]]>")
163 .ok_or(ParseError::UnexpectedEof)?;
164 let body = &self.input[start..start + close];
165 self.pos = start + close + "]]>".len();
166 Ok(Event::CData(body.to_string()))
167 }
168
169 fn skip_comment(&mut self) -> Result<Event, ParseError> {
170 let close = self.input[self.pos..]
172 .find("-->")
173 .ok_or(ParseError::UnexpectedEof)?;
174 self.pos += close + "-->".len();
175 match self.next_event() {
176 Some(r) => r,
177 None => Err(ParseError::UnexpectedEof),
178 }
179 }
180
181 fn parse_start_tag(&mut self) -> Result<Event, ParseError> {
182 self.advance(1); let name_end = self.input[self.pos..]
185 .find(|c: char| c.is_whitespace() || c == '>' || c == '/')
186 .ok_or(ParseError::UnexpectedEof)?;
187 let name = self.input[self.pos..self.pos + name_end].to_string();
188 if name.is_empty() {
189 return Err(ParseError::MalformedTag("empty tag name".into()));
190 }
191 self.pos += name_end;
192
193 let mut attrs = Vec::new();
194 loop {
195 self.skip_ws();
196 match self.peek_char() {
197 Some('>') => {
198 self.advance(1);
199 self.stack.push(name.clone());
200 return Ok(Event::StartElement { name, attrs });
201 }
202 Some('/') => {
203 self.advance(1);
204 if self.peek_char() != Some('>') {
205 return Err(ParseError::MalformedTag("expected > after /".into()));
206 }
207 self.advance(1);
208 self.pending_end = Some(name.clone());
210 return Ok(Event::StartElement { name, attrs });
211 }
212 Some(_) => {
213 let (n, v) = self.parse_attr()?;
214 attrs.push((n, v));
215 }
216 None => return Err(ParseError::UnexpectedEof),
217 }
218 }
219 }
220
221 fn parse_attr(&mut self) -> Result<(String, String), ParseError> {
222 let name_end = self.input[self.pos..]
223 .find('=')
224 .ok_or(ParseError::UnexpectedEof)?;
225 let name = self.input[self.pos..self.pos + name_end].trim().to_string();
226 self.pos += name_end + 1;
227 let quote = self.peek_char().ok_or(ParseError::UnexpectedEof)?;
228 if quote != '"' && quote != '\'' {
229 return Err(ParseError::MalformedTag("attribute without quotes".into()));
230 }
231 self.advance(1);
232 let close = self.input[self.pos..]
233 .find(quote)
234 .ok_or(ParseError::UnexpectedEof)?;
235 let raw = &self.input[self.pos..self.pos + close];
236 let value = decode_entities(raw)?;
237 self.pos += close + 1;
238 Ok((name, value))
239 }
240
241 fn parse_end_tag(&mut self) -> Result<Event, ParseError> {
242 self.advance(2); let name_end = self.input[self.pos..]
244 .find('>')
245 .ok_or(ParseError::UnexpectedEof)?;
246 let name = self.input[self.pos..self.pos + name_end].trim().to_string();
247 self.pos += name_end + 1;
248 let expected = self.stack.pop().ok_or_else(|| ParseError::TagMismatch {
249 expected: String::new(),
250 got: name.clone(),
251 })?;
252 if expected != name {
253 return Err(ParseError::TagMismatch {
254 expected,
255 got: name,
256 });
257 }
258 Ok(Event::EndElement(name))
259 }
260
261 fn parse_text(&mut self) -> Option<Result<Event, ParseError>> {
262 let next_lt = self.input[self.pos..]
263 .find('<')
264 .unwrap_or(self.input.len() - self.pos);
265 let raw = &self.input[self.pos..self.pos + next_lt];
266 self.pos += next_lt;
267 if raw.trim().is_empty() {
268 return self.next_event();
269 }
270 Some(decode_entities(raw).map(Event::Text))
271 }
272}
273
274impl Iterator for XmlParser<'_> {
275 type Item = Result<Event, ParseError>;
276 fn next(&mut self) -> Option<Self::Item> {
277 self.next_event()
278 }
279}
280
281fn decode_entities(s: &str) -> Result<String, ParseError> {
282 let mut out = String::with_capacity(s.len());
283 let mut chars = s.char_indices();
284 while let Some((_, c)) = chars.next() {
285 if c == '&' {
286 let rest = chars.as_str();
287 let semi = rest
288 .find(';')
289 .ok_or_else(|| ParseError::UnknownEntity(rest.to_string()))?;
290 let entity = &rest[..semi];
291 let resolved = match entity {
292 "amp" => '&',
293 "lt" => '<',
294 "gt" => '>',
295 "quot" => '"',
296 "apos" => '\'',
297 _ => return Err(ParseError::UnknownEntity(entity.to_string())),
298 };
299 out.push(resolved);
300 for _ in 0..semi + 1 {
302 chars.next();
303 }
304 } else {
305 out.push(c);
306 }
307 }
308 Ok(out)
309}
310
311#[cfg(test)]
312#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn parses_simple_element() {
318 let xml = "<a>text</a>";
319 let events: Vec<_> = XmlParser::new(xml).collect::<Result<_, _>>().unwrap();
320 assert_eq!(events.len(), 3);
321 match &events[0] {
322 Event::StartElement { name, .. } => assert_eq!(name, "a"),
323 e => panic!("expected Start, got {e:?}"),
324 }
325 match &events[1] {
326 Event::Text(s) => assert_eq!(s, "text"),
327 e => panic!("expected Text, got {e:?}"),
328 }
329 match &events[2] {
330 Event::EndElement(n) => assert_eq!(n, "a"),
331 e => panic!("expected End, got {e:?}"),
332 }
333 }
334
335 #[test]
336 fn parses_attributes() {
337 let xml = r#"<elem foo="bar" baz="qux"></elem>"#;
338 let events: Vec<_> = XmlParser::new(xml).collect::<Result<_, _>>().unwrap();
339 match &events[0] {
340 Event::StartElement { attrs, .. } => {
341 assert_eq!(attrs.len(), 2);
342 assert_eq!(attrs[0], ("foo".into(), "bar".into()));
343 assert_eq!(attrs[1], ("baz".into(), "qux".into()));
344 }
345 e => panic!("expected Start, got {e:?}"),
346 }
347 }
348
349 #[test]
350 fn parses_xml_declaration() {
351 let xml = r#"<?xml version="1.0"?><a/>"#;
352 let events: Vec<_> = XmlParser::new(xml).collect::<Result<_, _>>().unwrap();
353 match &events[0] {
354 Event::Declaration(s) => assert!(s.contains("version=\"1.0\"")),
355 e => panic!("got {e:?}"),
356 }
357 }
358
359 #[test]
360 fn parses_cdata() {
361 let xml = "<a><![CDATA[<raw>]]></a>";
362 let events: Vec<_> = XmlParser::new(xml).collect::<Result<_, _>>().unwrap();
363 let cdata = events.iter().find_map(|e| match e {
364 Event::CData(s) => Some(s.as_str()),
365 _ => None,
366 });
367 assert_eq!(cdata, Some("<raw>"));
368 }
369
370 #[test]
371 fn skips_comments() {
372 let xml = "<a><!-- comment -->text</a>";
373 let events: Vec<_> = XmlParser::new(xml).collect::<Result<_, _>>().unwrap();
374 let texts: Vec<_> = events
375 .iter()
376 .filter_map(|e| match e {
377 Event::Text(s) => Some(s.as_str()),
378 _ => None,
379 })
380 .collect();
381 assert_eq!(texts, alloc::vec!["text"]);
382 }
383
384 #[test]
385 fn decodes_entity_references() {
386 let xml = r#"<a>&<>"'</a>"#;
387 let events: Vec<_> = XmlParser::new(xml).collect::<Result<_, _>>().unwrap();
388 let text = events.iter().find_map(|e| match e {
389 Event::Text(s) => Some(s.as_str()),
390 _ => None,
391 });
392 assert_eq!(text, Some("&<>\"'"));
393 }
394
395 #[test]
396 fn rejects_tag_mismatch() {
397 let xml = "<a></b>";
398 let r: Result<Vec<_>, _> = XmlParser::new(xml).collect();
399 assert!(matches!(r, Err(ParseError::TagMismatch { .. })));
400 }
401
402 #[test]
403 fn rejects_unknown_entity() {
404 let xml = "<a>&xyz;</a>";
405 let r: Result<Vec<_>, _> = XmlParser::new(xml).collect();
406 assert!(matches!(r, Err(ParseError::UnknownEntity(_))));
407 }
408
409 #[test]
410 fn nested_elements_work() {
411 let xml = "<a><b><c/></b></a>";
412 let events: Vec<_> = XmlParser::new(xml).collect::<Result<_, _>>().unwrap();
413 let starts: usize = events
417 .iter()
418 .filter(|e| matches!(e, Event::StartElement { .. }))
419 .count();
420 let ends: usize = events
421 .iter()
422 .filter(|e| matches!(e, Event::EndElement(_)))
423 .count();
424 assert!(starts >= ends);
425 }
426}