Skip to main content

dynamics_spatial/
symmetric3.rs

1//! Defines **symmetric matrices** of size 3x3 and related operations.
2
3use nalgebra::Matrix3;
4use std::ops::{Add, Index, Mul, Sub};
5
6use crate::{motion::SpatialRotation, vector3d::Vector3D};
7
8#[cfg(feature = "python")]
9use numpy::{PyReadonlyArrayDyn, ToPyArray, ndarray::Array2};
10#[cfg(feature = "python")]
11use pyo3::prelude::*;
12
13/// A symmetric 3x3 matrix.
14///
15/// The matrix is stored in a compact form, only keeping the unique elements.
16#[derive(Debug, Clone, Copy, PartialEq, Default)]
17pub struct Symmetric3 {
18    /// The unique elements of the symmetric matrix, stored in the order:
19    /// [m11, m22, m33, m12, m13, m23]
20    data: [f64; 6],
21}
22
23impl Symmetric3 {
24    /// Creates a new `Symmetric3` matrix from the given elements.
25    ///
26    /// # Arguments
27    ///
28    /// * `m11`, `m22`, `m33` - The diagonal elements.
29    /// * `m12`, `m13`, `m23` - The off-diagonal elements.
30    #[must_use]
31    pub fn new(m11: f64, m22: f64, m33: f64, m12: f64, m13: f64, m23: f64) -> Self {
32        Self {
33            data: [m11, m22, m33, m12, m13, m23],
34        }
35    }
36
37    /// Returns the element at the specified row and column.
38    ///
39    /// # Arguments
40    ///
41    /// * `row` - The row index (0-based).
42    /// * `col` - The column index (0-based).
43    ///
44    /// # Panics
45    ///
46    /// Panics if the row or column index is out of bounds.
47    #[must_use]
48    pub fn get(&self, row: usize, col: usize) -> &f64 {
49        match (row, col) {
50            (0, 0) => &self.data[0],
51            (1, 1) => &self.data[1],
52            (2, 2) => &self.data[2],
53            (0, 1) | (1, 0) => &self.data[3],
54            (0, 2) | (2, 0) => &self.data[4],
55            (1, 2) | (2, 1) => &self.data[5],
56            _ => panic!("Index out of bounds"),
57        }
58    }
59
60    /// Returns the zero symmetric matrix.
61    #[must_use]
62    pub fn zeros() -> Self {
63        Self { data: [0.0; 6] }
64    }
65
66    /// Returns the identity symmetric matrix.
67    #[must_use]
68    pub fn identity() -> Self {
69        Self {
70            data: [1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
71        }
72    }
73
74    /// Creates a diagonal symmetric matrix from the given diagonal elements.
75    ///
76    /// # Arguments
77    /// * `diag` - A vector containing the diagonal elements [m11, m22, m33].
78    #[must_use]
79    pub fn from_diagonal(diag: &[f64; 3]) -> Self {
80        Self {
81            data: [diag[0], diag[1], diag[2], 0.0, 0.0, 0.0],
82        }
83    }
84
85    /// Convert the symmetric matrix to a full 3x3 matrix.
86    #[must_use]
87    pub fn matrix(&self) -> Matrix3<f64> {
88        Matrix3::new(
89            self.data[0],
90            self.data[3],
91            self.data[4],
92            self.data[3],
93            self.data[1],
94            self.data[5],
95            self.data[4],
96            self.data[5],
97            self.data[2],
98        )
99    }
100
101    #[cfg(feature = "python")]
102    #[must_use]
103    /// Converts the symmetric matrix to a NumPy array.
104    pub fn to_numpy(&self, py: Python) -> Py<PyAny> {
105        // FIXME: move this function to py_symmetric3.rs
106        let mat = self.matrix();
107        Array2::from_shape_fn((3, 3), |(i, j)| mat[(i, j)])
108            .to_pyarray(py)
109            .into_any()
110            .unbind()
111    }
112
113    #[must_use]
114    /// Constructs the skew symmetric matrix associated to the given 3D vector.
115    pub fn skew_square(v: &Vector3D) -> Symmetric3 {
116        let x = v.0[0];
117        let y = v.0[1];
118        let z = v.0[2];
119
120        Symmetric3::new(
121            -y * y - z * z,
122            -x * x - z * z,
123            -x * x - y * y,
124            x * y,
125            x * z,
126            y * z,
127        )
128    }
129
130    /// Computes the matrix product $RSR^\top$ where $R$ is a spatial rotation and $S$ is this symmetric matrix.
131    ///
132    /// # Arguments
133    /// * `rotation` - The spatial rotation to apply.
134    ///
135    /// # Returns
136    /// The rotated symmetric matrix.
137    #[must_use]
138    pub fn rotate(&self, rotation: &SpatialRotation) -> Symmetric3 {
139        // TODO: avoid constructing the full matrix
140        let r = &rotation.0;
141        let s = &self.matrix();
142        let rsrt = r * s * r.transpose();
143        Symmetric3::new(
144            rsrt[(0, 0)],
145            rsrt[(1, 1)],
146            rsrt[(2, 2)],
147            rsrt[(0, 1)],
148            rsrt[(0, 2)],
149            rsrt[(1, 2)],
150        )
151    }
152
153    #[cfg(feature = "python")]
154    /// Creates a `Symmetric3` from a NumPy array.
155    /// # Arguments
156    /// * `array` - A 3x3 NumPy array.
157    /// # Returns
158    /// * A `Symmetric3` instance if successful,
159    /// * otherwise raises a `ValueError` if the input is not symmetric or not of shape (3, 3).
160    pub fn from_pyarray(array: &PyReadonlyArrayDyn<f64>) -> Result<Self, PyErr> {
161        let array = array.as_array();
162        if array.shape() != [3, 3] {
163            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
164                "Input array must be of shape (3, 3)",
165            ));
166        }
167
168        // Check symmetry
169        for i in 0..3 {
170            for j in 0..3 {
171                if array[[i, j]] != array[[j, i]] {
172                    return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
173                        "Input array must be symmetric",
174                    ));
175                }
176            }
177        }
178
179        Ok(Symmetric3::new(
180            array[[0, 0]],
181            array[[1, 1]],
182            array[[2, 2]],
183            array[[0, 1]],
184            array[[0, 2]],
185            array[[1, 2]],
186        ))
187    }
188}
189
190impl Index<(usize, usize)> for Symmetric3 {
191    type Output = f64;
192
193    fn index(&self, index: (usize, usize)) -> &Self::Output {
194        self.get(index.0, index.1)
195    }
196}
197
198impl Add<Symmetric3> for Symmetric3 {
199    type Output = Symmetric3;
200
201    fn add(self, rhs: Symmetric3) -> Self::Output {
202        Symmetric3 {
203            data: [
204                self.data[0] + rhs.data[0],
205                self.data[1] + rhs.data[1],
206                self.data[2] + rhs.data[2],
207                self.data[3] + rhs.data[3],
208                self.data[4] + rhs.data[4],
209                self.data[5] + rhs.data[5],
210            ],
211        }
212    }
213}
214
215impl Sub<Symmetric3> for Symmetric3 {
216    type Output = Symmetric3;
217
218    fn sub(self, rhs: Symmetric3) -> Self::Output {
219        Symmetric3 {
220            data: [
221                self.data[0] - rhs.data[0],
222                self.data[1] - rhs.data[1],
223                self.data[2] - rhs.data[2],
224                self.data[3] - rhs.data[3],
225                self.data[4] - rhs.data[4],
226                self.data[5] - rhs.data[5],
227            ],
228        }
229    }
230}
231
232impl Mul<&Vector3D> for &Symmetric3 {
233    type Output = Vector3D;
234
235    fn mul(self, rhs: &Vector3D) -> Self::Output {
236        Vector3D::new(
237            self[(0, 0)] * rhs.0[0] + self[(0, 1)] * rhs.0[1] + self[(0, 2)] * rhs.0[2],
238            self[(1, 0)] * rhs.0[0] + self[(1, 1)] * rhs.0[1] + self[(1, 2)] * rhs.0[2],
239            self[(2, 0)] * rhs.0[0] + self[(2, 1)] * rhs.0[1] + self[(2, 2)] * rhs.0[2],
240        )
241    }
242}
243
244impl Mul<f64> for Symmetric3 {
245    type Output = Symmetric3;
246
247    fn mul(self, rhs: f64) -> Self::Output {
248        Symmetric3 {
249            data: [
250                self.data[0] * rhs,
251                self.data[1] * rhs,
252                self.data[2] * rhs,
253                self.data[3] * rhs,
254                self.data[4] * rhs,
255                self.data[5] * rhs,
256            ],
257        }
258    }
259}
260
261impl Mul<Symmetric3> for f64 {
262    type Output = Symmetric3;
263
264    fn mul(self, rhs: Symmetric3) -> Self::Output {
265        Symmetric3 {
266            data: [
267                rhs.data[0] * self,
268                rhs.data[1] * self,
269                rhs.data[2] * self,
270                rhs.data[3] * self,
271                rhs.data[4] * self,
272                rhs.data[5] * self,
273            ],
274        }
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use approx::assert_relative_eq;
282
283    #[test]
284    fn test_symmetric3_creation() {
285        let sym = Symmetric3::new(1.0, 2.0, 3.0, 0.1, 0.2, 0.3);
286        assert_eq!(sym[(0, 0)], 1.0);
287        assert_eq!(sym[(1, 1)], 2.0);
288        assert_eq!(sym[(2, 2)], 3.0);
289        assert_eq!(sym[(0, 1)], 0.1);
290        assert_eq!(sym[(1, 0)], 0.1);
291        assert_eq!(sym[(0, 2)], 0.2);
292        assert_eq!(sym[(2, 0)], 0.2);
293        assert_eq!(sym[(1, 2)], 0.3);
294        assert_eq!(sym[(2, 1)], 0.3);
295    }
296
297    #[test]
298    fn test_symmetric3_to_matrix() {
299        let sym = Symmetric3::new(1.0, 2.0, 3.0, 0.1, 0.2, 0.3);
300        let mat = sym.matrix();
301        let expected = Matrix3::new(1.0, 0.1, 0.2, 0.1, 2.0, 0.3, 0.2, 0.3, 3.0);
302        assert_relative_eq!(mat, expected);
303    }
304
305    #[test]
306    fn test_symmetric3_mul_vector3d() {
307        let sym = Symmetric3::new(1.0, 2.0, 3.0, 0.0, 0.0, 0.0);
308        let vec = Vector3D::new(1.0, 2.0, 3.0);
309        let result = &sym * &vec;
310        let expected = Vector3D::new(1.0, 4.0, 9.0);
311        assert_relative_eq!(result.0, expected.0);
312    }
313
314    #[test]
315    fn test_symmetric3_rotate() {
316        let full = Matrix3::new(1.0, 2.0, 3.0, 2.0, 4.0, 5.0, 3.0, 5.0, 6.0);
317        assert!(full == full.transpose());
318
319        let sym = Symmetric3::new(
320            full[(0, 0)],
321            full[(1, 1)],
322            full[(2, 2)],
323            full[(0, 1)],
324            full[(0, 2)],
325            full[(1, 2)],
326        );
327
328        let rotation = SpatialRotation::from_axis_angle(
329            &Vector3D::new(0.0, 0.0, 1.0),
330            std::f64::consts::FRAC_PI_2,
331        );
332        let rotated_sym = sym.rotate(&rotation);
333
334        let expected = rotation.0 * full * rotation.0.transpose();
335        let expected_sym = Symmetric3::new(
336            expected[(0, 0)],
337            expected[(1, 1)],
338            expected[(2, 2)],
339            expected[(0, 1)],
340            expected[(0, 2)],
341            expected[(1, 2)],
342        );
343        assert_relative_eq!(rotated_sym.matrix(), expected_sym.matrix());
344    }
345}