use dashmap::DashMap;
use mf_model::NodeId;
use mf_state::state::State;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::error::ForgeResult;
#[derive(Clone)]
pub struct ConcurrentCache<T: Clone + Send + Sync> {
inner: Arc<DashMap<NodeId, T>>,
}
impl<T: Clone + Send + Sync> ConcurrentCache<T> {
pub fn new() -> Self {
Self { inner: Arc::new(DashMap::new()) }
}
pub fn insert(
&self,
key: NodeId,
value: T,
) {
self.inner.insert(key, value);
}
pub fn get(
&self,
key: &NodeId,
) -> Option<T> {
self.inner.get(key).map(|v| v.clone())
}
pub fn get_all(&self) -> HashMap<NodeId, T> {
self.inner
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect()
}
pub fn clear(&self) {
self.inner.clear();
}
pub fn contains(
&self,
key: &NodeId,
) -> bool {
self.inner.contains_key(key)
}
}
impl<T: Clone + Send + Sync> Default for ConcurrentCache<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct ConcurrentCounter {
count: Arc<AtomicUsize>,
}
impl ConcurrentCounter {
pub fn new() -> Self {
Self { count: Arc::new(AtomicUsize::new(0)) }
}
pub fn increment(&self) -> usize {
self.count.fetch_add(1, Ordering::SeqCst) + 1
}
pub fn get(&self) -> usize {
self.count.load(Ordering::SeqCst)
}
pub fn reset(&self) {
self.count.store(0, Ordering::SeqCst);
}
}
impl Default for ConcurrentCounter {
fn default() -> Self {
Self::new()
}
}
pub trait LevelStrategy: Send + Sync {
fn get_level(
&self,
node_id: &NodeId,
state: &Arc<State>,
) -> usize;
}
pub struct DefaultLevelStrategy;
impl LevelStrategy for DefaultLevelStrategy {
fn get_level(
&self,
node_id: &NodeId,
state: &Arc<State>,
) -> usize {
let node_pool = &state.node_pool;
let mut level = 0;
let mut current = node_id.clone();
while let Some(parent_id) = node_pool.parent_id(¤t) {
level += 1;
current = parent_id.clone();
}
level
}
}
pub struct CachedLevelStrategy {
cache: Arc<DashMap<NodeId, usize>>,
}
impl CachedLevelStrategy {
pub fn new() -> Self {
Self { cache: Arc::new(DashMap::new()) }
}
pub fn clear_cache(&self) {
self.cache.clear();
}
}
impl Default for CachedLevelStrategy {
fn default() -> Self {
Self::new()
}
}
impl LevelStrategy for CachedLevelStrategy {
fn get_level(
&self,
node_id: &NodeId,
state: &Arc<State>,
) -> usize {
if let Some(level) = self.cache.get(node_id) {
return *level;
}
let node_pool = &state.node_pool;
let mut level = 0;
let mut current = node_id.clone();
while let Some(parent_id) = node_pool.parent_id(¤t) {
level += 1;
current = parent_id.clone();
if let Some(parent_level) = self.cache.get(¤t) {
level += *parent_level;
break;
}
}
self.cache.insert(node_id.clone(), level);
level
}
}
pub type NodeProcessor<T> = Arc<
dyn Fn(
NodeId,
Arc<State>,
Arc<ConcurrentCache<T>>,
) -> Pin<Box<dyn Future<Output = ForgeResult<T>> + Send>>
+ Send
+ Sync,
>;
pub trait NodeAggregatorTrait<T: Clone + Send + Sync>: Send + Sync {
fn aggregate_up(
&self,
start_node: &NodeId,
state: Arc<State>,
) -> impl Future<Output = ForgeResult<HashMap<NodeId, T>>> + Send;
}
pub struct NodeAggregator<T: Clone + Send + Sync + 'static> {
cache: Arc<ConcurrentCache<T>>,
processor: NodeProcessor<T>,
level_strategy: Arc<dyn LevelStrategy>,
}
impl<T: Clone + Send + Sync + 'static> NodeAggregator<T> {
pub fn new<F, Fut>(
processor: F,
level_strategy: impl LevelStrategy + 'static,
) -> Self
where
F: Fn(NodeId, Arc<State>, Arc<ConcurrentCache<T>>) -> Fut
+ Send
+ Sync
+ 'static,
Fut: Future<Output = ForgeResult<T>> + Send + 'static,
{
let cache = Arc::new(ConcurrentCache::new());
let processor_arc: NodeProcessor<T> =
Arc::new(move |id, state, cache| {
Box::pin(processor(id, state, cache))
});
Self {
cache,
processor: processor_arc,
level_strategy: Arc::new(level_strategy),
}
}
pub fn with_default_strategy<F, Fut>(processor: F) -> Self
where
F: Fn(NodeId, Arc<State>, Arc<ConcurrentCache<T>>) -> Fut
+ Send
+ Sync
+ 'static,
Fut: Future<Output = ForgeResult<T>> + Send + 'static,
{
Self::new(processor, DefaultLevelStrategy)
}
pub fn with_cached_strategy<F, Fut>(processor: F) -> Self
where
F: Fn(NodeId, Arc<State>, Arc<ConcurrentCache<T>>) -> Fut
+ Send
+ Sync
+ 'static,
Fut: Future<Output = ForgeResult<T>> + Send + 'static,
{
Self::new(processor, CachedLevelStrategy::new())
}
fn collect_ancestors(
&self,
start_node: &NodeId,
state: &Arc<State>,
) -> Vec<NodeId> {
let node_pool = &state.node_pool;
let mut ancestors = vec![start_node.clone()];
let mut current = start_node.clone();
while let Some(parent_id) = node_pool.parent_id(¤t) {
ancestors.push(parent_id.clone());
current = parent_id.clone();
}
ancestors
}
fn group_by_level(
&self,
nodes: &[NodeId],
state: &Arc<State>,
) -> HashMap<usize, Vec<NodeId>> {
let mut groups: HashMap<usize, Vec<NodeId>> = HashMap::new();
for node_id in nodes {
let level = self.level_strategy.get_level(node_id, state);
groups.entry(level).or_default().push(node_id.clone());
}
groups
}
async fn process_layer(
&self,
layer_nodes: &[NodeId],
state: Arc<State>,
) -> ForgeResult<()> {
let handles: Vec<_> = layer_nodes
.iter()
.map(|node_id| {
let state = state.clone();
let cache = self.cache.clone();
let node_id = node_id.clone();
let processor = self.processor.clone();
tokio::spawn(async move {
let result =
processor(node_id.clone(), state, cache.clone())
.await?;
cache.insert(node_id.clone(), result);
Ok::<_, crate::error::ForgeError>(())
})
})
.collect();
for handle in handles {
handle.await.map_err(|e| {
crate::error::error_utils::engine_error(format!(
"任务执行失败: {}",
e
))
})??;
}
Ok(())
}
}
impl<T: Clone + Send + Sync + 'static> NodeAggregatorTrait<T>
for NodeAggregator<T>
{
async fn aggregate_up(
&self,
start_node: &NodeId,
state: Arc<State>,
) -> ForgeResult<HashMap<NodeId, T>> {
self.cache.clear();
let all_nodes = self.collect_ancestors(start_node, &state);
let level_groups = self.group_by_level(&all_nodes, &state);
let mut levels: Vec<usize> = level_groups.keys().copied().collect();
levels.sort_by(|a, b| b.cmp(a));
for level in levels {
if let Some(layer_nodes) = level_groups.get(&level) {
self.process_layer(layer_nodes, state.clone()).await?;
}
}
Ok(self.cache.get_all())
}
}