moduforge_model/
tree.rs

1use std::{ops::Index, sync::Arc, num::NonZeroUsize};
2use std::hash::{Hash, Hasher};
3use std::collections::hash_map::DefaultHasher;
4use im::Vector;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use once_cell::sync::Lazy;
8use parking_lot::RwLock;
9use lru::LruCache;
10
11use crate::error::PoolResult;
12use crate::node_type::NodeEnum;
13use crate::{
14    error::error_helpers,
15    mark::Mark,
16    node::Node,
17    ops::{AttrsRef, MarkRef, NodeRef},
18    types::NodeId
19};
20
21// 全局LRU缓存用于存储NodeId到分片索引的映射
22static SHARD_INDEX_CACHE: Lazy<RwLock<LruCache<String, usize>>> = 
23    Lazy::new(|| RwLock::new(LruCache::new(NonZeroUsize::new(10000).unwrap())));
24
25#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
26pub struct Tree {
27    pub root_id: NodeId,
28    pub nodes: Vector<im::HashMap<NodeId, Arc<Node>>>, // 分片存储节点数据
29    pub parent_map: im::HashMap<NodeId, NodeId>,
30    #[serde(skip)]
31    num_shards: usize, // 缓存分片数量,避免重复计算
32}
33
34impl Tree {
35    #[inline]
36    pub fn get_shard_index(
37        &self,
38        id: &NodeId,
39    ) -> usize {
40        // 先检查缓存
41        {
42            let cache = SHARD_INDEX_CACHE.read();
43            if let Some(&index) = cache.peek(id) {
44                return index;
45            }
46        }
47        
48        // 缓存未命中,计算哈希值
49        let mut hasher = DefaultHasher::new();
50        id.hash(&mut hasher);
51        let index = (hasher.finish() as usize) % self.num_shards;
52        
53        // 更新缓存
54        {
55            let mut cache = SHARD_INDEX_CACHE.write();
56            cache.put(id.clone(), index);
57        }
58        
59        index
60    }
61
62    #[inline]
63    pub fn get_shard_indices(&self, ids: &[&NodeId]) -> Vec<usize> {
64        ids.iter().map(|id| self.get_shard_index(id)).collect()
65    }
66
67    // 为批量操作提供优化的哈希计算
68    #[inline]
69    pub fn get_shard_index_batch<'a>(
70        &self,
71        ids: &'a [&'a NodeId],
72    ) -> Vec<(usize, &'a NodeId)> {
73        let mut results = Vec::with_capacity(ids.len());
74        let mut cache_misses = Vec::new();
75        
76        // 批量检查缓存
77        {
78            let cache = SHARD_INDEX_CACHE.read();
79            for &id in ids {
80                if let Some(&index) = cache.peek(id) {
81                    results.push((index, id));
82                } else {
83                    cache_misses.push(id);
84                }
85            }
86        }
87        
88        // 批量计算缓存未命中的项
89        if !cache_misses.is_empty() {
90            let mut cache = SHARD_INDEX_CACHE.write();
91            for &id in &cache_misses {
92                let mut hasher = DefaultHasher::new();
93                id.hash(&mut hasher);
94                let index = (hasher.finish() as usize) % self.num_shards;
95                cache.put(id.clone(), index);
96                results.push((index, id));
97            }
98        }
99        
100        results
101    }
102
103    // 清理缓存的方法,用于内存管理
104    pub fn clear_shard_cache() {
105        let mut cache = SHARD_INDEX_CACHE.write();
106        cache.clear();
107    }
108
109    pub fn contains_node(
110        &self,
111        id: &NodeId,
112    ) -> bool {
113        let shard_index = self.get_shard_index(id);
114        self.nodes[shard_index].contains_key(id)
115    }
116
117    pub fn get_node(
118        &self,
119        id: &NodeId,
120    ) -> Option<Arc<Node>> {
121        let shard_index = self.get_shard_index(id);
122        self.nodes[shard_index].get(id).cloned()
123    }
124
125    pub fn get_parent_node(
126        &self,
127        id: &NodeId,
128    ) -> Option<Arc<Node>> {
129        self.parent_map.get(id).and_then(|parent_id| {
130            let shard_index = self.get_shard_index(parent_id);
131            self.nodes[shard_index].get(parent_id).cloned()
132        })
133    }
134    pub fn from(nodes: NodeEnum) -> Self {
135        let num_shards = std::cmp::max(
136            std::thread::available_parallelism()
137                .map(NonZeroUsize::get)
138                .unwrap_or(2),
139            2,
140        );
141        let mut shards = Vector::from(vec![im::HashMap::new(); num_shards]);
142        let mut parent_map = im::HashMap::new();
143        let (root_node, children) = nodes.into_parts();
144        let root_id = root_node.id.clone();
145
146        let mut hasher = DefaultHasher::new();
147        root_id.hash(&mut hasher);
148        let shard_index = (hasher.finish() as usize) % num_shards;
149        shards[shard_index] =
150            shards[shard_index].update(root_id.clone(), Arc::new(root_node));
151
152        fn process_children(
153            children: Vec<NodeEnum>,
154            parent_id: &NodeId,
155            shards: &mut Vector<im::HashMap<NodeId, Arc<Node>>>,
156            parent_map: &mut im::HashMap<NodeId, NodeId>,
157            num_shards: usize,
158        ) {
159            for child in children {
160                let (node, grand_children) = child.into_parts();
161                let node_id = node.id.clone();
162                let mut hasher = DefaultHasher::new();
163                node_id.hash(&mut hasher);
164                let shard_index = (hasher.finish() as usize) % num_shards;
165                shards[shard_index] =
166                    shards[shard_index].update(node_id.clone(), Arc::new(node));
167                parent_map.insert(node_id.clone(), parent_id.clone());
168
169                // Recursively process grand children
170                process_children(
171                    grand_children,
172                    &node_id,
173                    shards,
174                    parent_map,
175                    num_shards,
176                );
177            }
178        }
179
180        process_children(
181            children,
182            &root_id,
183            &mut shards,
184            &mut parent_map,
185            num_shards,
186        );
187
188        Self { root_id, nodes: shards, parent_map, num_shards }
189    }
190
191    pub fn new(root: Node) -> Self {
192        let num_shards = std::cmp::max(
193            std::thread::available_parallelism()
194                .map(NonZeroUsize::get)
195                .unwrap_or(2),
196            2,
197        );
198        let mut nodes = Vector::from(vec![im::HashMap::new(); num_shards]);
199        let root_id = root.id.clone();
200        let mut hasher = DefaultHasher::new();
201        root_id.hash(&mut hasher);
202        let shard_index = (hasher.finish() as usize) % num_shards;
203        nodes[shard_index] =
204            nodes[shard_index].update(root_id.clone(), Arc::new(root));
205        Self { root_id, nodes, parent_map: im::HashMap::new(), num_shards }
206    }
207
208    pub fn update_attr(
209        &mut self,
210        id: &NodeId,
211        new_values: im::HashMap<String, Value>,
212    ) -> PoolResult<()> {
213        let shard_index = self.get_shard_index(id);
214        let node = self.nodes[shard_index]
215            .get(id)
216            .ok_or(error_helpers::node_not_found(id.clone()))?;
217        let old_values = node.attrs.clone();
218        let mut new_node = node.as_ref().clone();
219        let new_attrs = old_values.update(new_values);
220        new_node.attrs = new_attrs.clone();
221        self.nodes[shard_index] =
222            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
223        Ok(())
224    }
225    pub fn update_node(
226        &mut self,
227        node: Node,
228    ) -> PoolResult<()> {
229        let shard_index = self.get_shard_index(&node.id);
230        self.nodes[shard_index] = self.nodes[shard_index].update(node.id.clone(), Arc::new(node));
231        Ok(())
232    }
233
234    /// 向树中添加新的节点及其子节点
235    ///
236    /// # 参数
237    /// * `nodes` - 要添加的节点枚举,包含节点本身及其子节点
238    ///
239    /// # 返回值
240    /// * `Result<(), PoolError>` - 如果添加成功返回 Ok(()), 否则返回错误
241    ///
242    /// # 错误
243    /// * `PoolError::ParentNotFound` - 如果父节点不存在
244    pub fn add(
245        &mut self,
246        parent_id: &NodeId,
247        nodes: Vec<NodeEnum>,
248    ) -> PoolResult<()> {
249  
250
251        // 检查父节点是否存在
252        let parent_shard_index = self.get_shard_index(&parent_id);
253        let parent_node = self.nodes[parent_shard_index]
254            .get(parent_id)
255            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
256        let mut new_parent = parent_node.as_ref().clone();
257
258        // 收集所有子节点的ID并添加到当前节点的content中
259        let zenliang: Vector<String> =
260            nodes.iter().map(|n| n.0.id.clone()).collect();
261        new_parent.content.extend(zenliang);
262
263        // 更新当前节点
264        self.nodes[parent_shard_index] =
265            self.nodes[parent_shard_index].update(parent_id.clone(), Arc::new(new_parent));
266
267        // 使用队列进行广度优先遍历,处理所有子节点
268        let mut node_queue = Vec::new();
269        node_queue.push((nodes, parent_id.clone()));
270        while let Some((current_children, current_parent_id)) = node_queue.pop() {
271            for child in current_children {
272                // 处理每个子节点
273                let (mut child_node, grand_children) = child.into_parts();
274                let current_node_id = child_node.id.clone();
275
276                // 收集孙节点的ID并添加到子节点的content中
277                let grand_children_ids: Vector<String> =
278                    grand_children.iter().map(|n| n.0.id.clone()).collect();
279                child_node.content.extend(grand_children_ids);
280
281                // 将当前节点存储到对应的分片中
282                let shard_index = self.get_shard_index(&current_node_id);
283                self.nodes[shard_index] = self.nodes[shard_index]
284                    .update(current_node_id.clone(), Arc::new(child_node));
285
286                // 更新父子关系映射
287                self.parent_map
288                    .insert(current_node_id.clone(), current_parent_id.clone());
289
290                // 将孙节点加入队列,以便后续处理
291                node_queue.push((grand_children, current_node_id.clone()));
292            }
293        }
294        Ok(())
295    }
296    // 添加到下标
297    pub fn add_at_index(
298        &mut self,
299        parent_id: &NodeId,
300        index: usize,
301        node: &Node,
302    ) -> PoolResult<()> {
303        //添加到节点到 parent_id 的 content 中
304        let parent_shard_index = self.get_shard_index(parent_id);
305        let parent = self.nodes[parent_shard_index]
306            .get(parent_id)
307            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
308        let mut new_parent = parent.as_ref().clone();
309        new_parent.content.insert(index, node.id.clone());
310        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
311            .update(parent_id.clone(), Arc::new(new_parent));
312        self.parent_map.insert(node.id.clone(), parent_id.clone());
313        Ok(())
314    }
315    pub fn add_node(
316        &mut self,
317        parent_id: &NodeId,
318        nodes: &Vec<Node>,
319    ) -> PoolResult<()> {
320        let parent_shard_index = self.get_shard_index(parent_id);
321        let parent = self.nodes[parent_shard_index]
322            .get(parent_id)
323            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
324        let mut new_parent = parent.as_ref().clone();
325        new_parent.content.push_back(nodes[0].id.clone());
326        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
327            .update(parent_id.clone(), Arc::new(new_parent));
328        self.parent_map.insert(nodes[0].id.clone(), parent_id.clone());
329        for node in nodes {
330            let shard_index = self.get_shard_index(&node.id);
331            for child_id in &node.content {
332                self.parent_map.insert(child_id.clone(), node.id.clone());
333            }
334            self.nodes[shard_index] = self.nodes[shard_index]
335                .update(node.id.clone(), Arc::new(node.clone()));
336        }
337        Ok(())
338    }
339
340    pub fn node(
341        &mut self,
342        key: &str,
343    ) -> NodeRef<'_> {
344        NodeRef::new(self, key.to_string())
345    }
346    pub fn mark(
347        &mut self,
348        key: &str,
349    ) -> MarkRef<'_> {
350        MarkRef::new(self, key.to_string())
351    }
352    pub fn attrs(
353        &mut self,
354        key: &str,
355    ) -> AttrsRef<'_> {
356        AttrsRef::new(self, key.to_string())
357    }
358
359    pub fn children(
360        &self,
361        parent_id: &NodeId,
362    ) -> Option<im::Vector<NodeId>> {
363        self.get_node(parent_id).map(|n| n.content.clone())
364    }
365
366    pub fn children_node(
367        &self,
368        parent_id: &NodeId,
369    ) -> Option<im::Vector<Arc<Node>>> {
370        self.children(parent_id)
371            .map(|ids| ids.iter().filter_map(|id| self.get_node(id)).collect())
372    }
373    //递归获取所有子节点 封装成 NodeEnum 返回
374    pub fn all_children(
375        &self,
376        parent_id: &NodeId,
377        filter: Option<&dyn Fn(&Node) -> bool>,
378    ) -> Option<NodeEnum> {
379        if let Some(node) = self.get_node(parent_id) {
380            let mut child_enums = Vec::new();
381            for child_id in &node.content {
382                if let Some(child_node) = self.get_node(child_id) {
383                    // 检查子节点是否满足过滤条件
384                    if let Some(filter_fn) = filter {
385                        if !filter_fn(child_node.as_ref()) {
386                            continue; // 跳过不满足条件的子节点
387                        }
388                    }
389                    // 递归处理满足条件的子节点
390                    if let Some(child_enum) =
391                        self.all_children(child_id, filter)
392                    {
393                        child_enums.push(child_enum);
394                    }
395                }
396            }
397            Some(NodeEnum(node.as_ref().clone(), child_enums))
398        } else {
399            None
400        }
401    }
402
403    pub fn children_count(
404        &self,
405        parent_id: &NodeId,
406    ) -> usize {
407        self.get_node(parent_id).map(|n| n.content.len()).unwrap_or(0)
408    }
409    pub fn remove_mark_by_name(
410        &mut self,
411        id: &NodeId,
412        mark_name: &str,
413    ) -> PoolResult<()> {
414        let shard_index = self.get_shard_index(id);
415        let node = self.nodes[shard_index]
416            .get(id)
417            .ok_or(error_helpers::node_not_found(id.clone()))?;
418        let mut new_node = node.as_ref().clone();
419        new_node.marks =
420            new_node.marks.iter().filter(|&m| m.r#type != mark_name).cloned().collect();
421        self.nodes[shard_index] =
422            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
423        Ok(())
424    }
425    pub fn get_marks(
426        &self,
427        id: &NodeId,
428    ) -> Option<im::Vector<Mark>> {
429        self.get_node(id).map(|n| n.marks.clone())
430    }
431    
432    pub fn remove_mark(
433        &mut self,
434        id: &NodeId,
435        mark: Mark,
436    ) -> PoolResult<()> {
437        let shard_index = self.get_shard_index(id);
438        let node = self.nodes[shard_index]
439            .get(id)
440            .ok_or(error_helpers::node_not_found(id.clone()))?;
441        let mut new_node = node.as_ref().clone();
442        new_node.marks =
443            new_node.marks.iter().filter(|&m| !m.eq(&mark)).cloned().collect();
444        self.nodes[shard_index] =
445            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
446        Ok(())
447    }
448
449    pub fn add_mark(
450        &mut self,
451        id: &NodeId,
452        marks: &Vec<Mark>,
453    ) -> PoolResult<()> {
454        let shard_index = self.get_shard_index(id);
455        let node = self.nodes[shard_index]
456            .get(id)
457            .ok_or(error_helpers::node_not_found(id.clone()))?;
458        let mut new_node = node.as_ref().clone();
459        new_node.marks.extend(marks.clone());
460        self.nodes[shard_index] =
461            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
462        Ok(())
463    }
464
465    pub fn move_node(
466        &mut self,
467        source_parent_id: &NodeId,
468        target_parent_id: &NodeId,
469        node_id: &NodeId,
470        position: Option<usize>,
471    ) -> PoolResult<()> {
472        let source_shard_index = self.get_shard_index(source_parent_id);
473        let target_shard_index = self.get_shard_index(target_parent_id);
474        let node_shard_index = self.get_shard_index(node_id);
475        let source_parent = self.nodes[source_shard_index]
476            .get(source_parent_id)
477            .ok_or(error_helpers::parent_not_found(source_parent_id.clone()))?;
478        let target_parent = self.nodes[target_shard_index]
479            .get(target_parent_id)
480            .ok_or(error_helpers::parent_not_found(target_parent_id.clone()))?;
481        let _node = self.nodes[node_shard_index]
482            .get(node_id)
483            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
484        if !source_parent.content.contains(node_id) {
485            return Err(error_helpers::invalid_parenting(
486                node_id.clone(),
487                source_parent_id.clone(),
488            ));
489        }
490        let mut new_source_parent = source_parent.as_ref().clone();
491        new_source_parent.content = new_source_parent
492            .content
493            .iter()
494            .filter(|&id| id != node_id)
495            .cloned()
496            .collect();
497        let mut new_target_parent = target_parent.as_ref().clone();
498        if let Some(pos) = position {
499            // 确保position不超过当前content的长度
500            let insert_pos = pos.min(new_target_parent.content.len());
501            
502            if insert_pos == new_target_parent.content.len() {
503                // 在末尾插入
504                new_target_parent.content.push_back(node_id.clone());
505            } else {
506                // 在指定位置插入
507                let mut new_content = im::Vector::new();
508                
509                // 添加position之前的所有元素
510                for (i, child_id) in new_target_parent.content.iter().enumerate() {
511                    if i == insert_pos {
512                        // 在这个位置插入新节点
513                        new_content.push_back(node_id.clone());
514                    }
515                    new_content.push_back(child_id.clone());
516                }
517                
518                new_target_parent.content = new_content;
519            }
520        } else {
521            // 没有指定位置,添加到末尾
522            new_target_parent.content.push_back(node_id.clone());
523        }
524        self.nodes[source_shard_index] = self.nodes[source_shard_index]
525            .update(source_parent_id.clone(), Arc::new(new_source_parent));
526        self.nodes[target_shard_index] = self.nodes[target_shard_index]
527            .update(target_parent_id.clone(), Arc::new(new_target_parent));
528        self.parent_map.insert(node_id.clone(), target_parent_id.clone());
529        Ok(())
530    }
531
532    pub fn remove_node(
533        &mut self,
534        parent_id: &NodeId,
535        nodes: Vec<NodeId>,
536    ) -> PoolResult<()> {
537        let parent_shard_index = self.get_shard_index(parent_id);
538        let parent = self.nodes[parent_shard_index]
539            .get(parent_id)
540            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
541        if nodes.contains(&self.root_id) {
542            return Err(error_helpers::cannot_remove_root());
543        }
544        for node_id in &nodes {
545            if !parent.content.contains(node_id) {
546                return Err(error_helpers::invalid_parenting(
547                    node_id.clone(),
548                    parent_id.clone(),
549                ));
550            }
551        }
552        let nodes_to_remove: std::collections::HashSet<_> =
553            nodes.iter().collect();
554        let filtered_children: im::Vector<NodeId> = parent
555            .as_ref()
556            .content
557            .iter()
558            .filter(|&id| !nodes_to_remove.contains(id))
559            .cloned()
560            .collect();
561        let mut parent_node = parent.as_ref().clone();
562        parent_node.content = filtered_children;
563        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
564            .update(parent_id.clone(), Arc::new(parent_node));
565        let mut remove_nodes = Vec::new();
566        for node_id in nodes {
567            self.remove_subtree(&node_id, &mut remove_nodes)?;
568        }
569        Ok(())
570    }
571    //=删除节点
572    pub fn remove_node_by_id(
573        &mut self,
574        node_id: &NodeId,
575    ) -> PoolResult<()> {
576        // 检查是否试图删除根节点
577        if node_id == &self.root_id {
578            return Err(error_helpers::cannot_remove_root());
579        }
580
581        let shard_index = self.get_shard_index(node_id);
582        let _ = self.nodes[shard_index]
583            .get(node_id)
584            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
585
586        // 从父节点的content中移除该节点
587        if let Some(parent_id) = self.parent_map.get(node_id).cloned() {
588            let parent_shard_index = self.get_shard_index(&parent_id);
589            if let Some(parent_node) = self.nodes[parent_shard_index].get(&parent_id) {
590                let mut new_parent = parent_node.as_ref().clone();
591                new_parent.content = new_parent.content
592                    .iter()
593                    .filter(|&id| id != node_id)
594                    .cloned()
595                    .collect();
596                self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
597                    .update(parent_id.clone(), Arc::new(new_parent));
598            }
599        }
600
601        // 删除子树(remove_subtree内部已经处理了节点的删除和parent_map的清理)
602        let mut remove_nodes = Vec::new();
603        self.remove_subtree(node_id, &mut remove_nodes)?;
604        
605        // remove_subtree已经删除了所有节点,包括node_id本身,所以这里不需要再次删除
606        Ok(())
607    }
608
609    ///根据下标删除
610    pub fn remove_node_by_index(
611        &mut self,
612        parent_id: &NodeId,
613        index: usize,
614    ) -> PoolResult<()> {
615        let shard_index = self.get_shard_index(parent_id);
616        let parent = self.nodes[shard_index]
617            .get(parent_id)
618            .ok_or(error_helpers::parent_not_found(parent_id.clone()))?;
619        let mut new_parent = parent.as_ref().clone();
620        let remove_node_id = new_parent.content.remove(index);
621        self.nodes[shard_index] = self.nodes[shard_index]
622            .update(parent_id.clone(), Arc::new(new_parent));
623        let mut remove_nodes = Vec::new();
624        self.remove_subtree(&remove_node_id, &mut remove_nodes)?;
625        for node in remove_nodes {
626            let shard_index = self.get_shard_index(&node.id);
627            self.nodes[shard_index].remove(&node.id);
628            self.parent_map.remove(&node.id);
629        }
630        Ok(())
631    }
632
633    //删除子树
634    fn remove_subtree(
635        &mut self,
636        node_id: &NodeId,
637        remove_nodes: &mut Vec<Node>,
638    ) -> PoolResult<()> {
639        if node_id == &self.root_id {
640            return Err(error_helpers::cannot_remove_root());
641        }
642        let shard_index = self.get_shard_index(node_id);
643        let _ = self.nodes[shard_index]
644            .get(node_id)
645            .ok_or(error_helpers::node_not_found(node_id.clone()))?;
646        if let Some(children) = self.children(node_id) {
647            for child_id in children {
648                self.remove_subtree(&child_id, remove_nodes)?;
649            }
650        }
651        self.parent_map.remove(node_id);
652        if let Some(remove_node) = self.nodes[shard_index].remove(node_id) {
653            remove_nodes.push(remove_node.as_ref().clone());
654        }
655        Ok(())
656    }
657}
658
659impl Index<&NodeId> for Tree {
660    type Output = Arc<Node>;
661    fn index(
662        &self,
663        index: &NodeId,
664    ) -> &Self::Output {
665        let shard_index = self.get_shard_index(index);
666        self.nodes[shard_index].get(index).expect("Node not found")
667    }
668}
669
670impl Index<&str> for Tree {
671    type Output = Arc<Node>;
672    fn index(
673        &self,
674        index: &str,
675    ) -> &Self::Output {
676        let node_id = NodeId::from(index);
677        let shard_index = self.get_shard_index(&node_id);
678        self.nodes[shard_index].get(&node_id).expect("Node not found")
679    }
680}
681
682#[cfg(test)]
683mod tests {
684    use super::*;
685    use crate::node::Node;
686    use crate::node_type::NodeEnum;
687    use crate::types::NodeId;
688    use crate::attrs::Attrs;
689    use crate::mark::Mark;
690    use im::HashMap;
691    use serde_json::json;
692
693    fn create_test_node(id: &str) -> Node {
694        Node::new(
695            id,
696            "test".to_string(),
697            Attrs::default(),
698            vec![],
699            vec![],
700        )
701    }
702
703    #[test]
704    fn test_tree_creation() {
705        let root = create_test_node("root");
706        let tree = Tree::new(root.clone());
707        assert_eq!(tree.root_id, root.id);
708        assert!(tree.contains_node(&root.id));
709    }
710
711    #[test]
712    fn test_add_node() {
713        let root = create_test_node("root");
714        let mut tree = Tree::new(root.clone());
715        
716        let child = create_test_node("child");
717        let nodes = vec![child.clone()];
718        
719        tree.add_node(&root.id, &nodes).unwrap();
720        dbg!(&tree);
721        assert!(tree.contains_node(&child.id));
722        assert_eq!(tree.children(&root.id).unwrap().len(), 1);
723    }
724
725    #[test]
726    fn test_remove_node() {
727        let root = create_test_node("root");
728        let mut tree = Tree::new(root.clone());
729       
730        let child = create_test_node("child");
731        let nodes = vec![child.clone()];
732        
733        tree.add_node(&root.id, &nodes).unwrap();
734        dbg!(&tree);
735        tree.remove_node(&root.id, vec![child.id.clone()]).unwrap();
736        dbg!(&tree);
737        assert!(!tree.contains_node(&child.id));
738        assert_eq!(tree.children(&root.id).unwrap().len(), 0);
739    }
740
741    #[test]
742    fn test_move_node() {
743        // 创建两个父节点
744        let parent1 = create_test_node("parent1");
745        let parent2 = create_test_node("parent2");
746        let mut tree = Tree::new(parent1.clone());
747        
748        // 将 parent2 添加为 parent1 的子节点
749        tree.add_node(&parent1.id, &vec![parent2.clone()]).unwrap();
750        
751        // 创建三个子节点
752        let child1 = create_test_node("child1");
753        let child2 = create_test_node("child2");
754        let child3 = create_test_node("child3");
755        
756        // 将所有子节点添加到 parent1 下
757        tree.add_node(&parent1.id, &vec![child1.clone()]).unwrap();
758        tree.add_node(&parent1.id, &vec![child2.clone()]).unwrap();
759        tree.add_node(&parent1.id, &vec![child3.clone()]).unwrap();
760        
761        // 验证初始状态
762        let parent1_children = tree.children(&parent1.id).unwrap();
763        assert_eq!(parent1_children.len(), 4); // parent2 + 3 children
764        assert_eq!(parent1_children[0], parent2.id);
765        assert_eq!(parent1_children[1], child1.id);
766        assert_eq!(parent1_children[2], child2.id);
767        assert_eq!(parent1_children[3], child3.id);
768        
769        // 将 child1 移动到 parent2 下
770        tree.move_node(&parent1.id, &parent2.id, &child1.id, None).unwrap();
771        
772        // 验证移动后的状态
773        let parent1_children = tree.children(&parent1.id).unwrap();
774        let parent2_children = tree.children(&parent2.id).unwrap();
775        assert_eq!(parent1_children.len(), 3); // parent2 + 2 children
776        assert_eq!(parent2_children.len(), 1); // child1
777        assert_eq!(parent2_children[0], child1.id);
778        
779        // 将 child2 移动到 parent2 下,放在 child1 后面
780        tree.move_node(&parent1.id, &parent2.id, &child2.id, Some(1)).unwrap();
781        
782        // 验证最终状态
783        let parent1_children = tree.children(&parent1.id).unwrap();
784        let parent2_children = tree.children(&parent2.id).unwrap();
785        assert_eq!(parent1_children.len(), 2); // parent2 + 1 child
786        assert_eq!(parent2_children.len(), 2); // child1 + child2
787        assert_eq!(parent2_children[0], child1.id);
788        assert_eq!(parent2_children[1], child2.id);
789        
790        // 验证父节点关系
791        let child1_parent = tree.get_parent_node(&child1.id).unwrap();
792        let child2_parent = tree.get_parent_node(&child2.id).unwrap();
793        assert_eq!(child1_parent.id, parent2.id);
794        assert_eq!(child2_parent.id, parent2.id);
795    }
796
797    #[test]
798    fn test_update_attr() {
799        let root = create_test_node("root");
800        let mut tree = Tree::new(root.clone());
801        
802        let mut attrs = HashMap::new();
803        attrs.insert("key".to_string(), json!("value"));
804        
805        tree.update_attr(&root.id, attrs).unwrap();
806        
807        let node = tree.get_node(&root.id).unwrap();
808        dbg!(&node);
809        assert_eq!(node.attrs.get("key").unwrap(), &json!("value"));
810    }
811
812    #[test]
813    fn test_add_mark() {
814        let root = create_test_node("root");
815        let mut tree = Tree::new(root.clone());
816        
817        let mark = Mark {
818            r#type: "test".to_string(),
819            attrs: Attrs::default(),
820        };
821        tree.add_mark(&root.id, &vec![mark.clone()]).unwrap();
822        dbg!(&tree);
823        let node = tree.get_node(&root.id).unwrap();
824        assert!(node.marks.contains(&mark));
825    }
826
827    #[test]
828    fn test_remove_mark() {
829        let root = create_test_node("root");
830        let mut tree = Tree::new(root.clone());
831        
832        let mark = Mark {
833            r#type: "test".to_string(),
834            attrs: Attrs::default(),
835        };
836        tree.add_mark(&root.id, &vec![mark.clone()]).unwrap();
837        dbg!(&tree);
838        tree.remove_mark(&root.id, mark.clone()).unwrap();
839        dbg!(&tree);
840        let node = tree.get_node(&root.id).unwrap();
841        assert!(!node.marks.contains(&mark));
842    }
843
844    #[test]
845    fn test_all_children() {
846        let root = create_test_node("root");
847        let mut tree = Tree::new(root.clone());
848        
849        let child1 = create_test_node("child1");
850        let child2 = create_test_node("child2");
851        
852        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
853        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
854        dbg!(&tree);
855        let all_children = tree.all_children(&root.id, None).unwrap();
856        assert_eq!(all_children.1.len(), 2);
857    }
858
859    #[test]
860    fn test_children_count() {
861        let root = create_test_node("root");
862        let mut tree = Tree::new(root.clone());
863        
864        let child1 = create_test_node("child1");
865        let child2 = create_test_node("child2");
866        
867        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
868        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
869        
870        assert_eq!(tree.children_count(&root.id), 2);
871    }
872
873    #[test]
874    fn test_remove_node_by_id_updates_parent() {
875        let root = create_test_node("root");
876        let mut tree = Tree::new(root.clone());
877        
878        let child = create_test_node("child");
879        tree.add_node(&root.id, &vec![child.clone()]).unwrap();
880        
881        // 验证子节点被添加
882        assert_eq!(tree.children_count(&root.id), 1);
883        assert!(tree.contains_node(&child.id));
884        
885        // 删除子节点
886        tree.remove_node_by_id(&child.id).unwrap();
887        
888        // 验证子节点被删除且父节点的content被更新
889        assert_eq!(tree.children_count(&root.id), 0);
890        assert!(!tree.contains_node(&child.id));
891    }
892
893    #[test]
894    fn test_move_node_position_edge_cases() {
895        let root = create_test_node("root");
896        let mut tree = Tree::new(root.clone());
897        
898        let container = create_test_node("container");
899        tree.add_node(&root.id, &vec![container.clone()]).unwrap();
900        
901        let child1 = create_test_node("child1");
902        let child2 = create_test_node("child2");
903        let child3 = create_test_node("child3");
904        
905        tree.add_node(&root.id, &vec![child1.clone()]).unwrap();
906        tree.add_node(&root.id, &vec![child2.clone()]).unwrap();
907        tree.add_node(&root.id, &vec![child3.clone()]).unwrap();
908        
909        // 测试移动到超出范围的位置(应该插入到末尾)
910        tree.move_node(&root.id, &container.id, &child1.id, Some(100)).unwrap();
911        
912        let container_children = tree.children(&container.id).unwrap();
913        assert_eq!(container_children.len(), 1);
914        assert_eq!(container_children[0], child1.id);
915        
916        // 测试移动到位置0
917        tree.move_node(&root.id, &container.id, &child2.id, Some(0)).unwrap();
918        
919        let container_children = tree.children(&container.id).unwrap();
920        assert_eq!(container_children.len(), 2);
921        assert_eq!(container_children[0], child2.id);
922        assert_eq!(container_children[1], child1.id);
923    }
924
925    #[test]
926    fn test_cannot_remove_root_node() {
927        let root = create_test_node("root");
928        let mut tree = Tree::new(root.clone());
929        
930        // 尝试删除根节点应该失败
931        let result = tree.remove_node_by_id(&root.id);
932        assert!(result.is_err());
933    }
934
935    #[test]
936    fn test_get_parent_node() {
937        let root = create_test_node("root");
938        let mut tree = Tree::new(root.clone());
939        
940        let child = create_test_node("child");
941        tree.add_node(&root.id, &vec![child.clone()]).unwrap();
942        
943        let parent = tree.get_parent_node(&child.id).unwrap();
944        assert_eq!(parent.id, root.id);
945    }
946}