mf_model/
tree.rs

1use std::sync::Arc;
2use std::{ops::Index, num::NonZeroUsize};
3use std::hash::{Hash, Hasher};
4use std::collections::hash_map::DefaultHasher;
5use imbl::Vector;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use once_cell::sync::Lazy;
9use parking_lot::RwLock;
10use lru::LruCache;
11use std::fmt::{self, Debug};
12use crate::error::PoolResult;
13use crate::node_type::NodeEnum;
14use crate::{
15    error::error_helpers,
16    mark::Mark,
17    node::Node,
18    ops::{AttrsRef, MarkRef, NodeRef},
19    types::NodeId,
20};
21
22// 全局LRU缓存用于存储NodeId到分片索引的映射
23static SHARD_INDEX_CACHE: Lazy<RwLock<LruCache<String, usize>>> =
24    Lazy::new(|| RwLock::new(LruCache::new(NonZeroUsize::new(10000).unwrap())));
25
26#[derive(Clone, PartialEq, Serialize, Deserialize)]
27pub struct Tree {
28    pub root_id: NodeId,
29    pub nodes: Vector<imbl::HashMap<NodeId, Arc<Node>>>, // 分片存储节点数据
30    pub parent_map: imbl::HashMap<NodeId, NodeId>,
31    #[serde(skip)]
32    num_shards: usize, // 缓存分片数量,避免重复计算
33}
34impl Debug for Tree {
35    fn fmt(
36        &self,
37        f: &mut fmt::Formatter<'_>,
38    ) -> fmt::Result {
39        
40        //输出的时候 过滤掉空的 nodes 节点
41        let nodes = self
42            .nodes
43            .iter()
44            .filter(|node| !node.is_empty())
45            .collect::<Vec<_>>();
46        f.debug_struct("Tree")
47            .field("root_id", &self.root_id)
48            .field("nodes", &nodes)
49            .field("parent_map", &self.parent_map)
50            .field("num_shards", &self.num_shards)
51            .finish()
52    }
53}
54
55impl Tree {
56    #[inline]
57    pub fn get_shard_index(
58        &self,
59        id: &NodeId,
60    ) -> usize {
61        // 先检查缓存
62        {
63            let cache = SHARD_INDEX_CACHE.read();
64            if let Some(&index) = cache.peek(id) {
65                return index;
66            }
67        }
68
69        // 缓存未命中,计算哈希值
70        let mut hasher = DefaultHasher::new();
71        id.hash(&mut hasher);
72        let index = (hasher.finish() as usize) % self.num_shards;
73
74        // 更新缓存
75        {
76            let mut cache = SHARD_INDEX_CACHE.write();
77            cache.put(id.clone(), index);
78        }
79
80        index
81    }
82
83    #[inline]
84    pub fn get_shard_indices(
85        &self,
86        ids: &[&NodeId],
87    ) -> Vec<usize> {
88        ids.iter().map(|id| self.get_shard_index(id)).collect()
89    }
90
91    // 为批量操作提供优化的哈希计算
92    #[inline]
93    pub fn get_shard_index_batch<'a>(
94        &self,
95        ids: &'a [&'a NodeId],
96    ) -> Vec<(usize, &'a NodeId)> {
97        let mut results = Vec::with_capacity(ids.len());
98        let mut cache_misses = Vec::new();
99
100        // 批量检查缓存
101        {
102            let cache = SHARD_INDEX_CACHE.read();
103            for &id in ids {
104                if let Some(&index) = cache.peek(id) {
105                    results.push((index, id));
106                } else {
107                    cache_misses.push(id);
108                }
109            }
110        }
111
112        // 批量计算缓存未命中的项
113        if !cache_misses.is_empty() {
114            let mut cache = SHARD_INDEX_CACHE.write();
115            for &id in &cache_misses {
116                let mut hasher = DefaultHasher::new();
117                id.hash(&mut hasher);
118                let index = (hasher.finish() as usize) % self.num_shards;
119                cache.put(id.clone(), index);
120                results.push((index, id));
121            }
122        }
123
124        results
125    }
126
127    // 清理缓存的方法,用于内存管理
128    pub fn clear_shard_cache() {
129        let mut cache = SHARD_INDEX_CACHE.write();
130        cache.clear();
131    }
132
133    pub fn contains_node(
134        &self,
135        id: &NodeId,
136    ) -> bool {
137        let shard_index = self.get_shard_index(id);
138        self.nodes[shard_index].contains_key(id)
139    }
140
141    pub fn get_node(
142        &self,
143        id: &NodeId,
144    ) -> Option<Arc<Node>> {
145        let shard_index = self.get_shard_index(id);
146        self.nodes[shard_index].get(id).cloned()
147    }
148
149    pub fn get_parent_node(
150        &self,
151        id: &NodeId,
152    ) -> Option<Arc<Node>> {
153        self.parent_map.get(id).and_then(|parent_id| {
154            let shard_index = self.get_shard_index(parent_id);
155            self.nodes[shard_index].get(parent_id).cloned()
156        })
157    }
158    pub fn from(nodes: NodeEnum) -> Self {
159        let num_shards = std::cmp::max(
160            std::thread::available_parallelism()
161                .map(NonZeroUsize::get)
162                .unwrap_or(2),
163            2,
164        );
165        let mut shards = Vector::from(vec![imbl::HashMap::new(); num_shards]);
166        let mut parent_map = imbl::HashMap::new();
167        let (root_node, children) = nodes.into_parts();
168        let root_id = root_node.id.clone();
169
170        let mut hasher = DefaultHasher::new();
171        root_id.hash(&mut hasher);
172        let shard_index = (hasher.finish() as usize) % num_shards;
173        shards[shard_index] =
174            shards[shard_index].update(root_id.clone(), Arc::new(root_node));
175
176        fn process_children(
177            children: Vec<NodeEnum>,
178            parent_id: &NodeId,
179            shards: &mut Vector<imbl::HashMap<NodeId, Arc<Node>>>,
180            parent_map: &mut imbl::HashMap<NodeId, NodeId>,
181            num_shards: usize,
182        ) {
183            for child in children {
184                let (node, grand_children) = child.into_parts();
185                let node_id = node.id.clone();
186                let mut hasher = DefaultHasher::new();
187                node_id.hash(&mut hasher);
188                let shard_index = (hasher.finish() as usize) % num_shards;
189                shards[shard_index] =
190                    shards[shard_index].update(node_id.clone(), Arc::new(node));
191                parent_map.insert(node_id.clone(), parent_id.clone());
192
193                // Recursively process grand children
194                process_children(
195                    grand_children,
196                    &node_id,
197                    shards,
198                    parent_map,
199                    num_shards,
200                );
201            }
202        }
203
204        process_children(
205            children,
206            &root_id,
207            &mut shards,
208            &mut parent_map,
209            num_shards,
210        );
211
212        Self { root_id, nodes: shards, parent_map, num_shards }
213    }
214
215    pub fn new(root: Node) -> Self {
216        let num_shards = std::cmp::max(
217            std::thread::available_parallelism()
218                .map(NonZeroUsize::get)
219                .unwrap_or(2),
220            2,
221        );
222        let mut nodes = Vector::from(vec![imbl::HashMap::new(); num_shards]);
223        let root_id = root.id.clone();
224        let mut hasher = DefaultHasher::new();
225        root_id.hash(&mut hasher);
226        let shard_index = (hasher.finish() as usize) % num_shards;
227        nodes[shard_index] =
228            nodes[shard_index].update(root_id.clone(), Arc::new(root));
229        Self { root_id, nodes, parent_map: imbl::HashMap::new(), num_shards }
230    }
231
232    pub fn update_attr(
233        &mut self,
234        id: &NodeId,
235        new_values: imbl::HashMap<String, Value>,
236    ) -> PoolResult<()> {
237        let shard_index = self.get_shard_index(id);
238        let node = self.nodes[shard_index]
239            .get(id)
240            .ok_or(error_helpers::node_not_found(id.clone()))?;
241        let new_node = node.as_ref().update_attr(new_values);
242        self.nodes[shard_index] =
243            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
244        Ok(())
245    }
246    pub fn update_node(
247        &mut self,
248        node: Node,
249    ) -> PoolResult<()> {
250        let shard_index = self.get_shard_index(&node.id);
251        self.nodes[shard_index] =
252            self.nodes[shard_index].update(node.id.clone(), Arc::new(node));
253        Ok(())
254    }
255
256    /// 向树中添加新的节点及其子节点
257    ///
258    /// # 参数
259    /// * `nodes` - 要添加的节点枚举,包含节点本身及其子节点
260    ///
261    /// # 返回值
262    /// * `Result<(), PoolError>` - 如果添加成功返回 Ok(()), 否则返回错误
263    ///
264    /// # 错误
265    /// * `PoolError::ParentNotFound` - 如果父节点不存在
266    pub fn add(
267        &mut self,
268        parent_id: &NodeId,
269        nodes: Vec<NodeEnum>,
270    ) -> PoolResult<()> {
271        // 检查父节点是否存在
272        let parent_shard_index = self.get_shard_index(&parent_id);
273        let parent_node = self.nodes[parent_shard_index]
274            .get(parent_id)
275            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
276        let mut new_parent = parent_node.as_ref().clone();
277
278        // 收集所有子节点的ID并添加到当前节点的content中
279        let zenliang: Vector<String> =
280            nodes.iter().map(|n| n.0.id.clone()).collect();
281        // 需要判断 new_parent.content 中是否已经存在 zenliang 中的节点
282        let mut new_content = imbl::Vector::new();
283        for id in zenliang {
284            if !new_parent.content.contains(&id) {
285                new_content.push_back(id);
286            }
287        }
288        new_parent.content.extend(new_content);
289
290        // 更新当前节点
291        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
292            .update(parent_id.clone(), Arc::new(new_parent));
293
294        // 使用队列进行广度优先遍历,处理所有子节点
295        let mut node_queue = Vec::new();
296        node_queue.push((nodes, parent_id.clone()));
297        while let Some((current_children, current_parent_id)) = node_queue.pop()
298        {
299            for child in current_children {
300                // 处理每个子节点
301                let (mut child_node, grand_children) = child.into_parts();
302                let current_node_id = child_node.id.clone();
303
304                // 收集孙节点的ID并添加到子节点的content中
305                let grand_children_ids: Vector<String> =
306                    grand_children.iter().map(|n| n.0.id.clone()).collect();
307                let mut new_content = imbl::Vector::new();
308                for id in grand_children_ids {
309                    if !child_node.content.contains(&id) {
310                        new_content.push_back(id);
311                    }
312                }
313                child_node.content.extend(new_content);
314
315                // 将当前节点存储到对应的分片中
316                let shard_index = self.get_shard_index(&current_node_id);
317                self.nodes[shard_index] = self.nodes[shard_index]
318                    .update(current_node_id.clone(), Arc::new(child_node));
319
320                // 更新父子关系映射
321                self.parent_map
322                    .insert(current_node_id.clone(), current_parent_id.clone());
323
324                // 将孙节点加入队列,以便后续处理
325                node_queue.push((grand_children, current_node_id.clone()));
326            }
327        }
328        Ok(())
329    }
330    // 添加到下标
331    pub fn add_at_index(
332        &mut self,
333        parent_id: &NodeId,
334        index: usize,
335        node: &Node,
336    ) -> PoolResult<()> {
337        //添加到节点到 parent_id 的 content 中
338        let parent_shard_index = self.get_shard_index(parent_id);
339        let parent = self.nodes[parent_shard_index]
340            .get(parent_id)
341            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
342        let  new_parent = parent.as_ref().insert_content_at_index(index, &node.id);
343        //更新父节点
344        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
345            .update(parent_id.clone(), Arc::new(new_parent));
346        //更新父子关系映射
347        self.parent_map.insert(node.id.clone(), parent_id.clone());
348        //更新子节点
349        let shard_index = self.get_shard_index(&node.id);
350        self.nodes[shard_index] = self.nodes[shard_index]
351            .update(node.id.clone(), Arc::new(node.clone()));
352        Ok(())
353    }
354    pub fn add_node(
355        &mut self,
356        parent_id: &NodeId,
357        nodes: &Vec<Node>,
358    ) -> PoolResult<()> {
359        let parent_shard_index = self.get_shard_index(parent_id);
360        let parent = self.nodes[parent_shard_index]
361            .get(parent_id)
362            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
363        let node_ids = nodes.iter().map(|n| n.id.clone()).collect();
364        // 更新父节点 - 添加所有节点的ID到content中
365        let new_parent = parent.as_ref().insert_contents(&node_ids);
366        
367        // 更新父节点到分片中
368        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
369            .update(parent_id.clone(), Arc::new(new_parent));
370        
371        // 更新所有子节点
372        for node in nodes {
373            // 设置当前节点的父子关系映射
374            self.parent_map.insert(node.id.clone(), parent_id.clone());
375            
376            // 设置当前节点的子节点的父子关系映射
377            for child_id in &node.content {
378                self.parent_map.insert(child_id.clone(), node.id.clone());
379            }
380            
381            // 将节点添加到对应的分片中
382            let shard_index = self.get_shard_index(&node.id);
383            self.nodes[shard_index] = self.nodes[shard_index]
384                .update(node.id.clone(), Arc::new(node.clone()));
385        }
386        Ok(())
387    }
388
389    pub fn node(
390        &mut self,
391        key: &str,
392    ) -> NodeRef<'_> {
393        NodeRef::new(self, key.to_string())
394    }
395    pub fn mark(
396        &mut self,
397        key: &str,
398    ) -> MarkRef<'_> {
399        MarkRef::new(self, key.to_string())
400    }
401    pub fn attrs(
402        &mut self,
403        key: &str,
404    ) -> AttrsRef<'_> {
405        AttrsRef::new(self, key.to_string())
406    }
407
408    pub fn children(
409        &self,
410        parent_id: &NodeId,
411    ) -> Option<imbl::Vector<NodeId>> {
412        self.get_node(parent_id).map(|n| n.content.clone())
413    }
414
415    pub fn children_node(
416        &self,
417        parent_id: &NodeId,
418    ) -> Option<imbl::Vector<Arc<Node>>> {
419        self.children(parent_id)
420            .map(|ids| ids.iter().filter_map(|id| self.get_node(id)).collect())
421    }
422    //递归获取所有子节点 封装成 NodeEnum 返回
423    pub fn all_children(
424        &self,
425        parent_id: &NodeId,
426        filter: Option<&dyn Fn(&Node) -> bool>,
427    ) -> Option<NodeEnum> {
428        if let Some(node) = self.get_node(parent_id) {
429            let mut child_enums = Vec::new();
430            for child_id in &node.content {
431                if let Some(child_node) = self.get_node(child_id) {
432                    // 检查子节点是否满足过滤条件
433                    if let Some(filter_fn) = filter {
434                        if !filter_fn(child_node.as_ref()) {
435                            continue; // 跳过不满足条件的子节点
436                        }
437                    }
438                    // 递归处理满足条件的子节点
439                    if let Some(child_enum) =
440                        self.all_children(child_id, filter)
441                    {
442                        child_enums.push(child_enum);
443                    }
444                }
445            }
446            Some(NodeEnum(node.as_ref().clone(), child_enums))
447        } else {
448            None
449        }
450    }
451
452    pub fn children_count(
453        &self,
454        parent_id: &NodeId,
455    ) -> usize {
456        self.get_node(parent_id).map(|n| n.content.len()).unwrap_or(0)
457    }
458    pub fn remove_mark_by_name(
459        &mut self,
460        id: &NodeId,
461        mark_name: &str,
462    ) -> PoolResult<()> {
463        let shard_index = self.get_shard_index(id);
464        let node = self.nodes[shard_index]
465            .get(id)
466            .ok_or(error_helpers::node_not_found(id.clone()))?;
467        let new_node = node.as_ref().remove_mark_by_name(mark_name);
468        self.nodes[shard_index] =
469            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
470        Ok(())
471    }
472    pub fn get_marks(
473        &self,
474        id: &NodeId,
475    ) -> Option<imbl::Vector<Mark>> {
476        self.get_node(id).map(|n| n.marks.clone())
477    }
478
479    pub fn remove_mark(
480        &mut self,
481        id: &NodeId,
482        mark_types: &[String],
483    ) -> PoolResult<()> {
484        let shard_index = self.get_shard_index(id);
485        let node = self.nodes[shard_index]
486            .get(id)
487            .ok_or(error_helpers::node_not_found(id.clone()))?;
488        let new_node = node.as_ref().remove_mark(mark_types);
489        self.nodes[shard_index] =
490            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
491        Ok(())
492    }
493
494    pub fn add_mark(
495        &mut self,
496        id: &NodeId,
497        marks: &Vec<Mark>,
498    ) -> PoolResult<()> {
499        let shard_index = self.get_shard_index(id);
500        let node = self.nodes[shard_index]
501            .get(id)
502            .ok_or(error_helpers::node_not_found(id.clone()))?;
503        let new_node = node.as_ref().add_marks(marks);
504        self.nodes[shard_index] =
505            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
506        Ok(())
507    }
508
509    pub fn move_node(
510        &mut self,
511        source_parent_id: &NodeId,
512        target_parent_id: &NodeId,
513        node_id: &NodeId,
514        position: Option<usize>,
515    ) -> PoolResult<()> {
516        let source_shard_index = self.get_shard_index(source_parent_id);
517        let target_shard_index = self.get_shard_index(target_parent_id);
518        let node_shard_index = self.get_shard_index(node_id);
519        let source_parent = self.nodes[source_shard_index]
520            .get(source_parent_id)
521            .ok_or(error_helpers::parent_not_found(source_parent_id.clone()))?;
522        let target_parent = self.nodes[target_shard_index]
523            .get(target_parent_id)
524            .ok_or(error_helpers::parent_not_found(target_parent_id.clone()))?;
525        let _node = self.nodes[node_shard_index]
526            .get(node_id)
527            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
528        if !source_parent.content.contains(node_id) {
529            return Err(error_helpers::invalid_parenting(
530                node_id.clone(),
531                source_parent_id.clone(),
532            ));
533        }
534        let mut new_source_parent = source_parent.as_ref().clone();
535        new_source_parent.content = new_source_parent
536            .content
537            .iter()
538            .filter(|&id| id != node_id)
539            .cloned()
540            .collect();
541        let mut new_target_parent = target_parent.as_ref().clone();
542        if let Some(pos) = position {
543            // 确保position不超过当前content的长度
544            let insert_pos = pos.min(new_target_parent.content.len());
545
546            // 在指定位置插入节点
547            new_target_parent.content.insert(insert_pos, node_id.clone());
548        } else {
549            // 没有指定位置,添加到末尾
550            new_target_parent.content.push_back(node_id.clone());
551        }
552        self.nodes[source_shard_index] = self.nodes[source_shard_index]
553            .update(source_parent_id.clone(), Arc::new(new_source_parent));
554        self.nodes[target_shard_index] = self.nodes[target_shard_index]
555            .update(target_parent_id.clone(), Arc::new(new_target_parent));
556        self.parent_map.insert(node_id.clone(), target_parent_id.clone());
557        Ok(())
558    }
559
560    pub fn remove_node(
561        &mut self,
562        parent_id: &NodeId,
563        nodes: Vec<NodeId>,
564    ) -> PoolResult<()> {
565        let parent_shard_index = self.get_shard_index(parent_id);
566        let parent = self.nodes[parent_shard_index]
567            .get(parent_id)
568            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
569        if nodes.contains(&self.root_id) {
570            return Err(error_helpers::cannot_remove_root());
571        }
572        for node_id in &nodes {
573            if !parent.content.contains(node_id) {
574                return Err(error_helpers::invalid_parenting(
575                    node_id.clone(),
576                    parent_id.clone(),
577                ));
578            }
579        }
580        let nodes_to_remove: std::collections::HashSet<_> =
581            nodes.iter().collect();
582        let filtered_children: imbl ::Vector<NodeId> = parent
583            .as_ref()
584            .content
585            .iter()
586            .filter(|&id| !nodes_to_remove.contains(id))
587            .cloned()
588            .collect();
589        let mut parent_node = parent.as_ref().clone();
590        parent_node.content = filtered_children;
591        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
592            .update(parent_id.clone(), Arc::new(parent_node));
593        let mut remove_nodes = Vec::new();
594        for node_id in nodes {
595            self.remove_subtree(&node_id, &mut remove_nodes)?;
596        }
597        Ok(())
598    }
599    //=删除节点
600    pub fn remove_node_by_id(
601        &mut self,
602        node_id: &NodeId,
603    ) -> PoolResult<()> {
604        // 检查是否试图删除根节点
605        if node_id == &self.root_id {
606            return Err(error_helpers::cannot_remove_root());
607        }
608
609        let shard_index = self.get_shard_index(node_id);
610        let _ = self.nodes[shard_index]
611            .get(node_id)
612            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
613
614        // 从父节点的content中移除该节点
615        if let Some(parent_id) = self.parent_map.get(node_id).cloned() {
616            let parent_shard_index = self.get_shard_index(&parent_id);
617            if let Some(parent_node) =
618                self.nodes[parent_shard_index].get(&parent_id)
619            {
620                let mut new_parent = parent_node.as_ref().clone();
621                new_parent.content = new_parent
622                    .content
623                    .iter()
624                    .filter(|&id| id != node_id)
625                    .cloned()
626                    .collect();
627                self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
628                    .update(parent_id.clone(), Arc::new(new_parent));
629            }
630        }
631
632        // 删除子树(remove_subtree内部已经处理了节点的删除和parent_map的清理)
633        let mut remove_nodes = Vec::new();
634        self.remove_subtree(node_id, &mut remove_nodes)?;
635
636        // remove_subtree已经删除了所有节点,包括node_id本身,所以这里不需要再次删除
637        Ok(())
638    }
639
640    ///根据下标删除
641    pub fn remove_node_by_index(
642        &mut self,
643        parent_id: &NodeId,
644        index: usize,
645    ) -> PoolResult<()> {
646        let shard_index = self.get_shard_index(parent_id);
647        let parent = self.nodes[shard_index]
648            .get(parent_id)
649            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
650        let mut new_parent = parent.as_ref().clone();
651        let remove_node_id = new_parent.content.remove(index);
652        self.nodes[shard_index] = self.nodes[shard_index]
653            .update(parent_id.clone(), Arc::new(new_parent));
654        let mut remove_nodes = Vec::new();
655        self.remove_subtree(&remove_node_id, &mut remove_nodes)?;
656        Ok(())
657    }
658
659    //删除子树
660    fn remove_subtree(
661        &mut self,
662        node_id: &NodeId,
663        remove_nodes: &mut Vec<Node>,
664    ) -> PoolResult<()> {
665        if node_id == &self.root_id {
666            return Err(error_helpers::cannot_remove_root());
667        }
668        let shard_index = self.get_shard_index(node_id);
669        let _ = self.nodes[shard_index]
670            .get(node_id)
671            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
672        if let Some(children) = self.children(node_id) {
673            for child_id in children {
674                self.remove_subtree(&child_id, remove_nodes)?;
675            }
676        }
677        self.parent_map.remove(node_id);
678        if let Some(remove_node) = self.nodes[shard_index].remove(node_id) {
679            remove_nodes.push(remove_node.as_ref().clone());
680        }
681        Ok(())
682    }
683}
684
685impl Index<&NodeId> for Tree {
686    type Output = Arc<Node>;
687    fn index(
688        &self,
689        index: &NodeId,
690    ) -> &Self::Output {
691        let shard_index = self.get_shard_index(index);
692        self.nodes[shard_index].get(index).expect("Node not found")
693    }
694}
695
696impl Index<&str> for Tree {
697    type Output = Arc<Node>;
698    fn index(
699        &self,
700        index: &str,
701    ) -> &Self::Output {
702        let node_id = NodeId::from(index);
703        let shard_index = self.get_shard_index(&node_id);
704        self.nodes[shard_index].get(&node_id).expect("Node not found")
705    }
706}
707
708#[cfg(test)]
709mod tests {
710    use super::*;
711    use crate::node::Node;
712    use crate::attrs::Attrs;
713    use crate::mark::Mark;
714    use imbl::HashMap;
715    use serde_json::json;
716
717    fn create_test_node(id: &str) -> Node {
718        Node::new(id, "test".to_string(), Attrs::default(), vec![], vec![])
719    }
720
721    #[test]
722    fn test_tree_creation() {
723        let root = create_test_node("root");
724        let tree = Tree::new(root.clone());
725        assert_eq!(tree.root_id, root.id);
726        assert!(tree.contains_node(&root.id));
727    }
728
729    #[test]
730    fn test_add_node() {
731        let root = create_test_node("root");
732        let mut tree = Tree::new(root.clone());
733
734        let child = create_test_node("child");
735        let nodes = vec![child.clone()];
736
737        tree.add_node(&root.id, &nodes).unwrap();
738        dbg!(&tree);
739        assert!(tree.contains_node(&child.id));
740        assert_eq!(tree.children(&root.id).unwrap().len(), 1);
741    }
742
743    #[test]
744    fn test_remove_node() {
745        let root = create_test_node("root");
746        let mut tree = Tree::new(root.clone());
747
748        let child = create_test_node("child");
749        let nodes = vec![child.clone()];
750
751        tree.add_node(&root.id, &nodes).unwrap();
752        dbg!(&tree);
753        tree.remove_node(&root.id, vec![child.id.clone()]).unwrap();
754        dbg!(&tree);
755        assert!(!tree.contains_node(&child.id));
756        assert_eq!(tree.children(&root.id).unwrap().len(), 0);
757    }
758
759    #[test]
760    fn test_move_node() {
761        // 创建两个父节点
762        let parent1 = create_test_node("parent1");
763        let parent2 = create_test_node("parent2");
764        let mut tree = Tree::new(parent1.clone());
765
766        // 将 parent2 添加为 parent1 的子节点
767        tree.add_node(&parent1.id, &vec![parent2.clone()]).unwrap();
768
769        // 创建三个子节点
770        let child1 = create_test_node("child1");
771        let child2 = create_test_node("child2");
772        let child3 = create_test_node("child3");
773
774        // 将所有子节点添加到 parent1 下
775        tree.add_node(&parent1.id, &vec![child1.clone()]).unwrap();
776        tree.add_node(&parent1.id, &vec![child2.clone()]).unwrap();
777        tree.add_node(&parent1.id, &vec![child3.clone()]).unwrap();
778
779        // 验证初始状态
780        let parent1_children = tree.children(&parent1.id).unwrap();
781        assert_eq!(parent1_children.len(), 4); // parent2 + 3 children
782        assert_eq!(parent1_children[0], parent2.id);
783        assert_eq!(parent1_children[1], child1.id);
784        assert_eq!(parent1_children[2], child2.id);
785        assert_eq!(parent1_children[3], child3.id);
786
787        // 将 child1 移动到 parent2 下
788        tree.move_node(&parent1.id, &parent2.id, &child1.id, None).unwrap();
789
790        // 验证移动后的状态
791        let parent1_children = tree.children(&parent1.id).unwrap();
792        let parent2_children = tree.children(&parent2.id).unwrap();
793        assert_eq!(parent1_children.len(), 3); // parent2 + 2 children
794        assert_eq!(parent2_children.len(), 1); // child1
795        assert_eq!(parent2_children[0], child1.id);
796
797        // 将 child2 移动到 parent2 下,放在 child1 后面
798        tree.move_node(&parent1.id, &parent2.id, &child2.id, Some(1)).unwrap();
799
800        // 验证最终状态
801        let parent1_children = tree.children(&parent1.id).unwrap();
802        let parent2_children = tree.children(&parent2.id).unwrap();
803        assert_eq!(parent1_children.len(), 2); // parent2 + 1 child
804        assert_eq!(parent2_children.len(), 2); // child1 + child2
805        assert_eq!(parent2_children[0], child1.id);
806        assert_eq!(parent2_children[1], child2.id);
807
808        // 验证父节点关系
809        let child1_parent = tree.get_parent_node(&child1.id).unwrap();
810        let child2_parent = tree.get_parent_node(&child2.id).unwrap();
811        assert_eq!(child1_parent.id, parent2.id);
812        assert_eq!(child2_parent.id, parent2.id);
813    }
814
815    #[test]
816    fn test_update_attr() {
817        let root = create_test_node("root");
818        let mut tree = Tree::new(root.clone());
819
820        let mut attrs = HashMap::new();
821        attrs.insert("key".to_string(), json!("value"));
822
823        tree.update_attr(&root.id, attrs).unwrap();
824
825        let node = tree.get_node(&root.id).unwrap();
826        dbg!(&node);
827        assert_eq!(node.attrs.get("key").unwrap(), &json!("value"));
828    }
829
830    #[test]
831    fn test_add_mark() {
832        let root = create_test_node("root");
833        let mut tree = Tree::new(root.clone());
834
835        let mark = Mark { r#type: "test".to_string(), attrs: Attrs::default() };
836        tree.add_mark(&root.id, &vec![mark.clone()]).unwrap();
837        dbg!(&tree);
838        let node = tree.get_node(&root.id).unwrap();
839        assert!(node.marks.contains(&mark));
840    }
841
842    #[test]
843    fn test_remove_mark() {
844        let root = create_test_node("root");
845        let mut tree = Tree::new(root.clone());
846
847        let mark = Mark { r#type: "test".to_string(), attrs: Attrs::default() };
848        tree.add_mark(&root.id, &vec![mark.clone()]).unwrap();
849        dbg!(&tree);
850        tree.remove_mark(&root.id, &[mark.r#type.clone()]).unwrap();
851        dbg!(&tree);
852        let node = tree.get_node(&root.id).unwrap();
853        assert!(!node.marks.iter().any(|m| m.r#type == mark.r#type));
854    }
855
856    #[test]
857    fn test_all_children() {
858        let root = create_test_node("root");
859        let mut tree = Tree::new(root.clone());
860
861        let child1 = create_test_node("child1");
862        let child2 = create_test_node("child2");
863
864        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
865        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
866        dbg!(&tree);
867        let all_children = tree.all_children(&root.id, None).unwrap();
868        assert_eq!(all_children.1.len(), 2);
869    }
870
871    #[test]
872    fn test_children_count() {
873        let root = create_test_node("root");
874        let mut tree = Tree::new(root.clone());
875
876        let child1 = create_test_node("child1");
877        let child2 = create_test_node("child2");
878
879        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
880        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
881
882        assert_eq!(tree.children_count(&root.id), 2);
883    }
884
885    #[test]
886    fn test_remove_node_by_id_updates_parent() {
887        let root = create_test_node("root");
888        let mut tree = Tree::new(root.clone());
889
890        let child = create_test_node("child");
891        tree.add_node(&root.id, &vec![child.clone()]).unwrap();
892
893        // 验证子节点被添加
894        assert_eq!(tree.children_count(&root.id), 1);
895        assert!(tree.contains_node(&child.id));
896
897        // 删除子节点
898        tree.remove_node_by_id(&child.id).unwrap();
899
900        // 验证子节点被删除且父节点的content被更新
901        assert_eq!(tree.children_count(&root.id), 0);
902        assert!(!tree.contains_node(&child.id));
903    }
904
905    #[test]
906    fn test_move_node_position_edge_cases() {
907        let root = create_test_node("root");
908        let mut tree = Tree::new(root.clone());
909
910        let container = create_test_node("container");
911        tree.add_node(&root.id, &vec![container.clone()]).unwrap();
912
913        let child1 = create_test_node("child1");
914        let child2 = create_test_node("child2");
915        let child3 = create_test_node("child3");
916
917        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
918        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
919        tree.add_node(&root.id, &vec![child3.clone()]).unwrap();
920
921        // 测试移动到超出范围的位置(应该插入到末尾)
922        tree.move_node(&root.id, &container.id, &child1.id, Some(100)).unwrap();
923
924        let container_children = tree.children(&container.id).unwrap();
925        assert_eq!(container_children.len(), 1);
926        assert_eq!(container_children[0], child1.id);
927
928        // 测试移动到位置0
929        tree.move_node(&root.id, &container.id, &child2.id, Some(0)).unwrap();
930
931        let container_children = tree.children(&container.id).unwrap();
932        assert_eq!(container_children.len(), 2);
933        assert_eq!(container_children[0], child2.id);
934        assert_eq!(container_children[1], child1.id);
935    }
936
937    #[test]
938    fn test_cannot_remove_root_node() {
939        let root = create_test_node("root");
940        let mut tree = Tree::new(root.clone());
941
942        // 尝试删除根节点应该失败
943        let result = tree.remove_node_by_id(&root.id);
944        assert!(result.is_err());
945    }
946
947    #[test]
948    fn test_get_parent_node() {
949        let root = create_test_node("root");
950        let mut tree = Tree::new(root.clone());
951
952        let child = create_test_node("child");
953        tree.add_node(&root.id, &vec![child.clone()]).unwrap();
954
955        let parent = tree.get_parent_node(&child.id).unwrap();
956        assert_eq!(parent.id, root.id);
957    }
958}