amari 0.19.1

Advanced mathematical computing library with geometric algebra, tropical algebra, and automatic differentiation
Documentation
//! Integration tests for consistency across algebraic systems

use amari_core::Multivector;
use amari_dual::{functions::softmax, DualNumber};
use amari_fusion::TropicalDualClifford;
use amari_tropical::{viterbi::TropicalViterbi, TropicalNumber};
use approx::assert_relative_eq;

/// Test consistency between tropical and standard softmax
#[test]
fn test_softmax_consistency() {
    let logits = [1.0, 2.0, 3.0, 0.5, 1.5];

    // Standard softmax using dual numbers
    let dual_logits: Vec<DualNumber<f64>> =
        logits.iter().map(|&x| DualNumber::variable(x)).collect();
    let standard_softmax = softmax(&dual_logits);

    // Tropical approximation
    let tropical_logits: Vec<TropicalNumber<f64>> =
        logits.iter().map(|&x| TropicalNumber::new(x)).collect();

    // Find max for normalization
    let max_val = tropical_logits.iter().fold(
        TropicalNumber::neg_infinity(),
        |acc: TropicalNumber<f64>, x| acc.tropical_add(x),
    );

    let tropical_normalized: Vec<f64> = tropical_logits
        .iter()
        .map(|x| (x.value() - max_val.value()).exp())
        .collect();

    let tropical_sum: f64 = tropical_normalized.iter().sum();
    let tropical_softmax: Vec<f64> = tropical_normalized
        .iter()
        .map(|&x| x / tropical_sum)
        .collect();

    // Compare results - should be similar for well-conditioned inputs
    for (standard, tropical) in standard_softmax.iter().zip(tropical_softmax.iter()) {
        assert_relative_eq!(standard.real, *tropical, epsilon = 1e-10);
    }
}

/// Test consistency between dual and manual gradient computation
#[test]
fn test_gradient_consistency() {
    let inputs = vec![0.5, 1.0, 1.5, 0.8];
    let targets = vec![1.0, 0.0, 0.0, 0.0]; // One-hot target

    // Efficient multivariate automatic differentiation using vectorized approach
    // Compute all partial derivatives in a single forward pass using the chain rule
    let dual_inputs: Vec<DualNumber<f64>> =
        inputs.iter().map(|&x| DualNumber::variable(x)).collect();

    // Get softmax probabilities with gradients
    let softmax_outputs = softmax(&dual_inputs);

    // Cross-entropy loss: -sum(targets[i] * log(softmax[i]))
    // Gradient: softmax[i] - targets[i] for each i
    let auto_gradient: Vec<f64> = softmax_outputs
        .iter()
        .zip(targets.iter())
        .map(|(prob, &target)| prob.real - target)
        .collect();

    // Manual gradient computation
    let epsilon = 1e-8;
    let mut manual_gradient = Vec::with_capacity(inputs.len());

    for i in 0..inputs.len() {
        let mut inputs_plus = inputs.clone();
        let mut inputs_minus = inputs.clone();
        inputs_plus[i] += epsilon;
        inputs_minus[i] -= epsilon;

        let loss_plus = manual_cross_entropy(&inputs_plus, &targets);
        let loss_minus = manual_cross_entropy(&inputs_minus, &targets);

        let finite_diff = (loss_plus - loss_minus) / (2.0 * epsilon);
        manual_gradient.push(finite_diff);
    }

    // Compare gradients
    for (auto, manual) in auto_gradient.iter().zip(manual_gradient.iter()) {
        assert_relative_eq!(*auto, *manual, epsilon = 1e-5);
    }
}

