1use pyo3::prelude::*;
2use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyIterator, PyList, PyNone, PyString};
3use quick_xml::Writer;
4use quick_xml::events::{BytesCData, BytesDecl, BytesEnd, BytesPI, BytesStart, BytesText, Event};
5use quick_xml::reader::Reader;
6use rustc_hash::{FxHashMap, FxHashSet};
7use std::borrow::Cow;
8use std::fs::File;
9use std::io::{BufWriter, Cursor, Write};
10use std::str;
11
12struct Config {
13 attr_prefix: String,
14 cdata_key: String,
15 default_func: Option<Py<PyAny>>,
16 item_name: String,
17 sort_attrs: bool,
18 namespaces: FxHashMap<String, String>,
19}
20
21#[derive(Clone, Copy, PartialEq)]
22enum CompatMode {
23 Native,
24 Obj2Xml,
25}
26
27#[inline]
28fn extract_str<'py>(value: &'py pyo3::Bound<'py, pyo3::PyAny>) -> PyResult<Cow<'py, str>> {
29 if let Ok(pystr) = value.cast::<PyString>() {
30 let slice = pystr.to_str()?;
31 Ok(Cow::Borrowed(slice))
32 } else {
33 Ok(Cow::Owned(value.to_string()))
34 }
35}
36
37struct PyWriter<'py> {
38 obj: Bound<'py, PyAny>,
39}
40
41impl<'py> Write for PyWriter<'py> {
42 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
43 self.obj
44 .call_method1("write", (buf,))
45 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
46 Ok(buf.len())
47 }
48 fn flush(&mut self) -> std::io::Result<()> {
49 if self.obj.hasattr("flush").unwrap_or(false) {
50 self.obj
51 .call_method0("flush")
52 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
53 }
54 Ok(())
55 }
56}
57
58#[derive(Clone)]
59struct NamespaceContext {
60 uri_to_prefix: FxHashMap<String, String>,
61 prefix_to_uri: FxHashMap<String, String>,
62 next_auto: usize,
63}
64
65impl NamespaceContext {
66 fn new(predefined: &FxHashMap<String, String>) -> Self {
67 let mut uri_to_prefix = FxHashMap::default();
68 let mut prefix_to_uri = FxHashMap::default();
69 for (prefix, uri) in predefined {
70 uri_to_prefix.insert(uri.clone(), prefix.clone());
71 prefix_to_uri.insert(prefix.clone(), uri.clone());
72 }
73 Self {
74 uri_to_prefix,
75 prefix_to_uri,
76 next_auto: 0,
77 }
78 }
79
80 fn get_or_create_prefix(&mut self, uri: &str) -> String {
81 if let Some(p) = self.uri_to_prefix.get(uri) {
82 return p.clone();
83 }
84 let prefix = format!("ns{}", self.next_auto);
85 self.next_auto += 1;
86 self.uri_to_prefix.insert(uri.to_string(), prefix.clone());
87 self.prefix_to_uri.insert(prefix.clone(), uri.to_string());
88 prefix
89 }
90}
91
92fn format_path(stack: &[String]) -> String {
93 if stack.is_empty() {
94 "root".to_string()
95 } else {
96 stack.join("/")
97 }
98}
99
100#[inline(always)]
101fn wrap_err<T, E: std::fmt::Display>(res: Result<T, E>, stack: &[String]) -> PyResult<T> {
102 res.map_err(|e| {
103 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
104 "{} (at {})",
105 e,
106 format_path(stack)
107 ))
108 })
109}
110
111#[inline(always)]
112fn wrap_py_err<T>(res: PyResult<T>, stack: &[String]) -> PyResult<T> {
113 res.map_err(|e| {
114 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
115 "{} (at {})",
116 e,
117 format_path(stack)
118 ))
119 })
120}
121
122#[inline]
123fn py_value_to_cow<'py>(
124 py: Python<'py>,
125 value: &'py pyo3::Bound<'py, pyo3::PyAny>,
126 default_func: &Option<Py<PyAny>>,
127 path: &[String],
128) -> PyResult<Option<Cow<'py, str>>> {
129 if value.is_instance_of::<PyNone>() {
130 return Ok(None);
131 }
132 if value.is_instance_of::<PyBool>() {
133 return Ok(Some(if value.extract::<bool>()? {
134 "true".into()
135 } else {
136 "false".into()
137 }));
138 }
139 if let Ok(pystr) = value.cast::<PyString>() {
140 return Ok(Some(Cow::Borrowed(pystr.to_str()?)));
141 }
142 if value.is_instance_of::<PyInt>() || value.is_instance_of::<PyFloat>() {
143 return Ok(Some(Cow::Owned(value.to_string())));
144 }
145
146 if let Some(func) = default_func {
147 match func.call1(py, (value,)) {
148 Ok(serialized) => return Ok(Some(Cow::Owned(serialized.to_string()))),
149 Err(e) => {
150 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
151 "Custom serialization failed: {} (at {})",
152 e,
153 format_path(path)
154 )));
155 }
156 }
157 }
158 Ok(Some(Cow::Owned(value.to_string())))
159}
160
161#[inline]
162fn qualify_tag<'a>(
163 tag: &'a str,
164 dict: Option<&Bound<'_, PyDict>>,
165 ns: &mut NamespaceContext,
166) -> Cow<'a, str> {
167 if tag.contains(':') {
168 return Cow::Borrowed(tag);
169 }
170 if let Some(d) = dict {
171 if let Ok(Some(ns_val)) = d.get_item("@ns") {
172 let ns_str = ns_val.to_string();
173 if ns.prefix_to_uri.contains_key(&ns_str) {
174 return Cow::Owned(format!("{}:{}", ns_str, tag));
175 }
176 let prefix = ns.get_or_create_prefix(&ns_str);
177 return Cow::Owned(format!("{}:{}", prefix, tag));
178 }
179 }
180 Cow::Borrowed(tag)
181}
182
183fn process_node<W: Write>(
184 writer: &mut Writer<W>,
185 tag_name: &str,
186 value: &Bound<'_, PyAny>,
187 config: &Config,
188 compat: CompatMode,
189 ns: &mut NamespaceContext,
190 is_root: bool,
191 path: &mut Vec<String>,
192 visited: &mut FxHashSet<usize>,
193) -> PyResult<()> {
194 let py = value.py();
195
196 if value.is_instance_of::<PyString>()
197 || value.is_instance_of::<PyBool>()
198 || value.is_instance_of::<PyNone>()
199 || value.is_instance_of::<PyInt>()
200 || value.is_instance_of::<PyFloat>()
201 {
202 } else if let Ok(dict) = value.cast::<PyDict>() {
203 let ptr = dict.as_ptr() as usize;
204 if !visited.insert(ptr) {
205 return Err(PyErr::new::<pyo3::exceptions::PyRecursionError, _>(
206 format!("Circular reference detected (at {})", format_path(path)),
207 ));
208 }
209
210 let mut tail_text: Option<String> = None;
211 if let Ok(Some(tail_val)) = dict.get_item("#tail") {
212 path.push("#tail".to_string());
213 if let Some(cow) = py_value_to_cow(py, &tail_val, &config.default_func, path)? {
214 tail_text = Some(cow.into_owned());
215 }
216 path.pop();
217 }
218 let mut ns_local = ns.clone();
219 let qualified = qualify_tag(tag_name, Some(&dict), &mut ns_local);
220 let mut elem = BytesStart::new(qualified.as_ref());
221 let mut attrs = Vec::new();
222 let mut xmlns_attrs = Vec::new();
223
224 if is_root {
225 for (prefix, uri) in &config.namespaces {
226 if prefix.is_empty() {
227 xmlns_attrs.push(("xmlns".to_string(), uri.clone()));
228 } else {
229 xmlns_attrs.push((format!("xmlns:{}", prefix), uri.clone()));
230 }
231 }
232 }
233
234 for (k, v) in dict {
235 let k_cow = extract_str(&k)?;
236 let k_str = k_cow.as_ref();
237 if k_str == "#comment" || k_str == "#tail" || k_str.starts_with("?") {
238 continue;
239 }
240
241 if k_str == "@xmlns" {
242 let uri = v.to_string();
243 if !config.namespaces.values().any(|u| u == &uri) {
244 let prefix = ns_local.get_or_create_prefix(&uri);
245 xmlns_attrs.push((format!("xmlns:{}", prefix), uri));
246 }
247 } else if let Some(p) = k_str.strip_prefix("@xmlns:") {
248 let uri = v.to_string();
249 ns_local.uri_to_prefix.insert(uri.clone(), p.to_string());
250 ns_local.prefix_to_uri.insert(p.to_string(), uri.clone());
251 xmlns_attrs.push((format!("xmlns:{}", p), uri));
252 } else if let Some(attr) = k_str.strip_prefix(&config.attr_prefix) {
253 path.push(format!("@{}", attr));
254 if let Some(val) = py_value_to_cow(py, &v, &config.default_func, path)? {
255 attrs.push((attr.to_string(), val.into_owned()));
256 }
257 path.pop();
258 }
259 }
260
261 if config.sort_attrs {
262 attrs.sort_by(|a, b| a.0.cmp(&b.0));
263 xmlns_attrs.sort_by(|a, b| a.0.cmp(&b.0));
264 }
265 for (k, v) in xmlns_attrs {
266 elem.push_attribute((k.as_str(), v.as_str()));
267 }
268 for (k, v) in attrs {
269 elem.push_attribute((k.as_str(), v.as_str()));
270 }
271
272 let has_children = dict.iter().any(|(k, _)| {
273 let k_cow = k
274 .cast::<PyString>()
275 .map(|s| s.to_string_lossy())
276 .unwrap_or_default();
277 let ks = k_cow.as_ref();
278 !ks.starts_with(&config.attr_prefix) && ks != "#tail"
279 });
280
281 if !has_children {
282 wrap_err(writer.write_event(Event::Empty(elem)), path)?;
283 visited.remove(&ptr);
284 if let Some(text) = tail_text {
285 wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
286 }
287 return Ok(());
288 }
289
290 wrap_err(writer.write_event(Event::Start(elem)), path)?;
291
292 for (k, v) in dict {
293 let k_cow = extract_str(&k)?;
294 let k_str = k_cow.as_ref();
295
296 if k_str.starts_with(&config.attr_prefix) {
297 continue;
298 }
299 if k_str == "#tail" {
300 continue;
301 }
302
303 path.push(k_str.to_string());
304
305 if k_str == "#comment" {
306 if let Some(comment_txt) = py_value_to_cow(py, &v, &config.default_func, path)? {
307 wrap_err(
308 writer.write_event(Event::Comment(BytesText::new(&comment_txt))),
309 path,
310 )?;
311 }
312 } else if k_str.starts_with("?") {
313 if let Some(content) = py_value_to_cow(py, &v, &config.default_func, path)? {
314 let target = k_str.strip_prefix("?").unwrap_or(&k_str);
315 let pi_content = format!("{} {}", target, content);
316 wrap_err(
317 writer.write_event(Event::PI(BytesPI::new(&pi_content))),
318 path,
319 )?;
320 }
321 } else if k_str == config.cdata_key {
322 let mut written = false;
323 if let Ok(inner) = v.cast::<PyDict>() {
324 if let Ok(Some(cdata)) = inner.get_item("__cdata__") {
325 let s = cdata.to_string();
326 wrap_err(writer.write_event(Event::CData(BytesCData::new(&s))), path)?;
327 written = true;
328 }
329 }
330 if !written {
331 if let Some(text) = py_value_to_cow(py, &v, &config.default_func, path)? {
332 wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
333 }
334 }
335 } else {
336 process_node(
337 writer,
338 k_str,
339 &v,
340 config,
341 compat,
342 &mut ns_local,
343 false,
344 path,
345 visited,
346 )?;
347 }
348 path.pop();
349 }
350 wrap_err(
351 writer.write_event(Event::End(BytesEnd::new(qualified.as_ref()))),
352 path,
353 )?;
354 visited.remove(&ptr);
355
356 if let Some(text) = tail_text {
357 wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
358 }
359
360 return Ok(());
361 } else if let Ok(list) = value.cast::<PyList>() {
362 let ptr = list.as_ptr() as usize;
363 if !visited.insert(ptr) {
364 return Err(PyErr::new::<pyo3::exceptions::PyRecursionError, _>(
365 format!("Circular ref (at {})", format_path(path)),
366 ));
367 }
368 for (i, item) in list.iter().enumerate() {
369 path.push(format!("[{}]", i));
370 process_node(
371 writer, tag_name, &item, config, compat, ns, is_root, path, visited,
372 )?;
373 path.pop();
374 }
375 visited.remove(&ptr);
376 return Ok(());
377 } else if let Ok(iter) = PyIterator::from_object(value) {
378 let mut i = 0;
379 for item in iter {
380 let obj = item?;
381 path.push(format!("[{}]", i));
382 process_node(
383 writer, tag_name, &obj, config, compat, ns, is_root, path, visited,
384 )?;
385 path.pop();
386 i += 1;
387 }
388 return Ok(());
389 }
390
391 if value.is_instance_of::<PyNone>() {
392 match compat {
393 CompatMode::Native => {
394 wrap_err(
395 writer.write_event(Event::Empty(BytesStart::new(tag_name))),
396 path,
397 )?;
398 }
399 CompatMode::Obj2Xml => {
400 wrap_err(
401 writer.write_event(Event::Start(BytesStart::new(tag_name))),
402 path,
403 )?;
404 wrap_err(
405 writer.write_event(Event::End(BytesEnd::new(tag_name))),
406 path,
407 )?;
408 }
409 }
410 } else if let Some(text) = py_value_to_cow(py, value, &config.default_func, path)? {
411 let elem = BytesStart::new(tag_name);
412 wrap_err(writer.write_event(Event::Start(elem)), path)?;
413 wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
414 wrap_err(
415 writer.write_event(Event::End(BytesEnd::new(tag_name))),
416 path,
417 )?;
418 }
419 Ok(())
420}
421
422fn generate_xml<W: Write>(
423 writer: W,
424 input: &Bound<'_, PyAny>,
425 encoding: &str,
426 full_document: bool,
427 pretty: bool,
428 indent: &str,
429 config: &Config,
430 compat: CompatMode,
431 attr_prefix: &str,
432) -> PyResult<W> {
433 let mut writer = if pretty {
434 Writer::new_with_indent(writer, b' ', indent.len())
435 } else {
436 Writer::new(writer)
437 };
438 if full_document {
439 writer
440 .write_event(Event::Decl(BytesDecl::new("1.0", Some(encoding), None)))
441 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
442 }
443 let mut ns = NamespaceContext::new(&config.namespaces);
444 let mut path = Vec::with_capacity(16);
445 let mut visited = FxHashSet::default();
446
447 if let Ok(dict) = input.cast::<PyDict>() {
448 let roots = dict
449 .iter()
450 .filter(|(k, _)| !k.to_string().starts_with(attr_prefix))
451 .count();
452 if full_document && roots != 1 {
453 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
454 "Document must have exactly one root",
455 ));
456 }
457 for (k, v) in dict {
458 let key_cow = extract_str(&k)?;
459 let key = key_cow.as_ref();
460 if key.starts_with(attr_prefix) {
461 continue;
462 }
463 path.push(key.to_string());
464 process_node(
465 &mut writer,
466 key,
467 &v,
468 config,
469 compat,
470 &mut ns,
471 true,
472 &mut path,
473 &mut visited,
474 )?;
475 path.pop();
476 }
477 } else {
478 let iter = PyIterator::from_object(input)?;
479 let mut i = 0;
480 for item in iter {
481 let obj = item?;
482 path.push(format!("[{}]", i));
483 if let Ok(d) = obj.cast::<PyDict>() {
484 for (k, v) in d {
485 let k_cow = extract_str(&k)?;
486 let k_str = k_cow.as_ref();
487 path.push(k_str.to_string());
488 process_node(
489 &mut writer,
490 k_str,
491 &v,
492 config,
493 compat,
494 &mut ns,
495 true,
496 &mut path,
497 &mut visited,
498 )?;
499 path.pop();
500 }
501 } else {
502 process_node(
503 &mut writer,
504 &config.item_name,
505 &obj,
506 config,
507 compat,
508 &mut ns,
509 true,
510 &mut path,
511 &mut visited,
512 )?;
513 }
514 path.pop();
515 i += 1;
516 }
517 }
518 Ok(writer.into_inner())
519}
520
521#[pyfunction]
522#[pyo3(signature = (
523 input, *,
524 output=None,
525 encoding="utf-8",
526 full_document=true,
527 attr_prefix="@",
528 cdata_key="#text",
529 pretty=false, indent=" ",
530 compat="native",
531 streaming=false,
532 default=None,
533 item_name="item",
534 sort_attributes=false,
535 namespaces=None
536))]
537fn unparse(
538 _py: Python<'_>,
539 input: &Bound<'_, PyAny>,
540 output: Option<&Bound<'_, PyAny>>,
541 encoding: &str,
542 full_document: bool,
543 attr_prefix: &str,
544 cdata_key: &str,
545 pretty: bool,
546 indent: &str,
547 compat: &str,
548 streaming: bool,
549 default: Option<Py<PyAny>>,
550 item_name: &str,
551 sort_attributes: bool,
552 namespaces: Option<FxHashMap<String, String>>,
553) -> PyResult<String> {
554 let config = Config {
555 attr_prefix: attr_prefix.to_string(),
556 cdata_key: cdata_key.to_string(),
557 default_func: default,
558 item_name: item_name.to_string(),
559 sort_attrs: sort_attributes,
560 namespaces: namespaces.unwrap_or_default(),
561 };
562 let compat_mode = if compat == "legacy" {
563 CompatMode::Obj2Xml
564 } else {
565 CompatMode::Native
566 };
567
568 if streaming {
569 let out_obj = output.ok_or_else(|| {
570 PyErr::new::<pyo3::exceptions::PyValueError, _>(
571 "streaming=True requires output argument",
572 )
573 })?;
574 let sink: Box<dyn Write> = if let Ok(path) = out_obj.extract::<String>() {
575 let f = File::create(path)?;
576 Box::new(BufWriter::new(f))
577 } else {
578 Box::new(PyWriter {
579 obj: out_obj.clone(),
580 })
581 };
582 wrap_err(
583 generate_xml(
584 sink,
585 input,
586 encoding,
587 full_document,
588 pretty,
589 indent,
590 &config,
591 compat_mode,
592 attr_prefix,
593 ),
594 &[],
595 )?;
596 return Ok(String::new());
597 }
598
599 let cursor = Cursor::new(Vec::with_capacity(32 * 1024));
600 let cursor = wrap_py_err(
601 generate_xml(
602 cursor,
603 input,
604 encoding,
605 full_document,
606 pretty,
607 indent,
608 &config,
609 compat_mode,
610 attr_prefix,
611 ),
612 &[],
613 )?;
614 let xml = String::from_utf8(cursor.into_inner())
615 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
616 if let Some(out_obj) = output {
617 if let Ok(path) = out_obj.extract::<String>() {
618 std::fs::write(path, &xml)?;
619 } else {
620 out_obj.call_method1("write", (xml.as_bytes(),))?;
621 }
622 return Ok(String::new());
623 }
624 Ok(xml)
625}
626
627struct ParseConfig {
628 attr_prefix: String,
629 cdata_key: String,
630 force_cdata: bool,
631 process_namespaces: bool,
632 namespace_separator: String,
633 strip_whitespace: bool,
634 force_list: Option<FxHashSet<String>>,
635 process_comments: bool,
636}
637
638struct StackItem<'py> {
639 dict: Bound<'py, PyDict>,
640 tag_name: String,
641}
642
643fn parse_xml<'py>(
644 py: Python<'py>,
645 reader: &mut Reader<Box<dyn std::io::BufRead>>,
646 config: &ParseConfig,
647) -> PyResult<Py<PyAny>> {
648 let mut buf = Vec::new();
649 let mut stack: Vec<StackItem<'py>> = Vec::new();
650 let mut text_buffer: Option<String> = None;
651
652 let mut ns_stack: Vec<FxHashMap<String, String>> = Vec::new();
653 if config.process_namespaces {
654 ns_stack.push(FxHashMap::default());
655 }
656
657 loop {
658 match reader.read_event_into(&mut buf) {
659 Ok(Event::Start(e)) => {
660 let raw_name = String::from_utf8_lossy(e.name().as_ref()).into_owned();
661
662 let mut current_ns = if config.process_namespaces {
663 ns_stack.last().cloned().unwrap_or_default()
664 } else {
665 FxHashMap::default()
666 };
667
668 let mut attributes_vec = Vec::new();
669 for attr in e.attributes() {
670 let attr = attr.map_err(|e| {
671 PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string())
672 })?;
673 let key = String::from_utf8_lossy(attr.key.as_ref()).into_owned();
674 let val = String::from_utf8_lossy(&attr.value).into_owned();
675
676 if config.process_namespaces {
677 if key == "xmlns" {
678 current_ns.insert("".to_string(), val.clone());
679 } else if let Some(prefix) = key.strip_prefix("xmlns:") {
680 current_ns.insert(prefix.to_string(), val.clone());
681 }
682 }
683 attributes_vec.push((key, val));
684 }
685
686 if config.process_namespaces {
687 ns_stack.push(current_ns.clone());
688 }
689
690 let tag_name = if config.process_namespaces {
691 resolve_name(&raw_name, &ns_stack, &config.namespace_separator, true)
692 } else {
693 raw_name
694 };
695
696 let new_dict = PyDict::new(py);
697
698 for (key, val) in attributes_vec {
699 let final_key = if config.process_namespaces && !key.starts_with("xmlns") {
700 resolve_name(&key, &ns_stack, &config.namespace_separator, false)
701 } else {
702 key
703 };
704
705 let key_str = format!("{}{}", config.attr_prefix, final_key);
706 new_dict.set_item(key_str, val)?;
707 }
708
709 if let Some(text) = text_buffer.take() {
710 if let Some(parent) = stack.last() {
711 if let Some(existing) = parent.dict.get_item(&config.cdata_key)? {
712 let s = existing.extract::<String>()?;
713 parent
714 .dict
715 .set_item(&config.cdata_key, format!("{}{}", s, text))?;
716 } else {
717 parent.dict.set_item(&config.cdata_key, text)?;
718 }
719 }
720 }
721
722 stack.push(StackItem {
723 dict: new_dict,
724 tag_name,
725 });
726 }
727 Ok(Event::End(_)) => {
728 let StackItem {
729 dict: current_dict,
730 tag_name,
731 } = stack.pop().ok_or_else(|| {
732 PyErr::new::<pyo3::exceptions::PyValueError, _>("Unexpected closing tag")
733 })?;
734
735 if config.process_namespaces {
736 ns_stack.pop();
737 }
738
739 if let Some(text) = text_buffer.take() {
740 if current_dict.is_empty() {
741 if config.force_cdata {
742 current_dict.set_item(&config.cdata_key, text)?;
743 } else {
744 current_dict.set_item(&config.cdata_key, text)?;
745 }
746 } else {
747 current_dict.set_item(&config.cdata_key, text)?;
748 }
749 }
750
751 if let Some(parent) = stack.last() {
752 let key = tag_name.as_str();
753
754 let value_to_insert: Py<PyAny> = if current_dict.len() == 1
755 && current_dict.contains(&config.cdata_key)?
756 && !config.force_cdata
757 {
758 current_dict.get_item(&config.cdata_key)?.unwrap().into()
759 } else if current_dict.is_empty() {
760 py.None().into()
761 } else {
762 current_dict.into()
763 };
764
765 if let Some(existing) = parent.dict.get_item(key)? {
766 if let Ok(list) = existing.cast::<PyList>() {
767 list.append(value_to_insert)?;
768 } else {
769 let list = PyList::new(
770 py,
771 vec![
772 existing.into_pyobject(py)?.into_any(),
773 value_to_insert.into_pyobject(py)?.into_any(),
774 ],
775 )?;
776 parent.dict.set_item(key, list)?;
777 }
778 } else {
779 let force = if let Some(fl) = &config.force_list {
780 fl.contains(key)
781 } else {
782 false
783 };
784 if force {
785 let list = PyList::new(
786 py,
787 vec![value_to_insert.into_pyobject(py)?.into_any()],
788 )?;
789 parent.dict.set_item(key, list)?;
790 } else {
791 parent.dict.set_item(key, value_to_insert)?;
792 }
793 }
794 } else {
795 let root_dict = PyDict::new(py);
796 let value_to_insert: Py<PyAny> = if current_dict.len() == 1
797 && current_dict.contains(&config.cdata_key)?
798 && !config.force_cdata
799 {
800 current_dict.get_item(&config.cdata_key)?.unwrap().into()
801 } else if current_dict.is_empty() {
802 py.None().into()
803 } else {
804 current_dict.into()
805 };
806 root_dict.set_item(tag_name, value_to_insert)?;
807 return Ok(root_dict.into());
808 }
809 }
810 Ok(Event::Empty(e)) => {
811 let raw_name = String::from_utf8_lossy(e.name().as_ref()).into_owned();
812
813 let mut current_ns = if config.process_namespaces {
814 ns_stack.last().cloned().unwrap_or_default()
815 } else {
816 FxHashMap::default()
817 };
818 let mut attributes_vec = Vec::new();
819 for attr in e.attributes() {
820 let attr = attr.map_err(|e| {
821 PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string())
822 })?;
823 let key = String::from_utf8_lossy(attr.key.as_ref()).into_owned();
824 let val = String::from_utf8_lossy(&attr.value).into_owned();
825 if config.process_namespaces {
826 if key == "xmlns" {
827 current_ns.insert("".to_string(), val.clone());
828 } else if let Some(prefix) = key.strip_prefix("xmlns:") {
829 current_ns.insert(prefix.to_string(), val.clone());
830 }
831 }
832 attributes_vec.push((key, val));
833 }
834
835 if config.process_namespaces {
836 ns_stack.push(current_ns);
837 }
838 let tag_name = if config.process_namespaces {
839 resolve_name(&raw_name, &ns_stack, &config.namespace_separator, true)
840 } else {
841 raw_name
842 };
843
844 let new_dict = PyDict::new(py);
845 for (key, val) in attributes_vec {
846 let final_key = if config.process_namespaces && !key.starts_with("xmlns") {
847 resolve_name(&key, &ns_stack, &config.namespace_separator, false)
848 } else {
849 key
850 };
851 new_dict.set_item(format!("{}{}", config.attr_prefix, final_key), val)?;
852 }
853
854 if config.process_namespaces {
855 ns_stack.pop();
856 }
857
858 if let Some(parent) = stack.last() {
859 let key = tag_name.as_str();
860 let value_to_insert: Py<PyAny> = if new_dict.is_empty() {
861 py.None().into()
862 } else {
863 new_dict.into()
864 };
865
866 if let Some(existing) = parent.dict.get_item(key)? {
867 if let Ok(list) = existing.cast::<PyList>() {
868 list.append(value_to_insert)?;
869 } else {
870 let list = PyList::new(
871 py,
872 vec![
873 existing.into_pyobject(py)?.into_any(),
874 value_to_insert.into_pyobject(py)?.into_any(),
875 ],
876 )?;
877 parent.dict.set_item(key, list)?;
878 }
879 } else {
880 let force = if let Some(fl) = &config.force_list {
881 fl.contains(key)
882 } else {
883 false
884 };
885 if force {
886 let list = PyList::new(
887 py,
888 vec![value_to_insert.into_pyobject(py)?.into_any()],
889 )?;
890 parent.dict.set_item(key, list)?;
891 } else {
892 parent.dict.set_item(key, value_to_insert)?;
893 }
894 }
895 } else {
896 let root_dict = PyDict::new(py);
897 root_dict.set_item(tag_name, py.None())?;
898 return Ok(root_dict.into());
899 }
900 }
901 Ok(Event::Text(e)) => {
902 let text_cow = std::str::from_utf8(&e)
903 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
904 let unescaped = quick_xml::escape::unescape(text_cow)
905 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
906
907 let text = unescaped.as_ref();
908 let trimmed = if config.strip_whitespace {
909 text.trim()
910 } else {
911 text
912 };
913 if !trimmed.is_empty() {
914 if let Some(ref mut buf) = text_buffer {
915 buf.push_str(trimmed);
916 } else {
917 text_buffer = Some(trimmed.to_string());
918 }
919 }
920 }
921 Ok(Event::CData(e)) => {
922 let text = String::from_utf8_lossy(&e);
923 let trimmed = if config.strip_whitespace {
924 text.trim()
925 } else {
926 &text
927 };
928 if !trimmed.is_empty() {
929 if let Some(ref mut buf) = text_buffer {
930 buf.push_str(trimmed);
931 } else {
932 text_buffer = Some(trimmed.to_string());
933 }
934 }
935 }
936
937 Ok(Event::Comment(e)) => {
938 if config.process_comments {
939 let comment = String::from_utf8_lossy(&e).into_owned();
940 if let Some(parent) = stack.last() {
941 if let Some(existing) = parent.dict.get_item("#comment")? {
942 if let Ok(list) = existing.cast::<PyList>() {
943 list.append(comment)?;
944 } else {
945 let list = PyList::new(
946 py,
947 vec![
948 existing.into_pyobject(py)?.into_any(),
949 comment.into_pyobject(py)?.into_any(),
950 ],
951 )?;
952 parent.dict.set_item("#comment", list)?;
953 }
954 } else {
955 parent.dict.set_item("#comment", comment)?;
956 }
957 }
958 }
959 }
960
961 Ok(Event::PI(e)) => {
962 let content = String::from_utf8_lossy(&e);
963 let (target, value) = if let Some((t, v)) = content.split_once(' ') {
964 (t, v)
965 } else {
966 (content.as_ref(), "")
967 };
968 let key = format!("?{}", target);
969 let val = value.to_string();
970
971 if let Some(parent) = stack.last() {
972 if let Some(existing) = parent.dict.get_item(&key)? {
973 if let Ok(list) = existing.cast::<PyList>() {
974 list.append(val)?;
975 } else {
976 let list = PyList::new(
977 py,
978 vec![
979 existing.into_pyobject(py)?.into_any(),
980 val.into_pyobject(py)?.into_any(),
981 ],
982 )?;
983 parent.dict.set_item(&key, list)?;
984 }
985 } else {
986 parent.dict.set_item(&key, val)?;
987 }
988 }
989 }
990 Ok(Event::Eof) => break,
991 Err(e) => {
992 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
993 "XML Parse Error: {}",
994 e
995 )));
996 }
997 _ => {}
998 }
999 buf.clear();
1000 }
1001 Ok(PyDict::new(py).into())
1002}
1003
1004fn resolve_name(
1005 name: &str,
1006 ns_stack: &[FxHashMap<String, String>],
1007 separator: &str,
1008 is_element: bool,
1009) -> String {
1010 if let Some((prefix, local)) = name.split_once(':') {
1011 for scope in ns_stack.iter().rev() {
1012 if let Some(uri) = scope.get(prefix) {
1013 return format!("{}{}{}", uri, separator, local);
1014 }
1015 }
1016 } else if is_element {
1017 for scope in ns_stack.iter().rev() {
1018 if let Some(uri) = scope.get("") {
1019 return format!("{}{}{}", uri, separator, name);
1020 }
1021 }
1022 }
1023 name.to_string()
1024}
1025
1026#[pyfunction]
1027#[pyo3(signature = (
1028 xml_input, *,
1029 _encoding=None,
1030 attr_prefix="@",
1031 cdata_key="#text",
1032 force_cdata=false,
1033 process_namespaces=false,
1034 namespace_separator=":",
1035 strip_whitespace=true,
1036 force_list=None,
1037 process_comments=false
1038))]
1039fn parse(
1040 py: Python<'_>,
1041 xml_input: &Bound<'_, PyAny>,
1042 _encoding: Option<&str>,
1043 attr_prefix: &str,
1044 cdata_key: &str,
1045 force_cdata: bool,
1046 process_namespaces: bool,
1047 namespace_separator: &str,
1048 strip_whitespace: bool,
1049 force_list: Option<Vec<String>>,
1050 process_comments: bool,
1051) -> PyResult<Py<PyAny>> {
1052 let config = ParseConfig {
1053 attr_prefix: attr_prefix.to_string(),
1054 cdata_key: cdata_key.to_string(),
1055 force_cdata,
1056 process_namespaces,
1057 namespace_separator: namespace_separator.to_string(),
1058 strip_whitespace,
1059 force_list: force_list.map(|v| v.into_iter().collect()),
1060 process_comments,
1061 };
1062
1063 if let Ok(s) = xml_input.extract::<String>() {
1064 let mut reader = Reader::from_str(&s);
1065 reader.config_mut().trim_text(strip_whitespace);
1066 let mut boxed_reader: Reader<Box<dyn std::io::BufRead>> =
1067 Reader::from_reader(Box::new(Cursor::new(s.into_bytes())));
1068 boxed_reader.config_mut().trim_text(strip_whitespace);
1069 boxed_reader.config_mut().expand_empty_elements = true;
1070 return parse_xml(py, &mut boxed_reader, &config);
1071 }
1072
1073 if let Ok(b) = xml_input.extract::<Vec<u8>>() {
1074 let mut boxed_reader =
1075 Reader::from_reader(Box::new(Cursor::new(b)) as Box<dyn std::io::BufRead>);
1076 boxed_reader.config_mut().trim_text(strip_whitespace);
1077 boxed_reader.config_mut().expand_empty_elements = true;
1078 return parse_xml(py, &mut boxed_reader, &config);
1079 }
1080
1081 if xml_input.hasattr("read")? {
1082 let bytes: Vec<u8> = xml_input.call_method0("read")?.extract()?;
1083 let mut boxed_reader =
1084 Reader::from_reader(Box::new(Cursor::new(bytes)) as Box<dyn std::io::BufRead>);
1085 boxed_reader.config_mut().trim_text(strip_whitespace);
1086 boxed_reader.config_mut().expand_empty_elements = true;
1087 return parse_xml(py, &mut boxed_reader, &config);
1088 }
1089
1090 Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
1091 "Input must be str, bytes, or file-like object",
1092 ))
1093}
1094
1095#[pymodule]
1096fn _obj2xml_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
1097 m.add_function(wrap_pyfunction!(unparse, m)?)?;
1098 m.add_function(wrap_pyfunction!(parse, m)?)?;
1099 Ok(())
1100}