1use crate::axis;
2use crate::bound;
3use crate::common;
4use crate::search;
5use crate::Error;
6use std::ops::{Add, Div, Mul, Sub};
7
8use ndarray::Array2;
9
10pub struct LookupTable2D<Axis1, Axis2, Dep>
11where
12 Axis1: axis::AxisImpl,
13 Axis2: axis::AxisImpl,
14{
15 indep1: Vec<<Axis1 as axis::AxisImpl>::Indep>,
16 search1: <Axis1 as axis::AxisImpl>::Search,
17 indep2: Vec<<Axis2 as axis::AxisImpl>::Indep>,
18 search2: <Axis2 as axis::AxisImpl>::Search,
19 dep: Array2<Dep>,
20}
21
22impl<Indep1, Search1, LowerBound1, UpperBound1, Indep2, Search2, LowerBound2, UpperBound2, Dep>
23 LookupTable2D<
24 axis::Axis<Indep1, Search1, LowerBound1, UpperBound1>,
25 axis::Axis<Indep2, Search2, LowerBound2, UpperBound2>,
26 Dep,
27 >
28where
29 Indep1: std::cmp::PartialOrd,
30 Indep2: std::cmp::PartialOrd,
31{
32 pub fn new(
33 mut indep1: Vec<Indep1>,
34 search1: Search1,
35 mut indep2: Vec<Indep2>,
36 search2: Search2,
37 mut dep: Array2<Dep>,
38 ) -> Result<Self, Error> {
39 match common::check_independent_variable(indep1.as_slice())? {
40 common::IndependentVariableOrdering::MonotonicallyIncreasing => {}
41 common::IndependentVariableOrdering::MonotonicallyDecreasing => {
42 indep1.reverse();
43 dep.invert_axis(ndarray::Axis(0));
44 dbg!("reversing");
45 }
46 }
47
48 match common::check_independent_variable(indep2.as_slice())? {
49 common::IndependentVariableOrdering::MonotonicallyIncreasing => {}
50 common::IndependentVariableOrdering::MonotonicallyDecreasing => {
51 indep2.reverse();
52 dep.invert_axis(ndarray::Axis(1));
53 dbg!("reversing");
54 }
55 }
56
57 common::check_lengths(indep1.len(), dep.len_of(ndarray::Axis(0)))?;
58 common::check_lengths(indep2.len(), dep.len_of(ndarray::Axis(1)))?;
59
60 Ok(Self {
61 indep1,
62 search1,
63 indep2,
64 search2,
65 dep,
66 })
67 }
68}
69
70impl<Indep1, Search1, LowerBound1, UpperBound1, Indep2, Search2, LowerBound2, UpperBound2, Dep>
71 LookupTable2D<
72 axis::Axis<Indep1, Search1, LowerBound1, UpperBound1>,
73 axis::Axis<Indep2, Search2, LowerBound2, UpperBound2>,
74 Dep,
75 >
76where
77 Search1: search::Search<Indep1>,
78 Search2: search::Search<Indep2>,
79 Dep: Copy
81 + Sub<Dep, Output = Dep>
82 + Div<Indep1, Output = Dep>
83 + Mul<Indep1, Output = Dep>
84 + Mul<Indep2, Output = Dep>
85 + Add<Dep, Output = Dep>
86 + std::fmt::Debug,
87 Indep1: Copy
88 + Sub<Indep1, Output = Indep1>
89 + std::cmp::PartialOrd
90 + Div<Indep1, Output = Indep1>
91 + std::fmt::Debug,
93 Indep2: Copy
94 + Sub<Indep2, Output = Indep2>
95 + std::cmp::PartialOrd
96 + Div<Indep2, Output = Indep2>
97 + std::fmt::Debug,
99 LowerBound1: bound::Bound<Indep1>,
100 UpperBound1: bound::Bound<Indep1>,
101 LowerBound2: bound::Bound<Indep2>,
102 UpperBound2: bound::Bound<Indep2>,
103{
104 pub fn lookup(&self, x: Indep1, y: Indep2) -> Dep {
105 let (idx_x_1, idx_x_2) = self.search1.search(x, self.indep1.as_slice());
106 let (idx_y_1, idx_y_2) = self.search2.search(y, self.indep2.as_slice());
107
108 let x_1: Indep1 = self.indep1[idx_x_1];
109 let x_2: Indep1 = self.indep1[idx_x_2];
110
111 let y_1: Indep2 = self.indep2[idx_y_1];
112 let y_2: Indep2 = self.indep2[idx_y_2];
113
114 let f_1_1: Dep = self.dep[[idx_x_1, idx_y_1]];
115 let f_1_2: Dep = self.dep[[idx_x_1, idx_y_2]];
116 let f_2_1: Dep = self.dep[[idx_x_2, idx_y_1]];
117 let f_2_2: Dep = self.dep[[idx_x_2, idx_y_2]];
118
119 let x = LowerBound1::lower_bound(x, *self.indep1.first().unwrap());
122 let x = UpperBound1::upper_bound(x, *self.indep1.last().unwrap());
123
124 let y = LowerBound2::lower_bound(y, *self.indep2.first().unwrap());
127 let y = UpperBound2::upper_bound(y, *self.indep2.last().unwrap());
128
129 let x_slope1 = (x_2 - x) / (x_2 - x_1);
130 let x_slope2 = (x - x_1) / (x_2 - x_1);
131 let y_slope1 = (y_2 - y) / (y_2 - y_1);
132 let y_slope2 = (y - y_1) / (y_2 - y_1);
133
134 let f_x_y1 = f_1_1 * x_slope1 + f_2_1 * x_slope2;
135 let f_x_y2 = f_1_2 * x_slope1 + f_2_2 * x_slope2;
136
137 f_x_y1 * y_slope1 + f_x_y2 * y_slope2
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 const TOL: f64 = 1e-10;
146
147 type LinearAxis = axis::Axis<f64, search::Linear, bound::Interp, bound::Interp>;
148 type TableLinLin = LookupTable2D<LinearAxis, LinearAxis, f64>;
149
150 fn func(x: f64, y: f64) -> f64 {
151 3. * x + y
152 }
153
154 fn data() -> (Vec<f64>, Vec<f64>, Array2<f64>) {
158 let len = 100;
159 let x = ndarray::Array1::linspace(0., 5.0, len).to_vec();
160 let y = x.clone();
161 let mut f = ndarray::Array2::zeros((x.len(), y.len()));
162
163 for i in 0..x.len() {
164 for j in 0..y.len() {
165 f[[i, j]] = func(x[i], y[j]);
166 }
167 }
168
169 (x, y, f)
170 }
171
172 fn linear_simple_table() -> LookupTable2D<LinearAxis, LinearAxis, f64> {
173 let (x, y, f) = data();
174 let search1 = search::Linear::default();
175 let search2 = search::Linear::default();
176 LookupTable2D::new(x, search1, y, search2, f).unwrap()
177 }
178
179 #[test]
184 fn construct_table_repeated_entries_1() {
185 let (mut x, y, f) = data();
186
187 x[0] = 0.;
188 x[1] = 0.;
189
190 let search1 = search::Linear::default();
191 let search2 = search::Linear::default();
192
193 let output: Result<TableLinLin, _> = LookupTable2D::new(x, search1, y, search2, f);
194 assert!(output.is_err());
195 }
196
197 #[test]
198 fn construct_table_repeated_entries_2() {
199 let (x, mut y, f) = data();
200
201 y[0] = 0.;
202 y[1] = 0.;
203
204 let search1 = search::Linear::default();
205 let search2 = search::Linear::default();
206
207 let output: Result<TableLinLin, _> = LookupTable2D::new(x, search1, y, search2, f);
208 assert!(output.is_err());
209 }
210
211 #[test]
212 fn construct_table_repeated_mismatch_length_1() {
213 let (mut x, y, f) = data();
214
215 x.push(100.);
216
217 let search1 = search::Linear::default();
218 let search2 = search::Linear::default();
219
220 let output: Result<TableLinLin, _> = LookupTable2D::new(x, search1, y, search2, f);
221 assert!(output.is_err());
222 }
223
224 #[test]
225 fn construct_table_repeated_mismatch_length_2() {
226 let (x, mut y, f) = data();
227
228 y.push(100.0);
229
230 let search1 = search::Linear::default();
231 let search2 = search::Linear::default();
232
233 let output: Result<TableLinLin, _> = LookupTable2D::new(x, search1, y, search2, f);
234 assert!(output.is_err());
235 }
236
237 #[test]
238 fn construct_table_reversed_ax1() {
240 let (mut x, y, f) = data();
241
242 x.reverse();
243
244 let y_0 = y[y.len() / 3];
245 let x_0 = x[2 * x.len() / 3];
246 let f_actual = func(x_0, y_0);
247
248 let search1 = search::Linear::default();
249 let search2 = search::Linear::default();
250
251 let table: TableLinLin = LookupTable2D::new(x, search1, y, search2, f).unwrap();
252
253 let table_reversed_output = table.lookup(x_0, y_0);
254
255 float_eq::assert_float_ne!(table_reversed_output, f_actual, abs <= TOL);
256 }
257
258 #[test]
259 fn construct_table_reversed_ax2() {
261 let (x, mut y, f) = data();
262
263 y.reverse();
264
265 let y_0 = y[y.len() / 3];
266 let x_0 = x[2 * x.len() / 3];
267 let f_actual = func(x_0, y_0);
268
269 let search1 = search::Linear::default();
270 let search2 = search::Linear::default();
271
272 let table: TableLinLin = LookupTable2D::new(x, search1, y, search2, f).unwrap();
273
274 let table_reversed_output = table.lookup(x_0, y_0);
275
276 float_eq::assert_float_ne!(table_reversed_output, f_actual, abs <= TOL);
277 }
278
279 #[test]
284 fn linear_1() {
285 let table = linear_simple_table();
286 let x = 0.5;
287 let y = 2.5;
288 let output = table.lookup(x, y);
289 float_eq::assert_float_eq!(output, func(x, y), abs <= TOL);
290 }
291
292 #[test]
293 fn linear_low_bound() {
294 let table = linear_simple_table();
295 let x = 0.;
296 let y = 0.;
297 let output = table.lookup(x, y);
298 float_eq::assert_float_eq!(output, func(x, y), abs <= TOL);
299 }
300
301 #[test]
302 fn linear_high_bound() {
303 let table = linear_simple_table();
304 let x = 5.0;
305 let y = 5.0;
306 let output = table.lookup(x, y);
307 float_eq::assert_float_eq!(output, func(x, y), abs <= TOL);
308 }
309}