1use crate::geometric::GA3;
25use crate::projection::Projection;
26use crate::transforms::{Rotor, Transform, Translation, Versor};
27use amari_core::{Bivector, Vector};
28use std::sync::{Arc, Mutex};
29
30type SubscriberList = Arc<Mutex<Vec<Box<dyn Fn(&GA3) + Send + Sync>>>>;
32
33#[derive(Clone)]
42pub struct GeometricState {
43 inner: Arc<Mutex<GA3>>,
45 subscribers: SubscriberList,
47}
48
49impl GeometricState {
50 pub fn new(mv: GA3) -> Self {
52 Self {
53 inner: Arc::new(Mutex::new(mv)),
54 subscribers: Arc::new(Mutex::new(Vec::new())),
55 }
56 }
57
58 pub fn from_scalar(value: f64) -> Self {
60 Self::new(GA3::scalar(value))
61 }
62
63 pub fn from_vector(x: f64, y: f64, z: f64) -> Self {
65 let v = Vector::<3, 0, 0>::from_components(x, y, z);
66 Self::new(GA3::from_vector(&v))
67 }
68
69 pub fn from_bivector(xy: f64, xz: f64, yz: f64) -> Self {
71 let b = Bivector::<3, 0, 0>::from_components(xy, xz, yz);
72 Self::new(GA3::from_bivector(&b))
73 }
74
75 pub fn from_coefficients(coeffs: Vec<f64>) -> Self {
77 Self::new(GA3::from_coefficients(coeffs))
78 }
79
80 pub fn zero() -> Self {
82 Self::new(GA3::zero())
83 }
84
85 pub fn identity() -> Self {
87 Self::new(GA3::scalar(1.0))
88 }
89
90 pub fn multivector(&self) -> GA3 {
92 self.inner.lock().unwrap().clone()
93 }
94
95 pub fn get(&self, index: usize) -> f64 {
97 self.inner.lock().unwrap().get(index)
98 }
99
100 pub fn scalar(&self) -> f64 {
102 self.get(0)
103 }
104
105 pub fn as_vector(&self) -> (f64, f64, f64) {
107 let mv = self.inner.lock().unwrap();
108 (mv.get(1), mv.get(2), mv.get(4))
109 }
110
111 pub fn as_bivector(&self) -> (f64, f64, f64) {
113 let mv = self.inner.lock().unwrap();
114 (mv.get(3), mv.get(5), mv.get(6))
115 }
116
117 pub fn magnitude(&self) -> f64 {
119 self.inner.lock().unwrap().magnitude()
120 }
121
122 pub fn project<P: Projection>(&self, projection: &P) -> P::Output {
124 let mv = self.inner.lock().unwrap();
125 projection.project(&mv)
126 }
127
128 pub fn set(&self, mv: GA3) {
130 {
131 let mut inner = self.inner.lock().unwrap();
132 *inner = mv;
133 }
134 self.notify_subscribers();
135 }
136
137 pub fn set_scalar(&self, value: f64) {
139 self.set(GA3::scalar(value));
140 }
141
142 pub fn set_vector(&self, x: f64, y: f64, z: f64) {
144 let v = Vector::<3, 0, 0>::from_components(x, y, z);
145 self.set(GA3::from_vector(&v));
146 }
147
148 pub fn update<F>(&self, f: F)
150 where
151 F: FnOnce(&GA3) -> GA3,
152 {
153 {
154 let mut inner = self.inner.lock().unwrap();
155 *inner = f(&inner);
156 }
157 self.notify_subscribers();
158 }
159
160 pub fn apply_rotor(&self, rotor: &Rotor) -> GeometricState {
164 let mv = self.inner.lock().unwrap();
165 let transformed = rotor.transform(&mv);
166 GeometricState::new(transformed)
167 }
168
169 pub fn apply_rotor_mut(&self, rotor: &Rotor) {
171 self.update(|mv| rotor.transform(mv));
172 }
173
174 pub fn apply_translation(&self, translation: &Translation) -> GeometricState {
178 let mv = self.inner.lock().unwrap();
179 let transformed = translation.transform(&mv);
180 GeometricState::new(transformed)
181 }
182
183 pub fn apply_translation_mut(&self, translation: &Translation) {
185 self.update(|mv| translation.transform(mv));
186 }
187
188 pub fn apply_versor(&self, versor: &Versor) -> GeometricState {
192 let mv = self.inner.lock().unwrap();
193 let transformed = versor.transform(&mv);
194 GeometricState::new(transformed)
195 }
196
197 pub fn apply_versor_mut(&self, versor: &Versor) {
199 self.update(|mv| versor.transform(mv));
200 }
201
202 pub fn apply_transform(&self, transform: &Transform) -> GeometricState {
206 let mv = self.inner.lock().unwrap();
207 let transformed = transform.transform(&mv);
208 GeometricState::new(transformed)
209 }
210
211 pub fn apply_transform_mut(&self, transform: &Transform) {
213 self.update(|mv| transform.transform(mv));
214 }
215
216 pub fn add(&self, other: &GeometricState) -> GeometricState {
218 let a = self.inner.lock().unwrap();
219 let b = other.inner.lock().unwrap();
220 GeometricState::new(&*a + &*b)
221 }
222
223 pub fn sub(&self, other: &GeometricState) -> GeometricState {
225 let a = self.inner.lock().unwrap();
226 let b = other.inner.lock().unwrap();
227 GeometricState::new(&*a - &*b)
228 }
229
230 pub fn scale(&self, factor: f64) -> GeometricState {
232 let mv = self.inner.lock().unwrap();
233 GeometricState::new(&*mv * factor)
234 }
235
236 pub fn geometric_product(&self, other: &GeometricState) -> GeometricState {
238 let a = self.inner.lock().unwrap();
239 let b = other.inner.lock().unwrap();
240 GeometricState::new(a.geometric_product(&b))
241 }
242
243 pub fn normalize(&self) -> Option<GeometricState> {
245 let mv = self.inner.lock().unwrap();
246 mv.normalize().map(GeometricState::new)
247 }
248
249 pub fn normalize_mut(&self) -> bool {
251 let mut inner = self.inner.lock().unwrap();
252 match inner.normalize() {
253 Some(normalized) => {
254 *inner = normalized;
255 drop(inner);
256 self.notify_subscribers();
257 true
258 }
259 None => false,
260 }
261 }
262
263 pub fn reverse(&self) -> GeometricState {
265 let mv = self.inner.lock().unwrap();
266 GeometricState::new(mv.reverse())
267 }
268
269 pub fn lerp(&self, other: &GeometricState, t: f64) -> GeometricState {
271 let a = self.inner.lock().unwrap();
272 let b = other.inner.lock().unwrap();
273
274 let diff = &*b - &*a;
276 let interpolated = &*a + &(&diff * t);
277 GeometricState::new(interpolated)
278 }
279
280 pub fn slerp(&self, other: &GeometricState, t: f64) -> GeometricState {
284 let a_mv = self.inner.lock().unwrap();
286 let b_mv = other.inner.lock().unwrap();
287
288 let rotor_a = Rotor::from_multivector(a_mv.clone());
289 let rotor_b = Rotor::from_multivector(b_mv.clone());
290
291 let interpolated = rotor_a.slerp_to(&rotor_b, t);
292 GeometricState::new(interpolated.as_multivector().clone())
293 }
294
295 pub fn subscribe<F>(&self, callback: F) -> GeometricSubscription
297 where
298 F: Fn(&GA3) + Send + Sync + 'static,
299 {
300 let mut subs = self.subscribers.lock().unwrap();
301 let id = subs.len();
302 subs.push(Box::new(callback));
303
304 GeometricSubscription {
305 id,
306 subscribers: self.subscribers.clone(),
307 }
308 }
309
310 fn notify_subscribers(&self) {
312 let mv = self.inner.lock().unwrap();
313 let subs = self.subscribers.lock().unwrap();
314 for callback in subs.iter() {
315 callback(&mv);
316 }
317 }
318}
319
320impl std::fmt::Debug for GeometricState {
321 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322 let mv = self.inner.lock().unwrap();
323 write!(f, "GeometricState({:?})", mv)
324 }
325}
326
327pub struct GeometricSubscription {
329 id: usize,
330 subscribers: SubscriberList,
331}
332
333impl GeometricSubscription {
334 pub fn unsubscribe(self) {
340 let _ = (self.id, self.subscribers);
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use crate::projection::{IntProjection, ScalarProjection, VectorProjection};
350 use std::f64::consts::PI;
351 use std::sync::atomic::{AtomicUsize, Ordering};
352
353 #[test]
354 fn test_from_scalar() {
355 let state = GeometricState::from_scalar(42.0);
356 assert!((state.scalar() - 42.0).abs() < 1e-10);
357 }
358
359 #[test]
360 fn test_from_vector() {
361 let state = GeometricState::from_vector(1.0, 2.0, 3.0);
362 let (x, y, z) = state.as_vector();
363 assert!((x - 1.0).abs() < 1e-10);
364 assert!((y - 2.0).abs() < 1e-10);
365 assert!((z - 3.0).abs() < 1e-10);
366 }
367
368 #[test]
369 fn test_apply_rotor() {
370 let state = GeometricState::from_vector(1.0, 0.0, 0.0);
371 let rotor = Rotor::xy(PI / 2.0);
372 let rotated = state.apply_rotor(&rotor);
373
374 let (x, y, z) = rotated.as_vector();
375 assert!(x.abs() < 1e-10, "x should be ~0, got {}", x);
376 assert!((y - 1.0).abs() < 1e-10, "y should be ~1, got {}", y);
377 assert!(z.abs() < 1e-10, "z should be ~0, got {}", z);
378 }
379
380 #[test]
381 fn test_apply_translation() {
382 let state = GeometricState::from_vector(0.0, 0.0, 0.0);
383 let trans = Translation::new(1.0, 2.0, 3.0);
384 let translated = state.apply_translation(&trans);
385
386 let (x, y, z) = translated.as_vector();
387 assert!((x - 1.0).abs() < 1e-10);
388 assert!((y - 2.0).abs() < 1e-10);
389 assert!((z - 3.0).abs() < 1e-10);
390 }
391
392 #[test]
393 fn test_apply_transform() {
394 let state = GeometricState::from_vector(1.0, 0.0, 0.0);
395 let transform = Transform::new(Rotor::xy(PI / 2.0), Translation::new(1.0, 0.0, 0.0));
396
397 let result = state.apply_transform(&transform);
398 let (x, y, z) = result.as_vector();
399
400 assert!((x - 1.0).abs() < 1e-10, "x should be ~1, got {}", x);
402 assert!((y - 1.0).abs() < 1e-10, "y should be ~1, got {}", y);
403 assert!(z.abs() < 1e-10);
404 }
405
406 #[test]
407 fn test_lerp() {
408 let a = GeometricState::from_scalar(0.0);
409 let b = GeometricState::from_scalar(10.0);
410
411 let half = a.lerp(&b, 0.5);
412 assert!((half.scalar() - 5.0).abs() < 1e-10);
413
414 let quarter = a.lerp(&b, 0.25);
415 assert!((quarter.scalar() - 2.5).abs() < 1e-10);
416 }
417
418 #[test]
419 fn test_projection() {
420 let state = GeometricState::from_scalar(42.7);
421
422 let scalar = state.project(&ScalarProjection);
423 assert!((scalar - 42.7).abs() < 1e-10);
424
425 let int = state.project(&IntProjection);
426 assert_eq!(int, 42);
427 }
428
429 #[test]
430 fn test_vector_projection() {
431 let state = GeometricState::from_vector(1.0, 2.0, 3.0);
432 let (x, y, z) = state.project(&VectorProjection);
433 assert!((x - 1.0).abs() < 1e-10);
434 assert!((y - 2.0).abs() < 1e-10);
435 assert!((z - 3.0).abs() < 1e-10);
436 }
437
438 #[test]
439 fn test_subscribe() {
440 let state = GeometricState::from_scalar(0.0);
441 let call_count = Arc::new(AtomicUsize::new(0));
442 let call_count_clone = call_count.clone();
443
444 let _sub = state.subscribe(move |_mv| {
445 call_count_clone.fetch_add(1, Ordering::SeqCst);
446 });
447
448 state.set_scalar(1.0);
449 state.set_scalar(2.0);
450
451 assert_eq!(call_count.load(Ordering::SeqCst), 2);
452 }
453
454 #[test]
455 fn test_scale() {
456 let state = GeometricState::from_vector(1.0, 2.0, 3.0);
457 let scaled = state.scale(2.0);
458
459 let (x, y, z) = scaled.as_vector();
460 assert!((x - 2.0).abs() < 1e-10);
461 assert!((y - 4.0).abs() < 1e-10);
462 assert!((z - 6.0).abs() < 1e-10);
463 }
464
465 #[test]
466 fn test_normalize() {
467 let state = GeometricState::from_vector(3.0, 4.0, 0.0);
468 let normalized = state.normalize().unwrap();
469
470 let mag = normalized.magnitude();
471 assert!((mag - 1.0).abs() < 1e-10);
472
473 let (x, y, _) = normalized.as_vector();
474 assert!((x - 0.6).abs() < 1e-10);
475 assert!((y - 0.8).abs() < 1e-10);
476 }
477}