use super::task_state::{Global, Shard};
use crate::{
core::{
entities::VID,
state::{
accumulator_id::AccId, agg::Accumulator, compute_state::ComputeState,
shuffle_state::ShuffleComputeState, StateType,
},
},
db::{api::view::StaticGraphViewOps, graph::node::NodeView},
prelude::GraphViewOps,
};
use std::{fmt::Debug, sync::Arc};
type MergeFn<CS> =
Arc<dyn Fn(&mut ShuffleComputeState<CS>, &ShuffleComputeState<CS>, usize) + Send + Sync>;
pub struct Context<G, CS>
where
G: StaticGraphViewOps,
CS: ComputeState,
{
pub ss: usize,
g: G,
merge_fns: Vec<MergeFn<CS>>,
resetable_states: Vec<u32>,
}
impl<G, CS> Context<G, CS>
where
G: StaticGraphViewOps,
CS: ComputeState,
{
pub fn new_local_state<O: Debug + Default, F: Fn(NodeView<'static, G>) -> O>(
&self,
init_f: F,
) -> Vec<O> {
let n = self.g.unfiltered_num_nodes();
let mut new_state = Vec::with_capacity(n);
for i in 0..n {
match self.g.node(VID(i)) {
Some(v) => new_state.push(init_f(v)),
None => new_state.push(O::default()),
}
}
new_state
}
pub fn ss(&self) -> usize {
self.ss
}
pub fn graph(&self) -> G {
self.g.clone()
}
pub fn increment_ss(&mut self) {
self.ss += 1;
}
pub fn reset_ss(&mut self) {
self.ss = 0;
}
pub fn resetable_states(&self) -> &[u32] {
&self.resetable_states
}
pub(crate) fn run_merge_shard(&self, a: Shard<CS>, b: Shard<CS>) -> Shard<CS> {
self.run_merge(a.unwrap(), b.unwrap()).into()
}
pub(crate) fn run_merge_global(&self, a: Global<CS>, b: Global<CS>) -> Global<CS> {
self.run_merge(a.unwrap(), b.unwrap()).into()
}
pub(crate) fn run_merge(
&self,
mut a: Arc<ShuffleComputeState<CS>>,
mut b: Arc<ShuffleComputeState<CS>>,
) -> Arc<ShuffleComputeState<CS>> {
if let Some(left) = Arc::get_mut(&mut a) {
for merge_fn in self.merge_fns.iter() {
merge_fn(left, &b, self.ss);
}
a
} else if let Some(right) = Arc::get_mut(&mut b) {
for merge_fn in self.merge_fns.iter() {
merge_fn(right, &a, self.ss);
}
b
} else {
a
}
}
pub fn agg<A: StateType, IN: 'static, OUT: 'static, ACC: Accumulator<A, IN, OUT>>(
&mut self,
id: AccId<A, IN, OUT, ACC>,
) {
let fn_merge: MergeFn<CS> = Arc::new(move |a, b, ss| a.merge_mut(b, id, ss));
self.merge_fns.push(fn_merge);
}
pub fn agg_reset<A: StateType, IN: 'static, OUT: 'static, ACC: Accumulator<A, IN, OUT>>(
&mut self,
id: AccId<A, IN, OUT, ACC>,
) {
let fn_merge: MergeFn<CS> = Arc::new(move |a, b, ss| a.merge_mut(b, id, ss));
self.merge_fns.push(fn_merge);
self.resetable_states.push(id.id());
}
pub fn local_reset<A: StateType, IN: 'static, OUT: 'static, ACC: Accumulator<A, IN, OUT>>(
&mut self,
id: AccId<A, IN, OUT, ACC>,
) {
self.resetable_states.push(id.id());
}
pub fn global_agg<A: StateType, IN: 'static, OUT: 'static, ACC: Accumulator<A, IN, OUT>>(
&mut self,
id: AccId<A, IN, OUT, ACC>,
) {
let fn_merge: MergeFn<CS> = Arc::new(move |a, b, ss| a.merge_mut_global(b, id, ss));
self.merge_fns.push(fn_merge);
}
pub fn global_agg_reset<
A: StateType,
IN: 'static,
OUT: 'static,
ACC: Accumulator<A, IN, OUT>,
>(
&mut self,
id: AccId<A, IN, OUT, ACC>,
) {
let fn_merge: MergeFn<CS> = Arc::new(move |a, b, ss| a.merge_mut_global(b, id, ss));
self.merge_fns.push(fn_merge);
self.resetable_states.push(id.id());
}
}
impl<G: StaticGraphViewOps, CS: ComputeState> From<&G> for Context<G, CS> {
fn from(g: &G) -> Self {
Self {
ss: 0,
g: g.clone(),
merge_fns: vec![],
resetable_states: vec![],
}
}
}
pub struct GlobalState<CS: ComputeState> {
state: Global<CS>,
ss: usize,
}
impl<CS: ComputeState> GlobalState<CS> {
pub fn finalize<A, IN, OUT, ACC>(&self, agg_def: &AccId<A, IN, OUT, ACC>) -> OUT
where
OUT: StateType + Default,
A: StateType + 'static,
ACC: Accumulator<A, IN, OUT>,
{
self.state
.inner()
.read_global(self.ss + 1, agg_def)
.unwrap_or_default()
}
pub(crate) fn new(state: Global<CS>, ss: usize) -> Self {
Self { state, ss }
}
pub fn read<A, IN, OUT, ACC: Accumulator<A, IN, OUT>>(
&self,
acc_id: &AccId<A, IN, OUT, ACC>,
) -> OUT
where
A: StateType,
OUT: Debug,
{
self.state
.inner()
.read_global(self.ss, acc_id)
.unwrap_or(ACC::finish(&ACC::zero()))
}
pub fn read_prev<A, IN, OUT, ACC: Accumulator<A, IN, OUT>>(
&self,
acc_id: &AccId<A, IN, OUT, ACC>,
) -> OUT
where
A: StateType,
OUT: Debug,
{
self.state
.inner()
.read_global(self.ss + 1, acc_id)
.unwrap_or(ACC::finish(&ACC::zero()))
}
}