dynamic_expressions 0.10.0

Fast batched evaluation + forward-mode derivatives for symbolic expressions (Rust port of DynamicExpressions.jl).
Documentation
use std::collections::HashSet;
use std::fmt;

use crate::traits::{LookupError, OpId, OperatorSet};

#[derive(Clone, Debug)]
pub struct Operators<const D: usize> {
    pub ops_by_arity: [Vec<OpId>; D],
}

#[derive(Debug, Clone)]
pub enum OperatorSelectError {
    Lookup(LookupError),
    ArityTooLarge { token: String, arity: u8, max_arity: usize },
    Duplicate(String),
    Empty,
}

impl fmt::Display for OperatorSelectError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            OperatorSelectError::Lookup(e) => write!(f, "{e:?}"),
            OperatorSelectError::ArityTooLarge {
                token,
                arity,
                max_arity,
            } => write!(
                f,
                "operator token {token:?} has arity={arity} which exceeds D={max_arity}"
            ),
            OperatorSelectError::Duplicate(tok) => write!(f, "duplicate operator token {tok:?}"),
            OperatorSelectError::Empty => write!(f, "no operators provided"),
        }
    }
}

impl std::error::Error for OperatorSelectError {}

impl<const D: usize> Operators<D> {
    pub fn new() -> Self {
        Self {
            ops_by_arity: std::array::from_fn(|_| Vec::new()),
        }
    }

    pub fn push(&mut self, op: OpId) {
        let arity = op.arity as usize;
        assert!((1..=D).contains(&arity));
        self.ops_by_arity[arity - 1].push(op);
    }

    pub fn nops(&self, arity: usize) -> usize {
        self.ops_by_arity[arity - 1].len()
    }

    pub fn from_names<Ops, I>(names: I) -> Result<Self, OperatorSelectError>
    where
        Ops: OperatorSet,
        I: IntoIterator,
        I::Item: AsRef<str>,
    {
        let mut iter = names.into_iter().peekable();
        if iter.peek().is_none() {
            return Err(OperatorSelectError::Empty);
        }

        let mut out = Self::new();
        let mut seen: HashSet<(u8, u16)> = HashSet::new();

        for tok in iter {
            let tok = tok.as_ref();
            let op = Ops::lookup(tok).map_err(OperatorSelectError::Lookup)?;
            if (op.arity as usize) > D {
                return Err(OperatorSelectError::ArityTooLarge {
                    token: tok.to_string(),
                    arity: op.arity,
                    max_arity: D,
                });
            }
            let key = (op.arity, op.id);
            if !seen.insert(key) {
                return Err(OperatorSelectError::Duplicate(tok.to_string()));
            }
            out.push(op);
        }

        Ok(out)
    }
}

impl<const D: usize> Default for Operators<D> {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use crate::operator_enum::builtin;
    use crate::operator_enum::presets::BuiltinOpsF64;
    use crate::{HasOp, Operators};

    crate::op!(Square for f64 {
        eval: |[x]| { x * x },
        partial: |[x], _idx| { 2.0 * x },
    });

    crate::opset! {
        V2Ops for f64 { Square }
    }

    #[test]
    fn from_names_requires_disambiguation_for_dash() {
        let ops: Result<Operators<3>, _> = BuiltinOpsF64::from_names(["-"]);
        assert!(ops.is_err());

        let ops: Operators<3> = BuiltinOpsF64::from_names(["neg", "sub"]).unwrap();
        assert_eq!(ops.nops(1), 1);
        assert_eq!(ops.ops_by_arity[0][0], <BuiltinOpsF64 as HasOp<builtin::Neg>>::op_id());
        assert_eq!(ops.nops(2), 1);
        assert_eq!(ops.ops_by_arity[1][0], <BuiltinOpsF64 as HasOp<builtin::Sub>>::op_id());
    }

    #[test]
    fn v2_opset_dsl_builds_operator_set() {
        let ops: Operators<3> = V2Ops::from_names(["square"]).unwrap();
        assert_eq!(ops.nops(1), 1);
    }
}