moduforge_model/
tree.rs

1use std::{ops::Index, sync::Arc, num::NonZeroUsize};
2use std::hash::{Hash, Hasher};
3use std::collections::hash_map::DefaultHasher;
4use im::Vector;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::node_type::NodeEnum;
9use crate::{
10    error::PoolError,
11    mark::Mark,
12    node::Node,
13    ops::{AttrsRef, MarkRef, NodeRef},
14    types::NodeId,
15};
16
17#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
18pub struct Tree {
19    pub root_id: NodeId,
20    pub nodes: Vector<im::HashMap<NodeId, Arc<Node>>>, // 分片存储节点数据
21    pub parent_map: im::HashMap<NodeId, NodeId>,
22}
23
24impl Tree {
25    pub fn get_shard_index(
26        &self,
27        id: &NodeId,
28    ) -> usize {
29        let mut hasher = DefaultHasher::new();
30        id.hash(&mut hasher);
31        (hasher.finish() as usize) % self.nodes.len()
32    }
33
34    pub fn contains_node(
35        &self,
36        id: &NodeId,
37    ) -> bool {
38        let shard_index = self.get_shard_index(id);
39        self.nodes[shard_index].contains_key(id)
40    }
41
42    pub fn get_node(
43        &self,
44        id: &NodeId,
45    ) -> Option<Arc<Node>> {
46        let shard_index = self.get_shard_index(id);
47        self.nodes[shard_index].get(id).cloned()
48    }
49
50    pub fn get_parent_node(
51        &self,
52        id: &NodeId,
53    ) -> Option<Arc<Node>> {
54        self.parent_map.get(id).and_then(|parent_id| {
55            let shard_index = self.get_shard_index(parent_id);
56            self.nodes[shard_index].get(parent_id).cloned()
57        })
58    }
59    pub fn from(nodes: NodeEnum) -> Self {
60        let num_shards = std::cmp::max(
61            std::thread::available_parallelism()
62                .map(NonZeroUsize::get)
63                .unwrap_or(2),
64            2,
65        );
66        let mut shards = Vector::from(vec![im::HashMap::new(); num_shards]);
67        let mut parent_map = im::HashMap::new();
68        let (root_node, children) = nodes.into_parts();
69        let root_id = root_node.id.clone();
70
71        let mut hasher = DefaultHasher::new();
72        root_id.hash(&mut hasher);
73        let shard_index = (hasher.finish() as usize) % num_shards;
74        shards[shard_index] =
75            shards[shard_index].update(root_id.clone(), Arc::new(root_node));
76
77        fn process_children(
78            children: Vec<NodeEnum>,
79            parent_id: &NodeId,
80            shards: &mut Vector<im::HashMap<NodeId, Arc<Node>>>,
81            parent_map: &mut im::HashMap<NodeId, NodeId>,
82            num_shards: usize,
83        ) {
84            for child in children {
85                let (node, grand_children) = child.into_parts();
86                let node_id = node.id.clone();
87                let mut hasher = DefaultHasher::new();
88                node_id.hash(&mut hasher);
89                let shard_index = (hasher.finish() as usize) % num_shards;
90                shards[shard_index] =
91                    shards[shard_index].update(node_id.clone(), Arc::new(node));
92                parent_map.insert(node_id.clone(), parent_id.clone());
93
94                // Recursively process grand children
95                process_children(
96                    grand_children,
97                    &node_id,
98                    shards,
99                    parent_map,
100                    num_shards,
101                );
102            }
103        }
104
105        process_children(
106            children,
107            &root_id,
108            &mut shards,
109            &mut parent_map,
110            num_shards,
111        );
112
113        Self { root_id, nodes: shards, parent_map }
114    }
115
116    pub fn new(root: Node) -> Self {
117        let num_shards = std::cmp::max(
118            std::thread::available_parallelism()
119                .map(NonZeroUsize::get)
120                .unwrap_or(2),
121            2,
122        );
123        let mut nodes = Vector::from(vec![im::HashMap::new(); num_shards]);
124        let root_id = root.id.clone();
125        let mut hasher = DefaultHasher::new();
126        root_id.hash(&mut hasher);
127        let shard_index = (hasher.finish() as usize) % num_shards;
128        nodes[shard_index] =
129            nodes[shard_index].update(root_id.clone(), Arc::new(root));
130        Self { root_id, nodes, parent_map: im::HashMap::new() }
131    }
132
133    pub fn update_attr(
134        &mut self,
135        id: &NodeId,
136        new_values: im::HashMap<String, Value>,
137    ) -> Result<(), PoolError> {
138        let shard_index = self.get_shard_index(id);
139        let node = self.nodes[shard_index]
140            .get(id)
141            .ok_or(PoolError::NodeNotFound(id.clone()))?;
142        let old_values = node.attrs.clone();
143        let mut new_node = node.as_ref().clone();
144        let new_attrs = old_values.update(new_values);
145        new_node.attrs = new_attrs.clone();
146        self.nodes[shard_index] =
147            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
148        Ok(())
149    }
150
151    /// 向树中添加新的节点及其子节点
152    ///
153    /// # 参数
154    /// * `nodes` - 要添加的节点枚举,包含节点本身及其子节点
155    ///
156    /// # 返回值
157    /// * `Result<(), PoolError>` - 如果添加成功返回 Ok(()), 否则返回错误
158    ///
159    /// # 错误
160    /// * `PoolError::ParentNotFound` - 如果父节点不存在
161    pub fn add(
162        &mut self,
163        nodes: NodeEnum,
164    ) -> Result<(), PoolError> {
165        // 将节点枚举分解为当前节点和子节点
166        let (mut node, children) = nodes.into_parts();
167        let node_id = node.id.clone();
168
169        // 检查父节点是否存在
170        let parent_shard_index = self.get_shard_index(&node_id);
171        let _ = self.nodes[parent_shard_index]
172            .get(&node_id)
173            .ok_or(PoolError::ParentNotFound(node_id.clone()))?;
174
175        // 收集所有子节点的ID并添加到当前节点的content中
176        let zenliang: Vector<String> =
177            children.iter().map(|n| n.0.id.clone()).collect();
178        node.content.extend(zenliang);
179
180        // 更新当前节点
181        let shard_index = self.get_shard_index(&node_id);
182        self.nodes[shard_index] =
183            self.nodes[shard_index].update(node_id.clone(), Arc::new(node));
184
185        // 使用队列进行广度优先遍历,处理所有子节点
186        let mut node_queue = Vec::new();
187        node_queue.push((children, node_id.clone()));
188        while let Some((current_children, parent_id)) = node_queue.pop() {
189            for child in current_children {
190                // 处理每个子节点
191                let (mut child_node, grand_children) = child.into_parts();
192                let current_node_id = child_node.id.clone();
193
194                // 收集孙节点的ID并添加到子节点的content中
195                let zenliang: Vector<String> =
196                    grand_children.iter().map(|n| n.0.id.clone()).collect();
197                child_node.content.extend(zenliang);
198
199                // 更新子节点
200                let shard_index = self.get_shard_index(&current_node_id);
201                self.nodes[shard_index] = self.nodes[shard_index]
202                    .update(current_node_id.clone(), Arc::new(child_node));
203
204                // 更新父子关系映射
205                self.parent_map
206                    .insert(current_node_id.clone(), parent_id.clone());
207
208                // 将孙节点加入队列,以便后续处理
209                node_queue.push((grand_children, current_node_id.clone()));
210            }
211        }
212        Ok(())
213    }
214
215    pub fn add_node(
216        &mut self,
217        parent_id: &NodeId,
218        nodes: &Vec<Node>,
219    ) -> Result<(), PoolError> {
220        let parent_shard_index = self.get_shard_index(parent_id);
221        let parent = self.nodes[parent_shard_index]
222            .get(parent_id)
223            .ok_or(PoolError::ParentNotFound(parent_id.clone()))?;
224        let mut new_parent = parent.as_ref().clone();
225        new_parent.content.push_back(nodes[0].id.clone());
226        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
227            .update(parent_id.clone(), Arc::new(new_parent));
228        self.parent_map.insert(nodes[0].id.clone(), parent_id.clone());
229        for node in nodes {
230            let shard_index = self.get_shard_index(&node.id);
231            for child_id in &node.content {
232                self.parent_map.insert(child_id.clone(), node.id.clone());
233            }
234            self.nodes[shard_index] = self.nodes[shard_index]
235                .update(node.id.clone(), Arc::new(node.clone()));
236        }
237        Ok(())
238    }
239
240    pub fn node(
241        &mut self,
242        key: &str,
243    ) -> NodeRef<'_> {
244        NodeRef::new(self, key.to_string())
245    }
246    pub fn mark(
247        &mut self,
248        key: &str,
249    ) -> MarkRef<'_> {
250        MarkRef::new(self, key.to_string())
251    }
252    pub fn attrs(
253        &mut self,
254        key: &str,
255    ) -> AttrsRef<'_> {
256        AttrsRef::new(self, key.to_string())
257    }
258
259    pub fn children(
260        &self,
261        parent_id: &NodeId,
262    ) -> Option<im::Vector<NodeId>> {
263        self.get_node(parent_id).map(|n| n.content.clone())
264    }
265
266    pub fn children_node(
267        &self,
268        parent_id: &NodeId,
269    ) -> Option<im::Vector<Arc<Node>>> {
270        self.children(parent_id)
271            .map(|ids| ids.iter().filter_map(|id| self.get_node(id)).collect())
272    }
273
274    pub fn children_count(
275        &self,
276        parent_id: &NodeId,
277    ) -> usize {
278        self.get_node(parent_id).map(|n| n.content.len()).unwrap_or(0)
279    }
280
281    pub fn remove_mark(
282        &mut self,
283        id: &NodeId,
284        mark: Mark,
285    ) -> Result<(), PoolError> {
286        let shard_index = self.get_shard_index(id);
287        let node = self.nodes[shard_index]
288            .get(id)
289            .ok_or(PoolError::NodeNotFound(id.clone()))?;
290        let mut new_node = node.as_ref().clone();
291        new_node.marks =
292            new_node.marks.iter().filter(|&m| !m.eq(&mark)).cloned().collect();
293        self.nodes[shard_index] =
294            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
295        Ok(())
296    }
297
298    pub fn add_mark(
299        &mut self,
300        id: &NodeId,
301        marks: &Vec<Mark>,
302    ) -> Result<(), PoolError> {
303        let shard_index = self.get_shard_index(id);
304        let node = self.nodes[shard_index]
305            .get(id)
306            .ok_or(PoolError::NodeNotFound(id.clone()))?;
307        let mut new_node = node.as_ref().clone();
308        new_node.marks.extend(marks.clone());
309        self.nodes[shard_index] =
310            self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
311        Ok(())
312    }
313    pub fn replace_node(
314        &mut self,
315        node_id: NodeId,
316        nodes: &Vec<Node>,
317    ) -> Result<(), PoolError> {
318        let shard_index = self.get_shard_index(&node_id);
319        let _ = self.nodes[shard_index]
320            .get(&node_id)
321            .ok_or(PoolError::NodeNotFound(node_id.clone()))?;
322        if nodes[0].id != node_id {
323            return Err(PoolError::InvalidNodeId {
324                nodeid: node_id,
325                new_node_id: nodes[0].id.clone(),
326            });
327        }
328        let _ = self.add_node(&node_id, nodes)?;
329        Ok(())
330    }
331
332    pub fn move_node(
333        &mut self,
334        source_parent_id: &NodeId,
335        target_parent_id: &NodeId,
336        node_id: &NodeId,
337        position: Option<usize>,
338    ) -> Result<(), PoolError> {
339        let source_shard_index = self.get_shard_index(source_parent_id);
340        let target_shard_index = self.get_shard_index(target_parent_id);
341        let node_shard_index = self.get_shard_index(node_id);
342        let source_parent = self.nodes[source_shard_index]
343            .get(source_parent_id)
344            .ok_or(PoolError::ParentNotFound(source_parent_id.clone()))?;
345        let target_parent = self.nodes[target_shard_index]
346            .get(target_parent_id)
347            .ok_or(PoolError::ParentNotFound(target_parent_id.clone()))?;
348        let _node = self.nodes[node_shard_index]
349            .get(node_id)
350            .ok_or(PoolError::NodeNotFound(node_id.clone()))?;
351        if !source_parent.content.contains(node_id) {
352            return Err(PoolError::InvalidParenting {
353                child: node_id.clone(),
354                alleged_parent: source_parent_id.clone(),
355            });
356        }
357        let mut new_source_parent = source_parent.as_ref().clone();
358        new_source_parent.content = new_source_parent
359            .content
360            .iter()
361            .filter(|&id| id != node_id)
362            .cloned()
363            .collect();
364        let mut new_target_parent = target_parent.as_ref().clone();
365        if let Some(pos) = position {
366            if pos <= new_target_parent.content.len() {
367                let mut new_content = im::Vector::new();
368                for (i, child_id) in
369                    new_target_parent.content.iter().enumerate()
370                {
371                    if i == pos {
372                        new_content.push_back(node_id.clone());
373                    }
374                    new_content.push_back(child_id.clone());
375                }
376                if pos == new_target_parent.content.len() {
377                    new_content.push_back(node_id.clone());
378                }
379                new_target_parent.content = new_content;
380            } else {
381                new_target_parent.content.push_back(node_id.clone());
382            }
383        } else {
384            new_target_parent.content.push_back(node_id.clone());
385        }
386        self.nodes[source_shard_index] = self.nodes[source_shard_index]
387            .update(source_parent_id.clone(), Arc::new(new_source_parent));
388        self.nodes[target_shard_index] = self.nodes[target_shard_index]
389            .update(target_parent_id.clone(), Arc::new(new_target_parent));
390        self.parent_map.insert(node_id.clone(), target_parent_id.clone());
391        Ok(())
392    }
393
394    pub fn remove_node(
395        &mut self,
396        parent_id: &NodeId,
397        nodes: Vec<NodeId>,
398    ) -> Result<(), PoolError> {
399        let parent_shard_index = self.get_shard_index(parent_id);
400        let parent = self.nodes[parent_shard_index]
401            .get(parent_id)
402            .ok_or(PoolError::ParentNotFound(parent_id.clone()))?;
403        if nodes.contains(&self.root_id) {
404            return Err(PoolError::CannotRemoveRoot);
405        }
406        for node_id in &nodes {
407            if !parent.content.contains(node_id) {
408                return Err(PoolError::InvalidParenting {
409                    child: node_id.clone(),
410                    alleged_parent: parent_id.clone(),
411                });
412            }
413        }
414        let nodes_to_remove: std::collections::HashSet<_> =
415            nodes.iter().collect();
416        let filtered_children: im::Vector<NodeId> = parent
417            .as_ref()
418            .content
419            .iter()
420            .filter(|&id| !nodes_to_remove.contains(id))
421            .cloned()
422            .collect();
423        let mut parent_node = parent.as_ref().clone();
424        parent_node.content = filtered_children;
425        self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
426            .update(parent_id.clone(), Arc::new(parent_node));
427        let mut remove_nodes = Vec::new();
428        for node_id in nodes {
429            self.remove_subtree(&node_id, &mut remove_nodes)?;
430        }
431        Ok(())
432    }
433
434    fn remove_subtree(
435        &mut self,
436        node_id: &NodeId,
437        remove_nodes: &mut Vec<Node>,
438    ) -> Result<(), PoolError> {
439        if node_id == &self.root_id {
440            return Err(PoolError::CannotRemoveRoot);
441        }
442        let shard_index = self.get_shard_index(node_id);
443        let _ = self.nodes[shard_index]
444            .get(node_id)
445            .ok_or(PoolError::NodeNotFound(node_id.clone()))?;
446        if let Some(children) = self.children(node_id) {
447            for child_id in children {
448                self.remove_subtree(&child_id, remove_nodes)?;
449            }
450        }
451        self.parent_map.remove(node_id);
452        if let Some(remove_node) = self.nodes[shard_index].remove(node_id) {
453            remove_nodes.push(remove_node.as_ref().clone());
454        }
455        Ok(())
456    }
457}
458
459impl Index<&NodeId> for Tree {
460    type Output = Arc<Node>;
461    fn index(
462        &self,
463        index: &NodeId,
464    ) -> &Self::Output {
465        let shard_index = self.get_shard_index(index);
466        self.nodes[shard_index].get(index).expect("Node not found")
467    }
468}
469
470impl Index<&str> for Tree {
471    type Output = Arc<Node>;
472    fn index(
473        &self,
474        index: &str,
475    ) -> &Self::Output {
476        let node_id = NodeId::from(index);
477        let shard_index = self.get_shard_index(&node_id);
478        self.nodes[shard_index].get(&node_id).expect("Node not found")
479    }
480}