1#![cfg_attr(not(feature = "std"), no_std)]
8
9extern crate alloc;
10use alloc::vec::Vec;
11use core::ops::{Add, Mul, Neg};
12use num_traits::Float;
13
14#[cfg(feature = "high-precision")]
16pub use amari_core::HighPrecisionFloat;
17pub use amari_core::{ExtendedFloat, PrecisionFloat, StandardFloat};
18
19pub mod error;
20pub mod polytope;
21pub mod viterbi;
22
23#[cfg(feature = "gpu")]
24pub mod gpu;
25
26pub use error::{TropicalError, TropicalResult};
28
29#[cfg(feature = "gpu")]
31pub use gpu::{
32 GpuParameter, GpuTropicalNumber, TropicalGpuAccelerated, TropicalGpuContext, TropicalGpuError,
33 TropicalGpuOps, TropicalGpuResult, TROPICAL_ATTENTION_SHADER, TROPICAL_MATRIX_MULTIPLY_SHADER,
34};
35
36pub type StandardTropical = TropicalNumber<StandardFloat>;
39
40pub type ExtendedTropical = TropicalNumber<ExtendedFloat>;
42
43#[cfg(feature = "formal-verification")]
45pub mod verified;
46
47#[cfg(feature = "formal-verification")]
48pub mod verified_contracts;
49
50#[cfg(feature = "formal-verification")]
52pub use amari_core::verified::VerifiedMultivector as CoreVerifiedMultivector;
53
54#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
58pub struct TropicalNumber<T: Float>(pub T);
59
60impl<T: Float> TropicalNumber<T> {
61 pub fn neg_infinity() -> Self {
63 Self(T::neg_infinity())
64 }
65
66 pub fn zero() -> Self {
68 Self(T::zero())
69 }
70
71 pub fn tropical_zero() -> Self {
73 Self::neg_infinity()
74 }
75
76 pub fn tropical_one() -> Self {
78 Self::zero()
79 }
80
81 pub fn new(value: T) -> Self {
83 Self(value)
84 }
85
86 pub fn value(&self) -> T {
88 self.0
89 }
90
91 pub fn is_zero(&self) -> bool {
93 self.0.is_infinite() && self.0.is_sign_negative()
94 }
95
96 pub fn is_one(&self) -> bool {
98 self.0.is_zero()
99 }
100
101 pub fn is_infinity(&self) -> bool {
103 self.0.is_infinite()
104 }
105
106 pub fn tropical_add(self, other: Self) -> Self {
108 Self(self.0.max(other.0))
109 }
110
111 pub fn tropical_mul(self, other: Self) -> Self {
113 Self(self.0 + other.0)
114 }
115
116 pub fn tropical_pow(self, n: T) -> Self {
118 Self(self.0 * n)
119 }
120
121 pub fn from_log_prob(log_p: T) -> Self {
123 Self(log_p)
124 }
125
126 pub fn to_prob(self) -> T {
128 if self.is_zero() {
129 T::zero()
130 } else {
131 self.0.exp()
132 }
133 }
134}
135
136impl<T: Float> Add for TropicalNumber<T> {
141 type Output = Self;
142
143 fn add(self, other: Self) -> Self {
144 self.tropical_add(other)
145 }
146}
147
148impl<T: Float> Mul for TropicalNumber<T> {
149 type Output = Self;
150
151 fn mul(self, other: Self) -> Self {
152 self.tropical_mul(other)
153 }
154}
155
156impl<T: Float> Neg for TropicalNumber<T> {
157 type Output = Self;
158
159 fn neg(self) -> Self {
160 Self(-self.0)
161 }
162}
163
164impl TropicalNumber<f64> {
166 pub const ZERO: Self = Self(f64::NEG_INFINITY);
167 pub const ONE: Self = Self(0.0);
168}
169
170impl TropicalNumber<f32> {
172 pub const ZERO: Self = Self(f32::NEG_INFINITY);
173 pub const ONE: Self = Self(0.0);
174}
175
176#[derive(Clone, Debug)]
180pub struct TropicalMultivector<T: Float, const DIM: usize> {
181 coefficients: Vec<TropicalNumber<T>>,
182}
183
184impl<T: Float, const DIM: usize> TropicalMultivector<T, DIM> {
185 const BASIS_COUNT: usize = 1 << DIM;
186
187 pub fn zero() -> Self {
189 Self {
190 coefficients: {
191 let mut coeffs = Vec::with_capacity(Self::BASIS_COUNT);
192 for _ in 0..Self::BASIS_COUNT {
193 coeffs.push(TropicalNumber::zero());
194 }
195 coeffs
196 },
197 }
198 }
199
200 pub fn from_coefficients(coeffs: Vec<T>) -> Self {
202 assert_eq!(coeffs.len(), Self::BASIS_COUNT);
203 Self {
204 coefficients: coeffs.into_iter().map(TropicalNumber::new).collect(),
205 }
206 }
207
208 pub fn from_log_probs(log_probs: &[T]) -> Self {
210 let mut coeffs = Vec::with_capacity(Self::BASIS_COUNT);
211 for i in 0..Self::BASIS_COUNT {
212 if i < log_probs.len() {
213 coeffs.push(TropicalNumber::from_log_prob(log_probs[i]));
214 } else {
215 coeffs.push(TropicalNumber::zero());
216 }
217 }
218 Self {
219 coefficients: coeffs,
220 }
221 }
222
223 pub fn get(&self, index: usize) -> TropicalNumber<T> {
225 self.coefficients
226 .get(index)
227 .copied()
228 .unwrap_or(TropicalNumber::tropical_zero())
229 }
230
231 pub fn set(&mut self, index: usize, value: TropicalNumber<T>) {
233 if index < self.coefficients.len() {
234 self.coefficients[index] = value;
235 }
236 }
237
238 pub fn geometric_product(&self, other: &Self) -> Self {
240 let mut result = Self::zero();
241
242 for i in 0..Self::BASIS_COUNT {
245 for j in 0..Self::BASIS_COUNT {
246 let index = i ^ j; let product = self.coefficients[i] * other.coefficients[j];
248 result.coefficients[index] = result.coefficients[index] + product;
249 }
250 }
251
252 result
253 }
254
255 pub fn max_element(&self) -> TropicalNumber<T> {
257 self.coefficients
258 .iter()
259 .copied()
260 .fold(TropicalNumber::zero(), |acc, x| acc + x)
261 }
262
263 pub fn is_zero(&self) -> bool {
265 self.coefficients.iter().all(|c| c.is_zero())
266 }
267
268 pub fn add(&self, other: &Self) -> Self {
270 let mut result = Self::zero();
271 for i in 0..Self::BASIS_COUNT {
272 result.coefficients[i] = self.coefficients[i] + other.coefficients[i];
273 }
274 result
275 }
276
277 pub fn tropical_add(&self, other: &Self) -> Self {
279 self.add(other)
280 }
281
282 pub fn tropical_mul(&self, other: &Self) -> Self {
284 let mut result = Self::zero();
285 for i in 0..Self::BASIS_COUNT {
286 result.coefficients[i] = self.coefficients[i].tropical_mul(other.coefficients[i]);
287 }
288 result
289 }
290
291 pub fn tropical_scale(&self, scalar: T) -> Self {
293 let mut result = Self::zero();
294 let tropical_scalar = TropicalNumber::new(scalar);
295 for i in 0..Self::BASIS_COUNT {
296 result.coefficients[i] = self.coefficients[i].tropical_mul(tropical_scalar);
297 }
298 result
299 }
300
301 pub fn scale(&self, factor: T) -> Self {
303 let mut result = Self::zero();
304 for i in 0..Self::BASIS_COUNT {
305 result.coefficients[i] = TropicalNumber(self.coefficients[i].0 + factor);
306 }
307 result
308 }
309
310 pub fn support(&self) -> Vec<usize> {
312 self.coefficients
313 .iter()
314 .enumerate()
315 .filter(|(_, &coeff)| !coeff.is_zero())
316 .map(|(i, _)| i)
317 .collect()
318 }
319
320 pub fn tropical_norm(&self) -> TropicalNumber<T> {
322 self.coefficients
323 .iter()
324 .copied()
325 .map(|x| TropicalNumber::new(x.value().abs()))
326 .fold(TropicalNumber::zero(), |acc, x| acc + x)
327 }
328
329 pub fn from_logits(logits: &[T]) -> Self {
331 Self::from_log_probs(logits)
332 }
333
334 pub fn viterbi(
337 transitions: &TropicalMatrix<T>,
338 emissions: &TropicalMatrix<T>,
339 initial_probs: &[T],
340 sequence_length: usize,
341 ) -> Vec<usize> {
342 let num_states = initial_probs.len();
343 let mut path = Vec::with_capacity(sequence_length);
344
345 let mut current_probs = Vec::with_capacity(num_states);
347 let mut prev_states = Vec::with_capacity(sequence_length);
348 for _ in 0..sequence_length {
349 prev_states.push(vec![0; num_states]);
350 }
351
352 #[allow(clippy::needless_range_loop)]
354 for i in 0..num_states {
355 let init_prob = TropicalNumber::from_log_prob(initial_probs[i]);
356 let emit_prob = emissions.data[i][0]; current_probs.push(init_prob * emit_prob);
358 }
359
360 #[allow(clippy::needless_range_loop)]
362 for t in 1..sequence_length {
363 let mut new_probs = Vec::with_capacity(num_states);
364
365 #[allow(clippy::needless_range_loop)]
366 for curr_state in 0..num_states {
367 let mut best_prob = TropicalNumber::zero(); let mut best_prev = 0;
369
370 for prev_state in 0..num_states {
371 let transition_prob = transitions.data[prev_state][curr_state];
372 let emission_prob = emissions.data[curr_state][t.min(emissions.cols - 1)];
373 let total_prob = current_probs[prev_state] * transition_prob * emission_prob;
374
375 if total_prob.value() > best_prob.value() {
376 best_prob = total_prob;
377 best_prev = prev_state;
378 }
379 }
380
381 new_probs.push(best_prob);
382 prev_states[t][curr_state] = best_prev;
383 }
384
385 current_probs = new_probs;
386 }
387
388 let mut best_final_state = 0;
390 let mut best_final_prob = current_probs[0];
391 for (i, &prob) in current_probs.iter().enumerate().skip(1) {
392 if prob.value() > best_final_prob.value() {
393 best_final_prob = prob;
394 best_final_state = i;
395 }
396 }
397
398 path.push(best_final_state);
400 let mut current_state = best_final_state;
401
402 for t in (1..sequence_length).rev() {
403 current_state = prev_states[t][current_state];
404 path.push(current_state);
405 }
406
407 path.reverse();
408 path
409 }
410}
411
412#[derive(Clone, Debug)]
414pub struct TropicalMatrix<T: Float> {
415 data: Vec<Vec<TropicalNumber<T>>>,
416 rows: usize,
417 cols: usize,
418}
419
420impl<T: Float> TropicalMatrix<T> {
421 pub fn rows(&self) -> usize {
423 self.rows
424 }
425
426 pub fn cols(&self) -> usize {
428 self.cols
429 }
430
431 pub fn get(&self, i: usize, j: usize) -> Option<TropicalNumber<T>> {
433 self.data.get(i).and_then(|row| row.get(j).copied())
434 }
435
436 pub fn new(rows: usize, cols: usize) -> Self {
438 let mut data = Vec::with_capacity(rows);
439 for _ in 0..rows {
440 let mut row = Vec::with_capacity(cols);
441 for _ in 0..cols {
442 row.push(TropicalNumber::zero());
443 }
444 data.push(row);
445 }
446
447 Self { data, rows, cols }
448 }
449
450 pub fn from_log_probs(log_probs: &[Vec<T>]) -> Self {
452 let rows = log_probs.len();
453 let cols = log_probs[0].len();
454 let mut matrix = Self::new(rows, cols);
455
456 for (i, row) in log_probs.iter().enumerate() {
457 for (j, &value) in row.iter().enumerate() {
458 matrix.data[i][j] = TropicalNumber::from_log_prob(value);
459 }
460 }
461
462 matrix
463 }
464
465 pub fn mul(&self, other: &Self) -> Self {
467 assert_eq!(self.cols, other.rows);
468
469 let mut result = Self::new(self.rows, other.cols);
470
471 for i in 0..self.rows {
472 for j in 0..other.cols {
473 let mut sum = TropicalNumber::tropical_zero();
474 for k in 0..self.cols {
475 sum = sum + (self.data[i][k] * other.data[k][j]);
477 }
478 result.data[i][j] = sum;
479 }
480 }
481
482 result
483 }
484
485 pub fn determinant(&self) -> TropicalNumber<T> {
487 assert_eq!(self.rows, self.cols);
488
489 if self.rows == 1 {
490 return self.data[0][0];
491 }
492
493 if self.rows == 2 {
494 return self.data[0][0] * self.data[1][1] + self.data[0][1] * self.data[1][0];
495 }
496
497 let mut det = TropicalNumber::zero();
499 for j in 0..self.cols {
500 let minor = self.minor(0, j);
501 let cofactor = self.data[0][j] * minor.determinant();
502 det = det + cofactor;
503 }
504
505 det
506 }
507
508 fn minor(&self, row: usize, col: usize) -> Self {
510 let mut minor_data = Vec::new();
511
512 for i in 0..self.rows {
513 if i == row {
514 continue;
515 }
516 let mut minor_row = Vec::new();
517 for j in 0..self.cols {
518 if j == col {
519 continue;
520 }
521 minor_row.push(self.data[i][j]);
522 }
523 minor_data.push(minor_row);
524 }
525
526 Self {
527 data: minor_data,
528 rows: self.rows - 1,
529 cols: self.cols - 1,
530 }
531 }
532
533 pub fn to_attention_scores(&self) -> Vec<Vec<T>> {
535 let mut scores = Vec::with_capacity(self.rows);
536
537 for row in &self.data {
538 let max_val = row
539 .iter()
540 .copied()
541 .fold(TropicalNumber::zero(), |acc, x| acc + x);
542 let row_scores: Vec<T> = row
543 .iter()
544 .map(|&val| {
545 if max_val.is_zero() {
546 T::zero()
547 } else if val == max_val {
548 T::one()
549 } else {
550 T::zero()
551 }
552 })
553 .collect();
554 scores.push(row_scores);
555 }
556
557 scores
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use alloc::vec;
565 use approx::assert_relative_eq;
566
567 #[test]
568 fn test_tropical_number_operations() {
569 let a = TropicalNumber::new(2.0);
570 let b = TropicalNumber::new(3.0);
571
572 assert_eq!(a + b, TropicalNumber::new(3.0));
574 assert_eq!(b + a, TropicalNumber::new(3.0));
575
576 assert_eq!(a * b, TropicalNumber::new(5.0));
578 assert_eq!(b * a, TropicalNumber::new(5.0));
579
580 assert_eq!(a + TropicalNumber::<f64>::ZERO, a);
582 assert_eq!(a * TropicalNumber::<f64>::ONE, a);
583 }
584
585 #[test]
586 fn test_tropical_multivector() {
587 let mv1 = TropicalMultivector::<f64, 2>::from_coefficients(Vec::from([1.0, 2.0, 3.0, 4.0]));
588 let mv2 = TropicalMultivector::<f64, 2>::from_coefficients(Vec::from([0.5, 1.5, 2.5, 3.5]));
589
590 let product = mv1.geometric_product(&mv2);
591
592 assert!(!product.max_element().is_zero());
594
595 let support = mv1.support();
597 assert!(!support.is_empty());
598 }
599
600 #[test]
601 fn test_tropical_matrix() {
602 let log_probs = vec![
603 Vec::from([0.0, -1.0, -2.0]),
604 Vec::from([-1.0, 0.0, -1.0]),
605 Vec::from([-2.0, -1.0, 0.0]),
606 ];
607
608 let matrix = TropicalMatrix::from_log_probs(&log_probs);
609 let det = matrix.determinant();
610
611 assert!(!det.is_zero());
613
614 let scores = matrix.to_attention_scores();
616 assert_eq!(scores.len(), 3);
617 assert_eq!(scores[0].len(), 3);
618 }
619
620 #[test]
621 fn test_viterbi_equivalence() {
622 let transitions = Vec::from([
624 TropicalNumber::from_log_prob(-0.5),
625 TropicalNumber::from_log_prob(-1.0),
626 TropicalNumber::from_log_prob(-0.3),
627 ]);
628
629 let path_prob = transitions
630 .into_iter()
631 .fold(TropicalNumber::<f64>::ONE, |acc, x| acc * x);
632
633 assert_relative_eq!(path_prob.value(), -1.8, epsilon = 1e-10);
635 }
636
637 mod tropical_number_tests {
639 use super::*;
640
641 #[test]
642 fn test_tropical_number_constructors() {
643 let n1 = TropicalNumber::new(5.0);
644 assert_eq!(n1.value(), 5.0);
645
646 let zero = TropicalNumber::<f64>::neg_infinity();
647 assert!(zero.is_zero());
648 assert!(zero.is_infinity());
649
650 let one = TropicalNumber::<f64>::zero();
651 assert!(one.is_one());
652 assert!(!one.is_infinity());
653
654 let tropical_zero = TropicalNumber::<f64>::tropical_zero();
655 assert!(tropical_zero.is_zero());
656
657 let tropical_one = TropicalNumber::<f64>::tropical_one();
658 assert!(tropical_one.is_one());
659 }
660
661 #[test]
662 fn test_tropical_number_constants() {
663 assert!(TropicalNumber::<f64>::ZERO.is_zero());
664 assert!(TropicalNumber::<f64>::ONE.is_one());
665 assert_eq!(TropicalNumber::<f64>::ZERO.value(), f64::NEG_INFINITY);
666 assert_eq!(TropicalNumber::<f64>::ONE.value(), 0.0);
667 }
668
669 #[test]
670 fn test_tropical_predicates() {
671 let finite = TropicalNumber::new(3.0);
672 assert!(!finite.is_zero());
673 assert!(!finite.is_one());
674 assert!(!finite.is_infinity());
675
676 let zero = TropicalNumber::new(0.0);
677 assert!(zero.is_one());
678 assert!(!zero.is_zero());
679
680 let pos_inf = TropicalNumber::new(f64::INFINITY);
681 assert!(pos_inf.is_infinity());
682 assert!(!pos_inf.is_zero());
683
684 let neg_inf = TropicalNumber::new(f64::NEG_INFINITY);
685 assert!(neg_inf.is_zero());
686 assert!(neg_inf.is_infinity());
687 }
688
689 #[test]
690 fn test_tropical_arithmetic_properties() {
691 let a = TropicalNumber::new(2.0);
692 let b = TropicalNumber::new(3.0);
693 let c = TropicalNumber::new(1.0);
694
695 assert_eq!(a + b, b + a);
697 assert_eq!(a * b, b * a);
698
699 assert_eq!((a + b) + c, a + (b + c));
701 assert_eq!((a * b) * c, a * (b * c));
702
703 assert_eq!(a + TropicalNumber::<f64>::ZERO, a);
705 assert_eq!(TropicalNumber::<f64>::ZERO + a, a);
706 assert_eq!(a * TropicalNumber::<f64>::ONE, a);
707 assert_eq!(TropicalNumber::<f64>::ONE * a, a);
708
709 let left = a * (b + c);
711 let right = (a * b) + (a * c);
712 assert_eq!(left, right);
713 }
714
715 #[test]
716 fn test_tropical_add_operation() {
717 let a = TropicalNumber::new(5.0);
718 let b = TropicalNumber::new(3.0);
719
720 let result = a.tropical_add(b);
722 assert_eq!(result.value(), 5.0);
723
724 let result2 = b.tropical_add(a);
725 assert_eq!(result2.value(), 5.0);
726
727 let inf = TropicalNumber::new(f64::INFINITY);
729 let result3 = a.tropical_add(inf);
730 assert!(result3.value().is_infinite() && result3.value().is_sign_positive());
731 }
732
733 #[test]
734 fn test_tropical_mul_operation() {
735 let a = TropicalNumber::new(2.0);
736 let b = TropicalNumber::new(3.0);
737
738 let result = a.tropical_mul(b);
740 assert_eq!(result.value(), 5.0);
741
742 let zero = TropicalNumber::<f64>::ZERO;
744 let result2 = a.tropical_mul(zero);
745 assert!(result2.is_zero());
746 }
747
748 #[test]
749 fn test_tropical_pow() {
750 let a = TropicalNumber::new(2.0);
751 let result = a.tropical_pow(3.0);
752 assert_eq!(result.value(), 6.0); let zero = TropicalNumber::<f64>::ZERO;
755 let result2 = zero.tropical_pow(5.0);
756 assert!(result2.is_zero());
757 }
758
759 #[test]
760 fn test_probability_conversion() {
761 let log_prob = -1.0;
762 let trop = TropicalNumber::from_log_prob(log_prob);
763 assert_eq!(trop.value(), -1.0);
764
765 let prob = trop.to_prob();
766 assert_relative_eq!(prob, (-1.0f64).exp(), epsilon = 1e-10);
767
768 let zero = TropicalNumber::<f64>::ZERO;
770 assert_eq!(zero.to_prob(), 0.0);
771 }
772
773 #[test]
774 fn test_negation() {
775 let a = TropicalNumber::new(3.0);
776 let neg_a = -a;
777 assert_eq!(neg_a.value(), -3.0);
778
779 let zero = TropicalNumber::new(0.0);
780 let neg_zero = -zero;
781 assert_eq!(neg_zero.value(), 0.0);
782 }
783
784 #[test]
785 fn test_operator_overloads() {
786 let a = TropicalNumber::new(4.0);
787 let b = TropicalNumber::new(2.0);
788
789 let sum = a + b;
791 assert_eq!(sum.value(), 4.0);
792
793 let product = a * b;
795 assert_eq!(product.value(), 6.0);
796
797 let neg = -a;
799 assert_eq!(neg.value(), -4.0);
800 }
801
802 #[test]
803 fn test_edge_cases() {
804 let inf = TropicalNumber::new(f64::INFINITY);
805 let neg_inf = TropicalNumber::new(f64::NEG_INFINITY);
806 let finite = TropicalNumber::new(1.0);
807
808 assert_eq!((inf + finite).value(), f64::INFINITY);
810 assert_eq!((inf * finite).value(), f64::INFINITY);
811
812 assert_eq!((neg_inf + finite).value(), 1.0); assert_eq!((neg_inf * finite).value(), f64::NEG_INFINITY); let nan = TropicalNumber::new(f64::NAN);
818 assert!(nan.value().is_nan());
819 }
820 }
821
822 mod tropical_multivector_tests {
824 use super::*;
825
826 #[test]
827 fn test_multivector_constructors() {
828 let zero = TropicalMultivector::<f64, 2>::zero();
829 assert!(!zero.is_zero()); let coeffs = vec![1.0, 2.0, 3.0, 4.0];
833 let mv = TropicalMultivector::<f64, 2>::from_coefficients(coeffs.clone());
834
835 for (i, &coeff) in coeffs.iter().enumerate() {
836 assert_eq!(mv.get(i).value(), coeff);
837 }
838 }
839
840 #[test]
841 fn test_from_log_probs() {
842 let log_probs = vec![-1.0, -2.0, -0.5, -3.0];
843 let mv = TropicalMultivector::<f64, 2>::from_log_probs(&log_probs);
844
845 for (i, &log_prob) in log_probs.iter().enumerate() {
846 assert_eq!(mv.get(i).value(), log_prob);
847 }
848 }
849
850 #[test]
851 fn test_get_set_operations() {
852 let mut mv = TropicalMultivector::<f64, 2>::zero();
853
854 let val = TropicalNumber::new(5.0);
855 mv.set(1, val);
856 assert_eq!(mv.get(1), val);
857
858 assert_eq!(mv.get(999).value(), f64::NEG_INFINITY); }
861
862 #[test]
863 fn test_max_element() {
864 let coeffs = vec![1.0, 5.0, 2.0, 3.0];
865 let mv = TropicalMultivector::<f64, 2>::from_coefficients(coeffs);
866
867 let max_elem = mv.max_element();
868 assert_eq!(max_elem.value(), 5.0);
869 }
870
871 #[test]
872 fn test_tropical_operations() {
873 let coeffs1 = vec![1.0, 2.0, 3.0, 4.0];
874 let coeffs2 = vec![0.5, 3.0, 1.5, 2.0];
875
876 let mv1 = TropicalMultivector::<f64, 2>::from_coefficients(coeffs1);
877 let mv2 = TropicalMultivector::<f64, 2>::from_coefficients(coeffs2);
878
879 let sum = mv1.tropical_add(&mv2);
881 assert_eq!(sum.get(0).value(), 1.0); assert_eq!(sum.get(1).value(), 3.0); assert_eq!(sum.get(2).value(), 3.0); assert_eq!(sum.get(3).value(), 4.0); let add = mv1.add(&mv2);
888 assert_eq!(add.get(0).value(), 1.0); }
890
891 #[test]
892 fn test_tropical_scaling() {
893 let coeffs = vec![1.0, 2.0, 3.0, 4.0];
894 let mv = TropicalMultivector::<f64, 2>::from_coefficients(coeffs);
895
896 let scaled = mv.tropical_scale(2.0);
897 assert_eq!(scaled.get(0).value(), 3.0); assert_eq!(scaled.get(1).value(), 4.0); assert_eq!(scaled.get(2).value(), 5.0); assert_eq!(scaled.get(3).value(), 6.0); let regular_scaled = mv.scale(2.0);
903 assert_eq!(regular_scaled.get(0).value(), 3.0); }
905
906 #[test]
907 fn test_support() {
908 let coeffs = vec![f64::NEG_INFINITY, 1.0, f64::NEG_INFINITY, 2.0];
909 let mv = TropicalMultivector::<f64, 2>::from_coefficients(coeffs);
910
911 let support = mv.support();
912 assert_eq!(support, vec![1, 3]); }
914
915 #[test]
916 fn test_tropical_norm() {
917 let coeffs = vec![1.0, 3.0, 2.0, 4.0];
918 let mv = TropicalMultivector::<f64, 2>::from_coefficients(coeffs);
919
920 let norm = mv.tropical_norm();
921 assert_eq!(norm.value(), 4.0); }
923
924 #[test]
925 fn test_from_logits() {
926 let logits = vec![1.0, 2.0, 0.5, 3.0];
927 let mv = TropicalMultivector::<f64, 2>::from_logits(&logits);
928
929 assert_eq!(mv.get(0).value(), 1.0);
931 assert_eq!(mv.get(1).value(), 2.0);
932 assert_eq!(mv.get(2).value(), 0.5);
933 assert_eq!(mv.get(3).value(), 3.0);
934 }
935
936 #[test]
937 fn test_geometric_product() {
938 let mv1 = TropicalMultivector::<f64, 2>::from_coefficients(vec![1.0, 2.0, 3.0, 4.0]);
939 let mv2 = TropicalMultivector::<f64, 2>::from_coefficients(vec![0.5, 1.0, 1.5, 2.0]);
940
941 let product = mv1.geometric_product(&mv2);
942
943 assert!(!product.is_zero());
945 assert!(!product.max_element().is_zero());
946 }
947
948 #[test]
949 fn test_viterbi() {
950 let transitions = TropicalMatrix::<f64>::new(2, 2);
953 let emissions = TropicalMatrix::<f64>::new(3, 2);
954 let initial_probs = vec![-0.5, -1.0];
955
956 let path =
957 TropicalMultivector::<f64, 2>::viterbi(&transitions, &emissions, &initial_probs, 3);
958 assert_eq!(path.len(), 3); }
960
961 #[test]
962 fn test_zero_detection() {
963 let true_zero_mv = TropicalMultivector::<f64, 2>::from_coefficients(vec![
965 f64::NEG_INFINITY,
966 f64::NEG_INFINITY,
967 f64::NEG_INFINITY,
968 f64::NEG_INFINITY,
969 ]);
970 assert!(true_zero_mv.is_zero());
971
972 let non_zero_mv = TropicalMultivector::<f64, 2>::from_coefficients(vec![
973 1.0,
974 f64::NEG_INFINITY,
975 f64::NEG_INFINITY,
976 f64::NEG_INFINITY,
977 ]);
978 assert!(!non_zero_mv.is_zero());
979 }
980 }
981
982 mod tropical_matrix_tests {
984 use super::*;
985
986 #[test]
987 fn test_matrix_constructor() {
988 let matrix = TropicalMatrix::<f64>::new(3, 3);
989 assert_eq!(matrix.rows, 3);
990 assert_eq!(matrix.cols, 3);
991 assert_eq!(matrix.data.len(), 3);
992
993 for row in &matrix.data {
995 assert_eq!(row.len(), 3);
996 for &val in row {
997 assert!(val.is_one()); }
999 }
1000 }
1001
1002 #[test]
1003 fn test_from_log_probs() {
1004 let log_probs = vec![
1005 vec![0.0, -1.0, -2.0],
1006 vec![-0.5, 0.0, -1.5],
1007 vec![-1.0, -0.5, 0.0],
1008 ];
1009
1010 let matrix = TropicalMatrix::from_log_probs(&log_probs);
1011 assert_eq!(matrix.rows, 3);
1012 assert_eq!(matrix.cols, 3);
1013
1014 assert_eq!(matrix.data[0][0].value(), 0.0);
1016 assert_eq!(matrix.data[0][1].value(), -1.0);
1017 assert_eq!(matrix.data[1][0].value(), -0.5);
1018 }
1019
1020 #[test]
1021 fn test_matrix_multiplication() {
1022 let log_probs1 = vec![vec![0.0, -1.0], vec![-0.5, 0.0]];
1023
1024 let log_probs2 = vec![vec![-0.2, -0.8], vec![-0.3, 0.0]];
1025
1026 let m1 = TropicalMatrix::from_log_probs(&log_probs1);
1027 let m2 = TropicalMatrix::from_log_probs(&log_probs2);
1028
1029 let result = m1.mul(&m2);
1030 assert_eq!(result.rows, 2);
1031 assert_eq!(result.cols, 2);
1032
1033 assert_relative_eq!(result.data[0][0].value(), -0.2, epsilon = 1e-10);
1035 }
1036
1037 #[test]
1038 fn test_determinant() {
1039 let log_probs = vec![
1040 vec![0.0, -1.0, -2.0],
1041 vec![-1.0, 0.0, -1.0],
1042 vec![-2.0, -1.0, 0.0],
1043 ];
1044
1045 let matrix = TropicalMatrix::from_log_probs(&log_probs);
1046 let det = matrix.determinant();
1047
1048 assert!(!det.is_zero());
1050
1051 let small_probs = vec![vec![0.0, -1.0], vec![-0.5, 0.0]];
1053 let small_matrix = TropicalMatrix::from_log_probs(&small_probs);
1054 let small_det = small_matrix.determinant();
1055
1056 assert_eq!(small_det.value(), 0.0);
1058 }
1059
1060 #[test]
1061 fn test_attention_scores() {
1062 let log_probs = vec![
1063 vec![0.0, -1.0, -2.0],
1064 vec![-0.5, 0.0, -1.0],
1065 vec![-1.5, -0.5, 0.0],
1066 ];
1067
1068 let matrix = TropicalMatrix::from_log_probs(&log_probs);
1069 let scores = matrix.to_attention_scores();
1070
1071 assert_eq!(scores.len(), 3);
1072 for row in &scores {
1073 assert_eq!(row.len(), 3);
1074
1075 let sum: f64 = row.iter().sum();
1077 assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
1078 }
1079 }
1080
1081 #[test]
1082 fn test_matrix_edge_cases() {
1083 let empty_matrix = TropicalMatrix::<f64>::new(0, 0);
1085 assert_eq!(empty_matrix.rows, 0);
1086 assert_eq!(empty_matrix.cols, 0);
1087
1088 let single = TropicalMatrix::from_log_probs(&[vec![-0.5]]);
1090 let det = single.determinant();
1091 assert_eq!(det.value(), -0.5);
1092
1093 let inf_probs = vec![vec![0.0, f64::NEG_INFINITY], vec![f64::NEG_INFINITY, 0.0]];
1095 let inf_matrix = TropicalMatrix::from_log_probs(&inf_probs);
1096 let inf_det = inf_matrix.determinant();
1097 assert_eq!(inf_det.value(), 0.0); }
1099
1100 #[test]
1101 fn test_matrix_properties() {
1102 let log_probs = vec![vec![0.0, -1.0], vec![-0.5, 0.0]];
1103
1104 let matrix = TropicalMatrix::from_log_probs(&log_probs);
1105
1106 let identity_probs = vec![vec![0.0, f64::NEG_INFINITY], vec![f64::NEG_INFINITY, 0.0]];
1108 let identity = TropicalMatrix::from_log_probs(&identity_probs);
1109
1110 let result = matrix.mul(&identity);
1111
1112 assert_relative_eq!(
1114 result.data[0][0].value(),
1115 matrix.data[0][0].value(),
1116 epsilon = 1e-10
1117 );
1118 assert_relative_eq!(
1119 result.data[1][1].value(),
1120 matrix.data[1][1].value(),
1121 epsilon = 1e-10
1122 );
1123 }
1124 }
1125}