1use dashmap::DashMap;
37use mf_model::NodeId;
38use mf_state::state::State;
39use std::collections::HashMap;
40use std::future::Future;
41use std::pin::Pin;
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44
45use crate::error::ForgeResult;
46
47#[derive(Clone)]
55pub struct ConcurrentCache<T: Clone + Send + Sync> {
56 inner: Arc<DashMap<NodeId, T>>,
57}
58
59impl<T: Clone + Send + Sync> ConcurrentCache<T> {
60 pub fn new() -> Self {
62 Self { inner: Arc::new(DashMap::new()) }
63 }
64
65 pub fn insert(
67 &self,
68 key: NodeId,
69 value: T,
70 ) {
71 self.inner.insert(key, value);
72 }
73
74 pub fn get(
76 &self,
77 key: &NodeId,
78 ) -> Option<T> {
79 self.inner.get(key).map(|v| v.clone())
80 }
81
82 pub fn get_all(&self) -> HashMap<NodeId, T> {
84 self.inner
85 .iter()
86 .map(|entry| (entry.key().clone(), entry.value().clone()))
87 .collect()
88 }
89
90 pub fn clear(&self) {
92 self.inner.clear();
93 }
94
95 pub fn contains(
97 &self,
98 key: &NodeId,
99 ) -> bool {
100 self.inner.contains_key(key)
101 }
102}
103
104impl<T: Clone + Send + Sync> Default for ConcurrentCache<T> {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110#[derive(Clone)]
116pub struct ConcurrentCounter {
117 count: Arc<AtomicUsize>,
118}
119
120impl ConcurrentCounter {
121 pub fn new() -> Self {
123 Self { count: Arc::new(AtomicUsize::new(0)) }
124 }
125
126 pub fn increment(&self) -> usize {
128 self.count.fetch_add(1, Ordering::SeqCst) + 1
129 }
130
131 pub fn get(&self) -> usize {
133 self.count.load(Ordering::SeqCst)
134 }
135
136 pub fn reset(&self) {
138 self.count.store(0, Ordering::SeqCst);
139 }
140}
141
142impl Default for ConcurrentCounter {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148pub trait LevelStrategy: Send + Sync {
156 fn get_level(
165 &self,
166 node_id: &NodeId,
167 state: &Arc<State>,
168 ) -> usize;
169}
170
171pub struct DefaultLevelStrategy;
179
180impl LevelStrategy for DefaultLevelStrategy {
181 fn get_level(
182 &self,
183 node_id: &NodeId,
184 state: &Arc<State>,
185 ) -> usize {
186 let node_pool = &state.node_pool;
187 let mut level = 0;
188 let mut current = node_id.clone();
189
190 while let Some(parent_id) = node_pool.parent_id(¤t) {
191 level += 1;
192 current = parent_id.clone();
193 }
194
195 level
196 }
197}
198
199pub struct CachedLevelStrategy {
207 cache: Arc<DashMap<NodeId, usize>>,
208}
209
210impl CachedLevelStrategy {
211 pub fn new() -> Self {
213 Self { cache: Arc::new(DashMap::new()) }
214 }
215
216 pub fn clear_cache(&self) {
218 self.cache.clear();
219 }
220}
221
222impl Default for CachedLevelStrategy {
223 fn default() -> Self {
224 Self::new()
225 }
226}
227
228impl LevelStrategy for CachedLevelStrategy {
229 fn get_level(
230 &self,
231 node_id: &NodeId,
232 state: &Arc<State>,
233 ) -> usize {
234 if let Some(level) = self.cache.get(node_id) {
236 return *level;
237 }
238
239 let node_pool = &state.node_pool;
241 let mut level = 0;
242 let mut current = node_id.clone();
243
244 while let Some(parent_id) = node_pool.parent_id(¤t) {
245 level += 1;
246 current = parent_id.clone();
247
248 if let Some(parent_level) = self.cache.get(¤t) {
250 level += *parent_level;
251 break;
252 }
253 }
254
255 self.cache.insert(node_id.clone(), level);
257 level
258 }
259}
260
261pub type NodeProcessor<T> = Arc<
271 dyn Fn(
272 NodeId,
273 Arc<State>,
274 Arc<ConcurrentCache<T>>,
275 ) -> Pin<Box<dyn Future<Output = ForgeResult<T>> + Send>>
276 + Send
277 + Sync,
278>;
279
280pub trait NodeAggregatorTrait<T: Clone + Send + Sync>: Send + Sync {
284 fn aggregate_up(
293 &self,
294 start_node: &NodeId,
295 state: Arc<State>,
296 ) -> impl Future<Output = ForgeResult<HashMap<NodeId, T>>> + Send;
297}
298
299pub struct NodeAggregator<T: Clone + Send + Sync + 'static> {
310 cache: Arc<ConcurrentCache<T>>,
312
313 processor: NodeProcessor<T>,
315
316 level_strategy: Arc<dyn LevelStrategy>,
318}
319
320impl<T: Clone + Send + Sync + 'static> NodeAggregator<T> {
321 pub fn new<F, Fut>(
339 processor: F,
340 level_strategy: impl LevelStrategy + 'static,
341 ) -> Self
342 where
343 F: Fn(NodeId, Arc<State>, Arc<ConcurrentCache<T>>) -> Fut
344 + Send
345 + Sync
346 + 'static,
347 Fut: Future<Output = ForgeResult<T>> + Send + 'static,
348 {
349 let cache = Arc::new(ConcurrentCache::new());
351
352 let processor_arc: NodeProcessor<T> =
354 Arc::new(move |id, state, cache| {
355 Box::pin(processor(id, state, cache))
356 });
357
358 Self {
359 cache,
360 processor: processor_arc,
361 level_strategy: Arc::new(level_strategy),
362 }
363 }
364
365 pub fn with_default_strategy<F, Fut>(processor: F) -> Self
367 where
368 F: Fn(NodeId, Arc<State>, Arc<ConcurrentCache<T>>) -> Fut
369 + Send
370 + Sync
371 + 'static,
372 Fut: Future<Output = ForgeResult<T>> + Send + 'static,
373 {
374 Self::new(processor, DefaultLevelStrategy)
375 }
376
377 pub fn with_cached_strategy<F, Fut>(processor: F) -> Self
379 where
380 F: Fn(NodeId, Arc<State>, Arc<ConcurrentCache<T>>) -> Fut
381 + Send
382 + Sync
383 + 'static,
384 Fut: Future<Output = ForgeResult<T>> + Send + 'static,
385 {
386 Self::new(processor, CachedLevelStrategy::new())
387 }
388
389 fn collect_ancestors(
391 &self,
392 start_node: &NodeId,
393 state: &Arc<State>,
394 ) -> Vec<NodeId> {
395 let node_pool = &state.node_pool;
396 let mut ancestors = vec![start_node.clone()];
397 let mut current = start_node.clone();
398
399 while let Some(parent_id) = node_pool.parent_id(¤t) {
400 ancestors.push(parent_id.clone());
401 current = parent_id.clone();
402 }
403
404 ancestors
405 }
406
407 fn group_by_level(
411 &self,
412 nodes: &[NodeId],
413 state: &Arc<State>,
414 ) -> HashMap<usize, Vec<NodeId>> {
415 let mut groups: HashMap<usize, Vec<NodeId>> = HashMap::new();
416
417 for node_id in nodes {
418 let level = self.level_strategy.get_level(node_id, state);
419 groups.entry(level).or_default().push(node_id.clone());
420 }
421
422 groups
423 }
424
425 async fn process_layer(
429 &self,
430 layer_nodes: &[NodeId],
431 state: Arc<State>,
432 ) -> ForgeResult<()> {
433 let handles: Vec<_> = layer_nodes
435 .iter()
436 .map(|node_id| {
437 let state = state.clone();
438 let cache = self.cache.clone();
439 let node_id = node_id.clone();
440 let processor = self.processor.clone();
441
442 tokio::spawn(async move {
444 let result =
445 processor(node_id.clone(), state, cache.clone())
446 .await?;
447 cache.insert(node_id.clone(), result);
448 Ok::<_, crate::error::ForgeError>(())
449 })
450 })
451 .collect();
452
453 for handle in handles {
455 handle.await.map_err(|e| {
456 crate::error::error_utils::engine_error(format!(
457 "任务执行失败: {}",
458 e
459 ))
460 })??;
461 }
462
463 Ok(())
464 }
465}
466
467impl<T: Clone + Send + Sync + 'static> NodeAggregatorTrait<T>
468 for NodeAggregator<T>
469{
470 async fn aggregate_up(
487 &self,
488 start_node: &NodeId,
489 state: Arc<State>,
490 ) -> ForgeResult<HashMap<NodeId, T>> {
491 self.cache.clear();
493
494 let all_nodes = self.collect_ancestors(start_node, &state);
496
497 let level_groups = self.group_by_level(&all_nodes, &state);
499
500 let mut levels: Vec<usize> = level_groups.keys().copied().collect();
502 levels.sort_by(|a, b| b.cmp(a)); for level in levels {
506 if let Some(layer_nodes) = level_groups.get(&level) {
507 self.process_layer(layer_nodes, state.clone()).await?;
508 }
509 }
510
511 Ok(self.cache.get_all())
513 }
514}