moduforge_model/
node_pool.rs

1use super::{error::PoolError, node::Node, types::NodeId};
2use im::HashMap;
3use serde::{Deserialize, Serialize};
4use std::sync::Arc;
5/// 节点池内部数据结构,实现结构共享和高效克隆
6///
7/// # 字段
8///
9/// * `root_id` - 根节点标识符
10/// * `nodes` - 节点存储的不可变哈希表(使用结构共享)
11/// * `parent_map` - 父子关系映射表
12#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
13pub struct NodePoolInner {
14    pub root_id: NodeId,
15    pub nodes: im::HashMap<NodeId, Arc<Node>>, // 节点数据共享
16    pub parent_map: im::HashMap<NodeId, NodeId>,
17}
18impl NodePoolInner {
19    /// 更新节点属性(创建新版本的数据结构)
20    ///
21    /// # 参数
22    ///
23    /// * `id` - 目标节点ID
24    /// * `values` - 要更新的属性键值对
25    ///
26    /// # 返回值
27    ///
28    /// 返回包含新节点属性的新版本 `NodePoolInner`
29    ///
30    /// # 错误
31    ///
32    /// 当节点不存在时返回 [`PoolError::NodeNotFound`]
33    pub fn update_attr(
34        &self,
35        id: &NodeId,
36        values: &HashMap<String, String>,
37    ) -> Result<Self, PoolError> {
38        if !self.nodes.contains_key(id) {
39            return Err(PoolError::NodeNotFound(id.clone()));
40        }
41        let node = self.nodes.get(id).unwrap();
42
43        let mut cope_node = node.clone().as_ref().clone();
44        cope_node.attrs.extend(values.clone());
45        let nodes = self.nodes.update(id.clone(), Arc::new(cope_node));
46        Ok(NodePoolInner {
47            nodes,
48            parent_map: self.parent_map.clone(),
49            root_id: self.root_id.clone(),
50        })
51    }
52}
53/// 线程安全的节点池封装
54///
55/// 使用 [`Arc`] 实现快速克隆,内部使用不可变数据结构保证线程安全
56#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
57pub struct NodePool {
58    // 使用 Arc 包裹内部结构,实现快速克隆
59    pub inner: Arc<NodePoolInner>,
60}
61unsafe impl Send for NodePool {}
62unsafe impl Sync for NodePool {}
63
64impl NodePool {
65    /// 获取节点池中节点总数
66    pub fn size(&self) -> usize {
67        self.inner.nodes.len()
68    }
69
70    /// 从节点列表构建节点池
71    ///
72    /// # 参数
73    ///
74    /// * `nodes` - 初始节点列表
75    /// * `root_id` - 指定根节点ID
76    ///
77    /// # 注意
78    ///
79    /// 会自动构建父子关系映射表
80    pub fn from(
81        nodes: Vec<Node>,
82        root_id: NodeId,
83    ) -> Self {
84        let mut nodes_ref = HashMap::new();
85        let mut parent_map_ref = HashMap::new();
86        for node in nodes.into_iter() {
87            for child_id in &node.content {
88                parent_map_ref.insert(child_id.clone(), node.id.clone());
89            }
90            nodes_ref.insert(node.id.clone(), Arc::new(node));
91        }
92
93        NodePool {
94            inner: Arc::new(NodePoolInner {
95                nodes: nodes_ref,
96                parent_map: parent_map_ref,
97                root_id,
98            }),
99        }
100    }
101
102    // -- 核心查询方法 --
103
104    /// 根据ID获取节点(immutable)
105    pub fn get_node(
106        &self,
107        id: &NodeId,
108    ) -> Option<&Arc<Node>> {
109        self.inner.nodes.get(id)
110    }
111
112    /// 检查节点是否存在
113    pub fn contains_node(
114        &self,
115        id: &NodeId,
116    ) -> bool {
117        self.inner.nodes.contains_key(id)
118    }
119
120    // -- 层级关系操作 --
121
122    /// 获取直接子节点列表
123    pub fn children(
124        &self,
125        parent_id: &NodeId,
126    ) -> Option<&im::Vector<NodeId>> {
127        self.get_node(parent_id).map(|n| &n.content)
128    }
129
130    /// 递归获取所有子节点(深度优先)
131    pub fn descendants(
132        &self,
133        parent_id: &NodeId,
134    ) -> Vec<&Node> {
135        let mut result: Vec<&Node> = Vec::new();
136        self._collect_descendants(parent_id, &mut result);
137        result
138    }
139
140    fn _collect_descendants<'a>(
141        &'a self,
142        parent_id: &NodeId,
143        result: &mut Vec<&'a Node>,
144    ) {
145        if let Some(children) = self.children(parent_id) {
146            for child_id in children {
147                if let Some(child) = self.get_node(child_id) {
148                    result.push(child);
149                    self._collect_descendants(child_id, result);
150                }
151            }
152        }
153    }
154
155    /// 获取父节点ID
156    pub fn parent_id(
157        &self,
158        child_id: &NodeId,
159    ) -> Option<&NodeId> {
160        self.inner.parent_map.get(child_id)
161    }
162
163    /// 获取完整祖先链
164    pub fn ancestors(
165        &self,
166        child_id: &NodeId,
167    ) -> Vec<&Arc<Node>> {
168        let mut chain = Vec::new();
169        let mut current_id = child_id;
170        while let Some(parent_id) = self.parent_id(current_id) {
171            if let Some(parent) = self.get_node(parent_id) {
172                chain.push(parent);
173                current_id = parent_id;
174            } else {
175                break;
176            }
177        }
178        chain
179    }
180
181    /// 验证父子关系一致性
182    pub fn validate_hierarchy(&self) -> Result<(), PoolError> {
183        for (child_id, parent_id) in &self.inner.parent_map {
184            // 验证父节点存在
185            if !self.contains_node(parent_id) {
186                return Err(PoolError::OrphanNode(child_id.clone()));
187            }
188
189            // 验证父节点确实包含该子节点
190            if let Some(children) = self.children(parent_id) {
191                if !children.contains(child_id) {
192                    return Err(PoolError::InvalidParenting {
193                        child: child_id.clone(),
194                        alleged_parent: parent_id.clone(),
195                    });
196                }
197            }
198        }
199        Ok(())
200    }
201
202    // -- 高级查询 --
203    /// 根据类型筛选节点
204    pub fn filter_nodes<P>(
205        &self,
206        predicate: P,
207    ) -> Vec<&Arc<Node>>
208    where
209        P: Fn(&Node) -> bool,
210    {
211        self.inner.nodes.values().filter(|n| predicate(n)).collect()
212    }
213    /// 查找第一个匹配节点
214    pub fn find_node<P>(
215        &self,
216        predicate: P,
217    ) -> Option<&Arc<Node>>
218    where
219        P: Fn(&Node) -> bool,
220    {
221        self.inner.nodes.values().find(|n| predicate(n))
222    }
223
224    /// 获取节点在树中的深度
225    ///
226    /// # 参数
227    ///
228    /// * `node_id` - 目标节点ID
229    ///
230    /// # 返回值
231    ///
232    /// 返回节点的深度,根节点深度为0
233    pub fn get_node_depth(
234        &self,
235        node_id: &NodeId,
236    ) -> Option<usize> {
237        let mut depth = 0;
238        let mut current_id = node_id;
239
240        while let Some(parent_id) = self.parent_id(current_id) {
241            depth += 1;
242            current_id = parent_id;
243        }
244
245        Some(depth)
246    }
247
248    /// 获取从根节点到目标节点的完整路径
249    ///
250    /// # 参数
251    ///
252    /// * `node_id` - 目标节点ID
253    ///
254    /// # 返回值
255    ///
256    /// 返回从根节点到目标节点的节点ID路径
257    pub fn get_node_path(
258        &self,
259        node_id: &NodeId,
260    ) -> Vec<NodeId> {
261        let mut path = Vec::new();
262        let mut current_id = node_id;
263
264        while let Some(parent_id) = self.parent_id(current_id) {
265            path.push(current_id.clone());
266            current_id = parent_id;
267        }
268        path.push(current_id.clone());
269        path.reverse();
270
271        path
272    }
273
274    /// 检查节点是否为叶子节点
275    ///
276    /// # 参数
277    ///
278    /// * `node_id` - 目标节点ID
279    ///
280    /// # 返回值
281    ///
282    /// 如果节点不存在或没有子节点则返回 true
283    pub fn is_leaf(
284        &self,
285        node_id: &NodeId,
286    ) -> bool {
287        if let Some(children) = self.children(node_id) {
288            children.is_empty()
289        } else {
290            true
291        }
292    }
293
294    /// 获取节点的同级节点(具有相同父节点的节点)
295    ///
296    /// # 参数
297    ///
298    /// * `node_id` - 目标节点ID
299    ///
300    /// # 返回值
301    ///
302    /// 返回同级节点的ID列表
303    pub fn get_siblings(
304        &self,
305        node_id: &NodeId,
306    ) -> Vec<NodeId> {
307        if let Some(parent_id) = self.parent_id(node_id) {
308            if let Some(children) = self.children(parent_id) {
309                return children
310                    .iter()
311                    .filter(|&id| id != node_id)
312                    .cloned()
313                    .collect();
314            }
315        }
316        Vec::new()
317    }
318
319    /// 获取节点的所有兄弟节点(包括自身)
320    ///
321    /// # 参数
322    ///
323    /// * `node_id` - 目标节点ID
324    ///
325    /// # 返回值
326    ///
327    /// 返回所有兄弟节点的ID列表(包括自身)
328    pub fn get_all_siblings(
329        &self,
330        node_id: &NodeId,
331    ) -> Vec<NodeId> {
332        if let Some(parent_id) = self.parent_id(node_id) {
333            if let Some(children) = self.children(parent_id) {
334                return children.iter().cloned().collect();
335            }
336        }
337        Vec::new()
338    }
339
340    /// 获取节点的子树大小(包括自身和所有子节点)
341    ///
342    /// # 参数
343    ///
344    /// * `node_id` - 目标节点ID
345    ///
346    /// # 返回值
347    ///
348    /// 返回子树中的节点总数
349    pub fn get_subtree_size(
350        &self,
351        node_id: &NodeId,
352    ) -> usize {
353        let mut size = 1; // 包含自身
354        if let Some(children) = self.children(node_id) {
355            for child_id in children {
356                size += self.get_subtree_size(child_id);
357            }
358        }
359        size
360    }
361
362    /// 检查一个节点是否是另一个节点的祖先
363    ///
364    /// # 参数
365    ///
366    /// * `ancestor_id` - 可能的祖先节点ID
367    /// * `descendant_id` - 可能的后代节点ID
368    ///
369    /// # 返回值
370    ///
371    /// 如果 ancestor_id 是 descendant_id 的祖先则返回 true
372    pub fn is_ancestor(
373        &self,
374        ancestor_id: &NodeId,
375        descendant_id: &NodeId,
376    ) -> bool {
377        let mut current_id = descendant_id;
378        while let Some(parent_id) = self.parent_id(current_id) {
379            if parent_id == ancestor_id {
380                return true;
381            }
382            current_id = parent_id;
383        }
384        false
385    }
386
387    /// 获取两个节点的最近公共祖先
388    ///
389    /// # 参数
390    ///
391    /// * `node1_id` - 第一个节点ID
392    /// * `node2_id` - 第二个节点ID
393    ///
394    /// # 返回值
395    ///
396    /// 返回两个节点的最近公共祖先ID
397    pub fn get_lowest_common_ancestor(
398        &self,
399        node1_id: &NodeId,
400        node2_id: &NodeId,
401    ) -> Option<NodeId> {
402        let path1 = self.get_node_path(node1_id);
403        let path2 = self.get_node_path(node2_id);
404
405        for ancestor_id in path1.iter().rev() {
406            if path2.contains(ancestor_id) {
407                return Some(ancestor_id.clone());
408            }
409        }
410        None
411    }
412}