1use std::sync::Arc;
2
3use crate::{transform_error, TransformResult};
4
5use super::{
6 step::{Step, StepResult},
7};
8use imbl::HashMap as ImHashMap;
9use mf_model::{schema::Schema, tree::Tree, types::NodeId};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value};
12
13#[derive(Debug, Serialize, Deserialize, Clone)]
14pub struct AttrStep {
15 pub id: NodeId,
16 pub values: ImHashMap<String, Value>,
17}
18
19impl AttrStep {
20 pub fn new(
21 id: NodeId,
22 values: ImHashMap<String, Value>,
23 ) -> Self {
24 AttrStep { id, values }
25 }
26}
27
28impl Step for AttrStep {
29 fn name(&self) -> String {
30 "attr_step".to_string()
31 }
32
33 fn apply(
34 &self,
35 dart: &mut Tree,
36 schema: Arc<Schema>,
37 ) -> TransformResult<StepResult> {
38 let _ = schema;
39 match dart.get_node(&self.id) {
40 Some(node) => {
41 let node_type = match schema.nodes.get(&node.r#type) {
43 Some(nt) => nt,
44 None => {
45 return Err(transform_error(format!(
46 "未知的节点类型: {}",
47 node.r#type
48 )));
49 },
50 };
51 let attr = &node_type.attrs;
52 let mut new_values = self.values.clone();
54 for (key, _) in self.values.iter() {
55 if !attr.contains_key(key) {
56 new_values.remove(key);
57 }
58 }
59 let result = dart.attrs(&self.id) + new_values;
60 match result {
61 Ok(_) => Ok(StepResult::ok()),
62 Err(e) => Err(transform_error(e.to_string())),
63 }
64 },
65 None => {
66 return Err(transform_error("节点不存在".to_string()));
67 },
68 }
69 }
70
71 fn serialize(&self) -> Option<Vec<u8>> {
72 serde_json::to_vec(self).ok()
73 }
74
75 fn invert(
76 &self,
77 dart: &Arc<Tree>,
78 ) -> Option<Arc<dyn Step>> {
79 match dart.get_node(&self.id) {
80 Some(node) => {
81 let mut revert_values = imbl::hashmap!();
83 for (changed_key, _) in self.values.iter() {
84 if let Some(old_val) = node.attrs.get_safe(changed_key) {
85 revert_values
86 .insert(changed_key.clone(), old_val.clone());
87 }
88 }
91 if revert_values.is_empty() {
92 None
93 } else {
94 Some(Arc::new(AttrStep::new(
95 self.id.clone(),
96 revert_values,
97 )))
98 }
99 },
100 None => None,
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use mf_model::node::Node;
109 use mf_model::attrs::Attrs;
110 use mf_model::node_type::NodeSpec;
111 use mf_model::schema::{SchemaSpec, AttributeSpec};
112 use std::collections::HashMap;
113 use std::sync::Arc;
114 use serde_json::json;
115
116 fn create_test_node(id: &str) -> Node {
117 Node::new(id, "test".to_string(), Attrs::default(), vec![], vec![])
118 }
119
120 fn create_test_schema() -> Arc<Schema> {
121 let mut nodes = HashMap::new();
122 let mut attrs = HashMap::new();
123 attrs.insert("name".to_string(), AttributeSpec { default: None });
124 attrs.insert("age".to_string(), AttributeSpec { default: None });
125
126 nodes.insert(
127 "test".to_string(),
128 NodeSpec {
129 content: None,
130 marks: None,
131 group: None,
132 desc: Some("Test node".to_string()),
133 attrs: Some(attrs),
134 },
135 );
136
137 let spec = SchemaSpec {
138 nodes,
139 marks: HashMap::new(),
140 top_node: Some("test".to_string()),
141 };
142
143 Arc::new(Schema::compile(spec).unwrap())
144 }
145
146 #[test]
147 fn test_attr_step_creation() {
148 let mut values = imbl::HashMap::new();
149 values.insert("name".to_string(), json!("test"));
150 values.insert("age".to_string(), json!(25));
151
152 let step = AttrStep::new("node1".into(), values.clone().into());
153 assert_eq!(step.id, "node1".into());
154 assert_eq!(step.values, values.into());
155 }
156
157 #[test]
158 fn test_attr_step_apply() {
159 let node = create_test_node("node1");
161 let mut tree = Tree::new(node);
162
163 let schema = create_test_schema();
165
166 let mut values = HashMap::new();
168 values.insert("name".to_string(), json!("test"));
169 values.insert("age".to_string(), json!(25));
170 let step = AttrStep::new("node1".into(), values.into());
171
172 let result = step.apply(&mut tree, schema.clone());
174 assert!(result.is_ok());
175
176 let updated_node = tree.get_node(&"node1".into()).unwrap();
178 assert_eq!(updated_node.attrs.get("name").unwrap(), &json!("test"));
179 assert_eq!(updated_node.attrs.get("age").unwrap(), &json!(25));
180 }
181
182 #[test]
183 fn test_attr_step_apply_invalid_attrs() {
184 let node = create_test_node("node1");
186 let mut tree = Tree::new(node);
187
188 let schema = create_test_schema();
190
191 let mut values = HashMap::new();
193 values.insert("invalid_attr".to_string(), json!("test"));
194 let step = AttrStep::new("node1".into(), values.into());
195
196 let result = step.apply(&mut tree, schema.clone());
198 assert!(result.is_ok());
199
200 let updated_node = tree.get_node(&"node1".into()).unwrap();
202 assert!(updated_node.attrs.get("invalid_attr").is_none());
203 }
204
205 #[test]
206 fn test_attr_step_apply_nonexistent_node() {
207 let node: Node = create_test_node("root");
209 let mut tree = Tree::new(node);
210
211 let schema = create_test_schema();
213
214 let mut values = HashMap::new();
216 values.insert("name".to_string(), json!("test"));
217 let step = AttrStep::new("nonexistent".into(), values.into());
218
219 let result = step.apply(&mut tree, schema);
221 assert!(result.is_err());
222 }
223
224 #[test]
225 fn test_attr_step_serialize() {
226 let mut values = HashMap::new();
227 values.insert("name".to_string(), json!("test"));
228 let step = AttrStep::new("node1".into(), values.into());
229
230 let serialized = Step::serialize(&step);
231 assert!(serialized.is_some());
232
233 let deserialized: AttrStep =
235 serde_json::from_slice(&serialized.unwrap()).unwrap();
236 assert_eq!(deserialized.id, "node1".into());
237 assert_eq!(deserialized.values.get("name").unwrap(), &json!("test"));
238 }
239
240 #[test]
241 fn test_attr_step_invert() {
242 let node = create_test_node("node1");
244 let mut tree = Tree::new(node);
245
246 let schema = create_test_schema();
248
249 let mut values = HashMap::new();
251 values.insert("name".to_string(), json!("original_name"));
252 values.insert("age".to_string(), json!(25));
253 let step = AttrStep::new("node1".into(), values.into());
254 step.apply(&mut tree, schema.clone()).unwrap();
255
256 let mut new_values = HashMap::new();
258 new_values.insert("name".to_string(), json!("modified_name"));
259 new_values.insert("age".to_string(), json!(30));
260 let new_step = AttrStep::new("node1".into(), new_values.into());
261
262 let inverted = new_step.invert(&Arc::new(tree.clone()));
264 assert!(inverted.is_some());
265
266 new_step.apply(&mut tree, schema.clone()).unwrap();
268 let node = tree.get_node(&"node1".into()).unwrap();
269 assert_eq!(node.attrs.get("name").unwrap(), &json!("modified_name"));
270 assert_eq!(node.attrs.get("age").unwrap(), &json!(30));
271
272 let inverted_step = inverted.unwrap();
274 inverted_step.apply(&mut tree, schema).unwrap();
275
276 let node = tree.get_node(&"node1".into()).unwrap();
278 assert_eq!(node.attrs.get("name").unwrap(), &json!("original_name"));
279 assert_eq!(node.attrs.get("age").unwrap(), &json!(25));
280 }
281}