_obj2xml_rs/
lib.rs

1use pyo3::prelude::*;
2use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyIterator, PyList, PyNone, PyString};
3use quick_xml::Writer;
4use quick_xml::events::{BytesCData, BytesDecl, BytesEnd, BytesPI, BytesStart, BytesText, Event};
5use std::borrow::Cow;
6use std::collections::{HashMap, HashSet};
7use std::fs::File;
8use std::io::{BufWriter, Cursor, Write};
9
10struct Config {
11    attr_prefix: String,
12    cdata_key: String,
13    default_func: Option<Py<PyAny>>,
14    item_name: String,
15    sort_attrs: bool,
16    namespaces: HashMap<String, String>,
17}
18
19#[derive(Clone, Copy, PartialEq)]
20enum CompatMode {
21    Native,
22    Obj2Xml,
23}
24
25fn extract_str<'py>(value: &'py pyo3::Bound<'py, pyo3::PyAny>) -> PyResult<Cow<'py, str>> {
26    if let Ok(pystr) = value.cast::<PyString>() {
27        let slice = pystr.to_str()?;
28        Ok(Cow::Borrowed(slice))
29    } else {
30        Ok(Cow::Owned(value.to_string()))
31    }
32}
33
34struct PyWriter<'py> {
35    obj: Bound<'py, PyAny>,
36}
37
38impl<'py> Write for PyWriter<'py> {
39    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
40        self.obj
41            .call_method1("write", (buf,))
42            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
43        Ok(buf.len())
44    }
45    fn flush(&mut self) -> std::io::Result<()> {
46        if self.obj.hasattr("flush").unwrap_or(false) {
47            self.obj
48                .call_method0("flush")
49                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
50        }
51        Ok(())
52    }
53}
54
55#[derive(Clone)]
56struct NamespaceContext {
57    uri_to_prefix: HashMap<String, String>,
58    prefix_to_uri: HashMap<String, String>,
59    next_auto: usize,
60}
61
62impl NamespaceContext {
63    fn new(predefined: &HashMap<String, String>) -> Self {
64        let mut uri_to_prefix = HashMap::new();
65        let mut prefix_to_uri = HashMap::new();
66        for (prefix, uri) in predefined {
67            uri_to_prefix.insert(uri.clone(), prefix.clone());
68            prefix_to_uri.insert(prefix.clone(), uri.clone());
69        }
70        Self {
71            uri_to_prefix,
72            prefix_to_uri,
73            next_auto: 0,
74        }
75    }
76
77    fn get_or_create_prefix(&mut self, uri: &str) -> String {
78        if let Some(p) = self.uri_to_prefix.get(uri) {
79            return p.clone();
80        }
81        let prefix = format!("ns{}", self.next_auto);
82        self.next_auto += 1;
83        self.uri_to_prefix.insert(uri.to_string(), prefix.clone());
84        self.prefix_to_uri.insert(prefix.clone(), uri.to_string());
85        prefix
86    }
87}
88
89fn format_path(stack: &[String]) -> String {
90    if stack.is_empty() {
91        "root".to_string()
92    } else {
93        stack.join("/")
94    }
95}
96
97fn wrap_err<T, E: std::fmt::Display>(res: Result<T, E>, stack: &[String]) -> PyResult<T> {
98    res.map_err(|e| {
99        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
100            "{} (at {})",
101            e,
102            format_path(stack)
103        ))
104    })
105}
106
107fn wrap_py_err<T>(res: PyResult<T>, stack: &[String]) -> PyResult<T> {
108    res.map_err(|e| {
109        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
110            "{} (at {})",
111            e,
112            format_path(stack)
113        ))
114    })
115}
116
117fn py_value_to_cow<'py>(
118    py: Python<'py>,
119    value: &'py pyo3::Bound<'py, pyo3::PyAny>,
120    default_func: &Option<Py<PyAny>>,
121    path: &[String],
122) -> PyResult<Option<Cow<'py, str>>> {
123    if value.is_instance_of::<PyNone>() {
124        return Ok(None);
125    }
126    if value.is_instance_of::<PyBool>() {
127        return Ok(Some(if value.extract::<bool>()? {
128            "true".into()
129        } else {
130            "false".into()
131        }));
132    }
133    if let Ok(pystr) = value.cast::<PyString>() {
134        return Ok(Some(Cow::Borrowed(pystr.to_str()?)));
135    }
136    if value.is_instance_of::<PyInt>() || value.is_instance_of::<PyFloat>() {
137        return Ok(Some(Cow::Owned(value.to_string())));
138    }
139
140    if let Some(func) = default_func {
141        match func.call1(py, (value,)) {
142            Ok(serialized) => return Ok(Some(Cow::Owned(serialized.to_string()))),
143            Err(e) => {
144                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
145                    "Custom serialization failed: {} (at {})",
146                    e,
147                    format_path(path)
148                )));
149            }
150        }
151    }
152    Ok(Some(Cow::Owned(value.to_string())))
153}
154
155fn qualify_tag<'a>(
156    tag: &'a str,
157    dict: Option<&Bound<'_, PyDict>>,
158    ns: &mut NamespaceContext,
159) -> Cow<'a, str> {
160    if tag.contains(':') {
161        return Cow::Borrowed(tag);
162    }
163    if let Some(d) = dict {
164        if let Ok(Some(ns_val)) = d.get_item("@ns") {
165            let ns_str = ns_val.to_string();
166            if ns.prefix_to_uri.contains_key(&ns_str) {
167                return Cow::Owned(format!("{}:{}", ns_str, tag));
168            }
169            let prefix = ns.get_or_create_prefix(&ns_str);
170            return Cow::Owned(format!("{}:{}", prefix, tag));
171        }
172    }
173    Cow::Borrowed(tag)
174}
175
176fn process_node<W: Write>(
177    writer: &mut Writer<W>,
178    tag_name: &str,
179    value: &Bound<'_, PyAny>,
180    config: &Config,
181    compat: CompatMode,
182    ns: &mut NamespaceContext,
183    is_root: bool,
184    path: &mut Vec<String>,
185    visited: &mut HashSet<usize>,
186) -> PyResult<()> {
187    let py = value.py();
188
189    if value.is_instance_of::<PyString>()
190        || value.is_instance_of::<PyBool>()
191        || value.is_instance_of::<PyNone>()
192        || value.is_instance_of::<PyInt>()
193        || value.is_instance_of::<PyFloat>()
194    {
195    } else if let Ok(dict) = value.cast::<PyDict>() {
196        let ptr = dict.as_ptr() as usize;
197        if !visited.insert(ptr) {
198            return Err(PyErr::new::<pyo3::exceptions::PyRecursionError, _>(
199                format!("Circular reference detected (at {})", format_path(path)),
200            ));
201        }
202
203        let mut tail_text: Option<String> = None;
204        if let Ok(Some(tail_val)) = dict.get_item("#tail") {
205            path.push("#tail".to_string());
206            if let Some(cow) = py_value_to_cow(py, &tail_val, &config.default_func, path)? {
207                tail_text = Some(cow.into_owned());
208            }
209            path.pop();
210        }
211        let mut ns_local = ns.clone();
212        let qualified = qualify_tag(tag_name, Some(&dict), &mut ns_local);
213        let mut elem = BytesStart::new(qualified.as_ref());
214        let mut attrs = Vec::new();
215        let mut xmlns_attrs = Vec::new();
216
217        if is_root {
218            for (prefix, uri) in &config.namespaces {
219                if prefix.is_empty() {
220                    xmlns_attrs.push(("xmlns".to_string(), uri.clone()));
221                } else {
222                    xmlns_attrs.push((format!("xmlns:{}", prefix), uri.clone()));
223                }
224            }
225        }
226
227        for (k, v) in dict {
228            let k_cow = extract_str(&k)?;
229            let k_str = k_cow.as_ref();
230            if k_str == "#comment" || k_str == "#tail" || k_str.starts_with("?") {
231                continue;
232            }
233
234            if k_str == "@xmlns" {
235                let uri = v.to_string();
236                if !config.namespaces.values().any(|u| u == &uri) {
237                    let prefix = ns_local.get_or_create_prefix(&uri);
238                    xmlns_attrs.push((format!("xmlns:{}", prefix), uri));
239                }
240            } else if let Some(p) = k_str.strip_prefix("@xmlns:") {
241                let uri = v.to_string();
242                ns_local.uri_to_prefix.insert(uri.clone(), p.to_string());
243                ns_local.prefix_to_uri.insert(p.to_string(), uri.clone());
244                xmlns_attrs.push((format!("xmlns:{}", p), uri));
245            } else if let Some(attr) = k_str.strip_prefix(&config.attr_prefix) {
246                path.push(format!("@{}", attr));
247                if let Some(val) = py_value_to_cow(py, &v, &config.default_func, path)? {
248                    attrs.push((attr.to_string(), val.into_owned()));
249                }
250                path.pop();
251            }
252        }
253
254        if config.sort_attrs {
255            attrs.sort_by(|a, b| a.0.cmp(&b.0));
256            xmlns_attrs.sort_by(|a, b| a.0.cmp(&b.0));
257        }
258        for (k, v) in xmlns_attrs {
259            elem.push_attribute((k.as_str(), v.as_str()));
260        }
261        for (k, v) in attrs {
262            elem.push_attribute((k.as_str(), v.as_str()));
263        }
264
265        let has_children = dict.iter().any(|(k, _)| {
266            let k_cow = extract_str(&k).unwrap_or(Cow::Borrowed(""));
267            let ks = k_cow.as_ref();
268            !ks.starts_with(&config.attr_prefix) && ks != "#tail"
269        });
270
271        if !has_children {
272            wrap_err(writer.write_event(Event::Empty(elem)), path)?;
273            visited.remove(&ptr);
274            // Write tail after self-closing
275            if let Some(text) = tail_text {
276                wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
277            }
278            return Ok(());
279        }
280
281        wrap_err(writer.write_event(Event::Start(elem)), path)?;
282
283        for (k, v) in dict {
284            let k_cow = extract_str(&k)?;
285            let k_str = k_cow.as_ref();
286
287            if k_str.starts_with(&config.attr_prefix) {
288                continue;
289            }
290            if k_str == "#tail" {
291                continue;
292            }
293
294            path.push(k_str.to_string());
295
296            if k_str == "#comment" {
297                if let Some(comment_txt) = py_value_to_cow(py, &v, &config.default_func, path)? {
298                    wrap_err(
299                        writer.write_event(Event::Comment(BytesText::new(&comment_txt))),
300                        path,
301                    )?;
302                }
303            } else if k_str.starts_with("?") {
304                if let Some(content) = py_value_to_cow(py, &v, &config.default_func, path)? {
305                    let target = k_str.strip_prefix("?").unwrap_or(&k_str);
306                    let pi_content = format!("{} {}", target, content);
307                    wrap_err(
308                        writer.write_event(Event::PI(BytesPI::new(&pi_content))),
309                        path,
310                    )?;
311                }
312            } else if k_str == config.cdata_key {
313                let mut written = false;
314                if let Ok(inner) = v.cast::<PyDict>() {
315                    if let Ok(Some(cdata)) = inner.get_item("__cdata__") {
316                        let s = cdata.to_string();
317                        wrap_err(writer.write_event(Event::CData(BytesCData::new(&s))), path)?;
318                        written = true;
319                    }
320                }
321                if !written {
322                    if let Some(text) = py_value_to_cow(py, &v, &config.default_func, path)? {
323                        wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
324                    }
325                }
326            } else {
327                process_node(
328                    writer,
329                    k_str,
330                    &v,
331                    config,
332                    compat,
333                    &mut ns_local,
334                    false,
335                    path,
336                    visited,
337                )?;
338            }
339            path.pop();
340        }
341        wrap_err(
342            writer.write_event(Event::End(BytesEnd::new(qualified.as_ref()))),
343            path,
344        )?;
345        visited.remove(&ptr);
346
347        if let Some(text) = tail_text {
348            wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
349        }
350
351        return Ok(());
352    } else if let Ok(list) = value.cast::<PyList>() {
353        let ptr = list.as_ptr() as usize;
354        if !visited.insert(ptr) {
355            return Err(PyErr::new::<pyo3::exceptions::PyRecursionError, _>(
356                format!("Circular ref (at {})", format_path(path)),
357            ));
358        }
359        for (i, item) in list.iter().enumerate() {
360            path.push(format!("[{}]", i));
361            process_node(
362                writer, tag_name, &item, config, compat, ns, is_root, path, visited,
363            )?;
364            path.pop();
365        }
366        visited.remove(&ptr);
367        return Ok(());
368    } else if let Ok(iter) = PyIterator::from_object(value) {
369        let mut i = 0;
370        for item in iter {
371            let obj = item?;
372            path.push(format!("[{}]", i));
373            process_node(
374                writer, tag_name, &obj, config, compat, ns, is_root, path, visited,
375            )?;
376            path.pop();
377            i += 1;
378        }
379        return Ok(());
380    }
381
382    if value.is_instance_of::<PyNone>() {
383        match compat {
384            CompatMode::Native => {
385                wrap_err(
386                    writer.write_event(Event::Empty(BytesStart::new(tag_name))),
387                    path,
388                )?;
389            }
390            CompatMode::Obj2Xml => {
391                wrap_err(
392                    writer.write_event(Event::Start(BytesStart::new(tag_name))),
393                    path,
394                )?;
395                wrap_err(
396                    writer.write_event(Event::End(BytesEnd::new(tag_name))),
397                    path,
398                )?;
399            }
400        }
401    } else if let Some(text) = py_value_to_cow(py, value, &config.default_func, path)? {
402        let elem = BytesStart::new(tag_name);
403        wrap_err(writer.write_event(Event::Start(elem)), path)?;
404        wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
405        wrap_err(
406            writer.write_event(Event::End(BytesEnd::new(tag_name))),
407            path,
408        )?;
409    }
410    Ok(())
411}
412
413fn generate_xml<W: Write>(
414    writer: W,
415    input: &Bound<'_, PyAny>,
416    encoding: &str,
417    full_document: bool,
418    pretty: bool,
419    indent: &str,
420    config: &Config,
421    compat: CompatMode,
422    attr_prefix: &str,
423) -> PyResult<W> {
424    let mut writer = if pretty {
425        Writer::new_with_indent(writer, b' ', indent.len())
426    } else {
427        Writer::new(writer)
428    };
429    if full_document {
430        writer
431            .write_event(Event::Decl(BytesDecl::new("1.0", Some(encoding), None)))
432            .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
433    }
434    let mut ns = NamespaceContext::new(&config.namespaces);
435    let mut path = Vec::new();
436    let mut visited = HashSet::new();
437
438    if let Ok(dict) = input.cast::<PyDict>() {
439        let roots = dict
440            .iter()
441            .filter(|(k, _)| !k.to_string().starts_with(attr_prefix))
442            .count();
443        if full_document && roots != 1 {
444            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
445                "Document must have exactly one root",
446            ));
447        }
448        for (k, v) in dict {
449            let key_cow = extract_str(&k)?;
450            let key = key_cow.as_ref();
451            if key.starts_with(attr_prefix) {
452                continue;
453            }
454            path.push(key.to_string());
455            process_node(
456                &mut writer,
457                key,
458                &v,
459                config,
460                compat,
461                &mut ns,
462                true,
463                &mut path,
464                &mut visited,
465            )?;
466            path.pop();
467        }
468    } else {
469        let iter = PyIterator::from_object(input)?;
470        let mut i = 0;
471        for item in iter {
472            let obj = item?;
473            path.push(format!("[{}]", i));
474            if let Ok(d) = obj.cast::<PyDict>() {
475                for (k, v) in d {
476                    let k_cow = extract_str(&k)?;
477                    let k_str = k_cow.as_ref();
478                    path.push(k_str.to_string());
479                    process_node(
480                        &mut writer,
481                        k_str,
482                        &v,
483                        config,
484                        compat,
485                        &mut ns,
486                        true,
487                        &mut path,
488                        &mut visited,
489                    )?;
490                    path.pop();
491                }
492            } else {
493                process_node(
494                    &mut writer,
495                    &config.item_name,
496                    &obj,
497                    config,
498                    compat,
499                    &mut ns,
500                    true,
501                    &mut path,
502                    &mut visited,
503                )?;
504            }
505            path.pop();
506            i += 1;
507        }
508    }
509    Ok(writer.into_inner())
510}
511
512#[pyfunction]
513#[pyo3(signature = (
514    input, *,
515    output=None,
516    encoding="utf-8",
517    full_document=true,
518    attr_prefix="@",
519    cdata_key="#text",
520    pretty=false, indent="  ",
521    compat="native",
522    streaming=false,
523    default=None,
524    item_name="item",
525    sort_attributes=false,
526    namespaces=None
527))]
528fn unparse(
529    _py: Python<'_>,
530    input: &Bound<'_, PyAny>,
531    output: Option<&Bound<'_, PyAny>>,
532    encoding: &str,
533    full_document: bool,
534    attr_prefix: &str,
535    cdata_key: &str,
536    pretty: bool,
537    indent: &str,
538    compat: &str,
539    streaming: bool,
540    default: Option<Py<PyAny>>,
541    item_name: &str,
542    sort_attributes: bool,
543    namespaces: Option<HashMap<String, String>>,
544) -> PyResult<String> {
545    let config = Config {
546        attr_prefix: attr_prefix.to_string(),
547        cdata_key: cdata_key.to_string(),
548        default_func: default,
549        item_name: item_name.to_string(),
550        sort_attrs: sort_attributes,
551        namespaces: namespaces.unwrap_or_default(),
552    };
553    let compat_mode = if compat == "legacy" {
554        CompatMode::Obj2Xml
555    } else {
556        CompatMode::Native
557    };
558
559    if streaming {
560        let out_obj = output.ok_or_else(|| {
561            PyErr::new::<pyo3::exceptions::PyValueError, _>(
562                "streaming=True requires output argument",
563            )
564        })?;
565        let sink: Box<dyn Write> = if let Ok(path) = out_obj.extract::<String>() {
566            let f = File::create(path)?;
567            Box::new(BufWriter::new(f))
568        } else {
569            Box::new(PyWriter {
570                obj: out_obj.clone(),
571            })
572        };
573        wrap_err(
574            generate_xml(
575                sink,
576                input,
577                encoding,
578                full_document,
579                pretty,
580                indent,
581                &config,
582                compat_mode,
583                attr_prefix,
584            ),
585            &[],
586        )?;
587        return Ok(String::new());
588    }
589
590    let cursor = Cursor::new(Vec::with_capacity(1024));
591    let cursor = wrap_py_err(
592        generate_xml(
593            cursor,
594            input,
595            encoding,
596            full_document,
597            pretty,
598            indent,
599            &config,
600            compat_mode,
601            attr_prefix,
602        ),
603        &[],
604    )?;
605    let xml = String::from_utf8(cursor.into_inner())
606        .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
607    if let Some(out_obj) = output {
608        if let Ok(path) = out_obj.extract::<String>() {
609            std::fs::write(path, &xml)?;
610        } else {
611            out_obj.call_method1("write", (xml.as_bytes(),))?;
612        }
613        return Ok(String::new());
614    }
615    Ok(xml)
616}
617
618#[pymodule]
619fn _obj2xml_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
620    m.add_function(wrap_pyfunction!(unparse, m)?)?;
621    Ok(())
622}