geo_aid_math/
lib.rs

1mod compiler;
2pub mod shared;
3
4/// A feature-specific floating point representation.
5#[cfg(feature = "f64")]
6pub type Float = f64;
7#[cfg(not(feature = "f64"))]
8pub type Float = f32;
9
10/// A compiled expression.
11#[must_use]
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub struct Expr(usize);
14
15/// A comparison between two expressions.
16#[derive(Debug, Clone, Copy)]
17pub struct Comparison {
18    pub a: Expr,
19    pub b: Expr,
20    pub kind: ComparisonKind,
21}
22
23/// The primitive kind of a comparison.
24#[derive(Debug, Clone, Copy)]
25pub enum ComparisonKind {
26    Eq,
27    Neq,
28    Gt,
29    Gteq,
30}
31
32/// A condition for ternary operators
33#[derive(Debug, Clone, Copy)]
34pub enum Condition {
35    Comparison(Comparison),
36}
37
38/// The kind of an expression. Internal use only.
39#[derive(Debug, Clone, Copy)]
40enum ExprKind {
41    /// A constant number
42    Constant(Float),
43    /// a + b
44    Add(Expr, Expr),
45    /// a - b
46    Sub(Expr, Expr),
47    /// a * b
48    Mul(Expr, Expr),
49    /// a / b
50    Div(Expr, Expr),
51    /// An input at an index
52    Input(usize),
53    /// sine of a value
54    Sin(Expr),
55    /// cosine of a value
56    Cos(Expr),
57    /// atan2 function
58    Atan2(Expr, Expr),
59    /// -expr
60    Neg(Expr),
61    /// If condition is true, returns first expression. Otherwise returns the second one.
62    Ternary(Condition, Expr, Expr),
63    /// Square root
64    Sqrt(Expr),
65    /// e^expr
66    Exp(Expr),
67    /// Natural logarithm
68    Log(Expr),
69}
70
71/// An entry in the expression record.
72#[derive(Debug, Clone, Copy)]
73struct Entry {
74    /// Kind of this expression
75    kind: ExprKind,
76    /// The index of derivatives of this expression to the `derivatives` vector.
77    derivatives: Option<usize>,
78}
79
80/// Compilation context. Necessary for any expression manipulation and compilation.
81#[derive(Debug, Clone)]
82pub struct Context {
83    /// Input count.
84    inputs: usize,
85    /// Expression vector.
86    exprs: Vec<Entry>,
87    /// Derivatives w.r.t. respective inputs. Every `inputs` next entries are a set
88    /// of derivatives.
89    derivatives: Vec<Expr>,
90}
91
92impl Context {
93    /// Create a new context prepared to handle a given amount of inputs.
94    #[must_use]
95    pub fn new(inputs: usize) -> Self {
96        let mut exprs = Vec::new();
97        let mut derivatives = Vec::new();
98
99        // Due to it being potentially useful and popular, zero is added automatically.
100        exprs.push(Entry {
101            kind: ExprKind::Constant(0.0),
102            derivatives: Some(0),
103        });
104        derivatives.extend([Expr(0)].repeat(inputs));
105
106        // For inputs, we also include a one.
107        exprs.push(Entry {
108            kind: ExprKind::Constant(1.0),
109            derivatives: Some(0),
110        });
111
112        // For the same reason. Input exprs is also included.
113        for i in 0..inputs {
114            exprs.push(Entry {
115                kind: ExprKind::Input(i),
116                derivatives: Some(derivatives.len()),
117            });
118            derivatives.extend((0..inputs).map(|j| if i == j { Expr(1) } else { Expr(0) }));
119        }
120
121        Self {
122            inputs,
123            exprs,
124            derivatives,
125        }
126    }
127
128    /// Print the expression as a string. Use only for debugging, brace for long output.
129    pub fn stringify(&self, expr: Expr) -> String {
130        self.stringify_kind(self.exprs[expr.0].kind)
131    }
132
133    /// Helper for the [`stringify`] function
134    fn stringify_kind(&self, expr_kind: ExprKind) -> String {
135        match expr_kind {
136            ExprKind::Constant(v) => format!("{v:.2}"),
137            ExprKind::Add(a, b) => {
138                format!("({} + {})", self.stringify(a), self.stringify(b))
139            }
140            ExprKind::Sub(a, b) => {
141                format!("({} - {})", self.stringify(a), self.stringify(b))
142            }
143            ExprKind::Mul(a, b) => {
144                format!("({} * {})", self.stringify(a), self.stringify(b))
145            }
146            ExprKind::Div(a, b) => {
147                format!("({} / {})", self.stringify(a), self.stringify(b))
148            }
149            ExprKind::Input(i) => format!("#{i}"),
150            ExprKind::Sin(v) => format!("sin({})", self.stringify(v)),
151            ExprKind::Cos(v) => format!("cos({})", self.stringify(v)),
152            ExprKind::Atan2(y, x) => format!("atan2({}, {})", self.stringify(y), self.stringify(x)),
153            ExprKind::Neg(v) => format!("-{}", self.stringify(v)),
154            ExprKind::Ternary(cond, then, else_) => format!(
155                "({} ? {} : {})",
156                self.stringify_condition(cond),
157                self.stringify(then),
158                self.stringify(else_)
159            ),
160            ExprKind::Sqrt(v) => format!("sqrt({})", self.stringify(v)),
161            ExprKind::Exp(v) => format!("e^{}", self.stringify(v)),
162            ExprKind::Log(v) => format!("ln({})", self.stringify(v)),
163        }
164    }
165
166    /// Helper for the [`stringify`] function.
167    fn stringify_condition(&self, condition: Condition) -> String {
168        match condition {
169            Condition::Comparison(cmp) => {
170                let a = self.stringify(cmp.a);
171                let b = self.stringify(cmp.b);
172                let sign = match cmp.kind {
173                    ComparisonKind::Eq => "=",
174                    ComparisonKind::Neq => "!=",
175                    ComparisonKind::Gt => ">",
176                    ComparisonKind::Gteq => "≥",
177                };
178
179                format!("{a} {sign} {b}")
180            }
181        }
182    }
183
184    /// Push an expression without derivatives.
185    fn push_expr_nodiff(&mut self, kind: ExprKind) -> Expr {
186        let id = self.exprs.len();
187        self.exprs.push(Entry {
188            kind,
189            derivatives: None,
190        });
191        Expr(id)
192    }
193
194    /// Push an expression with its derivatives.
195    fn push_expr(&mut self, kind: ExprKind, derivatives: Vec<Expr>) -> Expr {
196        assert_eq!(self.inputs, derivatives.len());
197        let id = self.exprs.len();
198        self.exprs.push(Entry {
199            kind,
200            derivatives: Some(self.derivatives.len()),
201        });
202        self.derivatives.extend(derivatives);
203        Expr(id)
204    }
205
206    /// Get the derivative of `expr` w.r.t. `input`
207    ///
208    /// # Panics
209    /// If the expression is not a user-defined one (does not have a derivative).
210    fn get_derivative(&self, expr: Expr, input: usize) -> Expr {
211        self.derivatives[self.exprs[expr.0].derivatives.unwrap() + input]
212    }
213
214    /// Creates a zero.
215    pub fn zero() -> Expr {
216        Expr(0)
217    }
218
219    /// Creates a one.
220    pub fn one() -> Expr {
221        Expr(1)
222    }
223
224    /// Creates a constant value.
225    pub fn constant(&mut self, value: Float) -> Expr {
226        let kind = ExprKind::Constant(value);
227        let id = self.exprs.len();
228        self.exprs.push(Entry {
229            kind,
230            derivatives: Some(0),
231        });
232        Expr(id)
233    }
234
235    /// Adds two values.
236    pub fn add(&mut self, a: Expr, b: Expr) -> Expr {
237        // The derivative of `a + b` is the derivative of `a` plus the one of `b`.
238        let derivatives = (0..self.inputs)
239            .map(|i| {
240                let a_d = self.get_derivative(a, i);
241                let b_d = self.get_derivative(b, i);
242                self.push_expr_nodiff(ExprKind::Add(a_d, b_d))
243            })
244            .collect();
245        self.push_expr(ExprKind::Add(a, b), derivatives)
246    }
247
248    /// Subtracts two values.
249    pub fn sub(&mut self, a: Expr, b: Expr) -> Expr {
250        // The derivative of `a - b` is the derivative of `a` minus the one of `b`.
251        let derivatives = (0..self.inputs)
252            .map(|i| {
253                let a_d = self.get_derivative(a, i);
254                let b_d = self.get_derivative(b, i);
255                self.push_expr_nodiff(ExprKind::Sub(a_d, b_d))
256            })
257            .collect();
258        self.push_expr(ExprKind::Sub(a, b), derivatives)
259    }
260
261    /// Multiplies two values.
262    pub fn mul(&mut self, a: Expr, b: Expr) -> Expr {
263        // `d(a*b) = da*b + a*db`
264        let derivatives = (0..self.inputs)
265            .map(|i| {
266                let a_d = self.get_derivative(a, i);
267                let b_d = self.get_derivative(b, i);
268                let first = self.push_expr_nodiff(ExprKind::Mul(a_d, b));
269                let second = self.push_expr_nodiff(ExprKind::Mul(a, b_d));
270                self.push_expr_nodiff(ExprKind::Add(first, second))
271            })
272            .collect();
273        self.push_expr(ExprKind::Mul(a, b), derivatives)
274    }
275
276    /// Divides two values.
277    pub fn div(&mut self, a: Expr, b: Expr) -> Expr {
278        // `d(a/b) = (da*b - a*db)/(b*b)`
279        let derivatives = (0..self.inputs)
280            .map(|i| {
281                let a_d = self.get_derivative(a, i);
282                let b_d = self.get_derivative(b, i);
283                let first = self.push_expr_nodiff(ExprKind::Mul(a_d, b));
284                let second = self.push_expr_nodiff(ExprKind::Mul(a, b_d));
285                let diff = self.push_expr_nodiff(ExprKind::Sub(first, second));
286                let b_squared = self.push_expr_nodiff(ExprKind::Mul(b, b));
287                self.push_expr_nodiff(ExprKind::Div(diff, b_squared))
288            })
289            .collect();
290        self.push_expr(ExprKind::Div(a, b), derivatives)
291    }
292
293    /// The i-th input.
294    ///
295    /// # Panics
296    /// If the input is out of bounds
297    pub fn input(&self, input: usize) -> Expr {
298        assert!(input < self.inputs);
299        Expr(2 + input)
300    }
301
302    /// Calculates the sine of a value.
303    pub fn sin(&mut self, v: Expr) -> Expr {
304        // `dsin(v) = cos(v) * dv`
305        let derivatives = (0..self.inputs)
306            .map(|i| {
307                let dv = self.get_derivative(v, i);
308                let cos = self.push_expr_nodiff(ExprKind::Cos(v));
309                self.push_expr_nodiff(ExprKind::Mul(cos, dv))
310            })
311            .collect();
312        self.push_expr(ExprKind::Sin(v), derivatives)
313    }
314
315    /// Calculates the cosine of a value.
316    pub fn cos(&mut self, v: Expr) -> Expr {
317        // `dcos(v) = -sin(v) * dv`
318        let derivatives = (0..self.inputs)
319            .map(|i| {
320                let dv = self.get_derivative(v, i);
321                let sin = self.push_expr_nodiff(ExprKind::Sin(v));
322                let minus_sin = self.push_expr_nodiff(ExprKind::Neg(sin));
323                self.push_expr_nodiff(ExprKind::Mul(minus_sin, dv))
324            })
325            .collect();
326        self.push_expr(ExprKind::Cos(v), derivatives)
327    }
328
329    /// Square root of a number
330    pub fn sqrt(&mut self, v: Expr) -> Expr {
331        // dsqrt(v) = dv / 2sqrt(v)
332        let derivatives = (0..self.inputs)
333            .map(|i| {
334                let dv = self.get_derivative(v, i);
335                let sqrt = self.push_expr_nodiff(ExprKind::Sqrt(v));
336                let two = self.push_expr_nodiff(ExprKind::Constant(2.0));
337                let two_sqrt = self.push_expr_nodiff(ExprKind::Mul(two, sqrt));
338                self.push_expr_nodiff(ExprKind::Div(dv, two_sqrt))
339            })
340            .collect();
341        self.push_expr(ExprKind::Sqrt(v), derivatives)
342    }
343
344    /// Calculates e^expr
345    pub fn exp(&mut self, v: Expr) -> Expr {
346        // dexp(v) = dv*exp(v)
347        let derivatives = (0..self.inputs)
348            .map(|i| {
349                let dv = self.get_derivative(v, i);
350                let expv = self.push_expr_nodiff(ExprKind::Exp(v));
351                self.push_expr_nodiff(ExprKind::Mul(expv, dv))
352            })
353            .collect();
354        self.push_expr(ExprKind::Exp(v), derivatives)
355    }
356
357    /// Natural logarithm of expression
358    pub fn log(&mut self, v: Expr) -> Expr {
359        // dlog(v) = dv/v
360        let derivatives = (0..self.inputs)
361            .map(|i| {
362                let dv = self.get_derivative(v, i);
363                self.push_expr_nodiff(ExprKind::Div(dv, v))
364            })
365            .collect();
366        self.push_expr(ExprKind::Log(v), derivatives)
367    }
368
369    /// Calculates the atan2 of two value.
370    pub fn atan2(&mut self, y: Expr, x: Expr) -> Expr {
371        // `datan2(v) = (x*dy - y*dx)/(y^2 + x^2)` In theory, at least. We'll see if that works.
372        let derivatives = (0..self.inputs)
373            .map(|i| {
374                let dy = self.get_derivative(y, i);
375                let dx = self.get_derivative(x, i);
376                let x_dy = self.push_expr_nodiff(ExprKind::Mul(x, dy));
377                let y_dx = self.push_expr_nodiff(ExprKind::Mul(y, dx));
378                let x2 = self.push_expr_nodiff(ExprKind::Mul(x, x));
379                let y2 = self.push_expr_nodiff(ExprKind::Mul(y, y));
380                let x2_plus_y2 = self.push_expr_nodiff(ExprKind::Add(x2, y2));
381                let xdy_minus_ydx = self.push_expr_nodiff(ExprKind::Sub(x_dy, y_dx));
382                self.push_expr_nodiff(ExprKind::Div(xdy_minus_ydx, x2_plus_y2))
383            })
384            .collect();
385        self.push_expr(ExprKind::Atan2(y, x), derivatives)
386    }
387
388    /// Negates the value.
389    pub fn neg(&mut self, v: Expr) -> Expr {
390        // `d(-v) = -dv`
391        let derivatives = (0..self.inputs)
392            .map(|i| {
393                let dv = self.get_derivative(v, i);
394                self.push_expr_nodiff(ExprKind::Neg(dv))
395            })
396            .collect();
397        self.push_expr(ExprKind::Neg(v), derivatives)
398    }
399
400    /// Gets the minimum value.
401    pub fn min(&mut self, a: Expr, b: Expr) -> Expr {
402        self.ternary(
403            Condition::Comparison(Comparison {
404                a: b,
405                b: a,
406                kind: ComparisonKind::Gt,
407            }),
408            a,
409            b,
410        )
411    }
412
413    /// Takes the absolute value.
414    pub fn abs(&mut self, v: Expr) -> Expr {
415        let cond = Condition::Comparison(Comparison {
416            a: v,
417            b: Self::zero(),
418            kind: ComparisonKind::Gt,
419        });
420
421        // `dabs(a) = if a > 0 da else -da`
422        let derivatives = (0..self.inputs)
423            .map(|i| {
424                let dv = self.get_derivative(v, i);
425                let minus_dv = self.push_expr_nodiff(ExprKind::Neg(dv));
426                self.push_expr_nodiff(ExprKind::Ternary(cond, dv, minus_dv))
427            })
428            .collect();
429        let minus_v = self.push_expr_nodiff(ExprKind::Neg(v));
430        self.push_expr(ExprKind::Ternary(cond, v, minus_v), derivatives)
431    }
432
433    /// A ternary expression
434    pub fn ternary(&mut self, condition: Condition, then: Expr, else_: Expr) -> Expr {
435        // dternary(then, else_) = if condition { dthen } else { delse_ }
436        let derivatives = (0..self.inputs)
437            .map(|i| {
438                let dthen = self.get_derivative(then, i);
439                let delse = self.get_derivative(else_, i);
440                self.push_expr_nodiff(ExprKind::Ternary(condition, dthen, delse))
441            })
442            .collect();
443        self.push_expr(ExprKind::Ternary(condition, then, else_), derivatives)
444    }
445}
446
447impl Context {
448    /// Returns a function computing the given expressions.
449    pub fn compute(&self, exprs: impl IntoIterator<Item = Expr>) -> Func {
450        Func {
451            func: compiler::compile(self, exprs),
452        }
453    }
454
455    /// Returns a function computing the gradient for the given expression.
456    pub fn compute_gradient(&self, expr: Expr) -> Func {
457        let func = compiler::compile(self, (0..self.inputs).map(|i| self.get_derivative(expr, i)));
458
459        Func { func }
460    }
461
462    /// Returns the gradient expressions, one for each input.
463    pub fn gradient(&self, expr: Expr) -> Vec<Expr> {
464        (0..self.inputs)
465            .map(|i| self.get_derivative(expr, i))
466            .collect()
467    }
468}
469
470/// A callable function accepting inputs and outputs (as mutable reference)
471#[derive(Clone, Copy)]
472pub struct Func {
473    func: fn(*const Float, *mut Float),
474}
475
476impl Func {
477    /// Call this function with `inputs` and collect the outputs into `dst`
478    pub fn call(&self, inputs: &[Float], dst: &mut [Float]) {
479        (self.func)(inputs.as_ptr(), dst.as_mut_ptr());
480    }
481}
482
483unsafe impl Send for Func {}
484unsafe impl Sync for Func {}