multivariate_optimization/
triangular.rs

1//! Triangular numbers and matrices.
2
3use rayon::prelude::*;
4
5/// Calculates a triangular number, where `trinum(x) == x*(x+1)/2`.
6pub fn trinum(x: usize) -> usize {
7    x * (x + 1) / 2
8}
9
10/// Same as [`trinum`], but returns [`None`] if calculation would overflow.
11pub fn checked_trinum(x: usize) -> Option<usize> {
12    x.checked_add(1)
13        .and_then(|t| t.checked_mul(x))
14        .map(|t| t / 2)
15}
16
17/// Triangular matrix.
18#[derive(Clone, Debug)]
19pub struct Triangular<T> {
20    /// Number of both rows and columns.
21    dim: usize,
22    /// Elements as linear slice.
23    linear: Box<[T]>,
24}
25
26impl<T> Triangular<T> {
27    /// Create triangular matrix with given dimension.
28    ///
29    /// Fills initial elements by calling `contents` with `(i, j)` as argument,
30    /// where `i < dim && j <= i`.
31    pub fn new<F>(dim: usize, mut contents: F) -> Triangular<T>
32    where
33        F: FnMut((usize, usize)) -> T,
34    {
35        let mut linear = Vec::with_capacity(trinum(dim));
36        for row in 0..dim {
37            for col in 0..=row {
38                linear.push(contents((row, col)));
39            }
40        }
41        let linear = linear.into_boxed_slice();
42        Triangular { dim, linear }
43    }
44    /// Same as [`Triangular::new`], but execute in parallel.
45    pub fn par_new<F>(dim: usize, contents: F) -> Triangular<T>
46    where
47        F: Sync + Fn((usize, usize)) -> T,
48        T: Send,
49    {
50        let contents = &contents;
51        // TODO: Collect directly into boxed slice, see:
52        // https://github.com/rayon-rs/rayon/pull/1061
53        let linear = (0..dim)
54            .into_par_iter()
55            .flat_map(|row| {
56                (0..=row)
57                    .into_par_iter()
58                    .map(move |col| (contents)((row, col)))
59            })
60            .collect::<Vec<_>>()
61            .into_boxed_slice();
62        Triangular { dim, linear }
63    }
64    /// Number of both rows and columns.
65    pub fn dim(&self) -> usize {
66        self.dim
67    }
68    fn linear_index(&self, (row, col): (usize, usize)) -> usize {
69        trinum(row) + col
70    }
71    fn checked_linear_index(&self, (row, col): (usize, usize)) -> Result<usize, &'static str> {
72        if !(row < self.dim) {
73            return Err("first index out of bounds");
74        }
75        if !(col <= row) {
76            return Err("second index larger than first index");
77        }
78        Ok(self.linear_index((row, col)))
79    }
80    /// Immutable unchecked indexing through `(i, j)`, where `j <= i < dim()`.
81    ///
82    /// # Safety
83    ///
84    /// * `row` (first tuple field) must be smaller than `self.dim()`
85    /// * `col` (second tuple field) must be equal to or smaller than `row`
86    pub unsafe fn get_unchecked(&self, (row, col): (usize, usize)) -> &T {
87        let idx = self.linear_index((row, col));
88        // SAFETY: `row` and `col` are valid
89        unsafe { self.linear.get_unchecked(idx) }
90    }
91    /// Mutable unchecked indexing through `(i, j)`, where `j <= i < dim()`.
92    ///
93    /// # Safety
94    ///
95    /// * `row` (first tuple field) must be smaller than `self.dim()`
96    /// * `col` (second tuple field) must be equal to or smaller than `row`
97    pub unsafe fn get_unchecked_mut(&mut self, (row, col): (usize, usize)) -> &mut T {
98        let idx = self.linear_index((row, col));
99        // SAFETY: `row` and `col` are valid
100        unsafe { self.linear.get_unchecked_mut(idx) }
101    }
102}
103
104/// Immutable indexing through `(i, j)`, where `j <= i < dim()`.
105impl<T> std::ops::Index<(usize, usize)> for Triangular<T> {
106    type Output = T;
107    fn index(&self, (row, col): (usize, usize)) -> &T {
108        let idx = match self.checked_linear_index((row, col)) {
109            Ok(x) => x,
110            Err(x) => panic!("invalid indices for triangular matrix: {x}"),
111        };
112        // SAFETY: `checked_linear_index` returns valid index on success
113        unsafe { self.linear.get_unchecked(idx) }
114    }
115}
116
117/// Mutable indexing through `(i, j)`, where `j <= i < dim()`.
118impl<T> std::ops::IndexMut<(usize, usize)> for Triangular<T> {
119    fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut T {
120        let idx = match self.checked_linear_index((row, col)) {
121            Ok(x) => x,
122            Err(x) => panic!("invalid indices when mutably indexing triangular matrix: {x}"),
123        };
124        // SAFETY: `checked_linear_index` returns valid index on success
125        unsafe { self.linear.get_unchecked_mut(idx) }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::{checked_trinum, trinum, Triangular};
132    #[test]
133    fn test_trinum() {
134        assert_eq!(checked_trinum(0), Some(0));
135        assert_eq!(checked_trinum(1), Some(1));
136        assert_eq!(checked_trinum(2), Some(3));
137        assert_eq!(checked_trinum(3), Some(6));
138        assert_eq!(checked_trinum(4), Some(10));
139        assert_eq!(checked_trinum(5), Some(15));
140        assert_eq!(checked_trinum(6), Some(21));
141        assert_eq!(checked_trinum(7), Some(28));
142        assert_eq!(checked_trinum(8), Some(36));
143        assert_eq!(checked_trinum(100), Some(5050));
144        assert_eq!(checked_trinum(usize::MAX / 16), None);
145        for i in 0..100 {
146            assert_eq!(Some(trinum(i)), checked_trinum(i));
147        }
148    }
149    #[test]
150    fn test_triangular() {
151        let calc = |(i, j)| (10 * i + j) as i16;
152        let mut m = Triangular::<i16>::new(5, calc);
153        assert_eq!(m.dim(), 5);
154        for i in 0..5 {
155            for j in 0..=i {
156                assert_eq!(m[(i, j)], calc((i, j)));
157            }
158        }
159        m[(0, 0)] = -1;
160        m[(3, 0)] = -2;
161        m[(4, 3)] = -3;
162        m[(4, 4)] = -4;
163        assert_eq!(m[(0, 0)], -1);
164        assert_eq!(m[(3, 0)], -2);
165        assert_eq!(m[(4, 3)], -3);
166        assert_eq!(m[(4, 4)], -4);
167    }
168    #[test]
169    #[should_panic]
170    fn test_triangular_index_too_large() {
171        let m = Triangular::<()>::new(3, |_| ());
172        m[(3, 0)];
173    }
174    #[test]
175    #[should_panic]
176    fn test_triangular_index_wrongly_ordered() {
177        let m = Triangular::<()>::new(3, |_| ());
178        m[(1, 2)];
179    }
180}