moduforge_model/
tree.rs

1use std::{ops::Index, sync::Arc, num::NonZeroUsize};
2use std::hash::{Hash, Hasher};
3use std::collections::hash_map::DefaultHasher;
4use im::Vector;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use once_cell::sync::Lazy;
8use parking_lot::RwLock;
9use lru::LruCache;
10use std::fmt::{self, Debug};
11use crate::error::PoolResult;
12use crate::node_type::NodeEnum;
13use crate::{
14    error::error_helpers,
15    mark::Mark,
16    node::Node,
17    ops::{AttrsRef, MarkRef, NodeRef},
18    types::NodeId,
19};
20
21// 全局LRU缓存用于存储NodeId到分片索引的映射
22static SHARD_INDEX_CACHE: Lazy<RwLock<LruCache<String, usize>>> =
23    Lazy::new(|| RwLock::new(LruCache::new(NonZeroUsize::new(10000).unwrap())));
24
25#[derive(Clone, PartialEq, Serialize, Deserialize)]
26pub struct Tree {
27    pub root_id: NodeId,
28    pub nodes: Vector<im::HashMap<NodeId, Arc<Node>>>, // 分片存储节点数据
29    pub parent_map: im::HashMap<NodeId, NodeId>,
30    #[serde(skip)]
31    num_shards: usize, // 缓存分片数量,避免重复计算
32}
33impl Debug for Tree {
34    fn fmt(
35        &self,
36        f: &mut fmt::Formatter<'_>,
37    ) -> fmt::Result {
38        //输出的时候 过滤掉空的 nodes 节点
39        let nodes = self
40            .nodes
41            .iter()
42            .filter(|node| !node.is_empty())
43            .collect::<Vec<_>>();
44        f.debug_struct("Tree")
45            .field("root_id", &self.root_id)
46            .field("nodes", &nodes)
47            .field("parent_map", &self.parent_map)
48            .field("num_shards", &self.num_shards)
49            .finish()
50    }
51}
52
53impl Tree {
54    #[inline]
55    pub fn get_shard_index(
56        &self,
57        id: &NodeId,
58    ) -> usize {
59        // 先检查缓存
60        {
61            let cache = SHARD_INDEX_CACHE.read();
62            if let Some(&index) = cache.peek(id) {
63                return index;
64            }
65        }
66
67        // 缓存未命中,计算哈希值
68        let mut hasher = DefaultHasher::new();
69        id.hash(&mut hasher);
70        let index = (hasher.finish() as usize) % self.num_shards;
71
72        // 更新缓存
73        {
74            let mut cache = SHARD_INDEX_CACHE.write();
75            cache.put(id.clone(), index);
76        }
77
78        index
79    }
80
81    #[inline]
82    pub fn get_shard_indices(
83        &self,
84        ids: &[&NodeId],
85    ) -> Vec<usize> {
86        ids.iter().map(|id| self.get_shard_index(id)).collect()
87    }
88
89    // 为批量操作提供优化的哈希计算
90    #[inline]
91    pub fn get_shard_index_batch<'a>(
92        &self,
93        ids: &'a [&'a NodeId],
94    ) -> Vec<(usize, &'a NodeId)> {
95        let mut results = Vec::with_capacity(ids.len());
96        let mut cache_misses = Vec::new();
97
98        // 批量检查缓存
99        {
100            let cache = SHARD_INDEX_CACHE.read();
101            for &id in ids {
102                if let Some(&index) = cache.peek(id) {
103                    results.push((index, id));
104                } else {
105                    cache_misses.push(id);
106                }
107            }
108        }
109
110        // 批量计算缓存未命中的项
111        if !cache_misses.is_empty() {
112            let mut cache = SHARD_INDEX_CACHE.write();
113            for &id in &cache_misses {
114                let mut hasher = DefaultHasher::new();
115                id.hash(&mut hasher);
116                let index = (hasher.finish() as usize) % self.num_shards;
117                cache.put(id.clone(), index);
118                results.push((index, id));
119            }
120        }
121
122        results
123    }
124
125    // 清理缓存的方法,用于内存管理
126    pub fn clear_shard_cache() {
127        let mut cache = SHARD_INDEX_CACHE.write();
128        cache.clear();
129    }
130
131    pub fn contains_node(
132        &self,
133        id: &NodeId,
134    ) -> bool {
135        let shard_index = self.get_shard_index(id);
136        self.nodes[shard_index].contains_key(id)
137    }
138
139    pub fn get_node(
140        &self,
141        id: &NodeId,
142    ) -> Option<Arc<Node>> {
143        let shard_index = self.get_shard_index(id);
144        self.nodes[shard_index].get(id).cloned()
145    }
146
147    pub fn get_parent_node(
148        &self,
149        id: &NodeId,
150    ) -> Option<Arc<Node>> {
151        self.parent_map.get(id).and_then(|parent_id| {
152            let shard_index = self.get_shard_index(parent_id);
153            self.nodes[shard_index].get(parent_id).cloned()
154        })
155    }
156    pub fn from(nodes: NodeEnum) -> Self {
157        let num_shards = std::cmp::max(
158            std::thread::available_parallelism()
159                .map(NonZeroUsize::get)
160                .unwrap_or(2),
161            2,
162        );
163        let mut shards = Vector::from(vec![im::HashMap::new(); num_shards]);
164        let mut parent_map = im::HashMap::new();
165        let (root_node, children) = nodes.into_parts();
166        let root_id = root_node.id.clone();
167
168        let mut hasher = DefaultHasher::new();
169        root_id.hash(&mut hasher);
170        let shard_index = (hasher.finish() as usize) % num_shards;
171        shards[shard_index] =
172            shards[shard_index].update(root_id.clone(), Arc::new(root_node));
173
174        fn process_children(
175            children: Vec<NodeEnum>,
176            parent_id: &NodeId,
177            shards: &mut Vector<im::HashMap<NodeId, Arc<Node>>>,
178            parent_map: &mut im::HashMap<NodeId, NodeId>,
179            num_shards: usize,
180        ) {
181            for child in children {
182                let (node, grand_children) = child.into_parts();
183                let node_id = node.id.clone();
184                let mut hasher = DefaultHasher::new();
185                node_id.hash(&mut hasher);
186                let shard_index = (hasher.finish() as usize) % num_shards;
187                shards[shard_index] =
188                    shards[shard_index].update(node_id.clone(), Arc::new(node));
189                parent_map.insert(node_id.clone(), parent_id.clone());
190
191                // Recursively process grand children
192                process_children(
193                    grand_children,
194                    &node_id,
195                    shards,
196                    parent_map,
197                    num_shards,
198                );
199            }
200        }
201
202        process_children(
203            children,
204            &root_id,
205            &mut shards,
206            &mut parent_map,
207            num_shards,
208        );
209
210        Self { root_id, nodes: shards, parent_map, num_shards }
211    }
212
213    pub fn new(root: Node) -> Self {
214        let num_shards = std::cmp::max(
215            std::thread::available_parallelism()
216                .map(NonZeroUsize::get)
217                .unwrap_or(2),
218            2,
219        );
220        let mut nodes = Vector::from(vec![im::HashMap::new(); num_shards]);
221        let root_id = root.id.clone();
222        let mut hasher = DefaultHasher::new();
223        root_id.hash(&mut hasher);
224        let shard_index = (hasher.finish() as usize) % num_shards;
225        nodes[shard_index] =
226            nodes[shard_index].update(root_id.clone(), Arc::new(root));
227        Self { root_id, nodes, parent_map: im::HashMap::new(), num_shards }
228    }
229
230    pub fn update_attr(
231        &mut self,
232        id: &NodeId,
233        new_values: im::HashMap<String, Value>,
234    ) -> PoolResult<()> {
235        let shard_index = self.get_shard_index(id);
236        let node = self.nodes[shard_index]
237            .get(id)
238            .ok_or(error_helpers::node_not_found(id.clone()))?;
239        let old_values = node.attrs.clone();
240        let mut new_node = node.as_ref().clone();
241        let new_attrs = old_values.update(new_values);
242        new_node.attrs = new_attrs.clone();
243        self.nodes[shard_index] =
244            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
245        Ok(())
246    }
247    pub fn update_node(
248        &mut self,
249        node: Node,
250    ) -> PoolResult<()> {
251        let shard_index = self.get_shard_index(&node.id);
252        self.nodes[shard_index] =
253            self.nodes[shard_index].update(node.id.clone(), Arc::new(node));
254        Ok(())
255    }
256
257    /// 向树中添加新的节点及其子节点
258    ///
259    /// # 参数
260    /// * `nodes` - 要添加的节点枚举,包含节点本身及其子节点
261    ///
262    /// # 返回值
263    /// * `Result<(), PoolError>` - 如果添加成功返回 Ok(()), 否则返回错误
264    ///
265    /// # 错误
266    /// * `PoolError::ParentNotFound` - 如果父节点不存在
267    pub fn add(
268        &mut self,
269        parent_id: &NodeId,
270        nodes: Vec<NodeEnum>,
271    ) -> PoolResult<()> {
272        // 检查父节点是否存在
273        let parent_shard_index = self.get_shard_index(&parent_id);
274        let parent_node = self.nodes[parent_shard_index]
275            .get(parent_id)
276            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
277        let mut new_parent = parent_node.as_ref().clone();
278
279        // 收集所有子节点的ID并添加到当前节点的content中
280        let zenliang: Vector<String> =
281            nodes.iter().map(|n| n.0.id.clone()).collect();
282        // 需要判断 new_parent.content 中是否已经存在 zenliang 中的节点
283        let mut new_content = im::Vector::new();
284        for id in zenliang {
285            if !new_parent.content.contains(&id) {
286                new_content.push_back(id);
287            }
288        }
289        new_parent.content.extend(new_content);
290
291        // 更新当前节点
292        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
293            .update(parent_id.clone(), Arc::new(new_parent));
294
295        // 使用队列进行广度优先遍历,处理所有子节点
296        let mut node_queue = Vec::new();
297        node_queue.push((nodes, parent_id.clone()));
298        while let Some((current_children, current_parent_id)) = node_queue.pop()
299        {
300            for child in current_children {
301                // 处理每个子节点
302                let (mut child_node, grand_children) = child.into_parts();
303                let current_node_id = child_node.id.clone();
304
305                // 收集孙节点的ID并添加到子节点的content中
306                let grand_children_ids: Vector<String> =
307                    grand_children.iter().map(|n| n.0.id.clone()).collect();
308                let mut new_content = im::Vector::new();
309                for id in grand_children_ids {
310                    if !child_node.content.contains(&id) {
311                        new_content.push_back(id);
312                    }
313                }
314                child_node.content.extend(new_content);
315
316                // 将当前节点存储到对应的分片中
317                let shard_index = self.get_shard_index(&current_node_id);
318                self.nodes[shard_index] = self.nodes[shard_index]
319                    .update(current_node_id.clone(), Arc::new(child_node));
320
321                // 更新父子关系映射
322                self.parent_map
323                    .insert(current_node_id.clone(), current_parent_id.clone());
324
325                // 将孙节点加入队列,以便后续处理
326                node_queue.push((grand_children, current_node_id.clone()));
327            }
328        }
329        Ok(())
330    }
331    // 添加到下标
332    pub fn add_at_index(
333        &mut self,
334        parent_id: &NodeId,
335        index: usize,
336        node: &Node,
337    ) -> PoolResult<()> {
338        //添加到节点到 parent_id 的 content 中
339        let parent_shard_index = self.get_shard_index(parent_id);
340        let parent = self.nodes[parent_shard_index]
341            .get(parent_id)
342            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
343        let mut new_parent = parent.as_ref().clone();
344        new_parent.content.insert(index, node.id.clone());
345        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
346            .update(parent_id.clone(), Arc::new(new_parent));
347        self.parent_map.insert(node.id.clone(), parent_id.clone());
348        Ok(())
349    }
350    pub fn add_node(
351        &mut self,
352        parent_id: &NodeId,
353        nodes: &Vec<Node>,
354    ) -> PoolResult<()> {
355        let parent_shard_index = self.get_shard_index(parent_id);
356        let parent = self.nodes[parent_shard_index]
357            .get(parent_id)
358            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
359        let mut new_parent = parent.as_ref().clone();
360        new_parent.content.push_back(nodes[0].id.clone());
361        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
362            .update(parent_id.clone(), Arc::new(new_parent));
363        self.parent_map.insert(nodes[0].id.clone(), parent_id.clone());
364        for node in nodes {
365            let shard_index = self.get_shard_index(&node.id);
366            for child_id in &node.content {
367                self.parent_map.insert(child_id.clone(), node.id.clone());
368            }
369            self.nodes[shard_index] = self.nodes[shard_index]
370                .update(node.id.clone(), Arc::new(node.clone()));
371        }
372        Ok(())
373    }
374
375    pub fn node(
376        &mut self,
377        key: &str,
378    ) -> NodeRef<'_> {
379        NodeRef::new(self, key.to_string())
380    }
381    pub fn mark(
382        &mut self,
383        key: &str,
384    ) -> MarkRef<'_> {
385        MarkRef::new(self, key.to_string())
386    }
387    pub fn attrs(
388        &mut self,
389        key: &str,
390    ) -> AttrsRef<'_> {
391        AttrsRef::new(self, key.to_string())
392    }
393
394    pub fn children(
395        &self,
396        parent_id: &NodeId,
397    ) -> Option<im::Vector<NodeId>> {
398        self.get_node(parent_id).map(|n| n.content.clone())
399    }
400
401    pub fn children_node(
402        &self,
403        parent_id: &NodeId,
404    ) -> Option<im::Vector<Arc<Node>>> {
405        self.children(parent_id)
406            .map(|ids| ids.iter().filter_map(|id| self.get_node(id)).collect())
407    }
408    //递归获取所有子节点 封装成 NodeEnum 返回
409    pub fn all_children(
410        &self,
411        parent_id: &NodeId,
412        filter: Option<&dyn Fn(&Node) -> bool>,
413    ) -> Option<NodeEnum> {
414        if let Some(node) = self.get_node(parent_id) {
415            let mut child_enums = Vec::new();
416            for child_id in &node.content {
417                if let Some(child_node) = self.get_node(child_id) {
418                    // 检查子节点是否满足过滤条件
419                    if let Some(filter_fn) = filter {
420                        if !filter_fn(child_node.as_ref()) {
421                            continue; // 跳过不满足条件的子节点
422                        }
423                    }
424                    // 递归处理满足条件的子节点
425                    if let Some(child_enum) =
426                        self.all_children(child_id, filter)
427                    {
428                        child_enums.push(child_enum);
429                    }
430                }
431            }
432            Some(NodeEnum(node.as_ref().clone(), child_enums))
433        } else {
434            None
435        }
436    }
437
438    pub fn children_count(
439        &self,
440        parent_id: &NodeId,
441    ) -> usize {
442        self.get_node(parent_id).map(|n| n.content.len()).unwrap_or(0)
443    }
444    pub fn remove_mark_by_name(
445        &mut self,
446        id: &NodeId,
447        mark_name: &str,
448    ) -> PoolResult<()> {
449        let shard_index = self.get_shard_index(id);
450        let node = self.nodes[shard_index]
451            .get(id)
452            .ok_or(error_helpers::node_not_found(id.clone()))?;
453        let mut new_node = node.as_ref().clone();
454        new_node.marks = new_node
455            .marks
456            .iter()
457            .filter(|&m| m.r#type != mark_name)
458            .cloned()
459            .collect();
460        self.nodes[shard_index] =
461            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
462        Ok(())
463    }
464    pub fn get_marks(
465        &self,
466        id: &NodeId,
467    ) -> Option<im::Vector<Mark>> {
468        self.get_node(id).map(|n| n.marks.clone())
469    }
470
471    pub fn remove_mark(
472        &mut self,
473        id: &NodeId,
474        mark_types: &[String],
475    ) -> PoolResult<()> {
476        let shard_index = self.get_shard_index(id);
477        let node = self.nodes[shard_index]
478            .get(id)
479            .ok_or(error_helpers::node_not_found(id.clone()))?;
480        let mut new_node = node.as_ref().clone();
481        new_node.marks =
482            new_node.marks.iter().filter(|&m| !mark_types.contains(&m.r#type)).cloned().collect();
483        self.nodes[shard_index] =
484            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
485        Ok(())
486    }
487
488    pub fn add_mark(
489        &mut self,
490        id: &NodeId,
491        marks: &Vec<Mark>,
492    ) -> PoolResult<()> {
493        let shard_index = self.get_shard_index(id);
494        let node = self.nodes[shard_index]
495            .get(id)
496            .ok_or(error_helpers::node_not_found(id.clone()))?;
497        let mut new_node = node.as_ref().clone();
498        //如果存在相同类型的mark,则覆盖
499        for mark in marks {
500            if new_node.marks.iter().any(|m| m.r#type == mark.r#type) {
501                new_node.marks = new_node.marks.iter().filter(|m| m.r#type != mark.r#type).cloned().collect();
502            } else {
503                new_node.marks.push_back(mark.clone());
504            }
505        }
506        self.nodes[shard_index] =
507            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
508        Ok(())
509    }
510
511    pub fn move_node(
512        &mut self,
513        source_parent_id: &NodeId,
514        target_parent_id: &NodeId,
515        node_id: &NodeId,
516        position: Option<usize>,
517    ) -> PoolResult<()> {
518        let source_shard_index = self.get_shard_index(source_parent_id);
519        let target_shard_index = self.get_shard_index(target_parent_id);
520        let node_shard_index = self.get_shard_index(node_id);
521        let source_parent = self.nodes[source_shard_index]
522            .get(source_parent_id)
523            .ok_or(error_helpers::parent_not_found(source_parent_id.clone()))?;
524        let target_parent = self.nodes[target_shard_index]
525            .get(target_parent_id)
526            .ok_or(error_helpers::parent_not_found(target_parent_id.clone()))?;
527        let _node = self.nodes[node_shard_index]
528            .get(node_id)
529            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
530        if !source_parent.content.contains(node_id) {
531            return Err(error_helpers::invalid_parenting(
532                node_id.clone(),
533                source_parent_id.clone(),
534            ));
535        }
536        let mut new_source_parent = source_parent.as_ref().clone();
537        new_source_parent.content = new_source_parent
538            .content
539            .iter()
540            .filter(|&id| id != node_id)
541            .cloned()
542            .collect();
543        let mut new_target_parent = target_parent.as_ref().clone();
544        if let Some(pos) = position {
545            // 确保position不超过当前content的长度
546            let insert_pos = pos.min(new_target_parent.content.len());
547
548            if insert_pos == new_target_parent.content.len() {
549                // 在末尾插入
550                new_target_parent.content.push_back(node_id.clone());
551            } else {
552                // 在指定位置插入
553                let mut new_content = im::Vector::new();
554
555                // 添加position之前的所有元素
556                for (i, child_id) in
557                    new_target_parent.content.iter().enumerate()
558                {
559                    if i == insert_pos {
560                        // 在这个位置插入新节点
561                        new_content.push_back(node_id.clone());
562                    }
563                    new_content.push_back(child_id.clone());
564                }
565
566                new_target_parent.content = new_content;
567            }
568        } else {
569            // 没有指定位置,添加到末尾
570            new_target_parent.content.push_back(node_id.clone());
571        }
572        self.nodes[source_shard_index] = self.nodes[source_shard_index]
573            .update(source_parent_id.clone(), Arc::new(new_source_parent));
574        self.nodes[target_shard_index] = self.nodes[target_shard_index]
575            .update(target_parent_id.clone(), Arc::new(new_target_parent));
576        self.parent_map.insert(node_id.clone(), target_parent_id.clone());
577        Ok(())
578    }
579
580    pub fn remove_node(
581        &mut self,
582        parent_id: &NodeId,
583        nodes: Vec<NodeId>,
584    ) -> PoolResult<()> {
585        let parent_shard_index = self.get_shard_index(parent_id);
586        let parent = self.nodes[parent_shard_index]
587            .get(parent_id)
588            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
589        if nodes.contains(&self.root_id) {
590            return Err(error_helpers::cannot_remove_root());
591        }
592        for node_id in &nodes {
593            if !parent.content.contains(node_id) {
594                return Err(error_helpers::invalid_parenting(
595                    node_id.clone(),
596                    parent_id.clone(),
597                ));
598            }
599        }
600        let nodes_to_remove: std::collections::HashSet<_> =
601            nodes.iter().collect();
602        let filtered_children: im::Vector<NodeId> = parent
603            .as_ref()
604            .content
605            .iter()
606            .filter(|&id| !nodes_to_remove.contains(id))
607            .cloned()
608            .collect();
609        let mut parent_node = parent.as_ref().clone();
610        parent_node.content = filtered_children;
611        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
612            .update(parent_id.clone(), Arc::new(parent_node));
613        let mut remove_nodes = Vec::new();
614        for node_id in nodes {
615            self.remove_subtree(&node_id, &mut remove_nodes)?;
616        }
617        Ok(())
618    }
619    //=删除节点
620    pub fn remove_node_by_id(
621        &mut self,
622        node_id: &NodeId,
623    ) -> PoolResult<()> {
624        // 检查是否试图删除根节点
625        if node_id == &self.root_id {
626            return Err(error_helpers::cannot_remove_root());
627        }
628
629        let shard_index = self.get_shard_index(node_id);
630        let _ = self.nodes[shard_index]
631            .get(node_id)
632            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
633
634        // 从父节点的content中移除该节点
635        if let Some(parent_id) = self.parent_map.get(node_id).cloned() {
636            let parent_shard_index = self.get_shard_index(&parent_id);
637            if let Some(parent_node) =
638                self.nodes[parent_shard_index].get(&parent_id)
639            {
640                let mut new_parent = parent_node.as_ref().clone();
641                new_parent.content = new_parent
642                    .content
643                    .iter()
644                    .filter(|&id| id != node_id)
645                    .cloned()
646                    .collect();
647                self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
648                    .update(parent_id.clone(), Arc::new(new_parent));
649            }
650        }
651
652        // 删除子树(remove_subtree内部已经处理了节点的删除和parent_map的清理)
653        let mut remove_nodes = Vec::new();
654        self.remove_subtree(node_id, &mut remove_nodes)?;
655
656        // remove_subtree已经删除了所有节点,包括node_id本身,所以这里不需要再次删除
657        Ok(())
658    }
659
660    ///根据下标删除
661    pub fn remove_node_by_index(
662        &mut self,
663        parent_id: &NodeId,
664        index: usize,
665    ) -> PoolResult<()> {
666        let shard_index = self.get_shard_index(parent_id);
667        let parent = self.nodes[shard_index]
668            .get(parent_id)
669            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
670        let mut new_parent = parent.as_ref().clone();
671        let remove_node_id = new_parent.content.remove(index);
672        self.nodes[shard_index] = self.nodes[shard_index]
673            .update(parent_id.clone(), Arc::new(new_parent));
674        let mut remove_nodes = Vec::new();
675        self.remove_subtree(&remove_node_id, &mut remove_nodes)?;
676        for node in remove_nodes {
677            let shard_index = self.get_shard_index(&node.id);
678            self.nodes[shard_index].remove(&node.id);
679            self.parent_map.remove(&node.id);
680        }
681        Ok(())
682    }
683
684    //删除子树
685    fn remove_subtree(
686        &mut self,
687        node_id: &NodeId,
688        remove_nodes: &mut Vec<Node>,
689    ) -> PoolResult<()> {
690        if node_id == &self.root_id {
691            return Err(error_helpers::cannot_remove_root());
692        }
693        let shard_index = self.get_shard_index(node_id);
694        let _ = self.nodes[shard_index]
695            .get(node_id)
696            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
697        if let Some(children) = self.children(node_id) {
698            for child_id in children {
699                self.remove_subtree(&child_id, remove_nodes)?;
700            }
701        }
702        self.parent_map.remove(node_id);
703        if let Some(remove_node) = self.nodes[shard_index].remove(node_id) {
704            remove_nodes.push(remove_node.as_ref().clone());
705        }
706        Ok(())
707    }
708}
709
710impl Index<&NodeId> for Tree {
711    type Output = Arc<Node>;
712    fn index(
713        &self,
714        index: &NodeId,
715    ) -> &Self::Output {
716        let shard_index = self.get_shard_index(index);
717        self.nodes[shard_index].get(index).expect("Node not found")
718    }
719}
720
721impl Index<&str> for Tree {
722    type Output = Arc<Node>;
723    fn index(
724        &self,
725        index: &str,
726    ) -> &Self::Output {
727        let node_id = NodeId::from(index);
728        let shard_index = self.get_shard_index(&node_id);
729        self.nodes[shard_index].get(&node_id).expect("Node not found")
730    }
731}
732
733#[cfg(test)]
734mod tests {
735    use super::*;
736    use crate::node::Node;
737    use crate::node_type::NodeEnum;
738    use crate::types::NodeId;
739    use crate::attrs::Attrs;
740    use crate::mark::Mark;
741    use im::HashMap;
742    use serde_json::json;
743
744    fn create_test_node(id: &str) -> Node {
745        Node::new(id, "test".to_string(), Attrs::default(), vec![], vec![])
746    }
747
748    #[test]
749    fn test_tree_creation() {
750        let root = create_test_node("root");
751        let tree = Tree::new(root.clone());
752        assert_eq!(tree.root_id, root.id);
753        assert!(tree.contains_node(&root.id));
754    }
755
756    #[test]
757    fn test_add_node() {
758        let root = create_test_node("root");
759        let mut tree = Tree::new(root.clone());
760
761        let child = create_test_node("child");
762        let nodes = vec![child.clone()];
763
764        tree.add_node(&root.id, &nodes).unwrap();
765        dbg!(&tree);
766        assert!(tree.contains_node(&child.id));
767        assert_eq!(tree.children(&root.id).unwrap().len(), 1);
768    }
769
770    #[test]
771    fn test_remove_node() {
772        let root = create_test_node("root");
773        let mut tree = Tree::new(root.clone());
774
775        let child = create_test_node("child");
776        let nodes = vec![child.clone()];
777
778        tree.add_node(&root.id, &nodes).unwrap();
779        dbg!(&tree);
780        tree.remove_node(&root.id, vec![child.id.clone()]).unwrap();
781        dbg!(&tree);
782        assert!(!tree.contains_node(&child.id));
783        assert_eq!(tree.children(&root.id).unwrap().len(), 0);
784    }
785
786    #[test]
787    fn test_move_node() {
788        // 创建两个父节点
789        let parent1 = create_test_node("parent1");
790        let parent2 = create_test_node("parent2");
791        let mut tree = Tree::new(parent1.clone());
792
793        // 将 parent2 添加为 parent1 的子节点
794        tree.add_node(&parent1.id, &vec![parent2.clone()]).unwrap();
795
796        // 创建三个子节点
797        let child1 = create_test_node("child1");
798        let child2 = create_test_node("child2");
799        let child3 = create_test_node("child3");
800
801        // 将所有子节点添加到 parent1 下
802        tree.add_node(&parent1.id, &vec![child1.clone()]).unwrap();
803        tree.add_node(&parent1.id, &vec![child2.clone()]).unwrap();
804        tree.add_node(&parent1.id, &vec![child3.clone()]).unwrap();
805
806        // 验证初始状态
807        let parent1_children = tree.children(&parent1.id).unwrap();
808        assert_eq!(parent1_children.len(), 4); // parent2 + 3 children
809        assert_eq!(parent1_children[0], parent2.id);
810        assert_eq!(parent1_children[1], child1.id);
811        assert_eq!(parent1_children[2], child2.id);
812        assert_eq!(parent1_children[3], child3.id);
813
814        // 将 child1 移动到 parent2 下
815        tree.move_node(&parent1.id, &parent2.id, &child1.id, None).unwrap();
816
817        // 验证移动后的状态
818        let parent1_children = tree.children(&parent1.id).unwrap();
819        let parent2_children = tree.children(&parent2.id).unwrap();
820        assert_eq!(parent1_children.len(), 3); // parent2 + 2 children
821        assert_eq!(parent2_children.len(), 1); // child1
822        assert_eq!(parent2_children[0], child1.id);
823
824        // 将 child2 移动到 parent2 下,放在 child1 后面
825        tree.move_node(&parent1.id, &parent2.id, &child2.id, Some(1)).unwrap();
826
827        // 验证最终状态
828        let parent1_children = tree.children(&parent1.id).unwrap();
829        let parent2_children = tree.children(&parent2.id).unwrap();
830        assert_eq!(parent1_children.len(), 2); // parent2 + 1 child
831        assert_eq!(parent2_children.len(), 2); // child1 + child2
832        assert_eq!(parent2_children[0], child1.id);
833        assert_eq!(parent2_children[1], child2.id);
834
835        // 验证父节点关系
836        let child1_parent = tree.get_parent_node(&child1.id).unwrap();
837        let child2_parent = tree.get_parent_node(&child2.id).unwrap();
838        assert_eq!(child1_parent.id, parent2.id);
839        assert_eq!(child2_parent.id, parent2.id);
840    }
841
842    #[test]
843    fn test_update_attr() {
844        let root = create_test_node("root");
845        let mut tree = Tree::new(root.clone());
846
847        let mut attrs = HashMap::new();
848        attrs.insert("key".to_string(), json!("value"));
849
850        tree.update_attr(&root.id, attrs).unwrap();
851
852        let node = tree.get_node(&root.id).unwrap();
853        dbg!(&node);
854        assert_eq!(node.attrs.get("key").unwrap(), &json!("value"));
855    }
856
857    #[test]
858    fn test_add_mark() {
859        let root = create_test_node("root");
860        let mut tree = Tree::new(root.clone());
861
862        let mark = Mark { r#type: "test".to_string(), attrs: Attrs::default() };
863        tree.add_mark(&root.id, &vec![mark.clone()]).unwrap();
864        dbg!(&tree);
865        let node = tree.get_node(&root.id).unwrap();
866        assert!(node.marks.contains(&mark));
867    }
868
869    #[test]
870    fn test_remove_mark() {
871        let root = create_test_node("root");
872        let mut tree = Tree::new(root.clone());
873
874        let mark = Mark { r#type: "test".to_string(), attrs: Attrs::default() };
875        tree.add_mark(&root.id, &vec![mark.clone()]).unwrap();
876        dbg!(&tree);
877        tree.remove_mark(&root.id, &[mark.r#type.clone()]).unwrap();
878        dbg!(&tree);
879        let node = tree.get_node(&root.id).unwrap();
880        assert!(!node.marks.iter().any(|m| m.r#type == mark.r#type));
881    }
882
883    #[test]
884    fn test_all_children() {
885        let root = create_test_node("root");
886        let mut tree = Tree::new(root.clone());
887
888        let child1 = create_test_node("child1");
889        let child2 = create_test_node("child2");
890
891        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
892        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
893        dbg!(&tree);
894        let all_children = tree.all_children(&root.id, None).unwrap();
895        assert_eq!(all_children.1.len(), 2);
896    }
897
898    #[test]
899    fn test_children_count() {
900        let root = create_test_node("root");
901        let mut tree = Tree::new(root.clone());
902
903        let child1 = create_test_node("child1");
904        let child2 = create_test_node("child2");
905
906        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
907        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
908
909        assert_eq!(tree.children_count(&root.id), 2);
910    }
911
912    #[test]
913    fn test_remove_node_by_id_updates_parent() {
914        let root = create_test_node("root");
915        let mut tree = Tree::new(root.clone());
916
917        let child = create_test_node("child");
918        tree.add_node(&root.id, &vec![child.clone()]).unwrap();
919
920        // 验证子节点被添加
921        assert_eq!(tree.children_count(&root.id), 1);
922        assert!(tree.contains_node(&child.id));
923
924        // 删除子节点
925        tree.remove_node_by_id(&child.id).unwrap();
926
927        // 验证子节点被删除且父节点的content被更新
928        assert_eq!(tree.children_count(&root.id), 0);
929        assert!(!tree.contains_node(&child.id));
930    }
931
932    #[test]
933    fn test_move_node_position_edge_cases() {
934        let root = create_test_node("root");
935        let mut tree = Tree::new(root.clone());
936
937        let container = create_test_node("container");
938        tree.add_node(&root.id, &vec![container.clone()]).unwrap();
939
940        let child1 = create_test_node("child1");
941        let child2 = create_test_node("child2");
942        let child3 = create_test_node("child3");
943
944        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
945        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
946        tree.add_node(&root.id, &vec![child3.clone()]).unwrap();
947
948        // 测试移动到超出范围的位置(应该插入到末尾)
949        tree.move_node(&root.id, &container.id, &child1.id, Some(100)).unwrap();
950
951        let container_children = tree.children(&container.id).unwrap();
952        assert_eq!(container_children.len(), 1);
953        assert_eq!(container_children[0], child1.id);
954
955        // 测试移动到位置0
956        tree.move_node(&root.id, &container.id, &child2.id, Some(0)).unwrap();
957
958        let container_children = tree.children(&container.id).unwrap();
959        assert_eq!(container_children.len(), 2);
960        assert_eq!(container_children[0], child2.id);
961        assert_eq!(container_children[1], child1.id);
962    }
963
964    #[test]
965    fn test_cannot_remove_root_node() {
966        let root = create_test_node("root");
967        let mut tree = Tree::new(root.clone());
968
969        // 尝试删除根节点应该失败
970        let result = tree.remove_node_by_id(&root.id);
971        assert!(result.is_err());
972    }
973
974    #[test]
975    fn test_get_parent_node() {
976        let root = create_test_node("root");
977        let mut tree = Tree::new(root.clone());
978
979        let child = create_test_node("child");
980        tree.add_node(&root.id, &vec![child.clone()]).unwrap();
981
982        let parent = tree.get_parent_node(&child.id).unwrap();
983        assert_eq!(parent.id, root.id);
984    }
985}