mf_transform/
attr_step.rs

1use std::sync::Arc;
2
3use crate::{transform_error, TransformResult};
4
5use super::{
6    step::{Step, StepResult},
7};
8use imbl::HashMap as ImHashMap;
9use mf_model::{schema::Schema, tree::Tree, types::NodeId};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value};
12
13#[derive(Debug, Serialize, Deserialize, Clone)]
14pub struct AttrStep {
15    pub id: NodeId,
16    pub values: ImHashMap<String, Value>,
17}
18
19impl AttrStep {
20    pub fn new(
21        id: NodeId,
22        values: ImHashMap<String, Value>,
23    ) -> Self {
24        AttrStep { id, values }
25    }
26}
27
28impl Step for AttrStep {
29    fn name(&self) -> String {
30        "attr_step".to_string()
31    }
32
33    fn apply(
34        &self,
35        dart: &mut Tree,
36        schema: Arc<Schema>,
37    ) -> TransformResult<StepResult> {
38        let _ = schema;
39        match dart.get_node(&self.id) {
40            Some(node) => {
41                // 获取节点类型定义,若缺失则返回错误而非 panic
42                let node_type = match schema.nodes.get(&node.r#type) {
43                    Some(nt) => nt,
44                    None => {
45                        return Err(transform_error(format!(
46                            "未知的节点类型: {}",
47                            node.r#type
48                        )));
49                    },
50                };
51                let attr = &node_type.attrs;
52                // 删除 self.values 中 attr中没有定义的属性
53                let mut new_values = self.values.clone();
54                for (key, _) in self.values.iter() {
55                    if !attr.contains_key(key) {
56                        new_values.remove(key);
57                    }
58                }
59                let result = dart.attrs(&self.id) + new_values;
60                match result {
61                    Ok(_) => Ok(StepResult::ok()),
62                    Err(e) => Err(transform_error(e.to_string())),
63                }
64            },
65            None => {
66                return Err(transform_error("节点不存在".to_string()));
67            },
68        }
69    }
70
71    fn serialize(&self) -> Option<Vec<u8>> {
72        serde_json::to_vec(self).ok()
73    }
74
75    fn invert(
76        &self,
77        dart: &Arc<Tree>,
78    ) -> Option<Arc<dyn Step>> {
79        match dart.get_node(&self.id) {
80            Some(node) => {
81                // 仅对本次修改过的键生成反向值,避免覆盖无关属性
82                let mut revert_values = imbl::hashmap!();
83                for (changed_key, _) in self.values.iter() {
84                    if let Some(old_val) = node.attrs.get_safe(changed_key) {
85                        revert_values
86                            .insert(changed_key.clone(), old_val.clone());
87                    }
88                    // 若原先不存在该键,这里不设置(缺少删除语义);
89                    // 如需彻底还原,可扩展支持 unset 语义
90                }
91                if revert_values.is_empty() {
92                    None
93                } else {
94                    Some(Arc::new(AttrStep::new(
95                        self.id.clone(),
96                        revert_values,
97                    )))
98                }
99            },
100            None => None,
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use mf_model::node::Node;
109    use mf_model::attrs::Attrs;
110    use mf_model::node_type::NodeSpec;
111    use mf_model::schema::{SchemaSpec, AttributeSpec};
112    use std::collections::HashMap;
113    use std::sync::Arc;
114    use serde_json::json;
115
116    fn create_test_node(id: &str) -> Node {
117        Node::new(id, "test".to_string(), Attrs::default(), vec![], vec![])
118    }
119
120    fn create_test_schema() -> Arc<Schema> {
121        let mut nodes = HashMap::new();
122        let mut attrs = HashMap::new();
123        attrs.insert("name".to_string(), AttributeSpec { default: None });
124        attrs.insert("age".to_string(), AttributeSpec { default: None });
125
126        nodes.insert(
127            "test".to_string(),
128            NodeSpec {
129                content: None,
130                marks: None,
131                group: None,
132                desc: Some("Test node".to_string()),
133                attrs: Some(attrs),
134            },
135        );
136
137        let spec = SchemaSpec {
138            nodes,
139            marks: HashMap::new(),
140            top_node: Some("test".to_string()),
141        };
142
143        Arc::new(Schema::compile(spec).unwrap())
144    }
145
146    #[test]
147    fn test_attr_step_creation() {
148        let mut values = imbl::HashMap::new();
149        values.insert("name".to_string(), json!("test"));
150        values.insert("age".to_string(), json!(25));
151
152        let step = AttrStep::new("node1".into(), values.clone().into());
153        assert_eq!(step.id, "node1".into());
154        assert_eq!(step.values, values.into());
155    }
156
157    #[test]
158    fn test_attr_step_apply() {
159        // 创建测试节点和树
160        let node = create_test_node("node1");
161        let mut tree = Tree::new(node);
162
163        // 创建测试 schema
164        let schema = create_test_schema();
165
166        // 创建属性步骤
167        let mut values = HashMap::new();
168        values.insert("name".to_string(), json!("test"));
169        values.insert("age".to_string(), json!(25));
170        let step = AttrStep::new("node1".into(), values.into());
171
172        // 应用步骤
173        let result = step.apply(&mut tree, schema.clone());
174        assert!(result.is_ok());
175
176        // 验证属性是否被正确设置
177        let updated_node = tree.get_node(&"node1".into()).unwrap();
178        assert_eq!(updated_node.attrs.get("name").unwrap(), &json!("test"));
179        assert_eq!(updated_node.attrs.get("age").unwrap(), &json!(25));
180    }
181
182    #[test]
183    fn test_attr_step_apply_invalid_attrs() {
184        // 创建测试节点和树
185        let node = create_test_node("node1");
186        let mut tree = Tree::new(node);
187
188        // 创建测试 schema
189        let schema = create_test_schema();
190
191        // 创建包含无效属性的步骤
192        let mut values = HashMap::new();
193        values.insert("invalid_attr".to_string(), json!("test"));
194        let step = AttrStep::new("node1".into(), values.into());
195
196        // 应用步骤
197        let result = step.apply(&mut tree, schema.clone());
198        assert!(result.is_ok());
199
200        // 验证无效属性是否被过滤掉
201        let updated_node = tree.get_node(&"node1".into()).unwrap();
202        assert!(updated_node.attrs.get("invalid_attr").is_none());
203    }
204
205    #[test]
206    fn test_attr_step_apply_nonexistent_node() {
207        // 创建测试树(不包含目标节点)
208        let node: Node = create_test_node("root");
209        let mut tree = Tree::new(node);
210
211        // 创建测试 schema
212        let schema = create_test_schema();
213
214        // 创建属性步骤
215        let mut values = HashMap::new();
216        values.insert("name".to_string(), json!("test"));
217        let step = AttrStep::new("nonexistent".into(), values.into());
218
219        // 应用步骤
220        let result = step.apply(&mut tree, schema);
221        assert!(result.is_err());
222    }
223
224    #[test]
225    fn test_attr_step_serialize() {
226        let mut values = HashMap::new();
227        values.insert("name".to_string(), json!("test"));
228        let step = AttrStep::new("node1".into(), values.into());
229
230        let serialized = Step::serialize(&step);
231        assert!(serialized.is_some());
232
233        // 验证序列化后的数据可以反序列化
234        let deserialized: AttrStep =
235            serde_json::from_slice(&serialized.unwrap()).unwrap();
236        assert_eq!(deserialized.id, "node1".into());
237        assert_eq!(deserialized.values.get("name").unwrap(), &json!("test"));
238    }
239
240    #[test]
241    fn test_attr_step_invert() {
242        // 创建测试节点和树
243        let node = create_test_node("node1");
244        let mut tree = Tree::new(node);
245
246        // 创建测试 schema
247        let schema = create_test_schema();
248
249        // 设置初始属性
250        let mut values = HashMap::new();
251        values.insert("name".to_string(), json!("original_name"));
252        values.insert("age".to_string(), json!(25));
253        let step = AttrStep::new("node1".into(), values.into());
254        step.apply(&mut tree, schema.clone()).unwrap();
255
256        // 创建新的属性步骤,修改属性
257        let mut new_values = HashMap::new();
258        new_values.insert("name".to_string(), json!("modified_name"));
259        new_values.insert("age".to_string(), json!(30));
260        let new_step = AttrStep::new("node1".into(), new_values.into());
261
262        // 获取反转步骤
263        let inverted = new_step.invert(&Arc::new(tree.clone()));
264        assert!(inverted.is_some());
265
266        // 应用新步骤
267        new_step.apply(&mut tree, schema.clone()).unwrap();
268        let node = tree.get_node(&"node1".into()).unwrap();
269        assert_eq!(node.attrs.get("name").unwrap(), &json!("modified_name"));
270        assert_eq!(node.attrs.get("age").unwrap(), &json!(30));
271
272        // 应用反转步骤
273        let inverted_step = inverted.unwrap();
274        inverted_step.apply(&mut tree, schema).unwrap();
275
276        // 验证属性是否恢复到原始值
277        let node = tree.get_node(&"node1".into()).unwrap();
278        assert_eq!(node.attrs.get("name").unwrap(), &json!("original_name"));
279        assert_eq!(node.attrs.get("age").unwrap(), &json!(25));
280    }
281}