amari_tropical/
lib.rs

1//! Tropical (max-plus) algebra for efficient LLM operations
2//!
3//! Tropical algebra replaces traditional (+, ×) with (max, +), which converts
4//! expensive softmax operations into simple max operations. This is particularly
5//! useful for finding most likely sequences and optimization in neural networks.
6
7#![cfg_attr(not(feature = "std"), no_std)]
8
9extern crate alloc;
10use alloc::vec::Vec;
11use core::ops::{Add, Mul, Neg};
12use num_traits::Float;
13
14// Import precision types from amari-core
15#[cfg(feature = "high-precision")]
16pub use amari_core::HighPrecisionFloat;
17pub use amari_core::{ExtendedFloat, PrecisionFloat, StandardFloat};
18
19pub mod error;
20pub mod polytope;
21pub mod viterbi;
22
23#[cfg(feature = "gpu")]
24pub mod gpu;
25
26// Re-export error types
27pub use error::{TropicalError, TropicalResult};
28
29// Re-export GPU functionality when available
30#[cfg(feature = "gpu")]
31pub use gpu::{
32    GpuParameter, GpuTropicalNumber, TropicalGpuAccelerated, TropicalGpuContext, TropicalGpuError,
33    TropicalGpuOps, TropicalGpuResult, TROPICAL_ATTENTION_SHADER, TROPICAL_MATRIX_MULTIPLY_SHADER,
34};
35
36// Precision-aware type aliases for tropical numbers
37/// Standard precision tropical number using f64
38pub type StandardTropical = TropicalNumber<StandardFloat>;
39
40/// Extended precision tropical number - uses high precision when available
41pub type ExtendedTropical = TropicalNumber<ExtendedFloat>;
42
43// Phantom types and formal verification modules
44#[cfg(feature = "formal-verification")]
45pub mod verified;
46
47#[cfg(feature = "formal-verification")]
48pub mod verified_contracts;
49
50// Re-export phantom types for tropical algebra (when available)
51#[cfg(feature = "formal-verification")]
52pub use amari_core::verified::VerifiedMultivector as CoreVerifiedMultivector;
53
54/// A number in the tropical (max-plus) semiring
55///
56/// Tropical addition is max, tropical multiplication is addition
57#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
58pub struct TropicalNumber<T: Float>(pub T);
59
60impl<T: Float> TropicalNumber<T> {
61    /// Additive identity (negative infinity)
62    pub fn neg_infinity() -> Self {
63        Self(T::neg_infinity())
64    }
65
66    /// Multiplicative identity (zero in regular arithmetic)
67    pub fn zero() -> Self {
68        Self(T::zero())
69    }
70
71    /// Additive identity (same as neg_infinity for tropical)
72    pub fn tropical_zero() -> Self {
73        Self::neg_infinity()
74    }
75
76    /// Multiplicative identity (same as zero for tropical)
77    pub fn tropical_one() -> Self {
78        Self::zero()
79    }
80
81    /// Create from regular number
82    pub fn new(value: T) -> Self {
83        Self(value)
84    }
85
86    /// Get the underlying value
87    pub fn value(&self) -> T {
88        self.0
89    }
90
91    /// Check if this is the zero element (negative infinity)
92    pub fn is_zero(&self) -> bool {
93        self.0.is_infinite() && self.0.is_sign_negative()
94    }
95
96    /// Check if this is the one element (zero)
97    pub fn is_one(&self) -> bool {
98        self.0.is_zero()
99    }
100
101    /// Check if this is infinite (either positive or negative)
102    pub fn is_infinity(&self) -> bool {
103        self.0.is_infinite()
104    }
105
106    /// Tropical addition (max operation)
107    pub fn tropical_add(self, other: Self) -> Self {
108        Self(self.0.max(other.0))
109    }
110
111    /// Tropical multiplication (addition)
112    pub fn tropical_mul(self, other: Self) -> Self {
113        Self(self.0 + other.0)
114    }
115
116    /// Tropical power (scalar multiplication)
117    pub fn tropical_pow(self, n: T) -> Self {
118        Self(self.0 * n)
119    }
120
121    /// Convert from log-probability to tropical number
122    pub fn from_log_prob(log_p: T) -> Self {
123        Self(log_p)
124    }
125
126    /// Convert tropical number back to probability (via exp)
127    pub fn to_prob(self) -> T {
128        if self.is_zero() {
129            T::zero()
130        } else {
131            self.0.exp()
132        }
133    }
134}
135
136// Note: We don't implement Zero and One traits for TropicalNumber
137// because tropical zero is negative infinity and tropical one is zero,
138// which conflicts with the standard definitions
139
140impl<T: Float> Add for TropicalNumber<T> {
141    type Output = Self;
142
143    fn add(self, other: Self) -> Self {
144        self.tropical_add(other)
145    }
146}
147
148impl<T: Float> Mul for TropicalNumber<T> {
149    type Output = Self;
150
151    fn mul(self, other: Self) -> Self {
152        self.tropical_mul(other)
153    }
154}
155
156impl<T: Float> Neg for TropicalNumber<T> {
157    type Output = Self;
158
159    fn neg(self) -> Self {
160        Self(-self.0)
161    }
162}
163
164// Convenient constants for f64
165impl TropicalNumber<f64> {
166    pub const ZERO: Self = Self(f64::NEG_INFINITY);
167    pub const ONE: Self = Self(0.0);
168}
169
170// Convenient constants for f32 (for GPU compatibility)
171impl TropicalNumber<f32> {
172    pub const ZERO: Self = Self(f32::NEG_INFINITY);
173    pub const ONE: Self = Self(0.0);
174}
175
176// Removed duplicate Zero/One implementations
177
178/// Tropical multivector for geometric operations in tropical algebra
179#[derive(Clone, Debug)]
180pub struct TropicalMultivector<T: Float, const DIM: usize> {
181    coefficients: Vec<TropicalNumber<T>>,
182}
183
184impl<T: Float, const DIM: usize> TropicalMultivector<T, DIM> {
185    const BASIS_COUNT: usize = 1 << DIM;
186
187    /// Create zero tropical multivector
188    pub fn zero() -> Self {
189        Self {
190            coefficients: {
191                let mut coeffs = Vec::with_capacity(Self::BASIS_COUNT);
192                for _ in 0..Self::BASIS_COUNT {
193                    coeffs.push(TropicalNumber::zero());
194                }
195                coeffs
196            },
197        }
198    }
199
200    /// Create from regular coefficients
201    pub fn from_coefficients(coeffs: Vec<T>) -> Self {
202        assert_eq!(coeffs.len(), Self::BASIS_COUNT);
203        Self {
204            coefficients: coeffs.into_iter().map(TropicalNumber::new).collect(),
205        }
206    }
207
208    /// Create from log probabilities
209    pub fn from_log_probs(log_probs: &[T]) -> Self {
210        let mut coeffs = Vec::with_capacity(Self::BASIS_COUNT);
211        for i in 0..Self::BASIS_COUNT {
212            if i < log_probs.len() {
213                coeffs.push(TropicalNumber::from_log_prob(log_probs[i]));
214            } else {
215                coeffs.push(TropicalNumber::zero());
216            }
217        }
218        Self {
219            coefficients: coeffs,
220        }
221    }
222
223    /// Get coefficient at index
224    pub fn get(&self, index: usize) -> TropicalNumber<T> {
225        self.coefficients
226            .get(index)
227            .copied()
228            .unwrap_or(TropicalNumber::tropical_zero())
229    }
230
231    /// Set coefficient at index
232    pub fn set(&mut self, index: usize, value: TropicalNumber<T>) {
233        if index < self.coefficients.len() {
234            self.coefficients[index] = value;
235        }
236    }
237
238    /// Tropical geometric product
239    pub fn geometric_product(&self, other: &Self) -> Self {
240        let mut result = Self::zero();
241
242        // Simplified tropical geometric product
243        // In tropical algebra, we use max for addition and + for multiplication
244        for i in 0..Self::BASIS_COUNT {
245            for j in 0..Self::BASIS_COUNT {
246                let index = i ^ j; // XOR for basis blade combination
247                let product = self.coefficients[i] * other.coefficients[j];
248                result.coefficients[index] = result.coefficients[index] + product;
249            }
250        }
251
252        result
253    }
254
255    /// Find the maximum element (tropical sum)
256    pub fn max_element(&self) -> TropicalNumber<T> {
257        self.coefficients
258            .iter()
259            .copied()
260            .fold(TropicalNumber::zero(), |acc, x| acc + x)
261    }
262
263    /// Check if the multivector is zero
264    pub fn is_zero(&self) -> bool {
265        self.coefficients.iter().all(|c| c.is_zero())
266    }
267
268    /// Add two tropical multivectors
269    pub fn add(&self, other: &Self) -> Self {
270        let mut result = Self::zero();
271        for i in 0..Self::BASIS_COUNT {
272            result.coefficients[i] = self.coefficients[i] + other.coefficients[i];
273        }
274        result
275    }
276
277    /// Tropical addition (alias for add)
278    pub fn tropical_add(&self, other: &Self) -> Self {
279        self.add(other)
280    }
281
282    /// Tropical multiplication (element-wise multiplication of coefficients)
283    pub fn tropical_mul(&self, other: &Self) -> Self {
284        let mut result = Self::zero();
285        for i in 0..Self::BASIS_COUNT {
286            result.coefficients[i] = self.coefficients[i].tropical_mul(other.coefficients[i]);
287        }
288        result
289    }
290
291    /// Tropical scaling (multiply all coefficients by scalar)
292    pub fn tropical_scale(&self, scalar: T) -> Self {
293        let mut result = Self::zero();
294        let tropical_scalar = TropicalNumber::new(scalar);
295        for i in 0..Self::BASIS_COUNT {
296            result.coefficients[i] = self.coefficients[i].tropical_mul(tropical_scalar);
297        }
298        result
299    }
300
301    /// Scale by a regular scalar
302    pub fn scale(&self, factor: T) -> Self {
303        let mut result = Self::zero();
304        for i in 0..Self::BASIS_COUNT {
305            result.coefficients[i] = TropicalNumber(self.coefficients[i].0 + factor);
306        }
307        result
308    }
309
310    /// Get indices of non-zero (non-negative-infinity) elements
311    pub fn support(&self) -> Vec<usize> {
312        self.coefficients
313            .iter()
314            .enumerate()
315            .filter(|(_, &coeff)| !coeff.is_zero())
316            .map(|(i, _)| i)
317            .collect()
318    }
319
320    /// Tropical norm (maximum absolute value)
321    pub fn tropical_norm(&self) -> TropicalNumber<T> {
322        self.coefficients
323            .iter()
324            .copied()
325            .map(|x| TropicalNumber::new(x.value().abs()))
326            .fold(TropicalNumber::zero(), |acc, x| acc + x)
327    }
328
329    /// Create from logits (log probabilities)
330    pub fn from_logits(logits: &[T]) -> Self {
331        Self::from_log_probs(logits)
332    }
333
334    /// Viterbi algorithm using tropical algebra
335    /// Returns the most likely path through states
336    pub fn viterbi(
337        transitions: &TropicalMatrix<T>,
338        emissions: &TropicalMatrix<T>,
339        initial_probs: &[T],
340        sequence_length: usize,
341    ) -> Vec<usize> {
342        let num_states = initial_probs.len();
343        let mut path = Vec::with_capacity(sequence_length);
344
345        // Dynamic programming tables
346        let mut current_probs = Vec::with_capacity(num_states);
347        let mut prev_states = Vec::with_capacity(sequence_length);
348        for _ in 0..sequence_length {
349            prev_states.push(vec![0; num_states]);
350        }
351
352        // Initialize with first observation
353        #[allow(clippy::needless_range_loop)]
354        for i in 0..num_states {
355            let init_prob = TropicalNumber::from_log_prob(initial_probs[i]);
356            let emit_prob = emissions.data[i][0]; // First observation
357            current_probs.push(init_prob * emit_prob);
358        }
359
360        // Forward pass through sequence
361        #[allow(clippy::needless_range_loop)]
362        for t in 1..sequence_length {
363            let mut new_probs = Vec::with_capacity(num_states);
364
365            #[allow(clippy::needless_range_loop)]
366            for curr_state in 0..num_states {
367                let mut best_prob = TropicalNumber::zero(); // -infinity
368                let mut best_prev = 0;
369
370                for prev_state in 0..num_states {
371                    let transition_prob = transitions.data[prev_state][curr_state];
372                    let emission_prob = emissions.data[curr_state][t.min(emissions.cols - 1)];
373                    let total_prob = current_probs[prev_state] * transition_prob * emission_prob;
374
375                    if total_prob.value() > best_prob.value() {
376                        best_prob = total_prob;
377                        best_prev = prev_state;
378                    }
379                }
380
381                new_probs.push(best_prob);
382                prev_states[t][curr_state] = best_prev;
383            }
384
385            current_probs = new_probs;
386        }
387
388        // Find best final state
389        let mut best_final_state = 0;
390        let mut best_final_prob = current_probs[0];
391        for (i, &prob) in current_probs.iter().enumerate().skip(1) {
392            if prob.value() > best_final_prob.value() {
393                best_final_prob = prob;
394                best_final_state = i;
395            }
396        }
397
398        // Backtrack to reconstruct path
399        path.push(best_final_state);
400        let mut current_state = best_final_state;
401
402        for t in (1..sequence_length).rev() {
403            current_state = prev_states[t][current_state];
404            path.push(current_state);
405        }
406
407        path.reverse();
408        path
409    }
410}
411
412/// Tropical matrix operations for attention mechanisms
413#[derive(Clone, Debug)]
414pub struct TropicalMatrix<T: Float> {
415    data: Vec<Vec<TropicalNumber<T>>>,
416    rows: usize,
417    cols: usize,
418}
419
420impl<T: Float> TropicalMatrix<T> {
421    /// Get number of rows
422    pub fn rows(&self) -> usize {
423        self.rows
424    }
425
426    /// Get number of columns
427    pub fn cols(&self) -> usize {
428        self.cols
429    }
430
431    /// Get element at position (i, j)
432    pub fn get(&self, i: usize, j: usize) -> Option<TropicalNumber<T>> {
433        self.data.get(i).and_then(|row| row.get(j).copied())
434    }
435
436    /// Create new tropical matrix
437    pub fn new(rows: usize, cols: usize) -> Self {
438        let mut data = Vec::with_capacity(rows);
439        for _ in 0..rows {
440            let mut row = Vec::with_capacity(cols);
441            for _ in 0..cols {
442                row.push(TropicalNumber::zero());
443            }
444            data.push(row);
445        }
446
447        Self { data, rows, cols }
448    }
449
450    /// Create from log-probability matrix (common in attention)
451    pub fn from_log_probs(log_probs: &[Vec<T>]) -> Self {
452        let rows = log_probs.len();
453        let cols = log_probs[0].len();
454        let mut matrix = Self::new(rows, cols);
455
456        for (i, row) in log_probs.iter().enumerate() {
457            for (j, &value) in row.iter().enumerate() {
458                matrix.data[i][j] = TropicalNumber::from_log_prob(value);
459            }
460        }
461
462        matrix
463    }
464
465    /// Tropical matrix multiplication
466    pub fn mul(&self, other: &Self) -> Self {
467        assert_eq!(self.cols, other.rows);
468
469        let mut result = Self::new(self.rows, other.cols);
470
471        for i in 0..self.rows {
472            for j in 0..other.cols {
473                let mut sum = TropicalNumber::tropical_zero();
474                for k in 0..self.cols {
475                    // Tropical: (A*B)[i,j] = max_k(A[i,k] + B[k,j])
476                    sum = sum + (self.data[i][k] * other.data[k][j]);
477                }
478                result.data[i][j] = sum;
479            }
480        }
481
482        result
483    }
484
485    /// Tropical determinant (maximum over all permutations)
486    pub fn determinant(&self) -> TropicalNumber<T> {
487        assert_eq!(self.rows, self.cols);
488
489        if self.rows == 1 {
490            return self.data[0][0];
491        }
492
493        if self.rows == 2 {
494            return self.data[0][0] * self.data[1][1] + self.data[0][1] * self.data[1][0];
495        }
496
497        // For larger matrices, use recursive expansion
498        let mut det = TropicalNumber::zero();
499        for j in 0..self.cols {
500            let minor = self.minor(0, j);
501            let cofactor = self.data[0][j] * minor.determinant();
502            det = det + cofactor;
503        }
504
505        det
506    }
507
508    /// Extract minor matrix
509    fn minor(&self, row: usize, col: usize) -> Self {
510        let mut minor_data = Vec::new();
511
512        for i in 0..self.rows {
513            if i == row {
514                continue;
515            }
516            let mut minor_row = Vec::new();
517            for j in 0..self.cols {
518                if j == col {
519                    continue;
520                }
521                minor_row.push(self.data[i][j]);
522            }
523            minor_data.push(minor_row);
524        }
525
526        Self {
527            data: minor_data,
528            rows: self.rows - 1,
529            cols: self.cols - 1,
530        }
531    }
532
533    /// Convert to attention scores (softmax → max operation)
534    pub fn to_attention_scores(&self) -> Vec<Vec<T>> {
535        let mut scores = Vec::with_capacity(self.rows);
536
537        for row in &self.data {
538            let max_val = row
539                .iter()
540                .copied()
541                .fold(TropicalNumber::zero(), |acc, x| acc + x);
542            let row_scores: Vec<T> = row
543                .iter()
544                .map(|&val| {
545                    if max_val.is_zero() {
546                        T::zero()
547                    } else if val == max_val {
548                        T::one()
549                    } else {
550                        T::zero()
551                    }
552                })
553                .collect();
554            scores.push(row_scores);
555        }
556
557        scores
558    }
559}
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564    use alloc::vec;
565    use approx::assert_relative_eq;
566
567    #[test]
568    fn test_tropical_number_operations() {
569        let a = TropicalNumber::new(2.0);
570        let b = TropicalNumber::new(3.0);
571
572        // Tropical addition is max
573        assert_eq!(a + b, TropicalNumber::new(3.0));
574        assert_eq!(b + a, TropicalNumber::new(3.0));
575
576        // Tropical multiplication is addition
577        assert_eq!(a * b, TropicalNumber::new(5.0));
578        assert_eq!(b * a, TropicalNumber::new(5.0));
579
580        // Identity elements
581        assert_eq!(a + TropicalNumber::<f64>::ZERO, a);
582        assert_eq!(a * TropicalNumber::<f64>::ONE, a);
583    }
584
585    #[test]
586    fn test_tropical_multivector() {
587        let mv1 = TropicalMultivector::<f64, 2>::from_coefficients(Vec::from([1.0, 2.0, 3.0, 4.0]));
588        let mv2 = TropicalMultivector::<f64, 2>::from_coefficients(Vec::from([0.5, 1.5, 2.5, 3.5]));
589
590        let product = mv1.geometric_product(&mv2);
591
592        // Verify the result has correct structure
593        assert!(!product.max_element().is_zero());
594
595        // Check support (non-zero elements)
596        let support = mv1.support();
597        assert!(!support.is_empty());
598    }
599
600    #[test]
601    fn test_tropical_matrix() {
602        let log_probs = vec![
603            Vec::from([0.0, -1.0, -2.0]),
604            Vec::from([-1.0, 0.0, -1.0]),
605            Vec::from([-2.0, -1.0, 0.0]),
606        ];
607
608        let matrix = TropicalMatrix::from_log_probs(&log_probs);
609        let det = matrix.determinant();
610
611        // Determinant should not be zero (negative infinity)
612        assert!(!det.is_zero());
613
614        // Test attention scores conversion
615        let scores = matrix.to_attention_scores();
616        assert_eq!(scores.len(), 3);
617        assert_eq!(scores[0].len(), 3);
618    }
619
620    #[test]
621    fn test_viterbi_equivalence() {
622        // Tropical multiplication chain should equal Viterbi path probability
623        let transitions = Vec::from([
624            TropicalNumber::from_log_prob(-0.5),
625            TropicalNumber::from_log_prob(-1.0),
626            TropicalNumber::from_log_prob(-0.3),
627        ]);
628
629        let path_prob = transitions
630            .into_iter()
631            .fold(TropicalNumber::<f64>::ONE, |acc, x| acc * x);
632
633        // Should equal sum of log probabilities
634        assert_relative_eq!(path_prob.value(), -1.8, epsilon = 1e-10);
635    }
636
637    // Comprehensive TropicalNumber tests
638    mod tropical_number_tests {
639        use super::*;
640
641        #[test]
642        fn test_tropical_number_constructors() {
643            let n1 = TropicalNumber::new(5.0);
644            assert_eq!(n1.value(), 5.0);
645
646            let zero = TropicalNumber::<f64>::neg_infinity();
647            assert!(zero.is_zero());
648            assert!(zero.is_infinity());
649
650            let one = TropicalNumber::<f64>::zero();
651            assert!(one.is_one());
652            assert!(!one.is_infinity());
653
654            let tropical_zero = TropicalNumber::<f64>::tropical_zero();
655            assert!(tropical_zero.is_zero());
656
657            let tropical_one = TropicalNumber::<f64>::tropical_one();
658            assert!(tropical_one.is_one());
659        }
660
661        #[test]
662        fn test_tropical_number_constants() {
663            assert!(TropicalNumber::<f64>::ZERO.is_zero());
664            assert!(TropicalNumber::<f64>::ONE.is_one());
665            assert_eq!(TropicalNumber::<f64>::ZERO.value(), f64::NEG_INFINITY);
666            assert_eq!(TropicalNumber::<f64>::ONE.value(), 0.0);
667        }
668
669        #[test]
670        fn test_tropical_predicates() {
671            let finite = TropicalNumber::new(3.0);
672            assert!(!finite.is_zero());
673            assert!(!finite.is_one());
674            assert!(!finite.is_infinity());
675
676            let zero = TropicalNumber::new(0.0);
677            assert!(zero.is_one());
678            assert!(!zero.is_zero());
679
680            let pos_inf = TropicalNumber::new(f64::INFINITY);
681            assert!(pos_inf.is_infinity());
682            assert!(!pos_inf.is_zero());
683
684            let neg_inf = TropicalNumber::new(f64::NEG_INFINITY);
685            assert!(neg_inf.is_zero());
686            assert!(neg_inf.is_infinity());
687        }
688
689        #[test]
690        fn test_tropical_arithmetic_properties() {
691            let a = TropicalNumber::new(2.0);
692            let b = TropicalNumber::new(3.0);
693            let c = TropicalNumber::new(1.0);
694
695            // Commutativity
696            assert_eq!(a + b, b + a);
697            assert_eq!(a * b, b * a);
698
699            // Associativity
700            assert_eq!((a + b) + c, a + (b + c));
701            assert_eq!((a * b) * c, a * (b * c));
702
703            // Identity elements
704            assert_eq!(a + TropicalNumber::<f64>::ZERO, a);
705            assert_eq!(TropicalNumber::<f64>::ZERO + a, a);
706            assert_eq!(a * TropicalNumber::<f64>::ONE, a);
707            assert_eq!(TropicalNumber::<f64>::ONE * a, a);
708
709            // Distributivity: a * (b + c) = (a * b) + (a * c)
710            let left = a * (b + c);
711            let right = (a * b) + (a * c);
712            assert_eq!(left, right);
713        }
714
715        #[test]
716        fn test_tropical_add_operation() {
717            let a = TropicalNumber::new(5.0);
718            let b = TropicalNumber::new(3.0);
719
720            // Tropical add is max
721            let result = a.tropical_add(b);
722            assert_eq!(result.value(), 5.0);
723
724            let result2 = b.tropical_add(a);
725            assert_eq!(result2.value(), 5.0);
726
727            // Test with infinity
728            let inf = TropicalNumber::new(f64::INFINITY);
729            let result3 = a.tropical_add(inf);
730            assert!(result3.value().is_infinite() && result3.value().is_sign_positive());
731        }
732
733        #[test]
734        fn test_tropical_mul_operation() {
735            let a = TropicalNumber::new(2.0);
736            let b = TropicalNumber::new(3.0);
737
738            // Tropical mul is addition
739            let result = a.tropical_mul(b);
740            assert_eq!(result.value(), 5.0);
741
742            // Test with zero (neg infinity)
743            let zero = TropicalNumber::<f64>::ZERO;
744            let result2 = a.tropical_mul(zero);
745            assert!(result2.is_zero());
746        }
747
748        #[test]
749        fn test_tropical_pow() {
750            let a = TropicalNumber::new(2.0);
751            let result = a.tropical_pow(3.0);
752            assert_eq!(result.value(), 6.0); // 2 * 3 = 6
753
754            let zero = TropicalNumber::<f64>::ZERO;
755            let result2 = zero.tropical_pow(5.0);
756            assert!(result2.is_zero());
757        }
758
759        #[test]
760        fn test_probability_conversion() {
761            let log_prob = -1.0;
762            let trop = TropicalNumber::from_log_prob(log_prob);
763            assert_eq!(trop.value(), -1.0);
764
765            let prob = trop.to_prob();
766            assert_relative_eq!(prob, (-1.0f64).exp(), epsilon = 1e-10);
767
768            // Test zero conversion
769            let zero = TropicalNumber::<f64>::ZERO;
770            assert_eq!(zero.to_prob(), 0.0);
771        }
772
773        #[test]
774        fn test_negation() {
775            let a = TropicalNumber::new(3.0);
776            let neg_a = -a;
777            assert_eq!(neg_a.value(), -3.0);
778
779            let zero = TropicalNumber::new(0.0);
780            let neg_zero = -zero;
781            assert_eq!(neg_zero.value(), 0.0);
782        }
783
784        #[test]
785        fn test_operator_overloads() {
786            let a = TropicalNumber::new(4.0);
787            let b = TropicalNumber::new(2.0);
788
789            // Addition operator (tropical add = max)
790            let sum = a + b;
791            assert_eq!(sum.value(), 4.0);
792
793            // Multiplication operator (tropical mul = add)
794            let product = a * b;
795            assert_eq!(product.value(), 6.0);
796
797            // Negation
798            let neg = -a;
799            assert_eq!(neg.value(), -4.0);
800        }
801
802        #[test]
803        fn test_edge_cases() {
804            let inf = TropicalNumber::new(f64::INFINITY);
805            let neg_inf = TropicalNumber::new(f64::NEG_INFINITY);
806            let finite = TropicalNumber::new(1.0);
807
808            // Infinity cases
809            assert_eq!((inf + finite).value(), f64::INFINITY);
810            assert_eq!((inf * finite).value(), f64::INFINITY);
811
812            // Negative infinity cases
813            assert_eq!((neg_inf + finite).value(), 1.0); // max(-∞, 1) = 1
814            assert_eq!((neg_inf * finite).value(), f64::NEG_INFINITY); // -∞ + 1 = -∞
815
816            // NaN handling
817            let nan = TropicalNumber::new(f64::NAN);
818            assert!(nan.value().is_nan());
819        }
820    }
821
822    // Comprehensive TropicalMultivector tests
823    mod tropical_multivector_tests {
824        use super::*;
825
826        #[test]
827        fn test_multivector_constructors() {
828            let zero = TropicalMultivector::<f64, 2>::zero();
829            // The zero() constructor creates multiplicative identities (0.0), not additive identities (-∞)
830            assert!(!zero.is_zero()); // is_zero() checks for all -∞, but zero() creates all 0.0
831
832            let coeffs = vec![1.0, 2.0, 3.0, 4.0];
833            let mv = TropicalMultivector::<f64, 2>::from_coefficients(coeffs.clone());
834
835            for (i, &coeff) in coeffs.iter().enumerate() {
836                assert_eq!(mv.get(i).value(), coeff);
837            }
838        }
839
840        #[test]
841        fn test_from_log_probs() {
842            let log_probs = vec![-1.0, -2.0, -0.5, -3.0];
843            let mv = TropicalMultivector::<f64, 2>::from_log_probs(&log_probs);
844
845            for (i, &log_prob) in log_probs.iter().enumerate() {
846                assert_eq!(mv.get(i).value(), log_prob);
847            }
848        }
849
850        #[test]
851        fn test_get_set_operations() {
852            let mut mv = TropicalMultivector::<f64, 2>::zero();
853
854            let val = TropicalNumber::new(5.0);
855            mv.set(1, val);
856            assert_eq!(mv.get(1), val);
857
858            // Test bounds
859            assert_eq!(mv.get(999).value(), f64::NEG_INFINITY); // Out of bounds returns zero
860        }
861
862        #[test]
863        fn test_max_element() {
864            let coeffs = vec![1.0, 5.0, 2.0, 3.0];
865            let mv = TropicalMultivector::<f64, 2>::from_coefficients(coeffs);
866
867            let max_elem = mv.max_element();
868            assert_eq!(max_elem.value(), 5.0);
869        }
870
871        #[test]
872        fn test_tropical_operations() {
873            let coeffs1 = vec![1.0, 2.0, 3.0, 4.0];
874            let coeffs2 = vec![0.5, 3.0, 1.5, 2.0];
875
876            let mv1 = TropicalMultivector::<f64, 2>::from_coefficients(coeffs1);
877            let mv2 = TropicalMultivector::<f64, 2>::from_coefficients(coeffs2);
878
879            // Tropical addition (element-wise max)
880            let sum = mv1.tropical_add(&mv2);
881            assert_eq!(sum.get(0).value(), 1.0); // max(1.0, 0.5)
882            assert_eq!(sum.get(1).value(), 3.0); // max(2.0, 3.0)
883            assert_eq!(sum.get(2).value(), 3.0); // max(3.0, 1.5)
884            assert_eq!(sum.get(3).value(), 4.0); // max(4.0, 2.0)
885
886            // Regular addition (for comparison)
887            let add = mv1.add(&mv2);
888            assert_eq!(add.get(0).value(), 1.0); // Still max operation
889        }
890
891        #[test]
892        fn test_tropical_scaling() {
893            let coeffs = vec![1.0, 2.0, 3.0, 4.0];
894            let mv = TropicalMultivector::<f64, 2>::from_coefficients(coeffs);
895
896            let scaled = mv.tropical_scale(2.0);
897            assert_eq!(scaled.get(0).value(), 3.0); // 1.0 + 2.0
898            assert_eq!(scaled.get(1).value(), 4.0); // 2.0 + 2.0
899            assert_eq!(scaled.get(2).value(), 5.0); // 3.0 + 2.0
900            assert_eq!(scaled.get(3).value(), 6.0); // 4.0 + 2.0
901
902            let regular_scaled = mv.scale(2.0);
903            assert_eq!(regular_scaled.get(0).value(), 3.0); // Same operation
904        }
905
906        #[test]
907        fn test_support() {
908            let coeffs = vec![f64::NEG_INFINITY, 1.0, f64::NEG_INFINITY, 2.0];
909            let mv = TropicalMultivector::<f64, 2>::from_coefficients(coeffs);
910
911            let support = mv.support();
912            assert_eq!(support, vec![1, 3]); // Only non-negative-infinity elements
913        }
914
915        #[test]
916        fn test_tropical_norm() {
917            let coeffs = vec![1.0, 3.0, 2.0, 4.0];
918            let mv = TropicalMultivector::<f64, 2>::from_coefficients(coeffs);
919
920            let norm = mv.tropical_norm();
921            assert_eq!(norm.value(), 4.0); // Maximum element
922        }
923
924        #[test]
925        fn test_from_logits() {
926            let logits = vec![1.0, 2.0, 0.5, 3.0];
927            let mv = TropicalMultivector::<f64, 2>::from_logits(&logits);
928
929            // Should just copy the logits directly (from_logits calls from_log_probs)
930            assert_eq!(mv.get(0).value(), 1.0);
931            assert_eq!(mv.get(1).value(), 2.0);
932            assert_eq!(mv.get(2).value(), 0.5);
933            assert_eq!(mv.get(3).value(), 3.0);
934        }
935
936        #[test]
937        fn test_geometric_product() {
938            let mv1 = TropicalMultivector::<f64, 2>::from_coefficients(vec![1.0, 2.0, 3.0, 4.0]);
939            let mv2 = TropicalMultivector::<f64, 2>::from_coefficients(vec![0.5, 1.0, 1.5, 2.0]);
940
941            let product = mv1.geometric_product(&mv2);
942
943            // Verify it's not zero and has reasonable structure
944            assert!(!product.is_zero());
945            assert!(!product.max_element().is_zero());
946        }
947
948        #[test]
949        fn test_viterbi() {
950            // Test the static viterbi function exists and works
951            // Note: This is a simplified test since viterbi is complex
952            let transitions = TropicalMatrix::<f64>::new(2, 2);
953            let emissions = TropicalMatrix::<f64>::new(3, 2);
954            let initial_probs = vec![-0.5, -1.0];
955
956            let path =
957                TropicalMultivector::<f64, 2>::viterbi(&transitions, &emissions, &initial_probs, 3);
958            assert_eq!(path.len(), 3); // Should match sequence length
959        }
960
961        #[test]
962        fn test_zero_detection() {
963            // Create a true tropical zero (all negative infinity)
964            let true_zero_mv = TropicalMultivector::<f64, 2>::from_coefficients(vec![
965                f64::NEG_INFINITY,
966                f64::NEG_INFINITY,
967                f64::NEG_INFINITY,
968                f64::NEG_INFINITY,
969            ]);
970            assert!(true_zero_mv.is_zero());
971
972            let non_zero_mv = TropicalMultivector::<f64, 2>::from_coefficients(vec![
973                1.0,
974                f64::NEG_INFINITY,
975                f64::NEG_INFINITY,
976                f64::NEG_INFINITY,
977            ]);
978            assert!(!non_zero_mv.is_zero());
979        }
980    }
981
982    // Comprehensive TropicalMatrix tests
983    mod tropical_matrix_tests {
984        use super::*;
985
986        #[test]
987        fn test_matrix_constructor() {
988            let matrix = TropicalMatrix::<f64>::new(3, 3);
989            assert_eq!(matrix.rows, 3);
990            assert_eq!(matrix.cols, 3);
991            assert_eq!(matrix.data.len(), 3);
992
993            // Should be initialized with multiplicative identity (0.0)
994            for row in &matrix.data {
995                assert_eq!(row.len(), 3);
996                for &val in row {
997                    assert!(val.is_one()); // TropicalNumber::zero() creates multiplicative identity
998                }
999            }
1000        }
1001
1002        #[test]
1003        fn test_from_log_probs() {
1004            let log_probs = vec![
1005                vec![0.0, -1.0, -2.0],
1006                vec![-0.5, 0.0, -1.5],
1007                vec![-1.0, -0.5, 0.0],
1008            ];
1009
1010            let matrix = TropicalMatrix::from_log_probs(&log_probs);
1011            assert_eq!(matrix.rows, 3);
1012            assert_eq!(matrix.cols, 3);
1013
1014            // Check values are correctly set
1015            assert_eq!(matrix.data[0][0].value(), 0.0);
1016            assert_eq!(matrix.data[0][1].value(), -1.0);
1017            assert_eq!(matrix.data[1][0].value(), -0.5);
1018        }
1019
1020        #[test]
1021        fn test_matrix_multiplication() {
1022            let log_probs1 = vec![vec![0.0, -1.0], vec![-0.5, 0.0]];
1023
1024            let log_probs2 = vec![vec![-0.2, -0.8], vec![-0.3, 0.0]];
1025
1026            let m1 = TropicalMatrix::from_log_probs(&log_probs1);
1027            let m2 = TropicalMatrix::from_log_probs(&log_probs2);
1028
1029            let result = m1.mul(&m2);
1030            assert_eq!(result.rows, 2);
1031            assert_eq!(result.cols, 2);
1032
1033            // Check first element: max(0 + (-0.2), (-1) + (-0.3)) = max(-0.2, -1.3) = -0.2
1034            assert_relative_eq!(result.data[0][0].value(), -0.2, epsilon = 1e-10);
1035        }
1036
1037        #[test]
1038        fn test_determinant() {
1039            let log_probs = vec![
1040                vec![0.0, -1.0, -2.0],
1041                vec![-1.0, 0.0, -1.0],
1042                vec![-2.0, -1.0, 0.0],
1043            ];
1044
1045            let matrix = TropicalMatrix::from_log_probs(&log_probs);
1046            let det = matrix.determinant();
1047
1048            // Should not be zero (negative infinity)
1049            assert!(!det.is_zero());
1050
1051            // Test 2x2 determinant
1052            let small_probs = vec![vec![0.0, -1.0], vec![-0.5, 0.0]];
1053            let small_matrix = TropicalMatrix::from_log_probs(&small_probs);
1054            let small_det = small_matrix.determinant();
1055
1056            // det = max(0 + 0, (-1) + (-0.5)) = max(0, -1.5) = 0
1057            assert_eq!(small_det.value(), 0.0);
1058        }
1059
1060        #[test]
1061        fn test_attention_scores() {
1062            let log_probs = vec![
1063                vec![0.0, -1.0, -2.0],
1064                vec![-0.5, 0.0, -1.0],
1065                vec![-1.5, -0.5, 0.0],
1066            ];
1067
1068            let matrix = TropicalMatrix::from_log_probs(&log_probs);
1069            let scores = matrix.to_attention_scores();
1070
1071            assert_eq!(scores.len(), 3);
1072            for row in &scores {
1073                assert_eq!(row.len(), 3);
1074
1075                // Each row should sum to approximately 1.0 (attention weights)
1076                let sum: f64 = row.iter().sum();
1077                assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
1078            }
1079        }
1080
1081        #[test]
1082        fn test_matrix_edge_cases() {
1083            // Empty matrix
1084            let empty_matrix = TropicalMatrix::<f64>::new(0, 0);
1085            assert_eq!(empty_matrix.rows, 0);
1086            assert_eq!(empty_matrix.cols, 0);
1087
1088            // Single element matrix
1089            let single = TropicalMatrix::from_log_probs(&[vec![-0.5]]);
1090            let det = single.determinant();
1091            assert_eq!(det.value(), -0.5);
1092
1093            // Matrix with infinities
1094            let inf_probs = vec![vec![0.0, f64::NEG_INFINITY], vec![f64::NEG_INFINITY, 0.0]];
1095            let inf_matrix = TropicalMatrix::from_log_probs(&inf_probs);
1096            let inf_det = inf_matrix.determinant();
1097            assert_eq!(inf_det.value(), 0.0); // max(0, -∞) = 0
1098        }
1099
1100        #[test]
1101        fn test_matrix_properties() {
1102            let log_probs = vec![vec![0.0, -1.0], vec![-0.5, 0.0]];
1103
1104            let matrix = TropicalMatrix::from_log_probs(&log_probs);
1105
1106            // Test that matrix operations are consistent
1107            let identity_probs = vec![vec![0.0, f64::NEG_INFINITY], vec![f64::NEG_INFINITY, 0.0]];
1108            let identity = TropicalMatrix::from_log_probs(&identity_probs);
1109
1110            let result = matrix.mul(&identity);
1111
1112            // Result should be close to original matrix
1113            assert_relative_eq!(
1114                result.data[0][0].value(),
1115                matrix.data[0][0].value(),
1116                epsilon = 1e-10
1117            );
1118            assert_relative_eq!(
1119                result.data[1][1].value(),
1120                matrix.data[1][1].value(),
1121                epsilon = 1e-10
1122            );
1123        }
1124    }
1125}