use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum HostNodeKind {
Document,
Session,
Definition,
Paragraph,
List,
ListItem,
Verbatim,
Table,
Annotation,
}
impl HostNodeKind {
pub const ALL: &'static [HostNodeKind] = &[
HostNodeKind::Document,
HostNodeKind::Session,
HostNodeKind::Definition,
HostNodeKind::Paragraph,
HostNodeKind::List,
HostNodeKind::ListItem,
HostNodeKind::Verbatim,
HostNodeKind::Table,
HostNodeKind::Annotation,
];
pub const fn as_str(self) -> &'static str {
match self {
HostNodeKind::Document => "document",
HostNodeKind::Session => "session",
HostNodeKind::Definition => "definition",
HostNodeKind::Paragraph => "paragraph",
HostNodeKind::List => "list",
HostNodeKind::ListItem => "list_item",
HostNodeKind::Verbatim => "verbatim",
HostNodeKind::Table => "table",
HostNodeKind::Annotation => "annotation",
}
}
pub fn parse(s: &str) -> Option<HostNodeKind> {
Self::ALL.iter().copied().find(|k| k.as_str() == s)
}
pub fn allowed_list() -> String {
Self::ALL
.iter()
.map(|k| k.as_str())
.collect::<Vec<_>>()
.join(", ")
}
}
impl std::fmt::Display for HostNodeKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_through_str() {
for k in HostNodeKind::ALL {
assert_eq!(HostNodeKind::parse(k.as_str()), Some(*k));
}
}
#[test]
fn parse_unknown_returns_none() {
assert_eq!(HostNodeKind::parse("fragment"), None);
assert_eq!(HostNodeKind::parse(""), None);
assert_eq!(HostNodeKind::parse("Document"), None); }
#[test]
fn serialises_as_snake_case_string() {
let k = HostNodeKind::ListItem;
let s = serde_json::to_string(&k).unwrap();
assert_eq!(s, r#""list_item""#);
let back: HostNodeKind = serde_json::from_str(&s).unwrap();
assert_eq!(back, k);
}
#[test]
fn allowed_list_includes_every_variant() {
let list = HostNodeKind::allowed_list();
let tokens: std::collections::HashSet<&str> = list.split(", ").collect();
for k in HostNodeKind::ALL {
assert!(
tokens.contains(k.as_str()),
"allowed_list missing variant `{}`: {list}",
k.as_str()
);
}
assert_eq!(tokens.len(), HostNodeKind::ALL.len());
}
#[test]
fn display_matches_as_str() {
assert_eq!(HostNodeKind::Paragraph.to_string(), "paragraph");
}
}