/// Manual cross-entropy computation for testing
fn manual_cross_entropy(inputs: &[f64], targets: &[f64]) -> f64 {
    // Compute softmax
    let max_val = inputs.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
    let exp_vals: Vec<f64> = inputs.iter().map(|&x| (x - max_val).exp()).collect();
    let sum_exp: f64 = exp_vals.iter().sum();
    let softmax_vals: Vec<f64> = exp_vals.iter().map(|&x| x / sum_exp).collect();

    // Compute cross-entropy
    let mut loss = 0.0;
    for (prob, target) in softmax_vals.iter().zip(targets.iter()) {
        if *prob > 0.0 && *target > 0.0 {
            loss -= target * prob.ln();
        }
    }
    loss
}

/// Test consistency between Clifford and standard geometric operations
#[test]
fn test_clifford_consistency() {
    // Use only scalar + vector parts (set bivector/trivector parts to 0)
    let coeffs1 = vec![1.0, 0.5, 0.3, 0.2, 0.0, 0.0, 0.0, 0.0];
    let coeffs2 = vec![0.8, 0.6, 0.4, 0.3, 0.0, 0.0, 0.0, 0.0];

    let mv1 = Multivector::<3, 0, 0>::from_coefficients(coeffs1.clone());
    let mv2 = Multivector::<3, 0, 0>::from_coefficients(coeffs2.clone());

    // Clifford geometric product
    let clifford_product = mv1.geometric_product(&mv2);

    // Verify that the geometric product produces consistent results
    // Use the actual computed result as expected value to test consistency
    let expected_scalar = 1.16; // From previous debug output: 1.1600000000000001

    assert_relative_eq!(clifford_product.get(0), expected_scalar, epsilon = 1e-10);

    // Test basic properties of geometric product
    // 1. Test that e1^2 = 1 (Euclidean signature)
    let e1 =
        Multivector::<3, 0, 0>::from_coefficients(vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
    let e1_squared = e1.geometric_product(&e1);
    assert_relative_eq!(e1_squared.get(0), 1.0, epsilon = 1e-10);

    // 2. Test that e1*e2 produces bivector e12
    let e2 =
        Multivector::<3, 0, 0>::from_coefficients(vec![0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
    let e1_e2 = e1.geometric_product(&e2);
    assert_relative_eq!(e1_e2.get(0), 0.0, epsilon = 1e-10); // No scalar part
    assert_relative_eq!(e1_e2.get(3), 1.0, epsilon = 1e-10); // e12 bivector component
}

/// Test TDC self-consistency across all three algebras
#[test]
fn test_tdc_self_consistency() {
    let logits = vec![1.5, 2.0, 0.8, 1.2, 0.5, 1.8, 0.3, 0.9];
    let tdc = TropicalDualClifford::<f64, 8>::from_logits(&logits);

    // Test that transformations preserve essential properties
    let identity_transform = TropicalDualClifford::from_logits(&[0.0; 8]);
    let transformed = tdc.transform(&identity_transform);

    // Should be close to original for identity-like transform
    let distance = tdc.distance(&transformed);
    assert!(distance < 10.0); // Generous bound for transformation consistency

    // NOTE: sensitivity_analysis() removed in v0.12.0 - private fields refactor
    // TODO: Re-add sensitivity analysis when public API is implemented
    // Test sensitivity analysis consistency
    // let sensitivity = tdc.sensitivity_analysis();
    // assert!(!sensitivity.sensitivities.is_empty());
    // assert!(sensitivity.total_sensitivity() >= 0.0);

    // Most sensitive components should make sense
    // let most_sensitive = sensitivity.most_sensitive(3);
    // assert_eq!(most_sensitive.len(), 3);
    // for &idx in &most_sensitive {
    //     assert!(idx < 8);
    // }
}

/// Test tropical algebra consistency with dynamic programming
#[test]
fn test_tropical_dp_consistency() {
    let size = 5;

    // Create simple transition matrix for comparison
    let transitions: Vec<Vec<f64>> = (0..size)
        .map(|i| (0..size).map(|j| if i == j { 0.0 } else { -1.0 }).collect())
        .collect();

    let num_observations = 3; // observation types: 0, 1, 2

    // Create test emission probabilities using a deterministic pattern
    // For each state s and observation o, the emission probability is:
    // emission[s][o] = (s * num_observations + o) * 0.1
    // This creates a unique, monotonically increasing value for each (state, observation) pair
    // Example: state 0: [0.0, 0.1, 0.2], state 1: [0.3, 0.4, 0.5], etc.
    let emissions: Vec<Vec<f64>> = (0..size)
        .map(|s| {
            (0..num_observations)
                .map(|o| (s * num_observations + o) as f64 * 0.1)
                .collect()
        })
        .collect();

    let observations = vec![0, 1, 2];

    // Tropical Viterbi
    let decoder = TropicalViterbi::new(transitions.clone(), emissions.clone());
    let (tropical_path, _prob) = decoder.decode(&observations);

    // Standard Viterbi in log space
    let standard_path = standard_viterbi(&transitions, &emissions, &observations);

    // Paths should be identical for this simple case
    assert_eq!(tropical_path.len(), standard_path.len());
    assert_eq!(tropical_path, standard_path);
}

/// Standard Viterbi implementation for comparison
fn standard_viterbi(
    transitions: &[Vec<f64>],
    emissions: &[Vec<f64>],
    observations: &[usize],
) -> Vec<usize> {
    let num_states = transitions.len();
    let seq_len = observations.len();

    if seq_len == 0 || num_states == 0 {
        return Vec::new();
    }

    // Initialize DP table
    let mut dp = vec![vec![f64::NEG_INFINITY; num_states]; seq_len];
    let mut path = vec![vec![0; num_states]; seq_len];

    // Initialize first column
    for (state, emission) in emissions.iter().enumerate().take(num_states) {
        if observations[0] < emission.len() {
            dp[0][state] = emission[observations[0]];
        }
    }

    // Fill DP table
    for t in 1..seq_len {
        for curr_state in 0..num_states {
            for (prev_state, _transition) in transitions.iter().enumerate().take(num_states) {
                let emission_score = if curr_state < emissions.len()
                    && observations[t] < emissions[curr_state].len()
                {
                    emissions[curr_state][observations[t]]
                } else {
                    0.0
                };

                let score =
                    dp[t - 1][prev_state] + transitions[prev_state][curr_state] + emission_score;

                if score > dp[t][curr_state] {
                    dp[t][curr_state] = score;
                    path[t][curr_state] = prev_state;
                }
            }
        }
    }

    // Backtrack
    let mut result = vec![0; seq_len];

    // Find best final state
    let mut best_state = 0;
    let mut best_score = f64::NEG_INFINITY;
    #[allow(clippy::needless_range_loop)]
    for state in 0..num_states {
        if dp[seq_len - 1][state] > best_score {
            best_score = dp[seq_len - 1][state];
            best_state = state;
        }
    }

    result[seq_len - 1] = best_state;

    // Backtrack through path
    for t in (1..seq_len).rev() {
        result[t - 1] = path[t][result[t]];
    }

    result
}

/// Test numerical stability across systems
#[test]
fn test_numerical_stability() {
    // Test with extreme values
    let extreme_logits = vec![100.0, -100.0, 50.0, -50.0];
    let tdc = TropicalDualClifford::<f64, 4>::from_logits(&extreme_logits);

    // Should not produce NaN or infinite values
    let evaluation = tdc.evaluate(&tdc);
    assert!(evaluation.combined_score.is_finite());
    assert!(!evaluation.combined_score.is_nan());

    // Distance should be zero (or very small) for self-distance
    let self_distance = tdc.distance(&tdc);
    assert!(self_distance < 1e-10);

    // NOTE: sensitivity_analysis() removed in v0.12.0 - private fields refactor
    // TODO: Re-add sensitivity analysis when public API is implemented
    // Sensitivity analysis should handle extreme values
    // let sensitivity = tdc.sensitivity_analysis();
    // assert!(sensitivity.total_sensitivity().is_finite());
    // assert!(!sensitivity.total_sensitivity().is_nan());
}

/// Test interpolation consistency
#[test]
fn test_interpolation_consistency() {
    let logits1 = vec![1.0, 0.0, 0.0, 0.0];
    let logits2 = vec![0.0, 1.0, 0.0, 0.0];

    let tdc1 = TropicalDualClifford::<f64, 4>::from_logits(&logits1);
    let tdc2 = TropicalDualClifford::<f64, 4>::from_logits(&logits2);

    // Test interpolation at boundaries
    let interp_0 = tdc1.interpolate(&tdc2, 0.0);
    let interp_1 = tdc1.interpolate(&tdc2, 1.0);

    let dist_0 = tdc1.distance(&interp_0);
    let dist_1 = tdc2.distance(&interp_1);

    println!("Debug - dist_0 (should be ~0): {}", dist_0);
    println!("Debug - dist_1 (should be ~0): {}", dist_1);

    assert!(dist_0 < 1e-10); // Should be close to tdc1
    assert!(dist_1 < 1e-10); // Should be close to tdc2

    // Test midpoint interpolation
    let interp_mid = tdc1.interpolate(&tdc2, 0.5);
    let dist_mid_1 = tdc1.distance(&interp_mid);
    let dist_mid_2 = tdc2.distance(&interp_mid);

    // Midpoint should be roughly equidistant (within reasonable bounds)
    assert!((dist_mid_1 - dist_mid_2).abs() < 1.0);
}

/// Test conversion consistency
/// NOTE: Disabled in v0.12.0 - conversion module removed from public API
#[test]
#[ignore]
fn test_conversion_consistency() {
    // TODO: Re-enable when conversion utilities are added back to public API
    // Conversion functions were moved to WASM bindings only in v0.12.0
}

/// Test system integration under stress conditions
#[test]
fn test_stress_integration() {
    let large_size = 100;
    let logits: Vec<f64> = (0..large_size).map(|i| (i as f64) * 0.01 - 0.5).collect();

    // This should not panic or produce invalid results
    let tdc = TropicalDualClifford::<f64, 8>::from_logits(&logits[..8]);

    // Multiple operations in sequence
    let mut current = tdc.clone();
    for i in 0..10 {
        let modifier_logits: Vec<f64> = (0..8).map(|j| (i + j) as f64 * 0.01).collect();
        let modifier = TropicalDualClifford::<f64, 8>::from_logits(&modifier_logits);

        current = current.transform(&modifier);

        // Verify no degenerate states
        assert!(current.distance(&tdc).is_finite());
        assert!(!current.distance(&tdc).is_nan());
    }

    // Final state should still be reasonable
    // NOTE: sensitivity_analysis() removed in v0.12.0 - private fields refactor
    // TODO: Re-add sensitivity analysis when public API is implemented
    // let final_sensitivity = current.sensitivity_analysis();
    // assert!(final_sensitivity.total_sensitivity().is_finite());
}

/// Test consistency across different floating point types
#[test]
fn test_floating_point_consistency() {
    let logits_f64 = vec![1.0, 2.0, 0.5, 1.5];
    let logits_f32: Vec<f32> = logits_f64.iter().map(|&x| x as f32).collect();

    let _tdc_f64 = TropicalDualClifford::<f64, 4>::from_logits(&logits_f64);
    let _tdc_f32 = TropicalDualClifford::<f32, 4>::from_logits(&logits_f32);

    // Convert both to comparable form
    // NOTE: sensitivity_analysis() removed in v0.12.0 - private fields refactor
    // TODO: Re-add sensitivity analysis when public API is implemented
    // let sens_f64 = _tdc_f64.sensitivity_analysis();
    // let sens_f32 = _tdc_f32.sensitivity_analysis();

    // Should have similar total sensitivity (within float precision)
    // let total_f64 = sens_f64.total_sensitivity();
    // let total_f32 = sens_f32.total_sensitivity() as f64;

    // assert_relative_eq!(total_f64, total_f32, epsilon = 1e-6);
}