1use std::collections::HashMap;
2
3use serde_json::Value;
4
5use crate::{
6 error::{error_helpers::schema_error, PoolResult},
7 id_generator::IdGenerator,
8 mark::Mark,
9 mark_definition::MarkDefinition,
10 node::Node,
11 node_definition::{NodeDefinition, NodeTree},
12 schema::Schema,
13 types::NodeId,
14};
15
16#[derive(Clone)]
18pub struct NodeFactory<'schema> {
19 schema: &'schema Schema,
20}
21
22impl<'schema> NodeFactory<'schema> {
23 pub fn new(schema: &'schema Schema) -> Self {
25 Self { schema }
26 }
27
28 pub fn schema(&self) -> &'schema Schema {
30 self.schema
31 }
32
33 pub fn create_node(
35 &self,
36 type_name: &str,
37 id: Option<NodeId>,
38 attrs: Option<&HashMap<String, Value>>,
39 content: Vec<NodeId>,
40 marks: Option<Vec<Mark>>,
41 ) -> PoolResult<Node> {
42 let node_type = self.schema.nodes.get(type_name).ok_or_else(|| {
43 schema_error(&format!("无法在 schema 中找到节点类型:{type_name}"))
44 })?;
45
46 Ok(Self::instantiate_node(node_type, id, attrs, content, marks))
47 }
48
49 pub fn node_definition(
51 &self,
52 type_name: &str,
53 ) -> Option<&NodeDefinition> {
54 self.schema.nodes.get(type_name)
55 }
56
57 pub fn create_mark(
59 &self,
60 type_name: &str,
61 attrs: Option<&HashMap<String, Value>>,
62 ) -> PoolResult<Mark> {
63 let mark_def = self.schema.marks.get(type_name).ok_or_else(|| {
64 schema_error(&format!("无法在 schema 中找到标记类型:{type_name}"))
65 })?;
66
67 Ok(Self::instantiate_mark(mark_def, attrs))
68 }
69
70 pub fn node_names(&self) -> Vec<&'schema str> {
72 let mut names: Vec<&'schema str> =
73 self.schema.nodes.keys().map(|key| key.as_str()).collect();
74 names.sort();
75 names
76 }
77
78 pub fn mark_names(&self) -> Vec<&'schema str> {
80 let mut names: Vec<&'schema str> =
81 self.schema.marks.keys().map(|key| key.as_str()).collect();
82 names.sort();
83 names
84 }
85
86 pub fn ensure_node(
88 &self,
89 type_name: &str,
90 ) -> PoolResult<&NodeDefinition> {
91 match self.schema.nodes.get(type_name) {
92 Some(def) => Ok(def),
93 None => {
94 let mut available: Vec<&String> =
95 self.schema.nodes.keys().collect();
96 available.sort_by(|a, b| a.as_str().cmp(b.as_str()));
97 Err(schema_error(&self.missing_message(
98 "节点类型",
99 type_name,
100 available,
101 )))
102 },
103 }
104 }
105
106 pub fn ensure_mark(
108 &self,
109 type_name: &str,
110 ) -> PoolResult<&MarkDefinition> {
111 match self.schema.marks.get(type_name) {
112 Some(def) => Ok(def),
113 None => {
114 let mut available: Vec<&String> =
115 self.schema.marks.keys().collect();
116 available.sort_by(|a, b| a.as_str().cmp(b.as_str()));
117 Err(schema_error(&self.missing_message(
118 "标记类型",
119 type_name,
120 available,
121 )))
122 },
123 }
124 }
125
126 fn missing_message(
127 &self,
128 kind: &str,
129 name: &str,
130 available: Vec<&String>,
131 ) -> String {
132 if available.is_empty() {
133 format!(
134 "未找到{kind} \"{name}\"。当前 Schema 中未声明任何{kind}。"
135 )
136 } else {
137 let preview: Vec<&str> =
138 available.iter().take(5).map(|s| s.as_str()).collect();
139 format!(
140 "未找到{kind} \"{name}\"。可用的{kind}示例:{}",
141 preview.join(", ")
142 )
143 }
144 }
145
146 pub fn mark_definition(
148 &self,
149 type_name: &str,
150 ) -> Option<&MarkDefinition> {
151 self.schema.marks.get(type_name)
152 }
153
154 pub fn definitions(
156 &self
157 ) -> (&HashMap<String, NodeDefinition>, &HashMap<String, MarkDefinition>)
158 {
159 (&self.schema.nodes, &self.schema.marks)
160 }
161
162 pub fn create_top_node(
164 &self,
165 id: Option<NodeId>,
166 attrs: Option<&HashMap<String, Value>>,
167 content: Vec<Node>,
168 marks: Option<Vec<Mark>>,
169 ) -> PoolResult<NodeTree> {
170 let top_node_type = self
171 .schema
172 .top_node_type
173 .as_ref()
174 .ok_or_else(|| schema_error("未找到顶级节点类型定义"))?;
175
176 self.create_tree_with_type(top_node_type, id, attrs, content, marks)
177 }
178
179 pub fn create_tree(
181 &self,
182 type_name: &str,
183 id: Option<NodeId>,
184 attrs: Option<&HashMap<String, Value>>,
185 content: Vec<Node>,
186 marks: Option<Vec<Mark>>,
187 ) -> PoolResult<NodeTree> {
188 let node_type = self.schema.nodes.get(type_name).ok_or_else(|| {
189 schema_error(&format!("无法在 schema 中找到节点类型:{type_name}"))
190 })?;
191
192 self.create_tree_with_type(node_type, id, attrs, content, marks)
193 }
194
195 pub(crate) fn create_tree_with_type(
197 &self,
198 node_type: &NodeDefinition,
199 id: Option<NodeId>,
200 attrs: Option<&HashMap<String, Value>>,
201 content: Vec<Node>,
202 marks: Option<Vec<Mark>>,
203 ) -> PoolResult<NodeTree> {
204 let id: NodeId = id.unwrap_or_else(IdGenerator::get_id);
205 let computed_attrs = node_type.compute_attrs(attrs);
206 let computed_marks = node_type.compute_marks(marks);
207
208 let mut filled_nodes: Vec<NodeTree> = Vec::new();
209 let mut final_content_ids: Vec<NodeId> = Vec::new();
210
211 if let Some(content_match) = &node_type.content_match {
212 if let Some(matched) =
213 content_match.match_fragment(&content, self.schema)
214 {
215 if let Some(needed_type_names) =
216 matched.fill(&content, true, self.schema)
217 {
218 for type_name in needed_type_names {
219 if let Some(existing_node) =
220 content.iter().find(|n| n.r#type == type_name)
221 {
222 let attrs_map: HashMap<String, Value> =
223 existing_node
224 .attrs
225 .attrs
226 .iter()
227 .map(|(k, v)| (k.clone(), v.clone()))
228 .collect();
229 let marks_vec: Vec<Mark> =
230 existing_node.marks.iter().cloned().collect();
231 let child_type = self
232 .schema
233 .nodes
234 .get(&type_name)
235 .ok_or_else(|| {
236 schema_error(&format!(
237 "无法在 schema 中找到节点类型:{type_name}"
238 ))
239 })?;
240
241 let child_tree = self.create_tree_with_type(
242 child_type,
243 Some(existing_node.id.clone()),
244 Some(&attrs_map),
245 vec![],
246 Some(marks_vec),
247 )?;
248 let child_id = child_tree.0.id.clone();
249 final_content_ids.push(child_id);
250 filled_nodes.push(child_tree);
251 } else {
252 let child_type = self
253 .schema
254 .nodes
255 .get(&type_name)
256 .ok_or_else(|| {
257 schema_error(&format!(
258 "无法在 schema 中找到节点类型:{type_name}"
259 ))
260 })?;
261
262 let child_tree = self.create_tree_with_type(
263 child_type,
264 None,
265 None,
266 vec![],
267 None,
268 )?;
269 let child_id = child_tree.0.id.clone();
270 final_content_ids.push(child_id);
271 filled_nodes.push(child_tree);
272 }
273 }
274 }
275 }
276 }
277
278 let node = Node::new(
279 &id,
280 node_type.name.clone(),
281 computed_attrs,
282 final_content_ids,
283 computed_marks,
284 );
285
286 Ok(NodeTree(node, filled_nodes))
287 }
288
289 fn instantiate_node(
290 node_type: &NodeDefinition,
291 id: Option<NodeId>,
292 attrs: Option<&HashMap<String, Value>>,
293 content: Vec<NodeId>,
294 marks: Option<Vec<Mark>>,
295 ) -> Node {
296 let id: NodeId = id.unwrap_or_else(IdGenerator::get_id);
297 let attrs = node_type.compute_attrs(attrs);
298 let marks = node_type.compute_marks(marks);
299
300 Node::new(&id, node_type.name.clone(), attrs, content, marks)
301 }
302
303 pub(crate) fn instantiate_mark(
304 mark_def: &MarkDefinition,
305 attrs: Option<&HashMap<String, Value>>,
306 ) -> Mark {
307 Mark {
308 r#type: mark_def.name.clone(),
309 attrs: mark_def.compute_attrs(attrs),
310 }
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use crate::mark_definition::MarkSpec;
318 use crate::node_definition::NodeSpec;
319 use crate::schema::{Schema, SchemaSpec};
320
321 fn build_schema() -> Schema {
322 let mut spec = SchemaSpec {
323 nodes: HashMap::new(),
324 marks: HashMap::new(),
325 top_node: Some("doc".to_string()),
326 };
327 spec.nodes.insert("doc".to_string(), NodeSpec::default());
328 spec.nodes.insert("paragraph".to_string(), NodeSpec::default());
329 spec.marks.insert("bold".to_string(), MarkSpec::default());
330 Schema::compile(spec).expect("schema should compile")
331 }
332
333 #[test]
334 fn ensure_node_returns_descriptive_error() {
335 let schema = build_schema();
336 let factory = NodeFactory::new(&schema);
337 let err = factory.ensure_node("unknown").unwrap_err();
338 let msg = err.to_string();
339 assert!(msg.contains("未找到节点类型"), "actual: {msg}");
340 assert!(msg.contains("unknown"), "actual: {msg}");
341 assert!(msg.contains("doc"), "actual: {msg}");
342 }
344
345 #[test]
346 fn ensure_mark_returns_descriptive_error() {
347 let schema = build_schema();
348 let factory = NodeFactory::new(&schema);
349 let err = factory.ensure_mark("italic").unwrap_err();
350 let msg = err.to_string();
351 assert!(msg.contains("未找到标记类型"), "actual: {msg}");
352 assert!(msg.contains("italic"), "actual: {msg}");
353 assert!(msg.contains("bold"), "actual: {msg}");
354 }
355 #[test]
356 fn node_and_mark_names_exposed() {
357 let schema = build_schema();
358 let factory = NodeFactory::new(&schema);
359 let nodes = factory.node_names();
360 assert!(nodes.contains(&"doc"));
361 let marks = factory.mark_names();
362 assert!(marks.contains(&"bold"));
363 }
364}