use serde::{Deserialize, Serialize};
use std::str::FromStr;
#[cfg(feature = "python")]
use pyo3::pyclass;
#[cfg(feature = "wasm")]
use tsify_next::Tsify;
#[derive(Debug, PartialEq, Clone)]
#[cfg_attr(feature = "python", pyclass(from_py_object))]
#[cfg_attr(feature = "wasm", derive(Tsify))]
#[cfg_attr(feature = "wasm", tsify(into_wasm_abi))]
pub enum XMLType {
Attribute { is_attr: bool, name: String },
Element { is_attr: bool, name: String },
Wrapped {
is_attr: bool,
name: String,
wrapped: Option<Vec<String>>,
},
}
impl FromStr for XMLType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Some(name) = s.strip_prefix('@') {
Ok(XMLType::Attribute {
is_attr: true,
name: name.to_string(),
})
} else if s.contains('/') {
let (wrapped, name) = split_at_last(s, '/');
Ok(XMLType::Wrapped {
is_attr: false,
name,
wrapped: Some(wrapped.split('/').map(|s| s.trim().to_string()).collect()),
})
} else {
Ok(XMLType::Element {
is_attr: false,
name: s.to_string(),
})
}
}
}
impl<'de> Deserialize<'de> for XMLType {
fn deserialize<D>(deserializer: D) -> Result<XMLType, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
struct XMLTypeVisitor {
is_attr: bool,
name: String,
}
let value = XMLTypeVisitor::deserialize(deserializer)?;
if value.is_attr {
Ok(XMLType::Attribute {
is_attr: true,
name: value.name,
})
} else {
Ok(XMLType::Element {
is_attr: false,
name: value.name,
})
}
}
}
impl Serialize for XMLType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
#[derive(Serialize)]
struct XMLTypeVisitor {
is_attr: bool,
name: String,
wrapped: Option<Vec<String>>,
}
let visitor = match self {
XMLType::Attribute { is_attr, name } | XMLType::Element { is_attr, name } => {
XMLTypeVisitor {
is_attr: *is_attr,
name: name.to_string(),
wrapped: None,
}
}
XMLType::Wrapped {
is_attr,
name,
wrapped,
} => XMLTypeVisitor {
is_attr: *is_attr,
name: name.to_string(),
wrapped: wrapped.clone(),
},
};
visitor.serialize(serializer)
}
}
pub(crate) fn split_at_last(s: &str, c: char) -> (String, String) {
let parts: Vec<&str> = s.split(c).collect();
if parts.len() <= 1 {
(String::new(), s.to_string())
} else {
let last = parts[parts.len() - 1];
let rest = parts[..parts.len() - 1].join(&c.to_string());
(rest.to_string(), last.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_xmltype_from_str() {
let attr = XMLType::Attribute {
is_attr: true,
name: "id".to_string(),
};
let elem = XMLType::Element {
is_attr: false,
name: "name".to_string(),
};
assert_eq!(XMLType::from_str("@id").unwrap(), attr);
assert_eq!(XMLType::from_str("name").unwrap(), elem);
}
#[test]
fn test_xmltype_deserialize() {
let attr = XMLType::Attribute {
is_attr: true,
name: "id".to_string(),
};
let elem = XMLType::Element {
is_attr: false,
name: "name".to_string(),
};
let attr_json = r#"{"is_attr":true,"name":"id"}"#;
let elem_json = r#"{"is_attr":false,"name":"name"}"#;
assert_eq!(serde_json::from_str::<XMLType>(attr_json).unwrap(), attr);
assert_eq!(serde_json::from_str::<XMLType>(elem_json).unwrap(), elem);
}
}