1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
12use scirs2_core::numeric::{Float, FromPrimitive};
13use std::fmt::Debug;
14use std::marker::PhantomData;
15
16use crate::error::{InterpolateError, InterpolateResult};
17
18#[derive(Debug, Clone, Copy, PartialEq)]
20pub enum WeightFunction {
21 Gaussian,
23
24 WendlandC2,
27
28 InverseDistance,
30
31 CubicSpline,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq)]
40pub enum PolynomialBasis {
41 Constant,
43
44 Linear,
46
47 Quadratic,
49}
50
51#[derive(Debug, Clone)]
90pub struct MovingLeastSquares<F>
91where
92 F: Float + FromPrimitive + Debug + 'static + std::cmp::PartialOrd,
93{
94 points: Array2<F>,
96
97 values: Array1<F>,
99
100 weight_fn: WeightFunction,
102
103 basis: PolynomialBasis,
105
106 bandwidth: F,
108
109 epsilon: F,
111
112 max_points: Option<usize>,
114
115 _phantom: PhantomData<F>,
117}
118
119impl<F> MovingLeastSquares<F>
120where
121 F: Float + FromPrimitive + Debug + 'static + std::cmp::PartialOrd,
122{
123 pub fn new(
137 points: Array2<F>,
138 values: Array1<F>,
139 weight_fn: WeightFunction,
140 basis: PolynomialBasis,
141 bandwidth: F,
142 ) -> InterpolateResult<Self> {
143 if points.shape()[0] != values.len() {
145 return Err(InterpolateError::DimensionMismatch(
146 "Number of points must match number of values".to_string(),
147 ));
148 }
149
150 if points.shape()[0] < 2 {
151 return Err(InterpolateError::InvalidValue(
152 "At least 2 points are required for MLS interpolation".to_string(),
153 ));
154 }
155
156 if bandwidth <= F::zero() {
157 return Err(InterpolateError::InvalidValue(
158 "Bandwidth parameter must be positive".to_string(),
159 ));
160 }
161
162 Ok(Self {
163 points,
164 values,
165 weight_fn,
166 basis,
167 bandwidth,
168 epsilon: F::from_f64(1e-10).unwrap(),
169 max_points: None,
170 _phantom: PhantomData,
171 })
172 }
173
174 pub fn with_max_points(mut self, maxpoints: usize) -> Self {
187 self.max_points = Some(maxpoints);
188 self
189 }
190
191 pub fn with_epsilon(mut self, epsilon: F) -> Self {
201 self.epsilon = epsilon;
202 self
203 }
204
205 pub fn evaluate(&self, x: &ArrayView1<F>) -> InterpolateResult<F> {
215 if x.len() != self.points.shape()[1] {
217 return Err(InterpolateError::DimensionMismatch(
218 "Query point dimension must match training points".to_string(),
219 ));
220 }
221
222 let (indices, distances) = self.find_relevant_points(x)?;
224
225 if indices.is_empty() {
226 return Err(InterpolateError::invalid_input(
227 "No points found within effective range".to_string(),
228 ));
229 }
230
231 let weights = self.compute_weights(&distances)?;
233
234 let basis_functions = self.create_basis_functions(&indices, x)?;
236
237 let result = self.solve_weighted_least_squares(&indices, &weights, &basis_functions, x)?;
239
240 Ok(result)
241 }
242
243 pub fn evaluate_multi(&self, points: &ArrayView2<F>) -> InterpolateResult<Array1<F>> {
253 if points.shape()[1] != self.points.shape()[1] {
255 return Err(InterpolateError::DimensionMismatch(
256 "Query points dimension must match training points".to_string(),
257 ));
258 }
259
260 let n_points = points.shape()[0];
261 let mut results = Array1::zeros(n_points);
262
263 for i in 0..n_points {
265 let point = points.slice(scirs2_core::ndarray::s![i, ..]);
266 results[i] = self.evaluate(&point)?;
267 }
268
269 Ok(results)
270 }
271
272 fn find_relevant_points(&self, x: &ArrayView1<F>) -> InterpolateResult<(Vec<usize>, Vec<F>)> {
276 let n_points = self.points.shape()[0];
277 let n_dims = self.points.shape()[1];
278
279 let mut distances = Vec::with_capacity(n_points);
281 for i in 0..n_points {
282 let mut d_squared = F::zero();
283 for j in 0..n_dims {
284 let diff = x[j] - self.points[[i, j]];
285 d_squared = d_squared + diff * diff;
286 }
287 let dist = d_squared.sqrt();
288 distances.push((i, dist));
289 }
290
291 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
293
294 let limit = match self.max_points {
296 Some(limit) => std::cmp::min(limit, n_points),
297 None => n_points,
298 };
299
300 let effective_radius = match self.weight_fn {
302 WeightFunction::WendlandC2 | WeightFunction::CubicSpline => self.bandwidth,
303 _ => F::infinity(),
304 };
305
306 let mut indices = Vec::new();
307 let mut dist_values = Vec::new();
308
309 for &(idx, dist) in distances.iter().take(limit) {
310 if dist <= effective_radius {
311 indices.push(idx);
312 dist_values.push(dist);
313 }
314 }
315
316 let min_points = match self.basis {
318 PolynomialBasis::Constant => 1,
319 PolynomialBasis::Linear => n_dims + 1,
320 PolynomialBasis::Quadratic => ((n_dims + 1) * (n_dims + 2)) / 2,
321 };
322
323 if indices.len() < min_points {
324 indices = distances
326 .iter()
327 .take(min_points)
328 .map(|&(idx, _)| idx)
329 .collect();
330 dist_values = distances
331 .iter()
332 .take(min_points)
333 .map(|&(_, dist)| dist)
334 .collect();
335 }
336
337 Ok((indices, dist_values))
338 }
339
340 fn compute_weights(&self, distances: &[F]) -> InterpolateResult<Array1<F>> {
342 let n = distances.len();
343 let mut weights = Array1::zeros(n);
344
345 for (i, &d) in distances.iter().enumerate() {
346 let r = d / self.bandwidth;
348
349 let weight = match self.weight_fn {
351 WeightFunction::Gaussian => (-r * r).exp(),
352 WeightFunction::WendlandC2 => {
353 if r < F::one() {
354 let t = F::one() - r;
355 let factor = F::from_f64(4.0).unwrap() * r + F::one();
356 t.powi(4) * factor
357 } else {
358 F::zero()
359 }
360 }
361 WeightFunction::InverseDistance => F::one() / (self.epsilon + r * r),
362 WeightFunction::CubicSpline => {
363 if r < F::from_f64(1.0 / 3.0).unwrap() {
364 let r2 = r * r;
365 let r3 = r2 * r;
366 F::from_f64(2.0 / 3.0).unwrap() - F::from_f64(9.0).unwrap() * r2
367 + F::from_f64(19.0).unwrap() * r3
368 } else if r < F::one() {
369 let t = F::from_f64(2.0).unwrap() - F::from_f64(3.0).unwrap() * r;
370 F::from_f64(1.0 / 3.0).unwrap() * t.powi(3)
371 } else {
372 F::zero()
373 }
374 }
375 };
376
377 weights[i] = weight;
378 }
379
380 let sum = weights.sum();
382 if sum > F::zero() {
383 weights.mapv_inplace(|w| w / sum);
384 } else {
385 weights.fill(F::from_f64(1.0 / n as f64).unwrap());
387 }
388
389 Ok(weights)
390 }
391
392 fn create_basis_functions(
394 &self,
395 indices: &[usize],
396 x: &ArrayView1<F>,
397 ) -> InterpolateResult<Array2<F>> {
398 let n_points = indices.len();
399 let n_dims = x.len();
400
401 let n_basis = match self.basis {
403 PolynomialBasis::Constant => 1,
404 PolynomialBasis::Linear => n_dims + 1,
405 PolynomialBasis::Quadratic => ((n_dims + 1) * (n_dims + 2)) / 2,
406 };
407
408 let mut basis = Array2::zeros((n_points, n_basis));
409
410 for (i, &idx) in indices.iter().enumerate() {
412 let point = self.points.row(idx);
413 let mut col = 0;
414
415 basis[[i, col]] = F::one();
417 col += 1;
418
419 if self.basis == PolynomialBasis::Linear || self.basis == PolynomialBasis::Quadratic {
420 for j in 0..n_dims {
422 basis[[i, col]] = point[j];
423 col += 1;
424 }
425 }
426
427 if self.basis == PolynomialBasis::Quadratic {
428 for j in 0..n_dims {
430 for k in j..n_dims {
431 basis[[i, col]] = point[j] * point[k];
432 col += 1;
433 }
434 }
435 }
436 }
437
438 Ok(basis)
439 }
440
441 fn create_query_basis(&self, x: &ArrayView1<F>) -> InterpolateResult<Array1<F>> {
443 let n_dims = x.len();
444
445 let n_basis = match self.basis {
447 PolynomialBasis::Constant => 1,
448 PolynomialBasis::Linear => n_dims + 1,
449 PolynomialBasis::Quadratic => ((n_dims + 1) * (n_dims + 2)) / 2,
450 };
451
452 let mut basis = Array1::zeros(n_basis);
453 let mut col = 0;
454
455 basis[col] = F::one();
457 col += 1;
458
459 if self.basis == PolynomialBasis::Linear || self.basis == PolynomialBasis::Quadratic {
460 for j in 0..n_dims {
462 basis[col] = x[j];
463 col += 1;
464 }
465 }
466
467 if self.basis == PolynomialBasis::Quadratic {
468 for j in 0..n_dims {
470 for k in j..n_dims {
471 basis[col] = x[j] * x[k];
472 col += 1;
473 }
474 }
475 }
476
477 Ok(basis)
478 }
479
480 fn solve_weighted_least_squares(
482 &self,
483 indices: &[usize],
484 weights: &Array1<F>,
485 basis: &Array2<F>,
486 x: &ArrayView1<F>,
487 ) -> InterpolateResult<F> {
488 let n_points = indices.len();
489 let n_basis = basis.shape()[1];
490
491 let mut w_basis = Array2::zeros((n_points, n_basis));
493 let mut w_values = Array1::zeros(n_points);
494
495 for i in 0..n_points {
496 let sqrt_w = weights[i].sqrt();
497 for j in 0..n_basis {
498 w_basis[[i, j]] = basis[[i, j]] * sqrt_w;
499 }
500 w_values[i] = self.values[indices[i]] * sqrt_w;
501 }
502
503 #[cfg(feature = "linalg")]
505 let btb = w_basis.t().dot(&w_basis);
506 #[cfg(not(feature = "linalg"))]
507 let _btb = w_basis.t().dot(&w_basis);
508 #[allow(unused_variables)]
509 let bty = w_basis.t().dot(&w_values);
510
511 #[cfg(feature = "linalg")]
513 let coeffs = {
514 use scirs2_linalg::solve;
515 let btb_f64 = btb.mapv(|x| x.to_f64().unwrap());
516 let bty_f64 = bty.mapv(|x| x.to_f64().unwrap());
517 match solve(&btb_f64.view(), &bty_f64.view(), None) {
518 Ok(c) => c.mapv(|x| F::from_f64(x).unwrap()),
519 Err(_) => {
520 let mut mean = F::zero();
522 let mut sum_weights = F::zero();
523 for (i, &idx) in indices.iter().enumerate() {
524 mean = mean + weights[i] * self.values[idx];
525 sum_weights = sum_weights + weights[i];
526 }
527
528 if sum_weights > F::zero() {
529 let mut fallback_coeffs = Array1::zeros(bty.len());
532 fallback_coeffs[0] = mean / sum_weights;
533 fallback_coeffs
534 } else {
535 return Err(InterpolateError::ComputationError(
536 "Failed to solve weighted least squares system".to_string(),
537 ));
538 }
539 }
540 }
541 };
542
543 #[cfg(not(feature = "linalg"))]
544 let coeffs = {
545 let mut result = Array1::zeros(bty.len());
548
549 let mut mean = F::zero();
551 let mut sum_weights = F::zero();
552 for (i, &idx) in indices.iter().enumerate() {
553 mean = mean + weights[i] * self.values[idx];
554 sum_weights = sum_weights + weights[i];
555 }
556
557 if sum_weights > F::zero() {
558 result[0] = mean / sum_weights;
559 }
560
561 result
562 };
563
564 let query_basis = self.create_query_basis(x)?;
566 let result = query_basis.dot(&coeffs);
567
568 Ok(result)
569 }
570
571 pub fn weight_fn(&self) -> WeightFunction {
573 self.weight_fn
574 }
575
576 pub fn bandwidth(&self) -> F {
578 self.bandwidth
579 }
580
581 pub fn points(&self) -> &Array2<F> {
583 &self.points
584 }
585
586 pub fn values(&self) -> &Array1<F> {
588 &self.values
589 }
590
591 pub fn basis(&self) -> PolynomialBasis {
593 self.basis
594 }
595
596 pub fn max_points(&self) -> Option<usize> {
598 self.max_points
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605 use approx::assert_abs_diff_eq;
606 use scirs2_core::ndarray::array;
607
608 #[test]
609 fn test_mls_constant_basis() {
610 let points =
612 Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]).unwrap();
613
614 let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
616
617 let mls = MovingLeastSquares::new(
618 points,
619 values,
620 WeightFunction::Gaussian,
621 PolynomialBasis::Constant,
622 0.5,
623 )
624 .unwrap();
625
626 let center = array![0.5, 0.5];
628 let val = mls.evaluate(¢er.view()).unwrap();
629
630 assert_abs_diff_eq!(val, 1.0, epsilon = 0.1);
631 }
632
633 #[test]
634 fn test_mls_linear_basis() {
635 let points =
637 Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]).unwrap();
638
639 let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
641
642 let mls = MovingLeastSquares::new(
643 points,
644 values,
645 WeightFunction::Gaussian,
646 PolynomialBasis::Linear,
647 1.0,
648 )
649 .unwrap();
650
651 let test_points = Array2::from_shape_vec(
653 (5, 2),
654 vec![
655 0.5, 0.5, 0.25, 0.25, 0.75, 0.25, 0.25, 0.75, 0.75, 0.75, ],
661 )
662 .unwrap();
663
664 let expected = Array1::from_vec(vec![1.0, 0.5, 1.0, 1.0, 1.5]);
665 let results = mls.evaluate_multi(&test_points.view()).unwrap();
666
667 for (result, expect) in results.iter().zip(expected.iter()) {
669 assert_abs_diff_eq!(result, expect, epsilon = 0.5);
670 }
671 }
672
673 #[test]
674 fn test_different_weight_functions() {
675 let points = Array2::from_shape_vec(
677 (6, 2),
678 vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.3, 0.3, 0.7, 0.7],
679 )
680 .unwrap();
681
682 let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0, 0.6, 1.4]);
684
685 let weight_fns = [WeightFunction::Gaussian, WeightFunction::InverseDistance];
687
688 let query = array![0.5, 0.5];
689 let expected = 0.5 + 0.5; for &weight_fn in &weight_fns {
692 let mls = MovingLeastSquares::new(
693 points.clone(),
694 values.clone(),
695 weight_fn,
696 PolynomialBasis::Linear, 2.0, )
699 .unwrap();
700
701 let result = mls.evaluate(&query.view());
702
703 match result {
704 Ok(val) => {
705 if val.is_finite() {
706 assert!((val - expected).abs() < 0.5,
708 "Weight function {:?}: result {:.6} differs too much from expected {:.6}",
709 weight_fn, val, expected);
710 } else {
711 panic!(
712 "Weight function {:?} produced non-finite result: {}",
713 weight_fn, val
714 );
715 }
716 }
717 Err(e) => {
718 panic!("Weight function {:?} failed with error: {}", weight_fn, e);
719 }
720 }
721 }
722 }
723}