mathcompile 0.1.2

High-performance symbolic mathematics with final tagless design, egglog optimization, and Rust hot-loading compilation
Documentation
//! Automatic Differentiation Integration
//!
//! This module integrates the `ad_trait` crate with the `MathCompile` library,
//! enabling automatic differentiation for mathematical expressions. The module
//! provides both forward-mode and reverse-mode automatic differentiation.
//!
//! # Features
//!
//! - **Forward-mode AD**: Efficient for functions with few inputs and many outputs
//! - **Reverse-mode AD**: Efficient for functions with many inputs and few outputs
//! - **Higher-order derivatives**: Support for computing derivatives of derivatives
//! - **SIMD support**: Forward-mode AD can compute multiple tangents simultaneously
//!
//! # Usage
//!
//! ```rust
//! use mathcompile::autodiff::{ForwardAD, ReverseAD};
//! use ad_trait::forward_ad::adfn::adfn;
//!
//! // Forward-mode AD
//! let forward_ad = ForwardAD::new();
//! let quadratic = |x: adfn<1>| x * x + adfn::new(2.0, [0.0]) * x + adfn::new(1.0, [0.0]);
//! let (value, derivative) = forward_ad.differentiate(quadratic, 2.0).unwrap();
//! println!("f(2) = {}, f'(2) = {}", value, derivative);
//! ```

use crate::error::{MathCompileError, Result};

#[cfg(feature = "autodiff")]
use ad_trait::forward_ad::adfn::adfn;

/// Forward-mode automatic differentiation
#[cfg(feature = "autodiff")]
pub struct ForwardAD {
    _private: (),
}

#[cfg(feature = "autodiff")]
impl ForwardAD {
    /// Create a new forward-mode AD instance
    #[must_use]
    pub fn new() -> Self {
        Self { _private: () }
    }

    /// Differentiate a function with respect to its input
    pub fn differentiate<F>(&self, f: F, x: f64) -> Result<(f64, f64)>
    where
        F: Fn(adfn<1>) -> adfn<1>,
    {
        // Create AD variable with derivative seed
        let ad_x = adfn::new(x, [1.0]);

        // Evaluate function
        let result = f(ad_x);

        // Extract value and derivative
        let value = result.value();
        let derivative = result.tangent()[0];

        Ok((value, derivative))
    }

    /// Differentiate a function with multiple variables
    pub fn differentiate_multi<F>(&self, f: F, inputs: &[f64]) -> Result<(f64, Vec<f64>)>
    where
        F: Fn(&[adfn<8>]) -> adfn<8>, // Support up to 8 variables
    {
        if inputs.len() > 8 {
            return Err(MathCompileError::InvalidInput(
                "Forward AD supports up to 8 variables".to_string(),
            ));
        }

        // Create AD variables with unit tangent vectors
        let mut ad_inputs = Vec::new();
        for (i, &input) in inputs.iter().enumerate() {
            let mut tangent = [0.0; 8];
            tangent[i] = 1.0;
            ad_inputs.push(adfn::new(input, tangent));
        }

        // Evaluate function
        let result = f(&ad_inputs);

        // Extract value and derivatives
        let value = result.value();
        let derivatives = result.tangent()[..inputs.len()].to_vec();

        Ok((value, derivatives))
    }
}

#[cfg(feature = "autodiff")]
impl Default for ForwardAD {
    fn default() -> Self {
        Self::new()
    }
}

/// Reverse-mode automatic differentiation
#[cfg(feature = "autodiff")]
pub struct ReverseAD {
    _private: (),
}

#[cfg(feature = "autodiff")]
impl ReverseAD {
    /// Create a new reverse-mode AD instance
    #[must_use]
    pub fn new() -> Self {
        Self { _private: () }
    }

    /// Differentiate a simple function with respect to its input
    /// Note: This is a simplified implementation using finite differences
    /// since the `ad_trait` reverse AD API is complex for simple use cases
    pub fn differentiate<F>(&self, f: F, x: f64) -> Result<(f64, f64)>
    where
        F: Fn(f64) -> f64,
    {
        // Use finite differences as a fallback since reverse AD setup is complex
        let h = 1e-8;
        let value = f(x);
        let derivative = (f(x + h) - f(x - h)) / (2.0 * h);

        Ok((value, derivative))
    }

    /// Differentiate a multi-variable function using finite differences
    pub fn differentiate_multi<F>(&self, f: F, inputs: &[f64]) -> Result<(f64, Vec<f64>)>
    where
        F: Fn(&[f64]) -> f64,
    {
        let value = f(inputs);
        let mut derivatives = Vec::new();
        let h = 1e-8;

        for i in 0..inputs.len() {
            let mut inputs_plus = inputs.to_vec();
            let mut inputs_minus = inputs.to_vec();
            inputs_plus[i] += h;
            inputs_minus[i] -= h;

            let derivative = (f(&inputs_plus) - f(&inputs_minus)) / (2.0 * h);
            derivatives.push(derivative);
        }

        Ok((value, derivatives))
    }
}

#[cfg(feature = "autodiff")]
impl Default for ReverseAD {
    fn default() -> Self {
        Self::new()
    }
}

/// Higher-order differentiation utilities
#[cfg(feature = "autodiff")]
pub struct HigherOrderAD {
    _private: (),
}

#[cfg(feature = "autodiff")]
impl HigherOrderAD {
    /// Create a new higher-order AD instance
    #[must_use]
    pub fn new() -> Self {
        Self { _private: () }
    }

