1mod compiler;
2pub mod shared;
3
4#[cfg(feature = "f64")]
6pub type Float = f64;
7#[cfg(not(feature = "f64"))]
8pub type Float = f32;
9
10#[must_use]
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub struct Expr(usize);
14
15#[derive(Debug, Clone, Copy)]
17pub struct Comparison {
18 pub a: Expr,
19 pub b: Expr,
20 pub kind: ComparisonKind,
21}
22
23#[derive(Debug, Clone, Copy)]
25pub enum ComparisonKind {
26 Eq,
27 Neq,
28 Gt,
29 Gteq,
30}
31
32#[derive(Debug, Clone, Copy)]
34pub enum Condition {
35 Comparison(Comparison),
36}
37
38#[derive(Debug, Clone, Copy)]
40enum ExprKind {
41 Constant(Float),
43 Add(Expr, Expr),
45 Sub(Expr, Expr),
47 Mul(Expr, Expr),
49 Div(Expr, Expr),
51 Input(usize),
53 Sin(Expr),
55 Cos(Expr),
57 Atan2(Expr, Expr),
59 Neg(Expr),
61 Ternary(Condition, Expr, Expr),
63 Sqrt(Expr),
65 Exp(Expr),
67 Log(Expr),
69}
70
71#[derive(Debug, Clone, Copy)]
73struct Entry {
74 kind: ExprKind,
76 derivatives: Option<usize>,
78}
79
80#[derive(Debug, Clone)]
82pub struct Context {
83 inputs: usize,
85 exprs: Vec<Entry>,
87 derivatives: Vec<Expr>,
90}
91
92impl Context {
93 #[must_use]
95 pub fn new(inputs: usize) -> Self {
96 let mut exprs = Vec::new();
97 let mut derivatives = Vec::new();
98
99 exprs.push(Entry {
101 kind: ExprKind::Constant(0.0),
102 derivatives: Some(0),
103 });
104 derivatives.extend([Expr(0)].repeat(inputs));
105
106 exprs.push(Entry {
108 kind: ExprKind::Constant(1.0),
109 derivatives: Some(0),
110 });
111
112 for i in 0..inputs {
114 exprs.push(Entry {
115 kind: ExprKind::Input(i),
116 derivatives: Some(derivatives.len()),
117 });
118 derivatives.extend((0..inputs).map(|j| if i == j { Expr(1) } else { Expr(0) }));
119 }
120
121 Self {
122 inputs,
123 exprs,
124 derivatives,
125 }
126 }
127
128 pub fn stringify(&self, expr: Expr) -> String {
130 self.stringify_kind(self.exprs[expr.0].kind)
131 }
132
133 fn stringify_kind(&self, expr_kind: ExprKind) -> String {
135 match expr_kind {
136 ExprKind::Constant(v) => format!("{v:.2}"),
137 ExprKind::Add(a, b) => {
138 format!("({} + {})", self.stringify(a), self.stringify(b))
139 }
140 ExprKind::Sub(a, b) => {
141 format!("({} - {})", self.stringify(a), self.stringify(b))
142 }
143 ExprKind::Mul(a, b) => {
144 format!("({} * {})", self.stringify(a), self.stringify(b))
145 }
146 ExprKind::Div(a, b) => {
147 format!("({} / {})", self.stringify(a), self.stringify(b))
148 }
149 ExprKind::Input(i) => format!("#{i}"),
150 ExprKind::Sin(v) => format!("sin({})", self.stringify(v)),
151 ExprKind::Cos(v) => format!("cos({})", self.stringify(v)),
152 ExprKind::Atan2(y, x) => format!("atan2({}, {})", self.stringify(y), self.stringify(x)),
153 ExprKind::Neg(v) => format!("-{}", self.stringify(v)),
154 ExprKind::Ternary(cond, then, else_) => format!(
155 "({} ? {} : {})",
156 self.stringify_condition(cond),
157 self.stringify(then),
158 self.stringify(else_)
159 ),
160 ExprKind::Sqrt(v) => format!("sqrt({})", self.stringify(v)),
161 ExprKind::Exp(v) => format!("e^{}", self.stringify(v)),
162 ExprKind::Log(v) => format!("ln({})", self.stringify(v)),
163 }
164 }
165
166 fn stringify_condition(&self, condition: Condition) -> String {
168 match condition {
169 Condition::Comparison(cmp) => {
170 let a = self.stringify(cmp.a);
171 let b = self.stringify(cmp.b);
172 let sign = match cmp.kind {
173 ComparisonKind::Eq => "=",
174 ComparisonKind::Neq => "!=",
175 ComparisonKind::Gt => ">",
176 ComparisonKind::Gteq => "≥",
177 };
178
179 format!("{a} {sign} {b}")
180 }
181 }
182 }
183
184 fn push_expr_nodiff(&mut self, kind: ExprKind) -> Expr {
186 let id = self.exprs.len();
187 self.exprs.push(Entry {
188 kind,
189 derivatives: None,
190 });
191 Expr(id)
192 }
193
194 fn push_expr(&mut self, kind: ExprKind, derivatives: Vec<Expr>) -> Expr {
196 assert_eq!(self.inputs, derivatives.len());
197 let id = self.exprs.len();
198 self.exprs.push(Entry {
199 kind,
200 derivatives: Some(self.derivatives.len()),
201 });
202 self.derivatives.extend(derivatives);
203 Expr(id)
204 }
205
206 fn get_derivative(&self, expr: Expr, input: usize) -> Expr {
211 self.derivatives[self.exprs[expr.0].derivatives.unwrap() + input]
212 }
213
214 pub fn zero() -> Expr {
216 Expr(0)
217 }
218
219 pub fn one() -> Expr {
221 Expr(1)
222 }
223
224 pub fn constant(&mut self, value: Float) -> Expr {
226 let kind = ExprKind::Constant(value);
227 let id = self.exprs.len();
228 self.exprs.push(Entry {
229 kind,
230 derivatives: Some(0),
231 });
232 Expr(id)
233 }
234
235 pub fn add(&mut self, a: Expr, b: Expr) -> Expr {
237 let derivatives = (0..self.inputs)
239 .map(|i| {
240 let a_d = self.get_derivative(a, i);
241 let b_d = self.get_derivative(b, i);
242 self.push_expr_nodiff(ExprKind::Add(a_d, b_d))
243 })
244 .collect();
245 self.push_expr(ExprKind::Add(a, b), derivatives)
246 }
247
248 pub fn sub(&mut self, a: Expr, b: Expr) -> Expr {
250 let derivatives = (0..self.inputs)
252 .map(|i| {
253 let a_d = self.get_derivative(a, i);
254 let b_d = self.get_derivative(b, i);
255 self.push_expr_nodiff(ExprKind::Sub(a_d, b_d))
256 })
257 .collect();
258 self.push_expr(ExprKind::Sub(a, b), derivatives)
259 }
260
261 pub fn mul(&mut self, a: Expr, b: Expr) -> Expr {
263 let derivatives = (0..self.inputs)
265 .map(|i| {
266 let a_d = self.get_derivative(a, i);
267 let b_d = self.get_derivative(b, i);
268 let first = self.push_expr_nodiff(ExprKind::Mul(a_d, b));
269 let second = self.push_expr_nodiff(ExprKind::Mul(a, b_d));
270 self.push_expr_nodiff(ExprKind::Add(first, second))
271 })
272 .collect();
273 self.push_expr(ExprKind::Mul(a, b), derivatives)
274 }
275
276 pub fn div(&mut self, a: Expr, b: Expr) -> Expr {
278 let derivatives = (0..self.inputs)
280 .map(|i| {
281 let a_d = self.get_derivative(a, i);
282 let b_d = self.get_derivative(b, i);
283 let first = self.push_expr_nodiff(ExprKind::Mul(a_d, b));
284 let second = self.push_expr_nodiff(ExprKind::Mul(a, b_d));
285 let diff = self.push_expr_nodiff(ExprKind::Sub(first, second));
286 let b_squared = self.push_expr_nodiff(ExprKind::Mul(b, b));
287 self.push_expr_nodiff(ExprKind::Div(diff, b_squared))
288 })
289 .collect();
290 self.push_expr(ExprKind::Div(a, b), derivatives)
291 }
292
293 pub fn input(&self, input: usize) -> Expr {
298 assert!(input < self.inputs);
299 Expr(2 + input)
300 }
301
302 pub fn sin(&mut self, v: Expr) -> Expr {
304 let derivatives = (0..self.inputs)
306 .map(|i| {
307 let dv = self.get_derivative(v, i);
308 let cos = self.push_expr_nodiff(ExprKind::Cos(v));
309 self.push_expr_nodiff(ExprKind::Mul(cos, dv))
310 })
311 .collect();
312 self.push_expr(ExprKind::Sin(v), derivatives)
313 }
314
315 pub fn cos(&mut self, v: Expr) -> Expr {
317 let derivatives = (0..self.inputs)
319 .map(|i| {
320 let dv = self.get_derivative(v, i);
321 let sin = self.push_expr_nodiff(ExprKind::Sin(v));
322 let minus_sin = self.push_expr_nodiff(ExprKind::Neg(sin));
323 self.push_expr_nodiff(ExprKind::Mul(minus_sin, dv))
324 })
325 .collect();
326 self.push_expr(ExprKind::Cos(v), derivatives)
327 }
328
329 pub fn sqrt(&mut self, v: Expr) -> Expr {
331 let derivatives = (0..self.inputs)
333 .map(|i| {
334 let dv = self.get_derivative(v, i);
335 let sqrt = self.push_expr_nodiff(ExprKind::Sqrt(v));
336 let two = self.push_expr_nodiff(ExprKind::Constant(2.0));
337 let two_sqrt = self.push_expr_nodiff(ExprKind::Mul(two, sqrt));
338 self.push_expr_nodiff(ExprKind::Div(dv, two_sqrt))
339 })
340 .collect();
341 self.push_expr(ExprKind::Sqrt(v), derivatives)
342 }
343
344 pub fn exp(&mut self, v: Expr) -> Expr {
346 let derivatives = (0..self.inputs)
348 .map(|i| {
349 let dv = self.get_derivative(v, i);
350 let expv = self.push_expr_nodiff(ExprKind::Exp(v));
351 self.push_expr_nodiff(ExprKind::Mul(expv, dv))
352 })
353 .collect();
354 self.push_expr(ExprKind::Exp(v), derivatives)
355 }
356
357 pub fn log(&mut self, v: Expr) -> Expr {
359 let derivatives = (0..self.inputs)
361 .map(|i| {
362 let dv = self.get_derivative(v, i);
363 self.push_expr_nodiff(ExprKind::Div(dv, v))
364 })
365 .collect();
366 self.push_expr(ExprKind::Log(v), derivatives)
367 }
368
369 pub fn atan2(&mut self, y: Expr, x: Expr) -> Expr {
371 let derivatives = (0..self.inputs)
373 .map(|i| {
374 let dy = self.get_derivative(y, i);
375 let dx = self.get_derivative(x, i);
376 let x_dy = self.push_expr_nodiff(ExprKind::Mul(x, dy));
377 let y_dx = self.push_expr_nodiff(ExprKind::Mul(y, dx));
378 let x2 = self.push_expr_nodiff(ExprKind::Mul(x, x));
379 let y2 = self.push_expr_nodiff(ExprKind::Mul(y, y));
380 let x2_plus_y2 = self.push_expr_nodiff(ExprKind::Add(x2, y2));
381 let xdy_minus_ydx = self.push_expr_nodiff(ExprKind::Sub(x_dy, y_dx));
382 self.push_expr_nodiff(ExprKind::Div(xdy_minus_ydx, x2_plus_y2))
383 })
384 .collect();
385 self.push_expr(ExprKind::Atan2(y, x), derivatives)
386 }
387
388 pub fn neg(&mut self, v: Expr) -> Expr {
390 let derivatives = (0..self.inputs)
392 .map(|i| {
393 let dv = self.get_derivative(v, i);
394 self.push_expr_nodiff(ExprKind::Neg(dv))
395 })
396 .collect();
397 self.push_expr(ExprKind::Neg(v), derivatives)
398 }
399
400 pub fn min(&mut self, a: Expr, b: Expr) -> Expr {
402 self.ternary(
403 Condition::Comparison(Comparison {
404 a: b,
405 b: a,
406 kind: ComparisonKind::Gt,
407 }),
408 a,
409 b,
410 )
411 }
412
413 pub fn abs(&mut self, v: Expr) -> Expr {
415 let cond = Condition::Comparison(Comparison {
416 a: v,
417 b: Self::zero(),
418 kind: ComparisonKind::Gt,
419 });
420
421 let derivatives = (0..self.inputs)
423 .map(|i| {
424 let dv = self.get_derivative(v, i);
425 let minus_dv = self.push_expr_nodiff(ExprKind::Neg(dv));
426 self.push_expr_nodiff(ExprKind::Ternary(cond, dv, minus_dv))
427 })
428 .collect();
429 let minus_v = self.push_expr_nodiff(ExprKind::Neg(v));
430 self.push_expr(ExprKind::Ternary(cond, v, minus_v), derivatives)
431 }
432
433 pub fn ternary(&mut self, condition: Condition, then: Expr, else_: Expr) -> Expr {
435 let derivatives = (0..self.inputs)
437 .map(|i| {
438 let dthen = self.get_derivative(then, i);
439 let delse = self.get_derivative(else_, i);
440 self.push_expr_nodiff(ExprKind::Ternary(condition, dthen, delse))
441 })
442 .collect();
443 self.push_expr(ExprKind::Ternary(condition, then, else_), derivatives)
444 }
445}
446
447impl Context {
448 pub fn compute(&self, exprs: impl IntoIterator<Item = Expr>) -> Func {
450 Func {
451 func: compiler::compile(self, exprs),
452 }
453 }
454
455 pub fn compute_gradient(&self, expr: Expr) -> Func {
457 let func = compiler::compile(self, (0..self.inputs).map(|i| self.get_derivative(expr, i)));
458
459 Func { func }
460 }
461
462 pub fn gradient(&self, expr: Expr) -> Vec<Expr> {
464 (0..self.inputs)
465 .map(|i| self.get_derivative(expr, i))
466 .collect()
467 }
468}
469
470#[derive(Clone, Copy)]
472pub struct Func {
473 func: fn(*const Float, *mut Float),
474}
475
476impl Func {
477 pub fn call(&self, inputs: &[Float], dst: &mut [Float]) {
479 (self.func)(inputs.as_ptr(), dst.as_mut_ptr());
480 }
481}
482
483unsafe impl Send for Func {}
484unsafe impl Sync for Func {}