moduforge_model/
tree.rs

1use std::sync::Arc;
2use std::{ops::Index, num::NonZeroUsize};
3use std::hash::{Hash, Hasher};
4use std::collections::hash_map::DefaultHasher;
5use im::Vector;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use once_cell::sync::Lazy;
9use parking_lot::RwLock;
10use lru::LruCache;
11use std::fmt::{self, Debug};
12use crate::error::PoolResult;
13use crate::node_type::NodeEnum;
14use crate::{
15    error::error_helpers,
16    mark::Mark,
17    node::Node,
18    ops::{AttrsRef, MarkRef, NodeRef},
19    types::NodeId,
20};
21
22// 全局LRU缓存用于存储NodeId到分片索引的映射
23static SHARD_INDEX_CACHE: Lazy<RwLock<LruCache<String, usize>>> =
24    Lazy::new(|| RwLock::new(LruCache::new(NonZeroUsize::new(10000).unwrap())));
25
26#[derive(Clone, PartialEq, Serialize, Deserialize)]
27pub struct Tree {
28    pub root_id: NodeId,
29    pub nodes: Vector<im::HashMap<NodeId, Arc<Node>>>, // 分片存储节点数据
30    pub parent_map: im::HashMap<NodeId, NodeId>,
31    #[serde(skip)]
32    num_shards: usize, // 缓存分片数量,避免重复计算
33}
34impl Debug for Tree {
35    fn fmt(
36        &self,
37        f: &mut fmt::Formatter<'_>,
38    ) -> fmt::Result {
39        //输出的时候 过滤掉空的 nodes 节点
40        let nodes = self
41            .nodes
42            .iter()
43            .filter(|node| !node.is_empty())
44            .collect::<Vec<_>>();
45        f.debug_struct("Tree")
46            .field("root_id", &self.root_id)
47            .field("nodes", &nodes)
48            .field("parent_map", &self.parent_map)
49            .field("num_shards", &self.num_shards)
50            .finish()
51    }
52}
53
54impl Tree {
55    #[inline]
56    pub fn get_shard_index(
57        &self,
58        id: &NodeId,
59    ) -> usize {
60        // 先检查缓存
61        {
62            let cache = SHARD_INDEX_CACHE.read();
63            if let Some(&index) = cache.peek(id) {
64                return index;
65            }
66        }
67
68        // 缓存未命中,计算哈希值
69        let mut hasher = DefaultHasher::new();
70        id.hash(&mut hasher);
71        let index = (hasher.finish() as usize) % self.num_shards;
72
73        // 更新缓存
74        {
75            let mut cache = SHARD_INDEX_CACHE.write();
76            cache.put(id.clone(), index);
77        }
78
79        index
80    }
81
82    #[inline]
83    pub fn get_shard_indices(
84        &self,
85        ids: &[&NodeId],
86    ) -> Vec<usize> {
87        ids.iter().map(|id| self.get_shard_index(id)).collect()
88    }
89
90    // 为批量操作提供优化的哈希计算
91    #[inline]
92    pub fn get_shard_index_batch<'a>(
93        &self,
94        ids: &'a [&'a NodeId],
95    ) -> Vec<(usize, &'a NodeId)> {
96        let mut results = Vec::with_capacity(ids.len());
97        let mut cache_misses = Vec::new();
98
99        // 批量检查缓存
100        {
101            let cache = SHARD_INDEX_CACHE.read();
102            for &id in ids {
103                if let Some(&index) = cache.peek(id) {
104                    results.push((index, id));
105                } else {
106                    cache_misses.push(id);
107                }
108            }
109        }
110
111        // 批量计算缓存未命中的项
112        if !cache_misses.is_empty() {
113            let mut cache = SHARD_INDEX_CACHE.write();
114            for &id in &cache_misses {
115                let mut hasher = DefaultHasher::new();
116                id.hash(&mut hasher);
117                let index = (hasher.finish() as usize) % self.num_shards;
118                cache.put(id.clone(), index);
119                results.push((index, id));
120            }
121        }
122
123        results
124    }
125
126    // 清理缓存的方法,用于内存管理
127    pub fn clear_shard_cache() {
128        let mut cache = SHARD_INDEX_CACHE.write();
129        cache.clear();
130    }
131
132    pub fn contains_node(
133        &self,
134        id: &NodeId,
135    ) -> bool {
136        let shard_index = self.get_shard_index(id);
137        self.nodes[shard_index].contains_key(id)
138    }
139
140    pub fn get_node(
141        &self,
142        id: &NodeId,
143    ) -> Option<Arc<Node>> {
144        let shard_index = self.get_shard_index(id);
145        self.nodes[shard_index].get(id).cloned()
146    }
147
148    pub fn get_parent_node(
149        &self,
150        id: &NodeId,
151    ) -> Option<Arc<Node>> {
152        self.parent_map.get(id).and_then(|parent_id| {
153            let shard_index = self.get_shard_index(parent_id);
154            self.nodes[shard_index].get(parent_id).cloned()
155        })
156    }
157    pub fn from(nodes: NodeEnum) -> Self {
158        let num_shards = std::cmp::max(
159            std::thread::available_parallelism()
160                .map(NonZeroUsize::get)
161                .unwrap_or(2),
162            2,
163        );
164        let mut shards = Vector::from(vec![im::HashMap::new(); num_shards]);
165        let mut parent_map = im::HashMap::new();
166        let (root_node, children) = nodes.into_parts();
167        let root_id = root_node.id.clone();
168
169        let mut hasher = DefaultHasher::new();
170        root_id.hash(&mut hasher);
171        let shard_index = (hasher.finish() as usize) % num_shards;
172        shards[shard_index] =
173            shards[shard_index].update(root_id.clone(), Arc::new(root_node));
174
175        fn process_children(
176            children: Vec<NodeEnum>,
177            parent_id: &NodeId,
178            shards: &mut Vector<im::HashMap<NodeId, Arc<Node>>>,
179            parent_map: &mut im::HashMap<NodeId, NodeId>,
180            num_shards: usize,
181        ) {
182            for child in children {
183                let (node, grand_children) = child.into_parts();
184                let node_id = node.id.clone();
185                let mut hasher = DefaultHasher::new();
186                node_id.hash(&mut hasher);
187                let shard_index = (hasher.finish() as usize) % num_shards;
188                shards[shard_index] =
189                    shards[shard_index].update(node_id.clone(), Arc::new(node));
190                parent_map.insert(node_id.clone(), parent_id.clone());
191
192                // Recursively process grand children
193                process_children(
194                    grand_children,
195                    &node_id,
196                    shards,
197                    parent_map,
198                    num_shards,
199                );
200            }
201        }
202
203        process_children(
204            children,
205            &root_id,
206            &mut shards,
207            &mut parent_map,
208            num_shards,
209        );
210
211        Self { root_id, nodes: shards, parent_map, num_shards }
212    }
213
214    pub fn new(root: Node) -> Self {
215        let num_shards = std::cmp::max(
216            std::thread::available_parallelism()
217                .map(NonZeroUsize::get)
218                .unwrap_or(2),
219            2,
220        );
221        let mut nodes = Vector::from(vec![im::HashMap::new(); num_shards]);
222        let root_id = root.id.clone();
223        let mut hasher = DefaultHasher::new();
224        root_id.hash(&mut hasher);
225        let shard_index = (hasher.finish() as usize) % num_shards;
226        nodes[shard_index] =
227            nodes[shard_index].update(root_id.clone(), Arc::new(root));
228        Self { root_id, nodes, parent_map: im::HashMap::new(), num_shards }
229    }
230
231    pub fn update_attr(
232        &mut self,
233        id: &NodeId,
234        new_values: im::HashMap<String, Value>,
235    ) -> PoolResult<()> {
236        let shard_index = self.get_shard_index(id);
237        let node = self.nodes[shard_index]
238            .get(id)
239            .ok_or(error_helpers::node_not_found(id.clone()))?;
240        let old_values = node.attrs.clone();
241        let mut new_node = node.as_ref().clone();
242        let new_attrs = old_values.update(new_values);
243        new_node.attrs = new_attrs.clone();
244        self.nodes[shard_index] =
245            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
246        Ok(())
247    }
248    pub fn update_node(
249        &mut self,
250        node: Node,
251    ) -> PoolResult<()> {
252        let shard_index = self.get_shard_index(&node.id);
253        self.nodes[shard_index] =
254            self.nodes[shard_index].update(node.id.clone(), Arc::new(node));
255        Ok(())
256    }
257
258    /// 向树中添加新的节点及其子节点
259    ///
260    /// # 参数
261    /// * `nodes` - 要添加的节点枚举,包含节点本身及其子节点
262    ///
263    /// # 返回值
264    /// * `Result<(), PoolError>` - 如果添加成功返回 Ok(()), 否则返回错误
265    ///
266    /// # 错误
267    /// * `PoolError::ParentNotFound` - 如果父节点不存在
268    pub fn add(
269        &mut self,
270        parent_id: &NodeId,
271        nodes: Vec<NodeEnum>,
272    ) -> PoolResult<()> {
273        // 检查父节点是否存在
274        let parent_shard_index = self.get_shard_index(&parent_id);
275        let parent_node = self.nodes[parent_shard_index]
276            .get(parent_id)
277            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
278        let mut new_parent = parent_node.as_ref().clone();
279
280        // 收集所有子节点的ID并添加到当前节点的content中
281        let zenliang: Vector<String> =
282            nodes.iter().map(|n| n.0.id.clone()).collect();
283        // 需要判断 new_parent.content 中是否已经存在 zenliang 中的节点
284        let mut new_content = im::Vector::new();
285        for id in zenliang {
286            if !new_parent.content.contains(&id) {
287                new_content.push_back(id);
288            }
289        }
290        new_parent.content.extend(new_content);
291
292        // 更新当前节点
293        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
294            .update(parent_id.clone(), Arc::new(new_parent));
295
296        // 使用队列进行广度优先遍历,处理所有子节点
297        let mut node_queue = Vec::new();
298        node_queue.push((nodes, parent_id.clone()));
299        while let Some((current_children, current_parent_id)) = node_queue.pop()
300        {
301            for child in current_children {
302                // 处理每个子节点
303                let (mut child_node, grand_children) = child.into_parts();
304                let current_node_id = child_node.id.clone();
305
306                // 收集孙节点的ID并添加到子节点的content中
307                let grand_children_ids: Vector<String> =
308                    grand_children.iter().map(|n| n.0.id.clone()).collect();
309                let mut new_content = im::Vector::new();
310                for id in grand_children_ids {
311                    if !child_node.content.contains(&id) {
312                        new_content.push_back(id);
313                    }
314                }
315                child_node.content.extend(new_content);
316
317                // 将当前节点存储到对应的分片中
318                let shard_index = self.get_shard_index(&current_node_id);
319                self.nodes[shard_index] = self.nodes[shard_index]
320                    .update(current_node_id.clone(), Arc::new(child_node));
321
322                // 更新父子关系映射
323                self.parent_map
324                    .insert(current_node_id.clone(), current_parent_id.clone());
325
326                // 将孙节点加入队列,以便后续处理
327                node_queue.push((grand_children, current_node_id.clone()));
328            }
329        }
330        Ok(())
331    }
332    // 添加到下标
333    pub fn add_at_index(
334        &mut self,
335        parent_id: &NodeId,
336        index: usize,
337        node: &Node,
338    ) -> PoolResult<()> {
339        //添加到节点到 parent_id 的 content 中
340        let parent_shard_index = self.get_shard_index(parent_id);
341        let parent = self.nodes[parent_shard_index]
342            .get(parent_id)
343            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
344        let mut new_parent = parent.as_ref().clone();
345        new_parent.content.insert(index, node.id.clone());
346        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
347            .update(parent_id.clone(), Arc::new(new_parent));
348        self.parent_map.insert(node.id.clone(), parent_id.clone());
349        Ok(())
350    }
351    pub fn add_node(
352        &mut self,
353        parent_id: &NodeId,
354        nodes: &Vec<Node>,
355    ) -> PoolResult<()> {
356        let parent_shard_index = self.get_shard_index(parent_id);
357        let parent = self.nodes[parent_shard_index]
358            .get(parent_id)
359            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
360        let mut new_parent = parent.as_ref().clone();
361        new_parent.content.push_back(nodes[0].id.clone());
362        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
363            .update(parent_id.clone(), Arc::new(new_parent));
364        self.parent_map.insert(nodes[0].id.clone(), parent_id.clone());
365        for node in nodes {
366            let shard_index = self.get_shard_index(&node.id);
367            for child_id in &node.content {
368                self.parent_map.insert(child_id.clone(), node.id.clone());
369            }
370            self.nodes[shard_index] = self.nodes[shard_index]
371                .update(node.id.clone(), Arc::new(node.clone()));
372        }
373        Ok(())
374    }
375
376    pub fn node(
377        &mut self,
378        key: &str,
379    ) -> NodeRef<'_> {
380        NodeRef::new(self, key.to_string())
381    }
382    pub fn mark(
383        &mut self,
384        key: &str,
385    ) -> MarkRef<'_> {
386        MarkRef::new(self, key.to_string())
387    }
388    pub fn attrs(
389        &mut self,
390        key: &str,
391    ) -> AttrsRef<'_> {
392        AttrsRef::new(self, key.to_string())
393    }
394
395    pub fn children(
396        &self,
397        parent_id: &NodeId,
398    ) -> Option<im::Vector<NodeId>> {
399        self.get_node(parent_id).map(|n| n.content.clone())
400    }
401
402    pub fn children_node(
403        &self,
404        parent_id: &NodeId,
405    ) -> Option<im::Vector<Arc<Node>>> {
406        self.children(parent_id)
407            .map(|ids| ids.iter().filter_map(|id| self.get_node(id)).collect())
408    }
409    //递归获取所有子节点 封装成 NodeEnum 返回
410    pub fn all_children(
411        &self,
412        parent_id: &NodeId,
413        filter: Option<&dyn Fn(&Node) -> bool>,
414    ) -> Option<NodeEnum> {
415        if let Some(node) = self.get_node(parent_id) {
416            let mut child_enums = Vec::new();
417            for child_id in &node.content {
418                if let Some(child_node) = self.get_node(child_id) {
419                    // 检查子节点是否满足过滤条件
420                    if let Some(filter_fn) = filter {
421                        if !filter_fn(child_node.as_ref()) {
422                            continue; // 跳过不满足条件的子节点
423                        }
424                    }
425                    // 递归处理满足条件的子节点
426                    if let Some(child_enum) =
427                        self.all_children(child_id, filter)
428                    {
429                        child_enums.push(child_enum);
430                    }
431                }
432            }
433            Some(NodeEnum(node.as_ref().clone(), child_enums))
434        } else {
435            None
436        }
437    }
438
439    pub fn children_count(
440        &self,
441        parent_id: &NodeId,
442    ) -> usize {
443        self.get_node(parent_id).map(|n| n.content.len()).unwrap_or(0)
444    }
445    pub fn remove_mark_by_name(
446        &mut self,
447        id: &NodeId,
448        mark_name: &str,
449    ) -> PoolResult<()> {
450        let shard_index = self.get_shard_index(id);
451        let node = self.nodes[shard_index]
452            .get(id)
453            .ok_or(error_helpers::node_not_found(id.clone()))?;
454        let mut new_node = node.as_ref().clone();
455        new_node.marks = new_node
456            .marks
457            .iter()
458            .filter(|&m| m.r#type != mark_name)
459            .cloned()
460            .collect();
461        self.nodes[shard_index] =
462            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
463        Ok(())
464    }
465    pub fn get_marks(
466        &self,
467        id: &NodeId,
468    ) -> Option<im::Vector<Mark>> {
469        self.get_node(id).map(|n| n.marks.clone())
470    }
471
472    pub fn remove_mark(
473        &mut self,
474        id: &NodeId,
475        mark_types: &[String],
476    ) -> PoolResult<()> {
477        let shard_index = self.get_shard_index(id);
478        let node = self.nodes[shard_index]
479            .get(id)
480            .ok_or(error_helpers::node_not_found(id.clone()))?;
481        let mut new_node = node.as_ref().clone();
482        new_node.marks = new_node
483            .marks
484            .iter()
485            .filter(|&m| !mark_types.contains(&m.r#type))
486            .cloned()
487            .collect();
488        self.nodes[shard_index] =
489            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
490        Ok(())
491    }
492
493    pub fn add_mark(
494        &mut self,
495        id: &NodeId,
496        marks: &Vec<Mark>,
497    ) -> PoolResult<()> {
498        let shard_index = self.get_shard_index(id);
499        let node = self.nodes[shard_index]
500            .get(id)
501            .ok_or(error_helpers::node_not_found(id.clone()))?;
502        let mut new_node = node.as_ref().clone();
503        //如果存在相同类型的mark,则覆盖
504        let mark_types =
505            marks.iter().map(|m| m.r#type.clone()).collect::<Vec<String>>();
506        new_node.marks = new_node
507            .marks
508            .iter()
509            .filter(|m| !mark_types.contains(&m.r#type))
510            .cloned()
511            .collect();
512        new_node.marks.extend(marks.iter().map(|m| m.clone()));
513        self.nodes[shard_index] =
514            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
515        Ok(())
516    }
517
518    pub fn move_node(
519        &mut self,
520        source_parent_id: &NodeId,
521        target_parent_id: &NodeId,
522        node_id: &NodeId,
523        position: Option<usize>,
524    ) -> PoolResult<()> {
525        let source_shard_index = self.get_shard_index(source_parent_id);
526        let target_shard_index = self.get_shard_index(target_parent_id);
527        let node_shard_index = self.get_shard_index(node_id);
528        let source_parent = self.nodes[source_shard_index]
529            .get(source_parent_id)
530            .ok_or(error_helpers::parent_not_found(source_parent_id.clone()))?;
531        let target_parent = self.nodes[target_shard_index]
532            .get(target_parent_id)
533            .ok_or(error_helpers::parent_not_found(target_parent_id.clone()))?;
534        let _node = self.nodes[node_shard_index]
535            .get(node_id)
536            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
537        if !source_parent.content.contains(node_id) {
538            return Err(error_helpers::invalid_parenting(
539                node_id.clone(),
540                source_parent_id.clone(),
541            ));
542        }
543        let mut new_source_parent = source_parent.as_ref().clone();
544        new_source_parent.content = new_source_parent
545            .content
546            .iter()
547            .filter(|&id| id != node_id)
548            .cloned()
549            .collect();
550        let mut new_target_parent = target_parent.as_ref().clone();
551        if let Some(pos) = position {
552            // 确保position不超过当前content的长度
553            let insert_pos = pos.min(new_target_parent.content.len());
554
555            if insert_pos == new_target_parent.content.len() {
556                // 在末尾插入
557                new_target_parent.content.push_back(node_id.clone());
558            } else {
559                // 在指定位置插入
560                let mut new_content = im::Vector::new();
561
562                // 添加position之前的所有元素
563                for (i, child_id) in
564                    new_target_parent.content.iter().enumerate()
565                {
566                    if i == insert_pos {
567                        // 在这个位置插入新节点
568                        new_content.push_back(node_id.clone());
569                    }
570                    new_content.push_back(child_id.clone());
571                }
572
573                new_target_parent.content = new_content;
574            }
575        } else {
576            // 没有指定位置,添加到末尾
577            new_target_parent.content.push_back(node_id.clone());
578        }
579        self.nodes[source_shard_index] = self.nodes[source_shard_index]
580            .update(source_parent_id.clone(), Arc::new(new_source_parent));
581        self.nodes[target_shard_index] = self.nodes[target_shard_index]
582            .update(target_parent_id.clone(), Arc::new(new_target_parent));
583        self.parent_map.insert(node_id.clone(), target_parent_id.clone());
584        Ok(())
585    }
586
587    pub fn remove_node(
588        &mut self,
589        parent_id: &NodeId,
590        nodes: Vec<NodeId>,
591    ) -> PoolResult<()> {
592        let parent_shard_index = self.get_shard_index(parent_id);
593        let parent = self.nodes[parent_shard_index]
594            .get(parent_id)
595            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
596        if nodes.contains(&self.root_id) {
597            return Err(error_helpers::cannot_remove_root());
598        }
599        for node_id in &nodes {
600            if !parent.content.contains(node_id) {
601                return Err(error_helpers::invalid_parenting(
602                    node_id.clone(),
603                    parent_id.clone(),
604                ));
605            }
606        }
607        let nodes_to_remove: std::collections::HashSet<_> =
608            nodes.iter().collect();
609        let filtered_children: im::Vector<NodeId> = parent
610            .as_ref()
611            .content
612            .iter()
613            .filter(|&id| !nodes_to_remove.contains(id))
614            .cloned()
615            .collect();
616        let mut parent_node = parent.as_ref().clone();
617        parent_node.content = filtered_children;
618        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
619            .update(parent_id.clone(), Arc::new(parent_node));
620        let mut remove_nodes = Vec::new();
621        for node_id in nodes {
622            self.remove_subtree(&node_id, &mut remove_nodes)?;
623        }
624        Ok(())
625    }
626    //=删除节点
627    pub fn remove_node_by_id(
628        &mut self,
629        node_id: &NodeId,
630    ) -> PoolResult<()> {
631        // 检查是否试图删除根节点
632        if node_id == &self.root_id {
633            return Err(error_helpers::cannot_remove_root());
634        }
635
636        let shard_index = self.get_shard_index(node_id);
637        let _ = self.nodes[shard_index]
638            .get(node_id)
639            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
640
641        // 从父节点的content中移除该节点
642        if let Some(parent_id) = self.parent_map.get(node_id).cloned() {
643            let parent_shard_index = self.get_shard_index(&parent_id);
644            if let Some(parent_node) =
645                self.nodes[parent_shard_index].get(&parent_id)
646            {
647                let mut new_parent = parent_node.as_ref().clone();
648                new_parent.content = new_parent
649                    .content
650                    .iter()
651                    .filter(|&id| id != node_id)
652                    .cloned()
653                    .collect();
654                self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
655                    .update(parent_id.clone(), Arc::new(new_parent));
656            }
657        }
658
659        // 删除子树(remove_subtree内部已经处理了节点的删除和parent_map的清理)
660        let mut remove_nodes = Vec::new();
661        self.remove_subtree(node_id, &mut remove_nodes)?;
662
663        // remove_subtree已经删除了所有节点,包括node_id本身,所以这里不需要再次删除
664        Ok(())
665    }
666
667    ///根据下标删除
668    pub fn remove_node_by_index(
669        &mut self,
670        parent_id: &NodeId,
671        index: usize,
672    ) -> PoolResult<()> {
673        let shard_index = self.get_shard_index(parent_id);
674        let parent = self.nodes[shard_index]
675            .get(parent_id)
676            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
677        let mut new_parent = parent.as_ref().clone();
678        let remove_node_id = new_parent.content.remove(index);
679        self.nodes[shard_index] = self.nodes[shard_index]
680            .update(parent_id.clone(), Arc::new(new_parent));
681        let mut remove_nodes = Vec::new();
682        self.remove_subtree(&remove_node_id, &mut remove_nodes)?;
683        for node in remove_nodes {
684            let shard_index = self.get_shard_index(&node.id);
685            self.nodes[shard_index].remove(&node.id);
686            self.parent_map.remove(&node.id);
687        }
688        Ok(())
689    }
690
691    //删除子树
692    fn remove_subtree(
693        &mut self,
694        node_id: &NodeId,
695        remove_nodes: &mut Vec<Node>,
696    ) -> PoolResult<()> {
697        if node_id == &self.root_id {
698            return Err(error_helpers::cannot_remove_root());
699        }
700        let shard_index = self.get_shard_index(node_id);
701        let _ = self.nodes[shard_index]
702            .get(node_id)
703            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
704        if let Some(children) = self.children(node_id) {
705            for child_id in children {
706                self.remove_subtree(&child_id, remove_nodes)?;
707            }
708        }
709        self.parent_map.remove(node_id);
710        if let Some(remove_node) = self.nodes[shard_index].remove(node_id) {
711            remove_nodes.push(remove_node.as_ref().clone());
712        }
713        Ok(())
714    }
715}
716
717impl Index<&NodeId> for Tree {
718    type Output = Arc<Node>;
719    fn index(
720        &self,
721        index: &NodeId,
722    ) -> &Self::Output {
723        let shard_index = self.get_shard_index(index);
724        self.nodes[shard_index].get(index).expect("Node not found")
725    }
726}
727
728impl Index<&str> for Tree {
729    type Output = Arc<Node>;
730    fn index(
731        &self,
732        index: &str,
733    ) -> &Self::Output {
734        let node_id = NodeId::from(index);
735        let shard_index = self.get_shard_index(&node_id);
736        self.nodes[shard_index].get(&node_id).expect("Node not found")
737    }
738}
739
740#[cfg(test)]
741mod tests {
742    use super::*;
743    use crate::node::Node;
744    use crate::attrs::Attrs;
745    use crate::mark::Mark;
746    use im::HashMap;
747    use serde_json::json;
748
749    fn create_test_node(id: &str) -> Node {
750        Node::new(id, "test".to_string(), Attrs::default(), vec![], vec![])
751    }
752
753    #[test]
754    fn test_tree_creation() {
755        let root = create_test_node("root");
756        let tree = Tree::new(root.clone());
757        assert_eq!(tree.root_id, root.id);
758        assert!(tree.contains_node(&root.id));
759    }
760
761    #[test]
762    fn test_add_node() {
763        let root = create_test_node("root");
764        let mut tree = Tree::new(root.clone());
765
766        let child = create_test_node("child");
767        let nodes = vec![child.clone()];
768
769        tree.add_node(&root.id, &nodes).unwrap();
770        dbg!(&tree);
771        assert!(tree.contains_node(&child.id));
772        assert_eq!(tree.children(&root.id).unwrap().len(), 1);
773    }
774
775    #[test]
776    fn test_remove_node() {
777        let root = create_test_node("root");
778        let mut tree = Tree::new(root.clone());
779
780        let child = create_test_node("child");
781        let nodes = vec![child.clone()];
782
783        tree.add_node(&root.id, &nodes).unwrap();
784        dbg!(&tree);
785        tree.remove_node(&root.id, vec![child.id.clone()]).unwrap();
786        dbg!(&tree);
787        assert!(!tree.contains_node(&child.id));
788        assert_eq!(tree.children(&root.id).unwrap().len(), 0);
789    }
790
791    #[test]
792    fn test_move_node() {
793        // 创建两个父节点
794        let parent1 = create_test_node("parent1");
795        let parent2 = create_test_node("parent2");
796        let mut tree = Tree::new(parent1.clone());
797
798        // 将 parent2 添加为 parent1 的子节点
799        tree.add_node(&parent1.id, &vec![parent2.clone()]).unwrap();
800
801        // 创建三个子节点
802        let child1 = create_test_node("child1");
803        let child2 = create_test_node("child2");
804        let child3 = create_test_node("child3");
805
806        // 将所有子节点添加到 parent1 下
807        tree.add_node(&parent1.id, &vec![child1.clone()]).unwrap();
808        tree.add_node(&parent1.id, &vec![child2.clone()]).unwrap();
809        tree.add_node(&parent1.id, &vec![child3.clone()]).unwrap();
810
811        // 验证初始状态
812        let parent1_children = tree.children(&parent1.id).unwrap();
813        assert_eq!(parent1_children.len(), 4); // parent2 + 3 children
814        assert_eq!(parent1_children[0], parent2.id);
815        assert_eq!(parent1_children[1], child1.id);
816        assert_eq!(parent1_children[2], child2.id);
817        assert_eq!(parent1_children[3], child3.id);
818
819        // 将 child1 移动到 parent2 下
820        tree.move_node(&parent1.id, &parent2.id, &child1.id, None).unwrap();
821
822        // 验证移动后的状态
823        let parent1_children = tree.children(&parent1.id).unwrap();
824        let parent2_children = tree.children(&parent2.id).unwrap();
825        assert_eq!(parent1_children.len(), 3); // parent2 + 2 children
826        assert_eq!(parent2_children.len(), 1); // child1
827        assert_eq!(parent2_children[0], child1.id);
828
829        // 将 child2 移动到 parent2 下,放在 child1 后面
830        tree.move_node(&parent1.id, &parent2.id, &child2.id, Some(1)).unwrap();
831
832        // 验证最终状态
833        let parent1_children = tree.children(&parent1.id).unwrap();
834        let parent2_children = tree.children(&parent2.id).unwrap();
835        assert_eq!(parent1_children.len(), 2); // parent2 + 1 child
836        assert_eq!(parent2_children.len(), 2); // child1 + child2
837        assert_eq!(parent2_children[0], child1.id);
838        assert_eq!(parent2_children[1], child2.id);
839
840        // 验证父节点关系
841        let child1_parent = tree.get_parent_node(&child1.id).unwrap();
842        let child2_parent = tree.get_parent_node(&child2.id).unwrap();
843        assert_eq!(child1_parent.id, parent2.id);
844        assert_eq!(child2_parent.id, parent2.id);
845    }
846
847    #[test]
848    fn test_update_attr() {
849        let root = create_test_node("root");
850        let mut tree = Tree::new(root.clone());
851
852        let mut attrs = HashMap::new();
853        attrs.insert("key".to_string(), json!("value"));
854
855        tree.update_attr(&root.id, attrs).unwrap();
856
857        let node = tree.get_node(&root.id).unwrap();
858        dbg!(&node);
859        assert_eq!(node.attrs.get("key").unwrap(), &json!("value"));
860    }
861
862    #[test]
863    fn test_add_mark() {
864        let root = create_test_node("root");
865        let mut tree = Tree::new(root.clone());
866
867        let mark = Mark { r#type: "test".to_string(), attrs: Attrs::default() };
868        tree.add_mark(&root.id, &vec![mark.clone()]).unwrap();
869        dbg!(&tree);
870        let node = tree.get_node(&root.id).unwrap();
871        assert!(node.marks.contains(&mark));
872    }
873
874    #[test]
875    fn test_remove_mark() {
876        let root = create_test_node("root");
877        let mut tree = Tree::new(root.clone());
878
879        let mark = Mark { r#type: "test".to_string(), attrs: Attrs::default() };
880        tree.add_mark(&root.id, &vec![mark.clone()]).unwrap();
881        dbg!(&tree);
882        tree.remove_mark(&root.id, &[mark.r#type.clone()]).unwrap();
883        dbg!(&tree);
884        let node = tree.get_node(&root.id).unwrap();
885        assert!(!node.marks.iter().any(|m| m.r#type == mark.r#type));
886    }
887
888    #[test]
889    fn test_all_children() {
890        let root = create_test_node("root");
891        let mut tree = Tree::new(root.clone());
892
893        let child1 = create_test_node("child1");
894        let child2 = create_test_node("child2");
895
896        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
897        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
898        dbg!(&tree);
899        let all_children = tree.all_children(&root.id, None).unwrap();
900        assert_eq!(all_children.1.len(), 2);
901    }
902
903    #[test]
904    fn test_children_count() {
905        let root = create_test_node("root");
906        let mut tree = Tree::new(root.clone());
907
908        let child1 = create_test_node("child1");
909        let child2 = create_test_node("child2");
910
911        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
912        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
913
914        assert_eq!(tree.children_count(&root.id), 2);
915    }
916
917    #[test]
918    fn test_remove_node_by_id_updates_parent() {
919        let root = create_test_node("root");
920        let mut tree = Tree::new(root.clone());
921
922        let child = create_test_node("child");
923        tree.add_node(&root.id, &vec![child.clone()]).unwrap();
924
925        // 验证子节点被添加
926        assert_eq!(tree.children_count(&root.id), 1);
927        assert!(tree.contains_node(&child.id));
928
929        // 删除子节点
930        tree.remove_node_by_id(&child.id).unwrap();
931
932        // 验证子节点被删除且父节点的content被更新
933        assert_eq!(tree.children_count(&root.id), 0);
934        assert!(!tree.contains_node(&child.id));
935    }
936
937    #[test]
938    fn test_move_node_position_edge_cases() {
939        let root = create_test_node("root");
940        let mut tree = Tree::new(root.clone());
941
942        let container = create_test_node("container");
943        tree.add_node(&root.id, &vec![container.clone()]).unwrap();
944
945        let child1 = create_test_node("child1");
946        let child2 = create_test_node("child2");
947        let child3 = create_test_node("child3");
948
949        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
950        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
951        tree.add_node(&root.id, &vec![child3.clone()]).unwrap();
952
953        // 测试移动到超出范围的位置(应该插入到末尾)
954        tree.move_node(&root.id, &container.id, &child1.id, Some(100)).unwrap();
955
956        let container_children = tree.children(&container.id).unwrap();
957        assert_eq!(container_children.len(), 1);
958        assert_eq!(container_children[0], child1.id);
959
960        // 测试移动到位置0
961        tree.move_node(&root.id, &container.id, &child2.id, Some(0)).unwrap();
962
963        let container_children = tree.children(&container.id).unwrap();
964        assert_eq!(container_children.len(), 2);
965        assert_eq!(container_children[0], child2.id);
966        assert_eq!(container_children[1], child1.id);
967    }
968
969    #[test]
970    fn test_cannot_remove_root_node() {
971        let root = create_test_node("root");
972        let mut tree = Tree::new(root.clone());
973
974        // 尝试删除根节点应该失败
975        let result = tree.remove_node_by_id(&root.id);
976        assert!(result.is_err());
977    }
978
979    #[test]
980    fn test_get_parent_node() {
981        let root = create_test_node("root");
982        let mut tree = Tree::new(root.clone());
983
984        let child = create_test_node("child");
985        tree.add_node(&root.id, &vec![child.clone()]).unwrap();
986
987        let parent = tree.get_parent_node(&child.id).unwrap();
988        assert_eq!(parent.id, root.id);
989    }
990}