1use std::sync::Arc;
2
3use crate::{transform_error, TransformResult};
4
5use super::{
6    step::{Step, StepResult},
7};
8use imbl::HashMap as ImHashMap;
9use mf_model::{schema::Schema, tree::Tree, types::NodeId};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value};
12#[derive(Debug, Serialize, Deserialize, Clone)]
14pub struct AttrStep {
15    pub id: NodeId,
16    pub values: ImHashMap<String, Value>,
17}
18
19impl AttrStep {
20    pub fn new(
21        id: NodeId,
22        values: ImHashMap<String, Value>,
23    ) -> Self {
24        AttrStep { id, values }
25    }
26}
27
28impl Step for AttrStep {
29    fn name(&self) -> String {
30        "attr_step".to_string()
31    }
32
33    fn apply(
34        &self,
35        dart: &mut Tree,
36        schema: Arc<Schema>,
37    ) -> TransformResult<StepResult> {
38        let _ = schema;
39        match dart.get_node(&self.id) {
40            Some(node) => {
41                let node_type = match schema.nodes.get(&node.r#type) {
43                    Some(nt) => nt,
44                    None => {
45                        return Err(transform_error(format!(
46                            "未知的节点类型: {}",
47                            node.r#type
48                        )));
49                    },
50                };
51                let attr = &node_type.attrs;
52                let mut new_values = self.values.clone();
54                for (key, _) in self.values.iter() {
55                    if !attr.contains_key(key) {
56                        new_values.remove(key);
57                    }
58                }
59                let result = dart.attrs(&self.id) + new_values;
60                match result {
61                    Ok(_) => Ok(StepResult::ok()),
62                    Err(e) => Err(transform_error(e.to_string())),
63                }
64            },
65            None => Err(transform_error("节点不存在".to_string())),
66        }
67    }
68
69    fn serialize(&self) -> Option<Vec<u8>> {
70        serde_json::to_vec(self).ok()
71    }
72
73    fn invert(
74        &self,
75        dart: &Arc<Tree>,
76    ) -> Option<Arc<dyn Step>> {
77        match dart.get_node(&self.id) {
78            Some(node) => {
79                let mut revert_values = imbl::hashmap!();
81                for (changed_key, _) in self.values.iter() {
82                    if let Some(old_val) = node.attrs.get_safe(changed_key) {
83                        revert_values
84                            .insert(changed_key.clone(), old_val.clone());
85                    }
86                    }
89                if revert_values.is_empty() {
90                    None
91                } else {
92                    Some(Arc::new(AttrStep::new(
93                        self.id.clone(),
94                        revert_values,
95                    )))
96                }
97            },
98            None => None,
99        }
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use mf_model::node::Node;
107    use mf_model::attrs::Attrs;
108    use mf_model::node_definition::NodeSpec;
109    use mf_model::schema::{SchemaSpec, AttributeSpec};
110    use std::collections::HashMap;
111    use std::sync::Arc;
112    use serde_json::json;
113
114    fn create_test_node(id: &str) -> Node {
115        Node::new(id, "test".to_string(), Attrs::default(), vec![], vec![])
116    }
117
118    fn create_test_schema() -> Arc<Schema> {
119        let mut nodes = HashMap::new();
120        let mut attrs = HashMap::new();
121        attrs.insert("name".to_string(), AttributeSpec { default: None });
122        attrs.insert("age".to_string(), AttributeSpec { default: None });
123
124        nodes.insert(
125            "test".to_string(),
126            NodeSpec {
127                content: None,
128                marks: None,
129                group: None,
130                desc: Some("Test node".to_string()),
131                attrs: Some(attrs),
132            },
133        );
134
135        let spec = SchemaSpec {
136            nodes,
137            marks: HashMap::new(),
138            top_node: Some("test".to_string()),
139        };
140
141        Arc::new(Schema::compile(spec).expect("测试 Schema 编译失败"))
142    }
143
144    #[test]
145    fn test_attr_step_creation() {
146        let mut values = imbl::HashMap::new();
147        values.insert("name".to_string(), json!("test"));
148        values.insert("age".to_string(), json!(25));
149
150        let step = AttrStep::new("node1".into(), values.clone());
151        assert_eq!(step.id, "node1".into());
152        assert_eq!(step.values, values);
153    }
154
155    #[test]
156    fn test_attr_step_apply() {
157        let node = create_test_node("node1");
159        let mut tree = Tree::new(node);
160
161        let schema = create_test_schema();
163
164        let mut values = HashMap::new();
166        values.insert("name".to_string(), json!("test"));
167        values.insert("age".to_string(), json!(25));
168        let step = AttrStep::new("node1".into(), values.into());
169
170        let result = step.apply(&mut tree, schema.clone());
172        assert!(result.is_ok());
173
174        let updated_node = tree.get_node(&"node1".into()).unwrap();
176        assert_eq!(updated_node.attrs.get("name").unwrap(), &json!("test"));
177        assert_eq!(updated_node.attrs.get("age").unwrap(), &json!(25));
178    }
179
180    #[test]
181    fn test_attr_step_apply_invalid_attrs() {
182        let node = create_test_node("node1");
184        let mut tree = Tree::new(node);
185
186        let schema = create_test_schema();
188
189        let mut values = HashMap::new();
191        values.insert("invalid_attr".to_string(), json!("test"));
192        let step = AttrStep::new("node1".into(), values.into());
193
194        let result = step.apply(&mut tree, schema.clone());
196        assert!(result.is_ok());
197
198        let updated_node = tree.get_node(&"node1".into()).unwrap();
200        assert!(updated_node.attrs.get("invalid_attr").is_none());
201    }
202
203    #[test]
204    fn test_attr_step_apply_nonexistent_node() {
205        let node: Node = create_test_node("root");
207        let mut tree = Tree::new(node);
208
209        let schema = create_test_schema();
211
212        let mut values = HashMap::new();
214        values.insert("name".to_string(), json!("test"));
215        let step = AttrStep::new("nonexistent".into(), values.into());
216
217        let result = step.apply(&mut tree, schema);
219        assert!(result.is_err());
220    }
221
222    #[test]
223    fn test_attr_step_serialize() {
224        let mut values = HashMap::new();
225        values.insert("name".to_string(), json!("test"));
226        let step = AttrStep::new("node1".into(), values.into());
227
228        let serialized = Step::serialize(&step);
229        assert!(serialized.is_some());
230
231        let deserialized: AttrStep =
233            serde_json::from_slice(&serialized.unwrap()).unwrap();
234        assert_eq!(deserialized.id, "node1".into());
235        assert_eq!(deserialized.values.get("name").unwrap(), &json!("test"));
236    }
237
238    #[test]
239    fn test_attr_step_invert() {
240        let node = create_test_node("node1");
242        let mut tree = Tree::new(node);
243
244        let schema = create_test_schema();
246
247        let mut values = HashMap::new();
249        values.insert("name".to_string(), json!("original_name"));
250        values.insert("age".to_string(), json!(25));
251        let step = AttrStep::new("node1".into(), values.into());
252        step.apply(&mut tree, schema.clone()).unwrap();
253
254        let mut new_values = HashMap::new();
256        new_values.insert("name".to_string(), json!("modified_name"));
257        new_values.insert("age".to_string(), json!(30));
258        let new_step = AttrStep::new("node1".into(), new_values.into());
259
260        let inverted = new_step.invert(&Arc::new(tree.clone()));
262        assert!(inverted.is_some());
263
264        new_step.apply(&mut tree, schema.clone()).unwrap();
266        let node = tree.get_node(&"node1".into()).unwrap();
267        assert_eq!(node.attrs.get("name").unwrap(), &json!("modified_name"));
268        assert_eq!(node.attrs.get("age").unwrap(), &json!(30));
269
270        let inverted_step = inverted.unwrap();
272        inverted_step.apply(&mut tree, schema).unwrap();
273
274        let node = tree.get_node(&"node1".into()).unwrap();
276        assert_eq!(node.attrs.get("name").unwrap(), &json!("original_name"));
277        assert_eq!(node.attrs.get("age").unwrap(), &json!(25));
278    }
279}