use std::any::type_name;
use std::fmt::Debug;
use std::marker::PhantomData;
use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::array::ArrayView;
use crate::array::VTable;
use crate::matcher::Matcher;
pub trait ArrayReduceRule<V: VTable>: Debug + Send + Sync + 'static {
fn reduce(&self, array: ArrayView<'_, V>) -> VortexResult<Option<ArrayRef>>;
}
pub trait ArrayParentReduceRule<V: VTable>: Debug + Send + Sync + 'static {
type Parent: Matcher;
fn reduce_parent(
&self,
array: ArrayView<'_, V>,
parent: <Self::Parent as Matcher>::Match<'_>,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>>;
}
pub trait DynArrayParentReduceRule<V: VTable>: Debug + Send + Sync {
fn matches(&self, parent: &ArrayRef) -> bool;
fn reduce_parent(
&self,
array: ArrayView<'_, V>,
parent: &ArrayRef,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>>;
}
pub struct ParentReduceRuleAdapter<V, R> {
rule: R,
_phantom: PhantomData<V>,
}
impl<V: VTable, R: ArrayParentReduceRule<V>> Debug for ParentReduceRuleAdapter<V, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrayParentReduceRuleAdapter")
.field("parent", &type_name::<R::Parent>())
.field("rule", &self.rule)
.finish()
}
}
impl<V: VTable, K: ArrayParentReduceRule<V>> DynArrayParentReduceRule<V>
for ParentReduceRuleAdapter<V, K>
{
fn matches(&self, parent: &ArrayRef) -> bool {
K::Parent::matches(parent)
}
fn reduce_parent(
&self,
child: ArrayView<'_, V>,
parent: &ArrayRef,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
let Some(parent_view) = K::Parent::try_match(parent) else {
return Ok(None);
};
self.rule.reduce_parent(child, parent_view, child_idx)
}
}
pub struct ReduceRuleSet<V: VTable> {
rules: &'static [&'static dyn ArrayReduceRule<V>],
}
impl<V: VTable> ReduceRuleSet<V> {
pub const fn new(rules: &'static [&'static dyn ArrayReduceRule<V>]) -> Self {
Self { rules }
}
pub fn evaluate(&self, array: ArrayView<'_, V>) -> VortexResult<Option<ArrayRef>> {
for rule in self.rules.iter() {
if let Some(reduced) = rule.reduce(array)? {
return Ok(Some(reduced));
}
}
Ok(None)
}
}
pub struct ParentRuleSet<V: VTable> {
rules: &'static [&'static dyn DynArrayParentReduceRule<V>],
}
impl<V: VTable> ParentRuleSet<V> {
pub const fn new(rules: &'static [&'static dyn DynArrayParentReduceRule<V>]) -> Self {
Self { rules }
}
pub const fn lift<R: ArrayParentReduceRule<V>>(
rule: &'static R,
) -> &'static dyn DynArrayParentReduceRule<V> {
const {
assert!(
!(size_of::<R>() != 0),
"Rule must be zero-sized to be lifted"
);
}
unsafe { &*(rule as *const R as *const ParentReduceRuleAdapter<V, R>) }
}
pub fn evaluate(
&self,
child: ArrayView<'_, V>,
parent: &ArrayRef,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
for rule in self.rules.iter() {
if !rule.matches(parent) {
continue;
}
if let Some(reduced) = rule.reduce_parent(child, parent, child_idx)? {
#[cfg(debug_assertions)]
{
vortex_error::vortex_ensure!(
reduced.len() == parent.len(),
"Reduced array length mismatch from {:?}\nFrom:\n{}\nTo:\n{}",
rule,
parent.encoding_id(),
reduced.encoding_id()
);
vortex_error::vortex_ensure!(
reduced.dtype() == parent.dtype(),
"Reduced array dtype mismatch from {:?}\nFrom:\n{}\nTo:\n{}",
rule,
parent.encoding_id(),
reduced.encoding_id()
);
}
return Ok(Some(reduced));
}
}
Ok(None)
}
}