mf_model/
tree.rs

1use std::sync::Arc;
2use std::{ops::Index, num::NonZeroUsize};
3use std::hash::{Hash, Hasher};
4use std::collections::hash_map::DefaultHasher;
5use imbl::Vector;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use once_cell::sync::Lazy;
9use parking_lot::RwLock;
10use lru::LruCache;
11use std::fmt::{self, Debug};
12use crate::error::PoolResult;
13use crate::node_type::NodeEnum;
14use crate::{
15    error::error_helpers,
16    mark::Mark,
17    node::Node,
18    ops::{AttrsRef, MarkRef, NodeRef},
19    types::NodeId,
20};
21
22// 全局LRU缓存用于存储NodeId到分片索引的映射
23static SHARD_INDEX_CACHE: Lazy<RwLock<LruCache<NodeId, usize>>> =
24    Lazy::new(|| {
25        RwLock::new(LruCache::new(
26            NonZeroUsize::new(10000).expect("cache size > 0"),
27        ))
28    });
29
30#[derive(Clone, PartialEq, Serialize, Deserialize)]
31pub struct Tree {
32    pub root_id: NodeId,
33    pub nodes: Vector<imbl::HashMap<NodeId, Arc<Node>>>, // 分片存储节点数据
34    pub parent_map: imbl::HashMap<NodeId, NodeId>,
35    #[serde(skip)]
36    num_shards: usize, // 缓存分片数量,避免重复计算
37}
38impl Debug for Tree {
39    fn fmt(
40        &self,
41        f: &mut fmt::Formatter<'_>,
42    ) -> fmt::Result {
43        //输出的时候 过滤掉空的 nodes 节点
44        let nodes = self
45            .nodes
46            .iter()
47            .filter(|node| !node.is_empty())
48            .collect::<Vec<_>>();
49        f.debug_struct("Tree")
50            .field("root_id", &self.root_id)
51            .field("nodes", &nodes)
52            .field("parent_map", &self.parent_map)
53            .field("num_shards", &self.num_shards)
54            .finish()
55    }
56}
57
58impl Tree {
59    #[inline]
60    pub fn get_shard_index(
61        &self,
62        id: &NodeId,
63    ) -> usize {
64        // 先检查缓存
65        {
66            let cache = SHARD_INDEX_CACHE.read();
67            if let Some(&index) = cache.peek(id) {
68                return index;
69            }
70        }
71
72        // 缓存未命中,计算哈希值
73        let mut hasher = DefaultHasher::new();
74        id.hash(&mut hasher);
75        let index = (hasher.finish() as usize) % self.num_shards;
76
77        // 更新缓存
78        {
79            let mut cache = SHARD_INDEX_CACHE.write();
80            cache.put(id.clone(), index);
81        }
82
83        index
84    }
85
86    #[inline]
87    pub fn get_shard_indices(
88        &self,
89        ids: &[&NodeId],
90    ) -> Vec<usize> {
91        ids.iter().map(|id| self.get_shard_index(id)).collect()
92    }
93
94    // 为批量操作提供优化的哈希计算
95    #[inline]
96    pub fn get_shard_index_batch<'a>(
97        &self,
98        ids: &'a [&'a NodeId],
99    ) -> Vec<(usize, &'a NodeId)> {
100        let mut results = Vec::with_capacity(ids.len());
101        let mut cache_misses = Vec::new();
102
103        // 批量检查缓存
104        {
105            let cache = SHARD_INDEX_CACHE.read();
106            for &id in ids {
107                if let Some(&index) = cache.peek(id) {
108                    results.push((index, id));
109                } else {
110                    cache_misses.push(id);
111                }
112            }
113        }
114
115        // 批量计算缓存未命中的项
116        if !cache_misses.is_empty() {
117            let mut cache = SHARD_INDEX_CACHE.write();
118            for &id in &cache_misses {
119                let mut hasher = DefaultHasher::new();
120                id.hash(&mut hasher);
121                let index = (hasher.finish() as usize) % self.num_shards;
122                cache.put(id.clone(), index);
123                results.push((index, id));
124            }
125        }
126
127        results
128    }
129
130    // 清理缓存的方法,用于内存管理
131    pub fn clear_shard_cache() {
132        let mut cache = SHARD_INDEX_CACHE.write();
133        cache.clear();
134    }
135
136    pub fn contains_node(
137        &self,
138        id: &NodeId,
139    ) -> bool {
140        let shard_index = self.get_shard_index(id);
141        self.nodes[shard_index].contains_key(id)
142    }
143
144    pub fn get_node(
145        &self,
146        id: &NodeId,
147    ) -> Option<Arc<Node>> {
148        let shard_index = self.get_shard_index(id);
149        self.nodes[shard_index].get(id).cloned()
150    }
151
152    pub fn get_parent_node(
153        &self,
154        id: &NodeId,
155    ) -> Option<Arc<Node>> {
156        self.parent_map.get(id).and_then(|parent_id| {
157            let shard_index = self.get_shard_index(parent_id);
158            self.nodes[shard_index].get(parent_id).cloned()
159        })
160    }
161    pub fn from(nodes: NodeEnum) -> Self {
162        let num_shards = std::cmp::max(
163            std::thread::available_parallelism()
164                .map(NonZeroUsize::get)
165                .unwrap_or(2),
166            2,
167        );
168        let mut shards = Vector::from(vec![imbl::HashMap::new(); num_shards]);
169        let mut parent_map = imbl::HashMap::new();
170        let (root_node, children) = nodes.into_parts();
171        let root_id = root_node.id.clone();
172
173        let mut hasher = DefaultHasher::new();
174        root_id.hash(&mut hasher);
175        let shard_index = (hasher.finish() as usize) % num_shards;
176        shards[shard_index] =
177            shards[shard_index].update(root_id.clone(), Arc::new(root_node));
178
179        fn process_children(
180            children: Vec<NodeEnum>,
181            parent_id: &NodeId,
182            shards: &mut Vector<imbl::HashMap<NodeId, Arc<Node>>>,
183            parent_map: &mut imbl::HashMap<NodeId, NodeId>,
184            num_shards: usize,
185        ) {
186            for child in children {
187                let (node, grand_children) = child.into_parts();
188                let node_id = node.id.clone();
189                let mut hasher = DefaultHasher::new();
190                node_id.hash(&mut hasher);
191                let shard_index = (hasher.finish() as usize) % num_shards;
192                shards[shard_index] =
193                    shards[shard_index].update(node_id.clone(), Arc::new(node));
194                parent_map.insert(node_id.clone(), parent_id.clone());
195
196                // Recursively process grand children
197                process_children(
198                    grand_children,
199                    &node_id,
200                    shards,
201                    parent_map,
202                    num_shards,
203                );
204            }
205        }
206
207        process_children(
208            children,
209            &root_id,
210            &mut shards,
211            &mut parent_map,
212            num_shards,
213        );
214
215        Self { root_id, nodes: shards, parent_map, num_shards }
216    }
217
218    pub fn new(root: Node) -> Self {
219        let num_shards = std::cmp::max(
220            std::thread::available_parallelism()
221                .map(NonZeroUsize::get)
222                .unwrap_or(2),
223            2,
224        );
225        let mut nodes = Vector::from(vec![imbl::HashMap::new(); num_shards]);
226        let root_id = root.id.clone();
227        let mut hasher = DefaultHasher::new();
228        root_id.hash(&mut hasher);
229        let shard_index = (hasher.finish() as usize) % num_shards;
230        nodes[shard_index] =
231            nodes[shard_index].update(root_id.clone(), Arc::new(root));
232        Self { root_id, nodes, parent_map: imbl::HashMap::new(), num_shards }
233    }
234
235    pub fn update_attr(
236        &mut self,
237        id: &NodeId,
238        new_values: imbl::HashMap<String, Value>,
239    ) -> PoolResult<()> {
240        let shard_index = self.get_shard_index(id);
241        let node = self.nodes[shard_index]
242            .get(id)
243            .ok_or(error_helpers::node_not_found(id.clone()))?;
244        let new_node = node.as_ref().update_attr(new_values);
245        self.nodes[shard_index] =
246            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
247        Ok(())
248    }
249    pub fn update_node(
250        &mut self,
251        node: Node,
252    ) -> PoolResult<()> {
253        let shard_index = self.get_shard_index(&node.id);
254        self.nodes[shard_index] =
255            self.nodes[shard_index].update(node.id.clone(), Arc::new(node));
256        Ok(())
257    }
258
259    /// 向树中添加新的节点及其子节点
260    ///
261    /// # 参数
262    /// * `nodes` - 要添加的节点枚举,包含节点本身及其子节点
263    ///
264    /// # 返回值
265    /// * `Result<(), PoolError>` - 如果添加成功返回 Ok(()), 否则返回错误
266    ///
267    /// # 错误
268    /// * `PoolError::ParentNotFound` - 如果父节点不存在
269    pub fn add(
270        &mut self,
271        parent_id: &NodeId,
272        nodes: Vec<NodeEnum>,
273    ) -> PoolResult<()> {
274        // 检查父节点是否存在
275        let parent_shard_index = self.get_shard_index(&parent_id);
276        let parent_node = self.nodes[parent_shard_index]
277            .get(parent_id)
278            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
279        let mut new_parent = parent_node.as_ref().clone();
280
281        // 收集所有子节点的ID并添加到当前节点的content中
282        let zenliang: Vector<NodeId> =
283            nodes.iter().map(|n| n.0.id.clone()).collect();
284        // 需要判断 new_parent.content 中是否已经存在 zenliang 中的节点
285        let mut new_content = imbl::Vector::new();
286        for id in zenliang {
287            if !new_parent.content.contains(&id) {
288                new_content.push_back(id);
289            }
290        }
291        new_parent.content.extend(new_content);
292
293        // 更新当前节点
294        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
295            .update(parent_id.clone(), Arc::new(new_parent));
296
297        // 使用队列进行广度优先遍历,处理所有子节点
298        let mut node_queue = Vec::new();
299        node_queue.push((nodes, parent_id.clone()));
300        while let Some((current_children, current_parent_id)) = node_queue.pop()
301        {
302            for child in current_children {
303                // 处理每个子节点
304                let (mut child_node, grand_children) = child.into_parts();
305                let current_node_id = child_node.id.clone();
306
307                // 收集孙节点的ID并添加到子节点的content中
308                let grand_children_ids: Vector<NodeId> =
309                    grand_children.iter().map(|n| n.0.id.clone()).collect();
310                let mut new_content = imbl::Vector::new();
311                for id in grand_children_ids {
312                    if !child_node.content.contains(&id) {
313                        new_content.push_back(id);
314                    }
315                }
316                child_node.content.extend(new_content);
317
318                // 将当前节点存储到对应的分片中
319                let shard_index = self.get_shard_index(&current_node_id);
320                self.nodes[shard_index] = self.nodes[shard_index]
321                    .update(current_node_id.clone(), Arc::new(child_node));
322
323                // 更新父子关系映射
324                self.parent_map
325                    .insert(current_node_id.clone(), current_parent_id.clone());
326
327                // 将孙节点加入队列,以便后续处理
328                node_queue.push((grand_children, current_node_id.clone()));
329            }
330        }
331        Ok(())
332    }
333    // 添加到下标
334    pub fn add_at_index(
335        &mut self,
336        parent_id: &NodeId,
337        index: usize,
338        node: &Node,
339    ) -> PoolResult<()> {
340        //添加到节点到 parent_id 的 content 中
341        let parent_shard_index = self.get_shard_index(parent_id);
342        let parent = self.nodes[parent_shard_index]
343            .get(parent_id)
344            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
345        let new_parent =
346            parent.as_ref().insert_content_at_index(index, &node.id);
347        //更新父节点
348        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
349            .update(parent_id.clone(), Arc::new(new_parent));
350        //更新父子关系映射
351        self.parent_map.insert(node.id.clone(), parent_id.clone());
352        //更新子节点
353        let shard_index = self.get_shard_index(&node.id);
354        self.nodes[shard_index] = self.nodes[shard_index]
355            .update(node.id.clone(), Arc::new(node.clone()));
356        Ok(())
357    }
358    pub fn add_node(
359        &mut self,
360        parent_id: &NodeId,
361        nodes: &Vec<Node>,
362    ) -> PoolResult<()> {
363        let parent_shard_index = self.get_shard_index(parent_id);
364        let parent = self.nodes[parent_shard_index]
365            .get(parent_id)
366            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
367        let node_ids = nodes.iter().map(|n| n.id.clone()).collect();
368        // 更新父节点 - 添加所有节点的ID到content中
369        let new_parent = parent.as_ref().insert_contents(&node_ids);
370
371        // 更新父节点到分片中
372        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
373            .update(parent_id.clone(), Arc::new(new_parent));
374
375        // 更新所有子节点
376        for node in nodes {
377            // 设置当前节点的父子关系映射
378            self.parent_map.insert(node.id.clone(), parent_id.clone());
379
380            // 设置当前节点的子节点的父子关系映射
381            for child_id in &node.content {
382                self.parent_map.insert(child_id.clone(), node.id.clone());
383            }
384
385            // 将节点添加到对应的分片中
386            let shard_index = self.get_shard_index(&node.id);
387            self.nodes[shard_index] = self.nodes[shard_index]
388                .update(node.id.clone(), Arc::new(node.clone()));
389        }
390        Ok(())
391    }
392
393    pub fn node(
394        &mut self,
395        key: &str,
396    ) -> NodeRef<'_> {
397        NodeRef::new(self, key.into())
398    }
399    pub fn mark(
400        &mut self,
401        key: &str,
402    ) -> MarkRef<'_> {
403        MarkRef::new(self, key.into())
404    }
405    pub fn attrs(
406        &mut self,
407        key: &str,
408    ) -> AttrsRef<'_> {
409        AttrsRef::new(self, key.into())
410    }
411
412    pub fn children(
413        &self,
414        parent_id: &NodeId,
415    ) -> Option<imbl::Vector<NodeId>> {
416        self.get_node(parent_id).map(|n| n.content.clone())
417    }
418
419    pub fn children_node(
420        &self,
421        parent_id: &NodeId,
422    ) -> Option<imbl::Vector<Arc<Node>>> {
423        self.children(parent_id)
424            .map(|ids| ids.iter().filter_map(|id| self.get_node(id)).collect())
425    }
426    //递归获取所有子节点 封装成 NodeEnum 返回
427    pub fn all_children(
428        &self,
429        parent_id: &NodeId,
430        filter: Option<&dyn Fn(&Node) -> bool>,
431    ) -> Option<NodeEnum> {
432        if let Some(node) = self.get_node(parent_id) {
433            let mut child_enums = Vec::new();
434            for child_id in &node.content {
435                if let Some(child_node) = self.get_node(child_id) {
436                    // 检查子节点是否满足过滤条件
437                    if let Some(filter_fn) = filter {
438                        if !filter_fn(child_node.as_ref()) {
439                            continue; // 跳过不满足条件的子节点
440                        }
441                    }
442                    // 递归处理满足条件的子节点
443                    if let Some(child_enum) =
444                        self.all_children(child_id, filter)
445                    {
446                        child_enums.push(child_enum);
447                    }
448                }
449            }
450            Some(NodeEnum(node.as_ref().clone(), child_enums))
451        } else {
452            None
453        }
454    }
455
456    pub fn children_count(
457        &self,
458        parent_id: &NodeId,
459    ) -> usize {
460        self.get_node(parent_id).map(|n| n.content.len()).unwrap_or(0)
461    }
462    pub fn remove_mark_by_name(
463        &mut self,
464        id: &NodeId,
465        mark_name: &str,
466    ) -> PoolResult<()> {
467        let shard_index = self.get_shard_index(id);
468        let node = self.nodes[shard_index]
469            .get(id)
470            .ok_or(error_helpers::node_not_found(id.clone()))?;
471        let new_node = node.as_ref().remove_mark_by_name(mark_name);
472        self.nodes[shard_index] =
473            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
474        Ok(())
475    }
476    pub fn get_marks(
477        &self,
478        id: &NodeId,
479    ) -> Option<imbl::Vector<Mark>> {
480        self.get_node(id).map(|n| n.marks.clone())
481    }
482
483    pub fn remove_mark(
484        &mut self,
485        id: &NodeId,
486        mark_types: &[String],
487    ) -> PoolResult<()> {
488        let shard_index = self.get_shard_index(id);
489        let node = self.nodes[shard_index]
490            .get(id)
491            .ok_or(error_helpers::node_not_found(id.clone()))?;
492        let new_node = node.as_ref().remove_mark(mark_types);
493        self.nodes[shard_index] =
494            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
495        Ok(())
496    }
497
498    pub fn add_mark(
499        &mut self,
500        id: &NodeId,
501        marks: &Vec<Mark>,
502    ) -> PoolResult<()> {
503        let shard_index = self.get_shard_index(id);
504        let node = self.nodes[shard_index]
505            .get(id)
506            .ok_or(error_helpers::node_not_found(id.clone()))?;
507        let new_node = node.as_ref().add_marks(marks);
508        self.nodes[shard_index] =
509            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
510        Ok(())
511    }
512
513    pub fn move_node(
514        &mut self,
515        source_parent_id: &NodeId,
516        target_parent_id: &NodeId,
517        node_id: &NodeId,
518        position: Option<usize>,
519    ) -> PoolResult<()> {
520        let source_shard_index = self.get_shard_index(source_parent_id);
521        let target_shard_index = self.get_shard_index(target_parent_id);
522        let node_shard_index = self.get_shard_index(node_id);
523        let source_parent = self.nodes[source_shard_index]
524            .get(source_parent_id)
525            .ok_or(error_helpers::parent_not_found(source_parent_id.clone()))?;
526        let target_parent = self.nodes[target_shard_index]
527            .get(target_parent_id)
528            .ok_or(error_helpers::parent_not_found(target_parent_id.clone()))?;
529        let _node = self.nodes[node_shard_index]
530            .get(node_id)
531            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
532        if !source_parent.content.contains(node_id) {
533            return Err(error_helpers::invalid_parenting(
534                node_id.clone(),
535                source_parent_id.clone(),
536            ));
537        }
538        let mut new_source_parent = source_parent.as_ref().clone();
539        new_source_parent.content = new_source_parent
540            .content
541            .iter()
542            .filter(|&id| id != node_id)
543            .cloned()
544            .collect();
545        let mut new_target_parent = target_parent.as_ref().clone();
546        if let Some(pos) = position {
547            // 确保position不超过当前content的长度
548            let insert_pos = pos.min(new_target_parent.content.len());
549
550            // 在指定位置插入节点
551            new_target_parent.content.insert(insert_pos, node_id.clone());
552        } else {
553            // 没有指定位置,添加到末尾
554            new_target_parent.content.push_back(node_id.clone());
555        }
556        self.nodes[source_shard_index] = self.nodes[source_shard_index]
557            .update(source_parent_id.clone(), Arc::new(new_source_parent));
558        self.nodes[target_shard_index] = self.nodes[target_shard_index]
559            .update(target_parent_id.clone(), Arc::new(new_target_parent));
560        self.parent_map.insert(node_id.clone(), target_parent_id.clone());
561        Ok(())
562    }
563
564    pub fn remove_node(
565        &mut self,
566        parent_id: &NodeId,
567        nodes: Vec<NodeId>,
568    ) -> PoolResult<()> {
569        let parent_shard_index = self.get_shard_index(parent_id);
570        let parent = self.nodes[parent_shard_index]
571            .get(parent_id)
572            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
573        if nodes.contains(&self.root_id) {
574            return Err(error_helpers::cannot_remove_root());
575        }
576        for node_id in &nodes {
577            if !parent.content.contains(node_id) {
578                return Err(error_helpers::invalid_parenting(
579                    node_id.clone(),
580                    parent_id.clone(),
581                ));
582            }
583        }
584        let nodes_to_remove: std::collections::HashSet<_> =
585            nodes.iter().collect();
586        let filtered_children: imbl::Vector<NodeId> = parent
587            .as_ref()
588            .content
589            .iter()
590            .filter(|&id| !nodes_to_remove.contains(id))
591            .cloned()
592            .collect();
593        let mut parent_node = parent.as_ref().clone();
594        parent_node.content = filtered_children;
595        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
596            .update(parent_id.clone(), Arc::new(parent_node));
597        let mut remove_nodes = Vec::new();
598        for node_id in nodes {
599            self.remove_subtree(&node_id, &mut remove_nodes)?;
600        }
601        Ok(())
602    }
603    //=删除节点
604    pub fn remove_node_by_id(
605        &mut self,
606        node_id: &NodeId,
607    ) -> PoolResult<()> {
608        // 检查是否试图删除根节点
609        if node_id == &self.root_id {
610            return Err(error_helpers::cannot_remove_root());
611        }
612
613        let shard_index = self.get_shard_index(node_id);
614        let _ = self.nodes[shard_index]
615            .get(node_id)
616            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
617
618        // 从父节点的content中移除该节点
619        if let Some(parent_id) = self.parent_map.get(node_id).cloned() {
620            let parent_shard_index = self.get_shard_index(&parent_id);
621            if let Some(parent_node) =
622                self.nodes[parent_shard_index].get(&parent_id)
623            {
624                let mut new_parent = parent_node.as_ref().clone();
625                new_parent.content = new_parent
626                    .content
627                    .iter()
628                    .filter(|&id| id != node_id)
629                    .cloned()
630                    .collect();
631                self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
632                    .update(parent_id.clone(), Arc::new(new_parent));
633            }
634        }
635
636        // 删除子树(remove_subtree内部已经处理了节点的删除和parent_map的清理)
637        let mut remove_nodes = Vec::new();
638        self.remove_subtree(node_id, &mut remove_nodes)?;
639
640        // remove_subtree已经删除了所有节点,包括node_id本身,所以这里不需要再次删除
641        Ok(())
642    }
643
644    ///根据下标删除
645    pub fn remove_node_by_index(
646        &mut self,
647        parent_id: &NodeId,
648        index: usize,
649    ) -> PoolResult<()> {
650        let shard_index = self.get_shard_index(parent_id);
651        let parent = self.nodes[shard_index]
652            .get(parent_id)
653            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
654        let mut new_parent = parent.as_ref().clone();
655        let remove_node_id = new_parent.content.remove(index);
656        self.nodes[shard_index] = self.nodes[shard_index]
657            .update(parent_id.clone(), Arc::new(new_parent));
658        let mut remove_nodes = Vec::new();
659        self.remove_subtree(&remove_node_id, &mut remove_nodes)?;
660        Ok(())
661    }
662
663    //删除子树
664    fn remove_subtree(
665        &mut self,
666        node_id: &NodeId,
667        remove_nodes: &mut Vec<Node>,
668    ) -> PoolResult<()> {
669        if node_id == &self.root_id {
670            return Err(error_helpers::cannot_remove_root());
671        }
672        let shard_index = self.get_shard_index(node_id);
673        let _ = self.nodes[shard_index]
674            .get(node_id)
675            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
676        if let Some(children) = self.children(node_id) {
677            for child_id in children {
678                self.remove_subtree(&child_id, remove_nodes)?;
679            }
680        }
681        self.parent_map.remove(node_id);
682        if let Some(remove_node) = self.nodes[shard_index].remove(node_id) {
683            remove_nodes.push(remove_node.as_ref().clone());
684        }
685        Ok(())
686    }
687}
688
689impl Index<&NodeId> for Tree {
690    type Output = Arc<Node>;
691    fn index(
692        &self,
693        index: &NodeId,
694    ) -> &Self::Output {
695        let shard_index = self.get_shard_index(index);
696        self.nodes[shard_index].get(index).expect("Node not found")
697    }
698}
699
700impl Index<&str> for Tree {
701    type Output = Arc<Node>;
702    fn index(
703        &self,
704        index: &str,
705    ) -> &Self::Output {
706        let node_id = NodeId::from(index);
707        let shard_index = self.get_shard_index(&node_id);
708        self.nodes[shard_index].get(&node_id).expect("Node not found")
709    }
710}
711
712#[cfg(test)]
713mod tests {
714    use super::*;
715    use crate::node::Node;
716    use crate::attrs::Attrs;
717    use crate::mark::Mark;
718    use imbl::HashMap;
719    use serde_json::json;
720
721    fn create_test_node(id: &str) -> Node {
722        Node::new(id, "test".to_string(), Attrs::default(), vec![], vec![])
723    }
724
725    #[test]
726    fn test_tree_creation() {
727        let root = create_test_node("root");
728        let tree = Tree::new(root.clone());
729        assert_eq!(tree.root_id, root.id);
730        assert!(tree.contains_node(&root.id));
731    }
732
733    #[test]
734    fn test_add_node() {
735        let root = create_test_node("root");
736        let mut tree = Tree::new(root.clone());
737
738        let child = create_test_node("child");
739        let nodes = vec![child.clone()];
740
741        tree.add_node(&root.id, &nodes).unwrap();
742        dbg!(&tree);
743        assert!(tree.contains_node(&child.id));
744        assert_eq!(tree.children(&root.id).unwrap().len(), 1);
745    }
746
747    #[test]
748    fn test_remove_node() {
749        let root = create_test_node("root");
750        let mut tree = Tree::new(root.clone());
751
752        let child = create_test_node("child");
753        let nodes = vec![child.clone()];
754
755        tree.add_node(&root.id, &nodes).unwrap();
756        dbg!(&tree);
757        tree.remove_node(&root.id, vec![child.id.clone()]).unwrap();
758        dbg!(&tree);
759        assert!(!tree.contains_node(&child.id));
760        assert_eq!(tree.children(&root.id).unwrap().len(), 0);
761    }
762
763    #[test]
764    fn test_move_node() {
765        // 创建两个父节点
766        let parent1 = create_test_node("parent1");
767        let parent2 = create_test_node("parent2");
768        let mut tree = Tree::new(parent1.clone());
769
770        // 将 parent2 添加为 parent1 的子节点
771        tree.add_node(&parent1.id, &vec![parent2.clone()]).unwrap();
772
773        // 创建三个子节点
774        let child1 = create_test_node("child1");
775        let child2 = create_test_node("child2");
776        let child3 = create_test_node("child3");
777
778        // 将所有子节点添加到 parent1 下
779        tree.add_node(&parent1.id, &vec![child1.clone()]).unwrap();
780        tree.add_node(&parent1.id, &vec![child2.clone()]).unwrap();
781        tree.add_node(&parent1.id, &vec![child3.clone()]).unwrap();
782
783        // 验证初始状态
784        let parent1_children = tree.children(&parent1.id).unwrap();
785        assert_eq!(parent1_children.len(), 4); // parent2 + 3 children
786        assert_eq!(parent1_children[0], parent2.id);
787        assert_eq!(parent1_children[1], child1.id);
788        assert_eq!(parent1_children[2], child2.id);
789        assert_eq!(parent1_children[3], child3.id);
790
791        // 将 child1 移动到 parent2 下
792        tree.move_node(&parent1.id, &parent2.id, &child1.id, None).unwrap();
793
794        // 验证移动后的状态
795        let parent1_children = tree.children(&parent1.id).unwrap();
796        let parent2_children = tree.children(&parent2.id).unwrap();
797        assert_eq!(parent1_children.len(), 3); // parent2 + 2 children
798        assert_eq!(parent2_children.len(), 1); // child1
799        assert_eq!(parent2_children[0], child1.id);
800
801        // 将 child2 移动到 parent2 下,放在 child1 后面
802        tree.move_node(&parent1.id, &parent2.id, &child2.id, Some(1)).unwrap();
803
804        // 验证最终状态
805        let parent1_children = tree.children(&parent1.id).unwrap();
806        let parent2_children = tree.children(&parent2.id).unwrap();
807        assert_eq!(parent1_children.len(), 2); // parent2 + 1 child
808        assert_eq!(parent2_children.len(), 2); // child1 + child2
809        assert_eq!(parent2_children[0], child1.id);
810        assert_eq!(parent2_children[1], child2.id);
811
812        // 验证父节点关系
813        let child1_parent = tree.get_parent_node(&child1.id).unwrap();
814        let child2_parent = tree.get_parent_node(&child2.id).unwrap();
815        assert_eq!(child1_parent.id, parent2.id);
816        assert_eq!(child2_parent.id, parent2.id);
817    }
818
819    #[test]
820    fn test_update_attr() {
821        let root = create_test_node("root");
822        let mut tree = Tree::new(root.clone());
823
824        let mut attrs = HashMap::new();
825        attrs.insert("key".to_string(), json!("value"));
826
827        tree.update_attr(&root.id, attrs).unwrap();
828
829        let node = tree.get_node(&root.id).unwrap();
830        dbg!(&node);
831        assert_eq!(node.attrs.get("key").unwrap(), &json!("value"));
832    }
833
834    #[test]
835    fn test_add_mark() {
836        let root = create_test_node("root");
837        let mut tree = Tree::new(root.clone());
838
839        let mark = Mark { r#type: "test".to_string(), attrs: Attrs::default() };
840        tree.add_mark(&root.id, &vec![mark.clone()]).unwrap();
841        dbg!(&tree);
842        let node = tree.get_node(&root.id).unwrap();
843        assert!(node.marks.contains(&mark));
844    }
845
846    #[test]
847    fn test_remove_mark() {
848        let root = create_test_node("root");
849        let mut tree = Tree::new(root.clone());
850
851        let mark = Mark { r#type: "test".to_string(), attrs: Attrs::default() };
852        tree.add_mark(&root.id, &vec![mark.clone()]).unwrap();
853        dbg!(&tree);
854        tree.remove_mark(&root.id, &[mark.r#type.clone()]).unwrap();
855        dbg!(&tree);
856        let node = tree.get_node(&root.id).unwrap();
857        assert!(!node.marks.iter().any(|m| m.r#type == mark.r#type));
858    }
859
860    #[test]
861    fn test_all_children() {
862        let root = create_test_node("root");
863        let mut tree = Tree::new(root.clone());
864
865        let child1 = create_test_node("child1");
866        let child2 = create_test_node("child2");
867
868        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
869        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
870        dbg!(&tree);
871        let all_children = tree.all_children(&root.id, None).unwrap();
872        assert_eq!(all_children.1.len(), 2);
873    }
874
875    #[test]
876    fn test_children_count() {
877        let root = create_test_node("root");
878        let mut tree = Tree::new(root.clone());
879
880        let child1 = create_test_node("child1");
881        let child2 = create_test_node("child2");
882
883        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
884        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
885
886        assert_eq!(tree.children_count(&root.id), 2);
887    }
888
889    #[test]
890    fn test_remove_node_by_id_updates_parent() {
891        let root = create_test_node("root");
892        let mut tree = Tree::new(root.clone());
893
894        let child = create_test_node("child");
895        tree.add_node(&root.id, &vec![child.clone()]).unwrap();
896
897        // 验证子节点被添加
898        assert_eq!(tree.children_count(&root.id), 1);
899        assert!(tree.contains_node(&child.id));
900
901        // 删除子节点
902        tree.remove_node_by_id(&child.id).unwrap();
903
904        // 验证子节点被删除且父节点的content被更新
905        assert_eq!(tree.children_count(&root.id), 0);
906        assert!(!tree.contains_node(&child.id));
907    }
908
909    #[test]
910    fn test_move_node_position_edge_cases() {
911        let root = create_test_node("root");
912        let mut tree = Tree::new(root.clone());
913
914        let container = create_test_node("container");
915        tree.add_node(&root.id, &vec![container.clone()]).unwrap();
916
917        let child1 = create_test_node("child1");
918        let child2 = create_test_node("child2");
919        let child3 = create_test_node("child3");
920
921        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
922        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
923        tree.add_node(&root.id, &vec![child3.clone()]).unwrap();
924
925        // 测试移动到超出范围的位置(应该插入到末尾)
926        tree.move_node(&root.id, &container.id, &child1.id, Some(100)).unwrap();
927
928        let container_children = tree.children(&container.id).unwrap();
929        assert_eq!(container_children.len(), 1);
930        assert_eq!(container_children[0], child1.id);
931
932        // 测试移动到位置0
933        tree.move_node(&root.id, &container.id, &child2.id, Some(0)).unwrap();
934
935        let container_children = tree.children(&container.id).unwrap();
936        assert_eq!(container_children.len(), 2);
937        assert_eq!(container_children[0], child2.id);
938        assert_eq!(container_children[1], child1.id);
939    }
940
941    #[test]
942    fn test_cannot_remove_root_node() {
943        let root = create_test_node("root");
944        let mut tree = Tree::new(root.clone());
945
946        // 尝试删除根节点应该失败
947        let result = tree.remove_node_by_id(&root.id);
948        assert!(result.is_err());
949    }
950
951    #[test]
952    fn test_get_parent_node() {
953        let root = create_test_node("root");
954        let mut tree = Tree::new(root.clone());
955
956        let child = create_test_node("child");
957        tree.add_node(&root.id, &vec![child.clone()]).unwrap();
958
959        let parent = tree.get_parent_node(&child.id).unwrap();
960        assert_eq!(parent.id, root.id);
961    }
962}