1use crate::error::error_helpers::schema_error;
2use crate::error::PoolResult;
3
4use super::attrs::Attrs;
5use super::content::ContentMatch;
6use super::mark_definition::{MarkDefinition, MarkSpec};
7use super::node_definition::{NodeDefinition, NodeSpec};
8use crate::node_factory::NodeFactory;
9use serde::Serialize;
10use serde_json::Value;
11use std::any::Any;
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize)]
17pub struct Attribute {
18 pub has_default: bool,
19 pub default: Option<Value>,
20}
21
22impl Attribute {
23 pub(crate) fn new(options: AttributeSpec) -> Self {
25 Attribute {
26 has_default: options.default.is_some(),
27 default: options.default,
28 }
29 }
30 pub fn is_required(&self) -> bool {
33 !self.has_default
34 }
35}
36#[derive(Clone, Debug)]
39pub struct Schema {
40 pub spec: SchemaSpec,
42 pub top_node_type: Option<NodeDefinition>,
44 pub cached: Arc<Mutex<HashMap<String, Arc<dyn Any + Send + Sync>>>>,
46 pub(crate) nodes: HashMap<String, NodeDefinition>,
48 pub(crate) marks: HashMap<String, MarkDefinition>,
50}
51impl PartialEq for Schema {
52 fn eq(
53 &self,
54 other: &Self,
55 ) -> bool {
56 self.spec == other.spec
57 && self.top_node_type == other.top_node_type
58 && self.nodes == other.nodes
59 && self.marks == other.marks
60 }
61}
62impl Eq for Schema {}
63impl Schema {
64 #[cfg_attr(feature = "dev-tracing", tracing::instrument(skip(spec), fields(
66 crate_name = "model",
67 node_count = spec.nodes.len(),
68 mark_count = spec.marks.len()
69 )))]
70 pub fn new(spec: SchemaSpec) -> Self {
71 let mut instance_spec = SchemaSpec {
72 nodes: HashMap::new(),
73 marks: HashMap::new(),
74 top_node: spec.top_node,
75 };
76 for (key, value) in spec.nodes {
78 instance_spec.nodes.insert(key, value);
79 }
80 for (key, value) in spec.marks {
81 instance_spec.marks.insert(key, value);
82 }
83 Schema {
84 spec: instance_spec,
85 top_node_type: None,
86 cached: Arc::new(Mutex::new(HashMap::new())),
87 nodes: HashMap::new(),
88 marks: HashMap::new(),
89 }
90 }
91 pub fn factory(&self) -> NodeFactory<'_> {
92 NodeFactory::new(self)
93 }
94 #[cfg_attr(feature = "dev-tracing", tracing::instrument(skip(instance_spec), fields(
97 crate_name = "model",
98 node_count = instance_spec.nodes.len(),
99 mark_count = instance_spec.marks.len()
100 )))]
101 pub fn compile(instance_spec: SchemaSpec) -> PoolResult<Schema> {
102 let mut schema: Schema = Schema::new(instance_spec);
103 let nodes: HashMap<String, NodeDefinition> =
104 NodeDefinition::compile(schema.spec.nodes.clone());
105 let marks = MarkDefinition::compile(schema.spec.marks.clone());
106 let mut content_expr_cache = HashMap::new();
107 let mut updated_nodes = HashMap::new();
108 for (prop, type_) in &nodes {
109 if marks.contains_key(prop) {
110 return Err(schema_error(&format!(
111 "{prop} 不能既是节点又是标记"
112 )));
113 }
114
115 let content_expr = type_.spec.content.as_deref().unwrap_or("");
116 let mark_expr = type_.spec.marks.as_deref();
117
118 let content_expr_string = content_expr.to_string();
119 let content_match = content_expr_cache
120 .entry(content_expr_string.clone())
121 .or_insert_with(|| {
122 ContentMatch::parse(content_expr_string, &nodes)
123 })
124 .clone();
125
126 let mark_set = match mark_expr {
127 Some("_") => None,
128 Some(expr) => {
129 let marks_result =
130 gather_marks(&marks, expr.split_whitespace().collect());
131 match marks_result {
132 Ok(marks) => Some(marks.into_iter().cloned().collect()), Err(e) => return Err(schema_error(&e)),
134 }
135 },
136 None => None,
137 };
138
139 let mut node = type_.clone();
140 node.content_match = Some(content_match);
141 node.mark_set = mark_set;
142 updated_nodes.insert(prop.clone(), node);
143 }
144 schema.nodes = updated_nodes;
145 schema.marks = marks;
146 schema.top_node_type = match schema.nodes.get(
147 &schema.spec.top_node.clone().unwrap_or_else(|| "doc".to_string()),
148 ) {
149 Some(node) => Some(node.clone()),
150 None => {
151 return Err(schema_error("未找到顶级节点类型定义"));
152 },
153 };
154
155 Ok(schema)
156 }
157}
158#[derive(Clone, PartialEq, Eq, Debug)]
161pub struct SchemaSpec {
162 pub nodes: HashMap<String, NodeSpec>,
163 pub marks: HashMap<String, MarkSpec>,
164 pub top_node: Option<String>,
165}
166
167pub fn default_attrs(
172 attrs: &HashMap<String, Attribute>
173) -> Option<HashMap<String, Value>> {
174 let mut defaults = HashMap::new();
175
176 for (attr_name, attr) in attrs {
177 if let Some(default) = &attr.default {
178 defaults.insert(attr_name.clone(), default.clone());
179 } else {
180 return None;
181 }
182 }
183
184 Some(defaults)
185}
186#[derive(Clone, PartialEq, Debug, Eq, Hash, Serialize)]
188pub struct AttributeSpec {
189 pub default: Option<Value>,
191}
192fn gather_marks<'a>(
195 marks_map: &'a HashMap<String, MarkDefinition>,
196 marks: Vec<&'a str>,
197) -> Result<Vec<&'a MarkDefinition>, String> {
198 let mut found = Vec::new();
199
200 for name in marks {
201 if let Some(mark) = marks_map.get(name) {
202 found.push(mark);
203 } else if name == "_" {
204 found.extend(marks_map.values());
206 } else {
207 let mut matched = false;
209 for mark_ref in marks_map.values() {
210 if mark_ref.spec.group.as_ref().is_some_and(|group| {
211 group.split_whitespace().any(|g| g == name)
212 }) {
213 found.push(mark_ref);
214 matched = true;
215 break;
216 }
217 }
218 if !matched {
219 return Err(format!("未知的标记类型: '{name}'"));
220 }
221 }
222 }
223 Ok(found)
224}
225pub fn compute_attrs(
228 attrs: &HashMap<String, Attribute>,
229 value: Option<&HashMap<String, Value>>,
230) -> Attrs {
231 let mut built = Attrs::default();
232
233 for (name, attr) in attrs {
234 let given = value.and_then(|v| v.get(name));
235
236 let given = match given {
237 Some(val) => val.clone(),
238 None => {
239 if attr.has_default {
240 attr.default.clone().unwrap_or_else(|| {
241 panic!("没有为属性提供默认值 {name}")
242 })
243 } else {
244 Value::Null
245 }
246 },
247 };
248
249 built[name] = given;
250 }
251
252 built
253}