mf_transform/
attr_step.rs1use std::sync::Arc;
2
3use crate::{transform_error, TransformResult};
4
5use super::{
6 step::{Step, StepResult},
7};
8
9use mf_model::{schema::Schema, tree::Tree, types::NodeId};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value};
12use mf_model::rpds::HashTrieMapSync;
13
14#[derive(Debug, Serialize, Deserialize, Clone)]
16pub struct AttrStep {
17 pub id: NodeId,
18 pub values: HashTrieMapSync<String, Value>,
19}
20
21impl AttrStep {
22 pub fn new(
23 id: NodeId,
24 values: HashTrieMapSync<String, Value>,
25 ) -> Self {
26 AttrStep { id, values }
27 }
28}
29
30impl Step for AttrStep {
31 fn name(&self) -> String {
32 "attr_step".to_string()
33 }
34
35 fn apply(
36 &self,
37 dart: &mut Tree,
38 schema: Arc<Schema>,
39 ) -> TransformResult<StepResult> {
40 let factory = schema.factory();
41 match dart.get_node(&self.id) {
42 Some(node) => {
43 let node_type = match factory.node_definition(&node.r#type) {
45 Some(nt) => nt,
46 None => {
47 return Err(transform_error(format!(
48 "未知的节点类型: {}",
49 node.r#type
50 )));
51 },
52 };
53 let attr = &node_type.attrs;
54 let mut new_values = self.values.clone();
56 for (key, _) in self.values.iter() {
57 if !attr.contains_key(key) {
58 new_values.remove_mut(key);
59 }
60 }
61 let result = dart.attrs(&self.id) + new_values;
62 match result {
63 Ok(_) => Ok(StepResult::ok()),
64 Err(e) => Err(transform_error(e.to_string())),
65 }
66 },
67 None => Err(transform_error("节点不存在".to_string())),
68 }
69 }
70
71 fn serialize(&self) -> Option<Vec<u8>> {
72 serde_json::to_vec(self).ok()
73 }
74
75 fn invert(
76 &self,
77 dart: &Arc<Tree>,
78 ) -> Option<Arc<dyn Step>> {
79 match dart.get_node(&self.id) {
80 Some(node) => {
81 let mut revert_values = HashTrieMapSync::new_sync();
83 for (changed_key, _) in self.values.iter() {
84 if let Some(old_val) = node.attrs.get_safe(changed_key) {
85 revert_values
86 .insert_mut(changed_key.clone(), old_val.clone());
87 }
88 }
91 if revert_values.is_empty() {
92 None
93 } else {
94 Some(Arc::new(AttrStep::new(
95 self.id.clone(),
96 revert_values,
97 )))
98 }
99 },
100 None => None,
101 }
102 }
103}