polars-plan 0.54.2

Lazy query engine for the Polars DataFrame library
pub mod storage;

use std::ops::ControlFlow;

use polars_core::prelude::{InitHashMaps as _, PlHashMap};
use polars_utils::arena::{Arena, Node};
use polars_utils::unique_id::UniqueId;
use polars_utils::{UnitVec, unitvec};

use crate::plans::IR;
use crate::plans::optimizer::ir_traversal::storage::IRTraversalStorage;
use crate::traversal::tree_traversal::{GraphVisitOrder, TreeTraversalImpl};
use crate::traversal::visitor::{FnVisitors, NodeVisitor, SubtreeVisit};

#[derive(Debug)]
pub struct IRCacheNodeVisit<Key> {
    pub hits: usize,
    pub output_keys: UnitVec<Key>,
}

pub fn get_ir_cache_hits<Key>(
    root: Node,
    mut arena: &Arena<IR>,
    visit_stack: &mut Vec<Node>,
) -> PlHashMap<UniqueId, IRCacheNodeVisit<Key>> {
    let mut cache_out_edge_keys_map: PlHashMap<UniqueId, IRCacheNodeVisit<Key>> = PlHashMap::new();

    TreeTraversalImpl {
        storage: &mut arena,
        visit_stack,
        edges: &mut vec![()],
        persist_input_edge_idxs: None,
        graph_visit_order_fn: None,
        visitor: &mut FnVisitors::new(
            || (),
            |key, storage: &mut &Arena<IR>, _| {
                ControlFlow::Continue(if let IR::Cache { id, .. } = storage.get(key) {
                    use hashbrown::hash_map::Entry;

                    match cache_out_edge_keys_map.entry(*id) {
                        Entry::Occupied(mut e) => {
                            e.get_mut().hits += 1;
                            SubtreeVisit::Skip
                        },
                        Entry::Vacant(e) => {
                            e.insert(IRCacheNodeVisit {
                                hits: 1,
                                output_keys: unitvec![],
                            });
                            SubtreeVisit::Visit
                        },
                    }
                } else {
                    SubtreeVisit::Visit
                })
            },
            |_, _, _| ControlFlow::<()>::Continue(()),
        ),
    }
    .traverse_rec(root, 0, false)
    .continue_value()
    .unwrap();

    cache_out_edge_keys_map
}

pub fn ir_graph_traversal<'storage, Edge, BreakValue>(
    root: Node,
    visitor: &mut dyn NodeVisitor<
        Key = Node,
        Storage = IRTraversalStorage<'storage>,
        Edge = Edge,
        BreakValue = BreakValue,
    >,
    visit_stack: &mut Vec<Node>,
    edges: &mut Vec<Edge>,
    mut storage: IRTraversalStorage<'storage>,
) -> ControlFlow<BreakValue, Edge> {
    let mut cache_out_edge_keys_map = get_ir_cache_hits::<usize>(root, &storage, visit_stack);

    visit_stack.clear();
    let root_edge_idx = edges.len();
    edges.push(visitor.default_edge(root, None));
    let root_edge_deleted = visitor.is_deleted_edge(&edges[root_edge_idx]) == Some(true);

    TreeTraversalImpl {
        storage: &mut storage,
        visit_stack,
        edges,
        persist_input_edge_idxs: None,
        graph_visit_order_fn: Some(&mut |key, storage, edge_key_idx| {
            if let IR::Cache { id, .. } = storage.get(key) {
                let cache_node_visit = cache_out_edge_keys_map.get_mut(id).unwrap();

                if let Some(idx) = edge_key_idx {
                    cache_node_visit.output_keys.push(idx);
                }

                cache_node_visit.hits -= 1;

                if cache_node_visit.hits == 0 {
                    GraphVisitOrder::Visit {
                        output_keys: std::mem::take(&mut cache_node_visit.output_keys),
                    }
                } else {
                    GraphVisitOrder::HasUnvisitedOutputs
                }
            } else {
                GraphVisitOrder::Visit {
                    output_keys: if let Some(idx) = edge_key_idx {
                        unitvec![idx]
                    } else {
                        unitvec![]
                    },
                }
            }
        }),
        visitor,
    }
    .traverse_rec(root, root_edge_idx, root_edge_deleted)?;

    assert!(edges.len() >= root_edge_idx);
    edges.truncate(root_edge_idx + 1);

    ControlFlow::Continue(edges.pop().unwrap())
}