sembas/structs/
mod.rs

1pub mod boundary;
2pub mod error;
3pub mod report;
4pub mod sampling;
5
6pub use boundary::*;
7pub use error::*;
8pub use sampling::*;
9
10use core::fmt;
11
12use nalgebra::{Const, OMatrix, SVector};
13
14use crate::utils::vector_to_string;
15
16/// A 2-dimensional subspace of an N-dimensional input space, described by two
17/// orthonormal vectors, u and v.
18#[derive(Debug, Clone, PartialEq)]
19pub struct Span<const N: usize> {
20    u: SVector<f64, N>,
21    v: SVector<f64, N>,
22}
23
24/// An N-dimensional hyperrectangle that is defined by an lower and upper bound (low
25/// and high). Used to define a valid input region to sample from for a system under
26/// test.
27#[derive(Debug, Clone, PartialEq)]
28pub struct Domain<const N: usize> {
29    low: SVector<f64, N>,
30    high: SVector<f64, N>,
31}
32
33impl<const N: usize> Span<N> {
34    /// Constructs a Span across @u and @v. @u and @v are orthonormalized, where @v
35    /// is forced to be orthogonal to @u, and @u retains its directionality. Uses
36    /// Gramm Schmidt Orthonormalization
37    pub fn new(u: SVector<f64, N>, v: SVector<f64, N>) -> Self {
38        let u = u.normalize();
39        let v = v.normalize();
40        let v = (v - u * u.dot(&v)).normalize();
41        Span { u, v }
42    }
43
44    pub fn u(&self) -> SVector<f64, N> {
45        self.u
46    }
47    pub fn v(&self) -> SVector<f64, N> {
48        self.v
49    }
50
51    // Provides a rotater function rot(angle: f64) which returns a rotation matrix
52    // that rotates by an angle in radians along &self's span.
53    pub fn get_rotater(&self) -> impl Fn(f64) -> OMatrix<f64, Const<N>, Const<N>> {
54        let identity = OMatrix::<f64, Const<N>, Const<N>>::identity();
55
56        let a = self.u * self.v.transpose() - self.v * self.u.transpose();
57        let b = self.v * self.v.transpose() + self.u * self.u.transpose();
58
59        move |angle: f64| identity + a * angle.sin() + b * (angle.cos() - 1.0)
60    }
61}
62
63impl<const N: usize> Domain<N> {
64    /// Returns a domain bounded by the two points.
65    pub fn new(p1: SVector<f64, N>, p2: SVector<f64, N>) -> Self {
66        let low = p1.zip_map(&p2, |a, b| a.min(b));
67        let high = p1.zip_map(&p2, |a, b| a.max(b));
68
69        Domain { low, high }
70    }
71
72    /// Returns a domain with the provided bounds.
73    /// ## Safety
74    /// This function is unsafe because it doesn't do any checks to ensure that for
75    /// all dimensions, low < high. If this condition is not met, the Domain's
76    /// operations behavior is undefined.
77    pub unsafe fn new_from_bounds(low: SVector<f64, N>, high: SVector<f64, N>) -> Self {
78        Domain { low, high }
79    }
80
81    /// Returns a Domain bounded between 0 and 1 for all dimensions.
82    pub fn normalized() -> Self {
83        let low = SVector::<f64, N>::zeros();
84        let high = SVector::<f64, N>::repeat(1.0);
85        Domain { low, high }
86    }
87
88    /// The lower bound of the domain.
89    pub fn low(&self) -> &SVector<f64, N> {
90        &self.low
91    }
92
93    /// The upper bound of the domain.
94    pub fn high(&self) -> &SVector<f64, N> {
95        &self.high
96    }
97
98    /// Checks if the given vector is within the domain.
99    pub fn contains(&self, p: &SVector<f64, N>) -> bool {
100        let below_low = SVector::<bool, N>::from_fn(|i, _| p[i] < self.low[i]);
101        if below_low.iter().any(|&x| x) {
102            return false;
103        }
104
105        let above_high = SVector::<bool, N>::from_fn(|i, _| p[i] > self.high[i]);
106        if above_high.iter().any(|&x| x) {
107            return false;
108        }
109
110        true
111    }
112
113    /// Returns the size of each dimension as a vector.
114    pub fn dimensions(&self) -> SVector<f64, N> {
115        self.high - self.low
116    }
117
118    /// Projects a point from one domain to another.
119    /// Retains the relative position for all points within the source domain.
120    /// Useful for projecting an input from one domain to a normalized domain and vis
121    /// versa.
122    /// ## Arguments
123    /// * p: The point that is being projected
124    /// * from: The domain that the point is projecting from
125    /// * to: The domain that the point is projecting to
126    pub fn project_point_domains(
127        p: &SVector<f64, N>,
128        from: &Domain<N>,
129        to: &Domain<N>,
130    ) -> SVector<f64, N> {
131        ((p - from.low).component_div(&from.dimensions())).component_mul(&to.dimensions()) + to.low
132    }
133
134    /// Finds the distance between the edge of the domain from a point in the
135    /// direction of the provided vector. Useful for finding target/non-target
136    /// samples on the extremes of the input space.
137    /// * p: A point that the ray starts from
138    /// * v: The direction the ray travels
139    ////// ## Returns
140    /// * t: The linear distance between p and the edge of the domain in the
141    ///   direction v
142    pub fn distance_to_edge(&self, p: &SVector<f64, N>, v: &SVector<f64, N>) -> Result<f64> {
143        let t_lower = (self.low - p).component_div(v);
144        let t_upper = (self.high - p).component_div(v);
145
146        let l = t_lower
147            .iter()
148            .filter(|&&xi| xi >= 0.0)
149            .min_by(|a, b| a.partial_cmp(b).unwrap())
150            .cloned();
151
152        let u = t_upper
153            .iter()
154            .filter(|&&xi| xi >= 0.0)
155            .min_by(|a, b| a.partial_cmp(b).unwrap())
156            .cloned();
157
158        let t = match (l, u) {
159            (None, Some(t)) => t,
160            (Some(t), None) => t,
161            (Some(tl), Some(tu)) => tl.min(tu),
162            // OOB due to point falling outside of domain
163            (None, None) => return Err(SamplingError::OutOfBounds),
164        };
165
166        Ok(t)
167    }
168}
169
170impl<const N: usize> fmt::Display for Span<N> {
171    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172        write!(
173            f,
174            "Span({:?}, {:?})",
175            vector_to_string(&self.u),
176            vector_to_string(&self.v)
177        )
178    }
179}
180
181#[cfg(test)]
182mod span_tests {
183    use super::*;
184
185    const ATOL: f64 = 1e-10;
186
187    fn approx(a: f64, b: f64, atol: f64) -> bool {
188        (a - b).abs() < atol
189    }
190
191    #[test]
192    fn is_orthogonal() {
193        let u = nalgebra::vector![0.5, 0.5, 0.1, 0.4, 1.0];
194        let v = nalgebra::vector![1.1, -0.2, -0.5, 0.1, 0.8];
195
196        let span = Span::new(u, v);
197
198        assert!(approx(span.u.angle(&span.v).to_degrees(), 90.0, ATOL));
199    }
200
201    #[test]
202    fn is_normal() {
203        let u = nalgebra::vector![0.5, 0.5, 0.1, 0.4, 1.0];
204        let v = nalgebra::vector![1.1, -0.2, -0.5, 0.1, 0.8];
205
206        let span = Span::new(u, v);
207
208        assert!(approx(span.u.norm(), 1.0, ATOL));
209        assert!(approx(span.v.norm(), 1.0, ATOL));
210    }
211
212    #[test]
213    fn rotater_90() {
214        let u = nalgebra::vector![0.5, 0.5, 0.1, 0.4, 1.0];
215        let v = nalgebra::vector![1.1, -0.2, -0.5, 0.1, 0.8];
216
217        let span = Span::new(u, v);
218
219        let x0 = v;
220        let angle = 90.0f64.to_radians();
221
222        let x1 = span.get_rotater()(angle) * x0;
223
224        assert!(approx(x0.angle(&x1), angle, ATOL));
225    }
226
227    #[test]
228    fn rotater_25() {
229        let u = nalgebra::vector![0.5, 0.5, 0.1, 0.4, 1.0];
230        let v = nalgebra::vector![1.1, -0.2, -0.5, 0.1, 0.8];
231
232        let span = Span::new(u, v);
233
234        let x0 = v;
235        let angle = 25.0f64.to_radians();
236
237        let x1 = span.get_rotater()(angle) * x0;
238
239        assert!(approx(x0.angle(&x1), angle, ATOL));
240    }
241}
242
243#[cfg(test)]
244mod domain_tests {
245    use nalgebra::vector;
246
247    use super::*;
248
249    const ATOL: f64 = 1e-10;
250
251    fn is_near<const N: usize>(a: &SVector<f64, N>, b: &SVector<f64, N>, atol: f64) -> bool {
252        (b - a).norm() <= atol
253    }
254
255    #[test]
256    fn point_translation_low_to_low() {
257        let src = Domain::<3>::normalized();
258        let dst = Domain::<3>::new(vector![1.0, 2.5, 3.5], vector![4.0, 5.0, 6.0]);
259
260        let p0 = src.low();
261        let p1 = Domain::project_point_domains(p0, &src, &dst);
262
263        assert!(is_near(&p1, dst.low(), ATOL))
264    }
265
266    #[test]
267    fn point_translation_high_to_high() {
268        let src = Domain::<3>::normalized();
269        let dst = Domain::<3>::new(vector![1.0, 2.5, 3.5], vector![4.0, 5.0, 6.0]);
270
271        let p0 = src.high();
272        let p1 = Domain::project_point_domains(p0, &src, &dst);
273
274        assert!(is_near(&p1, dst.high(), ATOL))
275    }
276
277    #[test]
278    fn point_translation_mid_to_mid() {
279        let src = Domain::<3>::normalized();
280        let dst = Domain::<3>::new(vector![1.0, 2.5, 3.5], vector![4.0, 5.0, 6.0]);
281
282        let src_mid = src.low() + src.dimensions() / 2.0;
283        let dst_mid = dst.low() + dst.dimensions() / 2.0;
284
285        let p0 = src_mid;
286        let p1 = Domain::project_point_domains(&p0, &src, &dst);
287
288        assert!(is_near(&p1, &dst_mid, ATOL))
289    }
290
291    #[test]
292    fn low_is_below_high() {
293        let d = Domain::<3>::new(vector![4.0, 2.5, 6.0], vector![1.0, 5.0, 3.5]);
294
295        assert!(d.low().iter().zip(d.high.iter()).all(|(l, h)| l < h));
296    }
297
298    #[test]
299    fn contains_false_when_below_low() {
300        let d = Domain::<3>::new(vector![4.0, 2.5, 6.0], vector![1.0, 5.0, 3.5]);
301        let p = d.low() - vector![0.01, 0.01, 0.01];
302
303        assert!(!d.contains(&p))
304    }
305
306    #[test]
307    fn contains_true_when_on_low() {
308        let d = Domain::<3>::new(vector![4.0, 2.5, 6.0], vector![1.0, 5.0, 3.5]);
309        let p = d.low();
310
311        assert!(d.contains(p))
312    }
313
314    #[test]
315    fn contains_false_when_above_high() {
316        let d = Domain::<3>::new(vector![4.0, 2.5, 6.0], vector![1.0, 5.0, 3.5]);
317        let p = d.high() + vector![0.01, 0.01, 0.01];
318
319        assert!(!d.contains(&p))
320    }
321
322    #[test]
323    fn contains_true_when_on_high() {
324        let d = Domain::<3>::new(vector![4.0, 2.5, 6.0], vector![1.0, 5.0, 3.5]);
325        let p = d.high();
326
327        assert!(d.contains(p))
328    }
329}