1use crate::{matrix::Matrix, vector::Vector};
2
3#[derive(Debug, Clone, Copy, PartialEq)]
4pub struct Quaternion {
5 q0: f64,
6 q1: f64,
7 q2: f64,
8 q3: f64,
9}
10
11impl Quaternion {
14 pub fn default() -> Self {
15 Quaternion {
16 q0: 1.0,
17 q1: 0.0,
18 q2: 0.0,
19 q3: 0.0,
20 }
21 }
22
23 pub fn from_axis_and_angle(axis: &Vector, angle: f64) -> Self {
24 let half_theta = angle / 2.0;
25 let q0 = half_theta.cos();
26 let sin_half_theta = half_theta.sin();
27
28 let real_axis = axis.normalized();
29
30 Quaternion {
31 q0,
32 q1: real_axis.x * sin_half_theta,
33 q2: real_axis.y * sin_half_theta,
34 q3: real_axis.z * sin_half_theta,
35 }
36 }
37
38 pub fn from_pitch_roll_yaw(roll: f64, pitch: f64, yaw: f64) -> Self {
39 let roll_q = Quaternion::from_axis_and_angle(&Vector::new(1.0, 0.0, 0.0), roll);
40 let pitch_q = Quaternion::from_axis_and_angle(&Vector::new(0.0, 1.0, 0.0), pitch);
41 let yaw_q = Quaternion::from_axis_and_angle(&Vector::new(0.0, 0.0, 1.0), yaw);
42
43 yaw_q.times(&pitch_q).times(&roll_q)
44 }
45
46 pub fn from_matrix(mat: &Matrix) -> Quaternion {
47 let tr = mat.get(0, 0) + mat.get(1, 1) + mat.get(2, 2);
48
49 if tr > 0.0 {
50 let mut s = (tr + 1.0).sqrt();
51 let q0 = s * 0.5;
52 s = 0.5 / s;
53 Quaternion {
54 q0,
55 q1: (mat.get(2, 1) - mat.get(1, 2)) * s,
56 q2: (mat.get(0, 2) - mat.get(2, 0)) * s,
57 q3: (mat.get(1, 0) - mat.get(0, 1)) * s,
58 }
59 } else {
60 let mut i = if mat.get(1, 1) > mat.get(0, 0) { 1 } else { 0 };
61 if mat.get(2, 2) > mat.get(i, i) {
62 i = 2;
63 }
64 let j = (i + 1) % 3;
65 let k = (j + 1) % 3;
66 let mut s = ((mat.get(i, i) - (mat.get(j, j) + mat.get(k, k))) + 1.0).sqrt();
67 let mut q = Quaternion::default();
68 q.set_q(i + 1, s * 0.5);
69 s = 0.5 / s;
70 q.q0 = (mat.get(k, j) - mat.get(j, k)) * s;
71 q.set_q(j + 1, (mat.get(j, i) + mat.get(i, j)) * s);
72 q.set_q(k + 1, (mat.get(k, i) + mat.get(i, k)) * s);
73 q
74 }
75 }
76
77 pub fn within_epsilon(&self, other: &Quaternion, epsilon: f64) -> bool {
78 (self.q0 - other.q0).abs() < epsilon
79 && (self.q1 - other.q1).abs() < epsilon
80 && (self.q2 - other.q2).abs() < epsilon
81 && (self.q3 - other.q3).abs() < epsilon
82 }
83
84 pub fn get(&self) -> (Vector, f64) {
85 let retval = 2.0 * self.q0.acos();
86 let mut axis = Vector::new(self.q1, self.q2, self.q3);
87 let len = axis.len();
88 if len == 0.0 {
89 (Vector::new(0.0, 0.0, 1.0), retval)
90 } else {
91 axis = axis.scale(1.0 / len);
92 (axis, retval)
93 }
94 }
95
96 pub fn set_q(&mut self, i: usize, val: f64) {
97 match i {
98 0 => self.q0 = val,
99 1 => self.q1 = val,
100 2 => self.q2 = val,
101 3 => self.q3 = val,
102 _ => panic!("Invalid quaternion index"),
103 };
104 }
105
106 pub fn invert(&self) -> Quaternion {
107 Quaternion {
108 q0: self.q0,
109 q1: self.q1 * -1.0,
110 q2: self.q2 * -1.0,
111 q3: self.q3 * -1.0,
112 }
113 }
114
115 pub fn length(&self) -> f64 {
116 (self.q0 * self.q0 + self.q1 * self.q1 + self.q2 * self.q2 + self.q3 * self.q3).sqrt()
117 }
118
119 pub fn normalized(&self) -> Quaternion {
120 let len = self.length();
121 Quaternion {
122 q0: self.q0 / len,
123 q1: self.q1 / len,
124 q2: self.q2 / len,
125 q3: self.q3 / len,
126 }
127 }
128
129 pub fn times(&self, other: &Quaternion) -> Quaternion {
130 Quaternion::mul(self, other)
131 }
132
133 pub fn mul(a: &Quaternion, b: &Quaternion) -> Quaternion {
134 Quaternion {
135 q0: (a.q0 * b.q0 - a.q1 * b.q1 - a.q2 * b.q2 - a.q3 * b.q3),
136 q1: (a.q0 * b.q1 + a.q1 * b.q0 + a.q2 * b.q3 - a.q3 * b.q2),
137 q2: (a.q0 * b.q2 + a.q2 * b.q0 - a.q1 * b.q3 + a.q3 * b.q1),
138 q3: (a.q0 * b.q3 + a.q3 * b.q0 + a.q1 * b.q2 - a.q2 * b.q1),
139 }
140 }
141
142 pub fn to_matrix(&self) -> Matrix {
143 let q00 = self.q0 * self.q0;
144 let q11 = self.q1 * self.q1;
145 let q22 = self.q2 * self.q2;
146 let q33 = self.q3 * self.q3;
147
148 let mut m = Matrix::default();
149
150 m.set(0, 0, q00 + q11 - q22 - q33);
151 m.set(1, 1, q00 - q11 + q22 - q33);
152 m.set(2, 2, q00 - q11 - q22 + q33);
153
154 let q03 = self.q0 * self.q3;
155 let q12 = self.q1 * self.q2;
156 m.set(1, 0, 2.0 * (q12 - q03));
157 m.set(0, 1, 2.0 * (q03 + q12));
158
159 let q02 = self.q0 * self.q2;
160 let q13 = self.q1 * self.q3;
161 m.set(2, 0, 2.0 * (q02 + q13));
162 m.set(0, 2, 2.0 * (q13 - q02));
163
164 let q01 = self.q0 * self.q1;
165 let q23 = self.q2 * self.q3;
166 m.set(2, 1, 2.0 * (q23 - q01));
167 m.set(1, 2, 2.0 * (q01 + q23));
168
169 m
170 }
171
172 pub fn rotate_vector(&self, src: &Vector) -> Vector {
173 let qvec = Vector::new(self.q1, self.q2, self.q3);
174
175 let mut q_cross_x = qvec.cross_product(src);
176 let mut q_cross_x_cross_q = q_cross_x.cross_product(&qvec);
177
178 q_cross_x = q_cross_x.scale(2.0 * self.q0);
179 q_cross_x_cross_q = q_cross_x_cross_q.scale(-2.0);
180
181 src.add(&q_cross_x).add(&q_cross_x_cross_q)
182 }
183}