1use crate::polynomial::Polynomial;
2use nalgebra::{
3 allocator::Allocator, ComplexField, DMatrix, DVector, DefaultAllocator, DimName, RealField,
4 VectorN,
5};
6use num_traits::{FromPrimitive, One, Zero};
7
8pub fn linear_fit<N: ComplexField>(xs: &[N], ys: &[N]) -> Result<Polynomial<N>, String> {
10 if xs.len() != ys.len() {
11 return Err("linear_fit: xs length does not match ys length".to_owned());
12 }
13
14 let mut sum_x = N::zero();
15 let mut sum_y = N::zero();
16 let mut sum_x_sq = N::zero();
17 let mut sum_y_sq = N::zero();
18 let mut sum_xy = N::zero();
19
20 for (ind, x) in xs.iter().enumerate() {
21 sum_x += *x;
22 sum_y += ys[ind];
23 sum_x_sq += x.powi(2);
24 sum_y_sq += ys[ind].powi(2);
25 sum_xy += ys[ind] * *x;
26 }
27
28 let m = N::from_usize(xs.len()).unwrap();
29 let denom = m * sum_x_sq - sum_x.powi(2);
30 let a = (m * sum_xy - sum_x * sum_y) / denom;
31 let b = (sum_x_sq * sum_y - sum_xy * sum_x) / denom;
32
33 Ok(polynomial![a, b])
34}
35
36fn jac_finite_differences<N: ComplexField, V: DimName, F: FnMut(N, &VectorN<N, V>) -> N>(
38 mut f: F,
39 xs: &[N],
40 params: &mut VectorN<N, V>,
41 mat: &mut DMatrix<N>,
42 h: N::RealField,
43) where
44 DefaultAllocator: Allocator<N, V>,
45{
46 let h = N::from_real(h);
47 let denom = N::one() / (N::from_i32(2).unwrap() * h);
48 for row in 0..mat.column(0).len() {
49 for col in 0..mat.row(0).len() {
50 params[col] += h;
51 let above = f(xs[row], ¶ms);
52 params[col] -= h;
53 params[col] -= h;
54 let below = f(xs[row], ¶ms);
55 mat[(row, col)] = denom * (above + below);
56 params[col] += h;
57 }
58 }
59}
60
61fn jac_analytic<N: ComplexField, V: DimName, F: FnMut(N, &VectorN<N, V>) -> VectorN<N, V>>(
63 mut jac: F,
64 xs: &[N],
65 params: &mut VectorN<N, V>,
66 mat: &mut DMatrix<N>,
67) where
68 DefaultAllocator: Allocator<N, V>,
69{
70 for row in 0..mat.column(0).len() {
71 let deriv = jac(xs[row], ¶ms);
72 for col in 0..mat.row(0).len() {
73 mat[(row, col)] = deriv[col];
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
79#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
80pub struct CurveFitParams<N: ComplexField> {
81 pub damping: N::RealField,
82 pub tolerance: N::RealField,
83 pub h: N::RealField,
84 pub damping_mult: N::RealField,
85}
86
87impl<N: ComplexField> Default for CurveFitParams<N> {
88 fn default() -> Self {
89 CurveFitParams {
90 damping: N::RealField::from_f64(2.0).unwrap(),
91 tolerance: N::RealField::from_f64(1e-5).unwrap(),
92 h: N::RealField::from_f64(0.1).unwrap(),
93 damping_mult: N::RealField::from_f64(1.5).unwrap(),
94 }
95 }
96}
97
98pub fn curve_fit<N: ComplexField, V: DimName, F: FnMut(N, &VectorN<N, V>) -> N>(
105 mut f: F,
106 xs: &[N],
107 ys: &[N],
108 initial: &[N],
109 params: &CurveFitParams<N>,
110) -> Result<VectorN<N, V>, String>
111where
112 DefaultAllocator: Allocator<N, V>,
113{
114 let tol = params.tolerance;
115 let mut damping = params.damping;
116 let h = params.h;
117 let damping_mult = params.damping_mult;
118
119 if !tol.is_sign_positive() {
120 return Err("curve_fit: tol must be positive".to_owned());
121 }
122
123 if !h.is_sign_positive() {
124 return Err("curve_fit: h must be positive".to_owned());
125 }
126
127 if !damping.is_sign_positive() {
128 return Err("curve_fit: damping must be positive".to_owned());
129 }
130
131 if xs.len() != ys.len() {
132 return Err("curve_fit: xs length must match ys length".to_owned());
133 }
134
135 let mut params = VectorN::<N, V>::from_column_slice(initial);
136 let ys = DVector::<N>::from_column_slice(ys);
137 let mut jac: DMatrix<N> = DMatrix::identity(xs.len(), params.len());
138 jac_finite_differences(&mut f, xs, &mut params, &mut jac, h);
139 let mut jac_transpose = jac.transpose();
140
141 let mut resid = Vec::with_capacity(xs.len());
143 for (ind, &x) in xs.iter().enumerate() {
144 resid.push(ys[ind] - f(x, ¶ms));
145 }
146 let sum_sq_initial: N::RealField = resid
147 .iter()
148 .map(|&r| r.modulus_squared())
149 .fold(N::RealField::zero(), |acc, r| acc + r);
150
151 let mut sum_sq = sum_sq_initial + N::RealField::one();
153 let mut damping_tmp = damping / damping_mult;
154 let mut j = 0;
155 let mut evaluation: DVector<N> =
156 DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, ¶ms)));
157 while sum_sq > sum_sq_initial && j < 1000 {
158 damping_tmp *= damping_mult;
159 let diff = &ys - &evaluation;
160 let mut b = &jac_transpose * &diff;
161 let mut multiplied = &jac_transpose * &jac;
163 for i in 0..multiplied.row(0).len() {
164 multiplied[(i, i)] *= N::one() + N::from_real(damping_tmp);
165 }
166 let lu = multiplied.clone().lu();
167 let solved = lu.solve_mut(&mut b);
168 if !solved {
169 return Err("curve_fit: unable to solve linear equation".to_owned());
170 }
171 params += &b;
172 evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, ¶ms)));
173 let diff = &ys - &evaluation;
174 sum_sq = diff
175 .iter()
176 .map(|&r| r.modulus_squared())
177 .fold(N::RealField::zero(), |acc, r| acc + r);
178 j += 1;
179 jac_finite_differences(&mut f, xs, &mut params, &mut jac, h);
180 jac_transpose = jac.transpose();
181 }
182 if j != 1000 {
183 damping = damping_tmp;
184 }
185
186 let mut last_sum_sq = sum_sq;
187 sum_sq += N::RealField::from_i32(2).unwrap() * tol;
188 while (last_sum_sq - sum_sq).abs() > tol {
189 last_sum_sq = sum_sq;
190 let diff = &ys - &evaluation;
192 let mut b = &jac_transpose * &diff;
193 let mut b_div = b.clone();
194 let mut multiplied = &jac_transpose * &jac;
196 let mut multiplied_div = multiplied.clone();
197 for i in 0..multiplied.row(0).len() {
198 multiplied[(i, i)] *= N::one() + N::from_real(damping);
199 }
200 let lu = multiplied.clone().lu();
202 let solved = lu.solve_mut(&mut b);
203 if !solved {
204 return Err("curve_fit: unable to solve linear equation".to_owned());
205 }
206 let new_params = ¶ms + &b;
207
208 for i in 0..multiplied_div.row(0).len() {
210 multiplied_div[(i, i)] *= N::one() + N::from_real(damping / damping_mult);
211 }
212 let lu = multiplied_div.clone().lu();
213 let solved = lu.solve_mut(&mut b_div);
214 if !solved {
215 let lu = multiplied_div.clone().full_piv_lu();
216 let solved = lu.solve_mut(&mut b_div);
217 if !solved {
218 let qr = multiplied_div.qr();
219 let solved = qr.solve_mut(&mut b_div);
220 if !solved {
221 return Err("curve_fit: unable to solve linear equation".to_owned());
222 }
223 }
224 }
225 let new_params_div = ¶ms + &b_div;
226
227 evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params)));
229 let evaluation_div =
230 DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params_div)));
231 let diff = &ys - &evaluation;
232 let diff_div = &ys - &evaluation_div;
233
234 let resid: N::RealField = diff
235 .iter()
236 .map(|&r| r.modulus_squared())
237 .fold(N::RealField::zero(), |acc, r| acc + r);
238 let resid_div: N::RealField = diff_div
239 .iter()
240 .map(|&r| r.modulus_squared())
241 .fold(N::RealField::zero(), |acc, r| acc + r);
242
243 if resid_div < resid {
244 damping /= damping_mult;
245 evaluation = evaluation_div;
246 params = new_params_div;
247 sum_sq = resid_div;
248 } else {
249 params = new_params;
250 sum_sq = resid;
251 }
252
253 jac_finite_differences(&mut f, xs, &mut params, &mut jac, h);
254 jac_transpose = jac.transpose();
255 }
256
257 Ok(params)
258}
259
260pub fn curve_fit_jac<
267 N: ComplexField,
268 V: DimName,
269 F: FnMut(N, &VectorN<N, V>) -> N,
270 G: FnMut(N, &VectorN<N, V>) -> VectorN<N, V>,
271>(
272 mut f: F,
273 xs: &[N],
274 ys: &[N],
275 initial: &[N],
276 mut jacobian: G,
277 params: &CurveFitParams<N>,
278) -> Result<VectorN<N, V>, String>
279where
280 DefaultAllocator: Allocator<N, V>,
281{
282 let tol = params.tolerance;
283 let mut damping = params.damping;
284 let damping_mult = params.damping_mult;
285
286 if !tol.is_sign_positive() {
287 return Err("curve_fit_jac: tol must be positive".to_owned());
288 }
289
290 if !damping.is_sign_positive() {
291 return Err("curve_fit_jac: damping must be positive".to_owned());
292 }
293
294 if xs.len() != ys.len() {
295 return Err("curve_fit_jac: xs length must match ys length".to_owned());
296 }
297
298 let mut params = VectorN::<N, V>::from_column_slice(initial);
299 let ys = DVector::<N>::from_column_slice(ys);
300 let mut jac: DMatrix<N> = DMatrix::identity(xs.len(), params.len());
301 jac_analytic(&mut jacobian, xs, &mut params, &mut jac);
302 let mut jac_transpose = jac.transpose();
303
304 let mut resid = Vec::with_capacity(xs.len());
306 for (ind, &x) in xs.iter().enumerate() {
307 resid.push(ys[ind] - f(x, ¶ms));
308 }
309 let sum_sq_initial: N::RealField = resid
310 .iter()
311 .map(|&r| r.modulus_squared())
312 .fold(N::RealField::zero(), |acc, r| acc + r);
313
314 let mut sum_sq = sum_sq_initial + N::RealField::one();
316 let mut damping_tmp = damping / damping_mult;
317 let mut j = 0;
318 let mut evaluation: DVector<N> =
319 DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, ¶ms)));
320 while sum_sq > sum_sq_initial && j < 1000 {
321 damping_tmp *= damping_mult;
322 let diff = &ys - &evaluation;
323 let mut b = &jac_transpose * &diff;
324 let mut multiplied = &jac_transpose * &jac;
326 for i in 0..multiplied.row(0).len() {
327 multiplied[(i, i)] *= N::one() + N::from_real(damping_tmp);
328 }
329 let lu = multiplied.clone().lu();
330 let solved = lu.solve_mut(&mut b);
331 if !solved {
332 let lu = multiplied.clone().full_piv_lu();
333 let solved = lu.solve_mut(&mut b);
334 if !solved {
335 let qr = multiplied.qr();
336 let solved = qr.solve_mut(&mut b);
337 if !solved {
338 return Err("curve_fit_jac: unable to solve linear equation".to_owned());
339 }
340 }
341 }
342 params += &b;
343 evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, ¶ms)));
344 let diff = &ys - &evaluation;
345 sum_sq = diff
346 .iter()
347 .map(|&r| r.modulus_squared())
348 .fold(N::RealField::zero(), |acc, r| acc + r);
349 j += 1;
350 jac_analytic(&mut jacobian, xs, &mut params, &mut jac);
351 jac_transpose = jac.transpose();
352 }
353 if j != 1000 {
354 damping = damping_tmp;
355 }
356
357 let mut last_sum_sq = sum_sq;
358 sum_sq += N::RealField::from_i32(2).unwrap() * tol;
359 while (last_sum_sq - sum_sq).abs() > tol {
360 last_sum_sq = sum_sq;
361 let diff = &ys - &evaluation;
363 let mut b = &jac_transpose * &diff;
364 let mut b_div = b.clone();
365 let mut multiplied = &jac_transpose * &jac;
367 let mut multiplied_div = multiplied.clone();
368 for i in 0..multiplied.row(0).len() {
369 multiplied[(i, i)] *= N::one() + N::from_real(damping);
370 }
371 let lu = multiplied.clone().lu();
374 let solved = lu.solve_mut(&mut b);
375 if !solved {
376 let lu = multiplied.clone().full_piv_lu();
377 let solved = lu.solve_mut(&mut b);
378 if !solved {
379 let qr = multiplied.qr();
380 let solved = qr.solve_mut(&mut b);
381 if !solved {
382 return Err("curve_fit_jac: unable to solve linear equation".to_owned());
383 }
384 }
385 }
386 let new_params = ¶ms + &b;
387
388 for i in 0..multiplied_div.row(0).len() {
390 multiplied_div[(i, i)] *= N::one() + N::from_real(damping / damping_mult);
391 }
392 let lu = multiplied_div.clone().lu();
393 let solved = lu.solve_mut(&mut b_div);
394 if !solved {
395 let lu = multiplied_div.clone().full_piv_lu();
396 let solved = lu.solve_mut(&mut b_div);
397 if !solved {
398 let qr = multiplied_div.qr();
399 let solved = qr.solve_mut(&mut b_div);
400 if !solved {
401 return Err("curve_fit_jac: unable to solve linear equation".to_owned());
402 }
403 }
404 }
405 let new_params_div = ¶ms + &b_div;
406
407 evaluation = DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params)));
409 let evaluation_div =
410 DVector::from_iterator(xs.len(), xs.iter().map(|&x| f(x, &new_params_div)));
411 let diff = &ys - &evaluation;
412 let diff_div = &ys - &evaluation_div;
413
414 let resid: N::RealField = diff
415 .iter()
416 .map(|&r| r.modulus_squared())
417 .fold(N::RealField::zero(), |acc, r| acc + r);
418 let resid_div: N::RealField = diff_div
419 .iter()
420 .map(|&r| r.modulus_squared())
421 .fold(N::RealField::zero(), |acc, r| acc + r);
422
423 if resid_div < resid {
424 damping /= damping_mult;
425 evaluation = evaluation_div;
426 params = new_params_div;
427 sum_sq = resid_div;
428 } else {
429 params = new_params;
430 sum_sq = resid;
431 }
432
433 jac_analytic(&mut jacobian, xs, &mut params, &mut jac);
434 jac_transpose = jac.transpose();
435 }
436
437 Ok(params)
438}