_obj2xml_rs/
lib.rs

1use pyo3::prelude::*;
2use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyIterator, PyList, PyNone, PyString};
3use quick_xml::Writer;
4use quick_xml::events::{BytesCData, BytesDecl, BytesEnd, BytesPI, BytesStart, BytesText, Event};
5use rustc_hash::{FxHashMap, FxHashSet};
6use std::borrow::Cow;
7use std::fs::File;
8use std::io::{BufWriter, Cursor, Write};
9
10struct Config {
11    attr_prefix: String,
12    cdata_key: String,
13    default_func: Option<Py<PyAny>>,
14    item_name: String,
15    sort_attrs: bool,
16    namespaces: FxHashMap<String, String>,
17}
18
19#[derive(Clone, Copy, PartialEq)]
20enum CompatMode {
21    Native,
22    Obj2Xml,
23}
24
25#[inline]
26fn extract_str<'py>(value: &'py pyo3::Bound<'py, pyo3::PyAny>) -> PyResult<Cow<'py, str>> {
27    if let Ok(pystr) = value.cast::<PyString>() {
28        let slice = pystr.to_str()?;
29        Ok(Cow::Borrowed(slice))
30    } else {
31        Ok(Cow::Owned(value.to_string()))
32    }
33}
34
35struct PyWriter<'py> {
36    obj: Bound<'py, PyAny>,
37}
38
39impl<'py> Write for PyWriter<'py> {
40    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
41        self.obj
42            .call_method1("write", (buf,))
43            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
44        Ok(buf.len())
45    }
46    fn flush(&mut self) -> std::io::Result<()> {
47        if self.obj.hasattr("flush").unwrap_or(false) {
48            self.obj
49                .call_method0("flush")
50                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
51        }
52        Ok(())
53    }
54}
55
56#[derive(Clone)]
57struct NamespaceContext {
58    uri_to_prefix: FxHashMap<String, String>,
59    prefix_to_uri: FxHashMap<String, String>,
60    next_auto: usize,
61}
62
63impl NamespaceContext {
64    fn new(predefined: &FxHashMap<String, String>) -> Self {
65        let mut uri_to_prefix = FxHashMap::default();
66        let mut prefix_to_uri = FxHashMap::default();
67        for (prefix, uri) in predefined {
68            uri_to_prefix.insert(uri.clone(), prefix.clone());
69            prefix_to_uri.insert(prefix.clone(), uri.clone());
70        }
71        Self {
72            uri_to_prefix,
73            prefix_to_uri,
74            next_auto: 0,
75        }
76    }
77
78    fn get_or_create_prefix(&mut self, uri: &str) -> String {
79        if let Some(p) = self.uri_to_prefix.get(uri) {
80            return p.clone();
81        }
82        let prefix = format!("ns{}", self.next_auto);
83        self.next_auto += 1;
84        self.uri_to_prefix.insert(uri.to_string(), prefix.clone());
85        self.prefix_to_uri.insert(prefix.clone(), uri.to_string());
86        prefix
87    }
88}
89
90fn format_path(stack: &[String]) -> String {
91    if stack.is_empty() {
92        "root".to_string()
93    } else {
94        stack.join("/")
95    }
96}
97
98#[inline(always)]
99fn wrap_err<T, E: std::fmt::Display>(res: Result<T, E>, stack: &[String]) -> PyResult<T> {
100    res.map_err(|e| {
101        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
102            "{} (at {})",
103            e,
104            format_path(stack)
105        ))
106    })
107}
108
109#[inline(always)]
110fn wrap_py_err<T>(res: PyResult<T>, stack: &[String]) -> PyResult<T> {
111    res.map_err(|e| {
112        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
113            "{} (at {})",
114            e,
115            format_path(stack)
116        ))
117    })
118}
119
120#[inline]
121fn py_value_to_cow<'py>(
122    py: Python<'py>,
123    value: &'py pyo3::Bound<'py, pyo3::PyAny>,
124    default_func: &Option<Py<PyAny>>,
125    path: &[String],
126) -> PyResult<Option<Cow<'py, str>>> {
127    if value.is_instance_of::<PyNone>() {
128        return Ok(None);
129    }
130    if value.is_instance_of::<PyBool>() {
131        return Ok(Some(if value.extract::<bool>()? {
132            "true".into()
133        } else {
134            "false".into()
135        }));
136    }
137    if let Ok(pystr) = value.cast::<PyString>() {
138        return Ok(Some(Cow::Borrowed(pystr.to_str()?)));
139    }
140    if value.is_instance_of::<PyInt>() || value.is_instance_of::<PyFloat>() {
141        return Ok(Some(Cow::Owned(value.to_string())));
142    }
143
144    if let Some(func) = default_func {
145        match func.call1(py, (value,)) {
146            Ok(serialized) => return Ok(Some(Cow::Owned(serialized.to_string()))),
147            Err(e) => {
148                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
149                    "Custom serialization failed: {} (at {})",
150                    e,
151                    format_path(path)
152                )));
153            }
154        }
155    }
156    Ok(Some(Cow::Owned(value.to_string())))
157}
158
159#[inline]
160fn qualify_tag<'a>(
161    tag: &'a str,
162    dict: Option<&Bound<'_, PyDict>>,
163    ns: &mut NamespaceContext,
164) -> Cow<'a, str> {
165    if tag.contains(':') {
166        return Cow::Borrowed(tag);
167    }
168    if let Some(d) = dict {
169        if let Ok(Some(ns_val)) = d.get_item("@ns") {
170            let ns_str = ns_val.to_string();
171            if ns.prefix_to_uri.contains_key(&ns_str) {
172                return Cow::Owned(format!("{}:{}", ns_str, tag));
173            }
174            let prefix = ns.get_or_create_prefix(&ns_str);
175            return Cow::Owned(format!("{}:{}", prefix, tag));
176        }
177    }
178    Cow::Borrowed(tag)
179}
180
181fn process_node<W: Write>(
182    writer: &mut Writer<W>,
183    tag_name: &str,
184    value: &Bound<'_, PyAny>,
185    config: &Config,
186    compat: CompatMode,
187    ns: &mut NamespaceContext,
188    is_root: bool,
189    path: &mut Vec<String>,
190    visited: &mut FxHashSet<usize>,
191) -> PyResult<()> {
192    let py = value.py();
193
194    if value.is_instance_of::<PyString>()
195        || value.is_instance_of::<PyBool>()
196        || value.is_instance_of::<PyNone>()
197        || value.is_instance_of::<PyInt>()
198        || value.is_instance_of::<PyFloat>()
199    {
200    } else if let Ok(dict) = value.cast::<PyDict>() {
201        let ptr = dict.as_ptr() as usize;
202        if !visited.insert(ptr) {
203            return Err(PyErr::new::<pyo3::exceptions::PyRecursionError, _>(
204                format!("Circular reference detected (at {})", format_path(path)),
205            ));
206        }
207
208        let mut tail_text: Option<String> = None;
209        if let Ok(Some(tail_val)) = dict.get_item("#tail") {
210            path.push("#tail".to_string());
211            if let Some(cow) = py_value_to_cow(py, &tail_val, &config.default_func, path)? {
212                tail_text = Some(cow.into_owned());
213            }
214            path.pop();
215        }
216        let mut ns_local = ns.clone();
217        let qualified = qualify_tag(tag_name, Some(&dict), &mut ns_local);
218        let mut elem = BytesStart::new(qualified.as_ref());
219        let mut attrs = Vec::new();
220        let mut xmlns_attrs = Vec::new();
221
222        if is_root {
223            for (prefix, uri) in &config.namespaces {
224                if prefix.is_empty() {
225                    xmlns_attrs.push(("xmlns".to_string(), uri.clone()));
226                } else {
227                    xmlns_attrs.push((format!("xmlns:{}", prefix), uri.clone()));
228                }
229            }
230        }
231
232        for (k, v) in dict {
233            let k_cow = extract_str(&k)?;
234            let k_str = k_cow.as_ref();
235            if k_str == "#comment" || k_str == "#tail" || k_str.starts_with("?") {
236                continue;
237            }
238
239            if k_str == "@xmlns" {
240                let uri = v.to_string();
241                if !config.namespaces.values().any(|u| u == &uri) {
242                    let prefix = ns_local.get_or_create_prefix(&uri);
243                    xmlns_attrs.push((format!("xmlns:{}", prefix), uri));
244                }
245            } else if let Some(p) = k_str.strip_prefix("@xmlns:") {
246                let uri = v.to_string();
247                ns_local.uri_to_prefix.insert(uri.clone(), p.to_string());
248                ns_local.prefix_to_uri.insert(p.to_string(), uri.clone());
249                xmlns_attrs.push((format!("xmlns:{}", p), uri));
250            } else if let Some(attr) = k_str.strip_prefix(&config.attr_prefix) {
251                path.push(format!("@{}", attr));
252                if let Some(val) = py_value_to_cow(py, &v, &config.default_func, path)? {
253                    attrs.push((attr.to_string(), val.into_owned()));
254                }
255                path.pop();
256            }
257        }
258
259        if config.sort_attrs {
260            attrs.sort_by(|a, b| a.0.cmp(&b.0));
261            xmlns_attrs.sort_by(|a, b| a.0.cmp(&b.0));
262        }
263        for (k, v) in xmlns_attrs {
264            elem.push_attribute((k.as_str(), v.as_str()));
265        }
266        for (k, v) in attrs {
267            elem.push_attribute((k.as_str(), v.as_str()));
268        }
269
270        let has_children = dict.iter().any(|(k, _)| {
271            let k_cow = k
272                .cast::<PyString>()
273                .map(|s| s.to_string_lossy())
274                .unwrap_or_default();
275            let ks = k_cow.as_ref();
276            !ks.starts_with(&config.attr_prefix) && ks != "#tail"
277        });
278
279        if !has_children {
280            wrap_err(writer.write_event(Event::Empty(elem)), path)?;
281            visited.remove(&ptr);
282            if let Some(text) = tail_text {
283                wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
284            }
285            return Ok(());
286        }
287
288        wrap_err(writer.write_event(Event::Start(elem)), path)?;
289
290        for (k, v) in dict {
291            let k_cow = extract_str(&k)?;
292            let k_str = k_cow.as_ref();
293
294            if k_str.starts_with(&config.attr_prefix) {
295                continue;
296            }
297            if k_str == "#tail" {
298                continue;
299            }
300
301            path.push(k_str.to_string());
302
303            if k_str == "#comment" {
304                if let Some(comment_txt) = py_value_to_cow(py, &v, &config.default_func, path)? {
305                    wrap_err(
306                        writer.write_event(Event::Comment(BytesText::new(&comment_txt))),
307                        path,
308                    )?;
309                }
310            } else if k_str.starts_with("?") {
311                if let Some(content) = py_value_to_cow(py, &v, &config.default_func, path)? {
312                    let target = k_str.strip_prefix("?").unwrap_or(&k_str);
313                    let pi_content = format!("{} {}", target, content);
314                    wrap_err(
315                        writer.write_event(Event::PI(BytesPI::new(&pi_content))),
316                        path,
317                    )?;
318                }
319            } else if k_str == config.cdata_key {
320                let mut written = false;
321                if let Ok(inner) = v.cast::<PyDict>() {
322                    if let Ok(Some(cdata)) = inner.get_item("__cdata__") {
323                        let s = cdata.to_string();
324                        wrap_err(writer.write_event(Event::CData(BytesCData::new(&s))), path)?;
325                        written = true;
326                    }
327                }
328                if !written {
329                    if let Some(text) = py_value_to_cow(py, &v, &config.default_func, path)? {
330                        wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
331                    }
332                }
333            } else {
334                process_node(
335                    writer,
336                    k_str,
337                    &v,
338                    config,
339                    compat,
340                    &mut ns_local,
341                    false,
342                    path,
343                    visited,
344                )?;
345            }
346            path.pop();
347        }
348        wrap_err(
349            writer.write_event(Event::End(BytesEnd::new(qualified.as_ref()))),
350            path,
351        )?;
352        visited.remove(&ptr);
353
354        if let Some(text) = tail_text {
355            wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
356        }
357
358        return Ok(());
359    } else if let Ok(list) = value.cast::<PyList>() {
360        let ptr = list.as_ptr() as usize;
361        if !visited.insert(ptr) {
362            return Err(PyErr::new::<pyo3::exceptions::PyRecursionError, _>(
363                format!("Circular ref (at {})", format_path(path)),
364            ));
365        }
366        for (i, item) in list.iter().enumerate() {
367            path.push(format!("[{}]", i));
368            process_node(
369                writer, tag_name, &item, config, compat, ns, is_root, path, visited,
370            )?;
371            path.pop();
372        }
373        visited.remove(&ptr);
374        return Ok(());
375    } else if let Ok(iter) = PyIterator::from_object(value) {
376        let mut i = 0;
377        for item in iter {
378            let obj = item?;
379            path.push(format!("[{}]", i));
380            process_node(
381                writer, tag_name, &obj, config, compat, ns, is_root, path, visited,
382            )?;
383            path.pop();
384            i += 1;
385        }
386        return Ok(());
387    }
388
389    if value.is_instance_of::<PyNone>() {
390        match compat {
391            CompatMode::Native => {
392                wrap_err(
393                    writer.write_event(Event::Empty(BytesStart::new(tag_name))),
394                    path,
395                )?;
396            }
397            CompatMode::Obj2Xml => {
398                wrap_err(
399                    writer.write_event(Event::Start(BytesStart::new(tag_name))),
400                    path,
401                )?;
402                wrap_err(
403                    writer.write_event(Event::End(BytesEnd::new(tag_name))),
404                    path,
405                )?;
406            }
407        }
408    } else if let Some(text) = py_value_to_cow(py, value, &config.default_func, path)? {
409        let elem = BytesStart::new(tag_name);
410        wrap_err(writer.write_event(Event::Start(elem)), path)?;
411        wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
412        wrap_err(
413            writer.write_event(Event::End(BytesEnd::new(tag_name))),
414            path,
415        )?;
416    }
417    Ok(())
418}
419
420fn generate_xml<W: Write>(
421    writer: W,
422    input: &Bound<'_, PyAny>,
423    encoding: &str,
424    full_document: bool,
425    pretty: bool,
426    indent: &str,
427    config: &Config,
428    compat: CompatMode,
429    attr_prefix: &str,
430) -> PyResult<W> {
431    let mut writer = if pretty {
432        Writer::new_with_indent(writer, b' ', indent.len())
433    } else {
434        Writer::new(writer)
435    };
436    if full_document {
437        writer
438            .write_event(Event::Decl(BytesDecl::new("1.0", Some(encoding), None)))
439            .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
440    }
441    let mut ns = NamespaceContext::new(&config.namespaces);
442    let mut path = Vec::with_capacity(16);
443    let mut visited = FxHashSet::default();
444
445    if let Ok(dict) = input.cast::<PyDict>() {
446        let roots = dict
447            .iter()
448            .filter(|(k, _)| !k.to_string().starts_with(attr_prefix))
449            .count();
450        if full_document && roots != 1 {
451            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
452                "Document must have exactly one root",
453            ));
454        }
455        for (k, v) in dict {
456            let key_cow = extract_str(&k)?;
457            let key = key_cow.as_ref();
458            if key.starts_with(attr_prefix) {
459                continue;
460            }
461            path.push(key.to_string());
462            process_node(
463                &mut writer,
464                key,
465                &v,
466                config,
467                compat,
468                &mut ns,
469                true,
470                &mut path,
471                &mut visited,
472            )?;
473            path.pop();
474        }
475    } else {
476        let iter = PyIterator::from_object(input)?;
477        let mut i = 0;
478        for item in iter {
479            let obj = item?;
480            path.push(format!("[{}]", i));
481            if let Ok(d) = obj.cast::<PyDict>() {
482                for (k, v) in d {
483                    let k_cow = extract_str(&k)?;
484                    let k_str = k_cow.as_ref();
485                    path.push(k_str.to_string());
486                    process_node(
487                        &mut writer,
488                        k_str,
489                        &v,
490                        config,
491                        compat,
492                        &mut ns,
493                        true,
494                        &mut path,
495                        &mut visited,
496                    )?;
497                    path.pop();
498                }
499            } else {
500                process_node(
501                    &mut writer,
502                    &config.item_name,
503                    &obj,
504                    config,
505                    compat,
506                    &mut ns,
507                    true,
508                    &mut path,
509                    &mut visited,
510                )?;
511            }
512            path.pop();
513            i += 1;
514        }
515    }
516    Ok(writer.into_inner())
517}
518
519#[pyfunction]
520#[pyo3(signature = (
521    input, *,
522    output=None,
523    encoding="utf-8",
524    full_document=true,
525    attr_prefix="@",
526    cdata_key="#text",
527    pretty=false, indent="  ",
528    compat="native",
529    streaming=false,
530    default=None,
531    item_name="item",
532    sort_attributes=false,
533    namespaces=None
534))]
535fn unparse(
536    _py: Python<'_>,
537    input: &Bound<'_, PyAny>,
538    output: Option<&Bound<'_, PyAny>>,
539    encoding: &str,
540    full_document: bool,
541    attr_prefix: &str,
542    cdata_key: &str,
543    pretty: bool,
544    indent: &str,
545    compat: &str,
546    streaming: bool,
547    default: Option<Py<PyAny>>,
548    item_name: &str,
549    sort_attributes: bool,
550    namespaces: Option<FxHashMap<String, String>>,
551) -> PyResult<String> {
552    let config = Config {
553        attr_prefix: attr_prefix.to_string(),
554        cdata_key: cdata_key.to_string(),
555        default_func: default,
556        item_name: item_name.to_string(),
557        sort_attrs: sort_attributes,
558        namespaces: namespaces.unwrap_or_default(),
559    };
560    let compat_mode = if compat == "legacy" {
561        CompatMode::Obj2Xml
562    } else {
563        CompatMode::Native
564    };
565
566    if streaming {
567        let out_obj = output.ok_or_else(|| {
568            PyErr::new::<pyo3::exceptions::PyValueError, _>(
569                "streaming=True requires output argument",
570            )
571        })?;
572        let sink: Box<dyn Write> = if let Ok(path) = out_obj.extract::<String>() {
573            let f = File::create(path)?;
574            Box::new(BufWriter::new(f))
575        } else {
576            Box::new(PyWriter {
577                obj: out_obj.clone(),
578            })
579        };
580        wrap_err(
581            generate_xml(
582                sink,
583                input,
584                encoding,
585                full_document,
586                pretty,
587                indent,
588                &config,
589                compat_mode,
590                attr_prefix,
591            ),
592            &[],
593        )?;
594        return Ok(String::new());
595    }
596
597    let cursor = Cursor::new(Vec::with_capacity(32 * 1024));
598    let cursor = wrap_py_err(
599        generate_xml(
600            cursor,
601            input,
602            encoding,
603            full_document,
604            pretty,
605            indent,
606            &config,
607            compat_mode,
608            attr_prefix,
609        ),
610        &[],
611    )?;
612    let xml = String::from_utf8(cursor.into_inner())
613        .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
614    if let Some(out_obj) = output {
615        if let Ok(path) = out_obj.extract::<String>() {
616            std::fs::write(path, &xml)?;
617        } else {
618            out_obj.call_method1("write", (xml.as_bytes(),))?;
619        }
620        return Ok(String::new());
621    }
622    Ok(xml)
623}
624
625#[pymodule]
626fn _obj2xml_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
627    m.add_function(wrap_pyfunction!(unparse, m)?)?;
628    Ok(())
629}