dynamic_expressions/
traits.rs1use num_traits::Float;
2
3pub trait Operator<T: Float, const A: usize> {
4 const NAME: &'static str;
5 const DISPLAY: &'static str = Self::NAME;
6 const INFIX: Option<&'static str> = None;
7 const ALIASES: &'static [&'static str] = &[];
8 const COMMUTATIVE: bool = false;
9 const ASSOCIATIVE: bool = false;
10 const COMPLEXITY: u16 = 1;
11
12 fn eval(args: &[T; A]) -> T;
13 fn partial(args: &[T; A], idx: usize) -> T;
14}
15
16#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
17pub struct OpId {
18 pub arity: u8,
19 pub id: u16,
20}
21
22#[derive(Copy, Clone, Debug, Eq, PartialEq)]
25pub struct OpMeta {
26 pub arity: u8,
27 pub id: u16,
28 pub name: &'static str,
29 pub display: &'static str,
30 pub infix: Option<&'static str>,
31 pub aliases: &'static [&'static str],
32 pub commutative: bool,
33 pub associative: bool,
34 pub complexity: u16,
35}
36
37pub trait OpTag {
41 const ARITY: u8;
42}
43
44pub trait HasOp<Tag: OpTag> {
45 const ID: u16;
46
47 #[inline]
48 fn op_id() -> OpId {
49 OpId {
50 arity: Tag::ARITY,
51 id: Self::ID,
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
57pub enum LookupError {
58 Unknown(String),
59 Ambiguous {
60 token: String,
61 candidates: Vec<&'static str>,
62 },
63}
64
65pub trait OperatorSet: Sized {
71 type T: Float;
72
73 const MAX_ARITY: u8;
74
75 fn ops_with_arity(arity: u8) -> &'static [u16];
76
77 fn meta(op: OpId) -> Option<&'static OpMeta>;
79
80 #[inline]
83 fn name(op: OpId) -> &'static str {
84 Self::meta(op).map(|m| m.name).unwrap_or("unknown_op")
85 }
86
87 #[inline]
88 fn display(op: OpId) -> &'static str {
89 Self::meta(op).map(|m| m.display).unwrap_or("unknown_op")
90 }
91
92 #[inline]
93 fn infix(op: OpId) -> Option<&'static str> {
94 Self::meta(op).and_then(|m| m.infix)
95 }
96
97 #[inline]
98 fn commutative(op: OpId) -> bool {
99 Self::meta(op).is_some_and(|m| m.commutative)
100 }
101
102 #[inline]
103 fn associative(op: OpId) -> bool {
104 Self::meta(op).is_some_and(|m| m.associative)
105 }
106
107 #[inline]
108 fn complexity(op: OpId) -> u16 {
109 Self::meta(op).map(|m| m.complexity).unwrap_or(0)
110 }
111
112 fn eval(op: OpId, ctx: crate::dispatch::EvalKernelCtx<'_, '_, Self::T>) -> bool;
115 fn diff(op: OpId, ctx: crate::dispatch::DiffKernelCtx<'_, '_, Self::T>) -> bool;
116 fn grad(op: OpId, ctx: crate::dispatch::GradKernelCtx<'_, '_, Self::T>) -> bool;
117
118 #[inline]
121 fn matches_token(op: OpId, tok: &str) -> bool {
122 let t = tok.trim();
123 let Some(meta) = Self::meta(op) else {
124 return false;
125 };
126
127 t.eq_ignore_ascii_case(meta.name)
128 || t == meta.display
129 || meta.infix.is_some_and(|s| t == s)
130 || meta.aliases.iter().any(|a| t.eq_ignore_ascii_case(a))
131 }
132
133 #[inline]
134 fn for_each_op(mut f: impl FnMut(OpId)) {
135 for arity in 1..=Self::MAX_ARITY {
136 for &id in Self::ops_with_arity(arity) {
137 f(OpId { arity, id });
138 }
139 }
140 }
141
142 fn lookup_all(token: &str) -> Vec<OpId> {
143 let mut out = Vec::new();
144 Self::for_each_op(|op| {
145 if Self::matches_token(op, token) {
146 out.push(op);
147 }
148 });
149 out
150 }
151
152 fn lookup(token: &str) -> Result<OpId, LookupError> {
153 let matches = Self::lookup_all(token);
154 match matches.as_slice() {
155 [] => Err(LookupError::Unknown(token.trim().to_string())),
156 [single] => Ok(*single),
157 _ => Err(LookupError::Ambiguous {
158 token: token.trim().to_string(),
159 candidates: matches.iter().map(|op| Self::name(*op)).collect(),
160 }),
161 }
162 }
163
164 fn lookup_with_arity(token: &str, arity: u8) -> Result<OpId, LookupError> {
165 let mut matches = Vec::new();
166 for &id in Self::ops_with_arity(arity) {
167 let op = OpId { arity, id };
168 if Self::matches_token(op, token) {
169 matches.push(op);
170 }
171 }
172 match matches.as_slice() {
173 [] => Err(LookupError::Unknown(token.trim().to_string())),
174 [single] => Ok(*single),
175 _ => Err(LookupError::Ambiguous {
176 token: token.trim().to_string(),
177 candidates: matches.iter().map(|op| Self::name(*op)).collect(),
178 }),
179 }
180 }
181}