use crate::node::Node;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::fmt;
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct Schema {
#[serde(default)]
pub nodes: HashMap<String, NodeSpec>,
#[serde(default)]
pub marks: HashMap<String, MarkSpec>,
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct NodeSpec {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content: Option<HashSet<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub marks: Option<HashSet<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub attrs: Option<HashSet<String>>,
#[serde(default, skip_serializing_if = "HashSet::is_empty")]
pub required_attrs: HashSet<String>,
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct MarkSpec {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub attrs: Option<HashSet<String>>,
#[serde(default, skip_serializing_if = "HashSet::is_empty")]
pub required_attrs: HashSet<String>,
}
fn into_set<I, S>(items: I) -> HashSet<String>
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
items.into_iter().map(Into::into).collect()
}
impl Schema {
pub fn new() -> Self {
Self::default()
}
pub fn node(mut self, node_type: impl Into<String>, spec: NodeSpec) -> Self {
self.nodes.insert(node_type.into(), spec);
self
}
pub fn mark(mut self, mark_type: impl Into<String>, spec: MarkSpec) -> Self {
self.marks.insert(mark_type.into(), spec);
self
}
pub fn from_json_str(s: &str) -> crate::Result<Self> {
Ok(serde_json::from_str(s)?)
}
}
impl NodeSpec {
pub fn new() -> Self {
Self::default()
}
pub fn content<I, S>(mut self, types: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.content = Some(into_set(types));
self
}
pub fn marks<I, S>(mut self, types: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.marks = Some(into_set(types));
self
}
pub fn attrs<I, S>(mut self, keys: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.attrs = Some(into_set(keys));
self
}
pub fn required_attrs<I, S>(mut self, keys: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.required_attrs = into_set(keys);
self
}
}
impl MarkSpec {
pub fn new() -> Self {
Self::default()
}
pub fn attrs<I, S>(mut self, keys: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.attrs = Some(into_set(keys));
self
}
pub fn required_attrs<I, S>(mut self, keys: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.required_attrs = into_set(keys);
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Violation {
pub path: Vec<usize>,
pub kind: ViolationKind,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ViolationKind {
MissingNodeType,
UnknownNodeType(String),
DisallowedChild { parent: String, child: String },
UnknownMark(String),
DisallowedMark { node: String, mark: String },
MissingAttr { key: String },
UnknownAttr { key: String },
}
impl fmt::Display for ViolationKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ViolationKind::MissingNodeType => write!(f, "node has no type"),
ViolationKind::UnknownNodeType(t) => write!(f, "unknown node type `{t}`"),
ViolationKind::DisallowedChild { parent, child } => {
write!(f, "node type `{child}` not allowed inside `{parent}`")
}
ViolationKind::UnknownMark(m) => write!(f, "unknown mark type `{m}`"),
ViolationKind::DisallowedMark { node, mark } => {
write!(f, "mark `{mark}` not allowed on `{node}`")
}
ViolationKind::MissingAttr { key } => write!(f, "missing required attribute `{key}`"),
ViolationKind::UnknownAttr { key } => write!(f, "unknown attribute `{key}`"),
}
}
}
impl fmt::Display for Violation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "at {:?}: {}", self.path, self.kind)
}
}
impl Node {
pub fn validate(&self, schema: &Schema) -> Vec<Violation> {
let mut out = Vec::new();
let mut path = Vec::new();
validate_node(self, schema, &mut path, &mut out);
out
}
pub fn is_valid(&self, schema: &Schema) -> bool {
self.validate(schema).is_empty()
}
}
fn validate_node(node: &Node, schema: &Schema, path: &mut Vec<usize>, out: &mut Vec<Violation>) {
let push = |out: &mut Vec<Violation>, path: &[usize], kind: ViolationKind| {
out.push(Violation {
path: path.to_vec(),
kind,
});
};
let spec = match &node.node_type {
None => {
push(out, path, ViolationKind::MissingNodeType);
None
}
Some(t) => match schema.nodes.get(t) {
Some(spec) => Some(spec),
None => {
push(out, path, ViolationKind::UnknownNodeType(t.clone()));
None
}
},
};
if let Some(spec) = spec {
check_attrs(
node.attrs.as_ref(),
spec.attrs.as_ref(),
&spec.required_attrs,
path,
out,
);
if let Some(marks) = &node.marks {
let node_type = node.node_type.as_deref().unwrap_or_default();
for mark in marks {
match schema.marks.get(&mark.mark_type) {
None => push(
out,
path,
ViolationKind::UnknownMark(mark.mark_type.clone()),
),
Some(mark_spec) => {
if let Some(allowed) = &spec.marks {
if !allowed.contains(&mark.mark_type) {
push(
out,
path,
ViolationKind::DisallowedMark {
node: node_type.to_string(),
mark: mark.mark_type.clone(),
},
);
}
}
check_attrs(
mark.attrs.as_ref(),
mark_spec.attrs.as_ref(),
&mark_spec.required_attrs,
path,
out,
);
}
}
}
}
if let (Some(allowed), Some(children)) = (&spec.content, &node.content) {
let parent = node.node_type.as_deref().unwrap_or_default();
for child in children {
if let Some(ct) = &child.node_type {
if !allowed.contains(ct) {
push(
out,
path,
ViolationKind::DisallowedChild {
parent: parent.to_string(),
child: ct.clone(),
},
);
}
}
}
}
}
if let Some(children) = &node.content {
for (i, child) in children.iter().enumerate() {
path.push(i);
validate_node(child, schema, path, out);
path.pop();
}
}
}
fn check_attrs(
attrs: Option<&serde_json::Map<String, serde_json::Value>>,
allowed: Option<&HashSet<String>>,
required: &HashSet<String>,
path: &[usize],
out: &mut Vec<Violation>,
) {
for key in required {
let present = attrs.is_some_and(|m| m.contains_key(key));
if !present {
out.push(Violation {
path: path.to_vec(),
kind: ViolationKind::MissingAttr { key: key.clone() },
});
}
}
if let (Some(allowed), Some(attrs)) = (allowed, attrs) {
for key in attrs.keys() {
if !allowed.contains(key) {
out.push(Violation {
path: path.to_vec(),
kind: ViolationKind::UnknownAttr { key: key.clone() },
});
}
}
}
}