1#[repr(C)]
39#[derive(Copy, Clone, Debug, PartialEq)]
40#[cfg_attr(
41 feature = "rkyv",
42 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
43)]
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[cfg_attr(feature = "bytemuck", derive(bytemuck::Pod, bytemuck::Zeroable))]
46#[cfg_attr(feature = "zerocopy", derive(zerocopy::AsBytes, zerocopy::FromBytes))]
47pub struct EncodedUnitVector3([f32; 2]);
48
49impl EncodedUnitVector3 {
50 pub fn encode(unit_vector: [f32; 3]) -> Self {
52 let mut n = unit_vector;
53
54 debug_assert!(
55 (length_2(n) - 1.0).abs() < 0.0001,
56 "Argument must be normalized"
57 );
58
59 let inv_sum = 1.0 / (n[0].abs() + n[1].abs() + n[2].abs());
60 n[0] *= inv_sum;
61 n[1] *= inv_sum;
62 if n[2] < 0.0 {
63 let x = n[0];
64 n[0] = if n[0] >= 0.0 { 1.0 } else { -1.0 } * (1.0 - n[1].abs());
65 n[1] = if n[1] >= 0.0 { 1.0 } else { -1.0 } * (1.0 - x.abs());
66 }
67 Self([n[0], n[1]])
68 }
69
70 pub fn decode(&self) -> [f32; 3] {
72 let x = self.0[0];
73 let y = self.0[1];
74 let z = 1.0 - x.abs() - y.abs();
75 let t = (-z).max(0.0);
76 normalize([
77 x + if x >= 0.0 { -t } else { t },
78 y + if y >= 0.0 { -t } else { t },
79 z,
80 ])
81 }
82
83 pub fn from_raw(raw: [f32; 2]) -> Self {
85 Self(raw)
86 }
87
88 pub fn raw(&self) -> [f32; 2] {
90 self.0
91 }
92}
93
94#[cfg(feature = "half")]
95mod float16 {
96 use half::f16;
97
98 use crate::EncodedUnitVector3;
99
100 #[repr(C)]
104 #[derive(Copy, Clone, Debug, PartialEq)]
105 #[cfg_attr(
106 feature = "rkyv",
107 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
108 )]
109 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
110 #[cfg_attr(feature = "bytemuck", derive(bytemuck::Pod, bytemuck::Zeroable))]
111 #[cfg_attr(feature = "zerocopy", derive(zerocopy::AsBytes, zerocopy::FromBytes))]
112 pub struct EncodedUnitVector3F16([f16; 2]);
113
114 impl EncodedUnitVector3F16 {
115 pub fn encode(unit_vector: [f32; 3]) -> Self {
117 let encoded_f32 = EncodedUnitVector3::encode(unit_vector);
118 Self([
119 f16::from_f32(encoded_f32.0[0]),
120 f16::from_f32(encoded_f32.0[1]),
121 ])
122 }
123
124 pub fn decode(&self) -> [f32; 3] {
126 EncodedUnitVector3::from_raw([self.0[0].to_f32(), self.0[1].to_f32()]).decode()
127 }
128
129 pub fn from_raw(raw: [f16; 2]) -> Self {
131 Self(raw)
132 }
133
134 pub fn raw(&self) -> [f16; 2] {
136 self.0
137 }
138 }
139}
140
141#[cfg(feature = "half")]
142pub use float16::EncodedUnitVector3F16;
143
144#[repr(C)]
148#[derive(Copy, Clone, Debug, PartialEq)]
149#[cfg_attr(
150 feature = "rkyv",
151 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
152)]
153#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
154#[cfg_attr(feature = "bytemuck", derive(bytemuck::Pod, bytemuck::Zeroable))]
155#[cfg_attr(feature = "zerocopy", derive(zerocopy::AsBytes, zerocopy::FromBytes))]
156pub struct EncodedUnitVector3U8([u8; 2]);
157
158impl EncodedUnitVector3U8 {
159 pub fn encode(unit_vector: [f32; 3]) -> Self {
161 let encoded_f32 = EncodedUnitVector3::encode(unit_vector);
162 Self([Self::to_u8(encoded_f32.0[0]), Self::to_u8(encoded_f32.0[1])])
163 }
164
165 pub fn decode(&self) -> [f32; 3] {
167 EncodedUnitVector3([Self::to_f32(self.0[0]), Self::to_f32(self.0[1])]).decode()
168 }
169
170 pub fn from_raw(raw: [u8; 2]) -> Self {
172 Self(raw)
173 }
174
175 pub fn raw(&self) -> [u8; 2] {
177 self.0
178 }
179
180 #[inline]
181 fn to_u8(from: f32) -> u8 {
182 (((from + 1.0) * 0.5) * 255.0) as u8
183 }
184
185 #[inline]
186 fn to_f32(from: u8) -> f32 {
187 (from as f32 / 255.0) * 2.0 - 1.0
188 }
189}
190
191#[inline]
192fn length_2(v: [f32; 3]) -> f32 {
193 v[0] * v[0] + v[1] * v[1] + v[2] * v[2]
194}
195
196#[inline]
197fn normalize(v: [f32; 3]) -> [f32; 3] {
198 let inv_length = 1.0 / length_2(v).sqrt();
199 [v[0] * inv_length, v[1] * inv_length, v[2] * inv_length]
200}
201
202#[cfg(test)]
203mod tests {
204 use rand::{Rng, SeedableRng};
205
206 use super::{length_2, normalize};
207
208 #[test]
209 fn test_error_rate_f32() {
210 let expected_avg_error = 1.0932611e-5;
211 let expected_max_error = 0.00048828125;
212 test_error_rate_impl(
213 |unit_vector| crate::EncodedUnitVector3::encode(unit_vector).decode(),
214 expected_avg_error,
215 expected_max_error,
216 );
217 }
218
219 #[test]
220 #[cfg(feature = "half")]
221 fn test_error_rate_f16() {
222 let expected_avg_error = 0.00013977697;
223 let expected_max_error = 0.001035801;
224 test_error_rate_impl(
225 |unit_vector| crate::EncodedUnitVector3F16::encode(unit_vector).decode(),
226 expected_avg_error,
227 expected_max_error,
228 );
229 }
230
231 #[test]
232 fn test_error_rate_u8() {
233 let expected_avg_error = 0.01223357;
234 let expected_max_error = 0.03299934;
235 test_error_rate_impl(
236 |unit_vector| crate::EncodedUnitVector3U8::encode(unit_vector).decode(),
237 expected_avg_error,
238 expected_max_error,
239 );
240 }
241
242 fn test_error_rate_impl<F>(codec: F, expected_avg_error: f32, expected_max_error: f32)
245 where
246 F: Fn([f32; 3]) -> [f32; 3],
247 {
248 let sample_size = 100_000;
249 let unit_vectors = generate_unit_vectors(sample_size);
250
251 let mut acc_error: f32 = 0.0;
252 let mut max_error: f32 = 0.0;
253
254 for unit_vector in unit_vectors {
255 let decoded = codec(unit_vector);
256 let error = angle_between(unit_vector, decoded);
257 acc_error += error;
258 max_error = max_error.max(error);
259 }
260
261 let avg_error = acc_error / sample_size as f32;
262
263 assert_eq!(max_error, expected_max_error);
264 assert_eq!(avg_error, expected_avg_error);
265 }
266
267 fn generate_unit_vectors(count: usize) -> Vec<[f32; 3]> {
268 let mut rng = rand_chacha::ChaCha8Rng::from_seed([0; 32]);
270
271 (0..count)
272 .map(|_| {
273 normalize([
274 rng.gen::<f32>() * 2.0 - 1.0,
275 rng.gen::<f32>() * 2.0 - 1.0,
276 rng.gen::<f32>() * 2.0 - 1.0,
277 ])
278 })
279 .collect()
280 }
281
282 fn angle_between(v1: [f32; 3], v2: [f32; 3]) -> f32 {
283 ((v1[0] * v2[0] + v1[1] * v2[1] + v1[2] * v2[2])
284 / (length_2(v1).sqrt() * length_2(v2).sqrt()))
285 .clamp(-1.0, 1.0)
286 .acos()
287 }
288}