1use crate::geometric::GA3;
15
16pub trait Projection: Send + Sync {
22 type Output;
24
25 fn project(&self, mv: &GA3) -> Self::Output;
27
28 fn name(&self) -> &str;
30}
31
32#[derive(Clone, Debug)]
34pub struct ScalarProjection;
35
36impl Projection for ScalarProjection {
37 type Output = f64;
38
39 fn project(&self, mv: &GA3) -> f64 {
40 mv.get(0)
41 }
42
43 fn name(&self) -> &str {
44 "scalar"
45 }
46}
47
48#[derive(Clone, Debug)]
50pub struct IntProjection;
51
52impl Projection for IntProjection {
53 type Output = i32;
54
55 fn project(&self, mv: &GA3) -> i32 {
56 mv.get(0) as i32
57 }
58
59 fn name(&self) -> &str {
60 "int"
61 }
62}
63
64#[derive(Clone, Debug)]
66pub struct BoolProjection;
67
68impl Projection for BoolProjection {
69 type Output = bool;
70
71 fn project(&self, mv: &GA3) -> bool {
72 mv.get(0) > 0.5
73 }
74
75 fn name(&self) -> &str {
76 "bool"
77 }
78}
79
80#[derive(Clone, Debug)]
82pub struct VectorProjection;
83
84impl Projection for VectorProjection {
85 type Output = (f64, f64, f64);
86
87 fn project(&self, mv: &GA3) -> (f64, f64, f64) {
88 (mv.get(1), mv.get(2), mv.get(4))
90 }
91
92 fn name(&self) -> &str {
93 "vector"
94 }
95}
96
97#[derive(Clone, Debug)]
99pub struct Position2DProjection;
100
101impl Projection for Position2DProjection {
102 type Output = (f64, f64);
103
104 fn project(&self, mv: &GA3) -> (f64, f64) {
105 (mv.get(1), mv.get(2))
106 }
107
108 fn name(&self) -> &str {
109 "position2d"
110 }
111}
112
113#[derive(Clone, Debug)]
115pub struct Position3DProjection;
116
117impl Projection for Position3DProjection {
118 type Output = (f64, f64, f64);
119
120 fn project(&self, mv: &GA3) -> (f64, f64, f64) {
121 (mv.get(1), mv.get(2), mv.get(4))
122 }
123
124 fn name(&self) -> &str {
125 "position3d"
126 }
127}
128
129#[derive(Clone, Debug)]
131pub struct BivectorProjection;
132
133impl Projection for BivectorProjection {
134 type Output = (f64, f64, f64);
135
136 fn project(&self, mv: &GA3) -> (f64, f64, f64) {
137 (mv.get(3), mv.get(5), mv.get(6))
139 }
140
141 fn name(&self) -> &str {
142 "bivector"
143 }
144}
145
146#[derive(Clone, Debug)]
150pub struct ColorProjection;
151
152impl Projection for ColorProjection {
153 type Output = (u8, u8, u8);
154
155 fn project(&self, mv: &GA3) -> (u8, u8, u8) {
156 let clamp = |v: f64| (v.clamp(0.0, 255.0)) as u8;
157 (clamp(mv.get(0)), clamp(mv.get(1)), clamp(mv.get(2)))
158 }
159
160 fn name(&self) -> &str {
161 "color"
162 }
163}
164
165#[derive(Clone, Debug)]
169pub struct ColorAlphaProjection;
170
171impl Projection for ColorAlphaProjection {
172 type Output = (u8, u8, u8, u8);
173
174 fn project(&self, mv: &GA3) -> (u8, u8, u8, u8) {
175 let clamp = |v: f64| (v.clamp(0.0, 255.0)) as u8;
176 (
177 clamp(mv.get(0)),
178 clamp(mv.get(1)),
179 clamp(mv.get(2)),
180 clamp(mv.get(4)),
181 )
182 }
183
184 fn name(&self) -> &str {
185 "color_alpha"
186 }
187}
188
189#[derive(Clone, Debug)]
191pub struct MagnitudeProjection;
192
193impl Projection for MagnitudeProjection {
194 type Output = f64;
195
196 fn project(&self, mv: &GA3) -> f64 {
197 mv.magnitude()
198 }
199
200 fn name(&self) -> &str {
201 "magnitude"
202 }
203}
204
205#[derive(Clone, Debug)]
209pub struct RotorAngleProjection;
210
211impl Projection for RotorAngleProjection {
212 type Output = f64;
213
214 fn project(&self, mv: &GA3) -> f64 {
215 let scalar = mv.get(0);
218 2.0 * scalar.clamp(-1.0, 1.0).acos()
219 }
220
221 fn name(&self) -> &str {
222 "rotor_angle"
223 }
224}
225
226pub struct MappedProjection<P, F, U>
228where
229 P: Projection,
230 F: Fn(P::Output) -> U + Send + Sync,
231{
232 inner: P,
233 map_fn: F,
234 name: String,
235}
236
237impl<P, F, U> MappedProjection<P, F, U>
238where
239 P: Projection,
240 F: Fn(P::Output) -> U + Send + Sync,
241{
242 pub fn new(inner: P, map_fn: F, name: impl Into<String>) -> Self {
244 Self {
245 inner,
246 map_fn,
247 name: name.into(),
248 }
249 }
250}
251
252impl<P, F, U> Projection for MappedProjection<P, F, U>
253where
254 P: Projection,
255 F: Fn(P::Output) -> U + Send + Sync,
256 U: Send + Sync,
257{
258 type Output = U;
259
260 fn project(&self, mv: &GA3) -> U {
261 (self.map_fn)(self.inner.project(mv))
262 }
263
264 fn name(&self) -> &str {
265 &self.name
266 }
267}
268
269pub struct CustomProjection<F, T>
271where
272 F: Fn(&GA3) -> T + Send + Sync,
273{
274 project_fn: F,
275 name: String,
276}
277
278impl<F, T> CustomProjection<F, T>
279where
280 F: Fn(&GA3) -> T + Send + Sync,
281{
282 pub fn new(project_fn: F, name: impl Into<String>) -> Self {
284 Self {
285 project_fn,
286 name: name.into(),
287 }
288 }
289}
290
291impl<F, T> Projection for CustomProjection<F, T>
292where
293 F: Fn(&GA3) -> T + Send + Sync,
294 T: Send + Sync,
295{
296 type Output = T;
297
298 fn project(&self, mv: &GA3) -> T {
299 (self.project_fn)(mv)
300 }
301
302 fn name(&self) -> &str {
303 &self.name
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use amari_core::Vector;
311
312 #[test]
313 fn test_scalar_projection() {
314 let mv = GA3::scalar(42.0);
315 let proj = ScalarProjection;
316 assert!((proj.project(&mv) - 42.0).abs() < 1e-10);
317 }
318
319 #[test]
320 fn test_int_projection() {
321 let mv = GA3::scalar(42.7);
322 let proj = IntProjection;
323 assert_eq!(proj.project(&mv), 42);
324 }
325
326 #[test]
327 fn test_bool_projection() {
328 let proj = BoolProjection;
329 assert!(!proj.project(&GA3::scalar(0.0)));
330 assert!(!proj.project(&GA3::scalar(0.4)));
331 assert!(proj.project(&GA3::scalar(0.6)));
332 assert!(proj.project(&GA3::scalar(1.0)));
333 }
334
335 #[test]
336 fn test_vector_projection() {
337 let v = Vector::<3, 0, 0>::from_components(1.0, 2.0, 3.0);
338 let mv = GA3::from_vector(&v);
339 let proj = VectorProjection;
340 let (x, y, z) = proj.project(&mv);
341 assert!((x - 1.0).abs() < 1e-10);
342 assert!((y - 2.0).abs() < 1e-10);
343 assert!((z - 3.0).abs() < 1e-10);
344 }
345
346 #[test]
347 fn test_color_projection() {
348 let mut coeffs = vec![0.0; 8];
350 coeffs[0] = 128.0; coeffs[1] = 64.0; coeffs[2] = 192.0; let mv = GA3::from_coefficients(coeffs);
354
355 let proj = ColorProjection;
356 let (r, g, b) = proj.project(&mv);
357 assert_eq!(r, 128);
358 assert_eq!(g, 64);
359 assert_eq!(b, 192);
360 }
361
362 #[test]
363 fn test_color_clamping() {
364 let mut coeffs = vec![0.0; 8];
365 coeffs[0] = 300.0; coeffs[1] = -50.0; coeffs[2] = 100.0; let mv = GA3::from_coefficients(coeffs);
369
370 let proj = ColorProjection;
371 let (r, g, b) = proj.project(&mv);
372 assert_eq!(r, 255); assert_eq!(g, 0); assert_eq!(b, 100);
375 }
376
377 #[test]
378 fn test_magnitude_projection() {
379 let v = Vector::<3, 0, 0>::from_components(3.0, 4.0, 0.0);
380 let mv = GA3::from_vector(&v);
381 let proj = MagnitudeProjection;
382 assert!((proj.project(&mv) - 5.0).abs() < 1e-10);
383 }
384
385 #[test]
386 fn test_custom_projection() {
387 let proj = CustomProjection::new(|mv: &GA3| mv.get(0) * 2.0, "doubled");
388 let mv = GA3::scalar(21.0);
389 assert!((proj.project(&mv) - 42.0).abs() < 1e-10);
390 assert_eq!(proj.name(), "doubled");
391 }
392
393 #[test]
394 fn test_mapped_projection() {
395 let proj = MappedProjection::new(ScalarProjection, |x| x as i32 * 2, "doubled_int");
396 let mv = GA3::scalar(21.0);
397 assert_eq!(proj.project(&mv), 42);
398 }
399}