1use crate::polynomial::Polynomial;
2use nalgebra::{ComplexField, DMatrix, DVector, RealField, SVector};
3use num_traits::{FromPrimitive, One, Zero};
4
5pub fn linear_fit<N>(xs: &[N], ys: &[N]) -> Result<Polynomial<N>, String>
13where
14 N: ComplexField + FromPrimitive + Copy,
15 <N as ComplexField>::RealField: FromPrimitive + Copy,
16{
17 if xs.len() != ys.len() {
18 return Err("linear_fit: xs length does not match ys length".to_owned());
19 }
20
21 let mut sum_x = N::zero();
22 let mut sum_y = N::zero();
23 let mut sum_x_sq = N::zero();
24 let mut sum_y_sq = N::zero();
25 let mut sum_xy = N::zero();
26
27 for (ind, x) in xs.iter().enumerate() {
28 sum_x += *x;
29 sum_y += ys[ind];
30 sum_x_sq += x.powi(2);
31 sum_y_sq += ys[ind].powi(2);
32 sum_xy += ys[ind] * *x;
33 }
34
35 let m = N::from_usize(xs.len()).unwrap();
36 let denom = m * sum_x_sq - sum_x.powi(2);
37 let a = (m * sum_xy - sum_x * sum_y) / denom;
38 let b = (sum_x_sq * sum_y - sum_xy * sum_x) / denom;
39
40 Ok(polynomial![a, b])
41}
42
43fn jac_finite_differences<N, F, const V: usize>(
45 mut f: F,
46 xs: &[N],
47 params: &mut SVector<N, V>,
48 mat: &mut DMatrix<N>,
49 h: N::RealField,
50) where
51 N: ComplexField + FromPrimitive + Copy,
52 F: FnMut(N, &SVector<N, V>) -> N,
53 <N as ComplexField>::RealField: FromPrimitive + Copy,
54{
55 let h = N::from_real(h);
56 let denom = N::one() / (N::from_i32(2).unwrap() * h);
57 for row in 0..mat.column(0).len() {
58 for col in 0..mat.row(0).len() {
59 params[col] += h;
60 let above = f(xs[row], params);
61 params[col] -= h;
62 params[col] -= h;
63 let below = f(xs[row], params);
64 mat[(row, col)] = denom * (above + below);
65 params[col] += h;
66 }
67 }
68}
69
70fn jac_analytic<N, F, const V: usize>(
72 mut jac: F,
73 xs: &[N],
74 params: &mut SVector<N, V>,
75 mat: &mut DMatrix<N>,
76) where
77 N: ComplexField + Copy,
78 F: FnMut(N, &SVector<N, V>) -> SVector<N, V>,
79{
80 for row in 0..mat.column(0).len() {
81 let deriv = jac(xs[row], params);
82 for col in 0..mat.row(0).len() {
83 mat[(row, col)] = deriv[col];
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
89pub struct CurveFitParams<N: ComplexField> {
90 pub damping: N::RealField,
91 pub tolerance: N::RealField,
92 pub h: N::RealField,
93 pub damping_mult: N::RealField,
94}
95
96impl<N: ComplexField + FromPrimitive> Default for CurveFitParams<N> {
97 fn default() -> Self {
98 CurveFitParams {
99 damping: N::from_f64(2.0).unwrap().real(),
100 tolerance: N::from_f64(1e-5).unwrap().real(),
101 h: N::from_f64(0.1).unwrap().real(),
102 damping_mult: N::from_f64(1.5).unwrap().real(),
103 }
104 }
105}
106
107#[allow(clippy::too_many_arguments)]
108fn initial_residuals<N, F, const V: usize>(
109 xs: &[N],
110 ys: &DVector<N>,
111 damping: &mut N::RealField,
112 damping_mult: N::RealField,
113 h: N::RealField,
114 mut f: F,
115 jac: &mut DMatrix<N>,
116 jac_transpose: &mut DMatrix<N>,
117 mut params: SVector<N, V>,
118) -> Result<(N::RealField, DVector<N>), String>
119where
120 N: ComplexField + Copy + FromPrimitive,
121 <N as ComplexField>::RealField: Copy + FromPrimitive,
122 F: FnMut(N, &SVector<N, V>) -> N,
123{
124 let mut resid = Vec::with_capacity(xs.len());
125 for (ind, &x) in xs.iter().enumerate() {
126 resid.push(ys[ind] - f(x, ¶ms));
127 }
128 let sum_sq_initial: N::RealField = resid
129 .iter()
130 .map(|&r| r.modulus_squared())
131 .fold(N::RealField::zero(), |acc, r| acc + r);
132
133 let mut sum_sq = sum_sq_initial + N::RealField::one();
135 let mut damping_tmp = *damping / damping_mult;
136 let mut j = 0;
137 let mut evaluation: DVector<N> =
138 DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, ¶ms)));
139 while sum_sq > sum_sq_initial && j < 1000 {
140 damping_tmp *= damping_mult;
141 let diff = ys - &evaluation;
142 let mut b = jac_transpose as &DMatrix<N> * &diff;
143 let mut multiplied = jac_transpose as &DMatrix<N> * jac as &DMatrix<N>;
145 for i in 0..multiplied.row(0).len() {
146 multiplied[(i, i)] *= N::one() + N::from_real(damping_tmp);
147 }
148 let lu = multiplied.clone().lu();
149 let solved = lu.solve_mut(&mut b);
150 if !solved {
151 return Err("curve_fit: unable to solve linear equation".to_owned());
152 }
153 params += &b;
154 evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, ¶ms)));
155 let diff = ys - &evaluation;
156 sum_sq = diff
157 .iter()
158 .map(|&r| r.modulus_squared())
159 .fold(N::RealField::zero(), |acc, r| acc + r);
160 j += 1;
161 jac_finite_differences(&mut f, xs, &mut params, jac, h);
162 *jac_transpose = jac.transpose();
163 }
164 if j != 1000 {
165 *damping = damping_tmp;
166 }
167 Ok((sum_sq, evaluation))
168}
169
170#[allow(clippy::too_many_arguments)]
171fn initial_residuals_exact<N, F, G, const V: usize>(
172 xs: &[N],
173 ys: &DVector<N>,
174 damping: &mut N::RealField,
175 damping_mult: N::RealField,
176 mut f: F,
177 mut jacobian: G,
178 jac: &mut DMatrix<N>,
179 jac_transpose: &mut DMatrix<N>,
180 mut params: SVector<N, V>,
181) -> Result<(N::RealField, DVector<N>), String>
182where
183 N: ComplexField + Copy + FromPrimitive,
184 <N as ComplexField>::RealField: Copy + FromPrimitive,
185 F: FnMut(N, &SVector<N, V>) -> N,
186 G: FnMut(N, &SVector<N, V>) -> SVector<N, V>,
187{
188 let mut resid = Vec::with_capacity(xs.len());
190 for (ind, &x) in xs.iter().enumerate() {
191 resid.push(ys[ind] - f(x, ¶ms));
192 }
193 let sum_sq_initial: N::RealField = resid
194 .iter()
195 .map(|&r| r.modulus_squared())
196 .fold(N::RealField::zero(), |acc, r| acc + r);
197
198 let mut sum_sq = sum_sq_initial + N::RealField::one();
200 let mut damping_tmp = *damping / damping_mult;
201 let mut j = 0;
202 let mut evaluation: DVector<N> =
203 DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, ¶ms)));
204 while sum_sq > sum_sq_initial && j < 1000 {
205 damping_tmp *= damping_mult;
206 let diff = ys - &evaluation;
207 let mut b = jac_transpose as &DMatrix<N> * &diff;
208 let mut multiplied = jac_transpose as &DMatrix<N> * jac as &DMatrix<N>;
210 for i in 0..multiplied.row(0).len() {
211 multiplied[(i, i)] *= N::one() + N::from_real(damping_tmp);
212 }
213 let lu = multiplied.clone().lu();
214 let solved = lu.solve_mut(&mut b);
215 if !solved {
216 let lu = multiplied.clone().full_piv_lu();
217 let full_lu_solved = lu.solve_mut(&mut b);
218 if !full_lu_solved {
219 let qr = multiplied.qr();
220 let qr_solved = qr.solve_mut(&mut b);
221 if !qr_solved {
222 return Err("curve_fit_jac: unable to solve linear equation".to_owned());
223 }
224 }
225 }
226 params += &b;
227 evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, ¶ms)));
228 let diff = ys - &evaluation;
229 sum_sq = diff
230 .iter()
231 .map(|&r| r.modulus_squared())
232 .fold(N::RealField::zero(), |acc, r| acc + r);
233 j += 1;
234 jac_analytic(&mut jacobian, xs, &mut params, jac);
235 *jac_transpose = jac.transpose();
236 }
237 if j != 1000 {
238 *damping = damping_tmp;
239 }
240
241 Ok((sum_sq, evaluation))
242}
243
244pub fn curve_fit<N, F, const V: usize>(
257 mut f: F,
258 xs: &[N],
259 ys: &[N],
260 initial: &[N],
261 params: &CurveFitParams<N>,
262) -> Result<SVector<N, V>, String>
263where
264 N: ComplexField + FromPrimitive + Copy,
265 <N as ComplexField>::RealField: FromPrimitive + Copy,
266 F: FnMut(N, &SVector<N, V>) -> N,
267{
268 let tol = params.tolerance;
269 let mut damping = params.damping;
270 let h = params.h;
271 let damping_mult = params.damping_mult;
272
273 if !tol.is_sign_positive() {
274 return Err("curve_fit: tol must be positive".to_owned());
275 }
276
277 if !h.is_sign_positive() {
278 return Err("curve_fit: h must be positive".to_owned());
279 }
280
281 if !damping.is_sign_positive() {
282 return Err("curve_fit: damping must be positive".to_owned());
283 }
284
285 if xs.len() != ys.len() {
286 return Err("curve_fit: xs length must match ys length".to_owned());
287 }
288
289 let mut params = SVector::<N, V>::from_column_slice(initial);
290 let ys = DVector::<N>::from_column_slice(ys);
291 let mut jac: DMatrix<N> = DMatrix::identity(xs.len(), params.len());
292 jac_finite_differences(&mut f, xs, &mut params, &mut jac, h);
293 let mut jac_transpose = jac.transpose();
294
295 let (mut sum_sq, mut evaluation) = initial_residuals(
297 xs,
298 &ys,
299 &mut damping,
300 damping_mult,
301 h,
302 &mut f,
303 &mut jac,
304 &mut jac_transpose,
305 params,
306 )?;
307
308 let mut last_sum_sq = sum_sq;
309 sum_sq += N::from_u8(2).unwrap().real() * tol;
310 while (last_sum_sq - sum_sq).abs() > tol {
311 last_sum_sq = sum_sq;
312 let diff = &ys - &evaluation;
314 let mut b = &jac_transpose * &diff;
315 let mut b_div = b.clone();
316 let mut multiplied = &jac_transpose * &jac;
318 let mut multiplied_div = multiplied.clone();
319 for i in 0..multiplied.row(0).len() {
320 multiplied[(i, i)] *= N::one() + N::from_real(damping);
321 }
322 let lu = multiplied.clone().lu();
324 let lu_solved = lu.solve_mut(&mut b);
325 if !lu_solved {
326 return Err("curve_fit: unable to solve linear equation".to_owned());
327 }
328 let new_params = params + &b;
329
330 for i in 0..multiplied_div.row(0).len() {
332 multiplied_div[(i, i)] *= N::one() + N::from_real(damping / damping_mult);
333 }
334 let lu = multiplied_div.clone().lu();
335 let solved = lu.solve_mut(&mut b_div);
336 if !solved {
337 let lu = multiplied_div.clone().full_piv_lu();
338 let full_lu_solved = lu.solve_mut(&mut b_div);
339 if !full_lu_solved {
340 let qr = multiplied_div.qr();
341 let qr_solved = qr.solve_mut(&mut b_div);
342 if !qr_solved {
343 return Err("curve_fit: unable to solve linear equation".to_owned());
344 }
345 }
346 }
347 let new_params_div = params + &b_div;
348
349 evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params)));
351 let evaluation_div =
352 DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params_div)));
353 let diff = &ys - &evaluation;
354 let diff_div = &ys - &evaluation_div;
355
356 let resid: N::RealField = diff
357 .iter()
358 .map(|&r| r.modulus_squared())
359 .fold(N::RealField::zero(), |acc, r| acc + r);
360 let resid_div: N::RealField = diff_div
361 .iter()
362 .map(|&r| r.modulus_squared())
363 .fold(N::RealField::zero(), |acc, r| acc + r);
364
365 if resid_div < resid {
366 damping /= damping_mult;
367 evaluation = evaluation_div;
368 params = new_params_div;
369 sum_sq = resid_div;
370 } else {
371 params = new_params;
372 sum_sq = resid;
373 }
374
375 jac_finite_differences(&mut f, xs, &mut params, &mut jac, h);
376 jac_transpose = jac.transpose();
377 }
378
379 Ok(params)
380}
381
382pub fn curve_fit_jac<N, F, G, const V: usize>(
395 mut f: F,
396 xs: &[N],
397 ys: &[N],
398 initial: &[N],
399 mut jacobian: G,
400 params: &CurveFitParams<N>,
401) -> Result<SVector<N, V>, String>
402where
403 N: ComplexField + FromPrimitive + Copy,
404 <N as ComplexField>::RealField: FromPrimitive + Copy,
405 F: FnMut(N, &SVector<N, V>) -> N,
406 G: FnMut(N, &SVector<N, V>) -> SVector<N, V>,
407{
408 let tol = params.tolerance;
409 let mut damping = params.damping;
410 let damping_mult = params.damping_mult;
411
412 if !tol.is_sign_positive() {
413 return Err("curve_fit_jac: tol must be positive".to_owned());
414 }
415
416 if !damping.is_sign_positive() {
417 return Err("curve_fit_jac: damping must be positive".to_owned());
418 }
419
420 if xs.len() != ys.len() {
421 return Err("curve_fit_jac: xs length must match ys length".to_owned());
422 }
423
424 let mut params = SVector::<N, V>::from_column_slice(initial);
425 let ys = DVector::<N>::from_column_slice(ys);
426 let mut jac: DMatrix<N> = DMatrix::identity(xs.len(), params.len());
427 jac_analytic(&mut jacobian, xs, &mut params, &mut jac);
428 let mut jac_transpose = jac.transpose();
429
430 let (mut sum_sq, mut evaluation) = initial_residuals_exact(
431 xs,
432 &ys,
433 &mut damping,
434 damping_mult,
435 &mut f,
436 &mut jacobian,
437 &mut jac,
438 &mut jac_transpose,
439 params,
440 )?;
441
442 let mut last_sum_sq = sum_sq;
443 sum_sq += N::from_u8(2).unwrap().real() * tol;
444 while (last_sum_sq - sum_sq).abs() > tol {
445 last_sum_sq = sum_sq;
446 let diff = &ys - &evaluation;
448 let mut b = &jac_transpose * &diff;
449 let mut b_div = b.clone();
450 let mut multiplied = &jac_transpose * &jac;
452 let mut multiplied_div = multiplied.clone();
453 for i in 0..multiplied.row(0).len() {
454 multiplied[(i, i)] *= N::one() + N::from_real(damping);
455 }
456 let lu = multiplied.clone().lu();
459 let lu_solved = lu.solve_mut(&mut b);
460 if !lu_solved {
461 let lu = multiplied.clone().full_piv_lu();
462 let full_lu_solved = lu.solve_mut(&mut b);
463 if !full_lu_solved {
464 let qr = multiplied.qr();
465 let qr_solved = qr.solve_mut(&mut b);
466 if !qr_solved {
467 return Err("curve_fit_jac: unable to solve linear equation".to_owned());
468 }
469 }
470 }
471 let new_params = params + &b;
472
473 for i in 0..multiplied_div.row(0).len() {
475 multiplied_div[(i, i)] *= N::one() + N::from_real(damping / damping_mult);
476 }
477 let lu = multiplied_div.clone().lu();
478 let solved = lu.solve_mut(&mut b_div);
479 if !solved {
480 let lu = multiplied_div.clone().full_piv_lu();
481 let full_lu_solved = lu.solve_mut(&mut b_div);
482 if !full_lu_solved {
483 let qr = multiplied_div.qr();
484 let qr_solved = qr.solve_mut(&mut b_div);
485 if !qr_solved {
486 return Err("curve_fit_jac: unable to solve linear equation".to_owned());
487 }
488 }
489 }
490 let new_params_div = params + &b_div;
491
492 evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params)));
494 let evaluation_div =
495 DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params_div)));
496 let diff = &ys - &evaluation;
497 let diff_div = &ys - &evaluation_div;
498
499 let resid: N::RealField = diff
500 .iter()
501 .map(|&r| r.modulus_squared())
502 .fold(N::RealField::zero(), |acc, r| acc + r);
503 let resid_div: N::RealField = diff_div
504 .iter()
505 .map(|&r| r.modulus_squared())
506 .fold(N::RealField::zero(), |acc, r| acc + r);
507
508 if resid_div < resid {
509 damping /= damping_mult;
510 evaluation = evaluation_div;
511 params = new_params_div;
512 sum_sq = resid_div;
513 } else {
514 params = new_params;
515 sum_sq = resid;
516 }
517
518 jac_analytic(&mut jacobian, xs, &mut params, &mut jac);
519 jac_transpose = jac.transpose();
520 }
521
522 Ok(params)
523}