moduforge_transform/
attr_step.rs

1use std::sync::Arc;
2
3use crate::{transform_error, TransformResult};
4
5use super::{
6    step::{Step, StepResult},
7};
8use im::HashMap as ImHashMap;
9use moduforge_model::{schema::Schema, tree::Tree, types::NodeId};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value};
12
13#[derive(Debug, Serialize, Deserialize, Clone)]
14pub struct AttrStep {
15    id: NodeId,
16    values: ImHashMap<String, Value>,
17}
18
19impl AttrStep {
20    pub fn new(
21        id: String,
22        values: ImHashMap<String, Value>,
23    ) -> Self {
24        AttrStep { id, values }
25    }
26}
27
28impl Step for AttrStep {
29    fn name(&self) -> String {
30        "attr_step".to_string()
31    }
32
33    fn apply(
34        &self,
35        dart: &mut Tree,
36        schema: Arc<Schema>,
37    ) -> TransformResult<StepResult> {
38        let _ = schema;
39        match dart.get_node(&self.id) {
40            Some(node) => {
41                let attr = &schema.nodes.get(&node.r#type).unwrap().attrs;
42                // 删除 self.values 中 attr中没有定义的属性
43                let mut new_values = self.values.clone();
44                for (key, _) in self.values.iter() {
45                    if !attr.contains_key(key) {
46                        new_values.remove(key);
47                    }
48                }
49                let result = dart.attrs(&self.id) + new_values;
50                match result {
51                    Ok(_) => Ok(StepResult::ok()),
52                    Err(e) => Err(transform_error(e.to_string())),
53                }
54            },
55            None => {
56                return Err(transform_error("节点不存在".to_string()));
57            },
58        }
59    }
60
61    fn serialize(&self) -> Option<Vec<u8>> {
62        serde_json::to_vec(self).ok()
63    }
64
65    fn invert(
66        &self,
67        dart: &Arc<Tree>,
68    ) -> Option<Arc<dyn Step>> {
69        match dart.get_node(&self.id) {
70            Some(node) => {
71                let mut new_values = im::hashmap!();
72                for (key, value) in node.attrs.attrs.iter() {
73                    new_values.insert(key.clone(), value.clone());
74                }
75                Some(Arc::new(AttrStep::new(self.id.clone(), new_values)))
76            },
77            None => {
78                return None;
79            },
80        }
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use moduforge_model::node::Node;
88    use moduforge_model::attrs::Attrs;
89    use moduforge_model::node_type::NodeSpec;
90    use moduforge_model::schema::{SchemaSpec, AttributeSpec};
91    use std::collections::HashMap;
92    use std::sync::Arc;
93    use serde_json::json;
94
95    fn create_test_node(id: &str) -> Node {
96        Node::new(
97            id,
98            "test".to_string(),
99            Attrs::default(),
100            vec![],
101            vec![],
102        )
103    }
104
105    fn create_test_schema() -> Arc<Schema> {
106        let mut nodes = HashMap::new();
107        let mut attrs = HashMap::new();
108        attrs.insert("name".to_string(), AttributeSpec { default: None });
109        attrs.insert("age".to_string(), AttributeSpec { default: None });
110        
111        nodes.insert("test".to_string(), NodeSpec {
112            content: None,
113            marks: None,
114            group: None,
115            desc: Some("Test node".to_string()),
116            attrs: Some(attrs),
117        });
118
119        let spec = SchemaSpec {
120            nodes,
121            marks: HashMap::new(),
122            top_node: Some("test".to_string()),
123        };
124
125        Arc::new(Schema::compile(spec).unwrap())
126    }
127
128    #[test]
129    fn test_attr_step_creation() {
130        let mut values = HashMap::new();
131        values.insert("name".to_string(), json!("test"));
132        values.insert("age".to_string(), json!(25));
133
134        let step = AttrStep::new("node1".to_string(), values.clone().into());
135        assert_eq!(step.id, "node1");
136        assert_eq!(step.values, values.into());
137    }
138
139    #[test]
140    fn test_attr_step_apply() {
141        // 创建测试节点和树
142        let node = create_test_node("node1");
143        let mut tree = Tree::new(node);
144        
145        // 创建测试 schema
146        let schema = create_test_schema();
147        
148        // 创建属性步骤
149        let mut values = HashMap::new();
150        values.insert("name".to_string(), json!("test"));
151        values.insert("age".to_string(), json!(25));
152        let step = AttrStep::new("node1".to_string(), values.into());
153
154        // 应用步骤
155        let result = step.apply(&mut tree, schema.clone());
156        assert!(result.is_ok());
157
158        // 验证属性是否被正确设置
159        let updated_node = tree.get_node(&"node1".to_string()).unwrap();
160        assert_eq!(updated_node.attrs.get("name").unwrap(), &json!("test"));
161        assert_eq!(updated_node.attrs.get("age").unwrap(), &json!(25));
162    }
163
164    #[test]
165    fn test_attr_step_apply_invalid_attrs() {
166        // 创建测试节点和树
167        let node = create_test_node("node1");
168        let mut tree = Tree::new(node);
169        
170        // 创建测试 schema
171        let schema = create_test_schema();
172        
173        // 创建包含无效属性的步骤
174        let mut values = HashMap::new();
175        values.insert("invalid_attr".to_string(), json!("test"));
176        let step = AttrStep::new("node1".to_string(), values.into());
177
178        // 应用步骤
179        let result = step.apply(&mut tree, schema.clone());
180        assert!(result.is_ok());
181
182        // 验证无效属性是否被过滤掉
183        let updated_node = tree.get_node(&"node1".to_string()).unwrap();
184        assert!(updated_node.attrs.get("invalid_attr").is_none());
185    }
186
187    #[test]
188    fn test_attr_step_apply_nonexistent_node() {
189        // 创建测试树(不包含目标节点)
190        let node: Node = create_test_node("root");
191        let mut tree = Tree::new(node);
192        
193        // 创建测试 schema
194        let schema = create_test_schema();
195        
196        // 创建属性步骤
197        let mut values = HashMap::new();
198        values.insert("name".to_string(), json!("test"));
199        let step = AttrStep::new("nonexistent".to_string(), values.into());
200
201        // 应用步骤
202        let result = step.apply(&mut tree, schema);
203        assert!(result.is_err());
204    }
205
206    #[test]
207    fn test_attr_step_serialize() {
208        let mut values = HashMap::new();
209        values.insert("name".to_string(), json!("test"));
210        let step = AttrStep::new("node1".to_string(), values.into());
211
212        let serialized = Step::serialize(&step);
213        assert!(serialized.is_some());
214        
215        // 验证序列化后的数据可以反序列化
216        let deserialized: AttrStep = serde_json::from_slice(&serialized.unwrap()).unwrap();
217        assert_eq!(deserialized.id, "node1");
218        assert_eq!(deserialized.values.get("name").unwrap(), &json!("test"));
219    }
220
221    #[test]
222    fn test_attr_step_invert() {
223        // 创建测试节点和树
224        let node = create_test_node("node1");
225        let mut tree = Tree::new(node);
226        
227        // 创建测试 schema
228        let schema = create_test_schema();
229        
230        // 设置初始属性
231        let mut values = HashMap::new();
232        values.insert("name".to_string(), json!("original_name"));
233        values.insert("age".to_string(), json!(25));
234        let step = AttrStep::new("node1".to_string(), values.into());
235        step.apply(&mut tree, schema.clone()).unwrap();
236
237        // 创建新的属性步骤,修改属性
238        let mut new_values = HashMap::new();
239        new_values.insert("name".to_string(), json!("modified_name"));
240        new_values.insert("age".to_string(), json!(30));
241        let new_step = AttrStep::new("node1".to_string(), new_values.into());
242
243        // 获取反转步骤
244        let inverted = new_step.invert(&Arc::new(tree.clone()));
245        assert!(inverted.is_some());
246
247        // 应用新步骤
248        new_step.apply(&mut tree, schema.clone()).unwrap();
249        let node = tree.get_node(&"node1".to_string()).unwrap();
250        assert_eq!(node.attrs.get("name").unwrap(), &json!("modified_name"));
251        assert_eq!(node.attrs.get("age").unwrap(), &json!(30));
252
253        // 应用反转步骤
254        let inverted_step = inverted.unwrap();
255        inverted_step.apply(&mut tree, schema).unwrap();
256        
257        // 验证属性是否恢复到原始值
258        let node = tree.get_node(&"node1".to_string()).unwrap();
259        assert_eq!(node.attrs.get("name").unwrap(), &json!("original_name"));
260        assert_eq!(node.attrs.get("age").unwrap(), &json!(25));
261    }
262}
263