1use nalgebra::Matrix3;
4use std::ops::{Add, Index, Mul, Sub};
5
6use crate::{motion::SpatialRotation, vector3d::Vector3D};
7
8#[cfg(feature = "python")]
9use numpy::{PyReadonlyArrayDyn, ToPyArray, ndarray::Array2};
10#[cfg(feature = "python")]
11use pyo3::prelude::*;
12
13#[derive(Debug, Clone, Copy, PartialEq, Default)]
17pub struct Symmetric3 {
18 data: [f64; 6],
21}
22
23impl Symmetric3 {
24 #[must_use]
31 pub fn new(m11: f64, m22: f64, m33: f64, m12: f64, m13: f64, m23: f64) -> Self {
32 Self {
33 data: [m11, m22, m33, m12, m13, m23],
34 }
35 }
36
37 #[must_use]
48 pub fn get(&self, row: usize, col: usize) -> &f64 {
49 match (row, col) {
50 (0, 0) => &self.data[0],
51 (1, 1) => &self.data[1],
52 (2, 2) => &self.data[2],
53 (0, 1) | (1, 0) => &self.data[3],
54 (0, 2) | (2, 0) => &self.data[4],
55 (1, 2) | (2, 1) => &self.data[5],
56 _ => panic!("Index out of bounds"),
57 }
58 }
59
60 #[must_use]
62 pub fn zeros() -> Self {
63 Self { data: [0.0; 6] }
64 }
65
66 #[must_use]
68 pub fn identity() -> Self {
69 Self {
70 data: [1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
71 }
72 }
73
74 #[must_use]
79 pub fn from_diagonal(diag: &[f64; 3]) -> Self {
80 Self {
81 data: [diag[0], diag[1], diag[2], 0.0, 0.0, 0.0],
82 }
83 }
84
85 #[must_use]
87 pub fn matrix(&self) -> Matrix3<f64> {
88 Matrix3::new(
89 self.data[0],
90 self.data[3],
91 self.data[4],
92 self.data[3],
93 self.data[1],
94 self.data[5],
95 self.data[4],
96 self.data[5],
97 self.data[2],
98 )
99 }
100
101 #[cfg(feature = "python")]
102 #[must_use]
103 pub fn to_numpy(&self, py: Python) -> Py<PyAny> {
105 let mat = self.matrix();
106 Array2::from_shape_fn((3, 3), |(i, j)| mat[(i, j)])
107 .to_pyarray(py)
108 .into_any()
109 .unbind()
110 }
111
112 #[must_use]
113 pub fn skew_square(v: Vector3D) -> Symmetric3 {
114 let x = v.0[0];
115 let y = v.0[1];
116 let z = v.0[2];
117
118 Symmetric3::new(
119 -y * y - z * z,
120 -x * x - z * z,
121 -x * x - y * y,
122 x * y,
123 x * z,
124 y * z,
125 )
126 }
127
128 #[must_use]
136 pub fn rotate(&self, rotation: &SpatialRotation) -> Symmetric3 {
137 let r = &rotation.0;
139 let s = &self.matrix();
140 let rsrt = r * s * r.transpose();
141 Symmetric3::new(
142 rsrt[(0, 0)],
143 rsrt[(1, 1)],
144 rsrt[(2, 2)],
145 rsrt[(0, 1)],
146 rsrt[(0, 2)],
147 rsrt[(1, 2)],
148 )
149 }
150
151 #[cfg(feature = "python")]
152 pub fn from_pyarray(array: &PyReadonlyArrayDyn<f64>) -> Result<Self, PyErr> {
159 let array = array.as_array();
160 if array.shape() != [3, 3] {
161 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
162 "Input array must be of shape (3, 3)",
163 ));
164 }
165
166 for i in 0..3 {
168 for j in 0..3 {
169 if array[[i, j]] != array[[j, i]] {
170 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
171 "Input array must be symmetric",
172 ));
173 }
174 }
175 }
176
177 Ok(Symmetric3::new(
178 array[[0, 0]],
179 array[[1, 1]],
180 array[[2, 2]],
181 array[[0, 1]],
182 array[[0, 2]],
183 array[[1, 2]],
184 ))
185 }
186}
187
188impl Index<(usize, usize)> for Symmetric3 {
189 type Output = f64;
190
191 fn index(&self, index: (usize, usize)) -> &Self::Output {
192 self.get(index.0, index.1)
193 }
194}
195
196impl Add<Symmetric3> for Symmetric3 {
197 type Output = Symmetric3;
198
199 fn add(self, rhs: Symmetric3) -> Self::Output {
200 Symmetric3 {
201 data: [
202 self.data[0] + rhs.data[0],
203 self.data[1] + rhs.data[1],
204 self.data[2] + rhs.data[2],
205 self.data[3] + rhs.data[3],
206 self.data[4] + rhs.data[4],
207 self.data[5] + rhs.data[5],
208 ],
209 }
210 }
211}
212
213impl Sub<Symmetric3> for Symmetric3 {
214 type Output = Symmetric3;
215
216 fn sub(self, rhs: Symmetric3) -> Self::Output {
217 Symmetric3 {
218 data: [
219 self.data[0] - rhs.data[0],
220 self.data[1] - rhs.data[1],
221 self.data[2] - rhs.data[2],
222 self.data[3] - rhs.data[3],
223 self.data[4] - rhs.data[4],
224 self.data[5] - rhs.data[5],
225 ],
226 }
227 }
228}
229
230impl Mul<&Vector3D> for &Symmetric3 {
231 type Output = Vector3D;
232
233 fn mul(self, rhs: &Vector3D) -> Self::Output {
234 Vector3D::new(
235 self[(0, 0)] * rhs.0[0] + self[(0, 1)] * rhs.0[1] + self[(0, 2)] * rhs.0[2],
236 self[(1, 0)] * rhs.0[0] + self[(1, 1)] * rhs.0[1] + self[(1, 2)] * rhs.0[2],
237 self[(2, 0)] * rhs.0[0] + self[(2, 1)] * rhs.0[1] + self[(2, 2)] * rhs.0[2],
238 )
239 }
240}
241
242impl Mul<f64> for Symmetric3 {
243 type Output = Symmetric3;
244
245 fn mul(self, rhs: f64) -> Self::Output {
246 Symmetric3 {
247 data: [
248 self.data[0] * rhs,
249 self.data[1] * rhs,
250 self.data[2] * rhs,
251 self.data[3] * rhs,
252 self.data[4] * rhs,
253 self.data[5] * rhs,
254 ],
255 }
256 }
257}
258
259impl Mul<Symmetric3> for f64 {
260 type Output = Symmetric3;
261
262 fn mul(self, rhs: Symmetric3) -> Self::Output {
263 Symmetric3 {
264 data: [
265 rhs.data[0] * self,
266 rhs.data[1] * self,
267 rhs.data[2] * self,
268 rhs.data[3] * self,
269 rhs.data[4] * self,
270 rhs.data[5] * self,
271 ],
272 }
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use approx::assert_relative_eq;
280
281 #[test]
282 fn test_symmetric3_creation() {
283 let sym = Symmetric3::new(1.0, 2.0, 3.0, 0.1, 0.2, 0.3);
284 assert_eq!(sym[(0, 0)], 1.0);
285 assert_eq!(sym[(1, 1)], 2.0);
286 assert_eq!(sym[(2, 2)], 3.0);
287 assert_eq!(sym[(0, 1)], 0.1);
288 assert_eq!(sym[(1, 0)], 0.1);
289 assert_eq!(sym[(0, 2)], 0.2);
290 assert_eq!(sym[(2, 0)], 0.2);
291 assert_eq!(sym[(1, 2)], 0.3);
292 assert_eq!(sym[(2, 1)], 0.3);
293 }
294
295 #[test]
296 fn test_symmetric3_to_matrix() {
297 let sym = Symmetric3::new(1.0, 2.0, 3.0, 0.1, 0.2, 0.3);
298 let mat = sym.matrix();
299 let expected = Matrix3::new(1.0, 0.1, 0.2, 0.1, 2.0, 0.3, 0.2, 0.3, 3.0);
300 assert_relative_eq!(mat, expected);
301 }
302
303 #[test]
304 fn test_symmetric3_mul_vector3d() {
305 let sym = Symmetric3::new(1.0, 2.0, 3.0, 0.0, 0.0, 0.0);
306 let vec = Vector3D::new(1.0, 2.0, 3.0);
307 let result = &sym * &vec;
308 let expected = Vector3D::new(1.0, 4.0, 9.0);
309 assert_relative_eq!(result.0, expected.0);
310 }
311
312 #[test]
313 fn test_symmetric3_rotate() {
314 let full = Matrix3::new(1.0, 2.0, 3.0, 2.0, 4.0, 5.0, 3.0, 5.0, 6.0);
315 assert!(full == full.transpose());
316
317 let sym = Symmetric3::new(
318 full[(0, 0)],
319 full[(1, 1)],
320 full[(2, 2)],
321 full[(0, 1)],
322 full[(0, 2)],
323 full[(1, 2)],
324 );
325
326 let rotation = SpatialRotation::from_axis_angle(
327 &Vector3D::new(0.0, 0.0, 1.0),
328 std::f64::consts::FRAC_PI_2,
329 );
330 let rotated_sym = sym.rotate(&rotation);
331
332 let expected = rotation.0 * full * rotation.0.transpose();
333 let expected_sym = Symmetric3::new(
334 expected[(0, 0)],
335 expected[(1, 1)],
336 expected[(2, 2)],
337 expected[(0, 1)],
338 expected[(0, 2)],
339 expected[(1, 2)],
340 );
341 assert_relative_eq!(rotated_sym.matrix(), expected_sym.matrix());
342 }
343}