Skip to main content

_obj2xml_rs/
lib.rs

1use pyo3::prelude::*;
2use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyIterator, PyList, PyNone, PyString};
3use quick_xml::Writer;
4use quick_xml::events::{BytesCData, BytesDecl, BytesEnd, BytesPI, BytesStart, BytesText, Event};
5use quick_xml::reader::Reader;
6use rustc_hash::{FxHashMap, FxHashSet};
7use std::borrow::Cow;
8use std::fs::File;
9use std::io::{BufWriter, Cursor, Write};
10use std::str;
11
12struct Config {
13    attr_prefix: String,
14    cdata_key: String,
15    default_func: Option<Py<PyAny>>,
16    item_name: String,
17    sort_attrs: bool,
18    namespaces: FxHashMap<String, String>,
19}
20
21#[derive(Clone, Copy, PartialEq)]
22enum CompatMode {
23    Native,
24    Obj2Xml,
25}
26
27#[inline]
28fn extract_str<'py>(value: &'py pyo3::Bound<'py, pyo3::PyAny>) -> PyResult<Cow<'py, str>> {
29    if let Ok(pystr) = value.cast::<PyString>() {
30        let slice = pystr.to_str()?;
31        Ok(Cow::Borrowed(slice))
32    } else {
33        Ok(Cow::Owned(value.to_string()))
34    }
35}
36
37struct PyWriter<'py> {
38    obj: Bound<'py, PyAny>,
39}
40
41impl<'py> Write for PyWriter<'py> {
42    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
43        self.obj
44            .call_method1("write", (buf,))
45            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
46        Ok(buf.len())
47    }
48    fn flush(&mut self) -> std::io::Result<()> {
49        if self.obj.hasattr("flush").unwrap_or(false) {
50            self.obj
51                .call_method0("flush")
52                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
53        }
54        Ok(())
55    }
56}
57
58#[derive(Clone)]
59struct NamespaceContext {
60    uri_to_prefix: FxHashMap<String, String>,
61    prefix_to_uri: FxHashMap<String, String>,
62    next_auto: usize,
63}
64
65impl NamespaceContext {
66    fn new(predefined: &FxHashMap<String, String>) -> Self {
67        let mut uri_to_prefix = FxHashMap::default();
68        let mut prefix_to_uri = FxHashMap::default();
69        for (prefix, uri) in predefined {
70            uri_to_prefix.insert(uri.clone(), prefix.clone());
71            prefix_to_uri.insert(prefix.clone(), uri.clone());
72        }
73        Self {
74            uri_to_prefix,
75            prefix_to_uri,
76            next_auto: 0,
77        }
78    }
79
80    fn get_or_create_prefix(&mut self, uri: &str) -> String {
81        if let Some(p) = self.uri_to_prefix.get(uri) {
82            return p.clone();
83        }
84        let prefix = format!("ns{}", self.next_auto);
85        self.next_auto += 1;
86        self.uri_to_prefix.insert(uri.to_string(), prefix.clone());
87        self.prefix_to_uri.insert(prefix.clone(), uri.to_string());
88        prefix
89    }
90}
91
92fn format_path(stack: &[String]) -> String {
93    if stack.is_empty() {
94        "root".to_string()
95    } else {
96        stack.join("/")
97    }
98}
99
100#[inline(always)]
101fn wrap_err<T, E: std::fmt::Display>(res: Result<T, E>, stack: &[String]) -> PyResult<T> {
102    res.map_err(|e| {
103        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
104            "{} (at {})",
105            e,
106            format_path(stack)
107        ))
108    })
109}
110
111#[inline(always)]
112fn wrap_py_err<T>(res: PyResult<T>, stack: &[String]) -> PyResult<T> {
113    res.map_err(|e| {
114        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
115            "{} (at {})",
116            e,
117            format_path(stack)
118        ))
119    })
120}
121
122#[inline]
123fn py_value_to_cow<'py>(
124    py: Python<'py>,
125    value: &'py pyo3::Bound<'py, pyo3::PyAny>,
126    default_func: &Option<Py<PyAny>>,
127    path: &[String],
128) -> PyResult<Option<Cow<'py, str>>> {
129    if value.is_instance_of::<PyNone>() {
130        return Ok(None);
131    }
132    if value.is_instance_of::<PyBool>() {
133        return Ok(Some(if value.extract::<bool>()? {
134            "true".into()
135        } else {
136            "false".into()
137        }));
138    }
139    if let Ok(pystr) = value.cast::<PyString>() {
140        return Ok(Some(Cow::Borrowed(pystr.to_str()?)));
141    }
142    if value.is_instance_of::<PyInt>() || value.is_instance_of::<PyFloat>() {
143        return Ok(Some(Cow::Owned(value.to_string())));
144    }
145
146    if let Some(func) = default_func {
147        match func.call1(py, (value,)) {
148            Ok(serialized) => return Ok(Some(Cow::Owned(serialized.to_string()))),
149            Err(e) => {
150                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
151                    "Custom serialization failed: {} (at {})",
152                    e,
153                    format_path(path)
154                )));
155            }
156        }
157    }
158    Ok(Some(Cow::Owned(value.to_string())))
159}
160
161#[inline]
162fn qualify_tag<'a>(
163    tag: &'a str,
164    dict: Option<&Bound<'_, PyDict>>,
165    ns: &mut NamespaceContext,
166) -> Cow<'a, str> {
167    if tag.contains(':') {
168        return Cow::Borrowed(tag);
169    }
170    if let Some(d) = dict {
171        if let Ok(Some(ns_val)) = d.get_item("@ns") {
172            let ns_str = ns_val.to_string();
173            if ns.prefix_to_uri.contains_key(&ns_str) {
174                return Cow::Owned(format!("{}:{}", ns_str, tag));
175            }
176            let prefix = ns.get_or_create_prefix(&ns_str);
177            return Cow::Owned(format!("{}:{}", prefix, tag));
178        }
179    }
180    Cow::Borrowed(tag)
181}
182
183fn process_node<W: Write>(
184    writer: &mut Writer<W>,
185    tag_name: &str,
186    value: &Bound<'_, PyAny>,
187    config: &Config,
188    compat: CompatMode,
189    ns: &mut NamespaceContext,
190    is_root: bool,
191    path: &mut Vec<String>,
192    visited: &mut FxHashSet<usize>,
193) -> PyResult<()> {
194    let py = value.py();
195
196    if value.is_instance_of::<PyString>()
197        || value.is_instance_of::<PyBool>()
198        || value.is_instance_of::<PyNone>()
199        || value.is_instance_of::<PyInt>()
200        || value.is_instance_of::<PyFloat>()
201    {
202    } else if let Ok(dict) = value.cast::<PyDict>() {
203        let ptr = dict.as_ptr() as usize;
204        if !visited.insert(ptr) {
205            return Err(PyErr::new::<pyo3::exceptions::PyRecursionError, _>(
206                format!("Circular reference detected (at {})", format_path(path)),
207            ));
208        }
209
210        let mut tail_text: Option<String> = None;
211        if let Ok(Some(tail_val)) = dict.get_item("#tail") {
212            path.push("#tail".to_string());
213            if let Some(cow) = py_value_to_cow(py, &tail_val, &config.default_func, path)? {
214                tail_text = Some(cow.into_owned());
215            }
216            path.pop();
217        }
218        let mut ns_local = ns.clone();
219        let qualified = qualify_tag(tag_name, Some(&dict), &mut ns_local);
220        let mut elem = BytesStart::new(qualified.as_ref());
221        let mut attrs = Vec::new();
222        let mut xmlns_attrs = Vec::new();
223
224        if is_root {
225            for (prefix, uri) in &config.namespaces {
226                if prefix.is_empty() {
227                    xmlns_attrs.push(("xmlns".to_string(), uri.clone()));
228                } else {
229                    xmlns_attrs.push((format!("xmlns:{}", prefix), uri.clone()));
230                }
231            }
232        }
233
234        for (k, v) in dict {
235            let k_cow = extract_str(&k)?;
236            let k_str = k_cow.as_ref();
237            if k_str == "#comment" || k_str == "#tail" || k_str.starts_with("?") {
238                continue;
239            }
240
241            if k_str == "@xmlns" {
242                let uri = v.to_string();
243                if !config.namespaces.values().any(|u| u == &uri) {
244                    let prefix = ns_local.get_or_create_prefix(&uri);
245                    xmlns_attrs.push((format!("xmlns:{}", prefix), uri));
246                }
247            } else if let Some(p) = k_str.strip_prefix("@xmlns:") {
248                let uri = v.to_string();
249                ns_local.uri_to_prefix.insert(uri.clone(), p.to_string());
250                ns_local.prefix_to_uri.insert(p.to_string(), uri.clone());
251                xmlns_attrs.push((format!("xmlns:{}", p), uri));
252            } else if let Some(attr) = k_str.strip_prefix(&config.attr_prefix) {
253                path.push(format!("@{}", attr));
254                if let Some(val) = py_value_to_cow(py, &v, &config.default_func, path)? {
255                    attrs.push((attr.to_string(), val.into_owned()));
256                }
257                path.pop();
258            }
259        }
260
261        if config.sort_attrs {
262            attrs.sort_by(|a, b| a.0.cmp(&b.0));
263            xmlns_attrs.sort_by(|a, b| a.0.cmp(&b.0));
264        }
265        for (k, v) in xmlns_attrs {
266            elem.push_attribute((k.as_str(), v.as_str()));
267        }
268        for (k, v) in attrs {
269            elem.push_attribute((k.as_str(), v.as_str()));
270        }
271
272        let has_children = dict.iter().any(|(k, _)| {
273            let k_cow = k
274                .cast::<PyString>()
275                .map(|s| s.to_string_lossy())
276                .unwrap_or_default();
277            let ks = k_cow.as_ref();
278            !ks.starts_with(&config.attr_prefix) && ks != "#tail"
279        });
280
281        if !has_children {
282            wrap_err(writer.write_event(Event::Empty(elem)), path)?;
283            visited.remove(&ptr);
284            if let Some(text) = tail_text {
285                wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
286            }
287            return Ok(());
288        }
289
290        wrap_err(writer.write_event(Event::Start(elem)), path)?;
291
292        for (k, v) in dict {
293            let k_cow = extract_str(&k)?;
294            let k_str = k_cow.as_ref();
295
296            if k_str.starts_with(&config.attr_prefix) {
297                continue;
298            }
299            if k_str == "#tail" {
300                continue;
301            }
302
303            path.push(k_str.to_string());
304
305            if k_str == "#comment" {
306                if let Some(comment_txt) = py_value_to_cow(py, &v, &config.default_func, path)? {
307                    wrap_err(
308                        writer.write_event(Event::Comment(BytesText::new(&comment_txt))),
309                        path,
310                    )?;
311                }
312            } else if k_str.starts_with("?") {
313                if let Some(content) = py_value_to_cow(py, &v, &config.default_func, path)? {
314                    let target = k_str.strip_prefix("?").unwrap_or(&k_str);
315                    let pi_content = format!("{} {}", target, content);
316                    wrap_err(
317                        writer.write_event(Event::PI(BytesPI::new(&pi_content))),
318                        path,
319                    )?;
320                }
321            } else if k_str == config.cdata_key {
322                let mut written = false;
323                if let Ok(inner) = v.cast::<PyDict>() {
324                    if let Ok(Some(cdata)) = inner.get_item("__cdata__") {
325                        let s = cdata.to_string();
326                        wrap_err(writer.write_event(Event::CData(BytesCData::new(&s))), path)?;
327                        written = true;
328                    }
329                }
330                if !written {
331                    if let Some(text) = py_value_to_cow(py, &v, &config.default_func, path)? {
332                        wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
333                    }
334                }
335            } else {
336                process_node(
337                    writer,
338                    k_str,
339                    &v,
340                    config,
341                    compat,
342                    &mut ns_local,
343                    false,
344                    path,
345                    visited,
346                )?;
347            }
348            path.pop();
349        }
350        wrap_err(
351            writer.write_event(Event::End(BytesEnd::new(qualified.as_ref()))),
352            path,
353        )?;
354        visited.remove(&ptr);
355
356        if let Some(text) = tail_text {
357            wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
358        }
359
360        return Ok(());
361    } else if let Ok(list) = value.cast::<PyList>() {
362        let ptr = list.as_ptr() as usize;
363        if !visited.insert(ptr) {
364            return Err(PyErr::new::<pyo3::exceptions::PyRecursionError, _>(
365                format!("Circular ref (at {})", format_path(path)),
366            ));
367        }
368        for (i, item) in list.iter().enumerate() {
369            path.push(format!("[{}]", i));
370            process_node(
371                writer, tag_name, &item, config, compat, ns, is_root, path, visited,
372            )?;
373            path.pop();
374        }
375        visited.remove(&ptr);
376        return Ok(());
377    } else if let Ok(iter) = PyIterator::from_object(value) {
378        let mut i = 0;
379        for item in iter {
380            let obj = item?;
381            path.push(format!("[{}]", i));
382            process_node(
383                writer, tag_name, &obj, config, compat, ns, is_root, path, visited,
384            )?;
385            path.pop();
386            i += 1;
387        }
388        return Ok(());
389    }
390
391    if value.is_instance_of::<PyNone>() {
392        match compat {
393            CompatMode::Native => {
394                wrap_err(
395                    writer.write_event(Event::Empty(BytesStart::new(tag_name))),
396                    path,
397                )?;
398            }
399            CompatMode::Obj2Xml => {
400                wrap_err(
401                    writer.write_event(Event::Start(BytesStart::new(tag_name))),
402                    path,
403                )?;
404                wrap_err(
405                    writer.write_event(Event::End(BytesEnd::new(tag_name))),
406                    path,
407                )?;
408            }
409        }
410    } else if let Some(text) = py_value_to_cow(py, value, &config.default_func, path)? {
411        let elem = BytesStart::new(tag_name);
412        wrap_err(writer.write_event(Event::Start(elem)), path)?;
413        wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
414        wrap_err(
415            writer.write_event(Event::End(BytesEnd::new(tag_name))),
416            path,
417        )?;
418    }
419    Ok(())
420}
421
422fn generate_xml<W: Write>(
423    writer: W,
424    input: &Bound<'_, PyAny>,
425    encoding: &str,
426    full_document: bool,
427    pretty: bool,
428    indent: &str,
429    config: &Config,
430    compat: CompatMode,
431    attr_prefix: &str,
432) -> PyResult<W> {
433    let mut writer = if pretty {
434        Writer::new_with_indent(writer, b' ', indent.len())
435    } else {
436        Writer::new(writer)
437    };
438    if full_document {
439        writer
440            .write_event(Event::Decl(BytesDecl::new("1.0", Some(encoding), None)))
441            .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
442    }
443    let mut ns = NamespaceContext::new(&config.namespaces);
444    let mut path = Vec::with_capacity(16);
445    let mut visited = FxHashSet::default();
446
447    if let Ok(dict) = input.cast::<PyDict>() {
448        let roots = dict
449            .iter()
450            .filter(|(k, _)| !k.to_string().starts_with(attr_prefix))
451            .count();
452        if full_document && roots != 1 {
453            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
454                "Document must have exactly one root",
455            ));
456        }
457        for (k, v) in dict {
458            let key_cow = extract_str(&k)?;
459            let key = key_cow.as_ref();
460            if key.starts_with(attr_prefix) {
461                continue;
462            }
463            path.push(key.to_string());
464            process_node(
465                &mut writer,
466                key,
467                &v,
468                config,
469                compat,
470                &mut ns,
471                true,
472                &mut path,
473                &mut visited,
474            )?;
475            path.pop();
476        }
477    } else {
478        let iter = PyIterator::from_object(input)?;
479        let mut i = 0;
480        for item in iter {
481            let obj = item?;
482            path.push(format!("[{}]", i));
483            if let Ok(d) = obj.cast::<PyDict>() {
484                for (k, v) in d {
485                    let k_cow = extract_str(&k)?;
486                    let k_str = k_cow.as_ref();
487                    path.push(k_str.to_string());
488                    process_node(
489                        &mut writer,
490                        k_str,
491                        &v,
492                        config,
493                        compat,
494                        &mut ns,
495                        true,
496                        &mut path,
497                        &mut visited,
498                    )?;
499                    path.pop();
500                }
501            } else {
502                process_node(
503                    &mut writer,
504                    &config.item_name,
505                    &obj,
506                    config,
507                    compat,
508                    &mut ns,
509                    true,
510                    &mut path,
511                    &mut visited,
512                )?;
513            }
514            path.pop();
515            i += 1;
516        }
517    }
518    Ok(writer.into_inner())
519}
520
521#[pyfunction]
522#[pyo3(signature = (
523    input, *,
524    output=None,
525    encoding="utf-8",
526    full_document=true,
527    attr_prefix="@",
528    cdata_key="#text",
529    pretty=false, indent="  ",
530    compat="native",
531    streaming=false,
532    default=None,
533    item_name="item",
534    sort_attributes=false,
535    namespaces=None
536))]
537fn unparse(
538    _py: Python<'_>,
539    input: &Bound<'_, PyAny>,
540    output: Option<&Bound<'_, PyAny>>,
541    encoding: &str,
542    full_document: bool,
543    attr_prefix: &str,
544    cdata_key: &str,
545    pretty: bool,
546    indent: &str,
547    compat: &str,
548    streaming: bool,
549    default: Option<Py<PyAny>>,
550    item_name: &str,
551    sort_attributes: bool,
552    namespaces: Option<FxHashMap<String, String>>,
553) -> PyResult<String> {
554    let config = Config {
555        attr_prefix: attr_prefix.to_string(),
556        cdata_key: cdata_key.to_string(),
557        default_func: default,
558        item_name: item_name.to_string(),
559        sort_attrs: sort_attributes,
560        namespaces: namespaces.unwrap_or_default(),
561    };
562    let compat_mode = if compat == "legacy" {
563        CompatMode::Obj2Xml
564    } else {
565        CompatMode::Native
566    };
567
568    if streaming {
569        let out_obj = output.ok_or_else(|| {
570            PyErr::new::<pyo3::exceptions::PyValueError, _>(
571                "streaming=True requires output argument",
572            )
573        })?;
574        let sink: Box<dyn Write> = if let Ok(path) = out_obj.extract::<String>() {
575            let f = File::create(path)?;
576            Box::new(BufWriter::new(f))
577        } else {
578            Box::new(PyWriter {
579                obj: out_obj.clone(),
580            })
581        };
582        wrap_err(
583            generate_xml(
584                sink,
585                input,
586                encoding,
587                full_document,
588                pretty,
589                indent,
590                &config,
591                compat_mode,
592                attr_prefix,
593            ),
594            &[],
595        )?;
596        return Ok(String::new());
597    }
598
599    let cursor = Cursor::new(Vec::with_capacity(32 * 1024));
600    let cursor = wrap_py_err(
601        generate_xml(
602            cursor,
603            input,
604            encoding,
605            full_document,
606            pretty,
607            indent,
608            &config,
609            compat_mode,
610            attr_prefix,
611        ),
612        &[],
613    )?;
614    let xml = String::from_utf8(cursor.into_inner())
615        .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
616    if let Some(out_obj) = output {
617        if let Ok(path) = out_obj.extract::<String>() {
618            std::fs::write(path, &xml)?;
619        } else {
620            out_obj.call_method1("write", (xml.as_bytes(),))?;
621        }
622        return Ok(String::new());
623    }
624    Ok(xml)
625}
626
627struct ParseConfig {
628    attr_prefix: String,
629    cdata_key: String,
630    force_cdata: bool,
631    process_namespaces: bool,
632    namespace_separator: String,
633    strip_whitespace: bool,
634    force_list: Option<FxHashSet<String>>,
635    process_comments: bool,
636}
637
638struct StackItem<'py> {
639    dict: Bound<'py, PyDict>,
640    tag_name: String,
641}
642
643fn parse_xml<'py>(
644    py: Python<'py>,
645    reader: &mut Reader<Box<dyn std::io::BufRead>>,
646    config: &ParseConfig,
647) -> PyResult<Py<PyAny>> {
648    let mut buf = Vec::new();
649    let mut stack: Vec<StackItem<'py>> = Vec::new();
650    let mut text_buffer: Option<String> = None;
651
652    let mut ns_stack: Vec<FxHashMap<String, String>> = Vec::new();
653    if config.process_namespaces {
654        ns_stack.push(FxHashMap::default());
655    }
656
657    loop {
658        match reader.read_event_into(&mut buf) {
659            Ok(Event::Start(e)) => {
660                let raw_name = String::from_utf8_lossy(e.name().as_ref()).into_owned();
661
662                let mut current_ns = if config.process_namespaces {
663                    ns_stack.last().cloned().unwrap_or_default()
664                } else {
665                    FxHashMap::default()
666                };
667
668                let mut attributes_vec = Vec::new();
669                for attr in e.attributes() {
670                    let attr = attr.map_err(|e| {
671                        PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string())
672                    })?;
673                    let key = String::from_utf8_lossy(attr.key.as_ref()).into_owned();
674                    let val = String::from_utf8_lossy(&attr.value).into_owned();
675
676                    if config.process_namespaces {
677                        if key == "xmlns" {
678                            current_ns.insert("".to_string(), val.clone());
679                        } else if let Some(prefix) = key.strip_prefix("xmlns:") {
680                            current_ns.insert(prefix.to_string(), val.clone());
681                        }
682                    }
683                    attributes_vec.push((key, val));
684                }
685
686                if config.process_namespaces {
687                    ns_stack.push(current_ns.clone());
688                }
689
690                let tag_name = if config.process_namespaces {
691                    resolve_name(&raw_name, &ns_stack, &config.namespace_separator, true)
692                } else {
693                    raw_name
694                };
695
696                let new_dict = PyDict::new(py);
697
698                for (key, val) in attributes_vec {
699                    let final_key = if config.process_namespaces && !key.starts_with("xmlns") {
700                        resolve_name(&key, &ns_stack, &config.namespace_separator, false)
701                    } else {
702                        key
703                    };
704
705                    let key_str = format!("{}{}", config.attr_prefix, final_key);
706                    new_dict.set_item(key_str, val)?;
707                }
708
709                if let Some(text) = text_buffer.take() {
710                    if let Some(parent) = stack.last() {
711                        if let Some(existing) = parent.dict.get_item(&config.cdata_key)? {
712                            let s = existing.extract::<String>()?;
713                            parent
714                                .dict
715                                .set_item(&config.cdata_key, format!("{}{}", s, text))?;
716                        } else {
717                            parent.dict.set_item(&config.cdata_key, text)?;
718                        }
719                    }
720                }
721
722                stack.push(StackItem {
723                    dict: new_dict,
724                    tag_name,
725                });
726            }
727            Ok(Event::End(_)) => {
728                let StackItem {
729                    dict: current_dict,
730                    tag_name,
731                } = stack.pop().ok_or_else(|| {
732                    PyErr::new::<pyo3::exceptions::PyValueError, _>("Unexpected closing tag")
733                })?;
734
735                if config.process_namespaces {
736                    ns_stack.pop();
737                }
738
739                if let Some(text) = text_buffer.take() {
740                    if current_dict.is_empty() {
741                        if config.force_cdata {
742                            current_dict.set_item(&config.cdata_key, text)?;
743                        } else {
744                            current_dict.set_item(&config.cdata_key, text)?;
745                        }
746                    } else {
747                        current_dict.set_item(&config.cdata_key, text)?;
748                    }
749                }
750
751                if let Some(parent) = stack.last() {
752                    let key = tag_name.as_str();
753
754                    let value_to_insert: Py<PyAny> = if current_dict.len() == 1
755                        && current_dict.contains(&config.cdata_key)?
756                        && !config.force_cdata
757                    {
758                        current_dict.get_item(&config.cdata_key)?.unwrap().into()
759                    } else if current_dict.is_empty() {
760                        py.None().into()
761                    } else {
762                        current_dict.into()
763                    };
764
765                    if let Some(existing) = parent.dict.get_item(key)? {
766                        if let Ok(list) = existing.cast::<PyList>() {
767                            list.append(value_to_insert)?;
768                        } else {
769                            let list = PyList::new(
770                                py,
771                                vec![
772                                    existing.into_pyobject(py)?.into_any(),
773                                    value_to_insert.into_pyobject(py)?.into_any(),
774                                ],
775                            )?;
776                            parent.dict.set_item(key, list)?;
777                        }
778                    } else {
779                        let force = if let Some(fl) = &config.force_list {
780                            fl.contains(key)
781                        } else {
782                            false
783                        };
784                        if force {
785                            let list = PyList::new(
786                                py,
787                                vec![value_to_insert.into_pyobject(py)?.into_any()],
788                            )?;
789                            parent.dict.set_item(key, list)?;
790                        } else {
791                            parent.dict.set_item(key, value_to_insert)?;
792                        }
793                    }
794                } else {
795                    let root_dict = PyDict::new(py);
796                    let value_to_insert: Py<PyAny> = if current_dict.len() == 1
797                        && current_dict.contains(&config.cdata_key)?
798                        && !config.force_cdata
799                    {
800                        current_dict.get_item(&config.cdata_key)?.unwrap().into()
801                    } else if current_dict.is_empty() {
802                        py.None().into()
803                    } else {
804                        current_dict.into()
805                    };
806                    root_dict.set_item(tag_name, value_to_insert)?;
807                    return Ok(root_dict.into());
808                }
809            }
810            Ok(Event::Empty(e)) => {
811                let raw_name = String::from_utf8_lossy(e.name().as_ref()).into_owned();
812
813                let mut current_ns = if config.process_namespaces {
814                    ns_stack.last().cloned().unwrap_or_default()
815                } else {
816                    FxHashMap::default()
817                };
818                let mut attributes_vec = Vec::new();
819                for attr in e.attributes() {
820                    let attr = attr.map_err(|e| {
821                        PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string())
822                    })?;
823                    let key = String::from_utf8_lossy(attr.key.as_ref()).into_owned();
824                    let val = String::from_utf8_lossy(&attr.value).into_owned();
825                    if config.process_namespaces {
826                        if key == "xmlns" {
827                            current_ns.insert("".to_string(), val.clone());
828                        } else if let Some(prefix) = key.strip_prefix("xmlns:") {
829                            current_ns.insert(prefix.to_string(), val.clone());
830                        }
831                    }
832                    attributes_vec.push((key, val));
833                }
834
835                if config.process_namespaces {
836                    ns_stack.push(current_ns);
837                }
838                let tag_name = if config.process_namespaces {
839                    resolve_name(&raw_name, &ns_stack, &config.namespace_separator, true)
840                } else {
841                    raw_name
842                };
843
844                let new_dict = PyDict::new(py);
845                for (key, val) in attributes_vec {
846                    let final_key = if config.process_namespaces && !key.starts_with("xmlns") {
847                        resolve_name(&key, &ns_stack, &config.namespace_separator, false)
848                    } else {
849                        key
850                    };
851                    new_dict.set_item(format!("{}{}", config.attr_prefix, final_key), val)?;
852                }
853
854                if config.process_namespaces {
855                    ns_stack.pop();
856                }
857
858                if let Some(parent) = stack.last() {
859                    let key = tag_name.as_str();
860                    let value_to_insert: Py<PyAny> = if new_dict.is_empty() {
861                        py.None().into()
862                    } else {
863                        new_dict.into()
864                    };
865
866                    if let Some(existing) = parent.dict.get_item(key)? {
867                        if let Ok(list) = existing.cast::<PyList>() {
868                            list.append(value_to_insert)?;
869                        } else {
870                            let list = PyList::new(
871                                py,
872                                vec![
873                                    existing.into_pyobject(py)?.into_any(),
874                                    value_to_insert.into_pyobject(py)?.into_any(),
875                                ],
876                            )?;
877                            parent.dict.set_item(key, list)?;
878                        }
879                    } else {
880                        let force = if let Some(fl) = &config.force_list {
881                            fl.contains(key)
882                        } else {
883                            false
884                        };
885                        if force {
886                            let list = PyList::new(
887                                py,
888                                vec![value_to_insert.into_pyobject(py)?.into_any()],
889                            )?;
890                            parent.dict.set_item(key, list)?;
891                        } else {
892                            parent.dict.set_item(key, value_to_insert)?;
893                        }
894                    }
895                } else {
896                    let root_dict = PyDict::new(py);
897                    root_dict.set_item(tag_name, py.None())?;
898                    return Ok(root_dict.into());
899                }
900            }
901            Ok(Event::Text(e)) => {
902                let text_cow = std::str::from_utf8(&e)
903                    .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
904                let unescaped = quick_xml::escape::unescape(text_cow)
905                    .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
906
907                let text = unescaped.as_ref();
908                let trimmed = if config.strip_whitespace {
909                    text.trim()
910                } else {
911                    text
912                };
913                if !trimmed.is_empty() {
914                    if let Some(ref mut buf) = text_buffer {
915                        buf.push_str(trimmed);
916                    } else {
917                        text_buffer = Some(trimmed.to_string());
918                    }
919                }
920            }
921            Ok(Event::CData(e)) => {
922                let text = String::from_utf8_lossy(&e);
923                let trimmed = if config.strip_whitespace {
924                    text.trim()
925                } else {
926                    &text
927                };
928                if !trimmed.is_empty() {
929                    if let Some(ref mut buf) = text_buffer {
930                        buf.push_str(trimmed);
931                    } else {
932                        text_buffer = Some(trimmed.to_string());
933                    }
934                }
935            }
936
937            Ok(Event::Comment(e)) => {
938                if config.process_comments {
939                    let comment = String::from_utf8_lossy(&e).into_owned();
940                    if let Some(parent) = stack.last() {
941                        if let Some(existing) = parent.dict.get_item("#comment")? {
942                            if let Ok(list) = existing.cast::<PyList>() {
943                                list.append(comment)?;
944                            } else {
945                                let list = PyList::new(
946                                    py,
947                                    vec![
948                                        existing.into_pyobject(py)?.into_any(),
949                                        comment.into_pyobject(py)?.into_any(),
950                                    ],
951                                )?;
952                                parent.dict.set_item("#comment", list)?;
953                            }
954                        } else {
955                            parent.dict.set_item("#comment", comment)?;
956                        }
957                    }
958                }
959            }
960
961            Ok(Event::PI(e)) => {
962                let content = String::from_utf8_lossy(&e);
963                let (target, value) = if let Some((t, v)) = content.split_once(' ') {
964                    (t, v)
965                } else {
966                    (content.as_ref(), "")
967                };
968                let key = format!("?{}", target);
969                let val = value.to_string();
970
971                if let Some(parent) = stack.last() {
972                    if let Some(existing) = parent.dict.get_item(&key)? {
973                        if let Ok(list) = existing.cast::<PyList>() {
974                            list.append(val)?;
975                        } else {
976                            let list = PyList::new(
977                                py,
978                                vec![
979                                    existing.into_pyobject(py)?.into_any(),
980                                    val.into_pyobject(py)?.into_any(),
981                                ],
982                            )?;
983                            parent.dict.set_item(&key, list)?;
984                        }
985                    } else {
986                        parent.dict.set_item(&key, val)?;
987                    }
988                }
989            }
990            Ok(Event::Eof) => break,
991            Err(e) => {
992                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
993                    "XML Parse Error: {}",
994                    e
995                )));
996            }
997            _ => {}
998        }
999        buf.clear();
1000    }
1001    Ok(PyDict::new(py).into())
1002}
1003
1004fn resolve_name(
1005    name: &str,
1006    ns_stack: &[FxHashMap<String, String>],
1007    separator: &str,
1008    is_element: bool,
1009) -> String {
1010    if let Some((prefix, local)) = name.split_once(':') {
1011        for scope in ns_stack.iter().rev() {
1012            if let Some(uri) = scope.get(prefix) {
1013                return format!("{}{}{}", uri, separator, local);
1014            }
1015        }
1016    } else if is_element {
1017        for scope in ns_stack.iter().rev() {
1018            if let Some(uri) = scope.get("") {
1019                return format!("{}{}{}", uri, separator, name);
1020            }
1021        }
1022    }
1023    name.to_string()
1024}
1025
1026#[pyfunction]
1027#[pyo3(signature = (
1028    xml_input, *,
1029    _encoding=None,
1030    attr_prefix="@",
1031    cdata_key="#text",
1032    force_cdata=false,
1033    process_namespaces=false,
1034    namespace_separator=":",
1035    strip_whitespace=true,
1036    force_list=None,
1037    process_comments=false
1038))]
1039fn parse(
1040    py: Python<'_>,
1041    xml_input: &Bound<'_, PyAny>,
1042    _encoding: Option<&str>,
1043    attr_prefix: &str,
1044    cdata_key: &str,
1045    force_cdata: bool,
1046    process_namespaces: bool,
1047    namespace_separator: &str,
1048    strip_whitespace: bool,
1049    force_list: Option<Vec<String>>,
1050    process_comments: bool,
1051) -> PyResult<Py<PyAny>> {
1052    let config = ParseConfig {
1053        attr_prefix: attr_prefix.to_string(),
1054        cdata_key: cdata_key.to_string(),
1055        force_cdata,
1056        process_namespaces,
1057        namespace_separator: namespace_separator.to_string(),
1058        strip_whitespace,
1059        force_list: force_list.map(|v| v.into_iter().collect()),
1060        process_comments,
1061    };
1062
1063    if let Ok(s) = xml_input.extract::<String>() {
1064        let mut reader = Reader::from_str(&s);
1065        reader.config_mut().trim_text(strip_whitespace);
1066        let mut boxed_reader: Reader<Box<dyn std::io::BufRead>> =
1067            Reader::from_reader(Box::new(Cursor::new(s.into_bytes())));
1068        boxed_reader.config_mut().trim_text(strip_whitespace);
1069        boxed_reader.config_mut().expand_empty_elements = true;
1070        return parse_xml(py, &mut boxed_reader, &config);
1071    }
1072
1073    if let Ok(b) = xml_input.extract::<Vec<u8>>() {
1074        let mut boxed_reader =
1075            Reader::from_reader(Box::new(Cursor::new(b)) as Box<dyn std::io::BufRead>);
1076        boxed_reader.config_mut().trim_text(strip_whitespace);
1077        boxed_reader.config_mut().expand_empty_elements = true;
1078        return parse_xml(py, &mut boxed_reader, &config);
1079    }
1080
1081    if xml_input.hasattr("read")? {
1082        let bytes: Vec<u8> = xml_input.call_method0("read")?.extract()?;
1083        let mut boxed_reader =
1084            Reader::from_reader(Box::new(Cursor::new(bytes)) as Box<dyn std::io::BufRead>);
1085        boxed_reader.config_mut().trim_text(strip_whitespace);
1086        boxed_reader.config_mut().expand_empty_elements = true;
1087        return parse_xml(py, &mut boxed_reader, &config);
1088    }
1089
1090    Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
1091        "Input must be str, bytes, or file-like object",
1092    ))
1093}
1094
1095#[pymodule]
1096fn _obj2xml_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
1097    m.add_function(wrap_pyfunction!(unparse, m)?)?;
1098    m.add_function(wrap_pyfunction!(parse, m)?)?;
1099    Ok(())
1100}