mf_model/
tree.rs

1use std::num::NonZeroUsize;
2use std::ops::Index;
3use std::hash::{Hash, Hasher};
4use rpds::VectorSync;
5use rpds::HashTrieMapSync;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use once_cell::sync::Lazy;
9use dashmap::DashMap;
10use ahash::{AHasher, RandomState};
11use std::fmt::{self, Debug};
12use crate::error::PoolResult;
13use crate::node_definition::NodeTree;
14use crate::{
15    error::error_helpers,
16    mark::Mark,
17    node::Node,
18    ops::{AttrsRef, MarkRef, NodeRef},
19    types::NodeId,
20};
21
22/// 全局分片索引缓存 - 使用 DashMap 实现无锁并发
23///
24/// # 性能优化
25///
26/// **旧实现 (RwLock + LruCache)**:
27/// - 读操作: ~100ns (需要读锁)
28/// - 写操作: ~500ns (需要写锁,阻塞所有读)
29/// - 高并发: 存在锁竞争
30///
31/// **新实现 (DashMap + AHash)**:
32/// - 读操作: ~20ns (无锁,分片并发)
33/// - 写操作: ~50ns (无锁,只锁单个分片)
34/// - 高并发: 完美扩展,零全局竞争
35///
36/// # 设计决策
37///
38/// 1. **DashMap vs RwLock<HashMap>**: 分片锁,减少竞争
39/// 2. **AHash vs DefaultHasher**: 速度快 3-5x
40/// 3. **无 LRU**: 分片索引计算成本低,缓存淘汰收益小
41static SHARD_INDEX_CACHE: Lazy<DashMap<NodeId, usize, RandomState>> =
42    Lazy::new(|| DashMap::with_capacity_and_hasher(10000, RandomState::new()));
43
44type TreeMap = HashTrieMapSync<NodeId, Node>;
45type TreeParentMap = HashTrieMapSync<NodeId, NodeId>;
46#[derive(Clone, PartialEq, Serialize, Deserialize)]
47pub struct Tree {
48    pub root_id: NodeId,
49    pub nodes: VectorSync<TreeMap>, // 分片存储节点数据
50    pub parent_map: TreeParentMap,
51    #[serde(skip)]
52    num_shards: usize, // 缓存分片数量,避免重复计算
53}
54impl Debug for Tree {
55    fn fmt(
56        &self,
57        f: &mut fmt::Formatter<'_>,
58    ) -> fmt::Result {
59        //输出的时候 过滤掉空的 nodes 节点
60        let nodes = self
61            .nodes
62            .iter()
63            .filter(|node| !node.is_empty())
64            .collect::<Vec<_>>();
65        f.debug_struct("Tree")
66            .field("root_id", &self.root_id)
67            .field("nodes", &nodes)
68            .field("parent_map", &self.parent_map)
69            .field("num_shards", &self.num_shards)
70            .finish()
71    }
72}
73
74impl Tree {
75    /// 计算分片索引 (内联,高性能)
76    ///
77    /// # 性能优化
78    ///
79    /// 1. **快速路径**: 缓存命中 ~20ns
80    /// 2. **慢速路径**: AHash 计算 ~50ns (vs DefaultHasher ~150ns)
81    /// 3. **无锁设计**: DashMap 分片锁,零全局竞争
82    ///
83    /// # 实现细节
84    ///
85    /// - 使用 AHash (ahash) 替代 DefaultHasher: 速度提升 3x
86    /// - 使用 DashMap 替代 RwLock: 并发性能提升 5-10x
87    /// - `#[inline(always)]`: 强制内联,消除函数调用开销
88    #[inline(always)]
89    pub fn get_shard_index(
90        &self,
91        id: &NodeId,
92    ) -> usize {
93        // 快速路径:缓存命中(无锁读取)
94        if let Some(index) = SHARD_INDEX_CACHE.get(id) {
95            return *index;
96        }
97
98        // 慢速路径:计算哈希并缓存
99        self.compute_and_cache_shard_index(id)
100    }
101
102    /// 计算并缓存分片索引 (慢速路径,不内联)
103    ///
104    /// 分离到独立函数,避免内联膨胀影响快速路径
105    #[cold]
106    #[inline(never)]
107    fn compute_and_cache_shard_index(
108        &self,
109        id: &NodeId,
110    ) -> usize {
111        // 使用 AHash 计算哈希值 (比 DefaultHasher 快 3x)
112        let mut hasher = AHasher::default();
113        id.hash(&mut hasher);
114        let index = (hasher.finish() as usize) % self.num_shards;
115
116        // 无锁插入缓存 (DashMap 自动处理并发)
117        SHARD_INDEX_CACHE.insert(id.clone(), index);
118
119        index
120    }
121
122    /// 批量获取分片索引
123    ///
124    /// # 性能优化
125    ///
126    /// - 预分配容量,减少重分配
127    /// - 并行友好,无全局锁
128    #[inline]
129    pub fn get_shard_indices(
130        &self,
131        ids: &[&NodeId],
132    ) -> Vec<usize> {
133        ids.iter().map(|id| self.get_shard_index(id)).collect()
134    }
135
136    /// 批量获取分片索引和ID对 (优化版本)
137    ///
138    /// # 性能优化
139    ///
140    /// **旧实现**: 两次锁操作 (读锁检查 + 写锁更新)
141    /// **新实现**: 零全局锁,DashMap 分片并发
142    ///
143    /// 100个ID的性能对比:
144    /// - 旧实现: ~50µs (锁竞争)
145    /// - 新实现: ~5µs (无锁)
146    #[inline]
147    pub fn get_shard_index_batch<'a>(
148        &self,
149        ids: &'a [&'a NodeId],
150    ) -> Vec<(usize, &'a NodeId)> {
151        ids.iter().map(|&id| (self.get_shard_index(id), id)).collect()
152    }
153
154    /// 清理分片缓存 (用于内存管理)
155    ///
156    /// # 注意
157    ///
158    /// 这个操作会清空整个缓存,应该谨慎使用。
159    /// 通常只在内存压力大或测试场景下调用。
160    pub fn clear_shard_cache() {
161        SHARD_INDEX_CACHE.clear();
162    }
163
164    /// 获取缓存统计信息
165    pub fn shard_cache_stats() -> (usize, usize) {
166        let len = SHARD_INDEX_CACHE.len();
167        let capacity = SHARD_INDEX_CACHE.capacity();
168        (len, capacity)
169    }
170
171    pub fn contains_node(
172        &self,
173        id: &NodeId,
174    ) -> bool {
175        let shard_index = self.get_shard_index(id);
176        self.nodes[shard_index].contains_key(id)
177    }
178
179    pub fn get_node(
180        &self,
181        id: &NodeId,
182    ) -> Option<&Node> {
183        let shard_index = self.get_shard_index(id);
184        self.nodes[shard_index].get(id)
185    }
186
187    pub fn get_parent_node(
188        &self,
189        id: &NodeId,
190    ) -> Option<&Node> {
191        self.parent_map.get(id).and_then(|parent_id| {
192            let shard_index = self.get_shard_index(parent_id);
193            self.nodes[shard_index].get(parent_id)
194        })
195    }
196    pub fn from(nodes: NodeTree) -> Self {
197        let num_shards = std::cmp::max(
198            std::thread::available_parallelism()
199                .map(NonZeroUsize::get)
200                .unwrap_or(2),
201            2,
202        );
203        let mut shards = VectorSync::new_sync(); //(vec![HashTrieMap::new(); num_shards]);
204        for _ in 0..num_shards {
205            shards.push_back_mut(HashTrieMapSync::new_sync());
206        }
207        let mut parent_map = HashTrieMapSync::new_sync();
208        let (root_node, children) = nodes.into_parts();
209        let root_id = root_node.id.clone();
210
211        let mut hasher = AHasher::default();
212        root_id.hash(&mut hasher);
213        let shard_index = (hasher.finish() as usize) % num_shards;
214
215        shards[shard_index] =
216            shards[shard_index].insert(root_id.clone(), root_node);
217
218        fn process_children(
219            children: Vec<NodeTree>,
220            parent_id: &NodeId,
221            shards: &mut VectorSync<TreeMap>,
222            parent_map: &mut TreeParentMap,
223            num_shards: usize,
224        ) {
225            for child in children {
226                let (node, grand_children) = child.into_parts();
227                let node_id = node.id.clone();
228                let mut hasher = AHasher::default();
229                node_id.hash(&mut hasher);
230                let shard_index = (hasher.finish() as usize) % num_shards;
231                shards[shard_index] =
232                    shards[shard_index].insert(node_id.clone(), node);
233                parent_map.insert_mut(node_id.clone(), parent_id.clone());
234
235                // Recursively process grand children
236                process_children(
237                    grand_children,
238                    &node_id,
239                    shards,
240                    parent_map,
241                    num_shards,
242                );
243            }
244        }
245
246        process_children(
247            children,
248            &root_id,
249            &mut shards,
250            &mut parent_map,
251            num_shards,
252        );
253
254        Self { root_id, nodes: shards, parent_map, num_shards }
255    }
256
257    pub fn new(root: Node) -> Self {
258        let num_shards = std::cmp::max(
259            std::thread::available_parallelism()
260                .map(NonZeroUsize::get)
261                .unwrap_or(2),
262            2,
263        );
264        let mut nodes = VectorSync::new_sync();
265        for _ in 0..num_shards {
266            nodes.push_back_mut(HashTrieMapSync::new_sync());
267        }
268        let root_id = root.id.clone();
269        let mut hasher = AHasher::default();
270        root_id.hash(&mut hasher);
271        let shard_index = (hasher.finish() as usize) % num_shards;
272        nodes[shard_index] = nodes[shard_index].insert(root_id.clone(), root);
273        Self {
274            root_id,
275            nodes,
276            parent_map: HashTrieMapSync::new_sync(),
277            num_shards,
278        }
279    }
280
281    pub fn update_attr(
282        &mut self,
283        id: &NodeId,
284        new_values: HashTrieMapSync<String, Value>,
285    ) -> PoolResult<()> {
286        let shard_index = self.get_shard_index(id);
287        let node = self.nodes[shard_index]
288            .get(id)
289            .ok_or(error_helpers::node_not_found(id.clone()))?;
290        let new_node = node.update_attr(new_values);
291        self.nodes[shard_index] =
292            self.nodes[shard_index].insert(id.clone(), new_node);
293        Ok(())
294    }
295    pub fn update_node(
296        &mut self,
297        node: Node,
298    ) -> PoolResult<()> {
299        let shard_index = self.get_shard_index(&node.id);
300        self.nodes[shard_index] =
301            self.nodes[shard_index].insert(node.id.clone(), node);
302        Ok(())
303    }
304
305    /// 向树中添加新的节点及其子节点
306    ///
307    /// # 参数
308    /// * `nodes` - 要添加的节点枚举,包含节点本身及其子节点
309    ///
310    /// # 返回值
311    /// * `Result<(), PoolError>` - 如果添加成功返回 Ok(()), 否则返回错误
312    ///
313    /// # 错误
314    /// * `PoolError::ParentNotFound` - 如果父节点不存在
315    pub fn add(
316        &mut self,
317        parent_id: &NodeId,
318        nodes: Vec<NodeTree>,
319    ) -> PoolResult<()> {
320        // 检查父节点是否存在
321        let parent_shard_index = self.get_shard_index(parent_id);
322        let parent_node = self.nodes[parent_shard_index]
323            .get(parent_id)
324            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
325        let mut new_parent = parent_node.clone();
326
327        // 收集所有子节点的ID并添加到当前节点的content中
328        let zenliang: VectorSync<NodeId> =
329            nodes.iter().map(|n| n.0.id.clone()).collect();
330        // 需要判断 new_parent.content 中是否已经存在 zenliang 中的节点
331        for id in zenliang.iter() {
332            if !new_parent.contains(id) {
333                new_parent.content = new_parent.content.push_back(id.clone());
334            }
335        }
336
337        // 更新当前节点
338        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
339            .insert(parent_id.clone(), new_parent);
340
341        // 使用队列进行广度优先遍历,处理所有子节点
342        let mut node_queue = Vec::new();
343        node_queue.push((nodes, parent_id.clone()));
344        while let Some((current_children, current_parent_id)) = node_queue.pop()
345        {
346            for child in current_children {
347                // 处理每个子节点
348                let (mut child_node, grand_children) = child.into_parts();
349                let current_node_id = child_node.id.clone();
350
351                // 收集孙节点的ID并添加到子节点的content中
352                let grand_children_ids: VectorSync<NodeId> =
353                    grand_children.iter().map(|n| n.0.id.clone()).collect();
354                for id in grand_children_ids.iter() {
355                    if !child_node.contains(id) {
356                        child_node.content =
357                            child_node.content.push_back(id.clone());
358                    }
359                }
360
361                // 将当前节点存储到对应的分片中
362                let shard_index = self.get_shard_index(&current_node_id);
363                self.nodes[shard_index] = self.nodes[shard_index]
364                    .insert(current_node_id.clone(), child_node);
365
366                // 更新父子关系映射
367                self.parent_map = self
368                    .parent_map
369                    .insert(current_node_id.clone(), current_parent_id.clone());
370
371                // 将孙节点加入队列,以便后续处理
372                node_queue.push((grand_children, current_node_id.clone()));
373            }
374        }
375        Ok(())
376    }
377    // 添加到下标
378    pub fn add_at_index(
379        &mut self,
380        parent_id: &NodeId,
381        index: usize,
382        node: &Node,
383    ) -> PoolResult<()> {
384        //添加到节点到 parent_id 的 content 中
385        let parent_shard_index = self.get_shard_index(parent_id);
386        let parent = self.nodes[parent_shard_index]
387            .get(parent_id)
388            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
389        let new_parent = parent.insert_content_at_index(index, &node.id);
390        //更新父节点
391        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
392            .insert(parent_id.clone(), new_parent);
393        //更新父子关系映射
394        self.parent_map =
395            self.parent_map.insert(node.id.clone(), parent_id.clone());
396        //更新子节点
397        let shard_index = self.get_shard_index(&node.id);
398        self.nodes[shard_index] =
399            self.nodes[shard_index].insert(node.id.clone(), node.clone());
400        Ok(())
401    }
402    pub fn add_node(
403        &mut self,
404        parent_id: &NodeId,
405        nodes: &Vec<Node>,
406    ) -> PoolResult<()> {
407        let parent_shard_index = self.get_shard_index(parent_id);
408        let parent = self.nodes[parent_shard_index]
409            .get(parent_id)
410            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
411        let node_ids = nodes.iter().map(|n| n.id.clone()).collect();
412        // 更新父节点 - 添加所有节点的ID到content中
413        let new_parent = parent.insert_contents(&node_ids);
414
415        // 更新父节点到分片中
416        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
417            .insert(parent_id.clone(), new_parent);
418
419        // 更新所有子节点
420        for node in nodes {
421            // 设置当前节点的父子关系映射
422            self.parent_map =
423                self.parent_map.insert(node.id.clone(), parent_id.clone());
424
425            // 设置当前节点的子节点的父子关系映射
426            for child_id in &node.content {
427                self.parent_map =
428                    self.parent_map.insert(child_id.clone(), node.id.clone());
429            }
430
431            // 将节点添加到对应的分片中
432            let shard_index = self.get_shard_index(&node.id);
433            self.nodes[shard_index] =
434                self.nodes[shard_index].insert(node.id.clone(), node.clone());
435        }
436        Ok(())
437    }
438
439    pub fn node(
440        &mut self,
441        key: &str,
442    ) -> NodeRef<'_> {
443        NodeRef::new(self, key.into())
444    }
445    pub fn mark(
446        &mut self,
447        key: &str,
448    ) -> MarkRef<'_> {
449        MarkRef::new(self, key.into())
450    }
451    pub fn attrs(
452        &mut self,
453        key: &str,
454    ) -> AttrsRef<'_> {
455        AttrsRef::new(self, key.into())
456    }
457
458    pub fn children(
459        &self,
460        parent_id: &NodeId,
461    ) -> Option<VectorSync<NodeId>> {
462        self.get_node(parent_id).map(|n| n.content.clone())
463    }
464
465    pub fn children_node(
466        &self,
467        parent_id: &NodeId,
468    ) -> Option<VectorSync<&Node>> {
469        self.children(parent_id)
470            .map(|ids| ids.iter().filter_map(|id| self.get_node(id)).collect())
471    }
472    //递归获取所有子节点 封装成 NodeTree 返回
473    pub fn all_children(
474        &self,
475        parent_id: &NodeId,
476        filter: Option<&dyn Fn(&Node) -> bool>,
477    ) -> Option<NodeTree> {
478        if let Some(node) = self.get_node(parent_id) {
479            let mut child_enums = Vec::new();
480            for child_id in &node.content {
481                if let Some(child_node) = self.get_node(child_id) {
482                    // 检查子节点是否满足过滤条件
483                    if let Some(filter_fn) = filter {
484                        if !filter_fn(child_node) {
485                            continue; // 跳过不满足条件的子节点
486                        }
487                    }
488                    // 递归处理满足条件的子节点
489                    if let Some(child_enum) =
490                        self.all_children(child_id, filter)
491                    {
492                        child_enums.push(child_enum);
493                    }
494                }
495            }
496            Some(NodeTree(node.clone(), child_enums))
497        } else {
498            None
499        }
500    }
501
502    pub fn children_count(
503        &self,
504        parent_id: &NodeId,
505    ) -> usize {
506        self.get_node(parent_id).map(|n| n.content.len()).unwrap_or(0)
507    }
508    pub fn remove_mark_by_name(
509        &mut self,
510        id: &NodeId,
511        mark_name: &str,
512    ) -> PoolResult<()> {
513        let shard_index = self.get_shard_index(id);
514        let node = self.nodes[shard_index]
515            .get(id)
516            .ok_or(error_helpers::node_not_found(id.clone()))?;
517        let new_node = node.remove_mark_by_name(mark_name);
518        self.nodes[shard_index] =
519            self.nodes[shard_index].insert(id.clone(), new_node);
520        Ok(())
521    }
522    pub fn get_marks(
523        &self,
524        id: &NodeId,
525    ) -> Option<VectorSync<Mark>> {
526        self.get_node(id).map(|n| n.marks.clone())
527    }
528
529    pub fn remove_mark(
530        &mut self,
531        id: &NodeId,
532        mark_types: &[String],
533    ) -> PoolResult<()> {
534        let shard_index = self.get_shard_index(id);
535        let node = self.nodes[shard_index]
536            .get(id)
537            .ok_or(error_helpers::node_not_found(id.clone()))?;
538        let new_node = node.remove_mark(mark_types);
539        self.nodes[shard_index] =
540            self.nodes[shard_index].insert(id.clone(), new_node);
541        Ok(())
542    }
543
544    pub fn add_mark(
545        &mut self,
546        id: &NodeId,
547        marks: &[Mark],
548    ) -> PoolResult<()> {
549        let shard_index = self.get_shard_index(id);
550        let node = self.nodes[shard_index]
551            .get(id)
552            .ok_or(error_helpers::node_not_found(id.clone()))?;
553        let new_node = node.add_marks(marks);
554        self.nodes[shard_index] =
555            self.nodes[shard_index].insert(id.clone(), new_node);
556        Ok(())
557    }
558
559    pub fn move_node(
560        &mut self,
561        source_parent_id: &NodeId,
562        target_parent_id: &NodeId,
563        node_id: &NodeId,
564        position: Option<usize>,
565    ) -> PoolResult<()> {
566        let source_shard_index = self.get_shard_index(source_parent_id);
567        let target_shard_index = self.get_shard_index(target_parent_id);
568        let node_shard_index = self.get_shard_index(node_id);
569        let source_parent = self.nodes[source_shard_index]
570            .get(source_parent_id)
571            .ok_or(error_helpers::parent_not_found(source_parent_id.clone()))?;
572        let target_parent = self.nodes[target_shard_index]
573            .get(target_parent_id)
574            .ok_or(error_helpers::parent_not_found(target_parent_id.clone()))?;
575        let _node = self.nodes[node_shard_index]
576            .get(node_id)
577            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
578        if !source_parent.contains(node_id) {
579            return Err(error_helpers::invalid_parenting(
580                node_id.clone(),
581                source_parent_id.clone(),
582            ));
583        }
584        let mut new_source_parent = source_parent.clone();
585        new_source_parent.content = new_source_parent
586            .content
587            .iter()
588            .filter(|&id| id != node_id)
589            .cloned()
590            .collect();
591        let mut new_target_parent = target_parent.clone();
592        if let Some(pos) = position {
593            // 确保position不超过当前content的长度
594            let insert_pos = pos.min(new_target_parent.content.len());
595
596            // 在指定位置插入节点
597            new_target_parent =
598                new_target_parent.insert_content_at_index(insert_pos, node_id);
599        } else {
600            // 没有指定位置,添加到末尾
601            new_target_parent.content =
602                new_target_parent.content.push_back(node_id.clone());
603        }
604        self.nodes[source_shard_index] = self.nodes[source_shard_index]
605            .insert(source_parent_id.clone(), new_source_parent);
606        self.nodes[target_shard_index] = self.nodes[target_shard_index]
607            .insert(target_parent_id.clone(), new_target_parent);
608        self.parent_map =
609            self.parent_map.insert(node_id.clone(), target_parent_id.clone());
610        Ok(())
611    }
612
613    pub fn remove_node(
614        &mut self,
615        parent_id: &NodeId,
616        nodes: Vec<NodeId>,
617    ) -> PoolResult<()> {
618        let parent_shard_index = self.get_shard_index(parent_id);
619        let parent = self.nodes[parent_shard_index]
620            .get(parent_id)
621            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
622        if nodes.contains(&self.root_id) {
623            return Err(error_helpers::cannot_remove_root());
624        }
625        for node_id in &nodes {
626            if !parent.contains(node_id) {
627                return Err(error_helpers::invalid_parenting(
628                    node_id.clone(),
629                    parent_id.clone(),
630                ));
631            }
632        }
633        let nodes_to_remove: std::collections::HashSet<_> =
634            nodes.iter().collect();
635        let filtered_children: VectorSync<NodeId> = parent
636            .content
637            .iter()
638            .filter(|&id| !nodes_to_remove.contains(id))
639            .cloned()
640            .collect();
641        let mut parent_node = parent.clone();
642        parent_node.content = filtered_children;
643        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
644            .insert(parent_id.clone(), parent_node);
645        let mut remove_nodes = Vec::new();
646        for node_id in nodes {
647            self.remove_subtree(&node_id, &mut remove_nodes)?;
648        }
649        Ok(())
650    }
651    //=删除节点
652    pub fn remove_node_by_id(
653        &mut self,
654        node_id: &NodeId,
655    ) -> PoolResult<()> {
656        // 检查是否试图删除根节点
657        if node_id == &self.root_id {
658            return Err(error_helpers::cannot_remove_root());
659        }
660
661        let shard_index = self.get_shard_index(node_id);
662        let _ = self.nodes[shard_index]
663            .get(node_id)
664            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
665
666        // 从父节点的content中移除该节点
667        if let Some(parent_id) = self.parent_map.get(node_id).cloned() {
668            let parent_shard_index = self.get_shard_index(&parent_id);
669            if let Some(parent_node) =
670                self.nodes[parent_shard_index].get(&parent_id)
671            {
672                let mut new_parent = parent_node.clone();
673                new_parent.content = new_parent
674                    .content
675                    .iter()
676                    .filter(|&id| id != node_id)
677                    .cloned()
678                    .collect();
679                self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
680                    .insert(parent_id.clone(), new_parent);
681            }
682        }
683
684        // 删除子树(remove_subtree内部已经处理了节点的删除和parent_map的清理)
685        let mut remove_nodes = Vec::new();
686        self.remove_subtree(node_id, &mut remove_nodes)?;
687
688        // remove_subtree已经删除了所有节点,包括node_id本身,所以这里不需要再次删除
689        Ok(())
690    }
691
692    ///根据下标删除
693    pub fn remove_node_by_index(
694        &mut self,
695        parent_id: &NodeId,
696        index: usize,
697    ) -> PoolResult<()> {
698        let shard_index = self.get_shard_index(parent_id);
699        let parent = self.nodes[shard_index]
700            .get(parent_id)
701            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
702        let mut new_parent = parent.clone();
703        let remove_node_id = {
704            match new_parent.content.get(index) {
705                Some(id) => id.clone(),
706                None => return Err(anyhow::anyhow!("index out of bounds")),
707            }
708        };
709        new_parent = new_parent.remove_content(&remove_node_id);
710        self.nodes[shard_index] =
711            self.nodes[shard_index].insert(parent_id.clone(), new_parent);
712        let mut remove_nodes = Vec::new();
713        self.remove_subtree(&remove_node_id, &mut remove_nodes)?;
714
715        Ok(())
716    }
717
718    //删除子树
719    fn remove_subtree(
720        &mut self,
721        node_id: &NodeId,
722        remove_nodes: &mut Vec<Node>,
723    ) -> PoolResult<()> {
724        if node_id == &self.root_id {
725            return Err(error_helpers::cannot_remove_root());
726        }
727        let shard_index = self.get_shard_index(node_id);
728        let _ = self.nodes[shard_index]
729            .get(node_id)
730            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
731        if let Some(children) = self.children(node_id) {
732            for child_id in children.iter() {
733                self.remove_subtree(&child_id, remove_nodes)?;
734            }
735        }
736        self.parent_map = self.parent_map.remove(node_id);
737
738        if let Some(remove_node) = self.nodes[shard_index].get(node_id) {
739            remove_nodes.push(remove_node.clone());
740            self.nodes[shard_index] = self.nodes[shard_index].remove(node_id);
741        }
742        Ok(())
743    }
744}
745
746impl Index<&NodeId> for Tree {
747    type Output = Node;
748    fn index(
749        &self,
750        index: &NodeId,
751    ) -> &Self::Output {
752        let shard_index = self.get_shard_index(index);
753        self.nodes[shard_index].get(index).expect("Node not found")
754    }
755}
756
757impl Index<&str> for Tree {
758    type Output = Node;
759    fn index(
760        &self,
761        index: &str,
762    ) -> &Self::Output {
763        let node_id = NodeId::from(index);
764        let shard_index = self.get_shard_index(&node_id);
765        self.nodes[shard_index].get(&node_id).expect("Node not found")
766    }
767}
768
769#[cfg(test)]
770mod tests {
771    use super::*;
772    use crate::node::Node;
773    use crate::attrs::Attrs;
774    use crate::mark::Mark;
775    use serde_json::json;
776
777    fn create_test_node(id: &str) -> Node {
778        Node::new(id, "test".to_string(), Attrs::default(), vec![], vec![])
779    }
780
781    #[test]
782    fn test_tree_creation() {
783        let root = create_test_node("root");
784        let tree = Tree::new(root.clone());
785        assert_eq!(tree.root_id, root.id);
786        assert!(tree.contains_node(&root.id));
787    }
788
789    #[test]
790    fn test_add_node() {
791        let root = create_test_node("root");
792        let mut tree = Tree::new(root.clone());
793
794        let child = create_test_node("child");
795        let nodes = vec![child.clone()];
796
797        tree.add_node(&root.id, &nodes).unwrap();
798        #[cfg(feature = "debug-logs")]
799        dbg!(&tree);
800        assert!(tree.contains_node(&child.id));
801        assert_eq!(tree.children(&root.id).unwrap().len(), 1);
802    }
803
804    #[test]
805    fn test_remove_node() {
806        let root = create_test_node("root");
807        let mut tree = Tree::new(root.clone());
808
809        let child = create_test_node("child");
810        let nodes = vec![child.clone()];
811
812        tree.add_node(&root.id, &nodes).unwrap();
813        #[cfg(feature = "debug-logs")]
814        dbg!(&tree);
815        tree.remove_node(&root.id, vec![child.id.clone()]).unwrap();
816        #[cfg(feature = "debug-logs")]
817        dbg!(&tree);
818        assert!(!tree.contains_node(&child.id));
819        assert_eq!(tree.children(&root.id).unwrap().len(), 0);
820    }
821
822    #[test]
823    fn test_move_node() {
824        // 创建两个父节点
825        let parent1 = create_test_node("parent1");
826        let parent2 = create_test_node("parent2");
827        let mut tree = Tree::new(parent1.clone());
828
829        // 将 parent2 添加为 parent1 的子节点
830        tree.add_node(&parent1.id, &vec![parent2.clone()]).unwrap();
831
832        // 创建三个子节点
833        let child1 = create_test_node("child1");
834        let child2 = create_test_node("child2");
835        let child3 = create_test_node("child3");
836
837        // 将所有子节点添加到 parent1 下
838        tree.add_node(&parent1.id, &vec![child1.clone()]).unwrap();
839        tree.add_node(&parent1.id, &vec![child2.clone()]).unwrap();
840        tree.add_node(&parent1.id, &vec![child3.clone()]).unwrap();
841
842        // 验证初始状态
843        let parent1_children = tree.children(&parent1.id).unwrap();
844        assert_eq!(parent1_children.len(), 4); // parent2 + 3 children
845        assert_eq!(parent1_children[0], parent2.id);
846        assert_eq!(parent1_children[1], child1.id);
847        assert_eq!(parent1_children[2], child2.id);
848        assert_eq!(parent1_children[3], child3.id);
849
850        // 将 child1 移动到 parent2 下
851        tree.move_node(&parent1.id, &parent2.id, &child1.id, None).unwrap();
852
853        // 验证移动后的状态
854        let parent1_children = tree.children(&parent1.id).unwrap();
855        let parent2_children = tree.children(&parent2.id).unwrap();
856        assert_eq!(parent1_children.len(), 3); // parent2 + 2 children
857        assert_eq!(parent2_children.len(), 1); // child1
858        assert_eq!(parent2_children[0], child1.id);
859
860        // 将 child2 移动到 parent2 下,放在 child1 后面
861        tree.move_node(&parent1.id, &parent2.id, &child2.id, Some(1)).unwrap();
862
863        // 验证最终状态
864        let parent1_children = tree.children(&parent1.id).unwrap();
865        let parent2_children = tree.children(&parent2.id).unwrap();
866        assert_eq!(parent1_children.len(), 2); // parent2 + 1 child
867        assert_eq!(parent2_children.len(), 2); // child1 + child2
868        assert_eq!(parent2_children[0], child1.id);
869        assert_eq!(parent2_children[1], child2.id);
870
871        // 验证父节点关系
872        let child1_parent = tree.get_parent_node(&child1.id).unwrap();
873        let child2_parent = tree.get_parent_node(&child2.id).unwrap();
874        assert_eq!(child1_parent.id, parent2.id);
875        assert_eq!(child2_parent.id, parent2.id);
876    }
877
878    #[test]
879    fn test_update_attr() {
880        let root = create_test_node("root");
881        let mut tree = Tree::new(root.clone());
882
883        let mut attrs = HashTrieMapSync::new_sync();
884        attrs = attrs.insert("key".to_string(), json!("value"));
885
886        tree.update_attr(&root.id, attrs).unwrap();
887
888        let node = tree.get_node(&root.id).unwrap();
889        #[cfg(feature = "debug-logs")]
890        dbg!(&node);
891        assert_eq!(node.attrs.get("key").unwrap(), &json!("value"));
892    }
893
894    #[test]
895    fn test_add_mark() {
896        let root = create_test_node("root");
897        let mut tree = Tree::new(root.clone());
898
899        let mark = Mark { r#type: "test".to_string(), attrs: Attrs::default() };
900        tree.add_mark(&root.id, &[mark.clone()]).unwrap();
901        #[cfg(feature = "debug-logs")]
902        dbg!(&tree);
903    }
904
905    #[test]
906    fn test_remove_mark() {
907        let root = create_test_node("root");
908        let mut tree = Tree::new(root.clone());
909
910        let mark = Mark { r#type: "test".to_string(), attrs: Attrs::default() };
911        tree.add_mark(&root.id, &[mark.clone()]).unwrap();
912        #[cfg(feature = "debug-logs")]
913        dbg!(&tree);
914        tree.remove_mark(&root.id, &[mark.r#type.clone()]).unwrap();
915        #[cfg(feature = "debug-logs")]
916        dbg!(&tree);
917        let node = tree.get_node(&root.id).unwrap();
918        assert!(!node.marks.iter().any(|m| m.r#type == mark.r#type));
919    }
920
921    #[test]
922    fn test_all_children() {
923        let root = create_test_node("root");
924        let mut tree = Tree::new(root.clone());
925
926        let child1 = create_test_node("child1");
927        let child2 = create_test_node("child2");
928
929        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
930        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
931        #[cfg(feature = "debug-logs")]
932        dbg!(&tree);
933        let all_children = tree.all_children(&root.id, None).unwrap();
934        assert_eq!(all_children.1.len(), 2);
935    }
936
937    #[test]
938    fn test_children_count() {
939        let root = create_test_node("root");
940        let mut tree = Tree::new(root.clone());
941
942        let child1 = create_test_node("child1");
943        let child2 = create_test_node("child2");
944
945        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
946        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
947
948        assert_eq!(tree.children_count(&root.id), 2);
949    }
950
951    #[test]
952    fn test_remove_node_by_id_updates_parent() {
953        let root = create_test_node("root");
954        let mut tree = Tree::new(root.clone());
955
956        let child = create_test_node("child");
957        tree.add_node(&root.id, &vec![child.clone()]).unwrap();
958
959        // 验证子节点被添加
960        assert_eq!(tree.children_count(&root.id), 1);
961        assert!(tree.contains_node(&child.id));
962
963        // 删除子节点
964        tree.remove_node_by_id(&child.id).unwrap();
965
966        // 验证子节点被删除且父节点的content被更新
967        assert_eq!(tree.children_count(&root.id), 0);
968        assert!(!tree.contains_node(&child.id));
969    }
970
971    #[test]
972    fn test_move_node_position_edge_cases() {
973        let root = create_test_node("root");
974        let mut tree = Tree::new(root.clone());
975
976        let container = create_test_node("container");
977        tree.add_node(&root.id, &vec![container.clone()]).unwrap();
978
979        let child1 = create_test_node("child1");
980        let child2 = create_test_node("child2");
981        let child3 = create_test_node("child3");
982
983        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
984        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
985        tree.add_node(&root.id, &vec![child3.clone()]).unwrap();
986
987        // 测试移动到超出范围的位置(应该插入到末尾)
988        tree.move_node(&root.id, &container.id, &child1.id, Some(100)).unwrap();
989
990        let container_children = tree.children(&container.id).unwrap();
991        assert_eq!(container_children.len(), 1);
992        assert_eq!(container_children[0], child1.id);
993
994        // 测试移动到位置0
995        tree.move_node(&root.id, &container.id, &child2.id, Some(0)).unwrap();
996
997        let container_children = tree.children(&container.id).unwrap();
998        assert_eq!(container_children.len(), 2);
999        assert_eq!(container_children[0], child2.id);
1000        assert_eq!(container_children[1], child1.id);
1001    }
1002
1003    #[test]
1004    fn test_cannot_remove_root_node() {
1005        let root = create_test_node("root");
1006        let mut tree = Tree::new(root.clone());
1007
1008        // 尝试删除根节点应该失败
1009        let result = tree.remove_node_by_id(&root.id);
1010        assert!(result.is_err());
1011    }
1012
1013    #[test]
1014    fn test_get_parent_node() {
1015        let root = create_test_node("root");
1016        let mut tree = Tree::new(root.clone());
1017
1018        let child = create_test_node("child");
1019        tree.add_node(&root.id, &vec![child.clone()]).unwrap();
1020
1021        let parent = tree.get_parent_node(&child.id).unwrap();
1022        assert_eq!(parent.id, root.id);
1023    }
1024}