use super::{AggOutputFunc, IncAggregateLinearFactories};
use crate::{
Circuit, DBData, DBWeight, DynZWeight, Stream, Timestamp, ZWeight,
algebra::{
AddAssignByRef, AddByRef, HasZero, IndexedZSet, IndexedZSetReader, MulByRef, NegByRef,
OrdIndexedZSet,
},
declare_trait_object,
dynamic::{ClonableTrait, DataTrait, Erase, Factory, Weight, WeightTrait, WithFactory},
trace::Deserializable,
};
use feldera_macros::IsNone;
use rkyv::{Archive, Deserialize, Serialize};
use size_of::SizeOf;
use std::{
fmt::Debug,
hash::Hash,
mem::take,
ops::{Div, Neg},
};
pub struct AvgFactories<Z, A, W, T>
where
Z: IndexedZSetReader,
A: DataTrait + ?Sized,
W: DataTrait + ?Sized,
T: Timestamp,
{
aggregate_factories:
IncAggregateLinearFactories<Z, DynAverage<W, DynZWeight>, OrdIndexedZSet<Z::Key, A>, T>,
weight_factory: &'static dyn Factory<W>,
}
impl<Z, A, W, T> AvgFactories<Z, A, W, T>
where
Z: IndexedZSet,
A: DataTrait + ?Sized,
W: WeightTrait + ?Sized,
DynAverage<W, DynZWeight>: WeightTrait,
T: Timestamp,
{
pub fn new<KType, AType, WType>() -> Self
where
KType: DBData + Erase<Z::Key>,
<KType as Deserializable>::ArchivedDeser: Ord,
WType: DBWeight + Erase<W>,
AType: DBWeight + Erase<A>,
WType: From<ZWeight> + Div<Output = WType>,
{
Self {
aggregate_factories: IncAggregateLinearFactories::new::<
KType,
Avg<WType, ZWeight>,
AType,
>(),
weight_factory: WithFactory::<WType>::FACTORY,
}
}
}
#[derive(
Debug,
Default,
Clone,
Eq,
Hash,
PartialEq,
Ord,
PartialOrd,
SizeOf,
Archive,
Serialize,
Deserialize,
IsNone,
)]
#[archive_attr(derive(Ord, Eq, PartialEq, PartialOrd))]
#[archive(bound(archive = "<T as Archive>::Archived: Ord, <R as Archive>::Archived: Ord"))]
#[archive(compare(PartialEq, PartialOrd))]
pub struct Avg<T, R> {
sum: T,
count: R,
}
impl<T, R> Avg<T, R> {
pub const fn new(sum: T, count: R) -> Self {
Self { sum, count }
}
pub fn sum(&self) -> T
where
T: Clone,
{
self.sum.clone()
}
pub fn count(&self) -> R
where
R: Clone,
{
self.count.clone()
}
pub fn compute_avg(&self) -> Option<T>
where
R: Clone + HasZero,
T: From<R> + Div<Output = T> + Clone,
{
if self.count.is_zero() {
None
} else {
Some(self.sum.clone() / T::from(self.count.clone()))
}
}
}
impl<T, R> HasZero for Avg<T, R>
where
T: HasZero,
R: HasZero,
{
fn is_zero(&self) -> bool {
self.sum.is_zero() && self.count.is_zero()
}
fn zero() -> Self {
Self::new(T::zero(), R::zero())
}
}
impl<T, R> AddByRef for Avg<T, R>
where
T: AddByRef,
R: AddByRef,
{
fn add_by_ref(&self, other: &Self) -> Self {
Self::new(
self.sum.add_by_ref(&other.sum),
self.count.add_by_ref(&other.count),
)
}
}
impl<T, R> AddAssignByRef for Avg<T, R>
where
T: AddAssignByRef,
R: AddAssignByRef,
{
fn add_assign_by_ref(&mut self, rhs: &Self) {
self.sum.add_assign_by_ref(&rhs.sum);
self.count.add_assign_by_ref(&rhs.count);
}
}
impl<T, R> Neg for Avg<T, R>
where
T: Neg<Output = T>,
R: Neg<Output = R>,
{
type Output = Self;
fn neg(self) -> Self {
Self::new(self.sum.neg(), self.count.neg())
}
}
impl<T, R> NegByRef for Avg<T, R>
where
T: NegByRef,
R: NegByRef,
{
fn neg_by_ref(&self) -> Self {
Self::new(self.sum.neg_by_ref(), self.count.neg_by_ref())
}
}
impl<T, R> MulByRef<R> for Avg<T, R>
where
T: MulByRef<Output = T>,
T: From<R>,
R: MulByRef<Output = R> + Clone,
{
type Output = Avg<T, R>;
fn mul_by_ref(&self, rhs: &R) -> Avg<T, R> {
Self::new(
self.sum.mul_by_ref(&T::from(rhs.clone())),
self.count.mul_by_ref(rhs),
)
}
}
pub trait Average<T: DataTrait + ?Sized, R: WeightTrait + ?Sized>: Weight {
fn sum(&self) -> &T;
fn count(&self) -> &R;
fn split_mut(&mut self) -> (&mut T, &mut R);
#[allow(clippy::wrong_self_convention)]
fn from_refs(&mut self, sum: &T, count: &R);
#[allow(clippy::wrong_self_convention)]
fn from_vals(&mut self, sum: &mut T, count: &mut R);
fn compute_avg(&self, avg: &mut T);
}
impl<T1Type, T2Type, T1, T2> Average<T1, T2> for Avg<T1Type, T2Type>
where
T1Type: DBWeight + Erase<T1>,
T2Type: DBWeight + Erase<T2>,
T1Type: From<T2Type> + Div<Output = T1Type>,
T1: DataTrait + ?Sized,
T2: WeightTrait + ?Sized,
{
fn sum(&self) -> &T1 {
self.sum.erase()
}
fn count(&self) -> &T2 {
self.count.erase()
}
fn split_mut(&mut self) -> (&mut T1, &mut T2) {
(self.sum.erase_mut(), self.count.erase_mut())
}
fn from_refs(&mut self, sum: &T1, count: &T2) {
let sum: &T1Type = unsafe { sum.downcast::<T1Type>() };
let count: &T2Type = unsafe { count.downcast::<T2Type>() };
self.sum = sum.clone();
self.count = count.clone();
}
fn from_vals(&mut self, sum: &mut T1, count: &mut T2) {
let sum: &mut T1Type = unsafe { sum.downcast_mut::<T1Type>() };
let count: &mut T2Type = unsafe { count.downcast_mut::<T2Type>() };
self.sum = take(sum);
self.count = take(count);
}
fn compute_avg(&self, avg: &mut T1) {
let avg: &mut T1Type = unsafe { avg.downcast_mut::<T1Type>() };
*avg = Avg::compute_avg(self).unwrap();
}
}
declare_trait_object!(DynAverage<T, R> = dyn Average<T, R>
where
T: DataTrait + ?Sized,
R: WeightTrait + ?Sized
);
impl<C, Z> Stream<C, Z>
where
C: Circuit,
Z: Clone + 'static,
{
#[track_caller]
pub fn dyn_average<A, W>(
&self,
persistent_id: Option<&str>,
factories: &AvgFactories<Z, A, W, C::Time>,
f: Box<dyn Fn(&Z::Key, &Z::Val, &DynZWeight, &mut W)>,
out_func: Box<dyn AggOutputFunc<W, A>>,
) -> Stream<C, OrdIndexedZSet<Z::Key, A>>
where
A: DataTrait + ?Sized,
W: DataTrait + ?Sized,
Z: IndexedZSet,
{
let weight_factory = factories.weight_factory;
self.dyn_aggregate_linear_generic(
persistent_id,
&factories.aggregate_factories,
Box::new(
move |k: &Z::Key, v: &Z::Val, w: &Z::R, avg: &mut DynAverage<W, Z::R>| {
let (sum, count) = avg.split_mut();
w.clone_to(count);
f(k, v, w, sum);
},
),
Box::new(move |avg, out| {
weight_factory.with(&mut |w| {
avg.compute_avg(w);
out_func(w, out);
})
}),
)
}
}
#[cfg(test)]
mod tests {
use rkyv::Deserialize;
use crate::operator::Avg;
#[test]
fn avg_decode_encode() {
type Type = Avg<u64, i64>;
for input in [
Avg::new(0, 0),
Avg::new(u64::MAX, i64::MAX),
Avg::new(u64::MIN, i64::MIN),
] {
let input: Type = input;
let encoded = rkyv::to_bytes::<_, 256>(&input).unwrap();
let archived = unsafe { rkyv::archived_root::<Type>(&encoded[..]) };
let decoded: Type = archived.deserialize(&mut rkyv::Infallible).unwrap();
assert_eq!(decoded, input);
}
}
}