mf_model/
schema.rs

1use crate::error::error_helpers::schema_error;
2use crate::error::PoolResult;
3
4use super::attrs::Attrs;
5use super::content::ContentMatch;
6use super::mark_type::{MarkSpec, MarkType};
7use super::node_type::{NodeSpec, NodeType};
8use serde::Serialize;
9use serde_json::Value;
10use std::any::Any;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13/// 属性定义结构体
14/// 用于定义节点或标记的属性特征
15#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize)]
16pub struct Attribute {
17    pub has_default: bool,
18    pub default: Option<Value>,
19}
20
21impl Attribute {
22    /// 从 AttributeSpec 创建新的 Attribute 实例
23    pub(crate) fn new(options: AttributeSpec) -> Self {
24        Attribute {
25            has_default: options.default.is_some(),
26            default: options.default,
27        }
28    }
29    /// 检查属性是否为必需的
30    /// 如果没有默认值,则属性为必需
31    pub fn is_required(&self) -> bool {
32        !self.has_default
33    }
34}
35/// Schema 结构体定义
36/// 用于管理文档模型的整体结构,包括节点和标记的类型定义
37#[derive(Clone, Debug)]
38pub struct Schema {
39    /// Schema 的规范定义
40    pub spec: SchemaSpec,
41    /// 顶级节点类型
42    pub top_node_type: Option<NodeType>,
43    /// 全局缓存
44    pub cached: Arc<Mutex<HashMap<String, Arc<dyn Any + Send + Sync>>>>,
45    /// 节点类型映射表
46    pub nodes: HashMap<String, NodeType>,
47    /// 标记类型映射表
48    pub marks: HashMap<String, MarkType>,
49}
50impl PartialEq for Schema {
51    fn eq(
52        &self,
53        other: &Self,
54    ) -> bool {
55        self.spec == other.spec
56            && self.top_node_type == other.top_node_type
57            && self.nodes == other.nodes
58            && self.marks == other.marks
59    }
60}
61impl Eq for Schema {}
62impl Schema {
63    /// 创建新的 Schema 实例
64    pub fn new(spec: SchemaSpec) -> Self {
65        let mut instance_spec = SchemaSpec {
66            nodes: HashMap::new(),
67            marks: HashMap::new(),
68            top_node: spec.top_node,
69        };
70        // 复制 spec 属性
71        for (key, value) in spec.nodes {
72            instance_spec.nodes.insert(key, value);
73        }
74        for (key, value) in spec.marks {
75            instance_spec.marks.insert(key, value);
76        }
77        Schema {
78            spec: instance_spec,
79            top_node_type: None,
80            cached: Arc::new(Mutex::new(HashMap::new())),
81            nodes: HashMap::new(),
82            marks: HashMap::new(),
83        }
84    }
85    /// 编译 Schema 定义
86    /// 处理节点和标记的定义,建立它们之间的关系
87    pub fn compile(instance_spec: SchemaSpec) -> PoolResult<Schema> {
88        let mut schema: Schema = Schema::new(instance_spec);
89        let nodes: HashMap<String, NodeType> =
90            NodeType::compile(schema.spec.nodes.clone());
91        let marks = MarkType::compile(schema.spec.marks.clone());
92        let mut content_expr_cache = HashMap::new();
93        let mut updated_nodes = HashMap::new();
94        for (prop, type_) in &nodes {
95            if marks.contains_key(prop) {
96                return Err(schema_error(&format!(
97                    "{} 不能既是节点又是标记",
98                    prop
99                )));
100            }
101
102            let content_expr = type_.spec.content.as_deref().unwrap_or("");
103            let mark_expr = type_.spec.marks.as_deref();
104
105            let content_expr_string = content_expr.to_string();
106            let content_match = content_expr_cache
107                .entry(content_expr_string.clone())
108                .or_insert_with(|| {
109                    ContentMatch::parse(content_expr_string, &nodes)
110                })
111                .clone();
112
113            let mark_set = match mark_expr {
114                Some("_") => None,
115                Some(expr) => {
116                    let marks_result =
117                        gather_marks(&marks, expr.split_whitespace().collect());
118                    match marks_result {
119                        Ok(marks) => Some(marks.into_iter().cloned().collect()), // Convert Vec<&MarkType> to Vec<MarkType>
120                        Err(e) => return Err(schema_error(&e)),
121                    }
122                },
123                None => None,
124            };
125
126            let mut node = type_.clone();
127            node.content_match = Some(content_match);
128            node.mark_set = mark_set;
129            updated_nodes.insert(prop.clone(), node);
130        }
131        schema.nodes = updated_nodes;
132        schema.marks = marks;
133        schema.top_node_type = match schema.nodes.get(
134            &schema.spec.top_node.clone().unwrap_or_else(|| "doc".to_string()),
135        ) {
136            Some(node) => Some(node.clone()),
137            None => {
138                return Err(schema_error("未找到顶级节点类型定义"));
139            },
140        };
141
142        Ok(schema)
143    }
144}
145/// Schema 规范定义
146/// 包含节点和标记的原始定义信息
147#[derive(Clone, PartialEq, Eq, Debug)]
148pub struct SchemaSpec {
149    pub nodes: HashMap<String, NodeSpec>,
150    pub marks: HashMap<String, MarkSpec>,
151    pub top_node: Option<String>,
152}
153
154// 其他辅助函数...
155/// 获取属性的默认值映射
156/// 如果所有属性都有默认值,返回包含所有默认值的映射
157/// 如果任一属性没有默认值,返回 None
158pub fn default_attrs(
159    attrs: &HashMap<String, Attribute>
160) -> Option<HashMap<String, Value>> {
161    let mut defaults = HashMap::new();
162
163    for (attr_name, attr) in attrs {
164        if let Some(default) = &attr.default {
165            defaults.insert(attr_name.clone(), default.clone());
166        } else {
167            return None;
168        }
169    }
170
171    Some(defaults)
172}
173/// 属性规范定义
174#[derive(Clone, PartialEq, Debug, Eq, Hash, Serialize)]
175pub struct AttributeSpec {
176    /// 属性的默认值
177    pub default: Option<Value>,
178}
179/// 收集标记类型
180/// 根据给定的标记名称列表,收集对应的标记类型
181fn gather_marks<'a>(
182    marks_map: &'a HashMap<String, MarkType>,
183    marks: Vec<&'a str>,
184) -> Result<Vec<&'a MarkType>, String> {
185    let mut found = Vec::new();
186
187    for name in marks {
188        if let Some(mark) = marks_map.get(name) {
189            found.push(mark);
190        } else if name == "_" {
191            // "_" 表示所有标记类型都被允许
192            found.extend(marks_map.values());
193        } else {
194            // 尝试通过组名匹配标记
195            let mut matched = false;
196            for mark_ref in marks_map.values() {
197                if mark_ref.spec.group.as_ref().is_some_and(|group| {
198                    group.split_whitespace().any(|g| g == name)
199                }) {
200                    found.push(mark_ref);
201                    matched = true;
202                    break;
203                }
204            }
205            if !matched {
206                return Err(format!("未知的标记类型: '{}'", name));
207            }
208        }
209    }
210    Ok(found)
211}
212/// 计算属性值
213/// 根据属性定义和提供的值计算最终的属性值
214pub fn compute_attrs(
215    attrs: &HashMap<String, Attribute>,
216    value: Option<&HashMap<String, Value>>,
217) -> Attrs {
218    let mut built = Attrs::default();
219
220    for (name, attr) in attrs {
221        let given = value.and_then(|v| v.get(name));
222
223        let given = match given {
224            Some(val) => val.clone(),
225            None => {
226                if attr.has_default {
227                    attr.default.clone().unwrap_or_else(|| {
228                        panic!("没有为属性提供默认值 {}", name)
229                    })
230                } else {
231                    Value::Null
232                }
233            },
234        };
235
236        built[&name] = given;
237    }
238
239    built
240}