use crate::error::error_helpers::schema_error;
use crate::error::PoolResult;
use super::attrs::Attrs;
use super::content::ContentMatch;
use super::mark_definition::{MarkDefinition, MarkSpec};
use super::node_definition::{NodeDefinition, NodeSpec};
use crate::node_factory::NodeFactory;
use serde::Serialize;
use serde_json::Value;
use std::any::Any;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize)]
pub struct Attribute {
pub has_default: bool,
pub default: Option<Value>,
}
impl Attribute {
pub(crate) fn new(options: AttributeSpec) -> Self {
Attribute {
has_default: options.default.is_some(),
default: options.default,
}
}
pub fn is_required(&self) -> bool {
!self.has_default
}
}
#[derive(Clone, Debug)]
pub struct Schema {
pub spec: SchemaSpec,
pub top_node_type: Option<NodeDefinition>,
pub cached: Arc<Mutex<HashMap<String, Arc<dyn Any + Send + Sync>>>>,
pub(crate) nodes: HashMap<String, NodeDefinition>,
pub(crate) marks: HashMap<String, MarkDefinition>,
}
impl PartialEq for Schema {
fn eq(
&self,
other: &Self,
) -> bool {
self.spec == other.spec
&& self.top_node_type == other.top_node_type
&& self.nodes == other.nodes
&& self.marks == other.marks
}
}
impl Eq for Schema {}
impl Schema {
#[cfg_attr(feature = "dev-tracing", tracing::instrument(skip(spec), fields(
crate_name = "model",
node_count = spec.nodes.len(),
mark_count = spec.marks.len()
)))]
pub fn new(spec: SchemaSpec) -> Self {
let mut instance_spec = SchemaSpec {
nodes: HashMap::new(),
marks: HashMap::new(),
top_node: spec.top_node,
};
for (key, value) in spec.nodes {
instance_spec.nodes.insert(key, value);
}
for (key, value) in spec.marks {
instance_spec.marks.insert(key, value);
}
Schema {
spec: instance_spec,
top_node_type: None,
cached: Arc::new(Mutex::new(HashMap::new())),
nodes: HashMap::new(),
marks: HashMap::new(),
}
}
pub fn factory(&self) -> NodeFactory<'_> {
NodeFactory::new(self)
}
#[cfg_attr(feature = "dev-tracing", tracing::instrument(skip(instance_spec), fields(
crate_name = "model",
node_count = instance_spec.nodes.len(),
mark_count = instance_spec.marks.len()
)))]
pub fn compile(instance_spec: SchemaSpec) -> PoolResult<Schema> {
let mut schema: Schema = Schema::new(instance_spec);
let nodes: HashMap<String, NodeDefinition> =
NodeDefinition::compile(schema.spec.nodes.clone());
let marks = MarkDefinition::compile(schema.spec.marks.clone());
let mut content_expr_cache = HashMap::new();
let mut updated_nodes = HashMap::new();
for (prop, type_) in &nodes {
if marks.contains_key(prop) {
return Err(schema_error(&format!(
"{prop} 不能既是节点又是标记"
)));
}
let content_expr = type_.spec.content.as_deref().unwrap_or("");
let mark_expr = type_.spec.marks.as_deref();
let content_expr_string = content_expr.to_string();
let content_match = content_expr_cache
.entry(content_expr_string.clone())
.or_insert_with(|| {
ContentMatch::parse(content_expr_string, &nodes)
})
.clone();
let mark_set = match mark_expr {
Some("_") => None,
Some(expr) => {
let marks_result =
gather_marks(&marks, expr.split_whitespace().collect());
match marks_result {
Ok(marks) => Some(marks.into_iter().cloned().collect()), Err(e) => return Err(schema_error(&e)),
}
},
None => None,
};
let mut node = type_.clone();
node.content_match = Some(content_match);
node.mark_set = mark_set;
updated_nodes.insert(prop.clone(), node);
}
schema.nodes = updated_nodes;
schema.marks = marks;
schema.top_node_type = match schema.nodes.get(
&schema.spec.top_node.clone().unwrap_or_else(|| "doc".to_string()),
) {
Some(node) => Some(node.clone()),
None => {
return Err(schema_error("未找到顶级节点类型定义"));
},
};
Ok(schema)
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct SchemaSpec {
pub nodes: HashMap<String, NodeSpec>,
pub marks: HashMap<String, MarkSpec>,
pub top_node: Option<String>,
}
pub fn default_attrs(
attrs: &HashMap<String, Attribute>
) -> Option<HashMap<String, Value>> {
let mut defaults = HashMap::new();
for (attr_name, attr) in attrs {
if let Some(default) = &attr.default {
defaults.insert(attr_name.clone(), default.clone());
} else {
return None;
}
}
Some(defaults)
}
#[derive(Clone, PartialEq, Debug, Eq, Hash, Serialize)]
pub struct AttributeSpec {
pub default: Option<Value>,
}
fn gather_marks<'a>(
marks_map: &'a HashMap<String, MarkDefinition>,
marks: Vec<&'a str>,
) -> Result<Vec<&'a MarkDefinition>, String> {
let mut found = Vec::new();
for name in marks {
if let Some(mark) = marks_map.get(name) {
found.push(mark);
} else if name == "_" {
found.extend(marks_map.values());
} else {
let mut matched = false;
for mark_ref in marks_map.values() {
if mark_ref.spec.group.as_ref().is_some_and(|group| {
group.split_whitespace().any(|g| g == name)
}) {
found.push(mark_ref);
matched = true;
break;
}
}
if !matched {
return Err(format!("未知的标记类型: '{name}'"));
}
}
}
Ok(found)
}
pub fn compute_attrs(
attrs: &HashMap<String, Attribute>,
value: Option<&HashMap<String, Value>>,
) -> Attrs {
let mut built = Attrs::default();
for (name, attr) in attrs {
let given = value.and_then(|v| v.get(name));
let given = match given {
Some(val) => val.clone(),
None => {
if attr.has_default {
attr.default.clone().unwrap_or_else(|| {
panic!("没有为属性提供默认值 {name}")
})
} else {
Value::Null
}
},
};
built[name] = given;
}
built
}