1use crate::gauss::Rule;
7use crate::interpolation1d::{evaluate_legendre_basis, legendre_collocation_matrix};
8use crate::numeric::CustomNumeric;
9use mdarray::DTensor;
10use std::fmt::Debug;
11
12#[derive(Debug, Clone)]
17pub struct Interpolate2D<T> {
18 pub x_min: T,
20 pub x_max: T,
21 pub y_min: T,
22 pub y_max: T,
23
24 pub coeffs: DTensor<T, 2>,
26
27 pub gauss_x: Rule<T>,
29 pub gauss_y: Rule<T>,
30}
31
32impl<T: CustomNumeric + Debug + 'static> Interpolate2D<T> {
33 pub fn new(values: &DTensor<T, 2>, gauss_x: &Rule<T>, gauss_y: &Rule<T>) -> Self {
43 let shape = *values.shape();
44 assert!(
45 shape.0 > 0 && shape.1 > 0,
46 "Cannot create interpolation from empty grid"
47 );
48 assert_eq!(
49 shape.0,
50 gauss_x.x.len(),
51 "Values height must match gauss_x length"
52 );
53 assert_eq!(
54 shape.1,
55 gauss_y.x.len(),
56 "Values width must match gauss_y length"
57 );
58
59 let normalized_gauss_x =
62 gauss_x.reseat(T::from_f64_unchecked(-1.0), T::from_f64_unchecked(1.0));
63 let normalized_gauss_y =
64 gauss_y.reseat(T::from_f64_unchecked(-1.0), T::from_f64_unchecked(1.0));
65
66 let coeffs = interpolate_2d_legendre(values, &normalized_gauss_x, &normalized_gauss_y);
67
68 Self {
69 x_min: gauss_x.a,
70 x_max: gauss_x.b,
71 y_min: gauss_y.a,
72 y_max: gauss_y.b,
73 coeffs,
74 gauss_x: gauss_x.clone(),
75 gauss_y: gauss_y.clone(),
76 }
77 }
78
79 pub fn interpolate(&self, x: T, y: T) -> T {
88 assert!(
89 x >= self.x_min && x <= self.x_max,
90 "x={} is outside cell bounds [{}, {}]",
91 x,
92 self.x_min,
93 self.x_max
94 );
95 assert!(
96 y >= self.y_min && y <= self.y_max,
97 "y={} is outside cell bounds [{}, {}]",
98 y,
99 self.y_min,
100 self.y_max
101 );
102
103 evaluate_2d_legendre_polynomial(x, y, &self.coeffs, &self.gauss_x, &self.gauss_y)
104 }
105
106 pub fn coefficients(&self) -> &DTensor<T, 2> {
108 &self.coeffs
109 }
110
111 pub fn bounds(&self) -> (T, T, T, T) {
113 (self.x_min, self.x_max, self.y_min, self.y_max)
114 }
115
116 pub fn domain(&self) -> (T, T, T, T) {
118 self.bounds()
119 }
120
121 pub fn n_points_x(&self) -> usize {
123 self.coeffs.shape().0
124 }
125
126 pub fn n_points_y(&self) -> usize {
128 self.coeffs.shape().1
129 }
130
131 pub fn evaluate(&self, x: T, y: T) -> T {
133 self.interpolate(x, y)
134 }
135}
136
137pub fn interpolate_2d_legendre<T: CustomNumeric + 'static>(
151 values: &DTensor<T, 2>,
152 gauss_x: &Rule<T>,
153 gauss_y: &Rule<T>,
154) -> DTensor<T, 2> {
155 let n_x = gauss_x.x.len();
156 let n_y = gauss_y.x.len();
157
158 let shape = *values.shape();
159 assert_eq!(shape.0, n_x, "Values matrix rows must match x grid points");
160 assert_eq!(shape.1, n_y, "Values matrix cols must match y grid points");
161
162 let collocation_x = legendre_collocation_matrix(gauss_x);
164 let collocation_y = legendre_collocation_matrix(gauss_y);
165
166 let mut temp = DTensor::<T, 2>::from_elem([n_x, n_y], T::zero());
169 for i in 0..n_x {
170 for j in 0..n_y {
171 for k in 0..n_x {
172 temp[[i, j]] = temp[[i, j]] + collocation_x[[i, k]] * values[[k, j]];
173 }
174 }
175 }
176
177 let mut coeffs = DTensor::<T, 2>::from_elem([n_x, n_y], T::zero());
178 for i in 0..n_x {
179 for j in 0..n_y {
180 for k in 0..n_y {
181 coeffs[[i, j]] = coeffs[[i, j]] + temp[[i, k]] * collocation_y[[j, k]];
182 }
183 }
184 }
185
186 coeffs
187}
188
189pub fn evaluate_2d_legendre_polynomial<T: CustomNumeric>(
191 x: T,
192 y: T,
193 coeffs: &DTensor<T, 2>,
194 gauss_x: &Rule<T>,
195 gauss_y: &Rule<T>,
196) -> T {
197 let shape = *coeffs.shape();
198 let n_x = shape.0;
199 let n_y = shape.1;
200
201 let x_norm = T::from_f64_unchecked(2.0) * (x - gauss_x.a) / (gauss_x.b - gauss_x.a)
203 - T::from_f64_unchecked(1.0);
204 let y_norm = T::from_f64_unchecked(2.0) * (y - gauss_y.a) / (gauss_y.b - gauss_y.a)
205 - T::from_f64_unchecked(1.0);
206
207 let p_x = evaluate_legendre_basis(x_norm, n_x);
209 let p_y = evaluate_legendre_basis(y_norm, n_y);
210
211 let mut result = T::zero();
213 for i in 0..n_x {
214 for j in 0..n_y {
215 result = result + coeffs[[i, j]] * p_x[i] * p_y[j];
216 }
217 }
218
219 result
220}
221
222#[cfg(test)]
223#[path = "interpolation2d_tests.rs"]
224mod tests;