mf_model/
schema.rs

1use crate::error::error_helpers::schema_error;
2use crate::error::PoolResult;
3
4use super::attrs::Attrs;
5use super::content::ContentMatch;
6use super::mark_definition::{MarkDefinition, MarkSpec};
7use super::node_definition::{NodeDefinition, NodeSpec};
8use crate::node_factory::NodeFactory;
9use serde::Serialize;
10use serde_json::Value;
11use std::any::Any;
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14/// 属性定义结构体
15/// 用于定义节点或标记的属性特征
16#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize)]
17pub struct Attribute {
18    pub has_default: bool,
19    pub default: Option<Value>,
20}
21
22impl Attribute {
23    /// 从 AttributeSpec 创建新的 Attribute 实例
24    pub(crate) fn new(options: AttributeSpec) -> Self {
25        Attribute {
26            has_default: options.default.is_some(),
27            default: options.default,
28        }
29    }
30    /// 检查属性是否为必需的
31    /// 如果没有默认值,则属性为必需
32    pub fn is_required(&self) -> bool {
33        !self.has_default
34    }
35}
36/// Schema 结构体定义
37/// 用于管理文档模型的整体结构,包括节点和标记的类型定义
38#[derive(Clone, Debug)]
39pub struct Schema {
40    /// Schema 的规范定义
41    pub spec: SchemaSpec,
42    /// 顶级节点类型
43    pub top_node_type: Option<NodeDefinition>,
44    /// 全局缓存
45    pub cached: Arc<Mutex<HashMap<String, Arc<dyn Any + Send + Sync>>>>,
46    /// 节点类型映射表
47    pub(crate) nodes: HashMap<String, NodeDefinition>,
48    /// 标记类型映射表
49    pub(crate) marks: HashMap<String, MarkDefinition>,
50}
51impl PartialEq for Schema {
52    fn eq(
53        &self,
54        other: &Self,
55    ) -> bool {
56        self.spec == other.spec
57            && self.top_node_type == other.top_node_type
58            && self.nodes == other.nodes
59            && self.marks == other.marks
60    }
61}
62impl Eq for Schema {}
63impl Schema {
64    /// 创建新的 Schema 实例
65    #[cfg_attr(feature = "dev-tracing", tracing::instrument(skip(spec), fields(
66        crate_name = "model",
67        node_count = spec.nodes.len(),
68        mark_count = spec.marks.len()
69    )))]
70    pub fn new(spec: SchemaSpec) -> Self {
71        let mut instance_spec = SchemaSpec {
72            nodes: HashMap::new(),
73            marks: HashMap::new(),
74            top_node: spec.top_node,
75        };
76        // 复制 spec 属性
77        for (key, value) in spec.nodes {
78            instance_spec.nodes.insert(key, value);
79        }
80        for (key, value) in spec.marks {
81            instance_spec.marks.insert(key, value);
82        }
83        Schema {
84            spec: instance_spec,
85            top_node_type: None,
86            cached: Arc::new(Mutex::new(HashMap::new())),
87            nodes: HashMap::new(),
88            marks: HashMap::new(),
89        }
90    }
91    pub fn factory(&self) -> NodeFactory<'_> {
92        NodeFactory::new(self)
93    }
94    /// 编译 Schema 定义
95    /// 处理节点和标记的定义,建立它们之间的关系
96    #[cfg_attr(feature = "dev-tracing", tracing::instrument(skip(instance_spec), fields(
97        crate_name = "model",
98        node_count = instance_spec.nodes.len(),
99        mark_count = instance_spec.marks.len()
100    )))]
101    pub fn compile(instance_spec: SchemaSpec) -> PoolResult<Schema> {
102        let mut schema: Schema = Schema::new(instance_spec);
103        let nodes: HashMap<String, NodeDefinition> =
104            NodeDefinition::compile(schema.spec.nodes.clone());
105        let marks = MarkDefinition::compile(schema.spec.marks.clone());
106        let mut content_expr_cache = HashMap::new();
107        let mut updated_nodes = HashMap::new();
108        for (prop, type_) in &nodes {
109            if marks.contains_key(prop) {
110                return Err(schema_error(&format!(
111                    "{prop} 不能既是节点又是标记"
112                )));
113            }
114
115            let content_expr = type_.spec.content.as_deref().unwrap_or("");
116            let mark_expr = type_.spec.marks.as_deref();
117
118            let content_expr_string = content_expr.to_string();
119            let content_match = content_expr_cache
120                .entry(content_expr_string.clone())
121                .or_insert_with(|| {
122                    ContentMatch::parse(content_expr_string, &nodes)
123                })
124                .clone();
125
126            let mark_set = match mark_expr {
127                Some("_") => None,
128                Some(expr) => {
129                    let marks_result =
130                        gather_marks(&marks, expr.split_whitespace().collect());
131                    match marks_result {
132                        Ok(marks) => Some(marks.into_iter().cloned().collect()), // Convert Vec<&MarkType> to Vec<MarkType>
133                        Err(e) => return Err(schema_error(&e)),
134                    }
135                },
136                None => None,
137            };
138
139            let mut node = type_.clone();
140            node.content_match = Some(content_match);
141            node.mark_set = mark_set;
142            updated_nodes.insert(prop.clone(), node);
143        }
144        schema.nodes = updated_nodes;
145        schema.marks = marks;
146        schema.top_node_type = match schema.nodes.get(
147            &schema.spec.top_node.clone().unwrap_or_else(|| "doc".to_string()),
148        ) {
149            Some(node) => Some(node.clone()),
150            None => {
151                return Err(schema_error("未找到顶级节点类型定义"));
152            },
153        };
154
155        Ok(schema)
156    }
157}
158/// Schema 规范定义
159/// 包含节点和标记的原始定义信息
160#[derive(Clone, PartialEq, Eq, Debug)]
161pub struct SchemaSpec {
162    pub nodes: HashMap<String, NodeSpec>,
163    pub marks: HashMap<String, MarkSpec>,
164    pub top_node: Option<String>,
165}
166
167// 其他辅助函数...
168/// 获取属性的默认值映射
169/// 如果所有属性都有默认值,返回包含所有默认值的映射
170/// 如果任一属性没有默认值,返回 None
171pub fn default_attrs(
172    attrs: &HashMap<String, Attribute>
173) -> Option<HashMap<String, Value>> {
174    let mut defaults = HashMap::new();
175
176    for (attr_name, attr) in attrs {
177        if let Some(default) = &attr.default {
178            defaults.insert(attr_name.clone(), default.clone());
179        } else {
180            return None;
181        }
182    }
183
184    Some(defaults)
185}
186/// 属性规范定义
187#[derive(Clone, PartialEq, Debug, Eq, Hash, Serialize)]
188pub struct AttributeSpec {
189    /// 属性的默认值
190    pub default: Option<Value>,
191}
192/// 收集标记类型
193/// 根据给定的标记名称列表,收集对应的标记类型
194fn gather_marks<'a>(
195    marks_map: &'a HashMap<String, MarkDefinition>,
196    marks: Vec<&'a str>,
197) -> Result<Vec<&'a MarkDefinition>, String> {
198    let mut found = Vec::new();
199
200    for name in marks {
201        if let Some(mark) = marks_map.get(name) {
202            found.push(mark);
203        } else if name == "_" {
204            // "_" 表示所有标记类型都被允许
205            found.extend(marks_map.values());
206        } else {
207            // 尝试通过组名匹配标记
208            let mut matched = false;
209            for mark_ref in marks_map.values() {
210                if mark_ref.spec.group.as_ref().is_some_and(|group| {
211                    group.split_whitespace().any(|g| g == name)
212                }) {
213                    found.push(mark_ref);
214                    matched = true;
215                    break;
216                }
217            }
218            if !matched {
219                return Err(format!("未知的标记类型: '{name}'"));
220            }
221        }
222    }
223    Ok(found)
224}
225/// 计算属性值
226/// 根据属性定义和提供的值计算最终的属性值
227pub fn compute_attrs(
228    attrs: &HashMap<String, Attribute>,
229    value: Option<&HashMap<String, Value>>,
230) -> Attrs {
231    let mut built = Attrs::default();
232
233    for (name, attr) in attrs {
234        let given = value.and_then(|v| v.get(name));
235
236        let given = match given {
237            Some(val) => val.clone(),
238            None => {
239                if attr.has_default {
240                    attr.default.clone().unwrap_or_else(|| {
241                        panic!("没有为属性提供默认值 {name}")
242                    })
243                } else {
244                    Value::Null
245                }
246            },
247        };
248
249        built[name] = given;
250    }
251
252    built
253}