moduforge_transform/
draft.rs

1use std::sync::Arc;
2use im::HashMap;
3use serde_json::Value;
4use moduforge_model::{
5    error::PoolError,
6    mark::Mark,
7    node::Node,
8    node_pool::{NodePool, NodePoolInner},
9    types::NodeId,
10};
11
12use crate::step::StepResult;
13
14use super::patch::Patch;
15
16/// 草稿修改上下文,用于安全地修改节点池
17///
18/// 跟踪以下信息:
19///
20/// * 基础版本节点池
21/// * 当前修改的中间状态
22/// * 生成的修改补丁
23/// * 当前操作路径(用于嵌套数据结构)
24#[derive(Debug, Clone)]
25pub struct Draft {
26    pub base: Arc<NodePool>,
27    pub inner: NodePoolInner,
28    pub patches: im::Vector<Patch>,
29    pub current_path: im::Vector<String>,
30    pub skip_record: bool,
31    pub begin: bool,
32}
33
34impl Draft {
35    /// 创建基于现有节点池的草稿
36    ///
37    /// # 参数
38    ///
39    /// * `base` - 基础版本节点池的引用
40    pub fn new(base: Arc<NodePool>) -> Self {
41        Draft {
42            inner: base.inner.as_ref().clone(),
43            base,
44            patches: im::Vector::new(),
45            current_path: im::Vector::new(),
46            skip_record: false,
47            begin: false,
48        }
49    }
50
51    /// 进入嵌套路径(用于记录结构化修改)
52    ///
53    /// # 参数
54    ///
55    /// * `key` - Map 类型的字段名称
56    ///
57    /// # 示例
58    ///
59    /// ```
60    /// draft.enter_map("content").enter_list(0);
61    /// ```
62    pub fn enter_map(
63        &mut self,
64        key: &str,
65    ) -> &mut Self {
66        self.current_path.push_back(key.to_string());
67        self
68    }
69
70    /// 进入嵌套路径(List类型索引)
71    pub fn enter_list(
72        &mut self,
73        index: usize,
74    ) -> &mut Self {
75        self.current_path.push_back(index.to_string());
76        self
77    }
78
79    /// 退出当前路径层级
80    pub fn exit(&mut self) -> &mut Self {
81        if !self.current_path.is_empty() {
82            self.current_path =
83                self.current_path.take(self.current_path.len() - 1);
84        }
85        self
86    }
87
88    /// 提交属性修改并记录补丁
89    ///
90    /// # 参数
91    ///
92    /// * `id` - 目标节点ID
93    /// * `new_values` - 新属性集合
94    ///
95    /// # 错误
96    ///
97    /// 当节点不存在时返回 [`PoolError::NodeNotFound`]
98    pub fn update_attr(
99        &mut self,
100        id: &NodeId,
101        new_values: HashMap<String, Value>,
102    ) -> Result<(), PoolError> {
103        let node =
104            self.get_node(id).ok_or(PoolError::NodeNotFound(id.clone()))?;
105        let old_values = node.attrs.clone();
106
107        // 更新节点属性
108        let mut new_node = node.as_ref().clone();
109        new_node.attrs = new_values.clone();
110        self.inner.nodes =
111            self.inner.nodes.update(id.clone(), Arc::new(new_node));
112        // 记录补丁
113        self.record_patch(Patch::UpdateAttr {
114            path: self.current_path.iter().cloned().collect(),
115            id: id.clone(),
116            old: old_values.into_iter().collect(),
117            new: new_values.into_iter().collect(),
118        });
119        Ok(())
120    }
121    /// 从节点中移除指定标记
122    ///
123    /// # 参数
124    ///
125    /// * `id` - 目标节点ID
126    /// * `mark` - 要移除的标记
127    ///
128    /// # 错误
129    ///
130    /// 当节点不存在时返回 [`PoolError::NodeNotFound`]
131    pub fn remove_mark(
132        &mut self,
133        id: &NodeId,
134        mark: Mark,
135    ) -> Result<(), PoolError> {
136        let mut node = self
137            .get_node(id)
138            .ok_or(PoolError::NodeNotFound(id.clone()))?
139            .as_ref()
140            .clone();
141        node.marks =
142            node.marks.iter().filter(|&m| !m.eq(&mark)).cloned().collect();
143        self.inner.nodes.insert(id.clone(), Arc::new(node));
144        // 记录补丁
145        self.record_patch(Patch::RemoveMark {
146            path: self.current_path.iter().cloned().collect(),
147            parent_id: id.clone(),
148            marks: vec![mark],
149        });
150        Ok(())
151    }
152    /// 为节点添加标记
153    ///
154    /// # 参数
155    ///
156    /// * `id` - 目标节点ID
157    /// * `marks` - 要添加的标记列表
158    ///
159    /// # 错误
160    ///
161    /// 当节点不存在时返回 [`PoolError::NodeNotFound`]
162    pub fn add_mark(
163        &mut self,
164        id: &NodeId,
165        marks: &Vec<Mark>,
166    ) -> Result<(), PoolError> {
167        let mut node = self
168            .get_node(id)
169            .ok_or(PoolError::NodeNotFound(id.clone()))?
170            .as_ref()
171            .clone();
172        node.marks.extend(marks.clone());
173        self.inner.nodes.insert(id.clone(), Arc::new(node));
174        // 记录补丁
175        self.record_patch(Patch::AddMark {
176            path: self.current_path.iter().cloned().collect(),
177            node_id: id.clone(),
178            marks: marks.clone(),
179        });
180        Ok(())
181    }
182    /// 对节点的子节点进行排序
183    ///
184    /// # 参数
185    ///
186    /// * `parent_id` - 父节点ID
187    /// * `compare` - 排序比较函数
188    ///
189    /// # 错误
190    ///
191    /// 当父节点不存在时返回 [`PoolError::ParentNotFound`]
192    pub fn sort_children<
193        F: FnMut(
194            &(NodeId, &Arc<Node>),
195            &(NodeId, &Arc<Node>),
196        ) -> std::cmp::Ordering,
197    >(
198        &mut self,
199        parent_id: &NodeId,
200        compare: F,
201    ) -> Result<(), PoolError> {
202        // 检查父节点是否存在
203        let parent = self
204            .get_node(parent_id)
205            .ok_or(PoolError::ParentNotFound(parent_id.clone()))?;
206
207        // 获取所有子节点
208        let children_ids = parent.content.clone();
209        if children_ids.is_empty() {
210            return Ok(()); // 没有子节点,无需排序
211        }
212        let mut children: Vec<(NodeId, &Arc<Node>)> = Vec::new();
213        for child_id in &children_ids {
214            if let Some(node) = self.get_node(child_id) {
215                children.push((child_id.clone(), node));
216            }
217        }
218        children.sort_by(compare);
219        // 创建排序后的子节点ID列表
220        let sorted_children: im::Vector<NodeId> =
221            children.into_iter().map(|(id, _)| id).collect();
222        // 更新父节点
223        let mut new_parent = parent.as_ref().clone();
224        new_parent.content = sorted_children.clone();
225
226        // 记录补丁
227        self.record_patch(Patch::SortChildren {
228            path: self.current_path.iter().cloned().collect(),
229            parent_id: parent_id.clone(),
230            old_children: children_ids.iter().cloned().collect(),
231            new_children: sorted_children.iter().cloned().collect(),
232        });
233
234        self.inner.nodes.insert(parent_id.clone(), Arc::new(new_parent));
235        Ok(())
236    }
237
238    /// 添加子节点
239    ///
240    /// # 参数
241    ///
242    /// * `parent_id` - 父节点ID
243    /// * `nodes` - 要添加的子节点列表
244    pub fn add_node(
245        &mut self,
246        parent_id: &NodeId,
247        nodes: &Vec<Node>,
248    ) -> Result<(), PoolError> {
249        let parent = self
250            .get_node(parent_id)
251            .ok_or(PoolError::ParentNotFound(parent_id.clone()))?;
252        let mut new_parent = parent.as_ref().clone();
253        new_parent.content.push_back(nodes[0].id.clone());
254        self.inner.nodes.insert(parent_id.clone(), Arc::new(new_parent));
255        self.inner.parent_map.insert(nodes[0].id.clone(), parent_id.clone());
256        let mut new_nodes = vec![];
257        for node in nodes.into_iter() {
258            new_nodes.push(node.clone());
259            // 更新父节点映射
260            for child_id in &node.content {
261                self.inner.parent_map.insert(child_id.clone(), node.id.clone());
262            }
263            // 更新节点池
264            self.inner.nodes.insert(node.id.clone(), Arc::new(node.clone()));
265        }
266        // 记录补丁
267        self.record_patch(Patch::AddNode {
268            path: self.current_path.iter().cloned().collect(),
269            parent_id: parent_id.clone(),
270            nodes: new_nodes,
271        });
272        Ok(())
273    }
274
275    pub fn replace_node(
276        &mut self,
277        node_id: NodeId,
278        nodes: &Vec<Node>,
279    ) -> Result<(), PoolError> {
280        // 检查节点是否存在
281        let old_node = self
282            .get_node(&node_id)
283            .ok_or(PoolError::NodeNotFound(node_id.clone()))?;
284        // 确保新节点ID与原节点ID一致
285        if nodes[0].id != node_id {
286            return Err(PoolError::InvalidNodeId {
287                nodeid: node_id,
288                new_node_id: nodes[0].id.clone(),
289            });
290        }
291        let _ = self.remove_node(
292            &node_id,
293            old_node.content.iter().map(|id| id.clone()).collect(),
294        )?;
295        let _ = self.add_node(&node_id, nodes)?;
296        Ok(())
297    }
298    /// 移动节点
299    pub fn move_node(
300        &mut self,
301        source_parent_id: &NodeId,
302        target_parent_id: &NodeId,
303        node_id: &NodeId,
304        position: Option<usize>,
305    ) -> Result<(), PoolError> {
306        // 检查源父节点是否存在
307        let source_parent = self
308            .get_node(source_parent_id)
309            .ok_or(PoolError::ParentNotFound(source_parent_id.clone()))?;
310        // 检查目标父节点是否存在
311        let target_parent = self
312            .get_node(target_parent_id)
313            .ok_or(PoolError::ParentNotFound(target_parent_id.clone()))?;
314        // 检查要移动的节点是否存在
315        let _node = self
316            .get_node(node_id)
317            .ok_or(PoolError::NodeNotFound(node_id.clone()))?;
318        // 检查节点是否是源父节点的子节点
319        if !source_parent.content.contains(node_id) {
320            return Err(PoolError::InvalidParenting {
321                child: node_id.clone(),
322                alleged_parent: source_parent_id.clone(),
323            });
324        }
325        // 从源父节点中移除该节点
326        let mut new_source_parent = source_parent.as_ref().clone();
327        new_source_parent.content = new_source_parent
328            .content
329            .iter()
330            .filter(|&id| id != node_id)
331            .cloned()
332            .collect();
333
334        // 准备将节点添加到目标父节点
335        let mut new_target_parent = target_parent.as_ref().clone();
336        // 根据指定位置插入节点
337        if let Some(pos) = position {
338            if pos <= new_target_parent.content.len() {
339                // 在指定位置插入
340                let mut new_content = im::Vector::new();
341                for (i, child_id) in
342                    new_target_parent.content.iter().enumerate()
343                {
344                    if i == pos {
345                        new_content.push_back(node_id.clone());
346                    }
347                    new_content.push_back(child_id.clone());
348                }
349                // 如果位置是在最后,需要额外处理
350                if pos == new_target_parent.content.len() {
351                    new_content.push_back(node_id.clone());
352                }
353                new_target_parent.content = new_content;
354            } else {
355                // 如果位置超出范围,添加到末尾
356                new_target_parent.content.push_back(node_id.clone());
357            }
358        } else {
359            // 默认添加到末尾
360            new_target_parent.content.push_back(node_id.clone());
361        }
362
363        self.inner
364            .nodes
365            .insert(source_parent_id.clone(), Arc::new(new_source_parent));
366        self.inner
367            .nodes
368            .insert(target_parent_id.clone(), Arc::new(new_target_parent));
369        // 更新父子关系映射
370        self.inner.parent_map.insert(node_id.clone(), target_parent_id.clone());
371        // 记录移动节点的补丁
372        self.record_patch(Patch::MoveNode {
373            path: self.current_path.iter().cloned().collect(),
374            node_id: node_id.clone(),
375            source_parent_id: source_parent_id.clone(),
376            target_parent_id: target_parent_id.clone(),
377            position,
378        });
379        Ok(())
380    }
381    pub fn get_node(
382        &self,
383        id: &NodeId,
384    ) -> Option<&Arc<Node>> {
385        self.inner.nodes.get(id)
386    }
387    pub fn children(
388        &self,
389        parent_id: &NodeId,
390    ) -> Option<&im::Vector<NodeId>> {
391        self.get_node(parent_id).map(|n| &n.content)
392    }
393    /// 移除子节点    
394    ///
395    /// # 参数
396    ///
397    /// * `parent_id` - 父节点ID
398    /// * `nodes` - 要移除的子节点ID列表
399    ///
400    /// # 错误
401    ///
402    /// 当父节点不存在时返回 [`PoolError::ParentNotFound`]
403    /// 当尝试删除根节点时返回 [`PoolError::CannotRemoveRoot`]
404    /// 当要删除的节点不是父节点的直接子节点时返回 [`PoolError::InvalidParenting`]
405    pub fn remove_node(
406        &mut self,
407        parent_id: &NodeId,
408        nodes: Vec<NodeId>,
409    ) -> Result<(), PoolError> {
410        // 检查父节点是否存在
411        let parent = self
412            .get_node(parent_id)
413            .ok_or(PoolError::ParentNotFound(parent_id.clone()))?;
414
415        // 检查是否尝试删除根节点
416        if nodes.contains(&self.inner.root_id) {
417            return Err(PoolError::CannotRemoveRoot);
418        }
419
420        // 验证所有要删除的节点都是父节点的直接子节点
421        for node_id in &nodes {
422            if !parent.content.contains(node_id) {
423                return Err(PoolError::InvalidParenting {
424                    child: node_id.clone(),
425                    alleged_parent: parent_id.clone(),
426                });
427            }
428        }
429
430        // 使用 HashSet 优化查找性能
431        let nodes_to_remove: std::collections::HashSet<_> =
432            nodes.iter().collect();
433
434        // 过滤保留的子节点
435        let filtered_children: im::Vector<NodeId> = parent
436            .as_ref()
437            .content
438            .iter()
439            .filter(|&id| !nodes_to_remove.contains(id))
440            .cloned()
441            .collect();
442
443        // 更新父节点
444        let mut parent_node = parent.as_ref().clone();
445        parent_node.content = filtered_children;
446        self.inner.nodes.insert(parent_id.clone(), Arc::new(parent_node));
447        let mut remove_nodes = Vec::new();
448        // 递归删除所有子节点
449        for node_id in nodes {
450            self.remove_subtree(&node_id, &mut remove_nodes)?;
451        }
452        self.record_patch(Patch::RemoveNode {
453            path: self.current_path.iter().cloned().collect(),
454            parent_id: parent_id.clone(),
455            nodes: remove_nodes,
456        });
457        Ok(())
458    }
459
460    /// 递归删除子树
461    ///
462    /// # 参数
463    ///
464    /// * `parent_id` - 父节点ID
465    /// * `node_id` - 要删除的节点ID
466    ///
467    /// # 错误
468    ///
469    /// 当节点不存在时返回 [`PoolError::NodeNotFound`]
470    /// 当尝试删除根节点时返回 [`PoolError::CannotRemoveRoot`]
471    fn remove_subtree(
472        &mut self,
473        node_id: &NodeId,
474        remove_nodes: &mut Vec<Node>,
475    ) -> Result<(), PoolError> {
476        // 检查是否是根节点
477        if node_id == &self.inner.root_id {
478            return Err(PoolError::CannotRemoveRoot);
479        }
480
481        // 获取要删除的节点
482        let _ = self
483            .get_node(node_id)
484            .ok_or(PoolError::NodeNotFound(node_id.clone()))?;
485
486        // 递归删除所有子节点
487        if let Some(children) = self.children(node_id).cloned() {
488            for child_id in children {
489                self.remove_subtree(&child_id, remove_nodes)?;
490            }
491        }
492
493        // 从父节点映射中移除
494        self.inner.parent_map.remove(node_id);
495
496        // 从节点池中移除并记录补丁
497        if let Some(remove_node) = self.inner.nodes.remove(node_id) {
498            remove_nodes.push(remove_node.as_ref().clone());
499        }
500        Ok(())
501    }
502
503    /// 应用补丁集合并更新节点池
504    ///
505    /// # 参数
506    ///
507    /// * `patches` - 要应用的补丁集合
508    ///
509    /// # 注意
510    ///
511    /// 应用过程中会临时禁用补丁记录
512    pub fn apply_patches(
513        &mut self,
514        patches: &Vec<Patch>,
515    ) -> Result<(), PoolError> {
516        //跳过记录
517        self.skip_record = true;
518        for patch in patches {
519            match patch {
520                Patch::UpdateAttr { path: _, id, old: _, new } => {
521                    self.update_attr(id, new.clone().into())?;
522                },
523                Patch::AddNode { path: _, parent_id, nodes } => {
524                    self.add_node(parent_id, nodes)?;
525                },
526                Patch::AddMark { path: _, node_id, marks } => {
527                    self.add_mark(node_id, marks)?;
528                },
529                Patch::RemoveNode { path: _, parent_id, nodes } => {
530                    self.remove_node(
531                        parent_id,
532                        nodes.iter().map(|n| n.id.clone()).collect(),
533                    )?;
534                },
535                Patch::RemoveMark { path: _, parent_id, marks } => {
536                    for mark in marks {
537                        self.remove_mark(&parent_id, mark.clone())?;
538                    }
539                },
540                Patch::MoveNode {
541                    path: _,
542                    node_id,
543                    source_parent_id,
544                    target_parent_id,
545                    position,
546                } => {
547                    self.move_node(
548                        source_parent_id,
549                        target_parent_id,
550                        node_id,
551                        position.clone(),
552                    )?;
553                },
554                Patch::SortChildren {
555                    path: _,
556                    parent_id,
557                    old_children: _,
558                    new_children,
559                } => {
560                    let parent = self
561                        .get_node(parent_id)
562                        .ok_or(PoolError::ParentNotFound(parent_id.clone()))?;
563                    let mut new_parent = parent.as_ref().clone();
564                    new_parent.content = new_children.iter().cloned().collect();
565                    self.inner
566                        .nodes
567                        .insert(parent_id.clone(), Arc::new(new_parent));
568                },
569            }
570        }
571        self.skip_record = false;
572        Ok(())
573    }
574    /// 翻转补丁集合并应用到节点池
575    pub fn reverse_patches(
576        &mut self,
577        patches: Vec<Patch>,
578    ) -> Result<(), PoolError> {
579        //跳过记录
580        self.skip_record = true;
581        for patch in patches {
582            match patch {
583                Patch::UpdateAttr { path: _, id, old, new: _ } => {
584                    self.update_attr(&id, old.clone().into())?;
585                },
586                Patch::AddNode { path: _, parent_id, nodes } => {
587                    self.remove_node(
588                        &parent_id,
589                        nodes.iter().map(|f| f.id.clone()).collect(),
590                    )?;
591                },
592                Patch::AddMark { path: _, node_id, marks } => {
593                    self.remove_mark(&node_id, marks[0].clone())?;
594                },
595                Patch::RemoveNode { path: _, parent_id, nodes } => {
596                    self.add_node(&parent_id, &nodes)?;
597                },
598                Patch::RemoveMark { path: _, parent_id, marks } => {
599                    self.add_mark(&parent_id, &marks)?;
600                },
601                Patch::MoveNode {
602                    path: _,
603                    node_id,
604                    source_parent_id,
605                    target_parent_id,
606                    position,
607                } => {
608                    self.move_node(
609                        &target_parent_id,
610                        &source_parent_id,
611                        &node_id,
612                        position.clone(),
613                    )?;
614                },
615                Patch::SortChildren {
616                    path: _,
617                    parent_id,
618                    old_children,
619                    new_children: _,
620                } => {
621                    let parent = self
622                        .get_node(&parent_id)
623                        .ok_or(PoolError::ParentNotFound(parent_id.clone()))?;
624                    let mut new_parent = parent.as_ref().clone();
625                    new_parent.content = old_children.iter().cloned().collect();
626                    self.inner
627                        .nodes
628                        .insert(parent_id.clone(), Arc::new(new_parent));
629                },
630            }
631        }
632        self.skip_record = false;
633        Ok(())
634    }
635
636    fn record_patch(
637        &mut self,
638        patch: Patch,
639    ) {
640        if !self.skip_record {
641            self.patches.push_back(patch);
642        }
643    }
644    /// 提交修改,生成新 NodePool 和补丁列表
645    pub fn commit(&self) -> StepResult {
646        match self.begin {
647            true => StepResult {
648                doc: None,
649                failed: Some("事务操作".to_string()),
650                patches: Vec::new(),
651            },
652            false => {
653                let new_pool = NodePool { inner: Arc::new(self.inner.clone()) };
654                StepResult::ok(
655                    Arc::new(new_pool),
656                    self.patches.iter().cloned().collect(),
657                )
658            },
659        }
660    }
661}