Skip to main content

fidget_solver/
lib.rs

1//! Solver for systems of equations expressed as sets of [Function] objects
2#![warn(missing_docs)]
3use fidget_core::{
4    Error,
5    eval::{BulkEvaluator, Function, Tape, TracingEvaluator},
6    types::Grad,
7    var::Var,
8};
9use std::collections::HashMap;
10
11/// Input parameter to the solver
12#[derive(Copy, Clone, Debug)]
13pub enum Parameter {
14    /// Free variable with the given starting position
15    Free(f32),
16    /// Fixed variable at the given value
17    Fixed(f32),
18}
19
20/// Workspace for solvers
21struct Solver<'a, F: Function> {
22    /// Input parameters
23    vars: &'a HashMap<Var, Parameter>,
24
25    /// Tapes for bulk gradient evaluation of each constraint
26    grad_tapes: Vec<<F::GradSliceEval as BulkEvaluator>::Tape>,
27
28    /// Tapes for single-point evaluation of each constraint
29    point_tapes: Vec<<F::PointEval as TracingEvaluator>::Tape>,
30
31    /// Bulk gradient evaluator, for use in computing the Jacobian
32    grad_eval: F::GradSliceEval,
33
34    /// Single-point evaluator, for use in checking our current error
35    point_eval: F::PointEval,
36
37    /// Input data for use when calling the gradient bulk evaluator
38    input_grad: Vec<Vec<Grad>>,
39
40    /// Input data for use when calling the single-point evaluator
41    input_point: Vec<f32>,
42
43    /// Map from (free) variables to the index of their gradient
44    ///
45    /// We evaluate 3x gradients per sample, so for `grad_index = gi`, the
46    /// relevant derivative will be `out[gi / 3].d(gi % 3)`
47    grad_index: HashMap<Var, usize>,
48}
49
50impl<'a, F: Function> Solver<'a, F> {
51    fn new(eqs: &'a [F], vars: &'a HashMap<Var, Parameter>) -> Self {
52        // Build our per-constraint
53        let grad_tapes = eqs
54            .iter()
55            .map(|f| f.grad_slice_tape(Default::default()))
56            .collect::<Vec<_>>();
57        let point_tapes = eqs
58            .iter()
59            .map(|f| f.point_tape(Default::default()))
60            .collect::<Vec<_>>();
61
62        // Build a map from *free* variable to index of its gradient, since
63        // we'll be using tightly-packed Vec everywhere here
64        //
65        // (We ignore the gradient of fixed variables)
66        let grad_index: HashMap<Var, usize> = vars
67            .iter()
68            .filter(|(_v, p)| matches!(p, Parameter::Free(..)))
69            .enumerate()
70            .map(|(i, (v, _p))| (*v, i))
71            .collect();
72
73        // Build a scratch array with rows for each variable, and enough columns
74        // to simultaneously compute all of the gradients that we need
75        let input_grad =
76            vec![
77                vec![Grad::from(0f32); grad_index.len().div_ceil(3)];
78                vars.len()
79            ];
80        let input_point = vec![0f32; vars.len()];
81
82        Self {
83            vars,
84            grad_tapes,
85            point_tapes,
86            grad_eval: Default::default(),
87            point_eval: Default::default(),
88            grad_index,
89
90            input_grad,
91            input_point,
92        }
93    }
94
95    /// Computes the Jacobian into `cur`
96    ///
97    /// # Panics
98    /// If `jacobian` or `result` are an invalid size
99    fn get_jacobian(
100        &mut self,
101        cur: &[f32],
102        jacobian: &mut nalgebra::DMatrix<f32>,
103        result: &mut nalgebra::DVector<f32>,
104    ) -> Result<(), Error> {
105        for (ti, tape) in self.grad_tapes.iter().enumerate() {
106            // Update the values in the gradient evaluation array
107            for (v, p) in self.vars {
108                let Some(i) = tape.vars().get(v) else {
109                    continue;
110                };
111                let Some(slice) = self.input_grad.get_mut(i) else {
112                    return Err(Error::BadVarIndex(i, self.input_grad.len()));
113                };
114                match p {
115                    Parameter::Free(..) => {
116                        let gi = self.grad_index[v];
117                        for (j, v) in slice.iter_mut().enumerate() {
118                            *v = Grad::new(
119                                cur[gi],
120                                if j * 3 == gi { 1.0 } else { 0.0 },
121                                if j * 3 + 1 == gi { 1.0 } else { 0.0 },
122                                if j * 3 + 2 == gi { 1.0 } else { 0.0 },
123                            );
124                        }
125                    }
126                    Parameter::Fixed(f) => {
127                        slice.fill(Grad::new(*f, 0.0, 0.0, 0.0));
128                    }
129                };
130            }
131            // Do the actual gradient evaluation
132            let out = self.grad_eval.eval(tape, &self.input_grad)?;
133
134            // Populate this row of the Jacobian
135            for gi in 0..self.grad_index.len() {
136                *jacobian.get_mut((ti, gi)).unwrap() = out[0][gi / 3].d(gi % 3);
137            }
138            result[ti] = out[0][0].v;
139        }
140        Ok(())
141    }
142
143    fn get_err(&mut self, cur: &[f32], delta: &[f32]) -> Result<f32, Error> {
144        let mut err = 0f32;
145        for tape in self.point_tapes.iter() {
146            // Update the free values in the gradient evaluation array
147            //
148            // (we preloaded unit gradients and fixed values in the appropriate
149            // locations, which don't change from evaluation to evaluation)
150            for (v, p) in self.vars {
151                let Some(i) = tape.vars().get(v) else {
152                    continue;
153                };
154                let Some(f) = self.input_point.get_mut(i) else {
155                    return Err(Error::BadVarIndex(i, self.input_point.len()));
156                };
157                match p {
158                    Parameter::Free(..) => {
159                        let gi = self.grad_index[v];
160                        *f = cur[gi] - delta[gi];
161                    }
162                    Parameter::Fixed(p) => {
163                        *f = *p;
164                    }
165                };
166            }
167            // Do the actual gradient evaluation
168            let (out, _t) = self.point_eval.eval(tape, &self.input_point)?;
169            err += out[0].powi(2); // TODO: consolidate into a single tape
170        }
171        Ok(err)
172    }
173}
174
175/// Least-squares minimization on a set of functions
176///
177/// Returns a map from free variable to its final value
178///
179/// Minimization is accomplished using a relatively basic implementation of
180/// the [Levenberg-Marquardt algorithm](https://en.wikipedia.org/wiki/Levenberg%E2%80%93Marquardt_algorithm).
181///
182/// ## References
183/// - [The Levenberg-Marquardt Algorithm (Ranganathan 2004)](http://ananth.in/docs/lmtut.pdf)
184/// - [Basics on Continuous Optimization ยง Levenberg-Marquardt](https://www.brnt.eu/phd/node10.html#SECTION00622700000000000000)
185/// - [Improvements to the Levenberg-Marquardt algorithm for nonlinear
186///   least-squares minimization (Transtrum 2012)](https://arxiv.org/pdf/1201.5885)
187pub fn solve<F: Function>(
188    eqs: &[F],
189    vars: &HashMap<Var, Parameter>,
190) -> Result<HashMap<Var, f32>, Error> {
191    let tapes = eqs
192        .iter()
193        .map(|f| f.grad_slice_tape(Default::default()))
194        .collect::<Vec<_>>();
195
196    // Current values for free variables
197    let mut cur = HashMap::new();
198    for (v, p) in vars {
199        if let Parameter::Free(f) = *p {
200            cur.insert(*v, f);
201        }
202    }
203
204    let mut solver = Solver::new(eqs, vars);
205
206    // Build an array of current values for each free variable
207    let mut cur = vec![0f32; solver.grad_index.len()];
208    for (v, i) in &solver.grad_index {
209        let Parameter::Free(f) = vars[v] else {
210            unreachable!();
211        };
212        cur[*i] = f;
213    }
214
215    // Working arrays for the current Jacobian and result
216    let mut jacobian = nalgebra::DMatrix::repeat(tapes.len(), cur.len(), 0f32);
217    let mut result = nalgebra::DVector::repeat(tapes.len(), 0f32);
218
219    let mut damping = 1.0;
220    let mut prev_err = f32::INFINITY;
221    let mut err_buf = [0f32; 4];
222    for i in 0.. {
223        solver.get_jacobian(&cur, &mut jacobian, &mut result)?;
224
225        // Early exit if we're done
226        if result.iter().all(|v| *v == 0.0) {
227            break;
228        }
229
230        let jt = jacobian.transpose();
231        let jt_j = &jt * &jacobian;
232
233        let jt_r = jt * &result;
234
235        // TODO: be optimistic and evaluate the full gradient on the first
236        // attempt, since it should usually succeed?
237        let (err, step) = loop {
238            let adjusted = &jt_j
239                + damping * nalgebra::DMatrix::from_diagonal(&jt_j.diagonal());
240
241            let delta = adjusted
242                .svd(true, true)
243                .solve(&jt_r, f32::EPSILON)
244                .map_err(Error::SingularMatrix)?;
245
246            let err = solver.get_err(&cur, delta.as_slice())?;
247            if err > prev_err {
248                // Keep going in this inner loop, taking smaller steps
249                damping *= 1.5;
250            } else {
251                // We found a good step size, so reduce damping
252                damping /= 3.0;
253                break (err, delta);
254            }
255        };
256
257        // Update our current position, checking whether it actually changed
258        // (i.e. whether our steps are below the floating-point epsilon)
259        //
260        // TODO: improve exit criteria?
261        let mut changed = false;
262        for gi in 0..solver.grad_index.len() {
263            let prev = cur[gi];
264            cur[gi] -= step[gi];
265            changed |= prev != cur[gi];
266        }
267        err_buf[i % err_buf.len()] = err;
268        if !changed
269            || err == 0.0
270            || damping == 0.0
271            || err_buf.iter().all(|e| *e == err_buf[0])
272        {
273            break;
274        }
275        prev_err = err;
276    }
277
278    // Return the new "current" values, which are our optimized position
279    let out = solver
280        .grad_index
281        .into_iter()
282        .map(|(v, i)| (v, cur[i]))
283        .collect();
284    Ok(out)
285}
286
287#[cfg(test)]
288mod test {
289    use super::*;
290    use approx::{assert_relative_eq, relative_eq};
291    use fidget_core::{
292        context::{Context, Tree},
293        eval::MathFunction,
294        vm::VmFunction,
295    };
296
297    #[test]
298    fn basic_solver() {
299        let eqn = Tree::x() + Tree::y();
300        let mut ctx = Context::new();
301        let root = ctx.import(&eqn);
302
303        let f = VmFunction::new(&ctx, &[root]).unwrap();
304        let mut values = HashMap::new();
305        values.insert(Var::X, Parameter::Free(0.0));
306        values.insert(Var::Y, Parameter::Fixed(-1.0));
307        let sol = solve(&[f], &values).unwrap();
308        assert_eq!(sol.len(), 1);
309        assert_relative_eq!(sol[&Var::X], 1.0);
310    }
311
312    #[test]
313    fn four_vars_at_once() {
314        let vs = (0..4).map(|_| Var::new()).collect::<Vec<Var>>();
315        let mut root = Tree::from(vs[0]);
316        for v in &vs[1..] {
317            root += Tree::from(*v);
318        }
319        let mut ctx = Context::new();
320        let root = ctx.import(&root);
321
322        let f = VmFunction::new(&ctx, &[root]).unwrap();
323        let mut values = HashMap::new();
324        for (i, &v) in vs.iter().enumerate() {
325            values.insert(v, Parameter::Free(i as f32));
326        }
327        let sol = solve(&[f], &values).unwrap();
328        assert_eq!(sol.len(), 4);
329        let mut out = 0.0;
330        for v in &vs {
331            out += sol[v];
332        }
333        assert_relative_eq!(out, 0.0);
334    }
335
336    #[test]
337    fn four_vars_independent() {
338        let vs = (0..4).map(|_| Var::new()).collect::<Vec<Var>>();
339        let mut eqns = vec![];
340        let mut ctx = Context::new();
341        for (i, &v) in vs.iter().enumerate() {
342            let eqn = Tree::from(v) - Tree::from(i as f32);
343            let root = ctx.import(&eqn);
344            let f = VmFunction::new(&ctx, &[root]).unwrap();
345            eqns.push(f);
346        }
347
348        let mut values = HashMap::new();
349        for (i, &v) in vs.iter().enumerate() {
350            values.insert(v, Parameter::Free(i as f32 * 2.0));
351        }
352        let sol = solve(&eqns, &values).unwrap();
353        assert_eq!(sol.len(), 4);
354        for (i, v) in vs.iter().enumerate() {
355            assert_relative_eq!(i as f32, sol[v]);
356        }
357    }
358
359    #[test]
360    fn xy_nonlinear() {
361        let constraints = vec![
362            (Tree::x() * 2 + Tree::y() * 3) * (Tree::x() - Tree::y()) - 2,
363            Tree::x() * 3 + Tree::y() - 5,
364        ];
365        let mut ctx = Context::new();
366        let eqns = constraints
367            .into_iter()
368            .map(|c| {
369                let root = ctx.import(&c);
370                VmFunction::new(&ctx, &[root]).unwrap()
371            })
372            .collect::<Vec<_>>();
373
374        let mut values = HashMap::new();
375        values.insert(Var::X, Parameter::Free(0.0));
376        values.insert(Var::Y, Parameter::Free(0.0));
377        let sol = solve(&eqns, &values).unwrap();
378
379        let x = sol[&Var::X];
380        let y = sol[&Var::Y];
381
382        assert_relative_eq!((x * 2.0 + y * 3.0) * (x - y), 2.0);
383        assert_relative_eq!(x * 3.0 + y, 5.0);
384    }
385
386    #[test]
387    fn one_var_no_solution() {
388        // Solve for X == 1 and X == 2 simultaneously
389        let constraints = vec![Tree::x() - 1.0, Tree::x() - 2.0];
390
391        let mut ctx = Context::new();
392        let eqns = constraints
393            .into_iter()
394            .map(|c| {
395                let root = ctx.import(&c);
396                VmFunction::new(&ctx, &[root]).unwrap()
397            })
398            .collect::<Vec<_>>();
399
400        let mut values = HashMap::new();
401        values.insert(Var::X, Parameter::Free(0.0));
402
403        let sol = solve(&eqns, &values).unwrap();
404
405        let x = sol[&Var::X];
406        assert_relative_eq!(x, 1.5);
407    }
408
409    #[test]
410    fn solve_banana() {
411        // See https://en.wikipedia.org/wiki/Rosenbrock_function
412        let a = 1f32;
413        let b = 100f32;
414        let constraints = [a - Tree::x(), b * (Tree::y() - Tree::x().square())];
415
416        let mut ctx = Context::new();
417        let eqns = constraints
418            .into_iter()
419            .map(|c| {
420                let root = ctx.import(&c);
421                VmFunction::new(&ctx, &[root]).unwrap()
422            })
423            .collect::<Vec<_>>();
424
425        let mut values = HashMap::new();
426        values.insert(Var::X, Parameter::Free(0.0));
427        values.insert(Var::Y, Parameter::Free(0.0));
428        let sol = solve(&eqns, &values).unwrap();
429        assert_relative_eq!(sol[&Var::X], 1.0);
430        assert_relative_eq!(sol[&Var::Y], 1.0);
431
432        let mut values = HashMap::new();
433        values.insert(Var::X, Parameter::Free(1.0));
434        values.insert(Var::Y, Parameter::Free(1.0));
435        let sol = solve(&eqns, &values).unwrap();
436        assert_relative_eq!(sol[&Var::X], 1.0);
437        assert_relative_eq!(sol[&Var::Y], 1.0);
438    }
439
440    #[test]
441    fn solve_circle() {
442        let t = (Tree::x().square() + Tree::y().square()).sqrt();
443        let mut ctx = Context::new();
444        let root = ctx.import(&t);
445        let eqn = VmFunction::new(&ctx, &[root]).unwrap();
446        let eqns = [eqn];
447
448        let mut values = HashMap::new();
449        values.insert(Var::X, Parameter::Free(0.0));
450        values.insert(Var::Y, Parameter::Free(0.0));
451        let sol = solve(&eqns, &values).unwrap();
452        assert_relative_eq!(sol[&Var::X], 0.0);
453        assert_relative_eq!(sol[&Var::Y], 0.0);
454
455        let mut values = HashMap::new();
456        values.insert(Var::X, Parameter::Free(1.0));
457        values.insert(Var::Y, Parameter::Free(1.5));
458        let sol = solve(&eqns, &values).unwrap();
459        assert_relative_eq!(sol[&Var::X], 0.0);
460        assert_relative_eq!(sol[&Var::Y], 0.0);
461    }
462
463    fn one_linear(n: usize) {
464        // Build a random matrix of our solutions
465        let mut values = nalgebra::DVector::<f32>::zeros(n);
466        for v in values.iter_mut() {
467            *v = rand::random();
468        }
469
470        let vars = (0..n).map(|_| Var::new()).collect::<Vec<_>>();
471        let trees = vars.iter().map(|v| Tree::from(*v)).collect::<Vec<_>>();
472
473        let mut mat = nalgebra::DMatrix::<f32>::zeros(n, n);
474        for v in mat.iter_mut() {
475            *v = rand::random();
476        }
477
478        let sol = &mat * &values;
479
480        let mut ctx = Context::new();
481        let mut eqns = vec![];
482        for row in 0..n {
483            let mut out = Tree::from(-sol[row]);
484            for (col, t) in trees.iter().enumerate() {
485                out += *mat.get((row, col)).unwrap() * t.clone();
486            }
487            let root = ctx.import(&out);
488            let f = VmFunction::new(&ctx, &[root]).unwrap();
489            eqns.push(f);
490        }
491
492        let params = vars.iter().map(|v| (*v, Parameter::Free(0.0))).collect();
493        let out = solve(&eqns, &params).unwrap();
494
495        // It's possible for there to be multiple solutions here, so we'll check
496        // the actual equations.
497        for i in 0..n {
498            values[i] = out[&vars[i]];
499        }
500        let sol2 = &mat * &values;
501        let err = (&sol - &sol2).norm_squared();
502        assert!(err < 1e-3, "error {err} is too large");
503        for (a, b) in sol.iter().zip(sol2.iter()) {
504            assert_relative_eq!(a, b, epsilon = 1e-2);
505        }
506    }
507
508    #[test]
509    fn small_linear() {
510        for _ in 0..1000 {
511            one_linear(2);
512        }
513    }
514
515    #[test]
516    fn medium_linear() {
517        for _ in 0..1000 {
518            one_linear(10);
519        }
520    }
521
522    #[test]
523    fn big_linear() {
524        for _ in 0..50 {
525            one_linear(50);
526        }
527    }
528
529    fn one_quadratic(n: usize) -> bool {
530        let m: usize = n * n + n;
531
532        // Build a random matrix of our solutions
533        let mut values = nalgebra::DVector::<f32>::zeros(n);
534        for v in values.iter_mut() {
535            *v = rand::random();
536        }
537
538        // Build a column vector of [a b c ... aa ab ac ... ba bb ...]^T
539        let mut col = nalgebra::DVector::<f32>::zeros(m);
540        col.rows_range_mut(..n).copy_from(&values);
541        for i in 0..n {
542            for j in 0..n {
543                let index = i * n + j + n;
544                col[index] = values[i] * values[j];
545            }
546        }
547
548        let vars = (0..n).map(|_| Var::new()).collect::<Vec<_>>();
549        let trees = vars.iter().map(|v| Tree::from(*v)).collect::<Vec<_>>();
550
551        let mut mat = nalgebra::DMatrix::<f32>::zeros(n, m);
552        for v in mat.iter_mut() {
553            *v = rand::random();
554        }
555
556        let sol = &mat * &col;
557
558        let mut ctx = Context::new();
559        let mut eqns = vec![];
560        for row in 0..n {
561            let mut out = Tree::from(-sol[row]);
562            for (col, t) in trees.iter().enumerate() {
563                out += *mat.get((row, col)).unwrap() * t.clone();
564            }
565            for i in 0..n {
566                for j in 0..n {
567                    let index = i * n + j + n;
568                    out += *mat.get((row, index)).unwrap()
569                        * trees[i].clone()
570                        * trees[j].clone();
571                }
572            }
573            let root = ctx.import(&out);
574            let f = VmFunction::new(&ctx, &[root]).unwrap();
575            eqns.push(f);
576        }
577
578        let params = vars.iter().map(|v| (*v, Parameter::Free(0.5))).collect();
579        let out = solve(&eqns, &params).unwrap();
580
581        // It's possible for there to be multiple solutions here, so we'll check
582        // the actual equations.
583        for i in 0..n {
584            col[i] = out[&vars[i]];
585            for j in 0..n {
586                let index = i * n + j + n;
587                col[index] = out[&vars[i]] * out[&vars[j]];
588            }
589        }
590        let sol2 = &mat * &col;
591        let err = (&sol - &sol2).norm_squared();
592        if err >= 1e-3 {
593            return false;
594        }
595        for (a, b) in sol.iter().zip(sol2.iter()) {
596            if !relative_eq!(a, b, epsilon = 1e-2) {
597                return false;
598            }
599        }
600        true
601    }
602
603    // Quadratic functions can get trapped in local minima, so we only require a
604    // certain percent to succeed (hard-coded to 90% right now)
605    fn many_quadratic(size: usize, count: usize) {
606        let mut okay = 0;
607        for _ in 0..count {
608            if one_quadratic(size) {
609                okay += 1;
610            }
611        }
612        assert!(
613            okay >= count * 9 / 10,
614            "too many failures: {okay} / {count}"
615        );
616    }
617
618    #[test]
619    fn small_quadratic() {
620        many_quadratic(2, 1000);
621    }
622
623    #[test]
624    fn medium_quadratic() {
625        many_quadratic(5, 100);
626    }
627
628    #[test]
629    fn large_quadratic() {
630        many_quadratic(10, 50);
631    }
632}