sciimg/
quaternion.rs

1use crate::{matrix::Matrix, vector::Vector};
2
3#[derive(Debug, Clone, Copy, PartialEq)]
4pub struct Quaternion {
5    q0: f64,
6    q1: f64,
7    q2: f64,
8    q3: f64,
9}
10
11// Adapted from some old code I wrote years ago: https://github.com/kmgill/jdem846/blob/master/jdem846/source/base/src/us/wthr/jdem846/math/Quaternion.java
12// which itself was adapter from something, I just forget from where.... Probably three.js
13impl Quaternion {
14    pub fn default() -> Self {
15        Quaternion {
16            q0: 1.0,
17            q1: 0.0,
18            q2: 0.0,
19            q3: 0.0,
20        }
21    }
22
23    pub fn from_axis_and_angle(axis: &Vector, angle: f64) -> Self {
24        let half_theta = angle / 2.0;
25        let q0 = half_theta.cos();
26        let sin_half_theta = half_theta.sin();
27
28        let real_axis = axis.normalized();
29
30        Quaternion {
31            q0,
32            q1: real_axis.x * sin_half_theta,
33            q2: real_axis.y * sin_half_theta,
34            q3: real_axis.z * sin_half_theta,
35        }
36    }
37
38    pub fn from_pitch_roll_yaw(roll: f64, pitch: f64, yaw: f64) -> Self {
39        let roll_q = Quaternion::from_axis_and_angle(&Vector::new(1.0, 0.0, 0.0), roll);
40        let pitch_q = Quaternion::from_axis_and_angle(&Vector::new(0.0, 1.0, 0.0), pitch);
41        let yaw_q = Quaternion::from_axis_and_angle(&Vector::new(0.0, 0.0, 1.0), yaw);
42
43        yaw_q.times(&pitch_q).times(&roll_q)
44    }
45
46    pub fn from_matrix(mat: &Matrix) -> Quaternion {
47        let tr = mat.get(0, 0) + mat.get(1, 1) + mat.get(2, 2);
48
49        if tr > 0.0 {
50            let mut s = (tr + 1.0).sqrt();
51            let q0 = s * 0.5;
52            s = 0.5 / s;
53            Quaternion {
54                q0,
55                q1: (mat.get(2, 1) - mat.get(1, 2)) * s,
56                q2: (mat.get(0, 2) - mat.get(2, 0)) * s,
57                q3: (mat.get(1, 0) - mat.get(0, 1)) * s,
58            }
59        } else {
60            let mut i = if mat.get(1, 1) > mat.get(0, 0) { 1 } else { 0 };
61            if mat.get(2, 2) > mat.get(i, i) {
62                i = 2;
63            }
64            let j = (i + 1) % 3;
65            let k = (j + 1) % 3;
66            let mut s = ((mat.get(i, i) - (mat.get(j, j) + mat.get(k, k))) + 1.0).sqrt();
67            let mut q = Quaternion::default();
68            q.set_q(i + 1, s * 0.5);
69            s = 0.5 / s;
70            q.q0 = (mat.get(k, j) - mat.get(j, k)) * s;
71            q.set_q(j + 1, (mat.get(j, i) + mat.get(i, j)) * s);
72            q.set_q(k + 1, (mat.get(k, i) + mat.get(i, k)) * s);
73            q
74        }
75    }
76
77    pub fn within_epsilon(&self, other: &Quaternion, epsilon: f64) -> bool {
78        (self.q0 - other.q0).abs() < epsilon
79            && (self.q1 - other.q1).abs() < epsilon
80            && (self.q2 - other.q2).abs() < epsilon
81            && (self.q3 - other.q3).abs() < epsilon
82    }
83
84    pub fn get(&self) -> (Vector, f64) {
85        let retval = 2.0 * self.q0.acos();
86        let mut axis = Vector::new(self.q1, self.q2, self.q3);
87        let len = axis.len();
88        if len == 0.0 {
89            (Vector::new(0.0, 0.0, 1.0), retval)
90        } else {
91            axis = axis.scale(1.0 / len);
92            (axis, retval)
93        }
94    }
95
96    pub fn set_q(&mut self, i: usize, val: f64) {
97        match i {
98            0 => self.q0 = val,
99            1 => self.q1 = val,
100            2 => self.q2 = val,
101            3 => self.q3 = val,
102            _ => panic!("Invalid quaternion index"),
103        };
104    }
105
106    pub fn invert(&self) -> Quaternion {
107        Quaternion {
108            q0: self.q0,
109            q1: self.q1 * -1.0,
110            q2: self.q2 * -1.0,
111            q3: self.q3 * -1.0,
112        }
113    }
114
115    pub fn length(&self) -> f64 {
116        (self.q0 * self.q0 + self.q1 * self.q1 + self.q2 * self.q2 + self.q3 * self.q3).sqrt()
117    }
118
119    pub fn normalized(&self) -> Quaternion {
120        let len = self.length();
121        Quaternion {
122            q0: self.q0 / len,
123            q1: self.q1 / len,
124            q2: self.q2 / len,
125            q3: self.q3 / len,
126        }
127    }
128
129    pub fn times(&self, other: &Quaternion) -> Quaternion {
130        Quaternion::mul(self, other)
131    }
132
133    pub fn mul(a: &Quaternion, b: &Quaternion) -> Quaternion {
134        Quaternion {
135            q0: (a.q0 * b.q0 - a.q1 * b.q1 - a.q2 * b.q2 - a.q3 * b.q3),
136            q1: (a.q0 * b.q1 + a.q1 * b.q0 + a.q2 * b.q3 - a.q3 * b.q2),
137            q2: (a.q0 * b.q2 + a.q2 * b.q0 - a.q1 * b.q3 + a.q3 * b.q1),
138            q3: (a.q0 * b.q3 + a.q3 * b.q0 + a.q1 * b.q2 - a.q2 * b.q1),
139        }
140    }
141
142    pub fn to_matrix(&self) -> Matrix {
143        let q00 = self.q0 * self.q0;
144        let q11 = self.q1 * self.q1;
145        let q22 = self.q2 * self.q2;
146        let q33 = self.q3 * self.q3;
147
148        let mut m = Matrix::default();
149
150        m.set(0, 0, q00 + q11 - q22 - q33);
151        m.set(1, 1, q00 - q11 + q22 - q33);
152        m.set(2, 2, q00 - q11 - q22 + q33);
153
154        let q03 = self.q0 * self.q3;
155        let q12 = self.q1 * self.q2;
156        m.set(1, 0, 2.0 * (q12 - q03));
157        m.set(0, 1, 2.0 * (q03 + q12));
158
159        let q02 = self.q0 * self.q2;
160        let q13 = self.q1 * self.q3;
161        m.set(2, 0, 2.0 * (q02 + q13));
162        m.set(0, 2, 2.0 * (q13 - q02));
163
164        let q01 = self.q0 * self.q1;
165        let q23 = self.q2 * self.q3;
166        m.set(2, 1, 2.0 * (q23 - q01));
167        m.set(1, 2, 2.0 * (q01 + q23));
168
169        m
170    }
171
172    pub fn rotate_vector(&self, src: &Vector) -> Vector {
173        let qvec = Vector::new(self.q1, self.q2, self.q3);
174
175        let mut q_cross_x = qvec.cross_product(src);
176        let mut q_cross_x_cross_q = q_cross_x.cross_product(&qvec);
177
178        q_cross_x = q_cross_x.scale(2.0 * self.q0);
179        q_cross_x_cross_q = q_cross_x_cross_q.scale(-2.0);
180
181        src.add(&q_cross_x).add(&q_cross_x_cross_q)
182    }
183}