use std::{marker::PhantomData, mem::take};
use dyn_clone::DynClone;
use crate::{
DBData, DBWeight,
algebra::Semigroup,
dynamic::{
DataTrait, DowncastTrait, DynOpt, DynUnit, Erase, Factory, WeightTrait, WithFactory,
},
trace::Cursor,
};
pub trait AggOutputFunc<A: ?Sized, O: ?Sized>: Fn(&mut A, &mut O) + DynClone {}
impl<A: ?Sized, O: ?Sized, F> AggOutputFunc<A, O> for F where F: Fn(&mut A, &mut O) + Clone {}
dyn_clone::clone_trait_object! {<A: ?Sized, O: ?Sized> AggOutputFunc<A, O>}
pub trait AggCombineFunc<A: ?Sized>: Fn(&mut A, &A) + DynClone {}
impl<A: ?Sized, F> AggCombineFunc<A> for F where F: Fn(&mut A, &A) + Clone {}
pub trait Aggregator<K, T, R>: Clone + 'static {
type Accumulator: DBData;
type Semigroup: Semigroup<Self::Accumulator>;
type Output: DBData;
fn aggregate<KTrait, RTrait>(
&self,
cursor: &mut dyn Cursor<KTrait, DynUnit, T, RTrait>,
) -> Option<Self::Accumulator>
where
KTrait: DataTrait + ?Sized,
RTrait: WeightTrait + ?Sized,
K: Erase<KTrait>,
R: Erase<RTrait>;
fn finalize(&self, accumulator: Self::Accumulator) -> Self::Output;
fn aggregate_and_finalize<KTrait, RTrait>(
&self,
cursor: &mut dyn Cursor<KTrait, DynUnit, T, RTrait>,
) -> Option<Self::Output>
where
KTrait: DataTrait + ?Sized,
RTrait: WeightTrait + ?Sized,
K: Erase<KTrait>,
R: Erase<RTrait>,
{
self.aggregate(cursor).map(|x| self.finalize(x))
}
}
#[derive(Clone)]
pub struct Postprocess<A, F> {
aggregator: A,
postprocess: F,
}
impl<A, F> Postprocess<A, F> {
pub fn new(aggregator: A, postprocess: F) -> Self {
Self {
aggregator,
postprocess,
}
}
}
impl<A, F, K, T, R, O> Aggregator<K, T, R> for Postprocess<A, F>
where
A: Aggregator<K, T, R>,
F: (Fn(&A::Output) -> O) + Clone + 'static,
O: DBData,
{
type Accumulator = A::Accumulator;
type Semigroup = A::Semigroup;
type Output = O;
fn aggregate<KTrait, RTrait>(
&self,
cursor: &mut dyn Cursor<KTrait, DynUnit, T, RTrait>,
) -> Option<Self::Accumulator>
where
KTrait: DataTrait + ?Sized,
RTrait: WeightTrait + ?Sized,
K: Erase<KTrait>,
R: Erase<RTrait>,
{
self.aggregator.aggregate(cursor)
}
fn finalize(&self, accumulator: Self::Accumulator) -> Self::Output {
(self.postprocess)(&self.aggregator.finalize(accumulator))
}
}
pub trait DynAggregator<K, T, R>: DynClone + 'static
where
K: DataTrait + ?Sized,
R: WeightTrait + ?Sized,
{
type Accumulator: DataTrait + ?Sized;
type Output: DataTrait + ?Sized;
fn opt_accumulator_factory(&self) -> &'static dyn Factory<DynOpt<Self::Accumulator>>;
fn output_factory(&self) -> &'static dyn Factory<Self::Output>;
fn combine(&self) -> &dyn AggCombineFunc<Self::Accumulator>;
fn aggregate(
&self,
cursor: &mut dyn Cursor<K, DynUnit, T, R>,
accumulator: &mut DynOpt<Self::Accumulator>,
);
fn finalize(&self, accumulator: &mut Self::Accumulator, output: &mut Self::Output);
fn aggregate_and_finalize(
&self,
cursor: &mut dyn Cursor<K, DynUnit, T, R>,
output: &mut DynOpt<Self::Output>,
);
}
pub struct DynAggregatorImpl<
K: ?Sized,
KType,
T: 'static,
R: ?Sized,
RType,
A,
Acc: ?Sized,
Out: ?Sized,
> {
aggregator: A,
phantom: PhantomData<fn(&Acc, &Out, &K, &KType, &T, &R, &RType)>,
}
impl<K: ?Sized, KType, T: 'static, R: ?Sized, RType, A: Clone, Acc: ?Sized, Out: ?Sized> Clone
for DynAggregatorImpl<K, KType, T, R, RType, A, Acc, Out>
{
fn clone(&self) -> Self {
Self {
aggregator: self.aggregator.clone(),
phantom: PhantomData,
}
}
}
impl<K: ?Sized, KType, T: 'static, R: ?Sized, RType, A, Acc: ?Sized, Out: ?Sized>
DynAggregatorImpl<K, KType, T, R, RType, A, Acc, Out>
{
pub fn new(aggregator: A) -> Self {
Self {
aggregator,
phantom: PhantomData,
}
}
}
impl<K, KType, T: 'static, R, RType, A, Acc, Out> DynAggregator<K, T, R>
for DynAggregatorImpl<K, KType, T, R, RType, A, Acc, Out>
where
A: Aggregator<KType, T, RType>,
A::Accumulator: Erase<Acc>,
A::Output: Erase<Out>,
K: DataTrait + ?Sized,
KType: DBData + Erase<K>,
R: WeightTrait + ?Sized,
RType: DBWeight + Erase<R>,
Acc: DataTrait + ?Sized,
Out: DataTrait + ?Sized,
{
type Accumulator = Acc;
type Output = Out;
fn opt_accumulator_factory(&self) -> &'static dyn Factory<DynOpt<Self::Accumulator>> {
WithFactory::<Option<A::Accumulator>>::FACTORY
}
fn output_factory(&self) -> &'static dyn Factory<Self::Output> {
WithFactory::<A::Output>::FACTORY
}
fn combine(&self) -> &dyn AggCombineFunc<Self::Accumulator> {
&|acc, val| {
let acc: &mut A::Accumulator = unsafe { acc.downcast_mut::<A::Accumulator>() };
let val: &A::Accumulator = unsafe { val.downcast::<A::Accumulator>() };
*acc = A::Semigroup::combine(acc, val);
}
}
fn aggregate(
&self,
cursor: &mut dyn Cursor<K, DynUnit, T, R>,
acc: &mut DynOpt<Self::Accumulator>,
) {
let acc = unsafe { acc.downcast_mut::<Option<A::Accumulator>>() };
*acc = self.aggregator.aggregate(cursor);
}
fn finalize(&self, acc: &mut Self::Accumulator, output: &mut Self::Output) {
let acc: &mut A::Accumulator = unsafe { acc.downcast_mut::<A::Accumulator>() };
let output: &mut A::Output = unsafe { output.downcast_mut::<A::Output>() };
*output = self.aggregator.finalize(take(acc))
}
fn aggregate_and_finalize(
&self,
cursor: &mut dyn Cursor<K, DynUnit, T, R>,
output: &mut DynOpt<Self::Output>,
) {
let output = unsafe { output.downcast_mut::<Option<A::Output>>() };
let acc = self.aggregator.aggregate(cursor);
*output = acc.map(|acc| self.aggregator.finalize(acc))
}
}