extern crate alloc;
use alloc::{borrow::Cow, format, string::String, vec::Vec};
use std::collections::HashMap;
use std::io::Write;
use facet_core::{Def, Facet, ScalarType};
use facet_dom::{DomSerializeError, DomSerializer};
use facet_reflect::Peek;
use crate::escaping::EscapingWriter;
pub use facet_dom::FloatFormatter;
fn write_scalar_value(
out: &mut dyn Write,
value: Peek<'_, '_>,
float_formatter: Option<FloatFormatter>,
) -> std::io::Result<bool> {
let value = value.innermost_peek();
if let Def::Option(_) = &value.shape().def
&& let Ok(opt) = value.into_option()
{
return match opt.value() {
Some(inner) => write_scalar_value(out, inner, float_formatter),
None => Ok(false),
};
}
let Some(scalar_type) = value.scalar_type() else {
if matches!(value.shape().def, Def::Scalar) && value.shape().vtable.has_display() {
write!(out, "{}", value)?;
return Ok(true);
}
if let Ok(enum_) = value.into_enum()
&& let Ok(variant) = enum_.active_variant()
&& variant.data.kind == facet_core::StructKind::Unit
{
let variant_name = if variant.rename.is_some() {
Cow::Borrowed(variant.effective_name())
} else {
facet_dom::naming::to_element_name(variant.name)
};
out.write_all(variant_name.as_bytes())?;
return Ok(true);
}
return Ok(false);
};
match scalar_type {
ScalarType::Unit => {
out.write_all(b"null")?;
}
ScalarType::Bool => {
let b = value.get::<bool>().unwrap();
out.write_all(if *b { b"true" } else { b"false" })?;
}
ScalarType::Char => {
let c = value.get::<char>().unwrap();
let mut buf = [0u8; 4];
let s = c.encode_utf8(&mut buf);
out.write_all(s.as_bytes())?;
}
ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
let s = value.as_str().unwrap();
out.write_all(s.as_bytes())?;
}
ScalarType::F32 => {
let v = value.get::<f32>().unwrap();
if let Some(fmt) = float_formatter {
fmt(*v as f64, out)?;
} else {
write!(out, "{}", v)?;
}
}
ScalarType::F64 => {
let v = value.get::<f64>().unwrap();
if let Some(fmt) = float_formatter {
fmt(*v, out)?;
} else {
write!(out, "{}", v)?;
}
}
ScalarType::U8 => write!(out, "{}", value.get::<u8>().unwrap())?,
ScalarType::U16 => write!(out, "{}", value.get::<u16>().unwrap())?,
ScalarType::U32 => write!(out, "{}", value.get::<u32>().unwrap())?,
ScalarType::U64 => write!(out, "{}", value.get::<u64>().unwrap())?,
ScalarType::U128 => write!(out, "{}", value.get::<u128>().unwrap())?,
ScalarType::USize => write!(out, "{}", value.get::<usize>().unwrap())?,
ScalarType::I8 => write!(out, "{}", value.get::<i8>().unwrap())?,
ScalarType::I16 => write!(out, "{}", value.get::<i16>().unwrap())?,
ScalarType::I32 => write!(out, "{}", value.get::<i32>().unwrap())?,
ScalarType::I64 => write!(out, "{}", value.get::<i64>().unwrap())?,
ScalarType::I128 => write!(out, "{}", value.get::<i128>().unwrap())?,
ScalarType::ISize => write!(out, "{}", value.get::<isize>().unwrap())?,
#[cfg(feature = "net")]
ScalarType::IpAddr => write!(out, "{}", value.get::<core::net::IpAddr>().unwrap())?,
#[cfg(feature = "net")]
ScalarType::Ipv4Addr => write!(out, "{}", value.get::<core::net::Ipv4Addr>().unwrap())?,
#[cfg(feature = "net")]
ScalarType::Ipv6Addr => write!(out, "{}", value.get::<core::net::Ipv6Addr>().unwrap())?,
#[cfg(feature = "net")]
ScalarType::SocketAddr => write!(out, "{}", value.get::<core::net::SocketAddr>().unwrap())?,
_ => return Ok(false),
}
Ok(true)
}
#[derive(Clone)]
pub struct SerializeOptions {
pub pretty: bool,
pub indent: Cow<'static, str>,
pub float_formatter: Option<FloatFormatter>,
pub preserve_entities: bool,
}
impl Default for SerializeOptions {
fn default() -> Self {
Self {
pretty: false,
indent: Cow::Borrowed(" "),
float_formatter: None,
preserve_entities: false,
}
}
}
impl core::fmt::Debug for SerializeOptions {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("SerializeOptions")
.field("pretty", &self.pretty)
.field("indent", &self.indent)
.field("float_formatter", &self.float_formatter.map(|_| "..."))
.field("preserve_entities", &self.preserve_entities)
.finish()
}
}
impl SerializeOptions {
pub fn new() -> Self {
Self::default()
}
pub const fn pretty(mut self) -> Self {
self.pretty = true;
self
}
pub fn indent(mut self, indent: impl Into<Cow<'static, str>>) -> Self {
self.indent = indent.into();
self.pretty = true;
self
}
pub fn float_formatter(mut self, formatter: FloatFormatter) -> Self {
self.float_formatter = Some(formatter);
self
}
pub const fn preserve_entities(mut self, preserve: bool) -> Self {
self.preserve_entities = preserve;
self
}
}
#[allow(dead_code)] const WELL_KNOWN_NAMESPACES: &[(&str, &str)] = &[
("http://www.w3.org/2001/XMLSchema-instance", "xsi"),
("http://www.w3.org/2001/XMLSchema", "xs"),
("http://www.w3.org/XML/1998/namespace", "xml"),
("http://www.w3.org/1999/xlink", "xlink"),
("http://www.w3.org/2000/svg", "svg"),
("http://www.w3.org/1999/xhtml", "xhtml"),
("http://schemas.xmlsoap.org/soap/envelope/", "soap"),
("http://www.w3.org/2003/05/soap-envelope", "soap12"),
("http://schemas.android.com/apk/res/android", "android"),
];
#[derive(Debug)]
pub struct XmlSerializeError {
msg: Cow<'static, str>,
}
impl core::fmt::Display for XmlSerializeError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(&self.msg)
}
}
impl std::error::Error for XmlSerializeError {}
pub struct XmlSerializer {
out: Vec<u8>,
element_stack: Vec<String>,
declared_namespaces: HashMap<String, String>,
next_ns_index: usize,
current_default_ns: Option<String>,
current_ns_all: Option<String>,
pending_is_attribute: bool,
pending_is_text: bool,
pending_is_elements: bool,
pending_is_doctype: bool,
pending_is_tag: bool,
pending_namespace: Option<String>,
options: SerializeOptions,
depth: usize,
collecting_attributes: bool,
pending_establish_default_ns: bool,
}
impl XmlSerializer {
pub fn new() -> Self {
Self::with_options(SerializeOptions::default())
}
pub fn with_options(options: SerializeOptions) -> Self {
Self {
out: Vec::new(),
element_stack: Vec::new(),
declared_namespaces: HashMap::new(),
next_ns_index: 0,
current_default_ns: None,
current_ns_all: None,
pending_is_attribute: false,
pending_is_text: false,
pending_is_elements: false,
pending_is_doctype: false,
pending_is_tag: false,
pending_namespace: None,
options,
depth: 0,
collecting_attributes: false,
pending_establish_default_ns: false,
}
}
pub fn finish(self) -> Vec<u8> {
self.out
}
fn write_element_tag_start(&mut self, name: &str, namespace: Option<&str>) {
self.write_indent();
self.out.push(b'<');
let close_tag: String;
if let Some(ns_uri) = namespace {
if self.current_default_ns.as_deref() == Some(ns_uri) {
self.out.extend_from_slice(name.as_bytes());
close_tag = name.to_string();
} else if self.pending_establish_default_ns {
self.out.extend_from_slice(name.as_bytes());
self.out.extend_from_slice(b" xmlns=\"");
self.out.extend_from_slice(ns_uri.as_bytes());
self.out.push(b'"');
self.current_default_ns = Some(ns_uri.to_string());
self.pending_establish_default_ns = false;
close_tag = name.to_string();
} else {
let prefix = self.get_or_create_prefix(ns_uri);
self.out.extend_from_slice(prefix.as_bytes());
self.out.push(b':');
self.out.extend_from_slice(name.as_bytes());
self.out.extend_from_slice(b" xmlns:");
self.out.extend_from_slice(prefix.as_bytes());
self.out.extend_from_slice(b"=\"");
self.out.extend_from_slice(ns_uri.as_bytes());
self.out.push(b'"');
close_tag = format!("{}:{}", prefix, name);
}
} else {
self.out.extend_from_slice(name.as_bytes());
close_tag = name.to_string();
}
self.element_stack.push(close_tag);
}
fn write_attribute(
&mut self,
name: &str,
value: Peek<'_, '_>,
namespace: Option<&str>,
) -> std::io::Result<bool> {
let mut value_buf = Vec::new();
let written = write_scalar_value(
&mut EscapingWriter::attribute(&mut value_buf),
value,
self.options.float_formatter,
)?;
if !written {
return Ok(false);
}
self.out.push(b' ');
if let Some(ns_uri) = namespace {
let prefix = self.get_or_create_prefix(ns_uri);
self.out.extend_from_slice(b"xmlns:");
self.out.extend_from_slice(prefix.as_bytes());
self.out.extend_from_slice(b"=\"");
self.out.extend_from_slice(ns_uri.as_bytes());
self.out.extend_from_slice(b"\" ");
self.out.extend_from_slice(prefix.as_bytes());
self.out.push(b':');
}
self.out.extend_from_slice(name.as_bytes());
self.out.extend_from_slice(b"=\"");
self.out.extend_from_slice(&value_buf);
self.out.push(b'"');
Ok(true)
}
fn write_element_tag_end(&mut self) {
self.out.push(b'>');
self.write_newline();
self.depth += 1;
}
fn write_close_tag(&mut self, name: &str) {
self.depth = self.depth.saturating_sub(1);
self.write_indent();
self.out.extend_from_slice(b"</");
self.out.extend_from_slice(name.as_bytes());
self.out.push(b'>');
self.write_newline();
}
fn write_text_escaped(&mut self, text: &str) {
use std::io::Write;
if self.options.preserve_entities {
let escaped = escape_preserving_entities(text, false);
self.out.extend_from_slice(escaped.as_bytes());
} else {
let _ = EscapingWriter::text(&mut self.out).write_all(text.as_bytes());
}
}
fn write_indent(&mut self) {
if self.options.pretty {
for _ in 0..self.depth {
self.out.extend_from_slice(self.options.indent.as_bytes());
}
}
}
fn write_newline(&mut self) {
if self.options.pretty {
self.out.push(b'\n');
}
}
fn get_or_create_prefix(&mut self, namespace_uri: &str) -> String {
if let Some(prefix) = self.declared_namespaces.get(namespace_uri) {
return prefix.clone();
}
let prefix = WELL_KNOWN_NAMESPACES
.iter()
.find(|(uri, _)| *uri == namespace_uri)
.map(|(_, prefix)| (*prefix).to_string())
.unwrap_or_else(|| {
let prefix = format!("ns{}", self.next_ns_index);
self.next_ns_index += 1;
prefix
});
let final_prefix = if self.declared_namespaces.values().any(|p| p == &prefix) {
let prefix = format!("ns{}", self.next_ns_index);
self.next_ns_index += 1;
prefix
} else {
prefix
};
self.declared_namespaces
.insert(namespace_uri.to_string(), final_prefix.clone());
final_prefix
}
fn clear_field_state_impl(&mut self) {
self.pending_is_attribute = false;
self.pending_is_text = false;
self.pending_is_elements = false;
self.pending_is_doctype = false;
self.pending_is_tag = false;
self.pending_namespace = None;
}
}
impl Default for XmlSerializer {
fn default() -> Self {
Self::new()
}
}
impl DomSerializer for XmlSerializer {
type Error = XmlSerializeError;
fn element_start(&mut self, tag: &str, namespace: Option<&str>) -> Result<(), Self::Error> {
let ns = namespace
.map(|s| s.to_string())
.or_else(|| self.pending_namespace.take())
.or_else(|| self.current_ns_all.clone());
self.write_element_tag_start(tag, ns.as_deref());
self.collecting_attributes = true;
Ok(())
}
fn attribute(
&mut self,
name: &str,
value: Peek<'_, '_>,
namespace: Option<&str>,
) -> Result<(), Self::Error> {
if !self.collecting_attributes {
return Err(XmlSerializeError {
msg: Cow::Borrowed("attribute() called after children_start()"),
});
}
let ns: Option<String> = match namespace {
Some(ns) => Some(ns.to_string()),
None => self.pending_namespace.clone(),
};
self.write_attribute(name, value, ns.as_deref())
.map_err(|e| XmlSerializeError {
msg: Cow::Owned(format!("write error: {}", e)),
})?;
Ok(())
}
fn children_start(&mut self) -> Result<(), Self::Error> {
self.write_element_tag_end();
self.collecting_attributes = false;
Ok(())
}
fn children_end(&mut self) -> Result<(), Self::Error> {
Ok(())
}
fn element_end(&mut self, _tag: &str) -> Result<(), Self::Error> {
if let Some(close_tag) = self.element_stack.pop() {
self.write_close_tag(&close_tag);
}
Ok(())
}
fn text(&mut self, content: &str) -> Result<(), Self::Error> {
self.write_text_escaped(content);
Ok(())
}
fn struct_metadata(&mut self, shape: &facet_core::Shape) -> Result<(), Self::Error> {
self.current_ns_all = shape
.attributes
.iter()
.find(|attr| attr.ns == Some("xml") && attr.key == "ns_all")
.and_then(|attr| attr.get_as::<&str>().copied())
.map(String::from);
self.pending_establish_default_ns = self.current_ns_all.is_some();
Ok(())
}
fn field_metadata(&mut self, field: &facet_reflect::FieldItem) -> Result<(), Self::Error> {
let Some(field_def) = field.field else {
self.pending_is_attribute = true;
self.pending_is_text = false;
self.pending_is_elements = false;
self.pending_is_doctype = false;
self.pending_is_tag = false;
return Ok(());
};
self.pending_is_attribute = field_def.get_attr(Some("xml"), "attribute").is_some();
self.pending_is_text = field_def.get_attr(Some("xml"), "text").is_some();
self.pending_is_elements = field_def.get_attr(Some("xml"), "elements").is_some();
self.pending_is_doctype = field_def.get_attr(Some("xml"), "doctype").is_some();
self.pending_is_tag = field_def.get_attr(Some("xml"), "tag").is_some();
if let Some(ns_attr) = field_def.get_attr(Some("xml"), "ns")
&& let Some(ns_uri) = ns_attr.get_as::<&str>().copied()
{
self.pending_namespace = Some(ns_uri.to_string());
} else if !self.pending_is_attribute && !self.pending_is_text {
self.pending_namespace = self.current_ns_all.clone();
} else {
self.pending_namespace = None;
}
Ok(())
}
fn variant_metadata(
&mut self,
_variant: &'static facet_core::Variant,
) -> Result<(), Self::Error> {
Ok(())
}
fn is_attribute_field(&self) -> bool {
self.pending_is_attribute
}
fn is_text_field(&self) -> bool {
self.pending_is_text
}
fn is_elements_field(&self) -> bool {
self.pending_is_elements
}
fn is_doctype_field(&self) -> bool {
self.pending_is_doctype
}
fn is_tag_field(&self) -> bool {
self.pending_is_tag
}
fn doctype(&mut self, content: &str) -> Result<(), Self::Error> {
self.out.write_all(b"<!DOCTYPE ").unwrap();
self.out.write_all(content.as_bytes()).unwrap();
self.out.write_all(b">").unwrap();
if self.options.pretty {
self.out.write_all(b"\n").unwrap();
}
Ok(())
}
fn clear_field_state(&mut self) {
self.clear_field_state_impl();
}
fn format_float(&self, value: f64) -> String {
if let Some(formatter) = self.options.float_formatter {
let mut buf = Vec::new();
if formatter(value, &mut buf).is_ok()
&& let Ok(s) = String::from_utf8(buf)
{
return s;
}
}
value.to_string()
}
fn serialize_none(&mut self) -> Result<(), Self::Error> {
Ok(())
}
fn format_namespace(&self) -> Option<&'static str> {
Some("xml")
}
}
pub fn to_vec<'facet, T>(value: &'_ T) -> Result<Vec<u8>, DomSerializeError<XmlSerializeError>>
where
T: Facet<'facet> + ?Sized,
{
to_vec_with_options(value, &SerializeOptions::default())
}
pub fn to_vec_with_options<'facet, T>(
value: &'_ T,
options: &SerializeOptions,
) -> Result<Vec<u8>, DomSerializeError<XmlSerializeError>>
where
T: Facet<'facet> + ?Sized,
{
let mut serializer = XmlSerializer::with_options(options.clone());
facet_dom::serialize(&mut serializer, Peek::new(value))?;
Ok(serializer.finish())
}
pub fn to_string<'facet, T>(value: &'_ T) -> Result<String, DomSerializeError<XmlSerializeError>>
where
T: Facet<'facet> + ?Sized,
{
let bytes = to_vec(value)?;
Ok(String::from_utf8(bytes).expect("XmlSerializer produces valid UTF-8"))
}
pub fn to_string_pretty<'facet, T>(
value: &'_ T,
) -> Result<String, DomSerializeError<XmlSerializeError>>
where
T: Facet<'facet> + ?Sized,
{
to_string_with_options(value, &SerializeOptions::default().pretty())
}
pub fn to_string_with_options<'facet, T>(
value: &'_ T,
options: &SerializeOptions,
) -> Result<String, DomSerializeError<XmlSerializeError>>
where
T: Facet<'facet> + ?Sized,
{
let bytes = to_vec_with_options(value, options)?;
Ok(String::from_utf8(bytes).expect("XmlSerializer produces valid UTF-8"))
}
fn escape_preserving_entities(s: &str, is_attribute: bool) -> String {
let mut result = String::with_capacity(s.len());
let chars: Vec<char> = s.chars().collect();
let mut i = 0;
while i < chars.len() {
let c = chars[i];
match c {
'<' => result.push_str("<"),
'>' => result.push_str(">"),
'"' if is_attribute => result.push_str("""),
'&' => {
if let Some(entity_len) = try_parse_entity_reference(&chars[i..]) {
for j in 0..entity_len {
result.push(chars[i + j]);
}
i += entity_len;
continue;
} else {
result.push_str("&");
}
}
_ => result.push(c),
}
i += 1;
}
result
}
fn try_parse_entity_reference(chars: &[char]) -> Option<usize> {
if chars.is_empty() || chars[0] != '&' {
return None;
}
if chars.len() < 3 {
return None;
}
let mut len = 1;
if chars[1] == '#' {
len = 2;
if len < chars.len() && (chars[len] == 'x' || chars[len] == 'X') {
len += 1;
let start = len;
while len < chars.len() && chars[len].is_ascii_hexdigit() {
len += 1;
}
if len == start {
return None;
}
} else {
let start = len;
while len < chars.len() && chars[len].is_ascii_digit() {
len += 1;
}
if len == start {
return None;
}
}
} else {
if !chars[len].is_ascii_alphabetic() && chars[len] != '_' {
return None;
}
len += 1;
while len < chars.len()
&& (chars[len].is_ascii_alphanumeric()
|| chars[len] == '_'
|| chars[len] == '-'
|| chars[len] == '.')
{
len += 1;
}
}
if len >= chars.len() || chars[len] != ';' {
return None;
}
Some(len + 1) }