mf_transform/
mark_step.rs

1use std::{sync::Arc};
2
3use mf_model::{mark::Mark, schema::Schema, tree::Tree, types::NodeId};
4
5use crate::{transform_error, TransformResult};
6
7use super::{
8    step::{Step, StepResult},
9};
10use serde::{Deserialize, Serialize};
11#[derive(Debug, Serialize, Deserialize, Clone)]
12pub struct AddMarkStep {
13    pub id: NodeId,
14    pub marks: Vec<Mark>,
15}
16impl AddMarkStep {
17    pub fn new(
18        id: NodeId,
19        marks: Vec<Mark>,
20    ) -> Self {
21        AddMarkStep { id, marks }
22    }
23}
24impl Step for AddMarkStep {
25    fn name(&self) -> String {
26        "add_mark_step".to_string()
27    }
28    fn apply(
29        &self,
30        dart: &mut Tree,
31        schema: Arc<Schema>,
32    ) -> TransformResult<StepResult> {
33        let _ = schema;
34        let result = dart.mark(&self.id) + self.marks.clone();
35        match result {
36            Ok(_) => Ok(StepResult::ok()),
37            Err(e) => Err(transform_error(e.to_string())),
38        }
39    }
40    fn serialize(&self) -> Option<Vec<u8>> {
41        serde_json::to_vec(self).ok()
42    }
43
44    fn invert(
45        &self,
46        dart: &Arc<Tree>,
47    ) -> Option<Arc<dyn Step>> {
48        match dart.get_node(&self.id) {
49            Some(_) => Some(Arc::new(RemoveMarkStep::new(
50                self.id.clone(),
51                self.marks.clone().iter().map(|m| m.r#type.clone()).collect(),
52            ))),
53            None => None,
54        }
55    }
56}
57
58#[derive(Debug, Serialize, Deserialize, Clone)]
59pub struct RemoveMarkStep {
60    pub id: NodeId,
61    pub mark_types: Vec<String>,
62}
63impl RemoveMarkStep {
64    pub fn new(
65        id: NodeId,
66        mark_types: Vec<String>,
67    ) -> Self {
68        RemoveMarkStep { id, mark_types }
69    }
70}
71impl Step for RemoveMarkStep {
72    fn name(&self) -> String {
73        "remove_mark_step".to_string()
74    }
75    fn apply(
76        &self,
77        dart: &mut Tree,
78        schema: Arc<Schema>,
79    ) -> TransformResult<StepResult> {
80        let _ = schema;
81        let result = dart.mark(&self.id) - self.mark_types.clone();
82        match result {
83            Ok(_) => Ok(StepResult::ok()),
84            Err(e) => Err(transform_error(e.to_string())),
85        }
86    }
87    fn serialize(&self) -> Option<Vec<u8>> {
88        serde_json::to_vec(self).ok()
89    }
90
91    fn invert(
92        &self,
93        dart: &Arc<Tree>,
94    ) -> Option<Arc<dyn Step>> {
95        match dart.get_node(&self.id) {
96            Some(node) => {
97                // 仅恢复被移除的 mark 类型,避免把未移除的也加回
98                let removed_types = &self.mark_types;
99                let to_restore: Vec<Mark> = node
100                    .marks
101                    .iter()
102                    .filter(|m| removed_types.contains(&m.r#type))
103                    .cloned()
104                    .collect();
105                if to_restore.is_empty() {
106                    None
107                } else {
108                    Some(Arc::new(AddMarkStep::new(
109                        self.id.clone(),
110                        to_restore,
111                    )))
112                }
113            },
114            None => None,
115        }
116    }
117}