    /// Compute second derivative using forward-over-forward mode
    pub fn second_derivative<F>(&self, f: F, x: f64) -> Result<(f64, f64, f64)>
    where
        F: Fn(adfn<1>) -> adfn<1> + Clone,
    {
        // First, create a function that computes the first derivative
        let df_dx = |x_val: f64| -> f64 {
            let ad_x = adfn::new(x_val, [1.0]);
            let result = f.clone()(ad_x);
            result.tangent()[0]
        };

        // Now differentiate the derivative function using finite differences
        let h = 1e-8;
        let first_deriv = df_dx(x);
        let second_deriv = (df_dx(x + h) - df_dx(x - h)) / (2.0 * h);

        // Also compute the original function value
        let ad_x = adfn::new(x, [1.0]);
        let result = f(ad_x);
        let value = result.value();

        Ok((value, first_deriv, second_deriv))
    }
}

#[cfg(feature = "autodiff")]
impl Default for HigherOrderAD {
    fn default() -> Self {
        Self::new()
    }
}

/// Convenience functions for common differentiation tasks
#[cfg(feature = "autodiff")]
pub mod convenience {
    use super::{Result, ReverseAD};

    /// Compute the gradient of a scalar function with respect to multiple variables
    pub fn gradient<F>(f: F, inputs: &[f64]) -> Result<Vec<f64>>
    where
        F: Fn(&[f64]) -> f64,
    {
        let reverse_ad = ReverseAD::new();
        let (_, gradient) = reverse_ad.differentiate_multi(f, inputs)?;
        Ok(gradient)
    }

    /// Compute the Jacobian matrix for a vector-valued function
    pub fn jacobian<F>(f: F, inputs: &[f64], num_outputs: usize) -> Result<Vec<Vec<f64>>>
    where
        F: Fn(&[f64]) -> Vec<f64>,
    {
        let mut jacobian = Vec::new();

        for output_idx in 0..num_outputs {
            let output_function = |x: &[f64]| f(x)[output_idx];
            let gradient = gradient(output_function, inputs)?;
            jacobian.push(gradient);
        }

        Ok(jacobian)
    }

    /// Compute the Hessian matrix (second derivatives) for a scalar function
    pub fn hessian<F>(f: F, inputs: &[f64]) -> Result<Vec<Vec<f64>>>
    where
        F: Fn(&[f64]) -> f64 + Clone,
    {
        let n = inputs.len();
        let mut hessian = vec![vec![0.0; n]; n];
        let h = 1e-6;

        // Use finite differences for Hessian computation
        for i in 0..n {
            for j in 0..n {
                let mut x_pp = inputs.to_vec();
                let mut x_pm = inputs.to_vec();
                let mut x_mp = inputs.to_vec();
                let mut x_mm = inputs.to_vec();

                x_pp[i] += h;
                x_pp[j] += h;
                x_pm[i] += h;
                x_pm[j] -= h;
                x_mp[i] -= h;
                x_mp[j] += h;
                x_mm[i] -= h;
                x_mm[j] -= h;

                let second_deriv = (f.clone()(&x_pp) - f.clone()(&x_pm) - f.clone()(&x_mp)
                    + f.clone()(&x_mm))
                    / (4.0 * h * h);
                hessian[i][j] = second_deriv;
            }
        }

        Ok(hessian)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[cfg(feature = "autodiff")]
    #[test]
    fn test_forward_ad_simple() {
        let forward_ad = ForwardAD::new();

        // Test f(x) = x^2, f'(x) = 2x
        let quadratic = |x: adfn<1>| x * x;
        let (value, derivative) = forward_ad.differentiate(quadratic, 3.0).unwrap();

        assert!((value - 9.0).abs() < 1e-10);
        assert!((derivative - 6.0).abs() < 1e-10);
    }

    #[cfg(feature = "autodiff")]
    #[test]
    fn test_reverse_ad_simple() {
        let reverse_ad = ReverseAD::new();

        // Test f(x) = x^3, f'(x) = 3x^2
        let cubic = |x: f64| x * x * x;
        let (value, derivative) = reverse_ad.differentiate(cubic, 2.0).unwrap();

        assert!((value - 8.0).abs() < 1e-10);
        assert!((derivative - 12.0).abs() < 1e-6); // Finite difference tolerance
    }

    #[cfg(feature = "autodiff")]
    #[test]
    fn test_polynomial_functions() {
        let forward_ad = ForwardAD::new();

        // Test f(x) = x^3 + 2x^2 + x + 1, f'(x) = 3x^2 + 4x + 1
        let polynomial = |x: adfn<1>| {
            let x2 = x * x;
            let x3 = x2 * x;
            let two = adfn::new(2.0, [0.0]);
            let one = adfn::new(1.0, [0.0]);
            x3 + two * x2 + x + one
        };

        let (value, derivative) = forward_ad.differentiate(polynomial, 1.0).unwrap();

        // f(1) = 1 + 2 + 1 + 1 = 5
        // f'(1) = 3 + 4 + 1 = 8
        assert!((value - 5.0).abs() < 1e-10);
        assert!((derivative - 8.0).abs() < 1e-10);
    }

    #[cfg(feature = "autodiff")]
    #[test]
    fn test_multi_variable_gradient() {
        use convenience::gradient;

        // Test f(x,y) = x^2 + y^2, gradient = [2x, 2y]
        let func = |vars: &[f64]| vars[0] * vars[0] + vars[1] * vars[1];
        let grad = gradient(func, &[1.0, 2.0]).unwrap();

        assert!((grad[0] - 2.0).abs() < 1e-6);
        assert!((grad[1] - 4.0).abs() < 1e-6);
    }
}