mf_core/helpers/
aggregation_helper.rs

1//! 节点聚合助手 - 并行的自下而上树形节点聚合计算
2//!
3//! 提供高性能的树形结构聚合计算,支持自定义聚合逻辑和并行处理。
4//!
5//! # 核心特性
6//!
7//! - **并行计算**: 使用 rayon 实现同层节点的并行处理
8//! - **层级策略**: 支持自定义层级计算策略,内置缓存优化
9//! - **类型安全**: 泛型设计支持任意聚合数据类型
10//! - **高性能**: 全局 tokio Runtime,避免重复创建开销
11//!
12//! # 示例
13//!
14//! ```rust,ignore
15//! use mf_core::helpers::aggregation_helper::NodeAggregator;
16//!
17//! // 定义聚合逻辑:求和
18//! let aggregator = NodeAggregator::new(
19//!     |node_id: NodeId, state: Arc<State>, cache: Arc<ConcurrentCache<i64>>| async move {
20//!         let node = state.get_node(&node_id)?;
21//!         let children: Vec<NodeId> = state.get_children(&node_id);
22//!
23//!         let sum: i64 = children.iter()
24//!             .filter_map(|child_id| cache.get(child_id))
25//!             .sum();
26//!
27//!         Ok(sum + node.get_value())
28//!     },
29//!     CachedLevelStrategy::new(state.clone()),
30//! );
31//!
32//! // 从叶子节点开始聚合
33//! let results = aggregator.aggregate_up(&leaf_node_id, state)?;
34//! ```
35
36use dashmap::DashMap;
37use mf_model::NodeId;
38use mf_state::state::State;
39use std::collections::HashMap;
40use std::future::Future;
41use std::pin::Pin;
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44
45use crate::error::ForgeResult;
46
47// ============================================================================
48// 并发安全的缓存结构
49// ============================================================================
50
51/// 并发安全的缓存,用于存储节点聚合结果
52///
53/// 内部使用 DashMap 提供无锁并发访问
54#[derive(Clone)]
55pub struct ConcurrentCache<T: Clone + Send + Sync> {
56    inner: Arc<DashMap<NodeId, T>>,
57}
58
59impl<T: Clone + Send + Sync> ConcurrentCache<T> {
60    /// 创建新的并发缓存
61    pub fn new() -> Self {
62        Self { inner: Arc::new(DashMap::new()) }
63    }
64
65    /// 插入或更新缓存值
66    pub fn insert(
67        &self,
68        key: NodeId,
69        value: T,
70    ) {
71        self.inner.insert(key, value);
72    }
73
74    /// 获取缓存值
75    pub fn get(
76        &self,
77        key: &NodeId,
78    ) -> Option<T> {
79        self.inner.get(key).map(|v| v.clone())
80    }
81
82    /// 批量获取缓存值
83    pub fn get_all(&self) -> HashMap<NodeId, T> {
84        self.inner
85            .iter()
86            .map(|entry| (entry.key().clone(), entry.value().clone()))
87            .collect()
88    }
89
90    /// 清空缓存
91    pub fn clear(&self) {
92        self.inner.clear();
93    }
94
95    /// 检查是否包含指定 key
96    pub fn contains(
97        &self,
98        key: &NodeId,
99    ) -> bool {
100        self.inner.contains_key(key)
101    }
102}
103
104impl<T: Clone + Send + Sync> Default for ConcurrentCache<T> {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110// ============================================================================
111// 并发安全的计数器
112// ============================================================================
113
114/// 并发安全的原子计数器
115#[derive(Clone)]
116pub struct ConcurrentCounter {
117    count: Arc<AtomicUsize>,
118}
119
120impl ConcurrentCounter {
121    /// 创建新的计数器
122    pub fn new() -> Self {
123        Self { count: Arc::new(AtomicUsize::new(0)) }
124    }
125
126    /// 增加计数
127    pub fn increment(&self) -> usize {
128        self.count.fetch_add(1, Ordering::SeqCst) + 1
129    }
130
131    /// 获取当前计数
132    pub fn get(&self) -> usize {
133        self.count.load(Ordering::SeqCst)
134    }
135
136    /// 重置计数器
137    pub fn reset(&self) {
138        self.count.store(0, Ordering::SeqCst);
139    }
140}
141
142impl Default for ConcurrentCounter {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148// ============================================================================
149// 层级策略 Trait
150// ============================================================================
151
152/// 层级计算策略 Trait
153///
154/// 定义如何计算节点在树中的层级(深度)
155pub trait LevelStrategy: Send + Sync {
156    /// 计算指定节点的层级
157    ///
158    /// # 参数
159    /// - `node_id`: 节点 ID
160    /// - `state`: 状态引用
161    ///
162    /// # 返回值
163    /// 节点层级,根节点为 0
164    fn get_level(
165        &self,
166        node_id: &NodeId,
167        state: &Arc<State>,
168    ) -> usize;
169}
170
171// ============================================================================
172// 默认层级策略(每次都计算)
173// ============================================================================
174
175/// 默认层级策略 - 每次都遍历父节点链计算层级
176///
177/// 时间复杂度: O(depth)
178pub struct DefaultLevelStrategy;
179
180impl LevelStrategy for DefaultLevelStrategy {
181    fn get_level(
182        &self,
183        node_id: &NodeId,
184        state: &Arc<State>,
185    ) -> usize {
186        let node_pool = &state.node_pool;
187        let mut level = 0;
188        let mut current = node_id.clone();
189
190        while let Some(parent_id) = node_pool.parent_id(&current) {
191            level += 1;
192            current = parent_id.clone();
193        }
194
195        level
196    }
197}
198
199// ============================================================================
200// 缓存层级策略(推荐使用)
201// ============================================================================
202
203/// 缓存层级策略 - 缓存已计算的层级结果
204///
205/// 时间复杂度: 首次 O(depth),后续 O(1)
206pub struct CachedLevelStrategy {
207    cache: Arc<DashMap<NodeId, usize>>,
208}
209
210impl CachedLevelStrategy {
211    /// 创建新的缓存层级策略
212    pub fn new() -> Self {
213        Self { cache: Arc::new(DashMap::new()) }
214    }
215
216    /// 清空缓存
217    pub fn clear_cache(&self) {
218        self.cache.clear();
219    }
220}
221
222impl Default for CachedLevelStrategy {
223    fn default() -> Self {
224        Self::new()
225    }
226}
227
228impl LevelStrategy for CachedLevelStrategy {
229    fn get_level(
230        &self,
231        node_id: &NodeId,
232        state: &Arc<State>,
233    ) -> usize {
234        // 先检查缓存
235        if let Some(level) = self.cache.get(node_id) {
236            return *level;
237        }
238
239        // 缓存未命中,计算层级
240        let node_pool = &state.node_pool;
241        let mut level = 0;
242        let mut current = node_id.clone();
243
244        while let Some(parent_id) = node_pool.parent_id(&current) {
245            level += 1;
246            current = parent_id.clone();
247
248            // 如果父节点已缓存,直接使用
249            if let Some(parent_level) = self.cache.get(&current) {
250                level += *parent_level;
251                break;
252            }
253        }
254
255        // 缓存结果
256        self.cache.insert(node_id.clone(), level);
257        level
258    }
259}
260
261// ============================================================================
262// 节点聚合器 Trait
263// ============================================================================
264
265/// 节点聚合处理器类型定义
266///
267/// 定义单个节点的聚合计算逻辑
268///
269/// 使用 Arc 包装以支持在多个异步任务间共享
270pub type NodeProcessor<T> = Arc<
271    dyn Fn(
272            NodeId,
273            Arc<State>,
274            Arc<ConcurrentCache<T>>,
275        ) -> Pin<Box<dyn Future<Output = ForgeResult<T>> + Send>>
276        + Send
277        + Sync,
278>;
279
280/// 节点聚合器 Trait
281///
282/// 定义树形结构的自下而上聚合计算接口
283pub trait NodeAggregatorTrait<T: Clone + Send + Sync>: Send + Sync {
284    /// 执行自下而上的聚合计算
285    ///
286    /// # 参数
287    /// - `start_node`: 起始节点 ID(通常是叶子节点)
288    /// - `state`: 状态引用
289    ///
290    /// # 返回值
291    /// 所有节点的聚合结果 HashMap
292    fn aggregate_up(
293        &self,
294        start_node: &NodeId,
295        state: Arc<State>,
296    ) -> impl Future<Output = ForgeResult<HashMap<NodeId, T>>> + Send;
297}
298
299// ============================================================================
300// 节点聚合器实现
301// ============================================================================
302
303/// 并行节点聚合器
304///
305/// 提供高性能的树形节点聚合计算,支持自定义聚合逻辑
306///
307/// # 类型参数
308/// - `T`: 聚合结果类型
309pub struct NodeAggregator<T: Clone + Send + Sync + 'static> {
310    /// 聚合结果缓存(修复:使用共享实例)
311    cache: Arc<ConcurrentCache<T>>,
312
313    /// 节点处理器
314    processor: NodeProcessor<T>,
315
316    /// 层级计算策略
317    level_strategy: Arc<dyn LevelStrategy>,
318}
319
320impl<T: Clone + Send + Sync + 'static> NodeAggregator<T> {
321    /// 创建新的节点聚合器
322    ///
323    /// # 参数
324    /// - `processor`: 节点聚合处理函数
325    /// - `level_strategy`: 层级计算策略
326    ///
327    /// # 示例
328    ///
329    /// ```rust,ignore
330    /// let aggregator = NodeAggregator::new(
331    ///     |node_id, state, cache| async move {
332    ///         // 自定义聚合逻辑
333    ///         Ok(result)
334    ///     },
335    ///     CachedLevelStrategy::new(),
336    /// );
337    /// ```
338    pub fn new<F, Fut>(
339        processor: F,
340        level_strategy: impl LevelStrategy + 'static,
341    ) -> Self
342    where
343        F: Fn(NodeId, Arc<State>, Arc<ConcurrentCache<T>>) -> Fut
344            + Send
345            + Sync
346            + 'static,
347        Fut: Future<Output = ForgeResult<T>> + Send + 'static,
348    {
349        // 创建共享缓存实例(修复 P0 问题:确保只有一个缓存实例)
350        let cache = Arc::new(ConcurrentCache::new());
351
352        // 创建处理器闭包(使用 Arc 包装以支持多任务共享)
353        let processor_arc: NodeProcessor<T> =
354            Arc::new(move |id, state, cache| {
355                Box::pin(processor(id, state, cache))
356            });
357
358        Self {
359            cache,
360            processor: processor_arc,
361            level_strategy: Arc::new(level_strategy),
362        }
363    }
364
365    /// 使用默认层级策略创建聚合器
366    pub fn with_default_strategy<F, Fut>(processor: F) -> Self
367    where
368        F: Fn(NodeId, Arc<State>, Arc<ConcurrentCache<T>>) -> Fut
369            + Send
370            + Sync
371            + 'static,
372        Fut: Future<Output = ForgeResult<T>> + Send + 'static,
373    {
374        Self::new(processor, DefaultLevelStrategy)
375    }
376
377    /// 使用缓存层级策略创建聚合器(推荐)
378    pub fn with_cached_strategy<F, Fut>(processor: F) -> Self
379    where
380        F: Fn(NodeId, Arc<State>, Arc<ConcurrentCache<T>>) -> Fut
381            + Send
382            + Sync
383            + 'static,
384        Fut: Future<Output = ForgeResult<T>> + Send + 'static,
385    {
386        Self::new(processor, CachedLevelStrategy::new())
387    }
388
389    /// 收集从起始节点到根节点的所有祖先
390    fn collect_ancestors(
391        &self,
392        start_node: &NodeId,
393        state: &Arc<State>,
394    ) -> Vec<NodeId> {
395        let node_pool = &state.node_pool;
396        let mut ancestors = vec![start_node.clone()];
397        let mut current = start_node.clone();
398
399        while let Some(parent_id) = node_pool.parent_id(&current) {
400            ancestors.push(parent_id.clone());
401            current = parent_id.clone();
402        }
403
404        ancestors
405    }
406
407    /// 按层级分组节点
408    ///
409    /// 返回: HashMap<层级, Vec<节点ID>>
410    fn group_by_level(
411        &self,
412        nodes: &[NodeId],
413        state: &Arc<State>,
414    ) -> HashMap<usize, Vec<NodeId>> {
415        let mut groups: HashMap<usize, Vec<NodeId>> = HashMap::new();
416
417        for node_id in nodes {
418            let level = self.level_strategy.get_level(node_id, state);
419            groups.entry(level).or_default().push(node_id.clone());
420        }
421
422        groups
423    }
424
425    /// 处理单层节点(并发执行)
426    ///
427    /// 使用 tokio::spawn 实现真正的异步并发
428    async fn process_layer(
429        &self,
430        layer_nodes: &[NodeId],
431        state: Arc<State>,
432    ) -> ForgeResult<()> {
433        // 为每个节点创建异步任务
434        let handles: Vec<_> = layer_nodes
435            .iter()
436            .map(|node_id| {
437                let state = state.clone();
438                let cache = self.cache.clone();
439                let node_id = node_id.clone();
440                let processor = self.processor.clone();
441
442                // 使用 tokio::spawn 并发执行
443                tokio::spawn(async move {
444                    let result =
445                        processor(node_id.clone(), state, cache.clone())
446                            .await?;
447                    cache.insert(node_id.clone(), result);
448                    Ok::<_, crate::error::ForgeError>(())
449                })
450            })
451            .collect();
452
453        // 等待所有任务完成
454        for handle in handles {
455            handle.await.map_err(|e| {
456                crate::error::error_utils::engine_error(format!(
457                    "任务执行失败: {}",
458                    e
459                ))
460            })??;
461        }
462
463        Ok(())
464    }
465}
466
467impl<T: Clone + Send + Sync + 'static> NodeAggregatorTrait<T>
468    for NodeAggregator<T>
469{
470    /// 执行自下而上的层级聚合
471    ///
472    /// # 算法流程
473    ///
474    /// 1. 收集从起始节点到根节点的所有祖先
475    /// 2. 按层级分组(叶子节点层级最大)
476    /// 3. 从最深层级开始,逐层向上处理
477    /// 4. 每层内部使用 tokio::spawn 并发处理
478    /// 5. 确保当前层完全处理完成后再处理上一层
479    ///
480    /// # 性能特性
481    ///
482    /// - 同层并发: 使用 tokio::spawn 真正异步并发
483    /// - 层间串行: 保证数据依赖正确性
484    /// - 零开销: 使用调用方现有的 tokio runtime
485    /// - 层级缓存: O(1) 层级查询
486    async fn aggregate_up(
487        &self,
488        start_node: &NodeId,
489        state: Arc<State>,
490    ) -> ForgeResult<HashMap<NodeId, T>> {
491        // 1. 清空缓存(每次聚合重新计算)
492        self.cache.clear();
493
494        // 2. 收集所有需要聚合的节点(从叶子到根)
495        let all_nodes = self.collect_ancestors(start_node, &state);
496
497        // 3. 按层级分组
498        let level_groups = self.group_by_level(&all_nodes, &state);
499
500        // 4. 获取层级列表并排序(从深到浅,即从叶子到根)
501        let mut levels: Vec<usize> = level_groups.keys().copied().collect();
502        levels.sort_by(|a, b| b.cmp(a)); // 降序排列
503
504        // 5. 逐层处理(层级完成后再处理下一层)
505        for level in levels {
506            if let Some(layer_nodes) = level_groups.get(&level) {
507                self.process_layer(layer_nodes, state.clone()).await?;
508            }
509        }
510
511        // 6. 返回所有聚合结果
512        Ok(self.cache.get_all())
513    }
514}