Skip to main content

laddu_core/vectors/
vec4.rs

1use std::fmt::Display;
2
3use approx::{AbsDiffEq, RelativeEq};
4use auto_ops::{impl_op_ex, impl_op_ex_commutative};
5use nalgebra::Vector4;
6use serde::{Deserialize, Serialize};
7
8use super::vec3::Vec3;
9
10/// A four-vector (Lorentz vector) whose last component stores the energy.
11///
12/// # Examples
13/// ```rust
14/// use laddu_core::vectors::{Vec3, Vec4};
15///
16/// let momentum = Vec3::new(1.0, 0.0, 0.0);
17/// let four_vector = momentum.with_mass(2.0);
18/// assert!((four_vector.m2() - 4.0).abs() < 1e-12);
19/// ```
20#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
21pub struct Vec4 {
22    /// The x-component of the vector
23    pub x: f64,
24    /// The y-component of the vector
25    pub y: f64,
26    /// The z-component of the vector
27    pub z: f64,
28    /// The t-component of the vector
29    pub t: f64,
30}
31
32impl Display for Vec4 {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(
35            f,
36            "[{:6.3}, {:6.3}, {:6.3}; {:6.3}]",
37            self.x, self.y, self.z, self.t
38        )
39    }
40}
41
42impl AbsDiffEq for Vec4 {
43    type Epsilon = <f64 as approx::AbsDiffEq>::Epsilon;
44
45    fn default_epsilon() -> Self::Epsilon {
46        f64::default_epsilon()
47    }
48
49    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
50        f64::abs_diff_eq(&self.x, &other.x, epsilon)
51            && f64::abs_diff_eq(&self.y, &other.y, epsilon)
52            && f64::abs_diff_eq(&self.z, &other.z, epsilon)
53            && f64::abs_diff_eq(&self.t, &other.t, epsilon)
54    }
55}
56impl RelativeEq for Vec4 {
57    fn default_max_relative() -> Self::Epsilon {
58        f64::default_max_relative()
59    }
60
61    fn relative_eq(
62        &self,
63        other: &Self,
64        epsilon: Self::Epsilon,
65        max_relative: Self::Epsilon,
66    ) -> bool {
67        f64::relative_eq(&self.x, &other.x, epsilon, max_relative)
68            && f64::relative_eq(&self.y, &other.y, epsilon, max_relative)
69            && f64::relative_eq(&self.z, &other.z, epsilon, max_relative)
70            && f64::relative_eq(&self.t, &other.t, epsilon, max_relative)
71    }
72}
73
74impl From<Vec4> for Vector4<f64> {
75    fn from(value: Vec4) -> Self {
76        Vector4::new(value.x, value.y, value.z, value.t)
77    }
78}
79
80impl From<Vector4<f64>> for Vec4 {
81    fn from(value: Vector4<f64>) -> Self {
82        Vec4::new(value.x, value.y, value.z, value.w)
83    }
84}
85
86impl From<Vec<f64>> for Vec4 {
87    fn from(value: Vec<f64>) -> Self {
88        Self {
89            x: value[0],
90            y: value[1],
91            z: value[2],
92            t: value[3],
93        }
94    }
95}
96
97impl From<Vec4> for Vec<f64> {
98    fn from(value: Vec4) -> Self {
99        vec![value.x, value.y, value.z, value.t]
100    }
101}
102
103impl From<[f64; 4]> for Vec4 {
104    fn from(value: [f64; 4]) -> Self {
105        Self {
106            x: value[0],
107            y: value[1],
108            z: value[2],
109            t: value[3],
110        }
111    }
112}
113
114impl From<Vec4> for [f64; 4] {
115    fn from(value: Vec4) -> Self {
116        [value.x, value.y, value.z, value.t]
117    }
118}
119
120impl Vec4 {
121    /// Create a new 4-vector from its components
122    pub fn new(x: f64, y: f64, z: f64, t: f64) -> Self {
123        Vec4 { x, y, z, t }
124    }
125
126    /// Momentum in the x-direction
127    pub fn px(&self) -> f64 {
128        self.x
129    }
130
131    /// Momentum in the y-direction
132    pub fn py(&self) -> f64 {
133        self.y
134    }
135
136    /// Momentum in the z-direction
137    pub fn pz(&self) -> f64 {
138        self.z
139    }
140
141    /// The energy of the 4-vector
142    pub fn e(&self) -> f64 {
143        self.t
144    }
145
146    /// The 3-momentum
147    pub fn momentum(&self) -> Vec3 {
148        self.vec3()
149    }
150
151    /// The $`\gamma`$ factor $`\frac{1}{\sqrt{1 - \beta^2}}`$.
152    pub fn gamma(&self) -> f64 {
153        let beta = self.beta();
154        let b2 = beta.dot(&beta);
155        1.0 / f64::sqrt(1.0 - b2)
156    }
157
158    /// The $`\vec{\beta}`$ vector $`\frac{\vec{p}}{E}`$.
159    pub fn beta(&self) -> Vec3 {
160        self.momentum() / self.e()
161    }
162
163    /// The invariant mass corresponding to this 4-momentum
164    pub fn m(&self) -> f64 {
165        self.mag()
166    }
167
168    /// The squared invariant mass corresponding to this 4-momentum
169    pub fn m2(&self) -> f64 {
170        self.mag2()
171    }
172
173    /// Pretty-prints the four-momentum.
174    pub fn to_p4_string(&self) -> String {
175        format!(
176            "[e = {:.5}; p = ({:.5}, {:.5}, {:.5}); m = {:.5}]",
177            self.e(),
178            self.px(),
179            self.py(),
180            self.pz(),
181            self.m()
182        )
183    }
184
185    /// The magnitude of the vector (with $`---+`$ signature).
186    pub fn mag(&self) -> f64 {
187        f64::sqrt(self.mag2())
188    }
189
190    /// The squared magnitude of the vector (with $`---+`$ signature).
191    pub fn mag2(&self) -> f64 {
192        self.t * self.t - (self.x * self.x + self.y * self.y + self.z * self.z)
193    }
194
195    /// Gives the vector boosted along a $`\vec{\beta}`$ vector.
196    pub fn boost(&self, beta: &Vec3) -> Self {
197        let b2 = beta.dot(beta);
198        if b2 == 0.0 {
199            return *self;
200        }
201        let gamma = 1.0 / f64::sqrt(1.0 - b2);
202        let p3 = self.vec3() + beta * ((gamma - 1.0) * self.vec3().dot(beta) / b2 + gamma * self.t);
203        Vec4::new(p3.x, p3.y, p3.z, gamma * (self.t + beta.dot(&self.vec3())))
204    }
205
206    /// The 3-vector contained in this 4-vector
207    pub fn vec3(&self) -> Vec3 {
208        Vec3 {
209            x: self.x,
210            y: self.y,
211            z: self.z,
212        }
213    }
214}
215
216impl_op_ex!(+ |a: &Vec4, b: &Vec4| -> Vec4 { Vec4::new(a.x + b.x, a.y + b.y, a.z + b.z, a.t + b.t) });
217impl_op_ex!(-|a: &Vec4, b: &Vec4| -> Vec4 {
218    Vec4::new(a.x - b.x, a.y - b.y, a.z - b.z, a.t - b.t)
219});
220impl_op_ex!(-|a: &Vec4| -> Vec4 { Vec4::new(-a.x, -a.y, -a.z, a.t) });
221impl_op_ex_commutative!(+ |a: &Vec4, b: &f64| -> Vec4 { Vec4::new(a.x + b, a.y + b, a.z + b, a.t) });
222impl_op_ex_commutative!(-|a: &Vec4, b: &f64| -> Vec4 { Vec4::new(a.x - b, a.y - b, a.z - b, a.t) });
223impl_op_ex_commutative!(*|a: &Vec4, b: &f64| -> Vec4 { Vec4::new(a.x * b, a.y * b, a.z * b, a.t) });
224impl_op_ex!(/ |a: &Vec4, b: &f64| -> Vec4 { Vec4::new(a.x / b, a.y / b, a.z / b, a.t) });
225
226impl<'a> std::iter::Sum<&'a Vec4> for Vec4 {
227    fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
228        iter.fold(Self::new(0.0, 0.0, 0.0, 0.0), |a, b| a + b)
229    }
230}
231
232impl std::iter::Sum<Vec4> for Vec4 {
233    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
234        iter.fold(Self::new(0.0, 0.0, 0.0, 0.0), |a, b| a + b)
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use approx::{assert_abs_diff_eq, assert_relative_eq};
241    use nalgebra::{Vector3, Vector4};
242
243    use super::*;
244
245    #[test]
246    fn test_display() {
247        let v3 = Vec3::new(1.2341, -2.3452, 3.4563);
248        assert_eq!(format!("{}", v3), "[ 1.234, -2.345,  3.456]");
249        let v4 = Vec4::new(1.2341, -2.3452, 3.4563, 4.5674);
250        assert_eq!(format!("{}", v4), "[ 1.234, -2.345,  3.456;  4.567]");
251    }
252
253    #[test]
254    fn test_vec_vector_conversion() {
255        let v = Vec3::new(1.0, 2.0, 3.0);
256        let vector3: Vec<f64> = v.into();
257        assert_eq!(vector3[0], 1.0);
258        assert_eq!(vector3[1], 2.0);
259        assert_eq!(vector3[2], 3.0);
260
261        let v_from_vec: Vec3 = vector3.into();
262        assert_eq!(v_from_vec, v);
263
264        let v = Vec4::new(1.0, 2.0, 3.0, 4.0);
265        let vector4: Vec<f64> = v.into();
266        assert_eq!(vector4[0], 1.0);
267        assert_eq!(vector4[1], 2.0);
268        assert_eq!(vector4[2], 3.0);
269        assert_eq!(vector4[3], 4.0);
270
271        let v_from_vec: Vec4 = vector4.into();
272        assert_eq!(v_from_vec, v);
273    }
274
275    #[test]
276    fn test_vec_array_conversion() {
277        let arr = [1.0, 2.0, 3.0];
278        let v: Vec3 = arr.into();
279        assert_eq!(v, Vec3::new(1.0, 2.0, 3.0));
280
281        let back_to_array: [f64; 3] = v.into();
282        assert_eq!(back_to_array, arr);
283
284        let arr = [1.0, 2.0, 3.0, 4.0];
285        let v: Vec4 = arr.into();
286        assert_eq!(v, Vec4::new(1.0, 2.0, 3.0, 4.0));
287
288        let back_to_array: [f64; 4] = v.into();
289        assert_eq!(back_to_array, arr);
290    }
291
292    #[test]
293    fn test_vec_nalgebra_conversion() {
294        let v = Vec3::new(1.0, 2.0, 3.0);
295        let vector3: Vector3<f64> = v.into();
296        assert_eq!(vector3.x, 1.0);
297        assert_eq!(vector3.y, 2.0);
298        assert_eq!(vector3.z, 3.0);
299
300        let v_from_vec: Vec3 = vector3.into();
301        assert_eq!(v_from_vec, v);
302
303        let v = Vec4::new(1.0, 2.0, 3.0, 4.0);
304        let vector4: Vector4<f64> = v.into();
305        assert_eq!(vector4.x, 1.0);
306        assert_eq!(vector4.y, 2.0);
307        assert_eq!(vector4.z, 3.0);
308        assert_eq!(vector4.w, 4.0);
309
310        let v_from_vec: Vec4 = vector4.into();
311        assert_eq!(v_from_vec, v);
312    }
313
314    #[test]
315    fn test_vec_sums() {
316        let vectors = [Vec3::new(1.0, 2.0, 3.0), Vec3::new(4.0, 5.0, 6.0)];
317        let sum: Vec3 = vectors.iter().sum();
318        assert_eq!(sum, Vec3::new(5.0, 7.0, 9.0));
319        let sum: Vec3 = vectors.into_iter().sum();
320        assert_eq!(sum, Vec3::new(5.0, 7.0, 9.0));
321
322        let vectors = [Vec4::new(1.0, 2.0, 3.0, 4.0), Vec4::new(4.0, 5.0, 6.0, 7.0)];
323        let sum: Vec4 = vectors.iter().sum();
324        assert_eq!(sum, Vec4::new(5.0, 7.0, 9.0, 11.0));
325        let sum: Vec4 = vectors.into_iter().sum();
326        assert_eq!(sum, Vec4::new(5.0, 7.0, 9.0, 11.0));
327    }
328
329    #[test]
330    fn test_three_to_four_momentum_conversion() {
331        let p3 = Vec3::new(1.0, 2.0, 3.0);
332        let target_p4 = Vec4::new(1.0, 2.0, 3.0, 10.0);
333        let p4_from_mass = p3.with_mass(target_p4.m());
334        assert_eq!(target_p4.e(), p4_from_mass.e());
335        assert_eq!(target_p4.px(), p4_from_mass.px());
336        assert_eq!(target_p4.py(), p4_from_mass.py());
337        assert_eq!(target_p4.pz(), p4_from_mass.pz());
338        let p4_from_energy = p3.with_energy(target_p4.e());
339        assert_eq!(target_p4.e(), p4_from_energy.e());
340        assert_eq!(target_p4.px(), p4_from_energy.px());
341        assert_eq!(target_p4.py(), p4_from_energy.py());
342        assert_eq!(target_p4.pz(), p4_from_energy.pz());
343    }
344
345    #[test]
346    fn test_four_momentum_basics() {
347        let p = Vec4::new(3.0, 4.0, 5.0, 10.0);
348        assert_eq!(p.e(), 10.0);
349        assert_eq!(p.px(), 3.0);
350        assert_eq!(p.py(), 4.0);
351        assert_eq!(p.pz(), 5.0);
352        assert_eq!(p.momentum().px(), 3.0);
353        assert_eq!(p.momentum().py(), 4.0);
354        assert_eq!(p.momentum().pz(), 5.0);
355        assert_relative_eq!(p.beta().x, 0.3);
356        assert_relative_eq!(p.beta().y, 0.4);
357        assert_relative_eq!(p.beta().z, 0.5);
358        assert_relative_eq!(p.m2(), 50.0);
359        assert_relative_eq!(p.m(), f64::sqrt(50.0));
360        assert_eq!(
361            format!("{}", p.to_p4_string()),
362            "[e = 10.00000; p = (3.00000, 4.00000, 5.00000); m = 7.07107]"
363        );
364        assert_relative_eq!(Vec3::x().x, 1.0);
365        assert_relative_eq!(Vec3::x().y, 0.0);
366        assert_relative_eq!(Vec3::x().z, 0.0);
367        assert_relative_eq!(Vec3::y().x, 0.0);
368        assert_relative_eq!(Vec3::y().y, 1.0);
369        assert_relative_eq!(Vec3::y().z, 0.0);
370        assert_relative_eq!(Vec3::z().x, 0.0);
371        assert_relative_eq!(Vec3::z().y, 0.0);
372        assert_relative_eq!(Vec3::z().z, 1.0);
373        assert_relative_eq!(Vec3::default().x, 0.0);
374        assert_relative_eq!(Vec3::default().y, 0.0);
375        assert_relative_eq!(Vec3::default().z, 0.0);
376    }
377
378    #[test]
379    fn test_three_momentum_basics() {
380        let p = Vec4::new(3.0, 4.0, 5.0, 10.0);
381        let q = Vec4::new(1.2, -3.4, 7.6, 0.0);
382        let p3_view = p.momentum();
383        let q3_view = q.momentum();
384        assert_eq!(p3_view.px(), 3.0);
385        assert_eq!(p3_view.py(), 4.0);
386        assert_eq!(p3_view.pz(), 5.0);
387        assert_relative_eq!(p3_view.mag2(), 50.0);
388        assert_relative_eq!(p3_view.mag(), f64::sqrt(50.0));
389        assert_relative_eq!(p3_view.costheta(), 5.0 / f64::sqrt(50.0));
390        assert_relative_eq!(p3_view.theta(), f64::acos(5.0 / f64::sqrt(50.0)));
391        assert_relative_eq!(p3_view.phi(), f64::atan2(4.0, 3.0));
392        assert_relative_eq!(
393            p3_view.unit(),
394            Vec3::new(
395                3.0 / f64::sqrt(50.0),
396                4.0 / f64::sqrt(50.0),
397                5.0 / f64::sqrt(50.0)
398            )
399        );
400        assert_relative_eq!(p3_view.cross(&q3_view), Vec3::new(47.4, -16.8, -15.0));
401    }
402
403    #[test]
404    fn test_vec_equality() {
405        let p = Vec3::new(1.1, 2.2, 3.3);
406        let p2 = Vec3::new(1.1 * 2.0, 2.2 * 2.0, 3.3 * 2.0);
407        assert_abs_diff_eq!(p * 2.0, p2);
408        assert_relative_eq!(p * 2.0, p2);
409        let p = Vec4::new(1.1, 2.2, 3.3, 10.0);
410        let p2 = Vec4::new(1.1 * 2.0, 2.2 * 2.0, 3.3 * 2.0, 10.0);
411        assert_abs_diff_eq!(p * 2.0, p2);
412        assert_relative_eq!(p * 2.0, p2);
413    }
414
415    #[test]
416    fn test_boost_com() {
417        let p = Vec4::new(3.0, 4.0, 5.0, 10.0);
418        let zero = p.boost(&-p.beta()).momentum();
419        assert_relative_eq!(zero, Vec3::zero());
420    }
421
422    #[test]
423    fn test_boost() {
424        let p0 = Vec4::new(0.0, 0.0, 0.0, 1.0);
425        assert_relative_eq!(p0.gamma(), 1.0);
426        let p0 = Vec4::new(f64::sqrt(3.0) / 2.0, 0.0, 0.0, 1.0);
427        assert_relative_eq!(p0.gamma(), 2.0);
428        let p1 = Vec4::new(3.0, 4.0, 5.0, 10.0);
429        let p2 = Vec4::new(3.4, 2.3, 1.2, 9.0);
430        let p1_boosted = p1.boost(&-p2.beta());
431        assert_relative_eq!(p1_boosted.e(), 8.157632144622882);
432        assert_relative_eq!(p1_boosted.px(), -0.6489200627053444);
433        assert_relative_eq!(p1_boosted.py(), 1.5316128987581492);
434        assert_relative_eq!(p1_boosted.pz(), 3.712145860221643);
435    }
436}