use std::any::Any;
use std::sync::Arc;
use vortex_session::Ref;
use vortex_session::SessionExt;
use vortex_session::SessionVar;
use crate::aggregate_fn::AggregateFnId;
use crate::aggregate_fn::AggregateFnPluginRef;
use crate::aggregate_fn::AggregateFnVTable;
use crate::aggregate_fn::fns::all_nan::AllNan;
use crate::aggregate_fn::fns::all_non_distinct::AllNonDistinct;
use crate::aggregate_fn::fns::all_non_nan::AllNonNan;
use crate::aggregate_fn::fns::all_non_null::AllNonNull;
use crate::aggregate_fn::fns::all_null::AllNull;
use crate::aggregate_fn::fns::bounded_max::BoundedMax;
use crate::aggregate_fn::fns::bounded_min::BoundedMin;
use crate::aggregate_fn::fns::count::Count;
use crate::aggregate_fn::fns::count::CountGroupedKernel;
use crate::aggregate_fn::fns::first::First;
use crate::aggregate_fn::fns::is_constant::IsConstant;
use crate::aggregate_fn::fns::is_sorted::IsSorted;
use crate::aggregate_fn::fns::last::Last;
use crate::aggregate_fn::fns::max::Max;
use crate::aggregate_fn::fns::min::Min;
use crate::aggregate_fn::fns::min_max::MinMax;
use crate::aggregate_fn::fns::nan_count::NanCount;
use crate::aggregate_fn::fns::null_count::NullCount;
use crate::aggregate_fn::fns::sum::PrimitiveGroupedSumEncodingKernel;
use crate::aggregate_fn::fns::sum::Sum;
use crate::aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes;
use crate::aggregate_fn::kernels::DynAggregateKernel;
use crate::aggregate_fn::kernels::DynGroupedAggregateKernel;
use crate::arc_swap_map::ArcSwapMap;
use crate::array::ArrayId;
use crate::array::VTable;
use crate::arrays::Chunked;
use crate::arrays::Dict;
use crate::arrays::Primitive;
use crate::arrays::chunked::compute::aggregate::ChunkedArrayAggregate;
use crate::arrays::dict::compute::is_constant::DictIsConstantKernel;
use crate::arrays::dict::compute::is_sorted::DictIsSortedKernel;
use crate::arrays::dict::compute::min_max::DictMinMaxKernel;
#[derive(Debug)]
pub struct AggregateFnSession {
registry: ArcSwapMap<AggregateFnId, AggregateFnPluginRef>,
kernels: ArcSwapMap<AggregateKernelKey, &'static dyn DynAggregateKernel>,
grouped_kernels: ArcSwapMap<AggregateFnId, &'static dyn DynGroupedAggregateKernel>,
grouped_encoding_kernels:
ArcSwapMap<GroupedEncodingKernelKey, &'static dyn DynGroupedAggregateKernel>,
}
impl SessionVar for AggregateFnSession {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
type AggregateKernelKey = (ArrayId, Option<AggregateFnId>);
type GroupedEncodingKernelKey = (ArrayId, AggregateFnId);
impl Default for AggregateFnSession {
fn default() -> Self {
let this = Self {
registry: ArcSwapMap::default(),
kernels: ArcSwapMap::default(),
grouped_kernels: ArcSwapMap::default(),
grouped_encoding_kernels: ArcSwapMap::default(),
};
this.register(AllNonDistinct);
this.register(AllNonNan);
this.register(AllNonNull);
this.register(AllNan);
this.register(AllNull);
this.register(BoundedMax);
this.register(BoundedMin);
this.register(First);
this.register(IsConstant);
this.register(IsSorted);
this.register(Last);
this.register(Max);
this.register(Min);
this.register(MinMax);
this.register(NanCount);
this.register(NullCount);
this.register(Sum);
this.register(UncompressedSizeInBytes);
this.register_aggregate_kernel(Chunked.id(), None::<AggregateFnId>, &ChunkedArrayAggregate);
this.register_aggregate_kernel(Dict.id(), Some(MinMax.id()), &DictMinMaxKernel);
this.register_aggregate_kernel(Dict.id(), Some(IsConstant.id()), &DictIsConstantKernel);
this.register_aggregate_kernel(Dict.id(), Some(IsSorted.id()), &DictIsSortedKernel);
this.register_grouped_kernel(Count.id(), &CountGroupedKernel);
this.register_grouped_encoding_kernel(
Primitive.id(),
Sum.id(),
&PrimitiveGroupedSumEncodingKernel,
);
this
}
}
impl AggregateFnSession {
pub fn find_plugin(&self, id: &AggregateFnId) -> Option<AggregateFnPluginRef> {
self.registry.get(id)
}
pub fn register<V: AggregateFnVTable>(&self, vtable: V) {
let id = vtable.id();
let pluginref = Arc::new(vtable) as AggregateFnPluginRef;
self.registry.insert(id, pluginref);
}
pub fn find_aggregate_kernel(
&self,
array_id: impl Into<ArrayId>,
agg_fn_id: impl Into<AggregateFnId>,
) -> Option<&'static dyn DynAggregateKernel> {
let id = array_id.into();
let fn_id = agg_fn_id.into();
self.kernels.read(|kernels| {
kernels
.get(&(id, Some(fn_id)))
.or_else(|| kernels.get(&(id, None)))
.copied()
})
}
pub fn register_aggregate_kernel(
&self,
array_id: impl Into<ArrayId>,
agg_fn_id: Option<impl Into<AggregateFnId>>,
kernel: &'static dyn DynAggregateKernel,
) {
let id = (array_id.into(), agg_fn_id.map(|id| id.into()));
self.kernels.insert(id, kernel);
}
pub fn find_grouped_kernel(
&self,
agg_fn_id: impl Into<AggregateFnId>,
) -> Option<&'static dyn DynGroupedAggregateKernel> {
let fn_id = agg_fn_id.into();
self.grouped_kernels
.read(|kernels| kernels.get(&fn_id).copied())
}
pub fn register_grouped_kernel(
&self,
agg_fn_id: impl Into<AggregateFnId>,
kernel: &'static dyn DynGroupedAggregateKernel,
) {
let fn_id = agg_fn_id.into();
self.grouped_kernels.insert(fn_id, kernel)
}
pub fn find_grouped_encoding_kernel(
&self,
array_id: impl Into<ArrayId>,
agg_fn_id: impl Into<AggregateFnId>,
) -> Option<&'static dyn DynGroupedAggregateKernel> {
let id = array_id.into();
let fn_id = agg_fn_id.into();
self.grouped_encoding_kernels
.read(|kernels| kernels.get(&(id, fn_id)).copied())
}
pub fn register_grouped_encoding_kernel(
&self,
array_id: impl Into<ArrayId>,
agg_fn_id: impl Into<AggregateFnId>,
kernel: &'static dyn DynGroupedAggregateKernel,
) {
let id = array_id.into();
let fn_id = agg_fn_id.into();
self.grouped_encoding_kernels.insert((id, fn_id), kernel)
}
}
pub trait AggregateFnSessionExt: SessionExt {
fn aggregate_fns(&self) -> Ref<'_, AggregateFnSession> {
self.get::<AggregateFnSession>()
}
}
impl<S: SessionExt> AggregateFnSessionExt for S {}