dynamic_expressions 0.10.0

Fast batched evaluation + forward-mode derivatives for symbolic expressions (Rust port of DynamicExpressions.jl).
Documentation
use std::cell::RefCell;

use crate::node::PNode;

pub fn tree_mapreduce<R, B>(
    nodes: &[PNode],
    f_leaf: impl FnMut(&PNode) -> R,
    f_branch: impl FnMut(&PNode) -> B,
    op: impl FnMut(B, &[R]) -> R,
) -> R {
    tree_mapreduce_with_stack(nodes, f_leaf, f_branch, op, None)
}

pub fn tree_mapreduce_with_stack<R, B>(
    nodes: &[PNode],
    mut f_leaf: impl FnMut(&PNode) -> R,
    mut f_branch: impl FnMut(&PNode) -> B,
    mut op: impl FnMut(B, &[R]) -> R,
    reusable_stack: Option<&mut Vec<R>>,
) -> R {
    match reusable_stack {
        Some(stack) => {
            stack.clear();
            tree_mapreduce_impl(nodes, stack, &mut f_leaf, &mut f_branch, &mut op)
        }
        None => {
            let mut stack = Vec::with_capacity(nodes.len());
            tree_mapreduce_impl(nodes, &mut stack, &mut f_leaf, &mut f_branch, &mut op)
        }
    }
}

fn tree_mapreduce_impl<R, B>(
    nodes: &[PNode],
    stack: &mut Vec<R>,
    f_leaf: &mut impl FnMut(&PNode) -> R,
    f_branch: &mut impl FnMut(&PNode) -> B,
    op: &mut impl FnMut(B, &[R]) -> R,
) -> R {
    for n in nodes {
        match *n {
            PNode::Var { .. } | PNode::Const { .. } => stack.push(f_leaf(n)),
            PNode::Op { arity, .. } => {
                let a = arity as usize;
                let start = stack.len().checked_sub(a).expect("invalid postfix (stack underflow)");
                let parent = f_branch(n);
                let out = op(parent, &stack[start..]);
                stack.truncate(start);
                stack.push(out);
            }
        }
    }
    assert_eq!(stack.len(), 1, "invalid postfix (did not reduce to one root)");
    stack.pop().expect("non-empty stack")
}

thread_local! {
    static COUNT_DEPTH_STACK: RefCell<Vec<usize>> = const { RefCell::new(Vec::new()) };
}

pub fn count_depth(nodes: &[PNode]) -> usize {
    COUNT_DEPTH_STACK.with(|stack| {
        let mut stack = stack.borrow_mut();
        tree_mapreduce_with_stack(
            nodes,
            |_| 1usize,
            |_| 0usize,
            |_, children| children.iter().copied().max().unwrap_or(0) + 1,
            Some(&mut stack),
        )
    })
}

pub fn count_nodes(nodes: &[PNode]) -> usize {
    nodes.len()
}

pub fn has_constants(nodes: &[PNode]) -> bool {
    nodes.iter().any(|n| matches!(n, PNode::Const { .. }))
}

pub fn count_constant_nodes(nodes: &[PNode]) -> usize {
    nodes.iter().filter(|n| matches!(n, PNode::Const { .. })).count()
}

pub fn has_operators(nodes: &[PNode]) -> bool {
    nodes.iter().any(|n| matches!(n, PNode::Op { .. }))
}

pub fn has_variables(nodes: &[PNode]) -> bool {
    nodes.iter().any(|n| matches!(n, PNode::Var { .. }))
}

pub fn count_variable_nodes(nodes: &[PNode]) -> usize {
    nodes.iter().filter(|n| matches!(n, PNode::Var { .. })).count()
}

pub fn count_operator_nodes(nodes: &[PNode]) -> usize {
    nodes.iter().filter(|n| matches!(n, PNode::Op { .. })).count()
}

pub fn max_arity(nodes: &[PNode]) -> u8 {
    nodes
        .iter()
        .filter_map(|n| match n {
            PNode::Op { arity, .. } => Some(*arity),
            _ => None,
        })
        .max()
        .unwrap_or(0)
}

pub fn is_leaf(nodes: &[PNode]) -> bool {
    matches!(nodes, [PNode::Var { .. }] | [PNode::Const { .. }])
}

pub fn is_valid_postfix(nodes: &[PNode]) -> bool {
    let mut stack: isize = 0;
    for n in nodes {
        match *n {
            PNode::Var { .. } | PNode::Const { .. } => stack += 1,
            PNode::Op { arity, .. } => {
                let a = arity as isize;
                if a <= 0 {
                    return false;
                }
                if stack < a {
                    return false;
                }
                stack = stack - a + 1;
            }
        }
    }
    stack == 1
}

pub fn subtree_sizes(nodes: &[PNode]) -> Vec<usize> {
    let mut sizes = vec![0usize; nodes.len()];
    let mut stack: Vec<usize> = Vec::with_capacity(nodes.len());

    for (i, n) in nodes.iter().enumerate() {
        match *n {
            PNode::Var { .. } | PNode::Const { .. } => {
                sizes[i] = 1;
                stack.push(1);
            }
            PNode::Op { arity, .. } => {
                let a = arity as usize;
                let mut sum = 1usize;
                for _ in 0..a {
                    sum += stack.pop().expect("invalid postfix (stack underflow)");
                }
                sizes[i] = sum;
                stack.push(sum);
            }
        }
    }

    assert_eq!(stack.len(), 1, "invalid postfix (did not reduce to one root)");
    sizes
}

pub fn subtree_range(subtree_sizes: &[usize], root_idx: usize) -> (usize, usize) {
    let sz = subtree_sizes[root_idx];
    (root_idx + 1 - sz, root_idx)
}