lookup_tables/
table2d.rs

1use crate::axis;
2use crate::bound;
3use crate::common;
4use crate::search;
5use crate::Error;
6use std::ops::{Add, Div, Mul, Sub};
7
8use ndarray::Array2;
9
10pub struct LookupTable2D<Axis1, Axis2, Dep>
11where
12    Axis1: axis::AxisImpl,
13    Axis2: axis::AxisImpl,
14{
15    indep1: Vec<<Axis1 as axis::AxisImpl>::Indep>,
16    search1: <Axis1 as axis::AxisImpl>::Search,
17    indep2: Vec<<Axis2 as axis::AxisImpl>::Indep>,
18    search2: <Axis2 as axis::AxisImpl>::Search,
19    dep: Array2<Dep>,
20}
21
22impl<Indep1, Search1, LowerBound1, UpperBound1, Indep2, Search2, LowerBound2, UpperBound2, Dep>
23    LookupTable2D<
24        axis::Axis<Indep1, Search1, LowerBound1, UpperBound1>,
25        axis::Axis<Indep2, Search2, LowerBound2, UpperBound2>,
26        Dep,
27    >
28where
29    Indep1: std::cmp::PartialOrd,
30    Indep2: std::cmp::PartialOrd,
31{
32    pub fn new(
33        mut indep1: Vec<Indep1>,
34        search1: Search1,
35        mut indep2: Vec<Indep2>,
36        search2: Search2,
37        mut dep: Array2<Dep>,
38    ) -> Result<Self, Error> {
39        match common::check_independent_variable(indep1.as_slice())? {
40            common::IndependentVariableOrdering::MonotonicallyIncreasing => {}
41            common::IndependentVariableOrdering::MonotonicallyDecreasing => {
42                indep1.reverse();
43                dep.invert_axis(ndarray::Axis(0));
44                dbg!("reversing");
45            }
46        }
47
48        match common::check_independent_variable(indep2.as_slice())? {
49            common::IndependentVariableOrdering::MonotonicallyIncreasing => {}
50            common::IndependentVariableOrdering::MonotonicallyDecreasing => {
51                indep2.reverse();
52                dep.invert_axis(ndarray::Axis(1));
53                dbg!("reversing");
54            }
55        }
56
57        common::check_lengths(indep1.len(), dep.len_of(ndarray::Axis(0)))?;
58        common::check_lengths(indep2.len(), dep.len_of(ndarray::Axis(1)))?;
59
60        Ok(Self {
61            indep1,
62            search1,
63            indep2,
64            search2,
65            dep,
66        })
67    }
68}
69
70impl<Indep1, Search1, LowerBound1, UpperBound1, Indep2, Search2, LowerBound2, UpperBound2, Dep>
71    LookupTable2D<
72        axis::Axis<Indep1, Search1, LowerBound1, UpperBound1>,
73        axis::Axis<Indep2, Search2, LowerBound2, UpperBound2>,
74        Dep,
75    >
76where
77    Search1: search::Search<Indep1>,
78    Search2: search::Search<Indep2>,
79    // TODO: HoldHigh / HoldLow does not require so many strict bounds
80    Dep: Copy
81        + Sub<Dep, Output = Dep>
82        + Div<Indep1, Output = Dep>
83        + Mul<Indep1, Output = Dep>
84        + Mul<Indep2, Output = Dep>
85        + Add<Dep, Output = Dep>
86        + std::fmt::Debug,
87    Indep1: Copy
88        + Sub<Indep1, Output = Indep1>
89        + std::cmp::PartialOrd
90        + Div<Indep1, Output = Indep1>
91        //
92        + std::fmt::Debug,
93    Indep2: Copy
94        + Sub<Indep2, Output = Indep2>
95        + std::cmp::PartialOrd
96        + Div<Indep2, Output = Indep2>
97        //
98        + std::fmt::Debug,
99    LowerBound1: bound::Bound<Indep1>,
100    UpperBound1: bound::Bound<Indep1>,
101    LowerBound2: bound::Bound<Indep2>,
102    UpperBound2: bound::Bound<Indep2>,
103{
104    pub fn lookup(&self, x: Indep1, y: Indep2) -> Dep {
105        let (idx_x_1, idx_x_2) = self.search1.search(x, self.indep1.as_slice());
106        let (idx_y_1, idx_y_2) = self.search2.search(y, self.indep2.as_slice());
107
108        let x_1: Indep1 = self.indep1[idx_x_1];
109        let x_2: Indep1 = self.indep1[idx_x_2];
110
111        let y_1: Indep2 = self.indep2[idx_y_1];
112        let y_2: Indep2 = self.indep2[idx_y_2];
113
114        let f_1_1: Dep = self.dep[[idx_x_1, idx_y_1]];
115        let f_1_2: Dep = self.dep[[idx_x_1, idx_y_2]];
116        let f_2_1: Dep = self.dep[[idx_x_2, idx_y_1]];
117        let f_2_2: Dep = self.dep[[idx_x_2, idx_y_2]];
118
119        // bound x acording to the axis we are interpolating on
120        // unwrap is safe here as we have checked its at least length 2
121        let x = LowerBound1::lower_bound(x, *self.indep1.first().unwrap());
122        let x = UpperBound1::upper_bound(x, *self.indep1.last().unwrap());
123
124        // bound y acording to the axis we are interpolating on
125        // unwrap is safe here as we have checked its at least length 2
126        let y = LowerBound2::lower_bound(y, *self.indep2.first().unwrap());
127        let y = UpperBound2::upper_bound(y, *self.indep2.last().unwrap());
128
129        let x_slope1 = (x_2 - x) / (x_2 - x_1);
130        let x_slope2 = (x - x_1) / (x_2 - x_1);
131        let y_slope1 = (y_2 - y) / (y_2 - y_1);
132        let y_slope2 = (y - y_1) / (y_2 - y_1);
133
134        let f_x_y1 = f_1_1 * x_slope1 + f_2_1 * x_slope2;
135        let f_x_y2 = f_1_2 * x_slope1 + f_2_2 * x_slope2;
136
137        f_x_y1 * y_slope1 + f_x_y2 * y_slope2
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    const TOL: f64 = 1e-10;
146
147    type LinearAxis = axis::Axis<f64, search::Linear, bound::Interp, bound::Interp>;
148    type TableLinLin = LookupTable2D<LinearAxis, LinearAxis, f64>;
149
150    fn func(x: f64, y: f64) -> f64 {
151        3. * x + y
152    }
153
154    //
155    // Table Construction
156    //
157    fn data() -> (Vec<f64>, Vec<f64>, Array2<f64>) {
158        let len = 100;
159        let x = ndarray::Array1::linspace(0., 5.0, len).to_vec();
160        let y = x.clone();
161        let mut f = ndarray::Array2::zeros((x.len(), y.len()));
162
163        for i in 0..x.len() {
164            for j in 0..y.len() {
165                f[[i, j]] = func(x[i], y[j]);
166            }
167        }
168
169        (x, y, f)
170    }
171
172    fn linear_simple_table() -> LookupTable2D<LinearAxis, LinearAxis, f64> {
173        let (x, y, f) = data();
174        let search1 = search::Linear::default();
175        let search2 = search::Linear::default();
176        LookupTable2D::new(x, search1, y, search2, f).unwrap()
177    }
178
179    //
180    // Table Construction Tests
181    //
182
183    #[test]
184    fn construct_table_repeated_entries_1() {
185        let (mut x, y, f) = data();
186
187        x[0] = 0.;
188        x[1] = 0.;
189
190        let search1 = search::Linear::default();
191        let search2 = search::Linear::default();
192
193        let output: Result<TableLinLin, _> = LookupTable2D::new(x, search1, y, search2, f);
194        assert!(output.is_err());
195    }
196
197    #[test]
198    fn construct_table_repeated_entries_2() {
199        let (x, mut y, f) = data();
200
201        y[0] = 0.;
202        y[1] = 0.;
203
204        let search1 = search::Linear::default();
205        let search2 = search::Linear::default();
206
207        let output: Result<TableLinLin, _> = LookupTable2D::new(x, search1, y, search2, f);
208        assert!(output.is_err());
209    }
210
211    #[test]
212    fn construct_table_repeated_mismatch_length_1() {
213        let (mut x, y, f) = data();
214
215        x.push(100.);
216
217        let search1 = search::Linear::default();
218        let search2 = search::Linear::default();
219
220        let output: Result<TableLinLin, _> = LookupTable2D::new(x, search1, y, search2, f);
221        assert!(output.is_err());
222    }
223
224    #[test]
225    fn construct_table_repeated_mismatch_length_2() {
226        let (x, mut y, f) = data();
227
228        y.push(100.0);
229
230        let search1 = search::Linear::default();
231        let search2 = search::Linear::default();
232
233        let output: Result<TableLinLin, _> = LookupTable2D::new(x, search1, y, search2, f);
234        assert!(output.is_err());
235    }
236
237    #[test]
238    /// prove reversing the x vector yields the same lookup results
239    fn construct_table_reversed_ax1() {
240        let (mut x, y, f) = data();
241
242        x.reverse();
243
244        let y_0 = y[y.len() / 3];
245        let x_0 = x[2 * x.len() / 3];
246        let f_actual = func(x_0, y_0);
247
248        let search1 = search::Linear::default();
249        let search2 = search::Linear::default();
250
251        let table: TableLinLin = LookupTable2D::new(x, search1, y, search2, f).unwrap();
252
253        let table_reversed_output = table.lookup(x_0, y_0);
254
255        float_eq::assert_float_ne!(table_reversed_output, f_actual, abs <= TOL);
256    }
257
258    #[test]
259    /// prove reversing the y vector yields the same lookup results
260    fn construct_table_reversed_ax2() {
261        let (x, mut y, f) = data();
262
263        y.reverse();
264
265        let y_0 = y[y.len() / 3];
266        let x_0 = x[2 * x.len() / 3];
267        let f_actual = func(x_0, y_0);
268
269        let search1 = search::Linear::default();
270        let search2 = search::Linear::default();
271
272        let table: TableLinLin = LookupTable2D::new(x, search1, y, search2, f).unwrap();
273
274        let table_reversed_output = table.lookup(x_0, y_0);
275
276        float_eq::assert_float_ne!(table_reversed_output, f_actual, abs <= TOL);
277    }
278
279    //
280    // Linear Tests
281    //
282
283    #[test]
284    fn linear_1() {
285        let table = linear_simple_table();
286        let x = 0.5;
287        let y = 2.5;
288        let output = table.lookup(x, y);
289        float_eq::assert_float_eq!(output, func(x, y), abs <= TOL);
290    }
291
292    #[test]
293    fn linear_low_bound() {
294        let table = linear_simple_table();
295        let x = 0.;
296        let y = 0.;
297        let output = table.lookup(x, y);
298        float_eq::assert_float_eq!(output, func(x, y), abs <= TOL);
299    }
300
301    #[test]
302    fn linear_high_bound() {
303        let table = linear_simple_table();
304        let x = 5.0;
305        let y = 5.0;
306        let output = table.lookup(x, y);
307        float_eq::assert_float_eq!(output, func(x, y), abs <= TOL);
308    }
309}