use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_session::SessionExt;
use vortex_session::VortexSession;
use crate::ArrayRef;
use crate::optimizer::kernels::ArrayKernels;
pub mod kernels;
pub mod rules;
pub trait ArrayOptimizer {
fn optimize(&self) -> VortexResult<ArrayRef>;
fn optimize_ctx(&self, session: &VortexSession) -> VortexResult<ArrayRef>;
fn optimize_recursive(&self, session: &VortexSession) -> VortexResult<ArrayRef>;
}
impl ArrayOptimizer for ArrayRef {
fn optimize(&self) -> VortexResult<ArrayRef> {
Ok(try_optimize(self, None)?.unwrap_or_else(|| self.clone()))
}
fn optimize_ctx(&self, session: &VortexSession) -> VortexResult<ArrayRef> {
Ok(try_optimize(self, Some(session))?.unwrap_or_else(|| self.clone()))
}
fn optimize_recursive(&self, session: &VortexSession) -> VortexResult<ArrayRef> {
Ok(try_optimize_recursive(self, session)?.unwrap_or_else(|| self.clone()))
}
}
fn try_optimize(
array: &ArrayRef,
session: Option<&VortexSession>,
) -> VortexResult<Option<ArrayRef>> {
let mut current_array = array.clone();
let mut any_optimizations = false;
let array_ref = session.and_then(|s| s.get_opt::<ArrayKernels>());
let mut loop_counter = 0;
'outer: loop {
if loop_counter > 100 {
vortex_bail!("Exceeded maximum optimization iterations (possible infinite loop)");
}
loop_counter += 1;
if let Some(new_array) = current_array.reduce()? {
current_array = new_array;
any_optimizations = true;
continue;
}
for (slot_idx, slot) in current_array.slots().iter().enumerate() {
let Some(child) = slot else { continue };
if let Some(array_ref) = &array_ref
&& let Some(plugins) =
array_ref.find_reduce_parent(current_array.encoding_id(), child.encoding_id())
{
for plugin in plugins.as_ref() {
if let Some(new_array) = plugin(child, ¤t_array, slot_idx)? {
current_array = new_array;
any_optimizations = true;
continue 'outer;
}
}
}
if let Some(new_array) = child.reduce_parent(¤t_array, slot_idx)? {
current_array = new_array;
any_optimizations = true;
continue 'outer;
}
}
break;
}
if any_optimizations {
Ok(Some(current_array))
} else {
Ok(None)
}
}
fn try_optimize_recursive(
array: &ArrayRef,
session: &VortexSession,
) -> VortexResult<Option<ArrayRef>> {
let mut current_array = array.clone();
let mut any_optimizations = false;
if let Some(new_array) = try_optimize(¤t_array, Some(session))? {
current_array = new_array;
any_optimizations = true;
}
let mut new_slots = Vec::with_capacity(current_array.slots().len());
let mut any_slot_optimized = false;
for slot in current_array.slots() {
match slot {
Some(child) => {
if let Some(new_child) = try_optimize_recursive(child, session)? {
new_slots.push(Some(new_child));
any_slot_optimized = true;
} else {
new_slots.push(Some(child.clone()));
}
}
None => new_slots.push(None),
}
}
if any_slot_optimized {
current_array = current_array.with_slots(new_slots)?;
any_optimizations = true;
}
if any_optimizations {
Ok(Some(current_array))
} else {
Ok(None)
}
}