1use pyo3::prelude::*;
2use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyIterator, PyList, PyNone, PyString};
3use quick_xml::Writer;
4use quick_xml::events::{BytesCData, BytesDecl, BytesEnd, BytesPI, BytesStart, BytesText, Event};
5use std::borrow::Cow;
6use std::collections::{HashMap, HashSet};
7use std::fs::File;
8use std::io::{BufWriter, Cursor, Write};
9
10struct Config {
11 attr_prefix: String,
12 cdata_key: String,
13 default_func: Option<Py<PyAny>>,
14 item_name: String,
15 sort_attrs: bool,
16 namespaces: HashMap<String, String>,
17}
18
19#[derive(Clone, Copy, PartialEq)]
20enum CompatMode {
21 Native,
22 Obj2Xml,
23}
24
25fn extract_str<'py>(value: &'py pyo3::Bound<'py, pyo3::PyAny>) -> PyResult<Cow<'py, str>> {
26 if let Ok(pystr) = value.cast::<PyString>() {
27 let slice = pystr.to_str()?;
28 Ok(Cow::Borrowed(slice))
29 } else {
30 Ok(Cow::Owned(value.to_string()))
31 }
32}
33
34struct PyWriter<'py> {
35 obj: Bound<'py, PyAny>,
36}
37
38impl<'py> Write for PyWriter<'py> {
39 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
40 self.obj
41 .call_method1("write", (buf,))
42 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
43 Ok(buf.len())
44 }
45 fn flush(&mut self) -> std::io::Result<()> {
46 if self.obj.hasattr("flush").unwrap_or(false) {
47 self.obj
48 .call_method0("flush")
49 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
50 }
51 Ok(())
52 }
53}
54
55#[derive(Clone)]
56struct NamespaceContext {
57 uri_to_prefix: HashMap<String, String>,
58 prefix_to_uri: HashMap<String, String>,
59 next_auto: usize,
60}
61
62impl NamespaceContext {
63 fn new(predefined: &HashMap<String, String>) -> Self {
64 let mut uri_to_prefix = HashMap::new();
65 let mut prefix_to_uri = HashMap::new();
66 for (prefix, uri) in predefined {
67 uri_to_prefix.insert(uri.clone(), prefix.clone());
68 prefix_to_uri.insert(prefix.clone(), uri.clone());
69 }
70 Self {
71 uri_to_prefix,
72 prefix_to_uri,
73 next_auto: 0,
74 }
75 }
76
77 fn get_or_create_prefix(&mut self, uri: &str) -> String {
78 if let Some(p) = self.uri_to_prefix.get(uri) {
79 return p.clone();
80 }
81 let prefix = format!("ns{}", self.next_auto);
82 self.next_auto += 1;
83 self.uri_to_prefix.insert(uri.to_string(), prefix.clone());
84 self.prefix_to_uri.insert(prefix.clone(), uri.to_string());
85 prefix
86 }
87}
88
89fn format_path(stack: &[String]) -> String {
90 if stack.is_empty() {
91 "root".to_string()
92 } else {
93 stack.join("/")
94 }
95}
96
97fn wrap_err<T, E: std::fmt::Display>(res: Result<T, E>, stack: &[String]) -> PyResult<T> {
98 res.map_err(|e| {
99 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
100 "{} (at {})",
101 e,
102 format_path(stack)
103 ))
104 })
105}
106
107fn wrap_py_err<T>(res: PyResult<T>, stack: &[String]) -> PyResult<T> {
108 res.map_err(|e| {
109 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
110 "{} (at {})",
111 e,
112 format_path(stack)
113 ))
114 })
115}
116
117fn py_value_to_cow<'py>(
118 py: Python<'py>,
119 value: &'py pyo3::Bound<'py, pyo3::PyAny>,
120 default_func: &Option<Py<PyAny>>,
121 path: &[String],
122) -> PyResult<Option<Cow<'py, str>>> {
123 if value.is_instance_of::<PyNone>() {
124 return Ok(None);
125 }
126 if value.is_instance_of::<PyBool>() {
127 return Ok(Some(if value.extract::<bool>()? {
128 "true".into()
129 } else {
130 "false".into()
131 }));
132 }
133 if let Ok(pystr) = value.cast::<PyString>() {
134 return Ok(Some(Cow::Borrowed(pystr.to_str()?)));
135 }
136 if value.is_instance_of::<PyInt>() || value.is_instance_of::<PyFloat>() {
137 return Ok(Some(Cow::Owned(value.to_string())));
138 }
139
140 if let Some(func) = default_func {
141 match func.call1(py, (value,)) {
142 Ok(serialized) => return Ok(Some(Cow::Owned(serialized.to_string()))),
143 Err(e) => {
144 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
145 "Custom serialization failed: {} (at {})",
146 e,
147 format_path(path)
148 )));
149 }
150 }
151 }
152 Ok(Some(Cow::Owned(value.to_string())))
153}
154
155fn qualify_tag<'a>(
156 tag: &'a str,
157 dict: Option<&Bound<'_, PyDict>>,
158 ns: &mut NamespaceContext,
159) -> Cow<'a, str> {
160 if tag.contains(':') {
161 return Cow::Borrowed(tag);
162 }
163 if let Some(d) = dict {
164 if let Ok(Some(ns_val)) = d.get_item("@ns") {
165 let ns_str = ns_val.to_string();
166 if ns.prefix_to_uri.contains_key(&ns_str) {
167 return Cow::Owned(format!("{}:{}", ns_str, tag));
168 }
169 let prefix = ns.get_or_create_prefix(&ns_str);
170 return Cow::Owned(format!("{}:{}", prefix, tag));
171 }
172 }
173 Cow::Borrowed(tag)
174}
175
176fn process_node<W: Write>(
177 writer: &mut Writer<W>,
178 tag_name: &str,
179 value: &Bound<'_, PyAny>,
180 config: &Config,
181 compat: CompatMode,
182 ns: &mut NamespaceContext,
183 is_root: bool,
184 path: &mut Vec<String>,
185 visited: &mut HashSet<usize>,
186) -> PyResult<()> {
187 let py = value.py();
188
189 if value.is_instance_of::<PyString>()
190 || value.is_instance_of::<PyBool>()
191 || value.is_instance_of::<PyNone>()
192 || value.is_instance_of::<PyInt>()
193 || value.is_instance_of::<PyFloat>()
194 {
195 } else if let Ok(dict) = value.cast::<PyDict>() {
196 let ptr = dict.as_ptr() as usize;
197 if !visited.insert(ptr) {
198 return Err(PyErr::new::<pyo3::exceptions::PyRecursionError, _>(
199 format!("Circular reference detected (at {})", format_path(path)),
200 ));
201 }
202
203 let mut tail_text: Option<String> = None;
204 if let Ok(Some(tail_val)) = dict.get_item("#tail") {
205 path.push("#tail".to_string());
206 if let Some(cow) = py_value_to_cow(py, &tail_val, &config.default_func, path)? {
207 tail_text = Some(cow.into_owned());
208 }
209 path.pop();
210 }
211 let mut ns_local = ns.clone();
212 let qualified = qualify_tag(tag_name, Some(&dict), &mut ns_local);
213 let mut elem = BytesStart::new(qualified.as_ref());
214 let mut attrs = Vec::new();
215 let mut xmlns_attrs = Vec::new();
216
217 if is_root {
218 for (prefix, uri) in &config.namespaces {
219 if prefix.is_empty() {
220 xmlns_attrs.push(("xmlns".to_string(), uri.clone()));
221 } else {
222 xmlns_attrs.push((format!("xmlns:{}", prefix), uri.clone()));
223 }
224 }
225 }
226
227 for (k, v) in dict {
228 let k_cow = extract_str(&k)?;
229 let k_str = k_cow.as_ref();
230 if k_str == "#comment" || k_str == "#tail" || k_str.starts_with("?") {
231 continue;
232 }
233
234 if k_str == "@xmlns" {
235 let uri = v.to_string();
236 if !config.namespaces.values().any(|u| u == &uri) {
237 let prefix = ns_local.get_or_create_prefix(&uri);
238 xmlns_attrs.push((format!("xmlns:{}", prefix), uri));
239 }
240 } else if let Some(p) = k_str.strip_prefix("@xmlns:") {
241 let uri = v.to_string();
242 ns_local.uri_to_prefix.insert(uri.clone(), p.to_string());
243 ns_local.prefix_to_uri.insert(p.to_string(), uri.clone());
244 xmlns_attrs.push((format!("xmlns:{}", p), uri));
245 } else if let Some(attr) = k_str.strip_prefix(&config.attr_prefix) {
246 path.push(format!("@{}", attr));
247 if let Some(val) = py_value_to_cow(py, &v, &config.default_func, path)? {
248 attrs.push((attr.to_string(), val.into_owned()));
249 }
250 path.pop();
251 }
252 }
253
254 if config.sort_attrs {
255 attrs.sort_by(|a, b| a.0.cmp(&b.0));
256 xmlns_attrs.sort_by(|a, b| a.0.cmp(&b.0));
257 }
258 for (k, v) in xmlns_attrs {
259 elem.push_attribute((k.as_str(), v.as_str()));
260 }
261 for (k, v) in attrs {
262 elem.push_attribute((k.as_str(), v.as_str()));
263 }
264
265 let has_children = dict.iter().any(|(k, _)| {
266 let k_cow = extract_str(&k).unwrap_or(Cow::Borrowed(""));
267 let ks = k_cow.as_ref();
268 !ks.starts_with(&config.attr_prefix) && ks != "#tail"
269 });
270
271 if !has_children {
272 wrap_err(writer.write_event(Event::Empty(elem)), path)?;
273 visited.remove(&ptr);
274 if let Some(text) = tail_text {
276 wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
277 }
278 return Ok(());
279 }
280
281 wrap_err(writer.write_event(Event::Start(elem)), path)?;
282
283 for (k, v) in dict {
284 let k_cow = extract_str(&k)?;
285 let k_str = k_cow.as_ref();
286
287 if k_str.starts_with(&config.attr_prefix) {
288 continue;
289 }
290 if k_str == "#tail" {
291 continue;
292 }
293
294 path.push(k_str.to_string());
295
296 if k_str == "#comment" {
297 if let Some(comment_txt) = py_value_to_cow(py, &v, &config.default_func, path)? {
298 wrap_err(
299 writer.write_event(Event::Comment(BytesText::new(&comment_txt))),
300 path,
301 )?;
302 }
303 } else if k_str.starts_with("?") {
304 if let Some(content) = py_value_to_cow(py, &v, &config.default_func, path)? {
305 let target = k_str.strip_prefix("?").unwrap_or(&k_str);
306 let pi_content = format!("{} {}", target, content);
307 wrap_err(
308 writer.write_event(Event::PI(BytesPI::new(&pi_content))),
309 path,
310 )?;
311 }
312 } else if k_str == config.cdata_key {
313 let mut written = false;
314 if let Ok(inner) = v.cast::<PyDict>() {
315 if let Ok(Some(cdata)) = inner.get_item("__cdata__") {
316 let s = cdata.to_string();
317 wrap_err(writer.write_event(Event::CData(BytesCData::new(&s))), path)?;
318 written = true;
319 }
320 }
321 if !written {
322 if let Some(text) = py_value_to_cow(py, &v, &config.default_func, path)? {
323 wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
324 }
325 }
326 } else {
327 process_node(
328 writer,
329 k_str,
330 &v,
331 config,
332 compat,
333 &mut ns_local,
334 false,
335 path,
336 visited,
337 )?;
338 }
339 path.pop();
340 }
341 wrap_err(
342 writer.write_event(Event::End(BytesEnd::new(qualified.as_ref()))),
343 path,
344 )?;
345 visited.remove(&ptr);
346
347 if let Some(text) = tail_text {
348 wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
349 }
350
351 return Ok(());
352 } else if let Ok(list) = value.cast::<PyList>() {
353 let ptr = list.as_ptr() as usize;
354 if !visited.insert(ptr) {
355 return Err(PyErr::new::<pyo3::exceptions::PyRecursionError, _>(
356 format!("Circular ref (at {})", format_path(path)),
357 ));
358 }
359 for (i, item) in list.iter().enumerate() {
360 path.push(format!("[{}]", i));
361 process_node(
362 writer, tag_name, &item, config, compat, ns, is_root, path, visited,
363 )?;
364 path.pop();
365 }
366 visited.remove(&ptr);
367 return Ok(());
368 } else if let Ok(iter) = PyIterator::from_object(value) {
369 let mut i = 0;
370 for item in iter {
371 let obj = item?;
372 path.push(format!("[{}]", i));
373 process_node(
374 writer, tag_name, &obj, config, compat, ns, is_root, path, visited,
375 )?;
376 path.pop();
377 i += 1;
378 }
379 return Ok(());
380 }
381
382 if value.is_instance_of::<PyNone>() {
383 match compat {
384 CompatMode::Native => {
385 wrap_err(
386 writer.write_event(Event::Empty(BytesStart::new(tag_name))),
387 path,
388 )?;
389 }
390 CompatMode::Obj2Xml => {
391 wrap_err(
392 writer.write_event(Event::Start(BytesStart::new(tag_name))),
393 path,
394 )?;
395 wrap_err(
396 writer.write_event(Event::End(BytesEnd::new(tag_name))),
397 path,
398 )?;
399 }
400 }
401 } else if let Some(text) = py_value_to_cow(py, value, &config.default_func, path)? {
402 let elem = BytesStart::new(tag_name);
403 wrap_err(writer.write_event(Event::Start(elem)), path)?;
404 wrap_err(writer.write_event(Event::Text(BytesText::new(&text))), path)?;
405 wrap_err(
406 writer.write_event(Event::End(BytesEnd::new(tag_name))),
407 path,
408 )?;
409 }
410 Ok(())
411}
412
413fn generate_xml<W: Write>(
414 writer: W,
415 input: &Bound<'_, PyAny>,
416 encoding: &str,
417 full_document: bool,
418 pretty: bool,
419 indent: &str,
420 config: &Config,
421 compat: CompatMode,
422 attr_prefix: &str,
423) -> PyResult<W> {
424 let mut writer = if pretty {
425 Writer::new_with_indent(writer, b' ', indent.len())
426 } else {
427 Writer::new(writer)
428 };
429 if full_document {
430 writer
431 .write_event(Event::Decl(BytesDecl::new("1.0", Some(encoding), None)))
432 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
433 }
434 let mut ns = NamespaceContext::new(&config.namespaces);
435 let mut path = Vec::new();
436 let mut visited = HashSet::new();
437
438 if let Ok(dict) = input.cast::<PyDict>() {
439 let roots = dict
440 .iter()
441 .filter(|(k, _)| !k.to_string().starts_with(attr_prefix))
442 .count();
443 if full_document && roots != 1 {
444 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
445 "Document must have exactly one root",
446 ));
447 }
448 for (k, v) in dict {
449 let key_cow = extract_str(&k)?;
450 let key = key_cow.as_ref();
451 if key.starts_with(attr_prefix) {
452 continue;
453 }
454 path.push(key.to_string());
455 process_node(
456 &mut writer,
457 key,
458 &v,
459 config,
460 compat,
461 &mut ns,
462 true,
463 &mut path,
464 &mut visited,
465 )?;
466 path.pop();
467 }
468 } else {
469 let iter = PyIterator::from_object(input)?;
470 let mut i = 0;
471 for item in iter {
472 let obj = item?;
473 path.push(format!("[{}]", i));
474 if let Ok(d) = obj.cast::<PyDict>() {
475 for (k, v) in d {
476 let k_cow = extract_str(&k)?;
477 let k_str = k_cow.as_ref();
478 path.push(k_str.to_string());
479 process_node(
480 &mut writer,
481 k_str,
482 &v,
483 config,
484 compat,
485 &mut ns,
486 true,
487 &mut path,
488 &mut visited,
489 )?;
490 path.pop();
491 }
492 } else {
493 process_node(
494 &mut writer,
495 &config.item_name,
496 &obj,
497 config,
498 compat,
499 &mut ns,
500 true,
501 &mut path,
502 &mut visited,
503 )?;
504 }
505 path.pop();
506 i += 1;
507 }
508 }
509 Ok(writer.into_inner())
510}
511
512#[pyfunction]
513#[pyo3(signature = (
514 input, *,
515 output=None,
516 encoding="utf-8",
517 full_document=true,
518 attr_prefix="@",
519 cdata_key="#text",
520 pretty=false, indent=" ",
521 compat="native",
522 streaming=false,
523 default=None,
524 item_name="item",
525 sort_attributes=false,
526 namespaces=None
527))]
528fn unparse(
529 _py: Python<'_>,
530 input: &Bound<'_, PyAny>,
531 output: Option<&Bound<'_, PyAny>>,
532 encoding: &str,
533 full_document: bool,
534 attr_prefix: &str,
535 cdata_key: &str,
536 pretty: bool,
537 indent: &str,
538 compat: &str,
539 streaming: bool,
540 default: Option<Py<PyAny>>,
541 item_name: &str,
542 sort_attributes: bool,
543 namespaces: Option<HashMap<String, String>>,
544) -> PyResult<String> {
545 let config = Config {
546 attr_prefix: attr_prefix.to_string(),
547 cdata_key: cdata_key.to_string(),
548 default_func: default,
549 item_name: item_name.to_string(),
550 sort_attrs: sort_attributes,
551 namespaces: namespaces.unwrap_or_default(),
552 };
553 let compat_mode = if compat == "legacy" {
554 CompatMode::Obj2Xml
555 } else {
556 CompatMode::Native
557 };
558
559 if streaming {
560 let out_obj = output.ok_or_else(|| {
561 PyErr::new::<pyo3::exceptions::PyValueError, _>(
562 "streaming=True requires output argument",
563 )
564 })?;
565 let sink: Box<dyn Write> = if let Ok(path) = out_obj.extract::<String>() {
566 let f = File::create(path)?;
567 Box::new(BufWriter::new(f))
568 } else {
569 Box::new(PyWriter {
570 obj: out_obj.clone(),
571 })
572 };
573 wrap_err(
574 generate_xml(
575 sink,
576 input,
577 encoding,
578 full_document,
579 pretty,
580 indent,
581 &config,
582 compat_mode,
583 attr_prefix,
584 ),
585 &[],
586 )?;
587 return Ok(String::new());
588 }
589
590 let cursor = Cursor::new(Vec::with_capacity(1024));
591 let cursor = wrap_py_err(
592 generate_xml(
593 cursor,
594 input,
595 encoding,
596 full_document,
597 pretty,
598 indent,
599 &config,
600 compat_mode,
601 attr_prefix,
602 ),
603 &[],
604 )?;
605 let xml = String::from_utf8(cursor.into_inner())
606 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
607 if let Some(out_obj) = output {
608 if let Ok(path) = out_obj.extract::<String>() {
609 std::fs::write(path, &xml)?;
610 } else {
611 out_obj.call_method1("write", (xml.as_bytes(),))?;
612 }
613 return Ok(String::new());
614 }
615 Ok(xml)
616}
617
618#[pymodule]
619fn _obj2xml_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
620 m.add_function(wrap_pyfunction!(unparse, m)?)?;
621 Ok(())
622}