mf_model/
node_factory.rs

1use std::collections::HashMap;
2
3use serde_json::Value;
4
5use crate::{
6    error::{error_helpers::schema_error, PoolResult},
7    id_generator::IdGenerator,
8    mark::Mark,
9    mark_definition::MarkDefinition,
10    node::Node,
11    node_definition::{NodeDefinition, NodeTree},
12    schema::Schema,
13    types::NodeId,
14};
15
16/// 工厂负责基于 [`Schema`] 生成各类节点,复用 [`NodeType`] 的编译信息。
17#[derive(Clone)]
18pub struct NodeFactory<'schema> {
19    schema: &'schema Schema,
20}
21
22impl<'schema> NodeFactory<'schema> {
23    /// 创建新的工厂实例,保存对 [`Schema`] 的只读引用。
24    pub fn new(schema: &'schema Schema) -> Self {
25        Self { schema }
26    }
27
28    /// 暴露内部引用,便于调用方读取原始 Schema。
29    pub fn schema(&self) -> &'schema Schema {
30        self.schema
31    }
32
33    /// 按类型名称创建单节点。
34    pub fn create_node(
35        &self,
36        type_name: &str,
37        id: Option<NodeId>,
38        attrs: Option<&HashMap<String, Value>>,
39        content: Vec<NodeId>,
40        marks: Option<Vec<Mark>>,
41    ) -> PoolResult<Node> {
42        let node_type = self.schema.nodes.get(type_name).ok_or_else(|| {
43            schema_error(&format!("无法在 schema 中找到节点类型:{type_name}"))
44        })?;
45
46        Ok(Self::instantiate_node(node_type, id, attrs, content, marks))
47    }
48
49    /// 获取节点类型定义引用,便于上层读取配置。
50    pub fn node_definition(
51        &self,
52        type_name: &str,
53    ) -> Option<&NodeDefinition> {
54        self.schema.nodes.get(type_name)
55    }
56
57    /// 按类型名称创建标记。
58    pub fn create_mark(
59        &self,
60        type_name: &str,
61        attrs: Option<&HashMap<String, Value>>,
62    ) -> PoolResult<Mark> {
63        let mark_def = self.schema.marks.get(type_name).ok_or_else(|| {
64            schema_error(&format!("无法在 schema 中找到标记类型:{type_name}"))
65        })?;
66
67        Ok(Self::instantiate_mark(mark_def, attrs))
68    }
69
70    /// 获取所有节点类型名称,主要用于调试或提示。
71    pub fn node_names(&self) -> Vec<&'schema str> {
72        let mut names: Vec<&'schema str> =
73            self.schema.nodes.keys().map(|key| key.as_str()).collect();
74        names.sort();
75        names
76    }
77
78    /// 获取所有标记类型名称,主要用于调试或提示。
79    pub fn mark_names(&self) -> Vec<&'schema str> {
80        let mut names: Vec<&'schema str> =
81            self.schema.marks.keys().map(|key| key.as_str()).collect();
82        names.sort();
83        names
84    }
85
86    /// 获取指定节点类型,若不存在则返回带提示的错误。
87    pub fn ensure_node(
88        &self,
89        type_name: &str,
90    ) -> PoolResult<&NodeDefinition> {
91        match self.schema.nodes.get(type_name) {
92            Some(def) => Ok(def),
93            None => {
94                let mut available: Vec<&String> =
95                    self.schema.nodes.keys().collect();
96                available.sort_by(|a, b| a.as_str().cmp(b.as_str()));
97                Err(schema_error(&self.missing_message(
98                    "节点类型",
99                    type_name,
100                    available,
101                )))
102            },
103        }
104    }
105
106    /// 获取指定标记类型,若不存在则返回带提示的错误。
107    pub fn ensure_mark(
108        &self,
109        type_name: &str,
110    ) -> PoolResult<&MarkDefinition> {
111        match self.schema.marks.get(type_name) {
112            Some(def) => Ok(def),
113            None => {
114                let mut available: Vec<&String> =
115                    self.schema.marks.keys().collect();
116                available.sort_by(|a, b| a.as_str().cmp(b.as_str()));
117                Err(schema_error(&self.missing_message(
118                    "标记类型",
119                    type_name,
120                    available,
121                )))
122            },
123        }
124    }
125
126    fn missing_message(
127        &self,
128        kind: &str,
129        name: &str,
130        available: Vec<&String>,
131    ) -> String {
132        if available.is_empty() {
133            format!(
134                "未找到{kind} \"{name}\"。当前 Schema 中未声明任何{kind}。"
135            )
136        } else {
137            let preview: Vec<&str> =
138                available.iter().take(5).map(|s| s.as_str()).collect();
139            format!(
140                "未找到{kind} \"{name}\"。可用的{kind}示例:{}",
141                preview.join(", ")
142            )
143        }
144    }
145
146    /// 获取标记类型定义引用,若不存在则返回 `None`。
147    pub fn mark_definition(
148        &self,
149        type_name: &str,
150    ) -> Option<&MarkDefinition> {
151        self.schema.marks.get(type_name)
152    }
153
154    /// 获取整个 Schema 的节点与标记定义映射,便于上层做批量/调试读取。
155    pub fn definitions(
156        &self
157    ) -> (&HashMap<String, NodeDefinition>, &HashMap<String, MarkDefinition>)
158    {
159        (&self.schema.nodes, &self.schema.marks)
160    }
161
162    /// 以顶级节点为根构建整棵子树。
163    pub fn create_top_node(
164        &self,
165        id: Option<NodeId>,
166        attrs: Option<&HashMap<String, Value>>,
167        content: Vec<Node>,
168        marks: Option<Vec<Mark>>,
169    ) -> PoolResult<NodeTree> {
170        let top_node_type = self
171            .schema
172            .top_node_type
173            .as_ref()
174            .ok_or_else(|| schema_error("未找到顶级节点类型定义"))?;
175
176        self.create_tree_with_type(top_node_type, id, attrs, content, marks)
177    }
178
179    /// 为指定类型构建并填充子树。
180    pub fn create_tree(
181        &self,
182        type_name: &str,
183        id: Option<NodeId>,
184        attrs: Option<&HashMap<String, Value>>,
185        content: Vec<Node>,
186        marks: Option<Vec<Mark>>,
187    ) -> PoolResult<NodeTree> {
188        let node_type = self.schema.nodes.get(type_name).ok_or_else(|| {
189            schema_error(&format!("无法在 schema 中找到节点类型:{type_name}"))
190        })?;
191
192        self.create_tree_with_type(node_type, id, attrs, content, marks)
193    }
194
195    /// 暴露给 [`NodeDefinition`] 的内部构建逻辑。
196    pub(crate) fn create_tree_with_type(
197        &self,
198        node_type: &NodeDefinition,
199        id: Option<NodeId>,
200        attrs: Option<&HashMap<String, Value>>,
201        content: Vec<Node>,
202        marks: Option<Vec<Mark>>,
203    ) -> PoolResult<NodeTree> {
204        let id: NodeId = id.unwrap_or_else(IdGenerator::get_id);
205        let computed_attrs = node_type.compute_attrs(attrs);
206        let computed_marks = node_type.compute_marks(marks);
207
208        let mut filled_nodes: Vec<NodeTree> = Vec::new();
209        let mut final_content_ids: Vec<NodeId> = Vec::new();
210
211        if let Some(content_match) = &node_type.content_match {
212            if let Some(matched) =
213                content_match.match_fragment(&content, self.schema)
214            {
215                if let Some(needed_type_names) =
216                    matched.fill(&content, true, self.schema)
217                {
218                    for type_name in needed_type_names {
219                        if let Some(existing_node) =
220                            content.iter().find(|n| n.r#type == type_name)
221                        {
222                            let attrs_map: HashMap<String, Value> =
223                                existing_node
224                                    .attrs
225                                    .attrs
226                                    .iter()
227                                    .map(|(k, v)| (k.clone(), v.clone()))
228                                    .collect();
229                            let marks_vec: Vec<Mark> =
230                                existing_node.marks.iter().cloned().collect();
231                            let child_type = self
232                                .schema
233                                .nodes
234                                .get(&type_name)
235                                .ok_or_else(|| {
236                                    schema_error(&format!(
237                                        "无法在 schema 中找到节点类型:{type_name}"
238                                    ))
239                                })?;
240
241                            let child_tree = self.create_tree_with_type(
242                                child_type,
243                                Some(existing_node.id.clone()),
244                                Some(&attrs_map),
245                                vec![],
246                                Some(marks_vec),
247                            )?;
248                            let child_id = child_tree.0.id.clone();
249                            final_content_ids.push(child_id);
250                            filled_nodes.push(child_tree);
251                        } else {
252                            let child_type = self
253                                .schema
254                                .nodes
255                                .get(&type_name)
256                                .ok_or_else(|| {
257                                    schema_error(&format!(
258                                        "无法在 schema 中找到节点类型:{type_name}"
259                                    ))
260                                })?;
261
262                            let child_tree = self.create_tree_with_type(
263                                child_type,
264                                None,
265                                None,
266                                vec![],
267                                None,
268                            )?;
269                            let child_id = child_tree.0.id.clone();
270                            final_content_ids.push(child_id);
271                            filled_nodes.push(child_tree);
272                        }
273                    }
274                }
275            }
276        }
277
278        let node = Node::new(
279            &id,
280            node_type.name.clone(),
281            computed_attrs,
282            final_content_ids,
283            computed_marks,
284        );
285
286        Ok(NodeTree(node, filled_nodes))
287    }
288
289    fn instantiate_node(
290        node_type: &NodeDefinition,
291        id: Option<NodeId>,
292        attrs: Option<&HashMap<String, Value>>,
293        content: Vec<NodeId>,
294        marks: Option<Vec<Mark>>,
295    ) -> Node {
296        let id: NodeId = id.unwrap_or_else(IdGenerator::get_id);
297        let attrs = node_type.compute_attrs(attrs);
298        let marks = node_type.compute_marks(marks);
299
300        Node::new(&id, node_type.name.clone(), attrs, content, marks)
301    }
302
303    pub(crate) fn instantiate_mark(
304        mark_def: &MarkDefinition,
305        attrs: Option<&HashMap<String, Value>>,
306    ) -> Mark {
307        Mark {
308            r#type: mark_def.name.clone(),
309            attrs: mark_def.compute_attrs(attrs),
310        }
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use crate::mark_definition::MarkSpec;
318    use crate::node_definition::NodeSpec;
319    use crate::schema::{Schema, SchemaSpec};
320
321    fn build_schema() -> Schema {
322        let mut spec = SchemaSpec {
323            nodes: HashMap::new(),
324            marks: HashMap::new(),
325            top_node: Some("doc".to_string()),
326        };
327        spec.nodes.insert("doc".to_string(), NodeSpec::default());
328        spec.nodes.insert("paragraph".to_string(), NodeSpec::default());
329        spec.marks.insert("bold".to_string(), MarkSpec::default());
330        Schema::compile(spec).expect("schema should compile")
331    }
332
333    #[test]
334    fn ensure_node_returns_descriptive_error() {
335        let schema = build_schema();
336        let factory = NodeFactory::new(&schema);
337        let err = factory.ensure_node("unknown").unwrap_err();
338        let msg = err.to_string();
339        assert!(msg.contains("未找到节点类型"), "actual: {msg}");
340        assert!(msg.contains("unknown"), "actual: {msg}");
341        assert!(msg.contains("doc"), "actual: {msg}");
342        // 这里只对关键字段做断言,完整文案在运行时人工确认。
343    }
344
345    #[test]
346    fn ensure_mark_returns_descriptive_error() {
347        let schema = build_schema();
348        let factory = NodeFactory::new(&schema);
349        let err = factory.ensure_mark("italic").unwrap_err();
350        let msg = err.to_string();
351        assert!(msg.contains("未找到标记类型"), "actual: {msg}");
352        assert!(msg.contains("italic"), "actual: {msg}");
353        assert!(msg.contains("bold"), "actual: {msg}");
354    }
355    #[test]
356    fn node_and_mark_names_exposed() {
357        let schema = build_schema();
358        let factory = NodeFactory::new(&schema);
359        let nodes = factory.node_names();
360        assert!(nodes.contains(&"doc"));
361        let marks = factory.mark_names();
362        assert!(marks.contains(&"bold"));
363    }
364}