Skip to main content

dynamics_spatial/
symmetric3.rs

1//! Defines **symmetric matrices** of size 3x3 and related operations.
2
3use nalgebra::Matrix3;
4use std::ops::{Add, Index, Mul, Sub};
5
6use crate::{motion::SpatialRotation, vector3d::Vector3D};
7
8#[cfg(feature = "python")]
9use numpy::{PyReadonlyArrayDyn, ToPyArray, ndarray::Array2};
10#[cfg(feature = "python")]
11use pyo3::prelude::*;
12
13/// A symmetric 3x3 matrix.
14///
15/// The matrix is stored in a compact form, only keeping the unique elements.
16#[derive(Debug, Clone, Copy, PartialEq, Default)]
17pub struct Symmetric3 {
18    /// The unique elements of the symmetric matrix, stored in the order:
19    /// [m11, m22, m33, m12, m13, m23]
20    data: [f64; 6],
21}
22
23impl Symmetric3 {
24    /// Creates a new `Symmetric3` matrix from the given elements.
25    ///
26    /// # Arguments
27    ///
28    /// * `m11`, `m22`, `m33` - The diagonal elements.
29    /// * `m12`, `m13`, `m23` - The off-diagonal elements.
30    #[must_use]
31    pub fn new(m11: f64, m22: f64, m33: f64, m12: f64, m13: f64, m23: f64) -> Self {
32        Self {
33            data: [m11, m22, m33, m12, m13, m23],
34        }
35    }
36
37    /// Returns the element at the specified row and column.
38    ///
39    /// # Arguments
40    ///
41    /// * `row` - The row index (0-based).
42    /// * `col` - The column index (0-based).
43    ///
44    /// # Panics
45    ///
46    /// Panics if the row or column index is out of bounds.
47    #[must_use]
48    pub fn get(&self, row: usize, col: usize) -> &f64 {
49        match (row, col) {
50            (0, 0) => &self.data[0],
51            (1, 1) => &self.data[1],
52            (2, 2) => &self.data[2],
53            (0, 1) | (1, 0) => &self.data[3],
54            (0, 2) | (2, 0) => &self.data[4],
55            (1, 2) | (2, 1) => &self.data[5],
56            _ => panic!("Index out of bounds"),
57        }
58    }
59
60    /// Returns the zero symmetric matrix.
61    #[must_use]
62    pub fn zeros() -> Self {
63        Self { data: [0.0; 6] }
64    }
65
66    /// Returns the identity symmetric matrix.
67    #[must_use]
68    pub fn identity() -> Self {
69        Self {
70            data: [1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
71        }
72    }
73
74    /// Creates a diagonal symmetric matrix from the given diagonal elements.
75    ///
76    /// # Arguments
77    /// * `diag` - A vector containing the diagonal elements [m11, m22, m33].
78    #[must_use]
79    pub fn from_diagonal(diag: &[f64; 3]) -> Self {
80        Self {
81            data: [diag[0], diag[1], diag[2], 0.0, 0.0, 0.0],
82        }
83    }
84
85    /// Convert the symmetric matrix to a full 3x3 matrix.
86    #[must_use]
87    pub fn matrix(&self) -> Matrix3<f64> {
88        Matrix3::new(
89            self.data[0],
90            self.data[3],
91            self.data[4],
92            self.data[3],
93            self.data[1],
94            self.data[5],
95            self.data[4],
96            self.data[5],
97            self.data[2],
98        )
99    }
100
101    #[cfg(feature = "python")]
102    #[must_use]
103    /// Converts the symmetric matrix to a NumPy array.
104    pub fn to_numpy(&self, py: Python) -> Py<PyAny> {
105        let mat = self.matrix();
106        Array2::from_shape_fn((3, 3), |(i, j)| mat[(i, j)])
107            .to_pyarray(py)
108            .into_any()
109            .unbind()
110    }
111
112    #[must_use]
113    pub fn skew_square(v: Vector3D) -> Symmetric3 {
114        let x = v.0[0];
115        let y = v.0[1];
116        let z = v.0[2];
117
118        Symmetric3::new(
119            -y * y - z * z,
120            -x * x - z * z,
121            -x * x - y * y,
122            x * y,
123            x * z,
124            y * z,
125        )
126    }
127
128    /// Computes the matrix product $RSR^\top$ where $R$ is a spatial rotation and $S$ is this symmetric matrix.
129    ///
130    /// # Arguments
131    /// * `rotation` - The spatial rotation to apply.
132    ///
133    /// # Returns
134    /// The rotated symmetric matrix.
135    #[must_use]
136    pub fn rotate(&self, rotation: &SpatialRotation) -> Symmetric3 {
137        // TODO: avoid constructing the full matrix
138        let r = &rotation.0;
139        let s = &self.matrix();
140        let rsrt = r * s * r.transpose();
141        Symmetric3::new(
142            rsrt[(0, 0)],
143            rsrt[(1, 1)],
144            rsrt[(2, 2)],
145            rsrt[(0, 1)],
146            rsrt[(0, 2)],
147            rsrt[(1, 2)],
148        )
149    }
150
151    #[cfg(feature = "python")]
152    /// Creates a `Symmetric3` from a NumPy array.
153    /// # Arguments
154    /// * `array` - A 3x3 NumPy array.
155    /// # Returns
156    /// * A `Symmetric3` instance if successful,
157    /// * otherwise raises a `ValueError` if the input is not symmetric or not of shape (3, 3).
158    pub fn from_pyarray(array: &PyReadonlyArrayDyn<f64>) -> Result<Self, PyErr> {
159        let array = array.as_array();
160        if array.shape() != [3, 3] {
161            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
162                "Input array must be of shape (3, 3)",
163            ));
164        }
165
166        // Check symmetry
167        for i in 0..3 {
168            for j in 0..3 {
169                if array[[i, j]] != array[[j, i]] {
170                    return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
171                        "Input array must be symmetric",
172                    ));
173                }
174            }
175        }
176
177        Ok(Symmetric3::new(
178            array[[0, 0]],
179            array[[1, 1]],
180            array[[2, 2]],
181            array[[0, 1]],
182            array[[0, 2]],
183            array[[1, 2]],
184        ))
185    }
186}
187
188impl Index<(usize, usize)> for Symmetric3 {
189    type Output = f64;
190
191    fn index(&self, index: (usize, usize)) -> &Self::Output {
192        self.get(index.0, index.1)
193    }
194}
195
196impl Add<Symmetric3> for Symmetric3 {
197    type Output = Symmetric3;
198
199    fn add(self, rhs: Symmetric3) -> Self::Output {
200        Symmetric3 {
201            data: [
202                self.data[0] + rhs.data[0],
203                self.data[1] + rhs.data[1],
204                self.data[2] + rhs.data[2],
205                self.data[3] + rhs.data[3],
206                self.data[4] + rhs.data[4],
207                self.data[5] + rhs.data[5],
208            ],
209        }
210    }
211}
212
213impl Sub<Symmetric3> for Symmetric3 {
214    type Output = Symmetric3;
215
216    fn sub(self, rhs: Symmetric3) -> Self::Output {
217        Symmetric3 {
218            data: [
219                self.data[0] - rhs.data[0],
220                self.data[1] - rhs.data[1],
221                self.data[2] - rhs.data[2],
222                self.data[3] - rhs.data[3],
223                self.data[4] - rhs.data[4],
224                self.data[5] - rhs.data[5],
225            ],
226        }
227    }
228}
229
230impl Mul<&Vector3D> for &Symmetric3 {
231    type Output = Vector3D;
232
233    fn mul(self, rhs: &Vector3D) -> Self::Output {
234        Vector3D::new(
235            self[(0, 0)] * rhs.0[0] + self[(0, 1)] * rhs.0[1] + self[(0, 2)] * rhs.0[2],
236            self[(1, 0)] * rhs.0[0] + self[(1, 1)] * rhs.0[1] + self[(1, 2)] * rhs.0[2],
237            self[(2, 0)] * rhs.0[0] + self[(2, 1)] * rhs.0[1] + self[(2, 2)] * rhs.0[2],
238        )
239    }
240}
241
242impl Mul<f64> for Symmetric3 {
243    type Output = Symmetric3;
244
245    fn mul(self, rhs: f64) -> Self::Output {
246        Symmetric3 {
247            data: [
248                self.data[0] * rhs,
249                self.data[1] * rhs,
250                self.data[2] * rhs,
251                self.data[3] * rhs,
252                self.data[4] * rhs,
253                self.data[5] * rhs,
254            ],
255        }
256    }
257}
258
259impl Mul<Symmetric3> for f64 {
260    type Output = Symmetric3;
261
262    fn mul(self, rhs: Symmetric3) -> Self::Output {
263        Symmetric3 {
264            data: [
265                rhs.data[0] * self,
266                rhs.data[1] * self,
267                rhs.data[2] * self,
268                rhs.data[3] * self,
269                rhs.data[4] * self,
270                rhs.data[5] * self,
271            ],
272        }
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use approx::assert_relative_eq;
280
281    #[test]
282    fn test_symmetric3_creation() {
283        let sym = Symmetric3::new(1.0, 2.0, 3.0, 0.1, 0.2, 0.3);
284        assert_eq!(sym[(0, 0)], 1.0);
285        assert_eq!(sym[(1, 1)], 2.0);
286        assert_eq!(sym[(2, 2)], 3.0);
287        assert_eq!(sym[(0, 1)], 0.1);
288        assert_eq!(sym[(1, 0)], 0.1);
289        assert_eq!(sym[(0, 2)], 0.2);
290        assert_eq!(sym[(2, 0)], 0.2);
291        assert_eq!(sym[(1, 2)], 0.3);
292        assert_eq!(sym[(2, 1)], 0.3);
293    }
294
295    #[test]
296    fn test_symmetric3_to_matrix() {
297        let sym = Symmetric3::new(1.0, 2.0, 3.0, 0.1, 0.2, 0.3);
298        let mat = sym.matrix();
299        let expected = Matrix3::new(1.0, 0.1, 0.2, 0.1, 2.0, 0.3, 0.2, 0.3, 3.0);
300        assert_relative_eq!(mat, expected);
301    }
302
303    #[test]
304    fn test_symmetric3_mul_vector3d() {
305        let sym = Symmetric3::new(1.0, 2.0, 3.0, 0.0, 0.0, 0.0);
306        let vec = Vector3D::new(1.0, 2.0, 3.0);
307        let result = &sym * &vec;
308        let expected = Vector3D::new(1.0, 4.0, 9.0);
309        assert_relative_eq!(result.0, expected.0);
310    }
311
312    #[test]
313    fn test_symmetric3_rotate() {
314        let full = Matrix3::new(1.0, 2.0, 3.0, 2.0, 4.0, 5.0, 3.0, 5.0, 6.0);
315        assert!(full == full.transpose());
316
317        let sym = Symmetric3::new(
318            full[(0, 0)],
319            full[(1, 1)],
320            full[(2, 2)],
321            full[(0, 1)],
322            full[(0, 2)],
323            full[(1, 2)],
324        );
325
326        let rotation = SpatialRotation::from_axis_angle(
327            &Vector3D::new(0.0, 0.0, 1.0),
328            std::f64::consts::FRAC_PI_2,
329        );
330        let rotated_sym = sym.rotate(&rotation);
331
332        let expected = rotation.0 * full * rotation.0.transpose();
333        let expected_sym = Symmetric3::new(
334            expected[(0, 0)],
335            expected[(1, 1)],
336            expected[(2, 2)],
337            expected[(0, 1)],
338            expected[(0, 2)],
339            expected[(1, 2)],
340        );
341        assert_relative_eq!(rotated_sym.matrix(), expected_sym.matrix());
342    }
343}