moduforge_transform/
attr_step.rs

1use std::sync::Arc;
2
3use crate::{transform_error, TransformResult};
4
5use super::{
6    step::{Step, StepResult},
7};
8use im::HashMap as ImHashMap;
9use moduforge_model::{schema::Schema, tree::Tree, types::NodeId};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value};
12
13#[derive(Debug, Serialize, Deserialize, Clone)]
14pub struct AttrStep {
15    id: NodeId,
16    values: ImHashMap<String, Value>,
17}
18
19impl AttrStep {
20    pub fn new(
21        id: String,
22        values: ImHashMap<String, Value>,
23    ) -> Self {
24        AttrStep { id, values }
25    }
26}
27
28impl Step for AttrStep {
29    fn name(&self) -> String {
30        "attr_step".to_string()
31    }
32
33    fn apply(
34        &self,
35        dart: &mut Tree,
36        schema: Arc<Schema>,
37    ) -> TransformResult<StepResult> {
38        let _ = schema;
39        match dart.get_node(&self.id) {
40            Some(node) => {
41                let attr = &schema.nodes.get(&node.r#type).unwrap().attrs;
42                // 删除 self.values 中 attr中没有定义的属性
43                let mut new_values = self.values.clone();
44                for (key, _) in self.values.iter() {
45                    if !attr.contains_key(key) {
46                        new_values.remove(key);
47                    }
48                }
49                let result = dart.attrs(&self.id) + new_values;
50                match result {
51                    Ok(_) => Ok(StepResult::ok()),
52                    Err(e) => Err(transform_error(e.to_string())),
53                }
54            },
55            None => {
56                return Err(transform_error("节点不存在".to_string()));
57            },
58        }
59    }
60
61    fn serialize(&self) -> Option<Vec<u8>> {
62        serde_json::to_vec(self).ok()
63    }
64
65    fn invert(
66        &self,
67        dart: &Arc<Tree>,
68    ) -> Option<Arc<dyn Step>> {
69        match dart.get_node(&self.id) {
70            Some(node) => {
71                let mut new_values = im::hashmap!();
72                for (key, value) in node.attrs.attrs.iter() {
73                    new_values.insert(key.clone(), value.clone());
74                }
75                Some(Arc::new(AttrStep::new(self.id.clone(), new_values)))
76            },
77            None => {
78                return None;
79            },
80        }
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use moduforge_model::node::Node;
88    use moduforge_model::attrs::Attrs;
89    use moduforge_model::node_type::NodeSpec;
90    use moduforge_model::schema::{SchemaSpec, AttributeSpec};
91    use std::collections::HashMap;
92    use std::sync::Arc;
93    use serde_json::json;
94
95    fn create_test_node(id: &str) -> Node {
96        Node::new(id, "test".to_string(), Attrs::default(), vec![], vec![])
97    }
98
99    fn create_test_schema() -> Arc<Schema> {
100        let mut nodes = HashMap::new();
101        let mut attrs = HashMap::new();
102        attrs.insert("name".to_string(), AttributeSpec { default: None });
103        attrs.insert("age".to_string(), AttributeSpec { default: None });
104
105        nodes.insert(
106            "test".to_string(),
107            NodeSpec {
108                content: None,
109                marks: None,
110                group: None,
111                desc: Some("Test node".to_string()),
112                attrs: Some(attrs),
113            },
114        );
115
116        let spec = SchemaSpec {
117            nodes,
118            marks: HashMap::new(),
119            top_node: Some("test".to_string()),
120        };
121
122        Arc::new(Schema::compile(spec).unwrap())
123    }
124
125    #[test]
126    fn test_attr_step_creation() {
127        let mut values = HashMap::new();
128        values.insert("name".to_string(), json!("test"));
129        values.insert("age".to_string(), json!(25));
130
131        let step = AttrStep::new("node1".to_string(), values.clone().into());
132        assert_eq!(step.id, "node1");
133        assert_eq!(step.values, values.into());
134    }
135
136    #[test]
137    fn test_attr_step_apply() {
138        // 创建测试节点和树
139        let node = create_test_node("node1");
140        let mut tree = Tree::new(node);
141
142        // 创建测试 schema
143        let schema = create_test_schema();
144
145        // 创建属性步骤
146        let mut values = HashMap::new();
147        values.insert("name".to_string(), json!("test"));
148        values.insert("age".to_string(), json!(25));
149        let step = AttrStep::new("node1".to_string(), values.into());
150
151        // 应用步骤
152        let result = step.apply(&mut tree, schema.clone());
153        assert!(result.is_ok());
154
155        // 验证属性是否被正确设置
156        let updated_node = tree.get_node(&"node1".to_string()).unwrap();
157        assert_eq!(updated_node.attrs.get("name").unwrap(), &json!("test"));
158        assert_eq!(updated_node.attrs.get("age").unwrap(), &json!(25));
159    }
160
161    #[test]
162    fn test_attr_step_apply_invalid_attrs() {
163        // 创建测试节点和树
164        let node = create_test_node("node1");
165        let mut tree = Tree::new(node);
166
167        // 创建测试 schema
168        let schema = create_test_schema();
169
170        // 创建包含无效属性的步骤
171        let mut values = HashMap::new();
172        values.insert("invalid_attr".to_string(), json!("test"));
173        let step = AttrStep::new("node1".to_string(), values.into());
174
175        // 应用步骤
176        let result = step.apply(&mut tree, schema.clone());
177        assert!(result.is_ok());
178
179        // 验证无效属性是否被过滤掉
180        let updated_node = tree.get_node(&"node1".to_string()).unwrap();
181        assert!(updated_node.attrs.get("invalid_attr").is_none());
182    }
183
184    #[test]
185    fn test_attr_step_apply_nonexistent_node() {
186        // 创建测试树(不包含目标节点)
187        let node: Node = create_test_node("root");
188        let mut tree = Tree::new(node);
189
190        // 创建测试 schema
191        let schema = create_test_schema();
192
193        // 创建属性步骤
194        let mut values = HashMap::new();
195        values.insert("name".to_string(), json!("test"));
196        let step = AttrStep::new("nonexistent".to_string(), values.into());
197
198        // 应用步骤
199        let result = step.apply(&mut tree, schema);
200        assert!(result.is_err());
201    }
202
203    #[test]
204    fn test_attr_step_serialize() {
205        let mut values = HashMap::new();
206        values.insert("name".to_string(), json!("test"));
207        let step = AttrStep::new("node1".to_string(), values.into());
208
209        let serialized = Step::serialize(&step);
210        assert!(serialized.is_some());
211
212        // 验证序列化后的数据可以反序列化
213        let deserialized: AttrStep =
214            serde_json::from_slice(&serialized.unwrap()).unwrap();
215        assert_eq!(deserialized.id, "node1");
216        assert_eq!(deserialized.values.get("name").unwrap(), &json!("test"));
217    }
218
219    #[test]
220    fn test_attr_step_invert() {
221        // 创建测试节点和树
222        let node = create_test_node("node1");
223        let mut tree = Tree::new(node);
224
225        // 创建测试 schema
226        let schema = create_test_schema();
227
228        // 设置初始属性
229        let mut values = HashMap::new();
230        values.insert("name".to_string(), json!("original_name"));
231        values.insert("age".to_string(), json!(25));
232        let step = AttrStep::new("node1".to_string(), values.into());
233        step.apply(&mut tree, schema.clone()).unwrap();
234
235        // 创建新的属性步骤,修改属性
236        let mut new_values = HashMap::new();
237        new_values.insert("name".to_string(), json!("modified_name"));
238        new_values.insert("age".to_string(), json!(30));
239        let new_step = AttrStep::new("node1".to_string(), new_values.into());
240
241        // 获取反转步骤
242        let inverted = new_step.invert(&Arc::new(tree.clone()));
243        assert!(inverted.is_some());
244
245        // 应用新步骤
246        new_step.apply(&mut tree, schema.clone()).unwrap();
247        let node = tree.get_node(&"node1".to_string()).unwrap();
248        assert_eq!(node.attrs.get("name").unwrap(), &json!("modified_name"));
249        assert_eq!(node.attrs.get("age").unwrap(), &json!(30));
250
251        // 应用反转步骤
252        let inverted_step = inverted.unwrap();
253        inverted_step.apply(&mut tree, schema).unwrap();
254
255        // 验证属性是否恢复到原始值
256        let node = tree.get_node(&"node1".to_string()).unwrap();
257        assert_eq!(node.attrs.get("name").unwrap(), &json!("original_name"));
258        assert_eq!(node.attrs.get("age").unwrap(), &json!(25));
259    }
260}