mf_model/
node_factory.rs

1use std::collections::HashMap;
2
3use serde_json::Value;
4
5use crate::{
6    error::{error_helpers::schema_error, PoolResult},
7    id_generator::IdGenerator,
8    mark::Mark,
9    mark_definition::MarkDefinition,
10    node::Node,
11    node_definition::{NodeDefinition, NodeTree},
12    schema::Schema,
13    types::NodeId,
14};
15
16/// 工厂负责基于 [`Schema`] 生成各类节点,复用 [`NodeType`] 的编译信息。
17#[derive(Clone)]
18pub struct NodeFactory<'schema> {
19    schema: &'schema Schema,
20}
21
22impl<'schema> NodeFactory<'schema> {
23    /// 创建新的工厂实例,保存对 [`Schema`] 的只读引用。
24    pub fn new(schema: &'schema Schema) -> Self {
25        Self { schema }
26    }
27
28    /// 暴露内部引用,便于调用方读取原始 Schema。
29    pub fn schema(&self) -> &'schema Schema {
30        self.schema
31    }
32
33    /// 按类型名称创建单节点。
34    pub fn create_node(
35        &self,
36        type_name: &str,
37        id: Option<NodeId>,
38        attrs: Option<&HashMap<String, Value>>,
39        content: Vec<NodeId>,
40        marks: Option<Vec<Mark>>,
41    ) -> PoolResult<Node> {
42        let node_type = self.schema.nodes.get(type_name).ok_or_else(|| {
43            schema_error(&format!("无法在 schema 中找到节点类型:{type_name}"))
44        })?;
45
46        Ok(Self::instantiate_node(node_type, id, attrs, content, marks))
47    }
48
49    /// 获取节点类型定义引用,便于上层读取配置。
50    pub fn node_definition(
51        &self,
52        type_name: &str,
53    ) -> Option<&NodeDefinition> {
54        self.schema.nodes.get(type_name)
55    }
56
57    /// 按类型名称创建标记。
58    pub fn create_mark(
59        &self,
60        type_name: &str,
61        attrs: Option<&HashMap<String, Value>>,
62    ) -> PoolResult<Mark> {
63        let mark_def = self.schema.marks.get(type_name).ok_or_else(|| {
64            schema_error(&format!("无法在 schema 中找到标记类型:{type_name}"))
65        })?;
66
67        Ok(Self::instantiate_mark(mark_def, attrs))
68    }
69
70    /// 获取所有节点类型名称,主要用于调试或提示。
71    pub fn node_names(&self) -> Vec<&'schema str> {
72        let mut names: Vec<&'schema str> =
73            self.schema.nodes.keys().map(|key| key.as_str()).collect();
74        names.sort();
75        names
76    }
77
78    /// 获取所有标记类型名称,主要用于调试或提示。
79    pub fn mark_names(&self) -> Vec<&'schema str> {
80        let mut names: Vec<&'schema str> =
81            self.schema.marks.keys().map(|key| key.as_str()).collect();
82        names.sort();
83        names
84    }
85
86    /// 获取指定节点类型,若不存在则返回带提示的错误。
87    pub fn ensure_node(
88        &self,
89        type_name: &str,
90    ) -> PoolResult<&NodeDefinition> {
91        match self.schema.nodes.get(type_name) {
92            Some(def) => Ok(def),
93            None => {
94                let mut available: Vec<&String> =
95                    self.schema.nodes.keys().collect();
96                available.sort_by(|a, b| a.as_str().cmp(b.as_str()));
97                Err(schema_error(&self.missing_message(
98                    "节点类型",
99                    type_name,
100                    available,
101                )))
102            },
103        }
104    }
105
106    /// 获取指定标记类型,若不存在则返回带提示的错误。
107    pub fn ensure_mark(
108        &self,
109        type_name: &str,
110    ) -> PoolResult<&MarkDefinition> {
111        match self.schema.marks.get(type_name) {
112            Some(def) => Ok(def),
113            None => {
114                let mut available: Vec<&String> =
115                    self.schema.marks.keys().collect();
116                available.sort_by(|a, b| a.as_str().cmp(b.as_str()));
117                Err(schema_error(&self.missing_message(
118                    "标记类型",
119                    type_name,
120                    available,
121                )))
122            },
123        }
124    }
125
126    fn missing_message(
127        &self,
128        kind: &str,
129        name: &str,
130        available: Vec<&String>,
131    ) -> String {
132        if available.is_empty() {
133            format!("未找到{kind} \"{name}\"。当前 Schema 中未声明任何{kind}。")
134        } else {
135            let preview: Vec<&str> =
136                available.iter().take(5).map(|s| s.as_str()).collect();
137            format!(
138                "未找到{kind} \"{name}\"。可用的{kind}示例:{}",
139                preview.join(", ")
140            )
141        }
142    }
143
144    /// 获取标记类型定义引用,若不存在则返回 `None`。
145    pub fn mark_definition(
146        &self,
147        type_name: &str,
148    ) -> Option<&MarkDefinition> {
149        self.schema.marks.get(type_name)
150    }
151
152    /// 获取整个 Schema 的节点与标记定义映射,便于上层做批量/调试读取。
153    pub fn definitions(
154        &self
155    ) -> (&HashMap<String, NodeDefinition>, &HashMap<String, MarkDefinition>)
156    {
157        (&self.schema.nodes, &self.schema.marks)
158    }
159
160    /// 以顶级节点为根构建整棵子树。
161    pub fn create_top_node(
162        &self,
163        id: Option<NodeId>,
164        attrs: Option<&HashMap<String, Value>>,
165        content: Vec<Node>,
166        marks: Option<Vec<Mark>>,
167    ) -> PoolResult<NodeTree> {
168        let top_node_type = self
169            .schema
170            .top_node_type
171            .as_ref()
172            .ok_or_else(|| schema_error("未找到顶级节点类型定义"))?;
173
174        self.create_tree_with_type(top_node_type, id, attrs, content, marks)
175    }
176
177    /// 为指定类型构建并填充子树。
178    pub fn create_tree(
179        &self,
180        type_name: &str,
181        id: Option<NodeId>,
182        attrs: Option<&HashMap<String, Value>>,
183        content: Vec<Node>,
184        marks: Option<Vec<Mark>>,
185    ) -> PoolResult<NodeTree> {
186        let node_type = self.schema.nodes.get(type_name).ok_or_else(|| {
187            schema_error(&format!("无法在 schema 中找到节点类型:{type_name}"))
188        })?;
189
190        self.create_tree_with_type(node_type, id, attrs, content, marks)
191    }
192
193    /// 暴露给 [`NodeDefinition`] 的内部构建逻辑。
194    pub(crate) fn create_tree_with_type(
195        &self,
196        node_type: &NodeDefinition,
197        id: Option<NodeId>,
198        attrs: Option<&HashMap<String, Value>>,
199        content: Vec<Node>,
200        marks: Option<Vec<Mark>>,
201    ) -> PoolResult<NodeTree> {
202        let id: NodeId = id.unwrap_or_else(IdGenerator::get_id);
203        let computed_attrs = node_type.compute_attrs(attrs);
204        let computed_marks = node_type.compute_marks(marks);
205
206        let mut filled_nodes: Vec<NodeTree> = Vec::new();
207        let mut final_content_ids: Vec<NodeId> = Vec::new();
208
209        if let Some(content_match) = &node_type.content_match {
210            if let Some(matched) =
211                content_match.match_fragment(&content, self.schema)
212            {
213                if let Some(needed_type_names) =
214                    matched.fill(&content, true, self.schema)
215                {
216                    for type_name in needed_type_names {
217                        if let Some(existing_node) =
218                            content.iter().find(|n| n.r#type == type_name)
219                        {
220                            let attrs_map: HashMap<String, Value> =
221                                existing_node
222                                    .attrs
223                                    .attrs
224                                    .iter()
225                                    .map(|(k, v)| (k.clone(), v.clone()))
226                                    .collect();
227                            let marks_vec: Vec<Mark> =
228                                existing_node.marks.iter().cloned().collect();
229                            let child_type = self
230                                .schema
231                                .nodes
232                                .get(&type_name)
233                                .ok_or_else(|| {
234                                    schema_error(&format!(
235                                        "无法在 schema 中找到节点类型:{type_name}"
236                                    ))
237                                })?;
238
239                            let child_tree = self.create_tree_with_type(
240                                child_type,
241                                Some(existing_node.id.clone()),
242                                Some(&attrs_map),
243                                vec![],
244                                Some(marks_vec),
245                            )?;
246                            let child_id = child_tree.0.id.clone();
247                            final_content_ids.push(child_id);
248                            filled_nodes.push(child_tree);
249                        } else {
250                            let child_type = self
251                                .schema
252                                .nodes
253                                .get(&type_name)
254                                .ok_or_else(|| {
255                                    schema_error(&format!(
256                                        "无法在 schema 中找到节点类型:{type_name}"
257                                    ))
258                                })?;
259
260                            let child_tree = self.create_tree_with_type(
261                                child_type,
262                                None,
263                                None,
264                                vec![],
265                                None,
266                            )?;
267                            let child_id = child_tree.0.id.clone();
268                            final_content_ids.push(child_id);
269                            filled_nodes.push(child_tree);
270                        }
271                    }
272                }
273            }
274        }
275
276        let node = Node::new(
277            &id,
278            node_type.name.clone(),
279            computed_attrs,
280            final_content_ids,
281            computed_marks,
282        );
283
284        Ok(NodeTree(node, filled_nodes))
285    }
286
287    fn instantiate_node(
288        node_type: &NodeDefinition,
289        id: Option<NodeId>,
290        attrs: Option<&HashMap<String, Value>>,
291        content: Vec<NodeId>,
292        marks: Option<Vec<Mark>>,
293    ) -> Node {
294        let id: NodeId = id.unwrap_or_else(IdGenerator::get_id);
295        let attrs = node_type.compute_attrs(attrs);
296        let marks = node_type.compute_marks(marks);
297
298        Node::new(&id, node_type.name.clone(), attrs, content, marks)
299    }
300
301    pub(crate) fn instantiate_mark(
302        mark_def: &MarkDefinition,
303        attrs: Option<&HashMap<String, Value>>,
304    ) -> Mark {
305        Mark {
306            r#type: mark_def.name.clone(),
307            attrs: mark_def.compute_attrs(attrs),
308        }
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use crate::mark_definition::MarkSpec;
316    use crate::node_definition::NodeSpec;
317    use crate::schema::{Schema, SchemaSpec};
318
319    fn build_schema() -> Schema {
320        let mut spec = SchemaSpec {
321            nodes: HashMap::new(),
322            marks: HashMap::new(),
323            top_node: Some("doc".to_string()),
324        };
325        spec.nodes.insert("doc".to_string(), NodeSpec::default());
326        spec.nodes.insert("paragraph".to_string(), NodeSpec::default());
327        spec.marks.insert("bold".to_string(), MarkSpec::default());
328        Schema::compile(spec).expect("schema should compile")
329    }
330
331    #[test]
332    fn ensure_node_returns_descriptive_error() {
333        let schema = build_schema();
334        let factory = NodeFactory::new(&schema);
335        let err = factory.ensure_node("unknown").unwrap_err();
336        let msg = err.to_string();
337        assert!(msg.contains("未找到节点类型"), "actual: {msg}");
338        assert!(msg.contains("unknown"), "actual: {msg}");
339        assert!(msg.contains("doc"), "actual: {msg}");
340        // 这里只对关键字段做断言,完整文案在运行时人工确认。
341    }
342
343    #[test]
344    fn ensure_mark_returns_descriptive_error() {
345        let schema = build_schema();
346        let factory = NodeFactory::new(&schema);
347        let err = factory.ensure_mark("italic").unwrap_err();
348        let msg = err.to_string();
349        assert!(msg.contains("未找到标记类型"), "actual: {msg}");
350        assert!(msg.contains("italic"), "actual: {msg}");
351        assert!(msg.contains("bold"), "actual: {msg}");
352    }
353    #[test]
354    fn node_and_mark_names_exposed() {
355        let schema = build_schema();
356        let factory = NodeFactory::new(&schema);
357        let nodes = factory.node_names();
358        assert!(nodes.contains(&"doc"));
359        let marks = factory.mark_names();
360        assert!(marks.contains(&"bold"));
361    }
362}