1use std::collections::HashMap;
2
3use serde_json::Value;
4
5use crate::{
6 error::{error_helpers::schema_error, PoolResult},
7 id_generator::IdGenerator,
8 mark::Mark,
9 mark_definition::MarkDefinition,
10 node::Node,
11 node_definition::{NodeDefinition, NodeTree},
12 schema::Schema,
13 types::NodeId,
14};
15
16#[derive(Clone)]
18pub struct NodeFactory<'schema> {
19 schema: &'schema Schema,
20}
21
22impl<'schema> NodeFactory<'schema> {
23 pub fn new(schema: &'schema Schema) -> Self {
25 Self { schema }
26 }
27
28 pub fn schema(&self) -> &'schema Schema {
30 self.schema
31 }
32
33 pub fn create_node(
35 &self,
36 type_name: &str,
37 id: Option<NodeId>,
38 attrs: Option<&HashMap<String, Value>>,
39 content: Vec<NodeId>,
40 marks: Option<Vec<Mark>>,
41 ) -> PoolResult<Node> {
42 let node_type = self.schema.nodes.get(type_name).ok_or_else(|| {
43 schema_error(&format!("无法在 schema 中找到节点类型:{type_name}"))
44 })?;
45
46 Ok(Self::instantiate_node(node_type, id, attrs, content, marks))
47 }
48
49 pub fn node_definition(
51 &self,
52 type_name: &str,
53 ) -> Option<&NodeDefinition> {
54 self.schema.nodes.get(type_name)
55 }
56
57 pub fn create_mark(
59 &self,
60 type_name: &str,
61 attrs: Option<&HashMap<String, Value>>,
62 ) -> PoolResult<Mark> {
63 let mark_def = self.schema.marks.get(type_name).ok_or_else(|| {
64 schema_error(&format!("无法在 schema 中找到标记类型:{type_name}"))
65 })?;
66
67 Ok(Self::instantiate_mark(mark_def, attrs))
68 }
69
70 pub fn node_names(&self) -> Vec<&'schema str> {
72 let mut names: Vec<&'schema str> =
73 self.schema.nodes.keys().map(|key| key.as_str()).collect();
74 names.sort();
75 names
76 }
77
78 pub fn mark_names(&self) -> Vec<&'schema str> {
80 let mut names: Vec<&'schema str> =
81 self.schema.marks.keys().map(|key| key.as_str()).collect();
82 names.sort();
83 names
84 }
85
86 pub fn ensure_node(
88 &self,
89 type_name: &str,
90 ) -> PoolResult<&NodeDefinition> {
91 match self.schema.nodes.get(type_name) {
92 Some(def) => Ok(def),
93 None => {
94 let mut available: Vec<&String> =
95 self.schema.nodes.keys().collect();
96 available.sort_by(|a, b| a.as_str().cmp(b.as_str()));
97 Err(schema_error(&self.missing_message(
98 "节点类型",
99 type_name,
100 available,
101 )))
102 },
103 }
104 }
105
106 pub fn ensure_mark(
108 &self,
109 type_name: &str,
110 ) -> PoolResult<&MarkDefinition> {
111 match self.schema.marks.get(type_name) {
112 Some(def) => Ok(def),
113 None => {
114 let mut available: Vec<&String> =
115 self.schema.marks.keys().collect();
116 available.sort_by(|a, b| a.as_str().cmp(b.as_str()));
117 Err(schema_error(&self.missing_message(
118 "标记类型",
119 type_name,
120 available,
121 )))
122 },
123 }
124 }
125
126 fn missing_message(
127 &self,
128 kind: &str,
129 name: &str,
130 available: Vec<&String>,
131 ) -> String {
132 if available.is_empty() {
133 format!("未找到{kind} \"{name}\"。当前 Schema 中未声明任何{kind}。")
134 } else {
135 let preview: Vec<&str> =
136 available.iter().take(5).map(|s| s.as_str()).collect();
137 format!(
138 "未找到{kind} \"{name}\"。可用的{kind}示例:{}",
139 preview.join(", ")
140 )
141 }
142 }
143
144 pub fn mark_definition(
146 &self,
147 type_name: &str,
148 ) -> Option<&MarkDefinition> {
149 self.schema.marks.get(type_name)
150 }
151
152 pub fn definitions(
154 &self
155 ) -> (&HashMap<String, NodeDefinition>, &HashMap<String, MarkDefinition>)
156 {
157 (&self.schema.nodes, &self.schema.marks)
158 }
159
160 pub fn create_top_node(
162 &self,
163 id: Option<NodeId>,
164 attrs: Option<&HashMap<String, Value>>,
165 content: Vec<Node>,
166 marks: Option<Vec<Mark>>,
167 ) -> PoolResult<NodeTree> {
168 let top_node_type = self
169 .schema
170 .top_node_type
171 .as_ref()
172 .ok_or_else(|| schema_error("未找到顶级节点类型定义"))?;
173
174 self.create_tree_with_type(top_node_type, id, attrs, content, marks)
175 }
176
177 pub fn create_tree(
179 &self,
180 type_name: &str,
181 id: Option<NodeId>,
182 attrs: Option<&HashMap<String, Value>>,
183 content: Vec<Node>,
184 marks: Option<Vec<Mark>>,
185 ) -> PoolResult<NodeTree> {
186 let node_type = self.schema.nodes.get(type_name).ok_or_else(|| {
187 schema_error(&format!("无法在 schema 中找到节点类型:{type_name}"))
188 })?;
189
190 self.create_tree_with_type(node_type, id, attrs, content, marks)
191 }
192
193 pub(crate) fn create_tree_with_type(
195 &self,
196 node_type: &NodeDefinition,
197 id: Option<NodeId>,
198 attrs: Option<&HashMap<String, Value>>,
199 content: Vec<Node>,
200 marks: Option<Vec<Mark>>,
201 ) -> PoolResult<NodeTree> {
202 let id: NodeId = id.unwrap_or_else(IdGenerator::get_id);
203 let computed_attrs = node_type.compute_attrs(attrs);
204 let computed_marks = node_type.compute_marks(marks);
205
206 let mut filled_nodes: Vec<NodeTree> = Vec::new();
207 let mut final_content_ids: Vec<NodeId> = Vec::new();
208
209 if let Some(content_match) = &node_type.content_match {
210 if let Some(matched) =
211 content_match.match_fragment(&content, self.schema)
212 {
213 if let Some(needed_type_names) =
214 matched.fill(&content, true, self.schema)
215 {
216 for type_name in needed_type_names {
217 if let Some(existing_node) =
218 content.iter().find(|n| n.r#type == type_name)
219 {
220 let attrs_map: HashMap<String, Value> =
221 existing_node
222 .attrs
223 .attrs
224 .iter()
225 .map(|(k, v)| (k.clone(), v.clone()))
226 .collect();
227 let marks_vec: Vec<Mark> =
228 existing_node.marks.iter().cloned().collect();
229 let child_type = self
230 .schema
231 .nodes
232 .get(&type_name)
233 .ok_or_else(|| {
234 schema_error(&format!(
235 "无法在 schema 中找到节点类型:{type_name}"
236 ))
237 })?;
238
239 let child_tree = self.create_tree_with_type(
240 child_type,
241 Some(existing_node.id.clone()),
242 Some(&attrs_map),
243 vec![],
244 Some(marks_vec),
245 )?;
246 let child_id = child_tree.0.id.clone();
247 final_content_ids.push(child_id);
248 filled_nodes.push(child_tree);
249 } else {
250 let child_type = self
251 .schema
252 .nodes
253 .get(&type_name)
254 .ok_or_else(|| {
255 schema_error(&format!(
256 "无法在 schema 中找到节点类型:{type_name}"
257 ))
258 })?;
259
260 let child_tree = self.create_tree_with_type(
261 child_type,
262 None,
263 None,
264 vec![],
265 None,
266 )?;
267 let child_id = child_tree.0.id.clone();
268 final_content_ids.push(child_id);
269 filled_nodes.push(child_tree);
270 }
271 }
272 }
273 }
274 }
275
276 let node = Node::new(
277 &id,
278 node_type.name.clone(),
279 computed_attrs,
280 final_content_ids,
281 computed_marks,
282 );
283
284 Ok(NodeTree(node, filled_nodes))
285 }
286
287 fn instantiate_node(
288 node_type: &NodeDefinition,
289 id: Option<NodeId>,
290 attrs: Option<&HashMap<String, Value>>,
291 content: Vec<NodeId>,
292 marks: Option<Vec<Mark>>,
293 ) -> Node {
294 let id: NodeId = id.unwrap_or_else(IdGenerator::get_id);
295 let attrs = node_type.compute_attrs(attrs);
296 let marks = node_type.compute_marks(marks);
297
298 Node::new(&id, node_type.name.clone(), attrs, content, marks)
299 }
300
301 pub(crate) fn instantiate_mark(
302 mark_def: &MarkDefinition,
303 attrs: Option<&HashMap<String, Value>>,
304 ) -> Mark {
305 Mark {
306 r#type: mark_def.name.clone(),
307 attrs: mark_def.compute_attrs(attrs),
308 }
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use crate::mark_definition::MarkSpec;
316 use crate::node_definition::NodeSpec;
317 use crate::schema::{Schema, SchemaSpec};
318
319 fn build_schema() -> Schema {
320 let mut spec = SchemaSpec {
321 nodes: HashMap::new(),
322 marks: HashMap::new(),
323 top_node: Some("doc".to_string()),
324 };
325 spec.nodes.insert("doc".to_string(), NodeSpec::default());
326 spec.nodes.insert("paragraph".to_string(), NodeSpec::default());
327 spec.marks.insert("bold".to_string(), MarkSpec::default());
328 Schema::compile(spec).expect("schema should compile")
329 }
330
331 #[test]
332 fn ensure_node_returns_descriptive_error() {
333 let schema = build_schema();
334 let factory = NodeFactory::new(&schema);
335 let err = factory.ensure_node("unknown").unwrap_err();
336 let msg = err.to_string();
337 assert!(msg.contains("未找到节点类型"), "actual: {msg}");
338 assert!(msg.contains("unknown"), "actual: {msg}");
339 assert!(msg.contains("doc"), "actual: {msg}");
340 }
342
343 #[test]
344 fn ensure_mark_returns_descriptive_error() {
345 let schema = build_schema();
346 let factory = NodeFactory::new(&schema);
347 let err = factory.ensure_mark("italic").unwrap_err();
348 let msg = err.to_string();
349 assert!(msg.contains("未找到标记类型"), "actual: {msg}");
350 assert!(msg.contains("italic"), "actual: {msg}");
351 assert!(msg.contains("bold"), "actual: {msg}");
352 }
353 #[test]
354 fn node_and_mark_names_exposed() {
355 let schema = build_schema();
356 let factory = NodeFactory::new(&schema);
357 let nodes = factory.node_names();
358 assert!(nodes.contains(&"doc"));
359 let marks = factory.mark_names();
360 assert!(marks.contains(&"bold"));
361 }
362}