1extern crate alloc;
2
3use alloc::{borrow::Cow, format, string::String, vec::Vec};
4use std::collections::HashMap;
5use std::io::Write;
6
7use facet_core::{Def, Facet, ScalarType};
8use facet_dom::{DomSerializeError, DomSerializer};
9use facet_reflect::Peek;
10
11use crate::escaping::EscapingWriter;
12
13pub use facet_dom::FloatFormatter;
14
15fn write_scalar_value(
19 out: &mut dyn Write,
20 value: Peek<'_, '_>,
21 float_formatter: Option<FloatFormatter>,
22) -> std::io::Result<bool> {
23 let value = value.innermost_peek();
25
26 if let Def::Option(_) = &value.shape().def
28 && let Ok(opt) = value.into_option()
29 {
30 return match opt.value() {
31 Some(inner) => write_scalar_value(out, inner, float_formatter),
32 None => Ok(false),
33 };
34 }
35
36 let Some(scalar_type) = value.scalar_type() else {
37 if matches!(value.shape().def, Def::Scalar) && value.shape().vtable.has_display() {
39 write!(out, "{}", value)?;
40 return Ok(true);
41 }
42 return Ok(false);
43 };
44
45 match scalar_type {
46 ScalarType::Unit => {
47 out.write_all(b"null")?;
48 }
49 ScalarType::Bool => {
50 let b = value.get::<bool>().unwrap();
51 out.write_all(if *b { b"true" } else { b"false" })?;
52 }
53 ScalarType::Char => {
54 let c = value.get::<char>().unwrap();
55 let mut buf = [0u8; 4];
56 let s = c.encode_utf8(&mut buf);
57 out.write_all(s.as_bytes())?;
58 }
59 ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
60 let s = value.as_str().unwrap();
61 out.write_all(s.as_bytes())?;
62 }
63 ScalarType::F32 => {
64 let v = value.get::<f32>().unwrap();
65 if let Some(fmt) = float_formatter {
66 fmt(*v as f64, out)?;
67 } else {
68 write!(out, "{}", v)?;
69 }
70 }
71 ScalarType::F64 => {
72 let v = value.get::<f64>().unwrap();
73 if let Some(fmt) = float_formatter {
74 fmt(*v, out)?;
75 } else {
76 write!(out, "{}", v)?;
77 }
78 }
79 ScalarType::U8 => write!(out, "{}", value.get::<u8>().unwrap())?,
80 ScalarType::U16 => write!(out, "{}", value.get::<u16>().unwrap())?,
81 ScalarType::U32 => write!(out, "{}", value.get::<u32>().unwrap())?,
82 ScalarType::U64 => write!(out, "{}", value.get::<u64>().unwrap())?,
83 ScalarType::U128 => write!(out, "{}", value.get::<u128>().unwrap())?,
84 ScalarType::USize => write!(out, "{}", value.get::<usize>().unwrap())?,
85 ScalarType::I8 => write!(out, "{}", value.get::<i8>().unwrap())?,
86 ScalarType::I16 => write!(out, "{}", value.get::<i16>().unwrap())?,
87 ScalarType::I32 => write!(out, "{}", value.get::<i32>().unwrap())?,
88 ScalarType::I64 => write!(out, "{}", value.get::<i64>().unwrap())?,
89 ScalarType::I128 => write!(out, "{}", value.get::<i128>().unwrap())?,
90 ScalarType::ISize => write!(out, "{}", value.get::<isize>().unwrap())?,
91 #[cfg(feature = "net")]
92 ScalarType::IpAddr => write!(out, "{}", value.get::<core::net::IpAddr>().unwrap())?,
93 #[cfg(feature = "net")]
94 ScalarType::Ipv4Addr => write!(out, "{}", value.get::<core::net::Ipv4Addr>().unwrap())?,
95 #[cfg(feature = "net")]
96 ScalarType::Ipv6Addr => write!(out, "{}", value.get::<core::net::Ipv6Addr>().unwrap())?,
97 #[cfg(feature = "net")]
98 ScalarType::SocketAddr => write!(out, "{}", value.get::<core::net::SocketAddr>().unwrap())?,
99 _ => return Ok(false),
100 }
101 Ok(true)
102}
103
104#[derive(Clone)]
106pub struct SerializeOptions {
107 pub pretty: bool,
109 pub indent: Cow<'static, str>,
111 pub float_formatter: Option<FloatFormatter>,
114 pub preserve_entities: bool,
122}
123
124impl Default for SerializeOptions {
125 fn default() -> Self {
126 Self {
127 pretty: false,
128 indent: Cow::Borrowed(" "),
129 float_formatter: None,
130 preserve_entities: false,
131 }
132 }
133}
134
135impl core::fmt::Debug for SerializeOptions {
136 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
137 f.debug_struct("SerializeOptions")
138 .field("pretty", &self.pretty)
139 .field("indent", &self.indent)
140 .field("float_formatter", &self.float_formatter.map(|_| "..."))
141 .field("preserve_entities", &self.preserve_entities)
142 .finish()
143 }
144}
145
146impl SerializeOptions {
147 pub fn new() -> Self {
149 Self::default()
150 }
151
152 pub const fn pretty(mut self) -> Self {
154 self.pretty = true;
155 self
156 }
157
158 pub fn indent(mut self, indent: impl Into<Cow<'static, str>>) -> Self {
160 self.indent = indent.into();
161 self.pretty = true;
162 self
163 }
164
165 pub fn float_formatter(mut self, formatter: FloatFormatter) -> Self {
199 self.float_formatter = Some(formatter);
200 self
201 }
202
203 pub const fn preserve_entities(mut self, preserve: bool) -> Self {
211 self.preserve_entities = preserve;
212 self
213 }
214}
215
216#[allow(dead_code)] const WELL_KNOWN_NAMESPACES: &[(&str, &str)] = &[
219 ("http://www.w3.org/2001/XMLSchema-instance", "xsi"),
220 ("http://www.w3.org/2001/XMLSchema", "xs"),
221 ("http://www.w3.org/XML/1998/namespace", "xml"),
222 ("http://www.w3.org/1999/xlink", "xlink"),
223 ("http://www.w3.org/2000/svg", "svg"),
224 ("http://www.w3.org/1999/xhtml", "xhtml"),
225 ("http://schemas.xmlsoap.org/soap/envelope/", "soap"),
226 ("http://www.w3.org/2003/05/soap-envelope", "soap12"),
227 ("http://schemas.android.com/apk/res/android", "android"),
228];
229
230#[derive(Debug)]
231pub struct XmlSerializeError {
232 msg: Cow<'static, str>,
233}
234
235impl core::fmt::Display for XmlSerializeError {
236 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
237 f.write_str(&self.msg)
238 }
239}
240
241impl std::error::Error for XmlSerializeError {}
242
243pub struct XmlSerializer {
250 out: Vec<u8>,
251 element_stack: Vec<String>,
253 declared_namespaces: HashMap<String, String>,
255 next_ns_index: usize,
257 current_default_ns: Option<String>,
260 current_ns_all: Option<String>,
262 pending_is_attribute: bool,
264 pending_is_text: bool,
266 pending_is_elements: bool,
268 pending_is_doctype: bool,
270 pending_namespace: Option<String>,
272 options: SerializeOptions,
274 depth: usize,
276 collecting_attributes: bool,
278 pending_establish_default_ns: bool,
280}
281
282impl XmlSerializer {
283 pub fn new() -> Self {
285 Self::with_options(SerializeOptions::default())
286 }
287
288 pub fn with_options(options: SerializeOptions) -> Self {
290 Self {
291 out: Vec::new(),
292 element_stack: Vec::new(),
293 declared_namespaces: HashMap::new(),
294 next_ns_index: 0,
295 current_default_ns: None,
296 current_ns_all: None,
297 pending_is_attribute: false,
298 pending_is_text: false,
299 pending_is_elements: false,
300 pending_is_doctype: false,
301 pending_namespace: None,
302 options,
303 depth: 0,
304 collecting_attributes: false,
305 pending_establish_default_ns: false,
306 }
307 }
308
309 pub fn finish(self) -> Vec<u8> {
310 self.out
311 }
312
313 fn write_element_tag_start(&mut self, name: &str, namespace: Option<&str>) {
316 self.write_indent();
317 self.out.push(b'<');
318
319 let close_tag: String;
321
322 if let Some(ns_uri) = namespace {
324 if self.current_default_ns.as_deref() == Some(ns_uri) {
325 self.out.extend_from_slice(name.as_bytes());
327 close_tag = name.to_string();
328 } else if self.pending_establish_default_ns {
329 self.out.extend_from_slice(name.as_bytes());
331 self.out.extend_from_slice(b" xmlns=\"");
332 self.out.extend_from_slice(ns_uri.as_bytes());
333 self.out.push(b'"');
334 self.current_default_ns = Some(ns_uri.to_string());
335 self.pending_establish_default_ns = false;
336 close_tag = name.to_string();
337 } else {
338 let prefix = self.get_or_create_prefix(ns_uri);
340 self.out.extend_from_slice(prefix.as_bytes());
341 self.out.push(b':');
342 self.out.extend_from_slice(name.as_bytes());
343 self.out.extend_from_slice(b" xmlns:");
345 self.out.extend_from_slice(prefix.as_bytes());
346 self.out.extend_from_slice(b"=\"");
347 self.out.extend_from_slice(ns_uri.as_bytes());
348 self.out.push(b'"');
349 close_tag = format!("{}:{}", prefix, name);
350 }
351 } else {
352 self.out.extend_from_slice(name.as_bytes());
353 close_tag = name.to_string();
354 }
355
356 self.element_stack.push(close_tag);
358 }
359
360 fn write_attribute(
363 &mut self,
364 name: &str,
365 value: Peek<'_, '_>,
366 namespace: Option<&str>,
367 ) -> std::io::Result<bool> {
368 let mut value_buf = Vec::new();
370 let written = write_scalar_value(
371 &mut EscapingWriter::attribute(&mut value_buf),
372 value,
373 self.options.float_formatter,
374 )?;
375
376 if !written {
377 return Ok(false);
379 }
380
381 self.out.push(b' ');
383 if let Some(ns_uri) = namespace {
384 let prefix = self.get_or_create_prefix(ns_uri);
385 self.out.extend_from_slice(b"xmlns:");
387 self.out.extend_from_slice(prefix.as_bytes());
388 self.out.extend_from_slice(b"=\"");
389 self.out.extend_from_slice(ns_uri.as_bytes());
390 self.out.extend_from_slice(b"\" ");
391 self.out.extend_from_slice(prefix.as_bytes());
393 self.out.push(b':');
394 }
395 self.out.extend_from_slice(name.as_bytes());
396 self.out.extend_from_slice(b"=\"");
397 self.out.extend_from_slice(&value_buf);
398 self.out.push(b'"');
399 Ok(true)
400 }
401
402 fn write_element_tag_end(&mut self) {
404 self.out.push(b'>');
405 self.write_newline();
406 self.depth += 1;
407 }
408
409 fn write_close_tag(&mut self, name: &str) {
410 self.depth = self.depth.saturating_sub(1);
411 self.write_indent();
412 self.out.extend_from_slice(b"</");
413 self.out.extend_from_slice(name.as_bytes());
414 self.out.push(b'>');
415 self.write_newline();
416 }
417
418 fn write_text_escaped(&mut self, text: &str) {
419 use std::io::Write;
420 if self.options.preserve_entities {
421 let escaped = escape_preserving_entities(text, false);
422 self.out.extend_from_slice(escaped.as_bytes());
423 } else {
424 let _ = EscapingWriter::text(&mut self.out).write_all(text.as_bytes());
426 }
427 }
428
429 fn write_indent(&mut self) {
431 if self.options.pretty {
432 for _ in 0..self.depth {
433 self.out.extend_from_slice(self.options.indent.as_bytes());
434 }
435 }
436 }
437
438 fn write_newline(&mut self) {
440 if self.options.pretty {
441 self.out.push(b'\n');
442 }
443 }
444
445 fn get_or_create_prefix(&mut self, namespace_uri: &str) -> String {
447 if let Some(prefix) = self.declared_namespaces.get(namespace_uri) {
449 return prefix.clone();
450 }
451
452 let prefix = WELL_KNOWN_NAMESPACES
454 .iter()
455 .find(|(uri, _)| *uri == namespace_uri)
456 .map(|(_, prefix)| (*prefix).to_string())
457 .unwrap_or_else(|| {
458 let prefix = format!("ns{}", self.next_ns_index);
460 self.next_ns_index += 1;
461 prefix
462 });
463
464 let final_prefix = if self.declared_namespaces.values().any(|p| p == &prefix) {
466 let prefix = format!("ns{}", self.next_ns_index);
467 self.next_ns_index += 1;
468 prefix
469 } else {
470 prefix
471 };
472
473 self.declared_namespaces
474 .insert(namespace_uri.to_string(), final_prefix.clone());
475 final_prefix
476 }
477
478 fn clear_field_state_impl(&mut self) {
479 self.pending_is_attribute = false;
480 self.pending_is_text = false;
481 self.pending_is_elements = false;
482 self.pending_is_doctype = false;
483 self.pending_namespace = None;
484 }
485}
486
487impl Default for XmlSerializer {
488 fn default() -> Self {
489 Self::new()
490 }
491}
492
493impl DomSerializer for XmlSerializer {
494 type Error = XmlSerializeError;
495
496 fn element_start(&mut self, tag: &str, namespace: Option<&str>) -> Result<(), Self::Error> {
497 let ns = namespace
499 .map(|s| s.to_string())
500 .or_else(|| self.pending_namespace.take())
501 .or_else(|| self.current_ns_all.clone());
502
503 self.write_element_tag_start(tag, ns.as_deref());
505 self.collecting_attributes = true;
506
507 Ok(())
508 }
509
510 fn attribute(
511 &mut self,
512 name: &str,
513 value: Peek<'_, '_>,
514 namespace: Option<&str>,
515 ) -> Result<(), Self::Error> {
516 if !self.collecting_attributes {
518 return Err(XmlSerializeError {
519 msg: Cow::Borrowed("attribute() called after children_start()"),
520 });
521 }
522
523 let ns: Option<String> = match namespace {
525 Some(ns) => Some(ns.to_string()),
526 None => self.pending_namespace.clone(),
527 };
528
529 self.write_attribute(name, value, ns.as_deref())
531 .map_err(|e| XmlSerializeError {
532 msg: Cow::Owned(format!("write error: {}", e)),
533 })?;
534 Ok(())
535 }
536
537 fn children_start(&mut self) -> Result<(), Self::Error> {
538 self.write_element_tag_end();
540 self.collecting_attributes = false;
541 Ok(())
542 }
543
544 fn children_end(&mut self) -> Result<(), Self::Error> {
545 Ok(())
546 }
547
548 fn element_end(&mut self, _tag: &str) -> Result<(), Self::Error> {
549 if let Some(close_tag) = self.element_stack.pop() {
550 self.write_close_tag(&close_tag);
551 }
552 Ok(())
553 }
554
555 fn text(&mut self, content: &str) -> Result<(), Self::Error> {
556 self.write_text_escaped(content);
557 Ok(())
558 }
559
560 fn struct_metadata(&mut self, shape: &facet_core::Shape) -> Result<(), Self::Error> {
561 self.current_ns_all = shape
563 .attributes
564 .iter()
565 .find(|attr| attr.ns == Some("xml") && attr.key == "ns_all")
566 .and_then(|attr| attr.get_as::<&str>().copied())
567 .map(String::from);
568
569 self.pending_establish_default_ns = self.current_ns_all.is_some();
571
572 Ok(())
573 }
574
575 fn field_metadata(&mut self, field: &facet_reflect::FieldItem) -> Result<(), Self::Error> {
576 let Some(field_def) = field.field else {
577 self.pending_is_attribute = true;
579 self.pending_is_text = false;
580 self.pending_is_elements = false;
581 self.pending_is_doctype = false;
582 return Ok(());
583 };
584
585 self.pending_is_attribute = field_def.get_attr(Some("xml"), "attribute").is_some();
587 self.pending_is_text = field_def.get_attr(Some("xml"), "text").is_some();
589 self.pending_is_elements = field_def.get_attr(Some("xml"), "elements").is_some();
591 self.pending_is_doctype = field_def.get_attr(Some("xml"), "doctype").is_some();
593
594 if let Some(ns_attr) = field_def.get_attr(Some("xml"), "ns")
596 && let Some(ns_uri) = ns_attr.get_as::<&str>().copied()
597 {
598 self.pending_namespace = Some(ns_uri.to_string());
599 } else if !self.pending_is_attribute && !self.pending_is_text {
600 self.pending_namespace = self.current_ns_all.clone();
602 } else {
603 self.pending_namespace = None;
605 }
606
607 Ok(())
608 }
609
610 fn variant_metadata(
611 &mut self,
612 _variant: &'static facet_core::Variant,
613 ) -> Result<(), Self::Error> {
614 Ok(())
615 }
616
617 fn is_attribute_field(&self) -> bool {
618 self.pending_is_attribute
619 }
620
621 fn is_text_field(&self) -> bool {
622 self.pending_is_text
623 }
624
625 fn is_elements_field(&self) -> bool {
626 self.pending_is_elements
627 }
628
629 fn is_doctype_field(&self) -> bool {
630 self.pending_is_doctype
631 }
632
633 fn doctype(&mut self, content: &str) -> Result<(), Self::Error> {
634 self.out.write_all(b"<!DOCTYPE ").unwrap();
636 self.out.write_all(content.as_bytes()).unwrap();
637 self.out.write_all(b">").unwrap();
638 if self.options.pretty {
639 self.out.write_all(b"\n").unwrap();
640 }
641 Ok(())
642 }
643
644 fn clear_field_state(&mut self) {
645 self.clear_field_state_impl();
646 }
647
648 fn format_float(&self, value: f64) -> String {
649 if let Some(formatter) = self.options.float_formatter {
650 let mut buf = Vec::new();
651 if formatter(value, &mut buf).is_ok()
653 && let Ok(s) = String::from_utf8(buf)
654 {
655 return s;
656 }
657 }
658 value.to_string()
659 }
660
661 fn serialize_none(&mut self) -> Result<(), Self::Error> {
662 Ok(())
664 }
665
666 fn format_namespace(&self) -> Option<&'static str> {
667 Some("xml")
668 }
669}
670
671pub fn to_vec<'facet, T>(value: &'_ T) -> Result<Vec<u8>, DomSerializeError<XmlSerializeError>>
673where
674 T: Facet<'facet> + ?Sized,
675{
676 to_vec_with_options(value, &SerializeOptions::default())
677}
678
679pub fn to_vec_with_options<'facet, T>(
681 value: &'_ T,
682 options: &SerializeOptions,
683) -> Result<Vec<u8>, DomSerializeError<XmlSerializeError>>
684where
685 T: Facet<'facet> + ?Sized,
686{
687 let mut serializer = XmlSerializer::with_options(options.clone());
688 facet_dom::serialize(&mut serializer, Peek::new(value))?;
689 Ok(serializer.finish())
690}
691
692pub fn to_string<'facet, T>(value: &'_ T) -> Result<String, DomSerializeError<XmlSerializeError>>
694where
695 T: Facet<'facet> + ?Sized,
696{
697 let bytes = to_vec(value)?;
698 Ok(String::from_utf8(bytes).expect("XmlSerializer produces valid UTF-8"))
700}
701
702pub fn to_string_pretty<'facet, T>(
704 value: &'_ T,
705) -> Result<String, DomSerializeError<XmlSerializeError>>
706where
707 T: Facet<'facet> + ?Sized,
708{
709 to_string_with_options(value, &SerializeOptions::default().pretty())
710}
711
712pub fn to_string_with_options<'facet, T>(
714 value: &'_ T,
715 options: &SerializeOptions,
716) -> Result<String, DomSerializeError<XmlSerializeError>>
717where
718 T: Facet<'facet> + ?Sized,
719{
720 let bytes = to_vec_with_options(value, options)?;
721 Ok(String::from_utf8(bytes).expect("XmlSerializer produces valid UTF-8"))
723}
724
725fn escape_preserving_entities(s: &str, is_attribute: bool) -> String {
732 let mut result = String::with_capacity(s.len());
733 let chars: Vec<char> = s.chars().collect();
734 let mut i = 0;
735
736 while i < chars.len() {
737 let c = chars[i];
738 match c {
739 '<' => result.push_str("<"),
740 '>' => result.push_str(">"),
741 '"' if is_attribute => result.push_str("""),
742 '&' => {
743 if let Some(entity_len) = try_parse_entity_reference(&chars[i..]) {
745 for j in 0..entity_len {
747 result.push(chars[i + j]);
748 }
749 i += entity_len;
750 continue;
751 } else {
752 result.push_str("&");
754 }
755 }
756 _ => result.push(c),
757 }
758 i += 1;
759 }
760
761 result
762}
763
764fn try_parse_entity_reference(chars: &[char]) -> Option<usize> {
772 if chars.is_empty() || chars[0] != '&' {
773 return None;
774 }
775
776 if chars.len() < 3 {
778 return None;
779 }
780
781 let mut len = 1; if chars[1] == '#' {
784 len = 2;
786
787 if len < chars.len() && (chars[len] == 'x' || chars[len] == 'X') {
788 len += 1;
790 let start = len;
791 while len < chars.len() && chars[len].is_ascii_hexdigit() {
792 len += 1;
793 }
794 if len == start {
796 return None;
797 }
798 } else {
799 let start = len;
801 while len < chars.len() && chars[len].is_ascii_digit() {
802 len += 1;
803 }
804 if len == start {
806 return None;
807 }
808 }
809 } else {
810 if !chars[len].is_ascii_alphabetic() && chars[len] != '_' {
812 return None;
813 }
814 len += 1;
815 while len < chars.len()
816 && (chars[len].is_ascii_alphanumeric()
817 || chars[len] == '_'
818 || chars[len] == '-'
819 || chars[len] == '.')
820 {
821 len += 1;
822 }
823 }
824
825 if len >= chars.len() || chars[len] != ';' {
827 return None;
828 }
829
830 Some(len + 1) }