1use crate::{Mat3, Vec3};
2use std::ops::{Add, Mul, Sub};
3use serde::{Deserialize, Serialize};
4
5#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
11pub struct SpatialVector {
12 pub w: Vec3, pub v: Vec3, }
15
16impl SpatialVector {
17 pub const ZERO: Self = Self {
18 w: Vec3::ZERO,
19 v: Vec3::ZERO,
20 };
21
22 pub fn new(w: Vec3, v: Vec3) -> Self {
23 Self { w, v }
24 }
25
26 pub fn dot(self, other: Self) -> f32 {
28 self.w.dot(other.w) + self.v.dot(other.v)
29 }
30
31 pub fn cross_motion(self, other: Self) -> Self {
33 Self {
34 w: self.w.cross(other.w),
35 v: self.w.cross(other.v) + self.v.cross(other.w),
36 }
37 }
38
39 pub fn cross_force(self, f: Self) -> Self {
41 Self {
42 w: self.w.cross(f.w) + self.v.cross(f.v),
43 v: self.w.cross(f.v),
44 }
45 }
46
47 pub fn outer_product(self, other: Self) -> SpatialMatrix {
49 let outer = |a: Vec3, b: Vec3| -> Mat3 { Mat3::from_cols(a * b.x, a * b.y, a * b.z) };
50 SpatialMatrix {
51 m00: outer(self.w, other.w),
52 m01: outer(self.w, other.v),
53 m10: outer(self.v, other.w),
54 m11: outer(self.v, other.v),
55 }
56 }
57}
58
59impl Add for SpatialVector {
60 type Output = Self;
61 fn add(self, rhs: Self) -> Self {
62 Self {
63 w: self.w + rhs.w,
64 v: self.v + rhs.v,
65 }
66 }
67}
68
69impl Sub for SpatialVector {
70 type Output = Self;
71 fn sub(self, rhs: Self) -> Self {
72 Self {
73 w: self.w - rhs.w,
74 v: self.v - rhs.v,
75 }
76 }
77}
78
79impl Mul<f32> for SpatialVector {
80 type Output = Self;
81 fn mul(self, rhs: f32) -> Self {
82 Self {
83 w: self.w * rhs,
84 v: self.v * rhs,
85 }
86 }
87}
88
89#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
91pub struct SpatialMatrix {
92 pub m00: Mat3,
93 pub m01: Mat3,
94 pub m10: Mat3,
95 pub m11: Mat3,
96}
97
98impl SpatialMatrix {
99 pub const ZERO: Self = Self {
100 m00: Mat3::ZERO,
101 m01: Mat3::ZERO,
102 m10: Mat3::ZERO,
103 m11: Mat3::ZERO,
104 };
105
106 pub fn mul_vec(self, v: SpatialVector) -> SpatialVector {
107 SpatialVector {
108 w: self.m00 * v.w + self.m01 * v.v,
109 v: self.m10 * v.w + self.m11 * v.v,
110 }
111 }
112
113 pub fn mul_scalar(self, scalar: f32) -> Self {
114 Self {
115 m00: self.m00 * scalar,
116 m01: self.m01 * scalar,
117 m10: self.m10 * scalar,
118 m11: self.m11 * scalar,
119 }
120 }
121}
122
123impl Add for SpatialMatrix {
124 type Output = Self;
125 fn add(self, rhs: Self) -> Self {
126 Self {
127 m00: self.m00 + rhs.m00,
128 m01: self.m01 + rhs.m01,
129 m10: self.m10 + rhs.m10,
130 m11: self.m11 + rhs.m11,
131 }
132 }
133}
134
135impl Sub for SpatialMatrix {
136 type Output = Self;
137 fn sub(self, rhs: Self) -> Self {
138 Self {
139 m00: self.m00 - rhs.m00,
140 m01: self.m01 - rhs.m01,
141 m10: self.m10 - rhs.m10,
142 m11: self.m11 - rhs.m11,
143 }
144 }
145}
146
147#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
150pub struct SpatialInertia {
151 pub rot: Mat3, pub mass: f32, pub com: Vec3, }
155
156impl SpatialInertia {
157 pub fn new(mass: f32, rot_inertia: Mat3, com_offset: Vec3) -> Self {
158 Self {
159 rot: rot_inertia,
160 mass,
161 com: com_offset,
162 }
163 }
164
165 pub fn from_mass_inertia(mass: f32, inertia: Mat3) -> Self {
166 Self {
167 rot: inertia,
168 mass,
169 com: Vec3::ZERO,
170 }
171 }
172
173 pub fn mul_vec(self, v: SpatialVector) -> SpatialVector {
175 let com_cross_v = self.com.cross(v.v);
179 let com_cross_w = self.com.cross(v.w);
180
181 let mut force_w = self.rot.mul_vec3(v.w) + com_cross_v * self.mass;
182 if self.com.length_squared() > 1e-12 {
184 force_w -= self.com.cross(com_cross_w) * self.mass;
185 }
186
187 let force_v = v.v * self.mass - com_cross_w * self.mass;
188
189 SpatialVector::new(force_w, force_v)
190 }
191
192}
193
194impl Add for SpatialInertia {
195 type Output = Self;
196
197 fn add(self, other: Self) -> Self {
198 let total_mass = self.mass + other.mass;
199 if total_mass == 0.0 {
200 return Self::from_mass_inertia(0.0, Mat3::ZERO);
201 }
202 let total_com = (self.com * self.mass + other.com * other.mass) * (1.0 / total_mass);
203
204 let shift = |inertia: &SpatialInertia| -> Mat3 {
206 let d = inertia.com - total_com; let d_sq = d.dot(d);
208 inertia.rot + Mat3::from_diagonal(Vec3::splat(inertia.mass * d_sq))
210 - Mat3::from_cols(d * d.x, d * d.y, d * d.z) * inertia.mass
211 };
212
213 Self {
214 mass: total_mass,
215 com: total_com,
216 rot: shift(&self) + shift(&other),
217 }
218 }
219}
220
221impl SpatialInertia {
222 pub fn to_matrix(self) -> SpatialMatrix {
224 let m = self.mass;
225 let c = self.com;
226
227 let c_cross = Mat3::from_cols(
228 Vec3::new(0.0, c.z, -c.y),
229 Vec3::new(-c.z, 0.0, c.x),
230 Vec3::new(c.y, -c.x, 0.0),
231 );
232 let mc_cross = c_cross * m;
233 let mc_cross_t = mc_cross.transpose();
234
235 let c_cross_c_cross = c_cross * c_cross;
236 let rot_shifted = self.rot - c_cross_c_cross * m;
237
238 SpatialMatrix {
239 m00: rot_shifted,
240 m01: mc_cross,
241 m10: mc_cross_t,
242 m11: Mat3::from_diagonal(Vec3::splat(m)),
243 }
244 }
245}
246
247#[cfg(test)]
252mod tests {
253 use super::*;
254
255 const EPS: f32 = 1e-5;
256
257 fn vec3_approx(a: Vec3, b: Vec3) -> bool {
258 (a - b).length() < EPS
259 }
260
261 fn mat3_approx(a: Mat3, b: Mat3) -> bool {
262 let diff = a - b;
263 diff.x_axis.length() < EPS && diff.y_axis.length() < EPS && diff.z_axis.length() < EPS
265 }
266
267 #[test]
272 fn spatial_vector_zero() {
273 let z = SpatialVector::ZERO;
274 assert_eq!(z.w, Vec3::ZERO);
275 assert_eq!(z.v, Vec3::ZERO);
276 }
277
278 #[test]
279 fn spatial_vector_new() {
280 let w = Vec3::new(1.0, 2.0, 3.0);
281 let v = Vec3::new(4.0, 5.0, 6.0);
282 let sv = SpatialVector::new(w, v);
283 assert_eq!(sv.w, w);
284 assert_eq!(sv.v, v);
285 }
286
287 #[test]
288 fn spatial_vector_add() {
289 let a = SpatialVector::new(Vec3::new(1.0, 0.0, 0.0), Vec3::new(0.0, 1.0, 0.0));
290 let b = SpatialVector::new(Vec3::new(0.0, 2.0, 0.0), Vec3::new(3.0, 0.0, 0.0));
291 let c = a + b;
292 assert!(vec3_approx(c.w, Vec3::new(1.0, 2.0, 0.0)));
293 assert!(vec3_approx(c.v, Vec3::new(3.0, 1.0, 0.0)));
294 }
295
296 #[test]
297 fn spatial_vector_sub() {
298 let a = SpatialVector::new(Vec3::new(5.0, 4.0, 3.0), Vec3::new(2.0, 1.0, 0.0));
299 let b = SpatialVector::new(Vec3::new(1.0, 1.0, 1.0), Vec3::new(1.0, 1.0, 0.0));
300 let c = a - b;
301 assert!(vec3_approx(c.w, Vec3::new(4.0, 3.0, 2.0)));
302 assert!(vec3_approx(c.v, Vec3::new(1.0, 0.0, 0.0)));
303 }
304
305 #[test]
306 fn spatial_vector_scalar_mul() {
307 let a = SpatialVector::new(Vec3::new(1.0, 2.0, 3.0), Vec3::new(4.0, 5.0, 6.0));
308 let b = a * 2.0;
309 assert!(vec3_approx(b.w, Vec3::new(2.0, 4.0, 6.0)));
310 assert!(vec3_approx(b.v, Vec3::new(8.0, 10.0, 12.0)));
311 }
312
313 #[test]
314 fn spatial_vector_dot() {
315 let a = SpatialVector::new(Vec3::new(1.0, 0.0, 0.0), Vec3::new(0.0, 1.0, 0.0));
316 let b = SpatialVector::new(Vec3::new(3.0, 4.0, 0.0), Vec3::new(5.0, 6.0, 0.0));
317 assert!((a.dot(b) - 9.0).abs() < EPS);
319 }
320
321 #[test]
322 fn spatial_vector_dot_self_is_length_squared() {
323 let a = SpatialVector::new(Vec3::new(1.0, 2.0, 3.0), Vec3::new(4.0, 5.0, 6.0));
324 let expected = 1.0 + 4.0 + 9.0 + 16.0 + 25.0 + 36.0; assert!((a.dot(a) - expected).abs() < EPS);
326 }
327
328 #[test]
329 fn spatial_vector_cross_motion_self_is_zero() {
330 let v = SpatialVector::new(Vec3::new(1.0, 2.0, 3.0), Vec3::new(4.0, 5.0, 6.0));
332 let result = v.cross_motion(v);
333 assert!(vec3_approx(result.w, Vec3::ZERO));
334 }
338
339 #[test]
340 fn spatial_vector_cross_motion_basis() {
341 let wz = SpatialVector::new(Vec3::Z, Vec3::ZERO);
343 let wx = SpatialVector::new(Vec3::X, Vec3::ZERO);
344 let result = wz.cross_motion(wx);
345 assert!(vec3_approx(result.w, Vec3::Y));
348 assert!(vec3_approx(result.v, Vec3::ZERO));
349 }
350
351 #[test]
352 fn spatial_vector_cross_force_basis() {
353 let wz = SpatialVector::new(Vec3::Z, Vec3::ZERO);
355 let fx = SpatialVector::new(Vec3::ZERO, Vec3::X);
356 let result = wz.cross_force(fx);
357 assert!(vec3_approx(result.w, Vec3::ZERO));
359 assert!(vec3_approx(result.v, Vec3::Y));
361 }
362
363 #[test]
368 fn spatial_vector_outer_product_structure() {
369 let a = SpatialVector::new(Vec3::X, Vec3::ZERO);
370 let b = SpatialVector::new(Vec3::Y, Vec3::ZERO);
371 let op = a.outer_product(b);
372 assert!(vec3_approx(op.m00.y_axis, Vec3::X));
374 assert!(mat3_approx(op.m01, Mat3::ZERO));
376 assert!(mat3_approx(op.m10, Mat3::ZERO));
377 assert!(mat3_approx(op.m11, Mat3::ZERO));
378 }
379
380 #[test]
385 fn spatial_matrix_zero() {
386 let z = SpatialMatrix::ZERO;
387 assert!(mat3_approx(z.m00, Mat3::ZERO));
388 assert!(mat3_approx(z.m01, Mat3::ZERO));
389 assert!(mat3_approx(z.m10, Mat3::ZERO));
390 assert!(mat3_approx(z.m11, Mat3::ZERO));
391 }
392
393 #[test]
394 fn spatial_matrix_mul_vec_zero() {
395 let m = SpatialMatrix::ZERO;
396 let v = SpatialVector::new(Vec3::new(1.0, 2.0, 3.0), Vec3::new(4.0, 5.0, 6.0));
397 let result = m.mul_vec(v);
398 assert!(vec3_approx(result.w, Vec3::ZERO));
399 assert!(vec3_approx(result.v, Vec3::ZERO));
400 }
401
402 #[test]
403 fn spatial_matrix_mul_vec_identity_like() {
404 let m = SpatialMatrix {
406 m00: Mat3::IDENTITY,
407 m01: Mat3::ZERO,
408 m10: Mat3::ZERO,
409 m11: Mat3::IDENTITY,
410 };
411 let v = SpatialVector::new(Vec3::new(1.0, 2.0, 3.0), Vec3::new(4.0, 5.0, 6.0));
412 let result = m.mul_vec(v);
413 assert!(vec3_approx(result.w, v.w));
414 assert!(vec3_approx(result.v, v.v));
415 }
416
417 #[test]
418 fn spatial_matrix_mul_scalar() {
419 let m = SpatialMatrix {
420 m00: Mat3::IDENTITY,
421 m01: Mat3::IDENTITY,
422 m10: Mat3::IDENTITY,
423 m11: Mat3::IDENTITY,
424 };
425 let scaled = m.mul_scalar(3.0);
426 let expected = Mat3::from_diagonal(Vec3::splat(3.0));
427 assert!(mat3_approx(scaled.m00, expected));
428 assert!(mat3_approx(scaled.m11, expected));
429 }
430
431 #[test]
432 fn spatial_matrix_add_sub() {
433 let a = SpatialMatrix {
434 m00: Mat3::IDENTITY,
435 m01: Mat3::ZERO,
436 m10: Mat3::ZERO,
437 m11: Mat3::IDENTITY,
438 };
439 let b = a;
440 let sum = a + b;
441 assert!(mat3_approx(sum.m00, Mat3::from_diagonal(Vec3::splat(2.0))));
442
443 let diff = sum - a;
444 assert!(mat3_approx(diff.m00, Mat3::IDENTITY));
445 }
446
447 #[test]
452 fn spatial_inertia_from_mass_inertia() {
453 let si = SpatialInertia::from_mass_inertia(5.0, Mat3::IDENTITY);
454 assert_eq!(si.mass, 5.0);
455 assert!(mat3_approx(si.rot, Mat3::IDENTITY));
456 assert!(vec3_approx(si.com, Vec3::ZERO));
457 }
458
459 #[test]
460 fn spatial_inertia_mul_vec_zero_com() {
461 let si = SpatialInertia::from_mass_inertia(2.0, Mat3::IDENTITY);
463 let vel = SpatialVector::new(Vec3::new(1.0, 0.0, 0.0), Vec3::new(0.0, 3.0, 0.0));
464 let result = si.mul_vec(vel);
465 assert!(vec3_approx(result.w, Vec3::new(1.0, 0.0, 0.0)));
467 assert!(vec3_approx(result.v, Vec3::new(0.0, 6.0, 0.0)));
469 }
470
471 #[test]
472 fn spatial_inertia_mul_vec_matches_to_matrix() {
473 let si = SpatialInertia::new(3.0, Mat3::from_diagonal(Vec3::new(2.0, 4.0, 6.0)), Vec3::new(0.5, -0.3, 0.1));
475 let vel = SpatialVector::new(Vec3::new(1.0, -0.5, 0.8), Vec3::new(-0.2, 0.7, -0.4));
476
477 let result_direct = si.mul_vec(vel);
478 let result_matrix = si.to_matrix().mul_vec(vel);
479
480 assert!(
481 vec3_approx(result_direct.w, result_matrix.w),
482 "mul_vec.w = {:?}, to_matrix().mul_vec.w = {:?}",
483 result_direct.w, result_matrix.w
484 );
485 assert!(
486 vec3_approx(result_direct.v, result_matrix.v),
487 "mul_vec.v = {:?}, to_matrix().mul_vec.v = {:?}",
488 result_direct.v, result_matrix.v
489 );
490 }
491
492 #[test]
493 fn spatial_inertia_mul_vec_matches_to_matrix_zero_com() {
494 let si = SpatialInertia::from_mass_inertia(5.0, Mat3::from_diagonal(Vec3::new(1.0, 2.0, 3.0)));
496 let vel = SpatialVector::new(Vec3::new(1.0, 2.0, 3.0), Vec3::new(4.0, 5.0, 6.0));
497
498 let result_direct = si.mul_vec(vel);
499 let result_matrix = si.to_matrix().mul_vec(vel);
500
501 assert!(vec3_approx(result_direct.w, result_matrix.w));
502 assert!(vec3_approx(result_direct.v, result_matrix.v));
503 }
504
505 #[test]
510 fn spatial_inertia_add_same_com() {
511 let a = SpatialInertia::from_mass_inertia(2.0, Mat3::IDENTITY);
513 let b = SpatialInertia::from_mass_inertia(3.0, Mat3::from_diagonal(Vec3::splat(2.0)));
514 let total = a + b;
515
516 assert!((total.mass - 5.0).abs() < EPS);
517 assert!(vec3_approx(total.com, Vec3::ZERO));
518 let expected_rot = Mat3::from_diagonal(Vec3::splat(3.0)); assert!(mat3_approx(total.rot, expected_rot));
521 }
522
523 #[test]
524 fn spatial_inertia_add_different_com() {
525 let a = SpatialInertia::new(1.0, Mat3::ZERO, Vec3::new(1.0, 0.0, 0.0));
527 let b = SpatialInertia::new(1.0, Mat3::ZERO, Vec3::new(-1.0, 0.0, 0.0));
528 let total = a + b;
529
530 assert!((total.mass - 2.0).abs() < EPS);
531 assert!(vec3_approx(total.com, Vec3::ZERO));
533 assert!(total.rot.y_axis.y > 0.5); assert!(total.rot.z_axis.z > 0.5);
537 }
538
539 #[test]
540 fn spatial_inertia_add_zero_masses() {
541 let a = SpatialInertia::from_mass_inertia(0.0, Mat3::ZERO);
542 let b = SpatialInertia::from_mass_inertia(0.0, Mat3::ZERO);
543 let total = a + b;
544 assert!((total.mass).abs() < EPS);
545 }
546
547 #[test]
552 fn spatial_inertia_to_matrix_zero_com() {
553 let si = SpatialInertia::from_mass_inertia(4.0, Mat3::from_diagonal(Vec3::new(1.0, 2.0, 3.0)));
555 let mat = si.to_matrix();
556
557 assert!(mat3_approx(mat.m00, Mat3::from_diagonal(Vec3::new(1.0, 2.0, 3.0))));
559 assert!(mat3_approx(mat.m01, Mat3::ZERO));
561 assert!(mat3_approx(mat.m10, Mat3::ZERO));
563 assert!(mat3_approx(mat.m11, Mat3::from_diagonal(Vec3::splat(4.0))));
565 }
566
567 #[test]
568 fn spatial_inertia_to_matrix_skew_symmetric() {
569 let si = SpatialInertia::new(2.0, Mat3::IDENTITY, Vec3::new(1.0, 2.0, 3.0));
571 let mat = si.to_matrix();
572
573 let _sum = mat.m01 + mat.m10.transpose();
576 let m01_t = mat.m01.transpose();
578 assert!(
579 mat3_approx(m01_t, mat.m10),
580 "m01^T should equal m10 (skew symmetry)"
581 );
582 }
583}
584