1use std::fmt::Display;
2
3use approx::{AbsDiffEq, RelativeEq};
4use auto_ops::{impl_op_ex, impl_op_ex_commutative};
5use nalgebra::{Vector3, Vector4};
6
7use crate::Float;
8use serde::{Deserialize, Serialize};
9
10#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
12pub struct Vec3 {
13 pub x: Float,
15 pub y: Float,
17 pub z: Float,
19}
20
21impl Display for Vec3 {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 write!(f, "[{:6.3}, {:6.3}, {:6.3}]", self.x, self.y, self.z)
24 }
25}
26
27impl AbsDiffEq for Vec3 {
28 type Epsilon = <Float as approx::AbsDiffEq>::Epsilon;
29
30 fn default_epsilon() -> Self::Epsilon {
31 Float::default_epsilon()
32 }
33
34 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
35 Float::abs_diff_eq(&self.x, &other.x, epsilon)
36 && Float::abs_diff_eq(&self.y, &other.y, epsilon)
37 && Float::abs_diff_eq(&self.z, &other.z, epsilon)
38 }
39}
40impl RelativeEq for Vec3 {
41 fn default_max_relative() -> Self::Epsilon {
42 Float::default_max_relative()
43 }
44
45 fn relative_eq(
46 &self,
47 other: &Self,
48 epsilon: Self::Epsilon,
49 max_relative: Self::Epsilon,
50 ) -> bool {
51 Float::relative_eq(&self.x, &other.x, epsilon, max_relative)
52 && Float::relative_eq(&self.y, &other.y, epsilon, max_relative)
53 && Float::relative_eq(&self.z, &other.z, epsilon, max_relative)
54 }
55}
56
57impl From<Vec3> for Vector3<Float> {
58 fn from(value: Vec3) -> Self {
59 Vector3::new(value.x, value.y, value.z)
60 }
61}
62
63impl From<Vector3<Float>> for Vec3 {
64 fn from(value: Vector3<Float>) -> Self {
65 Vec3::new(value.x, value.y, value.z)
66 }
67}
68
69impl From<Vec<Float>> for Vec3 {
70 fn from(value: Vec<Float>) -> Self {
71 Self {
72 x: value[0],
73 y: value[1],
74 z: value[2],
75 }
76 }
77}
78
79impl From<Vec3> for Vec<Float> {
80 fn from(value: Vec3) -> Self {
81 vec![value.x, value.y, value.z]
82 }
83}
84
85impl From<[Float; 3]> for Vec3 {
86 fn from(value: [Float; 3]) -> Self {
87 Self {
88 x: value[0],
89 y: value[1],
90 z: value[2],
91 }
92 }
93}
94
95impl From<Vec3> for [Float; 3] {
96 fn from(value: Vec3) -> Self {
97 [value.x, value.y, value.z]
98 }
99}
100
101impl Default for Vec3 {
102 fn default() -> Self {
103 Vec3::zero()
104 }
105}
106
107impl Vec3 {
108 pub fn new(x: Float, y: Float, z: Float) -> Self {
110 Vec3 { x, y, z }
111 }
112
113 pub const fn zero() -> Self {
115 Vec3 {
116 x: 0.0,
117 y: 0.0,
118 z: 0.0,
119 }
120 }
121
122 pub const fn x() -> Self {
124 Vec3 {
125 x: 1.0,
126 y: 0.0,
127 z: 0.0,
128 }
129 }
130
131 pub const fn y() -> Self {
133 Vec3 {
134 x: 0.0,
135 y: 1.0,
136 z: 0.0,
137 }
138 }
139
140 pub const fn z() -> Self {
142 Vec3 {
143 x: 0.0,
144 y: 0.0,
145 z: 1.0,
146 }
147 }
148
149 pub fn px(&self) -> Float {
151 self.x
152 }
153
154 pub fn py(&self) -> Float {
156 self.y
157 }
158
159 pub fn pz(&self) -> Float {
161 self.z
162 }
163
164 pub fn with_mass(&self, mass: Float) -> Vec4 {
166 let e = Float::sqrt(mass.powi(2) + self.mag2());
167 Vec4::new(self.px(), self.py(), self.pz(), e)
168 }
169
170 pub fn with_energy(&self, energy: Float) -> Vec4 {
172 Vec4::new(self.px(), self.py(), self.pz(), energy)
173 }
174
175 pub fn dot(&self, other: &Vec3) -> Float {
177 self.x * other.x + self.y * other.y + self.z * other.z
178 }
179
180 pub fn cross(&self, other: &Vec3) -> Vec3 {
182 Vec3::new(
183 self.y * other.z - other.y * self.z,
184 self.z * other.x - other.z * self.x,
185 self.x * other.y - other.x * self.y,
186 )
187 }
188
189 pub fn mag(&self) -> Float {
191 Float::sqrt(self.mag2())
192 }
193
194 pub fn mag2(&self) -> Float {
196 self.dot(self)
197 }
198
199 pub fn costheta(&self) -> Float {
201 self.z / self.mag()
202 }
203
204 pub fn theta(&self) -> Float {
206 Float::acos(self.costheta())
207 }
208
209 pub fn phi(&self) -> Float {
211 Float::atan2(self.y, self.x)
212 }
213
214 pub fn unit(&self) -> Vec3 {
216 let mag = self.mag();
217 Vec3::new(self.x / mag, self.y / mag, self.z / mag)
218 }
219}
220
221impl<'a> std::iter::Sum<&'a Vec3> for Vec3 {
222 fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
223 iter.fold(Self::zero(), |a, b| a + b)
224 }
225}
226impl std::iter::Sum<Vec3> for Vec3 {
227 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
228 iter.fold(Self::zero(), |a, b| a + b)
229 }
230}
231
232impl_op_ex!(+ |a: &Vec3, b: &Vec3| -> Vec3 { Vec3::new(a.x + b.x, a.y + b.y, a.z + b.z) });
233impl_op_ex!(-|a: &Vec3, b: &Vec3| -> Vec3 { Vec3::new(a.x - b.x, a.y - b.y, a.z - b.z) });
234impl_op_ex!(-|a: &Vec3| -> Vec3 { Vec3::new(-a.x, -a.y, -a.z) });
235impl_op_ex_commutative!(+ |a: &Vec3, b: &Float| -> Vec3 { Vec3::new(a.x + b, a.y + b, a.z + b) });
236impl_op_ex_commutative!(-|a: &Vec3, b: &Float| -> Vec3 { Vec3::new(a.x - b, a.y - b, a.z - b) });
237impl_op_ex_commutative!(*|a: &Vec3, b: &Float| -> Vec3 { Vec3::new(a.x * b, a.y * b, a.z * b) });
238impl_op_ex!(/ |a: &Vec3, b: &Float| -> Vec3 { Vec3::new(a.x / b, a.y / b, a.z / b) });
239
240#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
242pub struct Vec4 {
243 pub x: Float,
245 pub y: Float,
247 pub z: Float,
249 pub t: Float,
251}
252
253impl Display for Vec4 {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 write!(
256 f,
257 "[{:6.3}, {:6.3}, {:6.3}; {:6.3}]",
258 self.x, self.y, self.z, self.t
259 )
260 }
261}
262
263impl AbsDiffEq for Vec4 {
264 type Epsilon = <Float as approx::AbsDiffEq>::Epsilon;
265
266 fn default_epsilon() -> Self::Epsilon {
267 Float::default_epsilon()
268 }
269
270 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
271 Float::abs_diff_eq(&self.x, &other.x, epsilon)
272 && Float::abs_diff_eq(&self.y, &other.y, epsilon)
273 && Float::abs_diff_eq(&self.z, &other.z, epsilon)
274 && Float::abs_diff_eq(&self.t, &other.t, epsilon)
275 }
276}
277impl RelativeEq for Vec4 {
278 fn default_max_relative() -> Self::Epsilon {
279 Float::default_max_relative()
280 }
281
282 fn relative_eq(
283 &self,
284 other: &Self,
285 epsilon: Self::Epsilon,
286 max_relative: Self::Epsilon,
287 ) -> bool {
288 Float::relative_eq(&self.x, &other.x, epsilon, max_relative)
289 && Float::relative_eq(&self.y, &other.y, epsilon, max_relative)
290 && Float::relative_eq(&self.z, &other.z, epsilon, max_relative)
291 && Float::relative_eq(&self.t, &other.t, epsilon, max_relative)
292 }
293}
294
295impl From<Vec4> for Vector4<Float> {
296 fn from(value: Vec4) -> Self {
297 Vector4::new(value.x, value.y, value.z, value.t)
298 }
299}
300
301impl From<Vector4<Float>> for Vec4 {
302 fn from(value: Vector4<Float>) -> Self {
303 Vec4::new(value.x, value.y, value.z, value.w)
304 }
305}
306
307impl From<Vec<Float>> for Vec4 {
308 fn from(value: Vec<Float>) -> Self {
309 Self {
310 x: value[0],
311 y: value[1],
312 z: value[2],
313 t: value[3],
314 }
315 }
316}
317
318impl From<Vec4> for Vec<Float> {
319 fn from(value: Vec4) -> Self {
320 vec![value.x, value.y, value.z, value.t]
321 }
322}
323
324impl From<[Float; 4]> for Vec4 {
325 fn from(value: [Float; 4]) -> Self {
326 Self {
327 x: value[0],
328 y: value[1],
329 z: value[2],
330 t: value[3],
331 }
332 }
333}
334
335impl From<Vec4> for [Float; 4] {
336 fn from(value: Vec4) -> Self {
337 [value.x, value.y, value.z, value.t]
338 }
339}
340
341impl Vec4 {
342 pub fn new(x: Float, y: Float, z: Float, t: Float) -> Self {
344 Vec4 { x, y, z, t }
345 }
346
347 pub fn px(&self) -> Float {
349 self.x
350 }
351
352 pub fn py(&self) -> Float {
354 self.y
355 }
356
357 pub fn pz(&self) -> Float {
359 self.z
360 }
361
362 pub fn e(&self) -> Float {
364 self.t
365 }
366
367 pub fn momentum(&self) -> Vec3 {
369 self.vec3()
370 }
371
372 pub fn gamma(&self) -> Float {
374 let beta = self.beta();
375 let b2 = beta.dot(&beta);
376 1.0 / Float::sqrt(1.0 - b2)
377 }
378
379 pub fn beta(&self) -> Vec3 {
381 self.momentum() / self.e()
382 }
383
384 pub fn m(&self) -> Float {
386 self.mag()
387 }
388
389 pub fn m2(&self) -> Float {
391 self.mag2()
392 }
393
394 pub fn to_p4_string(&self) -> String {
396 format!(
397 "[e = {:.5}; p = ({:.5}, {:.5}, {:.5}); m = {:.5}]",
398 self.e(),
399 self.px(),
400 self.py(),
401 self.pz(),
402 self.m()
403 )
404 }
405
406 pub fn mag(&self) -> Float {
408 Float::sqrt(self.mag2())
409 }
410
411 pub fn mag2(&self) -> Float {
413 self.t * self.t - (self.x * self.x + self.y * self.y + self.z * self.z)
414 }
415
416 pub fn boost(&self, beta: &Vec3) -> Self {
418 let b2 = beta.dot(beta);
419 let gamma = 1.0 / Float::sqrt(1.0 - b2);
420 let p3 = self.vec3() + beta * ((gamma - 1.0) * self.vec3().dot(beta) / b2 + gamma * self.t);
421 Vec4::new(p3.x, p3.y, p3.z, gamma * (self.t + beta.dot(&self.vec3())))
422 }
423
424 pub fn vec3(&self) -> Vec3 {
426 Vec3 {
427 x: self.x,
428 y: self.y,
429 z: self.z,
430 }
431 }
432}
433
434impl_op_ex!(+ |a: &Vec4, b: &Vec4| -> Vec4 { Vec4::new(a.x + b.x, a.y + b.y, a.z + b.z, a.t + b.t) });
435impl_op_ex!(-|a: &Vec4, b: &Vec4| -> Vec4 {
436 Vec4::new(a.x - b.x, a.y - b.y, a.z - b.z, a.t - b.t)
437});
438impl_op_ex!(-|a: &Vec4| -> Vec4 { Vec4::new(-a.x, -a.y, -a.z, a.t) });
439impl_op_ex_commutative!(+ |a: &Vec4, b: &Float| -> Vec4 { Vec4::new(a.x + b, a.y + b, a.z + b, a.t) });
440impl_op_ex_commutative!(-|a: &Vec4, b: &Float| -> Vec4 {
441 Vec4::new(a.x - b, a.y - b, a.z - b, a.t)
442});
443impl_op_ex_commutative!(*|a: &Vec4, b: &Float| -> Vec4 {
444 Vec4::new(a.x * b, a.y * b, a.z * b, a.t)
445});
446impl_op_ex!(/ |a: &Vec4, b: &Float| -> Vec4 { Vec4::new(a.x / b, a.y / b, a.z / b, a.t) });
447
448impl<'a> std::iter::Sum<&'a Vec4> for Vec4 {
449 fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
450 iter.fold(Self::new(0.0, 0.0, 0.0, 0.0), |a, b| a + b)
451 }
452}
453
454impl std::iter::Sum<Vec4> for Vec4 {
455 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
456 iter.fold(Self::new(0.0, 0.0, 0.0, 0.0), |a, b| a + b)
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use approx::{assert_abs_diff_eq, assert_relative_eq};
463 use nalgebra::{Vector3, Vector4};
464
465 use super::*;
466
467 #[test]
468 fn test_display() {
469 let v3 = Vec3::new(1.2341, -2.3452, 3.4563);
470 assert_eq!(format!("{}", v3), "[ 1.234, -2.345, 3.456]");
471 let v4 = Vec4::new(1.2341, -2.3452, 3.4563, 4.5674);
472 assert_eq!(format!("{}", v4), "[ 1.234, -2.345, 3.456; 4.567]");
473 }
474
475 #[test]
476 fn test_vec_vector_conversion() {
477 let v = Vec3::new(1.0, 2.0, 3.0);
478 let vector3: Vec<Float> = v.into();
479 assert_eq!(vector3[0], 1.0);
480 assert_eq!(vector3[1], 2.0);
481 assert_eq!(vector3[2], 3.0);
482
483 let v_from_vec: Vec3 = vector3.into();
484 assert_eq!(v_from_vec, v);
485
486 let v = Vec4::new(1.0, 2.0, 3.0, 4.0);
487 let vector4: Vec<Float> = v.into();
488 assert_eq!(vector4[0], 1.0);
489 assert_eq!(vector4[1], 2.0);
490 assert_eq!(vector4[2], 3.0);
491 assert_eq!(vector4[3], 4.0);
492
493 let v_from_vec: Vec4 = vector4.into();
494 assert_eq!(v_from_vec, v);
495 }
496
497 #[test]
498 fn test_vec_array_conversion() {
499 let arr = [1.0, 2.0, 3.0];
500 let v: Vec3 = arr.into();
501 assert_eq!(v, Vec3::new(1.0, 2.0, 3.0));
502
503 let back_to_array: [Float; 3] = v.into();
504 assert_eq!(back_to_array, arr);
505
506 let arr = [1.0, 2.0, 3.0, 4.0];
507 let v: Vec4 = arr.into();
508 assert_eq!(v, Vec4::new(1.0, 2.0, 3.0, 4.0));
509
510 let back_to_array: [Float; 4] = v.into();
511 assert_eq!(back_to_array, arr);
512 }
513
514 #[test]
515 fn test_vec_nalgebra_conversion() {
516 let v = Vec3::new(1.0, 2.0, 3.0);
517 let vector3: Vector3<Float> = v.into();
518 assert_eq!(vector3.x, 1.0);
519 assert_eq!(vector3.y, 2.0);
520 assert_eq!(vector3.z, 3.0);
521
522 let v_from_vec: Vec3 = vector3.into();
523 assert_eq!(v_from_vec, v);
524
525 let v = Vec4::new(1.0, 2.0, 3.0, 4.0);
526 let vector4: Vector4<Float> = v.into();
527 assert_eq!(vector4.x, 1.0);
528 assert_eq!(vector4.y, 2.0);
529 assert_eq!(vector4.z, 3.0);
530 assert_eq!(vector4.w, 4.0);
531
532 let v_from_vec: Vec4 = vector4.into();
533 assert_eq!(v_from_vec, v);
534 }
535
536 #[test]
537 fn test_vec_sums() {
538 let vectors = [Vec3::new(1.0, 2.0, 3.0), Vec3::new(4.0, 5.0, 6.0)];
539 let sum: Vec3 = vectors.iter().sum();
540 assert_eq!(sum, Vec3::new(5.0, 7.0, 9.0));
541 let sum: Vec3 = vectors.into_iter().sum();
542 assert_eq!(sum, Vec3::new(5.0, 7.0, 9.0));
543
544 let vectors = [Vec4::new(1.0, 2.0, 3.0, 4.0), Vec4::new(4.0, 5.0, 6.0, 7.0)];
545 let sum: Vec4 = vectors.iter().sum();
546 assert_eq!(sum, Vec4::new(5.0, 7.0, 9.0, 11.0));
547 let sum: Vec4 = vectors.into_iter().sum();
548 assert_eq!(sum, Vec4::new(5.0, 7.0, 9.0, 11.0));
549 }
550
551 #[test]
552 fn test_three_to_four_momentum_conversion() {
553 let p3 = Vec3::new(1.0, 2.0, 3.0);
554 let target_p4 = Vec4::new(1.0, 2.0, 3.0, 10.0);
555 let p4_from_mass = p3.with_mass(target_p4.m());
556 assert_eq!(target_p4.e(), p4_from_mass.e());
557 assert_eq!(target_p4.px(), p4_from_mass.px());
558 assert_eq!(target_p4.py(), p4_from_mass.py());
559 assert_eq!(target_p4.pz(), p4_from_mass.pz());
560 let p4_from_energy = p3.with_energy(target_p4.e());
561 assert_eq!(target_p4.e(), p4_from_energy.e());
562 assert_eq!(target_p4.px(), p4_from_energy.px());
563 assert_eq!(target_p4.py(), p4_from_energy.py());
564 assert_eq!(target_p4.pz(), p4_from_energy.pz());
565 }
566
567 #[test]
568 fn test_four_momentum_basics() {
569 let p = Vec4::new(3.0, 4.0, 5.0, 10.0);
570 assert_eq!(p.e(), 10.0);
571 assert_eq!(p.px(), 3.0);
572 assert_eq!(p.py(), 4.0);
573 assert_eq!(p.pz(), 5.0);
574 assert_eq!(p.momentum().px(), 3.0);
575 assert_eq!(p.momentum().py(), 4.0);
576 assert_eq!(p.momentum().pz(), 5.0);
577 assert_relative_eq!(p.beta().x, 0.3);
578 assert_relative_eq!(p.beta().y, 0.4);
579 assert_relative_eq!(p.beta().z, 0.5);
580 assert_relative_eq!(p.m2(), 50.0);
581 assert_relative_eq!(p.m(), Float::sqrt(50.0));
582 assert_eq!(
583 format!("{}", p.to_p4_string()),
584 "[e = 10.00000; p = (3.00000, 4.00000, 5.00000); m = 7.07107]"
585 );
586 assert_relative_eq!(Vec3::x().x, 1.0);
587 assert_relative_eq!(Vec3::x().y, 0.0);
588 assert_relative_eq!(Vec3::x().z, 0.0);
589 assert_relative_eq!(Vec3::y().x, 0.0);
590 assert_relative_eq!(Vec3::y().y, 1.0);
591 assert_relative_eq!(Vec3::y().z, 0.0);
592 assert_relative_eq!(Vec3::z().x, 0.0);
593 assert_relative_eq!(Vec3::z().y, 0.0);
594 assert_relative_eq!(Vec3::z().z, 1.0);
595 assert_relative_eq!(Vec3::default().x, 0.0);
596 assert_relative_eq!(Vec3::default().y, 0.0);
597 assert_relative_eq!(Vec3::default().z, 0.0);
598 }
599
600 #[test]
601 fn test_three_momentum_basics() {
602 let p = Vec4::new(3.0, 4.0, 5.0, 10.0);
603 let q = Vec4::new(1.2, -3.4, 7.6, 0.0);
604 let p3_view = p.momentum();
605 let q3_view = q.momentum();
606 assert_eq!(p3_view.px(), 3.0);
607 assert_eq!(p3_view.py(), 4.0);
608 assert_eq!(p3_view.pz(), 5.0);
609 assert_relative_eq!(p3_view.mag2(), 50.0);
610 assert_relative_eq!(p3_view.mag(), Float::sqrt(50.0));
611 assert_relative_eq!(p3_view.costheta(), 5.0 / Float::sqrt(50.0));
612 assert_relative_eq!(p3_view.theta(), Float::acos(5.0 / Float::sqrt(50.0)));
613 assert_relative_eq!(p3_view.phi(), Float::atan2(4.0, 3.0));
614 assert_relative_eq!(
615 p3_view.unit(),
616 Vec3::new(
617 3.0 / Float::sqrt(50.0),
618 4.0 / Float::sqrt(50.0),
619 5.0 / Float::sqrt(50.0)
620 )
621 );
622 assert_relative_eq!(p3_view.cross(&q3_view), Vec3::new(47.4, -16.8, -15.0));
623 }
624
625 #[test]
626 fn test_vec_equality() {
627 let p = Vec3::new(1.1, 2.2, 3.3);
628 let p2 = Vec3::new(1.1 * 2.0, 2.2 * 2.0, 3.3 * 2.0);
629 assert_abs_diff_eq!(p * 2.0, p2);
630 assert_relative_eq!(p * 2.0, p2);
631 let p = Vec4::new(1.1, 2.2, 3.3, 10.0);
632 let p2 = Vec4::new(1.1 * 2.0, 2.2 * 2.0, 3.3 * 2.0, 10.0);
633 assert_abs_diff_eq!(p * 2.0, p2);
634 assert_relative_eq!(p * 2.0, p2);
635 }
636
637 #[test]
638 fn test_boost_com() {
639 let p = Vec4::new(3.0, 4.0, 5.0, 10.0);
640 let zero = p.boost(&-p.beta()).momentum();
641 assert_relative_eq!(zero, Vec3::zero());
642 }
643
644 #[test]
645 fn test_boost() {
646 let p0 = Vec4::new(0.0, 0.0, 0.0, 1.0);
647 assert_relative_eq!(p0.gamma(), 1.0);
648 let p0 = Vec4::new(Float::sqrt(3.0) / 2.0, 0.0, 0.0, 1.0);
649 assert_relative_eq!(p0.gamma(), 2.0);
650 let p1 = Vec4::new(3.0, 4.0, 5.0, 10.0);
651 let p2 = Vec4::new(3.4, 2.3, 1.2, 9.0);
652 let p1_boosted = p1.boost(&-p2.beta());
653 assert_relative_eq!(
654 p1_boosted.e(),
655 8.157632144622882,
656 epsilon = Float::EPSILON.sqrt()
657 );
658 assert_relative_eq!(
659 p1_boosted.px(),
660 -0.6489200627053444,
661 epsilon = Float::EPSILON.sqrt()
662 );
663 assert_relative_eq!(
664 p1_boosted.py(),
665 1.5316128987581492,
666 epsilon = Float::EPSILON.sqrt()
667 );
668 assert_relative_eq!(
669 p1_boosted.pz(),
670 3.712145860221643,
671 epsilon = Float::EPSILON.sqrt()
672 );
673 }
674}