1#![warn(missing_docs)]
3use fidget_core::{
4 Error,
5 eval::{BulkEvaluator, Function, Tape, TracingEvaluator},
6 types::Grad,
7 var::Var,
8};
9use std::collections::HashMap;
10
11#[derive(Copy, Clone, Debug)]
13pub enum Parameter {
14 Free(f32),
16 Fixed(f32),
18}
19
20struct Solver<'a, F: Function> {
22 vars: &'a HashMap<Var, Parameter>,
24
25 grad_tapes: Vec<<F::GradSliceEval as BulkEvaluator>::Tape>,
27
28 point_tapes: Vec<<F::PointEval as TracingEvaluator>::Tape>,
30
31 grad_eval: F::GradSliceEval,
33
34 point_eval: F::PointEval,
36
37 input_grad: Vec<Vec<Grad>>,
39
40 input_point: Vec<f32>,
42
43 grad_index: HashMap<Var, usize>,
48}
49
50impl<'a, F: Function> Solver<'a, F> {
51 fn new(eqs: &'a [F], vars: &'a HashMap<Var, Parameter>) -> Self {
52 let grad_tapes = eqs
54 .iter()
55 .map(|f| f.grad_slice_tape(Default::default()))
56 .collect::<Vec<_>>();
57 let point_tapes = eqs
58 .iter()
59 .map(|f| f.point_tape(Default::default()))
60 .collect::<Vec<_>>();
61
62 let grad_index: HashMap<Var, usize> = vars
67 .iter()
68 .filter(|(_v, p)| matches!(p, Parameter::Free(..)))
69 .enumerate()
70 .map(|(i, (v, _p))| (*v, i))
71 .collect();
72
73 let input_grad =
76 vec![
77 vec![Grad::from(0f32); grad_index.len().div_ceil(3)];
78 vars.len()
79 ];
80 let input_point = vec![0f32; vars.len()];
81
82 Self {
83 vars,
84 grad_tapes,
85 point_tapes,
86 grad_eval: Default::default(),
87 point_eval: Default::default(),
88 grad_index,
89
90 input_grad,
91 input_point,
92 }
93 }
94
95 fn get_jacobian(
100 &mut self,
101 cur: &[f32],
102 jacobian: &mut nalgebra::DMatrix<f32>,
103 result: &mut nalgebra::DVector<f32>,
104 ) -> Result<(), Error> {
105 for (ti, tape) in self.grad_tapes.iter().enumerate() {
106 for (v, p) in self.vars {
108 let Some(i) = tape.vars().get(v) else {
109 continue;
110 };
111 let Some(slice) = self.input_grad.get_mut(i) else {
112 return Err(Error::BadVarIndex(i, self.input_grad.len()));
113 };
114 match p {
115 Parameter::Free(..) => {
116 let gi = self.grad_index[v];
117 for (j, v) in slice.iter_mut().enumerate() {
118 *v = Grad::new(
119 cur[gi],
120 if j * 3 == gi { 1.0 } else { 0.0 },
121 if j * 3 + 1 == gi { 1.0 } else { 0.0 },
122 if j * 3 + 2 == gi { 1.0 } else { 0.0 },
123 );
124 }
125 }
126 Parameter::Fixed(f) => {
127 slice.fill(Grad::new(*f, 0.0, 0.0, 0.0));
128 }
129 };
130 }
131 let out = self.grad_eval.eval(tape, &self.input_grad)?;
133
134 for gi in 0..self.grad_index.len() {
136 *jacobian.get_mut((ti, gi)).unwrap() = out[0][gi / 3].d(gi % 3);
137 }
138 result[ti] = out[0][0].v;
139 }
140 Ok(())
141 }
142
143 fn get_err(&mut self, cur: &[f32], delta: &[f32]) -> Result<f32, Error> {
144 let mut err = 0f32;
145 for tape in self.point_tapes.iter() {
146 for (v, p) in self.vars {
151 let Some(i) = tape.vars().get(v) else {
152 continue;
153 };
154 let Some(f) = self.input_point.get_mut(i) else {
155 return Err(Error::BadVarIndex(i, self.input_point.len()));
156 };
157 match p {
158 Parameter::Free(..) => {
159 let gi = self.grad_index[v];
160 *f = cur[gi] - delta[gi];
161 }
162 Parameter::Fixed(p) => {
163 *f = *p;
164 }
165 };
166 }
167 let (out, _t) = self.point_eval.eval(tape, &self.input_point)?;
169 err += out[0].powi(2); }
171 Ok(err)
172 }
173}
174
175pub fn solve<F: Function>(
188 eqs: &[F],
189 vars: &HashMap<Var, Parameter>,
190) -> Result<HashMap<Var, f32>, Error> {
191 let tapes = eqs
192 .iter()
193 .map(|f| f.grad_slice_tape(Default::default()))
194 .collect::<Vec<_>>();
195
196 let mut cur = HashMap::new();
198 for (v, p) in vars {
199 if let Parameter::Free(f) = *p {
200 cur.insert(*v, f);
201 }
202 }
203
204 let mut solver = Solver::new(eqs, vars);
205
206 let mut cur = vec![0f32; solver.grad_index.len()];
208 for (v, i) in &solver.grad_index {
209 let Parameter::Free(f) = vars[v] else {
210 unreachable!();
211 };
212 cur[*i] = f;
213 }
214
215 let mut jacobian = nalgebra::DMatrix::repeat(tapes.len(), cur.len(), 0f32);
217 let mut result = nalgebra::DVector::repeat(tapes.len(), 0f32);
218
219 let mut damping = 1.0;
220 let mut prev_err = f32::INFINITY;
221 let mut err_buf = [0f32; 4];
222 for i in 0.. {
223 solver.get_jacobian(&cur, &mut jacobian, &mut result)?;
224
225 if result.iter().all(|v| *v == 0.0) {
227 break;
228 }
229
230 let jt = jacobian.transpose();
231 let jt_j = &jt * &jacobian;
232
233 let jt_r = jt * &result;
234
235 let (err, step) = loop {
238 let adjusted = &jt_j
239 + damping * nalgebra::DMatrix::from_diagonal(&jt_j.diagonal());
240
241 let delta = adjusted
242 .svd(true, true)
243 .solve(&jt_r, f32::EPSILON)
244 .map_err(Error::SingularMatrix)?;
245
246 let err = solver.get_err(&cur, delta.as_slice())?;
247 if err > prev_err {
248 damping *= 1.5;
250 } else {
251 damping /= 3.0;
253 break (err, delta);
254 }
255 };
256
257 let mut changed = false;
262 for gi in 0..solver.grad_index.len() {
263 let prev = cur[gi];
264 cur[gi] -= step[gi];
265 changed |= prev != cur[gi];
266 }
267 err_buf[i % err_buf.len()] = err;
268 if !changed
269 || err == 0.0
270 || damping == 0.0
271 || err_buf.iter().all(|e| *e == err_buf[0])
272 {
273 break;
274 }
275 prev_err = err;
276 }
277
278 let out = solver
280 .grad_index
281 .into_iter()
282 .map(|(v, i)| (v, cur[i]))
283 .collect();
284 Ok(out)
285}
286
287#[cfg(test)]
288mod test {
289 use super::*;
290 use approx::{assert_relative_eq, relative_eq};
291 use fidget_core::{
292 context::{Context, Tree},
293 eval::MathFunction,
294 vm::VmFunction,
295 };
296
297 #[test]
298 fn basic_solver() {
299 let eqn = Tree::x() + Tree::y();
300 let mut ctx = Context::new();
301 let root = ctx.import(&eqn);
302
303 let f = VmFunction::new(&ctx, &[root]).unwrap();
304 let mut values = HashMap::new();
305 values.insert(Var::X, Parameter::Free(0.0));
306 values.insert(Var::Y, Parameter::Fixed(-1.0));
307 let sol = solve(&[f], &values).unwrap();
308 assert_eq!(sol.len(), 1);
309 assert_relative_eq!(sol[&Var::X], 1.0);
310 }
311
312 #[test]
313 fn four_vars_at_once() {
314 let vs = (0..4).map(|_| Var::new()).collect::<Vec<Var>>();
315 let mut root = Tree::from(vs[0]);
316 for v in &vs[1..] {
317 root += Tree::from(*v);
318 }
319 let mut ctx = Context::new();
320 let root = ctx.import(&root);
321
322 let f = VmFunction::new(&ctx, &[root]).unwrap();
323 let mut values = HashMap::new();
324 for (i, &v) in vs.iter().enumerate() {
325 values.insert(v, Parameter::Free(i as f32));
326 }
327 let sol = solve(&[f], &values).unwrap();
328 assert_eq!(sol.len(), 4);
329 let mut out = 0.0;
330 for v in &vs {
331 out += sol[v];
332 }
333 assert_relative_eq!(out, 0.0);
334 }
335
336 #[test]
337 fn four_vars_independent() {
338 let vs = (0..4).map(|_| Var::new()).collect::<Vec<Var>>();
339 let mut eqns = vec![];
340 let mut ctx = Context::new();
341 for (i, &v) in vs.iter().enumerate() {
342 let eqn = Tree::from(v) - Tree::from(i as f32);
343 let root = ctx.import(&eqn);
344 let f = VmFunction::new(&ctx, &[root]).unwrap();
345 eqns.push(f);
346 }
347
348 let mut values = HashMap::new();
349 for (i, &v) in vs.iter().enumerate() {
350 values.insert(v, Parameter::Free(i as f32 * 2.0));
351 }
352 let sol = solve(&eqns, &values).unwrap();
353 assert_eq!(sol.len(), 4);
354 for (i, v) in vs.iter().enumerate() {
355 assert_relative_eq!(i as f32, sol[v]);
356 }
357 }
358
359 #[test]
360 fn xy_nonlinear() {
361 let constraints = vec![
362 (Tree::x() * 2 + Tree::y() * 3) * (Tree::x() - Tree::y()) - 2,
363 Tree::x() * 3 + Tree::y() - 5,
364 ];
365 let mut ctx = Context::new();
366 let eqns = constraints
367 .into_iter()
368 .map(|c| {
369 let root = ctx.import(&c);
370 VmFunction::new(&ctx, &[root]).unwrap()
371 })
372 .collect::<Vec<_>>();
373
374 let mut values = HashMap::new();
375 values.insert(Var::X, Parameter::Free(0.0));
376 values.insert(Var::Y, Parameter::Free(0.0));
377 let sol = solve(&eqns, &values).unwrap();
378
379 let x = sol[&Var::X];
380 let y = sol[&Var::Y];
381
382 assert_relative_eq!((x * 2.0 + y * 3.0) * (x - y), 2.0);
383 assert_relative_eq!(x * 3.0 + y, 5.0);
384 }
385
386 #[test]
387 fn one_var_no_solution() {
388 let constraints = vec![Tree::x() - 1.0, Tree::x() - 2.0];
390
391 let mut ctx = Context::new();
392 let eqns = constraints
393 .into_iter()
394 .map(|c| {
395 let root = ctx.import(&c);
396 VmFunction::new(&ctx, &[root]).unwrap()
397 })
398 .collect::<Vec<_>>();
399
400 let mut values = HashMap::new();
401 values.insert(Var::X, Parameter::Free(0.0));
402
403 let sol = solve(&eqns, &values).unwrap();
404
405 let x = sol[&Var::X];
406 assert_relative_eq!(x, 1.5);
407 }
408
409 #[test]
410 fn solve_banana() {
411 let a = 1f32;
413 let b = 100f32;
414 let constraints = [a - Tree::x(), b * (Tree::y() - Tree::x().square())];
415
416 let mut ctx = Context::new();
417 let eqns = constraints
418 .into_iter()
419 .map(|c| {
420 let root = ctx.import(&c);
421 VmFunction::new(&ctx, &[root]).unwrap()
422 })
423 .collect::<Vec<_>>();
424
425 let mut values = HashMap::new();
426 values.insert(Var::X, Parameter::Free(0.0));
427 values.insert(Var::Y, Parameter::Free(0.0));
428 let sol = solve(&eqns, &values).unwrap();
429 assert_relative_eq!(sol[&Var::X], 1.0);
430 assert_relative_eq!(sol[&Var::Y], 1.0);
431
432 let mut values = HashMap::new();
433 values.insert(Var::X, Parameter::Free(1.0));
434 values.insert(Var::Y, Parameter::Free(1.0));
435 let sol = solve(&eqns, &values).unwrap();
436 assert_relative_eq!(sol[&Var::X], 1.0);
437 assert_relative_eq!(sol[&Var::Y], 1.0);
438 }
439
440 #[test]
441 fn solve_circle() {
442 let t = (Tree::x().square() + Tree::y().square()).sqrt();
443 let mut ctx = Context::new();
444 let root = ctx.import(&t);
445 let eqn = VmFunction::new(&ctx, &[root]).unwrap();
446 let eqns = [eqn];
447
448 let mut values = HashMap::new();
449 values.insert(Var::X, Parameter::Free(0.0));
450 values.insert(Var::Y, Parameter::Free(0.0));
451 let sol = solve(&eqns, &values).unwrap();
452 assert_relative_eq!(sol[&Var::X], 0.0);
453 assert_relative_eq!(sol[&Var::Y], 0.0);
454
455 let mut values = HashMap::new();
456 values.insert(Var::X, Parameter::Free(1.0));
457 values.insert(Var::Y, Parameter::Free(1.5));
458 let sol = solve(&eqns, &values).unwrap();
459 assert_relative_eq!(sol[&Var::X], 0.0);
460 assert_relative_eq!(sol[&Var::Y], 0.0);
461 }
462
463 fn one_linear(n: usize) {
464 let mut values = nalgebra::DVector::<f32>::zeros(n);
466 for v in values.iter_mut() {
467 *v = rand::random();
468 }
469
470 let vars = (0..n).map(|_| Var::new()).collect::<Vec<_>>();
471 let trees = vars.iter().map(|v| Tree::from(*v)).collect::<Vec<_>>();
472
473 let mut mat = nalgebra::DMatrix::<f32>::zeros(n, n);
474 for v in mat.iter_mut() {
475 *v = rand::random();
476 }
477
478 let sol = &mat * &values;
479
480 let mut ctx = Context::new();
481 let mut eqns = vec![];
482 for row in 0..n {
483 let mut out = Tree::from(-sol[row]);
484 for (col, t) in trees.iter().enumerate() {
485 out += *mat.get((row, col)).unwrap() * t.clone();
486 }
487 let root = ctx.import(&out);
488 let f = VmFunction::new(&ctx, &[root]).unwrap();
489 eqns.push(f);
490 }
491
492 let params = vars.iter().map(|v| (*v, Parameter::Free(0.0))).collect();
493 let out = solve(&eqns, ¶ms).unwrap();
494
495 for i in 0..n {
498 values[i] = out[&vars[i]];
499 }
500 let sol2 = &mat * &values;
501 let err = (&sol - &sol2).norm_squared();
502 assert!(err < 1e-3, "error {err} is too large");
503 for (a, b) in sol.iter().zip(sol2.iter()) {
504 assert_relative_eq!(a, b, epsilon = 1e-2);
505 }
506 }
507
508 #[test]
509 fn small_linear() {
510 for _ in 0..1000 {
511 one_linear(2);
512 }
513 }
514
515 #[test]
516 fn medium_linear() {
517 for _ in 0..1000 {
518 one_linear(10);
519 }
520 }
521
522 #[test]
523 fn big_linear() {
524 for _ in 0..50 {
525 one_linear(50);
526 }
527 }
528
529 fn one_quadratic(n: usize) -> bool {
530 let m: usize = n * n + n;
531
532 let mut values = nalgebra::DVector::<f32>::zeros(n);
534 for v in values.iter_mut() {
535 *v = rand::random();
536 }
537
538 let mut col = nalgebra::DVector::<f32>::zeros(m);
540 col.rows_range_mut(..n).copy_from(&values);
541 for i in 0..n {
542 for j in 0..n {
543 let index = i * n + j + n;
544 col[index] = values[i] * values[j];
545 }
546 }
547
548 let vars = (0..n).map(|_| Var::new()).collect::<Vec<_>>();
549 let trees = vars.iter().map(|v| Tree::from(*v)).collect::<Vec<_>>();
550
551 let mut mat = nalgebra::DMatrix::<f32>::zeros(n, m);
552 for v in mat.iter_mut() {
553 *v = rand::random();
554 }
555
556 let sol = &mat * &col;
557
558 let mut ctx = Context::new();
559 let mut eqns = vec![];
560 for row in 0..n {
561 let mut out = Tree::from(-sol[row]);
562 for (col, t) in trees.iter().enumerate() {
563 out += *mat.get((row, col)).unwrap() * t.clone();
564 }
565 for i in 0..n {
566 for j in 0..n {
567 let index = i * n + j + n;
568 out += *mat.get((row, index)).unwrap()
569 * trees[i].clone()
570 * trees[j].clone();
571 }
572 }
573 let root = ctx.import(&out);
574 let f = VmFunction::new(&ctx, &[root]).unwrap();
575 eqns.push(f);
576 }
577
578 let params = vars.iter().map(|v| (*v, Parameter::Free(0.5))).collect();
579 let out = solve(&eqns, ¶ms).unwrap();
580
581 for i in 0..n {
584 col[i] = out[&vars[i]];
585 for j in 0..n {
586 let index = i * n + j + n;
587 col[index] = out[&vars[i]] * out[&vars[j]];
588 }
589 }
590 let sol2 = &mat * &col;
591 let err = (&sol - &sol2).norm_squared();
592 if err >= 1e-3 {
593 return false;
594 }
595 for (a, b) in sol.iter().zip(sol2.iter()) {
596 if !relative_eq!(a, b, epsilon = 1e-2) {
597 return false;
598 }
599 }
600 true
601 }
602
603 fn many_quadratic(size: usize, count: usize) {
606 let mut okay = 0;
607 for _ in 0..count {
608 if one_quadratic(size) {
609 okay += 1;
610 }
611 }
612 assert!(
613 okay >= count * 9 / 10,
614 "too many failures: {okay} / {count}"
615 );
616 }
617
618 #[test]
619 fn small_quadratic() {
620 many_quadratic(2, 1000);
621 }
622
623 #[test]
624 fn medium_quadratic() {
625 many_quadratic(5, 100);
626 }
627
628 #[test]
629 fn large_quadratic() {
630 many_quadratic(10, 50);
631 }
632}