use std::any::Any;
use std::hash::BuildHasher;
use std::sync::Arc;
use std::sync::LazyLock;
use arc_swap::ArcSwap;
use vortex_error::VortexResult;
use vortex_session::Ref;
use vortex_session::SessionExt;
use vortex_session::SessionVar;
use vortex_session::registry::Id;
use vortex_utils::aliases::DefaultHashBuilder;
use vortex_utils::aliases::hash_map::HashMap;
use crate::ArrayRef;
static FN_HASHER: LazyLock<DefaultHashBuilder> = LazyLock::new(DefaultHashBuilder::default);
pub type ReduceParentFn =
fn(child: &ArrayRef, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;
#[derive(Debug, Default)]
pub struct ArrayKernels {
reduce_parent: ArcSwap<HashMap<u64, Arc<[ReduceParentFn]>>>,
}
impl ArrayKernels {
pub fn empty() -> Self {
Self::default()
}
pub fn register_reduce_parent<I: IntoIterator<Item = ReduceParentFn>>(
&self,
parent: Id,
child: Id,
fns: I,
) {
let registry = self.reduce_parent.load();
let id = self.hash_fn_ids(parent, child);
let mut owned_registry = registry.as_ref().clone();
if let Some(existing) = owned_registry.remove(&id) {
owned_registry.insert(id, existing.as_ref().iter().cloned().chain(fns).collect());
} else {
owned_registry.insert(id, fns.into_iter().collect());
}
self.reduce_parent.store(Arc::new(owned_registry));
}
pub fn find_reduce_parent(&self, parent: Id, child: Id) -> Option<Arc<[ReduceParentFn]>> {
let id = self.hash_fn_ids(parent, child);
let map = self.reduce_parent.load();
let entry = map.get(&id)?;
Some(Arc::clone(entry))
}
fn hash_fn_ids(&self, parent: Id, child: Id) -> u64 {
FN_HASHER.hash_one((parent, child))
}
}
impl SessionVar for ArrayKernels {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
pub trait ArrayKernelsExt: SessionExt {
fn kernels(&self) -> Ref<'_, ArrayKernels> {
self.get::<ArrayKernels>()
}
}
impl<S: SessionExt> ArrayKernelsExt for S {}