1use pyo3::prelude::*;
2use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyIterator, PyList, PyNone, PyString};
3use quick_xml::Writer;
4use quick_xml::events::{BytesCData, BytesDecl, BytesEnd, BytesPI, BytesStart, BytesText, Event};
5use rustc_hash::{FxHashMap, FxHashSet};
6use std::borrow::Cow;
7use std::fs::File;
8use std::io::{BufWriter, Cursor, Write};
9
10struct Config {
11 attr_prefix: String,
12 cdata_key: String,
13 default_func: Option<Py<PyAny>>,
14 item_name: String,
15 sort_attrs: bool,
16 namespaces: FxHashMap<String, String>,
17}
18
19#[derive(Clone, Copy, PartialEq)]
20enum CompatMode {
21 Native,
22 Obj2Xml,
23}
24
25#[inline]
26fn extract_str<'py>(value: &'py pyo3::Bound<'py, pyo3::PyAny>) -> PyResult<Cow<'py, str>> {
27 if let Ok(pystr) = value.cast::<PyString>() {
28 let slice = pystr.to_str()?;
29 Ok(Cow::Borrowed(slice))
30 } else {
31 Ok(Cow::Owned(value.to_string()))
32 }
33}
34
35struct PyWriter<'py> {
36 obj: Bound<'py, PyAny>,
37}
38
39impl<'py> Write for PyWriter<'py> {
40 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
41 self.obj
42 .call_method1("write", (buf,))
43 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
44 Ok(buf.len())
45 }
46 fn flush(&mut self) -> std::io::Result<()> {
47 if self.obj.hasattr("flush").unwrap_or(false) {
48 self.obj
49 .call_method0("flush")
50 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
51 }
52 Ok(())
53 }
54}
55
56#[derive(Clone)]
57struct NamespaceContext {
58 uri_to_prefix: FxHashMap<String, String>,
59 prefix_to_uri: FxHashMap<String, String>,
60 next_auto: usize,
61}
62
63impl NamespaceContext {
64 fn new(predefined: &FxHashMap<String, String>) -> Self {
65 let mut uri_to_prefix = FxHashMap::default();
66 let mut prefix_to_uri = FxHashMap::default();
67 for (prefix, uri) in predefined {
68 uri_to_prefix.insert(uri.clone(), prefix.clone());
69 prefix_to_uri.insert(prefix.clone(), uri.clone());
70 }
71 Self {
72 uri_to_prefix,
73 prefix_to_uri,
74 next_auto: 0,
75 }
76 }
77
78 fn get_or_create_prefix(&mut self, uri: &str) -> String {
79 if let Some(p) = self.uri_to_prefix.get(uri) {
80 return p.clone();
81 }
82 let prefix = format!("ns{}", self.next_auto);
83 self.next_auto += 1;
84 self.uri_to_prefix.insert(uri.to_string(), prefix.clone());
85 self.prefix_to_uri.insert(prefix.clone(), uri.to_string());
86 prefix
87 }
88}
89
90fn format_path(stack: &[String]) -> String {
91 if stack.is_empty() {
92 "root".to_string()
93 } else {
94 stack.join("/")
95 }
96}
97
98#[inline(always)]
99fn wrap_err<T, E: std::fmt::Display>(res: Result<T, E>, stack: &[String]) -> PyResult<T> {
100 res.map_err(|e| {
101 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
102 "{} (at {})",
103 e,
104 format_path(stack)
105 ))
106 })
107}
108
109#[inline(always)]
110fn wrap_py_err<T>(res: PyResult<T>, stack: &[String]) -> PyResult<T> {
111 res.map_err(|e| {
112 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
113 "{} (at {})",
114 e,
115 format_path(stack)
116 ))
117 })
118}
119
120#[inline]
121fn py_value_to_cow<'py>(
122 py: Python<'py>,
123 value: &'py pyo3::Bound<'py, pyo3::PyAny>,
124 default_func: &Option<Py<PyAny>>,
125 path: &[String],
126) -> PyResult<Option<Cow<'py, str>>> {
127 if value.is_instance_of::<PyNone>() {
128 return Ok(None);
129 }
130 if value.is_instance_of::<PyBool>() {
131 return Ok(Some(if value.extract::<bool>()? {
132 "true".into()
133 } else {
134 "false".into()
135 }));
136 }
137 if let Ok(pystr) = value.cast::<PyString>() {
138 return Ok(Some(Cow::Borrowed(pystr.to_str()?)));
139 }
140 if value.is_instance_of::<PyInt>() || value.is_instance_of::<PyFloat>() {
141 return Ok(Some(Cow::Owned(value.to_string())));
142 }
143
144 if let Some(func) = default_func {
145 match func.call1(py, (value,)) {
146 Ok(serialized) => return Ok(Some(Cow::Owned(serialized.to_string()))),
147 Err(e) => {
148 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
149 "Custom serialization failed: {} (at {})",
150 e,
151 format_path(path)
152 )));
153 }
154 }
155 }
156 Ok(Some(Cow::Owned(value.to_string())))
157}
158
159#[inline]
160fn qualify_tag<'a>(
161 tag: &'a str,
162 dict: Option<&Bound<'_, PyDict>>,
163 ns: &mut NamespaceContext,
164) -> Cow<'a, str> {
165 if tag.contains(':') {
166 return Cow::Borrowed(tag);
167 }
168 if let Some(d) = dict {
169 if let Ok(Some(ns_val)) = d.get_item("@ns") {
170 let ns_str = ns_val.to_string();
171 if ns.prefix_to_uri.contains_key(&ns_str) {
172 return Cow::Owned(format!("{}:{}", ns_str, tag));
173 }
174 let prefix = ns.get_or_create_prefix(&ns_str);
175 return Cow::Owned(format!("{}:{}", prefix, tag));
176 }
177 }
178 Cow::Borrowed(tag)
179}
180
181fn process_node<W: Write>(
182 writer: &mut Writer<W>,
183 tag_name: &str,
184 value: &Bound<'_, PyAny>,
185 config: &Config,
186 compat: CompatMode,
187 ns: &mut NamespaceContext,
188 is_root: bool,
189 path: &mut Vec<String>,
190 visited: &mut FxHashSet<usize>,
191) -> PyResult<()> {
192 let py = value.py();
193
194 if value.is_instance_of::<PyString>()
195 || value.is_instance_of::<PyBool>()
196 || value.is_instance_of::<PyNone>()
197 || value.is_instance_of::<PyInt>()
198 || value.is_instance_of::<PyFloat>()
199 {
200 } else if let Ok(dict) = value.cast::<PyDict>() {
201 let ptr = dict.as_ptr() as usize;
202 if !visited.insert(ptr) {
203 return Err(PyErr::new::<pyo3::exceptions::PyRecursionError, _>(
204 format!("Circular reference detected (at {})", format_path(path)),
205 ));
206 }
207
208 let mut tail_text: Option<String> = None;
209 if let Ok(Some(tail_val)) = dict.get_item("#tail") {
210 path.push("#tail".to_string());
211 if let Some(cow) = py_value_to_cow(py, &tail_val, &config.default_func, path)? {
212 tail_text = Some(cow.into_owned());
213 }
214 path.pop();
215 }
216 let mut ns_local = ns.clone();
217 let qualified = qualify_tag(tag_name, Some(&dict), &mut ns_local);
218 let mut elem = BytesStart::new(qualified.as_ref());
219 let mut attrs = Vec::new();
220 let mut xmlns_attrs = Vec::new();
221
222 if is_root {
223 for (prefix, uri) in &config.namespaces {
224 if prefix.is_empty() {
225 xmlns_attrs.push(("xmlns".to_string(), uri.clone()));
226 } else {
227 xmlns_attrs.push((format!("xmlns:{}", prefix), uri.clone()));
228 }
229 }
230 }
231
232 for (k, v) in dict {
233 let k_cow = extract_str(&k)?;
234 let k_str = k_cow.as_ref();
235 if k_str == "#comment" || k_str == "#tail" || k_str.starts_with("?") {
236 continue;
237 }
238
239 if k_str == "@xmlns" {
240 let uri = v.to_string();
241 if !config.namespaces.values().any(|u| u == &uri) {
242 let prefix = ns_local.get_or_create_prefix(&uri);
243 xmlns_attrs.push((format!("xmlns:{}", prefix), uri));
244 }
245 } else if let Some(p) = k_str.strip_prefix("@xmlns:") {
246 let uri = v.to_string();
247 ns_local.uri_to_prefix.insert(uri.clone(), p.to_string());
248 ns_local.prefix_to_uri.insert(p.to_string(), uri.clone());
249 xmlns_attrs.push((format!("xmlns:{}", p), uri));
250 } else if let Some(attr) = k_str.strip_prefix(&config.attr_prefix) {
251 path.push(format!("@{}", attr));
252 if let Some(val) = py_value_to_cow(py, &v, &config.default_func, path)? {
253 attrs.push((attr.to_string(), val.into_owned()));
254 }
255 path.pop();
256 }
257 }
258
259 if config.sort_attrs {
260 attrs.sort_by(|a, b| a.0.cmp(&b.0));
261 xmlns_attrs.sort_by(|a, b| a.0.cmp(&b.0));
262 }
263 for (k, v) in xmlns_attrs {
264 elem.push_attribute((k.as_str(), v.as_str()));
265 }
266 for (k, v) in attrs {
267 elem.push_attribute((k.as_str(), v.as_str()));
268 }
269
270 let has_children = dict.iter().any(|(k, _)| {
271 let k_cow = k
272 .cast::<PyString>()
273 .map(|s| s.to_string_lossy())
274 .unwrap_or_default();
275 let ks = k_cow.as_ref();
276 !ks.starts_with(&config.attr_prefix) && ks != "#tail"
277 });
278
279 if !has_children {
280 wrap_err(writer.write_event(Event::Empty(elem)), path)?;
281 visited.remove(&ptr);
282 if let Some(text) = tail_text {
283 wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
284 }
285 return Ok(());
286 }
287
288 wrap_err(writer.write_event(Event::Start(elem)), path)?;
289
290 for (k, v) in dict {
291 let k_cow = extract_str(&k)?;
292 let k_str = k_cow.as_ref();
293
294 if k_str.starts_with(&config.attr_prefix) {
295 continue;
296 }
297 if k_str == "#tail" {
298 continue;
299 }
300
301 path.push(k_str.to_string());
302
303 if k_str == "#comment" {
304 if let Some(comment_txt) = py_value_to_cow(py, &v, &config.default_func, path)? {
305 wrap_err(
306 writer.write_event(Event::Comment(BytesText::new(&comment_txt))),
307 path,
308 )?;
309 }
310 } else if k_str.starts_with("?") {
311 if let Some(content) = py_value_to_cow(py, &v, &config.default_func, path)? {
312 let target = k_str.strip_prefix("?").unwrap_or(&k_str);
313 let pi_content = format!("{} {}", target, content);
314 wrap_err(
315 writer.write_event(Event::PI(BytesPI::new(&pi_content))),
316 path,
317 )?;
318 }
319 } else if k_str == config.cdata_key {
320 let mut written = false;
321 if let Ok(inner) = v.cast::<PyDict>() {
322 if let Ok(Some(cdata)) = inner.get_item("__cdata__") {
323 let s = cdata.to_string();
324 wrap_err(writer.write_event(Event::CData(BytesCData::new(&s))), path)?;
325 written = true;
326 }
327 }
328 if !written {
329 if let Some(text) = py_value_to_cow(py, &v, &config.default_func, path)? {
330 wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
331 }
332 }
333 } else {
334 process_node(
335 writer,
336 k_str,
337 &v,
338 config,
339 compat,
340 &mut ns_local,
341 false,
342 path,
343 visited,
344 )?;
345 }
346 path.pop();
347 }
348 wrap_err(
349 writer.write_event(Event::End(BytesEnd::new(qualified.as_ref()))),
350 path,
351 )?;
352 visited.remove(&ptr);
353
354 if let Some(text) = tail_text {
355 wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
356 }
357
358 return Ok(());
359 } else if let Ok(list) = value.cast::<PyList>() {
360 let ptr = list.as_ptr() as usize;
361 if !visited.insert(ptr) {
362 return Err(PyErr::new::<pyo3::exceptions::PyRecursionError, _>(
363 format!("Circular ref (at {})", format_path(path)),
364 ));
365 }
366 for (i, item) in list.iter().enumerate() {
367 path.push(format!("[{}]", i));
368 process_node(
369 writer, tag_name, &item, config, compat, ns, is_root, path, visited,
370 )?;
371 path.pop();
372 }
373 visited.remove(&ptr);
374 return Ok(());
375 } else if let Ok(iter) = PyIterator::from_object(value) {
376 let mut i = 0;
377 for item in iter {
378 let obj = item?;
379 path.push(format!("[{}]", i));
380 process_node(
381 writer, tag_name, &obj, config, compat, ns, is_root, path, visited,
382 )?;
383 path.pop();
384 i += 1;
385 }
386 return Ok(());
387 }
388
389 if value.is_instance_of::<PyNone>() {
390 match compat {
391 CompatMode::Native => {
392 wrap_err(
393 writer.write_event(Event::Empty(BytesStart::new(tag_name))),
394 path,
395 )?;
396 }
397 CompatMode::Obj2Xml => {
398 wrap_err(
399 writer.write_event(Event::Start(BytesStart::new(tag_name))),
400 path,
401 )?;
402 wrap_err(
403 writer.write_event(Event::End(BytesEnd::new(tag_name))),
404 path,
405 )?;
406 }
407 }
408 } else if let Some(text) = py_value_to_cow(py, value, &config.default_func, path)? {
409 let elem = BytesStart::new(tag_name);
410 wrap_err(writer.write_event(Event::Start(elem)), path)?;
411 wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
412 wrap_err(
413 writer.write_event(Event::End(BytesEnd::new(tag_name))),
414 path,
415 )?;
416 }
417 Ok(())
418}
419
420fn generate_xml<W: Write>(
421 writer: W,
422 input: &Bound<'_, PyAny>,
423 encoding: &str,
424 full_document: bool,
425 pretty: bool,
426 indent: &str,
427 config: &Config,
428 compat: CompatMode,
429 attr_prefix: &str,
430) -> PyResult<W> {
431 let mut writer = if pretty {
432 Writer::new_with_indent(writer, b' ', indent.len())
433 } else {
434 Writer::new(writer)
435 };
436 if full_document {
437 writer
438 .write_event(Event::Decl(BytesDecl::new("1.0", Some(encoding), None)))
439 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
440 }
441 let mut ns = NamespaceContext::new(&config.namespaces);
442 let mut path = Vec::with_capacity(16);
443 let mut visited = FxHashSet::default();
444
445 if let Ok(dict) = input.cast::<PyDict>() {
446 let roots = dict
447 .iter()
448 .filter(|(k, _)| !k.to_string().starts_with(attr_prefix))
449 .count();
450 if full_document && roots != 1 {
451 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
452 "Document must have exactly one root",
453 ));
454 }
455 for (k, v) in dict {
456 let key_cow = extract_str(&k)?;
457 let key = key_cow.as_ref();
458 if key.starts_with(attr_prefix) {
459 continue;
460 }
461 path.push(key.to_string());
462 process_node(
463 &mut writer,
464 key,
465 &v,
466 config,
467 compat,
468 &mut ns,
469 true,
470 &mut path,
471 &mut visited,
472 )?;
473 path.pop();
474 }
475 } else {
476 let iter = PyIterator::from_object(input)?;
477 let mut i = 0;
478 for item in iter {
479 let obj = item?;
480 path.push(format!("[{}]", i));
481 if let Ok(d) = obj.cast::<PyDict>() {
482 for (k, v) in d {
483 let k_cow = extract_str(&k)?;
484 let k_str = k_cow.as_ref();
485 path.push(k_str.to_string());
486 process_node(
487 &mut writer,
488 k_str,
489 &v,
490 config,
491 compat,
492 &mut ns,
493 true,
494 &mut path,
495 &mut visited,
496 )?;
497 path.pop();
498 }
499 } else {
500 process_node(
501 &mut writer,
502 &config.item_name,
503 &obj,
504 config,
505 compat,
506 &mut ns,
507 true,
508 &mut path,
509 &mut visited,
510 )?;
511 }
512 path.pop();
513 i += 1;
514 }
515 }
516 Ok(writer.into_inner())
517}
518
519#[pyfunction]
520#[pyo3(signature = (
521 input, *,
522 output=None,
523 encoding="utf-8",
524 full_document=true,
525 attr_prefix="@",
526 cdata_key="#text",
527 pretty=false, indent=" ",
528 compat="native",
529 streaming=false,
530 default=None,
531 item_name="item",
532 sort_attributes=false,
533 namespaces=None
534))]
535fn unparse(
536 _py: Python<'_>,
537 input: &Bound<'_, PyAny>,
538 output: Option<&Bound<'_, PyAny>>,
539 encoding: &str,
540 full_document: bool,
541 attr_prefix: &str,
542 cdata_key: &str,
543 pretty: bool,
544 indent: &str,
545 compat: &str,
546 streaming: bool,
547 default: Option<Py<PyAny>>,
548 item_name: &str,
549 sort_attributes: bool,
550 namespaces: Option<FxHashMap<String, String>>,
551) -> PyResult<String> {
552 let config = Config {
553 attr_prefix: attr_prefix.to_string(),
554 cdata_key: cdata_key.to_string(),
555 default_func: default,
556 item_name: item_name.to_string(),
557 sort_attrs: sort_attributes,
558 namespaces: namespaces.unwrap_or_default(),
559 };
560 let compat_mode = if compat == "legacy" {
561 CompatMode::Obj2Xml
562 } else {
563 CompatMode::Native
564 };
565
566 if streaming {
567 let out_obj = output.ok_or_else(|| {
568 PyErr::new::<pyo3::exceptions::PyValueError, _>(
569 "streaming=True requires output argument",
570 )
571 })?;
572 let sink: Box<dyn Write> = if let Ok(path) = out_obj.extract::<String>() {
573 let f = File::create(path)?;
574 Box::new(BufWriter::new(f))
575 } else {
576 Box::new(PyWriter {
577 obj: out_obj.clone(),
578 })
579 };
580 wrap_err(
581 generate_xml(
582 sink,
583 input,
584 encoding,
585 full_document,
586 pretty,
587 indent,
588 &config,
589 compat_mode,
590 attr_prefix,
591 ),
592 &[],
593 )?;
594 return Ok(String::new());
595 }
596
597 let cursor = Cursor::new(Vec::with_capacity(32 * 1024));
598 let cursor = wrap_py_err(
599 generate_xml(
600 cursor,
601 input,
602 encoding,
603 full_document,
604 pretty,
605 indent,
606 &config,
607 compat_mode,
608 attr_prefix,
609 ),
610 &[],
611 )?;
612 let xml = String::from_utf8(cursor.into_inner())
613 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
614 if let Some(out_obj) = output {
615 if let Ok(path) = out_obj.extract::<String>() {
616 std::fs::write(path, &xml)?;
617 } else {
618 out_obj.call_method1("write", (xml.as_bytes(),))?;
619 }
620 return Ok(String::new());
621 }
622 Ok(xml)
623}
624
625#[pymodule]
626fn _obj2xml_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
627 m.add_function(wrap_pyfunction!(unparse, m)?)?;
628 Ok(())
629}