mod array;
mod rules;
mod slice_;
mod vtable;
use std::ops::Range;
pub use array::SliceArrayExt;
pub use array::SliceData;
pub use array::SliceDataParts;
use vortex_error::VortexResult;
pub use vtable::*;
use crate::ArrayRef;
use crate::Canonical;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::array::ArrayView;
use crate::array::VTable;
use crate::kernel::ExecuteParentKernel;
use crate::matcher::Matcher;
use crate::optimizer::rules::ArrayParentReduceRule;
pub trait SliceReduce: VTable {
fn slice(array: ArrayView<'_, Self>, range: Range<usize>) -> VortexResult<Option<ArrayRef>>;
}
pub trait SliceKernel: VTable {
fn slice(
array: ArrayView<'_, Self>,
range: Range<usize>,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>>;
}
fn precondition<V: VTable>(array: ArrayView<'_, V>, range: &Range<usize>) -> Option<ArrayRef> {
if range.start == 0 && range.end == array.len() {
return Some(array.array().clone());
};
if range.start == range.end {
return Some(Canonical::empty(array.dtype()).into_array());
}
None
}
#[derive(Default, Debug)]
pub struct SliceReduceAdaptor<V>(pub V);
impl<V> ArrayParentReduceRule<V> for SliceReduceAdaptor<V>
where
V: SliceReduce,
{
type Parent = Slice;
fn reduce_parent(
&self,
array: ArrayView<'_, V>,
parent: <Self::Parent as Matcher>::Match<'_>,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
assert_eq!(child_idx, 0);
if let Some(result) = precondition::<V>(array, &parent.range) {
return Ok(Some(result));
}
<V as SliceReduce>::slice(array, parent.range.clone())
}
}
#[derive(Default, Debug)]
pub struct SliceExecuteAdaptor<V>(pub V);
impl<V> ExecuteParentKernel<V> for SliceExecuteAdaptor<V>
where
V: SliceKernel,
{
type Parent = Slice;
fn execute_parent(
&self,
array: ArrayView<'_, V>,
parent: <Self::Parent as Matcher>::Match<'_>,
child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
assert_eq!(child_idx, 0);
if let Some(result) = precondition::<V>(array, &parent.range) {
return Ok(Some(result));
}
<V as SliceKernel>::slice(array, parent.range.clone(), ctx)
}
}