1use nalgebra::Matrix3;
4use std::ops::{Add, Index, Mul, Sub};
5
6use crate::{motion::SpatialRotation, vector3d::Vector3D};
7
8#[cfg(feature = "python")]
9use numpy::{PyReadonlyArrayDyn, ToPyArray, ndarray::Array2};
10#[cfg(feature = "python")]
11use pyo3::prelude::*;
12
13#[derive(Debug, Clone, Copy, PartialEq, Default)]
17pub struct Symmetric3 {
18 data: [f64; 6],
21}
22
23impl Symmetric3 {
24 #[must_use]
31 pub fn new(m11: f64, m22: f64, m33: f64, m12: f64, m13: f64, m23: f64) -> Self {
32 Self {
33 data: [m11, m22, m33, m12, m13, m23],
34 }
35 }
36
37 #[must_use]
48 pub fn get(&self, row: usize, col: usize) -> &f64 {
49 match (row, col) {
50 (0, 0) => &self.data[0],
51 (1, 1) => &self.data[1],
52 (2, 2) => &self.data[2],
53 (0, 1) | (1, 0) => &self.data[3],
54 (0, 2) | (2, 0) => &self.data[4],
55 (1, 2) | (2, 1) => &self.data[5],
56 _ => panic!("Index out of bounds"),
57 }
58 }
59
60 #[must_use]
62 pub fn zeros() -> Self {
63 Self { data: [0.0; 6] }
64 }
65
66 #[must_use]
68 pub fn identity() -> Self {
69 Self {
70 data: [1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
71 }
72 }
73
74 #[must_use]
79 pub fn from_diagonal(diag: &[f64; 3]) -> Self {
80 Self {
81 data: [diag[0], diag[1], diag[2], 0.0, 0.0, 0.0],
82 }
83 }
84
85 #[must_use]
87 pub fn matrix(&self) -> Matrix3<f64> {
88 Matrix3::new(
89 self.data[0],
90 self.data[3],
91 self.data[4],
92 self.data[3],
93 self.data[1],
94 self.data[5],
95 self.data[4],
96 self.data[5],
97 self.data[2],
98 )
99 }
100
101 #[cfg(feature = "python")]
102 #[must_use]
103 pub fn to_numpy(&self, py: Python) -> Py<PyAny> {
105 let mat = self.matrix();
107 Array2::from_shape_fn((3, 3), |(i, j)| mat[(i, j)])
108 .to_pyarray(py)
109 .into_any()
110 .unbind()
111 }
112
113 #[must_use]
114 pub fn skew_square(v: &Vector3D) -> Symmetric3 {
116 let x = v.0[0];
117 let y = v.0[1];
118 let z = v.0[2];
119
120 Symmetric3::new(
121 -y * y - z * z,
122 -x * x - z * z,
123 -x * x - y * y,
124 x * y,
125 x * z,
126 y * z,
127 )
128 }
129
130 #[must_use]
138 pub fn rotate(&self, rotation: &SpatialRotation) -> Symmetric3 {
139 let r = &rotation.0;
141 let s = &self.matrix();
142 let rsrt = r * s * r.transpose();
143 Symmetric3::new(
144 rsrt[(0, 0)],
145 rsrt[(1, 1)],
146 rsrt[(2, 2)],
147 rsrt[(0, 1)],
148 rsrt[(0, 2)],
149 rsrt[(1, 2)],
150 )
151 }
152
153 #[cfg(feature = "python")]
154 pub fn from_pyarray(array: &PyReadonlyArrayDyn<f64>) -> Result<Self, PyErr> {
161 let array = array.as_array();
162 if array.shape() != [3, 3] {
163 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
164 "Input array must be of shape (3, 3)",
165 ));
166 }
167
168 for i in 0..3 {
170 for j in 0..3 {
171 if array[[i, j]] != array[[j, i]] {
172 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
173 "Input array must be symmetric",
174 ));
175 }
176 }
177 }
178
179 Ok(Symmetric3::new(
180 array[[0, 0]],
181 array[[1, 1]],
182 array[[2, 2]],
183 array[[0, 1]],
184 array[[0, 2]],
185 array[[1, 2]],
186 ))
187 }
188}
189
190impl Index<(usize, usize)> for Symmetric3 {
191 type Output = f64;
192
193 fn index(&self, index: (usize, usize)) -> &Self::Output {
194 self.get(index.0, index.1)
195 }
196}
197
198impl Add<Symmetric3> for Symmetric3 {
199 type Output = Symmetric3;
200
201 fn add(self, rhs: Symmetric3) -> Self::Output {
202 Symmetric3 {
203 data: [
204 self.data[0] + rhs.data[0],
205 self.data[1] + rhs.data[1],
206 self.data[2] + rhs.data[2],
207 self.data[3] + rhs.data[3],
208 self.data[4] + rhs.data[4],
209 self.data[5] + rhs.data[5],
210 ],
211 }
212 }
213}
214
215impl Sub<Symmetric3> for Symmetric3 {
216 type Output = Symmetric3;
217
218 fn sub(self, rhs: Symmetric3) -> Self::Output {
219 Symmetric3 {
220 data: [
221 self.data[0] - rhs.data[0],
222 self.data[1] - rhs.data[1],
223 self.data[2] - rhs.data[2],
224 self.data[3] - rhs.data[3],
225 self.data[4] - rhs.data[4],
226 self.data[5] - rhs.data[5],
227 ],
228 }
229 }
230}
231
232impl Mul<&Vector3D> for &Symmetric3 {
233 type Output = Vector3D;
234
235 fn mul(self, rhs: &Vector3D) -> Self::Output {
236 Vector3D::new(
237 self[(0, 0)] * rhs.0[0] + self[(0, 1)] * rhs.0[1] + self[(0, 2)] * rhs.0[2],
238 self[(1, 0)] * rhs.0[0] + self[(1, 1)] * rhs.0[1] + self[(1, 2)] * rhs.0[2],
239 self[(2, 0)] * rhs.0[0] + self[(2, 1)] * rhs.0[1] + self[(2, 2)] * rhs.0[2],
240 )
241 }
242}
243
244impl Mul<f64> for Symmetric3 {
245 type Output = Symmetric3;
246
247 fn mul(self, rhs: f64) -> Self::Output {
248 Symmetric3 {
249 data: [
250 self.data[0] * rhs,
251 self.data[1] * rhs,
252 self.data[2] * rhs,
253 self.data[3] * rhs,
254 self.data[4] * rhs,
255 self.data[5] * rhs,
256 ],
257 }
258 }
259}
260
261impl Mul<Symmetric3> for f64 {
262 type Output = Symmetric3;
263
264 fn mul(self, rhs: Symmetric3) -> Self::Output {
265 Symmetric3 {
266 data: [
267 rhs.data[0] * self,
268 rhs.data[1] * self,
269 rhs.data[2] * self,
270 rhs.data[3] * self,
271 rhs.data[4] * self,
272 rhs.data[5] * self,
273 ],
274 }
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use approx::assert_relative_eq;
282
283 #[test]
284 fn test_symmetric3_creation() {
285 let sym = Symmetric3::new(1.0, 2.0, 3.0, 0.1, 0.2, 0.3);
286 assert_eq!(sym[(0, 0)], 1.0);
287 assert_eq!(sym[(1, 1)], 2.0);
288 assert_eq!(sym[(2, 2)], 3.0);
289 assert_eq!(sym[(0, 1)], 0.1);
290 assert_eq!(sym[(1, 0)], 0.1);
291 assert_eq!(sym[(0, 2)], 0.2);
292 assert_eq!(sym[(2, 0)], 0.2);
293 assert_eq!(sym[(1, 2)], 0.3);
294 assert_eq!(sym[(2, 1)], 0.3);
295 }
296
297 #[test]
298 fn test_symmetric3_to_matrix() {
299 let sym = Symmetric3::new(1.0, 2.0, 3.0, 0.1, 0.2, 0.3);
300 let mat = sym.matrix();
301 let expected = Matrix3::new(1.0, 0.1, 0.2, 0.1, 2.0, 0.3, 0.2, 0.3, 3.0);
302 assert_relative_eq!(mat, expected);
303 }
304
305 #[test]
306 fn test_symmetric3_mul_vector3d() {
307 let sym = Symmetric3::new(1.0, 2.0, 3.0, 0.0, 0.0, 0.0);
308 let vec = Vector3D::new(1.0, 2.0, 3.0);
309 let result = &sym * &vec;
310 let expected = Vector3D::new(1.0, 4.0, 9.0);
311 assert_relative_eq!(result.0, expected.0);
312 }
313
314 #[test]
315 fn test_symmetric3_rotate() {
316 let full = Matrix3::new(1.0, 2.0, 3.0, 2.0, 4.0, 5.0, 3.0, 5.0, 6.0);
317 assert!(full == full.transpose());
318
319 let sym = Symmetric3::new(
320 full[(0, 0)],
321 full[(1, 1)],
322 full[(2, 2)],
323 full[(0, 1)],
324 full[(0, 2)],
325 full[(1, 2)],
326 );
327
328 let rotation = SpatialRotation::from_axis_angle(
329 &Vector3D::new(0.0, 0.0, 1.0),
330 std::f64::consts::FRAC_PI_2,
331 );
332 let rotated_sym = sym.rotate(&rotation);
333
334 let expected = rotation.0 * full * rotation.0.transpose();
335 let expected_sym = Symmetric3::new(
336 expected[(0, 0)],
337 expected[(1, 1)],
338 expected[(2, 2)],
339 expected[(0, 1)],
340 expected[(0, 2)],
341 expected[(1, 2)],
342 );
343 assert_relative_eq!(rotated_sym.matrix(), expected_sym.matrix());
344 }
345}