1use crate::geometric::GA3;
15use amari_core::{Bivector, Vector};
16
17pub trait Projection: Send + Sync {
23 type Output;
25
26 fn project(&self, mv: &GA3) -> Self::Output;
28
29 fn name(&self) -> &str;
31}
32
33#[derive(Clone, Debug)]
35pub struct ScalarProjection;
36
37impl Projection for ScalarProjection {
38 type Output = f64;
39
40 fn project(&self, mv: &GA3) -> f64 {
41 mv.get(0)
42 }
43
44 fn name(&self) -> &str {
45 "scalar"
46 }
47}
48
49#[derive(Clone, Debug)]
51pub struct IntProjection;
52
53impl Projection for IntProjection {
54 type Output = i32;
55
56 fn project(&self, mv: &GA3) -> i32 {
57 mv.get(0) as i32
58 }
59
60 fn name(&self) -> &str {
61 "int"
62 }
63}
64
65#[derive(Clone, Debug)]
67pub struct BoolProjection;
68
69impl Projection for BoolProjection {
70 type Output = bool;
71
72 fn project(&self, mv: &GA3) -> bool {
73 mv.get(0) > 0.5
74 }
75
76 fn name(&self) -> &str {
77 "bool"
78 }
79}
80
81#[derive(Clone, Debug)]
83pub struct VectorProjection;
84
85impl Projection for VectorProjection {
86 type Output = (f64, f64, f64);
87
88 fn project(&self, mv: &GA3) -> (f64, f64, f64) {
89 (mv.get(1), mv.get(2), mv.get(4))
91 }
92
93 fn name(&self) -> &str {
94 "vector"
95 }
96}
97
98#[derive(Clone, Debug)]
100pub struct Position2DProjection;
101
102impl Projection for Position2DProjection {
103 type Output = (f64, f64);
104
105 fn project(&self, mv: &GA3) -> (f64, f64) {
106 (mv.get(1), mv.get(2))
107 }
108
109 fn name(&self) -> &str {
110 "position2d"
111 }
112}
113
114#[derive(Clone, Debug)]
116pub struct Position3DProjection;
117
118impl Projection for Position3DProjection {
119 type Output = (f64, f64, f64);
120
121 fn project(&self, mv: &GA3) -> (f64, f64, f64) {
122 (mv.get(1), mv.get(2), mv.get(4))
123 }
124
125 fn name(&self) -> &str {
126 "position3d"
127 }
128}
129
130#[derive(Clone, Debug)]
132pub struct BivectorProjection;
133
134impl Projection for BivectorProjection {
135 type Output = (f64, f64, f64);
136
137 fn project(&self, mv: &GA3) -> (f64, f64, f64) {
138 (mv.get(3), mv.get(5), mv.get(6))
140 }
141
142 fn name(&self) -> &str {
143 "bivector"
144 }
145}
146
147#[derive(Clone, Debug)]
149pub struct TypedVectorProjection;
150
151impl Projection for TypedVectorProjection {
152 type Output = Vector<3, 0, 0>;
153
154 fn project(&self, mv: &GA3) -> Vector<3, 0, 0> {
155 Vector::from_components(mv.get(1), mv.get(2), mv.get(4))
156 }
157
158 fn name(&self) -> &str {
159 "typed_vector"
160 }
161}
162
163#[derive(Clone, Debug)]
165pub struct TypedBivectorProjection;
166
167impl Projection for TypedBivectorProjection {
168 type Output = Bivector<3, 0, 0>;
169
170 fn project(&self, mv: &GA3) -> Bivector<3, 0, 0> {
171 Bivector::from_components(mv.get(3), mv.get(5), mv.get(6))
172 }
173
174 fn name(&self) -> &str {
175 "typed_bivector"
176 }
177}
178
179#[derive(Clone, Debug)]
183pub struct ColorProjection;
184
185impl Projection for ColorProjection {
186 type Output = (u8, u8, u8);
187
188 fn project(&self, mv: &GA3) -> (u8, u8, u8) {
189 let clamp = |v: f64| (v.clamp(0.0, 255.0)) as u8;
190 (clamp(mv.get(0)), clamp(mv.get(1)), clamp(mv.get(2)))
191 }
192
193 fn name(&self) -> &str {
194 "color"
195 }
196}
197
198#[derive(Clone, Debug)]
202pub struct ColorAlphaProjection;
203
204impl Projection for ColorAlphaProjection {
205 type Output = (u8, u8, u8, u8);
206
207 fn project(&self, mv: &GA3) -> (u8, u8, u8, u8) {
208 let clamp = |v: f64| (v.clamp(0.0, 255.0)) as u8;
209 (
210 clamp(mv.get(0)),
211 clamp(mv.get(1)),
212 clamp(mv.get(2)),
213 clamp(mv.get(4)),
214 )
215 }
216
217 fn name(&self) -> &str {
218 "color_alpha"
219 }
220}
221
222#[derive(Clone, Debug)]
224pub struct MagnitudeProjection;
225
226impl Projection for MagnitudeProjection {
227 type Output = f64;
228
229 fn project(&self, mv: &GA3) -> f64 {
230 mv.magnitude()
231 }
232
233 fn name(&self) -> &str {
234 "magnitude"
235 }
236}
237
238#[derive(Clone, Debug)]
242pub struct RotorAngleProjection;
243
244impl Projection for RotorAngleProjection {
245 type Output = f64;
246
247 fn project(&self, mv: &GA3) -> f64 {
248 let scalar = mv.get(0);
251 2.0 * scalar.clamp(-1.0, 1.0).acos()
252 }
253
254 fn name(&self) -> &str {
255 "rotor_angle"
256 }
257}
258
259pub struct MappedProjection<P, F, U>
261where
262 P: Projection,
263 F: Fn(P::Output) -> U + Send + Sync,
264{
265 inner: P,
266 map_fn: F,
267 name: String,
268}
269
270impl<P, F, U> MappedProjection<P, F, U>
271where
272 P: Projection,
273 F: Fn(P::Output) -> U + Send + Sync,
274{
275 pub fn new(inner: P, map_fn: F, name: impl Into<String>) -> Self {
277 Self {
278 inner,
279 map_fn,
280 name: name.into(),
281 }
282 }
283}
284
285impl<P, F, U> Projection for MappedProjection<P, F, U>
286where
287 P: Projection,
288 F: Fn(P::Output) -> U + Send + Sync,
289 U: Send + Sync,
290{
291 type Output = U;
292
293 fn project(&self, mv: &GA3) -> U {
294 (self.map_fn)(self.inner.project(mv))
295 }
296
297 fn name(&self) -> &str {
298 &self.name
299 }
300}
301
302pub struct CustomProjection<F, T>
304where
305 F: Fn(&GA3) -> T + Send + Sync,
306{
307 project_fn: F,
308 name: String,
309}
310
311impl<F, T> CustomProjection<F, T>
312where
313 F: Fn(&GA3) -> T + Send + Sync,
314{
315 pub fn new(project_fn: F, name: impl Into<String>) -> Self {
317 Self {
318 project_fn,
319 name: name.into(),
320 }
321 }
322}
323
324impl<F, T> Projection for CustomProjection<F, T>
325where
326 F: Fn(&GA3) -> T + Send + Sync,
327 T: Send + Sync,
328{
329 type Output = T;
330
331 fn project(&self, mv: &GA3) -> T {
332 (self.project_fn)(mv)
333 }
334
335 fn name(&self) -> &str {
336 &self.name
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use amari_core::Vector;
344
345 #[test]
346 fn test_scalar_projection() {
347 let mv = GA3::scalar(42.0);
348 let proj = ScalarProjection;
349 assert!((proj.project(&mv) - 42.0).abs() < 1e-10);
350 }
351
352 #[test]
353 fn test_int_projection() {
354 let mv = GA3::scalar(42.7);
355 let proj = IntProjection;
356 assert_eq!(proj.project(&mv), 42);
357 }
358
359 #[test]
360 fn test_bool_projection() {
361 let proj = BoolProjection;
362 assert!(!proj.project(&GA3::scalar(0.0)));
363 assert!(!proj.project(&GA3::scalar(0.4)));
364 assert!(proj.project(&GA3::scalar(0.6)));
365 assert!(proj.project(&GA3::scalar(1.0)));
366 }
367
368 #[test]
369 fn test_vector_projection() {
370 let v = Vector::<3, 0, 0>::from_components(1.0, 2.0, 3.0);
371 let mv = GA3::from_vector(&v);
372 let proj = VectorProjection;
373 let (x, y, z) = proj.project(&mv);
374 assert!((x - 1.0).abs() < 1e-10);
375 assert!((y - 2.0).abs() < 1e-10);
376 assert!((z - 3.0).abs() < 1e-10);
377 }
378
379 #[test]
380 fn test_color_projection() {
381 let mut coeffs = vec![0.0; 8];
383 coeffs[0] = 128.0; coeffs[1] = 64.0; coeffs[2] = 192.0; let mv = GA3::from_coefficients(coeffs);
387
388 let proj = ColorProjection;
389 let (r, g, b) = proj.project(&mv);
390 assert_eq!(r, 128);
391 assert_eq!(g, 64);
392 assert_eq!(b, 192);
393 }
394
395 #[test]
396 fn test_color_clamping() {
397 let mut coeffs = vec![0.0; 8];
398 coeffs[0] = 300.0; coeffs[1] = -50.0; coeffs[2] = 100.0; let mv = GA3::from_coefficients(coeffs);
402
403 let proj = ColorProjection;
404 let (r, g, b) = proj.project(&mv);
405 assert_eq!(r, 255); assert_eq!(g, 0); assert_eq!(b, 100);
408 }
409
410 #[test]
411 fn test_magnitude_projection() {
412 let v = Vector::<3, 0, 0>::from_components(3.0, 4.0, 0.0);
413 let mv = GA3::from_vector(&v);
414 let proj = MagnitudeProjection;
415 assert!((proj.project(&mv) - 5.0).abs() < 1e-10);
416 }
417
418 #[test]
419 fn test_custom_projection() {
420 let proj = CustomProjection::new(|mv: &GA3| mv.get(0) * 2.0, "doubled");
421 let mv = GA3::scalar(21.0);
422 assert!((proj.project(&mv) - 42.0).abs() < 1e-10);
423 assert_eq!(proj.name(), "doubled");
424 }
425
426 #[test]
427 fn test_typed_vector_projection() {
428 let v = Vector::<3, 0, 0>::from_components(1.0, 2.0, 3.0);
429 let mv = GA3::from_vector(&v);
430 let proj = TypedVectorProjection;
431 let result = proj.project(&mv);
432 assert!((result.mv.get(1) - 1.0).abs() < 1e-10);
434 assert!((result.mv.get(2) - 2.0).abs() < 1e-10);
435 assert!((result.mv.get(4) - 3.0).abs() < 1e-10);
436 }
437
438 #[test]
439 fn test_typed_bivector_projection() {
440 let b = Bivector::<3, 0, 0>::from_components(0.5, 0.3, 0.1);
441 let mv = GA3::from_bivector(&b);
442 let proj = TypedBivectorProjection;
443 let result = proj.project(&mv);
444 assert!((result.get(0) - 0.5).abs() < 1e-10);
445 assert!((result.get(1) - 0.3).abs() < 1e-10);
446 assert!((result.get(2) - 0.1).abs() < 1e-10);
447 }
448
449 #[test]
450 fn test_mapped_projection() {
451 let proj = MappedProjection::new(ScalarProjection, |x| x as i32 * 2, "doubled_int");
452 let mv = GA3::scalar(21.0);
453 assert_eq!(proj.project(&mv), 42);
454 }
455}