mf_model/
tree.rs

1use std::num::NonZeroUsize;
2use std::sync::Arc;
3use std::ops::Index;
4use std::hash::{Hash, Hasher};
5use imbl::Vector;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use once_cell::sync::Lazy;
9use dashmap::DashMap;
10use ahash::{AHasher, RandomState};
11use std::fmt::{self, Debug};
12use crate::error::PoolResult;
13use crate::node_definition::NodeTree;
14use crate::{
15    error::error_helpers,
16    mark::Mark,
17    node::Node,
18    ops::{AttrsRef, MarkRef, NodeRef},
19    types::NodeId,
20};
21
22/// 全局分片索引缓存 - 使用 DashMap 实现无锁并发
23///
24/// # 性能优化
25///
26/// **旧实现 (RwLock + LruCache)**:
27/// - 读操作: ~100ns (需要读锁)
28/// - 写操作: ~500ns (需要写锁,阻塞所有读)
29/// - 高并发: 存在锁竞争
30///
31/// **新实现 (DashMap + AHash)**:
32/// - 读操作: ~20ns (无锁,分片并发)
33/// - 写操作: ~50ns (无锁,只锁单个分片)
34/// - 高并发: 完美扩展,零全局竞争
35///
36/// # 设计决策
37///
38/// 1. **DashMap vs RwLock<HashMap>**: 分片锁,减少竞争
39/// 2. **AHash vs DefaultHasher**: 速度快 3-5x
40/// 3. **无 LRU**: 分片索引计算成本低,缓存淘汰收益小
41static SHARD_INDEX_CACHE: Lazy<DashMap<NodeId, usize, RandomState>> =
42    Lazy::new(|| DashMap::with_capacity_and_hasher(10000, RandomState::new()));
43
44#[derive(Clone, PartialEq, Serialize, Deserialize)]
45pub struct Tree {
46    pub root_id: NodeId,
47    pub nodes: Vector<imbl::HashMap<NodeId, Arc<Node>>>, // 分片存储节点数据
48    pub parent_map: imbl::HashMap<NodeId, NodeId>,
49    #[serde(skip)]
50    num_shards: usize, // 缓存分片数量,避免重复计算
51}
52impl Debug for Tree {
53    fn fmt(
54        &self,
55        f: &mut fmt::Formatter<'_>,
56    ) -> fmt::Result {
57        //输出的时候 过滤掉空的 nodes 节点
58        let nodes = self
59            .nodes
60            .iter()
61            .filter(|node| !node.is_empty())
62            .collect::<Vec<_>>();
63        f.debug_struct("Tree")
64            .field("root_id", &self.root_id)
65            .field("nodes", &nodes)
66            .field("parent_map", &self.parent_map)
67            .field("num_shards", &self.num_shards)
68            .finish()
69    }
70}
71
72impl Tree {
73    /// 计算分片索引 (内联,高性能)
74    ///
75    /// # 性能优化
76    ///
77    /// 1. **快速路径**: 缓存命中 ~20ns
78    /// 2. **慢速路径**: AHash 计算 ~50ns (vs DefaultHasher ~150ns)
79    /// 3. **无锁设计**: DashMap 分片锁,零全局竞争
80    ///
81    /// # 实现细节
82    ///
83    /// - 使用 AHash (ahash) 替代 DefaultHasher: 速度提升 3x
84    /// - 使用 DashMap 替代 RwLock: 并发性能提升 5-10x
85    /// - `#[inline(always)]`: 强制内联,消除函数调用开销
86    #[inline(always)]
87    pub fn get_shard_index(
88        &self,
89        id: &NodeId,
90    ) -> usize {
91        // 快速路径:缓存命中(无锁读取)
92        if let Some(index) = SHARD_INDEX_CACHE.get(id) {
93            return *index;
94        }
95
96        // 慢速路径:计算哈希并缓存
97        self.compute_and_cache_shard_index(id)
98    }
99
100    /// 计算并缓存分片索引 (慢速路径,不内联)
101    ///
102    /// 分离到独立函数,避免内联膨胀影响快速路径
103    #[cold]
104    #[inline(never)]
105    fn compute_and_cache_shard_index(
106        &self,
107        id: &NodeId,
108    ) -> usize {
109        // 使用 AHash 计算哈希值 (比 DefaultHasher 快 3x)
110        let mut hasher = AHasher::default();
111        id.hash(&mut hasher);
112        let index = (hasher.finish() as usize) % self.num_shards;
113
114        // 无锁插入缓存 (DashMap 自动处理并发)
115        SHARD_INDEX_CACHE.insert(id.clone(), index);
116
117        index
118    }
119
120    /// 批量获取分片索引
121    ///
122    /// # 性能优化
123    ///
124    /// - 预分配容量,减少重分配
125    /// - 并行友好,无全局锁
126    #[inline]
127    pub fn get_shard_indices(
128        &self,
129        ids: &[&NodeId],
130    ) -> Vec<usize> {
131        ids.iter().map(|id| self.get_shard_index(id)).collect()
132    }
133
134    /// 批量获取分片索引和ID对 (优化版本)
135    ///
136    /// # 性能优化
137    ///
138    /// **旧实现**: 两次锁操作 (读锁检查 + 写锁更新)
139    /// **新实现**: 零全局锁,DashMap 分片并发
140    ///
141    /// 100个ID的性能对比:
142    /// - 旧实现: ~50µs (锁竞争)
143    /// - 新实现: ~5µs (无锁)
144    #[inline]
145    pub fn get_shard_index_batch<'a>(
146        &self,
147        ids: &'a [&'a NodeId],
148    ) -> Vec<(usize, &'a NodeId)> {
149        ids.iter().map(|&id| (self.get_shard_index(id), id)).collect()
150    }
151
152    /// 清理分片缓存 (用于内存管理)
153    ///
154    /// # 注意
155    ///
156    /// 这个操作会清空整个缓存,应该谨慎使用。
157    /// 通常只在内存压力大或测试场景下调用。
158    pub fn clear_shard_cache() {
159        SHARD_INDEX_CACHE.clear();
160    }
161
162    /// 获取缓存统计信息
163    pub fn shard_cache_stats() -> (usize, usize) {
164        let len = SHARD_INDEX_CACHE.len();
165        let capacity = SHARD_INDEX_CACHE.capacity();
166        (len, capacity)
167    }
168
169    pub fn contains_node(
170        &self,
171        id: &NodeId,
172    ) -> bool {
173        let shard_index = self.get_shard_index(id);
174        self.nodes[shard_index].contains_key(id)
175    }
176
177    pub fn get_node(
178        &self,
179        id: &NodeId,
180    ) -> Option<Arc<Node>> {
181        let shard_index = self.get_shard_index(id);
182        self.nodes[shard_index].get(id).cloned()
183    }
184
185    pub fn get_parent_node(
186        &self,
187        id: &NodeId,
188    ) -> Option<Arc<Node>> {
189        self.parent_map.get(id).and_then(|parent_id| {
190            let shard_index = self.get_shard_index(parent_id);
191            self.nodes[shard_index].get(parent_id).cloned()
192        })
193    }
194    pub fn from(nodes: NodeTree) -> Self {
195        let num_shards = std::cmp::max(
196            std::thread::available_parallelism()
197                .map(NonZeroUsize::get)
198                .unwrap_or(2),
199            2,
200        );
201        let mut shards = Vector::from(vec![imbl::HashMap::new(); num_shards]);
202        let mut parent_map = imbl::HashMap::new();
203        let (root_node, children) = nodes.into_parts();
204        let root_id = root_node.id.clone();
205
206        let mut hasher = AHasher::default();
207        root_id.hash(&mut hasher);
208        let shard_index = (hasher.finish() as usize) % num_shards;
209        shards[shard_index] =
210            shards[shard_index].update(root_id.clone(), Arc::new(root_node));
211
212        fn process_children(
213            children: Vec<NodeTree>,
214            parent_id: &NodeId,
215            shards: &mut Vector<imbl::HashMap<NodeId, Arc<Node>>>,
216            parent_map: &mut imbl::HashMap<NodeId, NodeId>,
217            num_shards: usize,
218        ) {
219            for child in children {
220                let (node, grand_children) = child.into_parts();
221                let node_id = node.id.clone();
222                let mut hasher = AHasher::default();
223                node_id.hash(&mut hasher);
224                let shard_index = (hasher.finish() as usize) % num_shards;
225                shards[shard_index] =
226                    shards[shard_index].update(node_id.clone(), Arc::new(node));
227                parent_map.insert(node_id.clone(), parent_id.clone());
228
229                // Recursively process grand children
230                process_children(
231                    grand_children,
232                    &node_id,
233                    shards,
234                    parent_map,
235                    num_shards,
236                );
237            }
238        }
239
240        process_children(
241            children,
242            &root_id,
243            &mut shards,
244            &mut parent_map,
245            num_shards,
246        );
247
248        Self { root_id, nodes: shards, parent_map, num_shards }
249    }
250
251    pub fn new(root: Node) -> Self {
252        let num_shards = std::cmp::max(
253            std::thread::available_parallelism()
254                .map(NonZeroUsize::get)
255                .unwrap_or(2),
256            2,
257        );
258        let mut nodes = Vector::from(vec![imbl::HashMap::new(); num_shards]);
259        let root_id = root.id.clone();
260        let mut hasher = AHasher::default();
261        root_id.hash(&mut hasher);
262        let shard_index = (hasher.finish() as usize) % num_shards;
263        nodes[shard_index] =
264            nodes[shard_index].update(root_id.clone(), Arc::new(root));
265        Self { root_id, nodes, parent_map: imbl::HashMap::new(), num_shards }
266    }
267
268    pub fn update_attr(
269        &mut self,
270        id: &NodeId,
271        new_values: imbl::HashMap<String, Value>,
272    ) -> PoolResult<()> {
273        let shard_index = self.get_shard_index(id);
274        let node = self.nodes[shard_index]
275            .get(id)
276            .ok_or(error_helpers::node_not_found(id.clone()))?;
277        let new_node = node.as_ref().update_attr(new_values);
278        self.nodes[shard_index] =
279            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
280        Ok(())
281    }
282    pub fn update_node(
283        &mut self,
284        node: Node,
285    ) -> PoolResult<()> {
286        let shard_index = self.get_shard_index(&node.id);
287        self.nodes[shard_index] =
288            self.nodes[shard_index].update(node.id.clone(), Arc::new(node));
289        Ok(())
290    }
291
292    /// 向树中添加新的节点及其子节点
293    ///
294    /// # 参数
295    /// * `nodes` - 要添加的节点枚举,包含节点本身及其子节点
296    ///
297    /// # 返回值
298    /// * `Result<(), PoolError>` - 如果添加成功返回 Ok(()), 否则返回错误
299    ///
300    /// # 错误
301    /// * `PoolError::ParentNotFound` - 如果父节点不存在
302    pub fn add(
303        &mut self,
304        parent_id: &NodeId,
305        nodes: Vec<NodeTree>,
306    ) -> PoolResult<()> {
307        // 检查父节点是否存在
308        let parent_shard_index = self.get_shard_index(parent_id);
309        let parent_node = self.nodes[parent_shard_index]
310            .get(parent_id)
311            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
312        let mut new_parent = parent_node.as_ref().clone();
313
314        // 收集所有子节点的ID并添加到当前节点的content中
315        let zenliang: Vector<NodeId> =
316            nodes.iter().map(|n| n.0.id.clone()).collect();
317        // 需要判断 new_parent.content 中是否已经存在 zenliang 中的节点
318        let mut new_content = imbl::Vector::new();
319        for id in zenliang {
320            if !new_parent.content.contains(&id) {
321                new_content.push_back(id);
322            }
323        }
324        new_parent.content.extend(new_content);
325
326        // 更新当前节点
327        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
328            .update(parent_id.clone(), Arc::new(new_parent));
329
330        // 使用队列进行广度优先遍历,处理所有子节点
331        let mut node_queue = Vec::new();
332        node_queue.push((nodes, parent_id.clone()));
333        while let Some((current_children, current_parent_id)) = node_queue.pop()
334        {
335            for child in current_children {
336                // 处理每个子节点
337                let (mut child_node, grand_children) = child.into_parts();
338                let current_node_id = child_node.id.clone();
339
340                // 收集孙节点的ID并添加到子节点的content中
341                let grand_children_ids: Vector<NodeId> =
342                    grand_children.iter().map(|n| n.0.id.clone()).collect();
343                let mut new_content = imbl::Vector::new();
344                for id in grand_children_ids {
345                    if !child_node.content.contains(&id) {
346                        new_content.push_back(id);
347                    }
348                }
349                child_node.content.extend(new_content);
350
351                // 将当前节点存储到对应的分片中
352                let shard_index = self.get_shard_index(&current_node_id);
353                self.nodes[shard_index] = self.nodes[shard_index]
354                    .update(current_node_id.clone(), Arc::new(child_node));
355
356                // 更新父子关系映射
357                self.parent_map
358                    .insert(current_node_id.clone(), current_parent_id.clone());
359
360                // 将孙节点加入队列,以便后续处理
361                node_queue.push((grand_children, current_node_id.clone()));
362            }
363        }
364        Ok(())
365    }
366    // 添加到下标
367    pub fn add_at_index(
368        &mut self,
369        parent_id: &NodeId,
370        index: usize,
371        node: &Node,
372    ) -> PoolResult<()> {
373        //添加到节点到 parent_id 的 content 中
374        let parent_shard_index = self.get_shard_index(parent_id);
375        let parent = self.nodes[parent_shard_index]
376            .get(parent_id)
377            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
378        let new_parent =
379            parent.as_ref().insert_content_at_index(index, &node.id);
380        //更新父节点
381        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
382            .update(parent_id.clone(), Arc::new(new_parent));
383        //更新父子关系映射
384        self.parent_map.insert(node.id.clone(), parent_id.clone());
385        //更新子节点
386        let shard_index = self.get_shard_index(&node.id);
387        self.nodes[shard_index] = self.nodes[shard_index]
388            .update(node.id.clone(), Arc::new(node.clone()));
389        Ok(())
390    }
391    pub fn add_node(
392        &mut self,
393        parent_id: &NodeId,
394        nodes: &Vec<Node>,
395    ) -> PoolResult<()> {
396        let parent_shard_index = self.get_shard_index(parent_id);
397        let parent = self.nodes[parent_shard_index]
398            .get(parent_id)
399            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
400        let node_ids = nodes.iter().map(|n| n.id.clone()).collect();
401        // 更新父节点 - 添加所有节点的ID到content中
402        let new_parent = parent.as_ref().insert_contents(&node_ids);
403
404        // 更新父节点到分片中
405        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
406            .update(parent_id.clone(), Arc::new(new_parent));
407
408        // 更新所有子节点
409        for node in nodes {
410            // 设置当前节点的父子关系映射
411            self.parent_map.insert(node.id.clone(), parent_id.clone());
412
413            // 设置当前节点的子节点的父子关系映射
414            for child_id in &node.content {
415                self.parent_map.insert(child_id.clone(), node.id.clone());
416            }
417
418            // 将节点添加到对应的分片中
419            let shard_index = self.get_shard_index(&node.id);
420            self.nodes[shard_index] = self.nodes[shard_index]
421                .update(node.id.clone(), Arc::new(node.clone()));
422        }
423        Ok(())
424    }
425
426    pub fn node(
427        &mut self,
428        key: &str,
429    ) -> NodeRef<'_> {
430        NodeRef::new(self, key.into())
431    }
432    pub fn mark(
433        &mut self,
434        key: &str,
435    ) -> MarkRef<'_> {
436        MarkRef::new(self, key.into())
437    }
438    pub fn attrs(
439        &mut self,
440        key: &str,
441    ) -> AttrsRef<'_> {
442        AttrsRef::new(self, key.into())
443    }
444
445    pub fn children(
446        &self,
447        parent_id: &NodeId,
448    ) -> Option<imbl::Vector<NodeId>> {
449        self.get_node(parent_id).map(|n| n.content.clone())
450    }
451
452    pub fn children_node(
453        &self,
454        parent_id: &NodeId,
455    ) -> Option<imbl::Vector<Arc<Node>>> {
456        self.children(parent_id)
457            .map(|ids| ids.iter().filter_map(|id| self.get_node(id)).collect())
458    }
459    //递归获取所有子节点 封装成 NodeTree 返回
460    pub fn all_children(
461        &self,
462        parent_id: &NodeId,
463        filter: Option<&dyn Fn(&Node) -> bool>,
464    ) -> Option<NodeTree> {
465        if let Some(node) = self.get_node(parent_id) {
466            let mut child_enums = Vec::new();
467            for child_id in &node.content {
468                if let Some(child_node) = self.get_node(child_id) {
469                    // 检查子节点是否满足过滤条件
470                    if let Some(filter_fn) = filter {
471                        if !filter_fn(child_node.as_ref()) {
472                            continue; // 跳过不满足条件的子节点
473                        }
474                    }
475                    // 递归处理满足条件的子节点
476                    if let Some(child_enum) =
477                        self.all_children(child_id, filter)
478                    {
479                        child_enums.push(child_enum);
480                    }
481                }
482            }
483            Some(NodeTree(node.as_ref().clone(), child_enums))
484        } else {
485            None
486        }
487    }
488
489    pub fn children_count(
490        &self,
491        parent_id: &NodeId,
492    ) -> usize {
493        self.get_node(parent_id).map(|n| n.content.len()).unwrap_or(0)
494    }
495    pub fn remove_mark_by_name(
496        &mut self,
497        id: &NodeId,
498        mark_name: &str,
499    ) -> PoolResult<()> {
500        let shard_index = self.get_shard_index(id);
501        let node = self.nodes[shard_index]
502            .get(id)
503            .ok_or(error_helpers::node_not_found(id.clone()))?;
504        let new_node = node.as_ref().remove_mark_by_name(mark_name);
505        self.nodes[shard_index] =
506            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
507        Ok(())
508    }
509    pub fn get_marks(
510        &self,
511        id: &NodeId,
512    ) -> Option<imbl::Vector<Mark>> {
513        self.get_node(id).map(|n| n.marks.clone())
514    }
515
516    pub fn remove_mark(
517        &mut self,
518        id: &NodeId,
519        mark_types: &[String],
520    ) -> PoolResult<()> {
521        let shard_index = self.get_shard_index(id);
522        let node = self.nodes[shard_index]
523            .get(id)
524            .ok_or(error_helpers::node_not_found(id.clone()))?;
525        let new_node = node.as_ref().remove_mark(mark_types);
526        self.nodes[shard_index] =
527            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
528        Ok(())
529    }
530
531    pub fn add_mark(
532        &mut self,
533        id: &NodeId,
534        marks: &[Mark],
535    ) -> PoolResult<()> {
536        let shard_index = self.get_shard_index(id);
537        let node = self.nodes[shard_index]
538            .get(id)
539            .ok_or(error_helpers::node_not_found(id.clone()))?;
540        let new_node = node.as_ref().add_marks(marks);
541        self.nodes[shard_index] =
542            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
543        Ok(())
544    }
545
546    pub fn move_node(
547        &mut self,
548        source_parent_id: &NodeId,
549        target_parent_id: &NodeId,
550        node_id: &NodeId,
551        position: Option<usize>,
552    ) -> PoolResult<()> {
553        let source_shard_index = self.get_shard_index(source_parent_id);
554        let target_shard_index = self.get_shard_index(target_parent_id);
555        let node_shard_index = self.get_shard_index(node_id);
556        let source_parent = self.nodes[source_shard_index]
557            .get(source_parent_id)
558            .ok_or(error_helpers::parent_not_found(source_parent_id.clone()))?;
559        let target_parent = self.nodes[target_shard_index]
560            .get(target_parent_id)
561            .ok_or(error_helpers::parent_not_found(target_parent_id.clone()))?;
562        let _node = self.nodes[node_shard_index]
563            .get(node_id)
564            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
565        if !source_parent.content.contains(node_id) {
566            return Err(error_helpers::invalid_parenting(
567                node_id.clone(),
568                source_parent_id.clone(),
569            ));
570        }
571        let mut new_source_parent = source_parent.as_ref().clone();
572        new_source_parent.content = new_source_parent
573            .content
574            .iter()
575            .filter(|&id| id != node_id)
576            .cloned()
577            .collect();
578        let mut new_target_parent = target_parent.as_ref().clone();
579        if let Some(pos) = position {
580            // 确保position不超过当前content的长度
581            let insert_pos = pos.min(new_target_parent.content.len());
582
583            // 在指定位置插入节点
584            new_target_parent.content.insert(insert_pos, node_id.clone());
585        } else {
586            // 没有指定位置,添加到末尾
587            new_target_parent.content.push_back(node_id.clone());
588        }
589        self.nodes[source_shard_index] = self.nodes[source_shard_index]
590            .update(source_parent_id.clone(), Arc::new(new_source_parent));
591        self.nodes[target_shard_index] = self.nodes[target_shard_index]
592            .update(target_parent_id.clone(), Arc::new(new_target_parent));
593        self.parent_map.insert(node_id.clone(), target_parent_id.clone());
594        Ok(())
595    }
596
597    pub fn remove_node(
598        &mut self,
599        parent_id: &NodeId,
600        nodes: Vec<NodeId>,
601    ) -> PoolResult<()> {
602        let parent_shard_index = self.get_shard_index(parent_id);
603        let parent = self.nodes[parent_shard_index]
604            .get(parent_id)
605            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
606        if nodes.contains(&self.root_id) {
607            return Err(error_helpers::cannot_remove_root());
608        }
609        for node_id in &nodes {
610            if !parent.content.contains(node_id) {
611                return Err(error_helpers::invalid_parenting(
612                    node_id.clone(),
613                    parent_id.clone(),
614                ));
615            }
616        }
617        let nodes_to_remove: std::collections::HashSet<_> =
618            nodes.iter().collect();
619        let filtered_children: imbl::Vector<NodeId> = parent
620            .as_ref()
621            .content
622            .iter()
623            .filter(|&id| !nodes_to_remove.contains(id))
624            .cloned()
625            .collect();
626        let mut parent_node = parent.as_ref().clone();
627        parent_node.content = filtered_children;
628        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
629            .update(parent_id.clone(), Arc::new(parent_node));
630        let mut remove_nodes = Vec::new();
631        for node_id in nodes {
632            self.remove_subtree(&node_id, &mut remove_nodes)?;
633        }
634        Ok(())
635    }
636    //=删除节点
637    pub fn remove_node_by_id(
638        &mut self,
639        node_id: &NodeId,
640    ) -> PoolResult<()> {
641        // 检查是否试图删除根节点
642        if node_id == &self.root_id {
643            return Err(error_helpers::cannot_remove_root());
644        }
645
646        let shard_index = self.get_shard_index(node_id);
647        let _ = self.nodes[shard_index]
648            .get(node_id)
649            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
650
651        // 从父节点的content中移除该节点
652        if let Some(parent_id) = self.parent_map.get(node_id).cloned() {
653            let parent_shard_index = self.get_shard_index(&parent_id);
654            if let Some(parent_node) =
655                self.nodes[parent_shard_index].get(&parent_id)
656            {
657                let mut new_parent = parent_node.as_ref().clone();
658                new_parent.content = new_parent
659                    .content
660                    .iter()
661                    .filter(|&id| id != node_id)
662                    .cloned()
663                    .collect();
664                self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
665                    .update(parent_id.clone(), Arc::new(new_parent));
666            }
667        }
668
669        // 删除子树(remove_subtree内部已经处理了节点的删除和parent_map的清理)
670        let mut remove_nodes = Vec::new();
671        self.remove_subtree(node_id, &mut remove_nodes)?;
672
673        // remove_subtree已经删除了所有节点,包括node_id本身,所以这里不需要再次删除
674        Ok(())
675    }
676
677    ///根据下标删除
678    pub fn remove_node_by_index(
679        &mut self,
680        parent_id: &NodeId,
681        index: usize,
682    ) -> PoolResult<()> {
683        let shard_index = self.get_shard_index(parent_id);
684        let parent = self.nodes[shard_index]
685            .get(parent_id)
686            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
687        let mut new_parent = parent.as_ref().clone();
688        let remove_node_id = new_parent.content.remove(index);
689        self.nodes[shard_index] = self.nodes[shard_index]
690            .update(parent_id.clone(), Arc::new(new_parent));
691        let mut remove_nodes = Vec::new();
692        self.remove_subtree(&remove_node_id, &mut remove_nodes)?;
693        Ok(())
694    }
695
696    //删除子树
697    fn remove_subtree(
698        &mut self,
699        node_id: &NodeId,
700        remove_nodes: &mut Vec<Node>,
701    ) -> PoolResult<()> {
702        if node_id == &self.root_id {
703            return Err(error_helpers::cannot_remove_root());
704        }
705        let shard_index = self.get_shard_index(node_id);
706        let _ = self.nodes[shard_index]
707            .get(node_id)
708            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
709        if let Some(children) = self.children(node_id) {
710            for child_id in children {
711                self.remove_subtree(&child_id, remove_nodes)?;
712            }
713        }
714        self.parent_map.remove(node_id);
715        if let Some(remove_node) = self.nodes[shard_index].remove(node_id) {
716            remove_nodes.push(remove_node.as_ref().clone());
717        }
718        Ok(())
719    }
720}
721
722impl Index<&NodeId> for Tree {
723    type Output = Arc<Node>;
724    fn index(
725        &self,
726        index: &NodeId,
727    ) -> &Self::Output {
728        let shard_index = self.get_shard_index(index);
729        self.nodes[shard_index].get(index).expect("Node not found")
730    }
731}
732
733impl Index<&str> for Tree {
734    type Output = Arc<Node>;
735    fn index(
736        &self,
737        index: &str,
738    ) -> &Self::Output {
739        let node_id = NodeId::from(index);
740        let shard_index = self.get_shard_index(&node_id);
741        self.nodes[shard_index].get(&node_id).expect("Node not found")
742    }
743}
744
745#[cfg(test)]
746mod tests {
747    use super::*;
748    use crate::node::Node;
749    use crate::attrs::Attrs;
750    use crate::mark::Mark;
751    use imbl::HashMap;
752    use serde_json::json;
753
754    fn create_test_node(id: &str) -> Node {
755        Node::new(id, "test".to_string(), Attrs::default(), vec![], vec![])
756    }
757
758    #[test]
759    fn test_tree_creation() {
760        let root = create_test_node("root");
761        let tree = Tree::new(root.clone());
762        assert_eq!(tree.root_id, root.id);
763        assert!(tree.contains_node(&root.id));
764    }
765
766    #[test]
767    fn test_add_node() {
768        let root = create_test_node("root");
769        let mut tree = Tree::new(root.clone());
770
771        let child = create_test_node("child");
772        let nodes = vec![child.clone()];
773
774        tree.add_node(&root.id, &nodes).unwrap();
775        #[cfg(feature = "debug-logs")]
776        dbg!(&tree);
777        assert!(tree.contains_node(&child.id));
778        assert_eq!(tree.children(&root.id).unwrap().len(), 1);
779    }
780
781    #[test]
782    fn test_remove_node() {
783        let root = create_test_node("root");
784        let mut tree = Tree::new(root.clone());
785
786        let child = create_test_node("child");
787        let nodes = vec![child.clone()];
788
789        tree.add_node(&root.id, &nodes).unwrap();
790        #[cfg(feature = "debug-logs")]
791        dbg!(&tree);
792        tree.remove_node(&root.id, vec![child.id.clone()]).unwrap();
793        #[cfg(feature = "debug-logs")]
794        dbg!(&tree);
795        assert!(!tree.contains_node(&child.id));
796        assert_eq!(tree.children(&root.id).unwrap().len(), 0);
797    }
798
799    #[test]
800    fn test_move_node() {
801        // 创建两个父节点
802        let parent1 = create_test_node("parent1");
803        let parent2 = create_test_node("parent2");
804        let mut tree = Tree::new(parent1.clone());
805
806        // 将 parent2 添加为 parent1 的子节点
807        tree.add_node(&parent1.id, &vec![parent2.clone()]).unwrap();
808
809        // 创建三个子节点
810        let child1 = create_test_node("child1");
811        let child2 = create_test_node("child2");
812        let child3 = create_test_node("child3");
813
814        // 将所有子节点添加到 parent1 下
815        tree.add_node(&parent1.id, &vec![child1.clone()]).unwrap();
816        tree.add_node(&parent1.id, &vec![child2.clone()]).unwrap();
817        tree.add_node(&parent1.id, &vec![child3.clone()]).unwrap();
818
819        // 验证初始状态
820        let parent1_children = tree.children(&parent1.id).unwrap();
821        assert_eq!(parent1_children.len(), 4); // parent2 + 3 children
822        assert_eq!(parent1_children[0], parent2.id);
823        assert_eq!(parent1_children[1], child1.id);
824        assert_eq!(parent1_children[2], child2.id);
825        assert_eq!(parent1_children[3], child3.id);
826
827        // 将 child1 移动到 parent2 下
828        tree.move_node(&parent1.id, &parent2.id, &child1.id, None).unwrap();
829
830        // 验证移动后的状态
831        let parent1_children = tree.children(&parent1.id).unwrap();
832        let parent2_children = tree.children(&parent2.id).unwrap();
833        assert_eq!(parent1_children.len(), 3); // parent2 + 2 children
834        assert_eq!(parent2_children.len(), 1); // child1
835        assert_eq!(parent2_children[0], child1.id);
836
837        // 将 child2 移动到 parent2 下,放在 child1 后面
838        tree.move_node(&parent1.id, &parent2.id, &child2.id, Some(1)).unwrap();
839
840        // 验证最终状态
841        let parent1_children = tree.children(&parent1.id).unwrap();
842        let parent2_children = tree.children(&parent2.id).unwrap();
843        assert_eq!(parent1_children.len(), 2); // parent2 + 1 child
844        assert_eq!(parent2_children.len(), 2); // child1 + child2
845        assert_eq!(parent2_children[0], child1.id);
846        assert_eq!(parent2_children[1], child2.id);
847
848        // 验证父节点关系
849        let child1_parent = tree.get_parent_node(&child1.id).unwrap();
850        let child2_parent = tree.get_parent_node(&child2.id).unwrap();
851        assert_eq!(child1_parent.id, parent2.id);
852        assert_eq!(child2_parent.id, parent2.id);
853    }
854
855    #[test]
856    fn test_update_attr() {
857        let root = create_test_node("root");
858        let mut tree = Tree::new(root.clone());
859
860        let mut attrs = HashMap::new();
861        attrs.insert("key".to_string(), json!("value"));
862
863        tree.update_attr(&root.id, attrs).unwrap();
864
865        let node = tree.get_node(&root.id).unwrap();
866        #[cfg(feature = "debug-logs")]
867        dbg!(&node);
868        assert_eq!(node.attrs.get("key").unwrap(), &json!("value"));
869    }
870
871    #[test]
872    fn test_add_mark() {
873        let root = create_test_node("root");
874        let mut tree = Tree::new(root.clone());
875
876        let mark = Mark { r#type: "test".to_string(), attrs: Attrs::default() };
877        tree.add_mark(&root.id, &[mark.clone()]).unwrap();
878        #[cfg(feature = "debug-logs")]
879        dbg!(&tree);
880        let node = tree.get_node(&root.id).unwrap();
881        assert!(node.marks.contains(&mark));
882    }
883
884    #[test]
885    fn test_remove_mark() {
886        let root = create_test_node("root");
887        let mut tree = Tree::new(root.clone());
888
889        let mark = Mark { r#type: "test".to_string(), attrs: Attrs::default() };
890        tree.add_mark(&root.id, &[mark.clone()]).unwrap();
891        #[cfg(feature = "debug-logs")]
892        dbg!(&tree);
893        tree.remove_mark(&root.id, &[mark.r#type.clone()]).unwrap();
894        #[cfg(feature = "debug-logs")]
895        dbg!(&tree);
896        let node = tree.get_node(&root.id).unwrap();
897        assert!(!node.marks.iter().any(|m| m.r#type == mark.r#type));
898    }
899
900    #[test]
901    fn test_all_children() {
902        let root = create_test_node("root");
903        let mut tree = Tree::new(root.clone());
904
905        let child1 = create_test_node("child1");
906        let child2 = create_test_node("child2");
907
908        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
909        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
910        #[cfg(feature = "debug-logs")]
911        dbg!(&tree);
912        let all_children = tree.all_children(&root.id, None).unwrap();
913        assert_eq!(all_children.1.len(), 2);
914    }
915
916    #[test]
917    fn test_children_count() {
918        let root = create_test_node("root");
919        let mut tree = Tree::new(root.clone());
920
921        let child1 = create_test_node("child1");
922        let child2 = create_test_node("child2");
923
924        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
925        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
926
927        assert_eq!(tree.children_count(&root.id), 2);
928    }
929
930    #[test]
931    fn test_remove_node_by_id_updates_parent() {
932        let root = create_test_node("root");
933        let mut tree = Tree::new(root.clone());
934
935        let child = create_test_node("child");
936        tree.add_node(&root.id, &vec![child.clone()]).unwrap();
937
938        // 验证子节点被添加
939        assert_eq!(tree.children_count(&root.id), 1);
940        assert!(tree.contains_node(&child.id));
941
942        // 删除子节点
943        tree.remove_node_by_id(&child.id).unwrap();
944
945        // 验证子节点被删除且父节点的content被更新
946        assert_eq!(tree.children_count(&root.id), 0);
947        assert!(!tree.contains_node(&child.id));
948    }
949
950    #[test]
951    fn test_move_node_position_edge_cases() {
952        let root = create_test_node("root");
953        let mut tree = Tree::new(root.clone());
954
955        let container = create_test_node("container");
956        tree.add_node(&root.id, &vec![container.clone()]).unwrap();
957
958        let child1 = create_test_node("child1");
959        let child2 = create_test_node("child2");
960        let child3 = create_test_node("child3");
961
962        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
963        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
964        tree.add_node(&root.id, &vec![child3.clone()]).unwrap();
965
966        // 测试移动到超出范围的位置(应该插入到末尾)
967        tree.move_node(&root.id, &container.id, &child1.id, Some(100)).unwrap();
968
969        let container_children = tree.children(&container.id).unwrap();
970        assert_eq!(container_children.len(), 1);
971        assert_eq!(container_children[0], child1.id);
972
973        // 测试移动到位置0
974        tree.move_node(&root.id, &container.id, &child2.id, Some(0)).unwrap();
975
976        let container_children = tree.children(&container.id).unwrap();
977        assert_eq!(container_children.len(), 2);
978        assert_eq!(container_children[0], child2.id);
979        assert_eq!(container_children[1], child1.id);
980    }
981
982    #[test]
983    fn test_cannot_remove_root_node() {
984        let root = create_test_node("root");
985        let mut tree = Tree::new(root.clone());
986
987        // 尝试删除根节点应该失败
988        let result = tree.remove_node_by_id(&root.id);
989        assert!(result.is_err());
990    }
991
992    #[test]
993    fn test_get_parent_node() {
994        let root = create_test_node("root");
995        let mut tree = Tree::new(root.clone());
996
997        let child = create_test_node("child");
998        tree.add_node(&root.id, &vec![child.clone()]).unwrap();
999
1000        let parent = tree.get_parent_node(&child.id).unwrap();
1001        assert_eq!(parent.id, root.id);
1002    }
1003}