dynamic_expressions/
traits.rs

1use num_traits::Float;
2
3pub trait Operator<T: Float, const A: usize> {
4    const NAME: &'static str;
5    const DISPLAY: &'static str = Self::NAME;
6    const INFIX: Option<&'static str> = None;
7    const ALIASES: &'static [&'static str] = &[];
8    const COMMUTATIVE: bool = false;
9    const ASSOCIATIVE: bool = false;
10    const COMPLEXITY: u16 = 1;
11
12    fn eval(args: &[T; A]) -> T;
13    fn partial(args: &[T; A], idx: usize) -> T;
14}
15
16#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
17pub struct OpId {
18    pub arity: u8,
19    pub id: u16,
20}
21
22/// Each operator-set implementation generates a `&'static [OpMeta]` per arity and does
23/// `META.get(op.id as usize)`.
24#[derive(Copy, Clone, Debug, Eq, PartialEq)]
25pub struct OpMeta {
26    pub arity: u8,
27    pub id: u16,
28    pub name: &'static str,
29    pub display: &'static str,
30    pub infix: Option<&'static str>,
31    pub aliases: &'static [&'static str],
32    pub commutative: bool,
33    pub associative: bool,
34    pub complexity: u16,
35}
36
37/// Associate a fixed arity with an operator marker type.
38///
39/// Rust can't (in general) infer `A` from `Tag: Operator<T, A>`, so we attach arity directly.
40pub trait OpTag {
41    const ARITY: u8;
42}
43
44pub trait HasOp<Tag: OpTag> {
45    const ID: u16;
46
47    #[inline]
48    fn op_id() -> OpId {
49        OpId {
50            arity: Tag::ARITY,
51            id: Self::ID,
52        }
53    }
54}
55
56#[derive(Debug, Clone)]
57pub enum LookupError {
58    Unknown(String),
59    Ambiguous {
60        token: String,
61        candidates: Vec<&'static str>,
62    },
63}
64
65/// Operator-set abstraction.
66///
67/// - "What ops exist?" -> `ops_with_arity`
68/// - "What are their tokens / properties?" -> `meta` (plus default helpers)
69/// - "How do I eval/diff/grad?" -> `eval/diff/grad` (dispatch to kernels)
70pub trait OperatorSet: Sized {
71    type T: Float;
72
73    const MAX_ARITY: u8;
74
75    fn ops_with_arity(arity: u8) -> &'static [u16];
76
77    /// The only required metadata entrypoint.
78    fn meta(op: OpId) -> Option<&'static OpMeta>;
79
80    // ---- Convenience defaults derived from meta() ----
81
82    #[inline]
83    fn name(op: OpId) -> &'static str {
84        Self::meta(op).map(|m| m.name).unwrap_or("unknown_op")
85    }
86
87    #[inline]
88    fn display(op: OpId) -> &'static str {
89        Self::meta(op).map(|m| m.display).unwrap_or("unknown_op")
90    }
91
92    #[inline]
93    fn infix(op: OpId) -> Option<&'static str> {
94        Self::meta(op).and_then(|m| m.infix)
95    }
96
97    #[inline]
98    fn commutative(op: OpId) -> bool {
99        Self::meta(op).is_some_and(|m| m.commutative)
100    }
101
102    #[inline]
103    fn associative(op: OpId) -> bool {
104        Self::meta(op).is_some_and(|m| m.associative)
105    }
106
107    #[inline]
108    fn complexity(op: OpId) -> u16 {
109        Self::meta(op).map(|m| m.complexity).unwrap_or(0)
110    }
111
112    // ---- Kernel dispatch ----
113
114    fn eval(op: OpId, ctx: crate::dispatch::EvalKernelCtx<'_, '_, Self::T>) -> bool;
115    fn diff(op: OpId, ctx: crate::dispatch::DiffKernelCtx<'_, '_, Self::T>) -> bool;
116    fn grad(op: OpId, ctx: crate::dispatch::GradKernelCtx<'_, '_, Self::T>) -> bool;
117
118    // ---- Token lookup ----
119
120    #[inline]
121    fn matches_token(op: OpId, tok: &str) -> bool {
122        let t = tok.trim();
123        let Some(meta) = Self::meta(op) else {
124            return false;
125        };
126
127        t.eq_ignore_ascii_case(meta.name)
128            || t == meta.display
129            || meta.infix.is_some_and(|s| t == s)
130            || meta.aliases.iter().any(|a| t.eq_ignore_ascii_case(a))
131    }
132
133    #[inline]
134    fn for_each_op(mut f: impl FnMut(OpId)) {
135        for arity in 1..=Self::MAX_ARITY {
136            for &id in Self::ops_with_arity(arity) {
137                f(OpId { arity, id });
138            }
139        }
140    }
141
142    fn lookup_all(token: &str) -> Vec<OpId> {
143        let mut out = Vec::new();
144        Self::for_each_op(|op| {
145            if Self::matches_token(op, token) {
146                out.push(op);
147            }
148        });
149        out
150    }
151
152    fn lookup(token: &str) -> Result<OpId, LookupError> {
153        let matches = Self::lookup_all(token);
154        match matches.as_slice() {
155            [] => Err(LookupError::Unknown(token.trim().to_string())),
156            [single] => Ok(*single),
157            _ => Err(LookupError::Ambiguous {
158                token: token.trim().to_string(),
159                candidates: matches.iter().map(|op| Self::name(*op)).collect(),
160            }),
161        }
162    }
163
164    fn lookup_with_arity(token: &str, arity: u8) -> Result<OpId, LookupError> {
165        let mut matches = Vec::new();
166        for &id in Self::ops_with_arity(arity) {
167            let op = OpId { arity, id };
168            if Self::matches_token(op, token) {
169                matches.push(op);
170            }
171        }
172        match matches.as_slice() {
173            [] => Err(LookupError::Unknown(token.trim().to_string())),
174            [single] => Ok(*single),
175            _ => Err(LookupError::Ambiguous {
176                token: token.trim().to_string(),
177                candidates: matches.iter().map(|op| Self::name(*op)).collect(),
178            }),
179        }
180    }
181}