polars-plan 0.54.3

Lazy query engine for the Polars DataFrame library
use polars_core::error::PolarsResult;
use polars_core::prelude::PlHashSet;
use polars_utils::aliases::InitHashMaps;
use polars_utils::arena::{Arena, Node};
use polars_utils::pl_str::PlSmallStr;

use super::OptimizationRule;
use crate::prelude::{AExpr, IR};

pub struct FlattenMergeSortedRule {
    collected_inputs: Vec<Node>,
    optimized_nodes: PlHashSet<Node>,
    traversal_stack: Vec<Node>,
}

impl FlattenMergeSortedRule {
    pub fn new() -> Self {
        Self {
            collected_inputs: Vec::new(),
            optimized_nodes: PlHashSet::new(),
            traversal_stack: Vec::new(),
        }
    }
}

impl OptimizationRule for FlattenMergeSortedRule {
    fn optimize_plan(
        &mut self,
        lp_arena: &mut Arena<IR>,
        _expr_arena: &mut Arena<AExpr>,
        node: Node,
    ) -> PolarsResult<Option<IR>> {
        let (key, maintain_order) = match lp_arena.get(node) {
            IR::MergeSorted {
                key,
                maintain_order,
                ..
            } => (key.clone(), *maintain_order),
            _ => return Ok(None),
        };
        if !self.optimized_nodes.insert(node) {
            return Ok(None);
        }

        self.collected_inputs.clear();
        collect_merge_sorted_inputs(
            node,
            &key,
            maintain_order,
            lp_arena,
            &mut self.collected_inputs,
            &mut self.traversal_stack,
        );

        if self.collected_inputs.len() <= 2 {
            return Ok(None);
        }

        Ok(Some(build_merge_sorted_subtree(
            &self.collected_inputs,
            &key,
            maintain_order,
            lp_arena,
            &mut self.optimized_nodes,
        )))
    }
}

fn collect_merge_sorted_inputs(
    root: Node,
    key: &PlSmallStr,
    maintain_order: bool,
    lp_arena: &Arena<IR>,
    out: &mut Vec<Node>,
    traversal_stack: &mut Vec<Node>,
) {
    traversal_stack.clear();
    traversal_stack.push(root);

    while let Some(node) = traversal_stack.pop() {
        match lp_arena.get(node) {
            IR::MergeSorted {
                input_left,
                input_right,
                key: merge_key,
                maintain_order: merge_maintain_order,
            } if merge_key == key && *merge_maintain_order == maintain_order => {
                traversal_stack.push(*input_right);
                traversal_stack.push(*input_left);
            },
            _ => out.push(node),
        }
    }
}

fn build_merge_sorted_subtree(
    inputs: &[Node],
    key: &PlSmallStr,
    maintain_order: bool,
    lp_arena: &mut Arena<IR>,
    optimized_nodes: &mut PlHashSet<Node>,
) -> IR {
    debug_assert!(inputs.len() >= 2);

    let split = inputs.len().div_ceil(2);
    let input_left = build_merge_sorted_subtree_node(
        &inputs[..split],
        key,
        maintain_order,
        lp_arena,
        optimized_nodes,
    );
    let input_right = build_merge_sorted_subtree_node(
        &inputs[split..],
        key,
        maintain_order,
        lp_arena,
        optimized_nodes,
    );
    IR::MergeSorted {
        input_left,
        input_right,
        key: key.clone(),
        maintain_order,
    }
}

fn build_merge_sorted_subtree_node(
    inputs: &[Node],
    key: &PlSmallStr,
    maintain_order: bool,
    lp_arena: &mut Arena<IR>,
    optimized_nodes: &mut PlHashSet<Node>,
) -> Node {
    match inputs {
        [node] => *node,
        _ => {
            let subtree =
                build_merge_sorted_subtree(inputs, key, maintain_order, lp_arena, optimized_nodes);
            let node = lp_arena.add(subtree);
            optimized_nodes.insert(node);
            node
        },
    }
}