xad-rs 0.2.0

Automatic differentiation library for Rust — forward/reverse mode AD, a Rust port of the C++ XAD library (https://github.com/auto-differentiation/xad)
Documentation
//! `LabeledFReal<T>` — labeled wrapper over `FReal<T>` (generic forward-mode).
//!
//! **Shape A (Phase 02.2):** does NOT carry an `Arc<VarRegistry>` field in
//! release builds. The struct layout is a single `inner: FReal<T>` field
//! plus, under `#[cfg(debug_assertions)]` only, a `gen_id: u64` stamped by the
//! owning [`LabeledForwardTape`] scope for the cross-registry debug guard.
//! Release builds are bit-for-bit equivalent to a pure `FReal<T>` wrapper
//! and carry zero atomic-refcount cost per operator.
//!
//! The only way to obtain a `LabeledFReal<T>` is via
//! [`LabeledForwardTape::input_freal`] or
//! [`LabeledForwardTape::constant_freal`], which stamp the current TLS
//! active-generation into the wrapper.
//!
//! `FReal<T>` carries ONE tangent direction — [`derivative`](Self::derivative)
//! returns the same value regardless of which `name` is queried. The
//! cross-generation check at binary-op time catches mixing values from
//! different `LabeledForwardTape` scopes in debug builds.

use std::fmt;

use crate::freal::FReal;
use crate::math;
use crate::traits::Scalar;

/// Labeled wrapper around the positional [`FReal<T>`] type.
///
/// # Example
///
/// ```
/// use xad_rs::labeled::{LabeledFReal, LabeledForwardTape};
///
/// let mut ft = LabeledForwardTape::new();
/// let x: LabeledFReal<f64> = ft.input_freal("x", 2.0);
/// let _registry = ft.freeze();
/// let f = &x * &x + &x; // f(x) = x^2 + x, f'(x) = 2x + 1 = 5
/// assert_eq!(f.value(), 6.0);
/// assert_eq!(f.derivative("x"), 5.0);
/// ```
#[derive(Clone)]
pub struct LabeledFReal<T: Scalar> {
    pub(super) inner: FReal<T>,
    // NOTE: field name is `gen_id` — `gen` alone is a reserved keyword in
    // Rust 2024 edition. The D-01/D-02 CONTEXT blocks spell it as `gen`;
    // we carry the spelling adjustment forward to satisfy the compiler.
    #[cfg(debug_assertions)]
    pub(super) gen_id: u64,
}

impl<T: Scalar> LabeledFReal<T> {
    /// Internal constructor used by `LabeledForwardTape::input_freal` and
    /// `LabeledForwardTape::constant_freal`. Reads the TLS active
    /// generation (debug builds only) to stamp the `gen` field. Not part
    /// of the public API.
    #[inline]
    pub(crate) fn __from_inner(inner: FReal<T>) -> Self {
        Self {
            inner,
            #[cfg(debug_assertions)]
            gen_id: crate::labeled::forward_tape::current_gen(),
        }
    }

    /// Value part.
    #[inline]
    pub fn value(&self) -> T {
        self.inner.value()
    }

    /// Label-keyed single-direction derivative accessor.
    ///
    /// `FReal<T>` carries only one tangent direction, so this returns the
    /// current tangent value regardless of which `name` is queried. The
    /// cross-generation debug guard in binary-op impls catches mixing
    /// values from different `LabeledForwardTape` scopes; users are
    /// expected to call `derivative(name)` only with names registered on
    /// the tape that constructed this value.
    #[inline]
    pub fn derivative(&self, _name: &str) -> T {
        self.inner.derivative()
    }

    /// Escape hatch: direct access to the inner positional `FReal<T>`.
    #[inline]
    pub fn inner(&self) -> &FReal<T> {
        &self.inner
    }

    // ============ Elementary math delegations ============
    // Forward to the free-function `math::fwd::*` surface, which operates
    // on `&FReal<T>` and returns `FReal<T>`. Each method preserves the
    // parent's generation stamp explicitly (debug builds) to avoid a TLS
    // read on the hot path.

    /// Natural exponential, preserving the parent scope's generation.
    #[inline]
    pub fn exp(&self) -> Self {
        Self {
            inner: math::fwd::exp(&self.inner),
            #[cfg(debug_assertions)]
            gen_id: self.gen_id,
        }
    }

    /// Natural logarithm, preserving the parent scope's generation.
    #[inline]
    pub fn ln(&self) -> Self {
        Self {
            inner: math::fwd::ln(&self.inner),
            #[cfg(debug_assertions)]
            gen_id: self.gen_id,
        }
    }

    /// Square root, preserving the parent scope's generation.
    #[inline]
    pub fn sqrt(&self) -> Self {
        Self {
            inner: math::fwd::sqrt(&self.inner),
            #[cfg(debug_assertions)]
            gen_id: self.gen_id,
        }
    }

    /// Sine, preserving the parent scope's generation.
    #[inline]
    pub fn sin(&self) -> Self {
        Self {
            inner: math::fwd::sin(&self.inner),
            #[cfg(debug_assertions)]
            gen_id: self.gen_id,
        }
    }

    /// Cosine, preserving the parent scope's generation.
    #[inline]
    pub fn cos(&self) -> Self {
        Self {
            inner: math::fwd::cos(&self.inner),
            #[cfg(debug_assertions)]
            gen_id: self.gen_id,
        }
    }

    /// Tangent, preserving the parent scope's generation.
    #[inline]
    pub fn tan(&self) -> Self {
        Self {
            inner: math::fwd::tan(&self.inner),
            #[cfg(debug_assertions)]
            gen_id: self.gen_id,
        }
    }
}

impl<T: Scalar> fmt::Debug for LabeledFReal<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("LabeledFReal")
            .field("value", &self.inner.value())
            .field("derivative", &self.inner.derivative())
            .finish()
    }
}

impl<T: Scalar> fmt::Display for LabeledFReal<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "LabeledFReal({})", self.inner.value())
    }
}

// ============ Operator overloads — hand-written, Shape A ============
// No shared op-stamping macro is used: Shape A does not carry an
// `Arc` registry field on the per-value wrapper. The four reference
// variants (owned/owned, ref/ref, owned/ref, ref/owned) plus scalar-RHS
// variants are stamped explicitly via a local `__lbl_freal_binop!`
// macro, modelled on `__lbl_areal_binop!` in `src/labeled/areal.rs`
// but generalised over `<T: Scalar>`. Each impl performs a debug-only
// `check_gen` between the two operands' generations, then constructs
// the result preserving the LHS's generation stamp.

macro_rules! __lbl_freal_binop {
    ($trait:ident, $method:ident, $op:tt) => {
        impl<T: Scalar> ::core::ops::$trait<LabeledFReal<T>> for LabeledFReal<T> {
            type Output = LabeledFReal<T>;
            #[inline]
            fn $method(self, rhs: LabeledFReal<T>) -> LabeledFReal<T> {
                #[cfg(debug_assertions)]
                crate::labeled::forward_tape::check_gen(self.gen_id, rhs.gen_id);
                LabeledFReal {
                    inner: self.inner $op rhs.inner,
                    #[cfg(debug_assertions)]
                    gen_id: self.gen_id,
                }
            }
        }
        impl<T: Scalar> ::core::ops::$trait<&LabeledFReal<T>> for &LabeledFReal<T> {
            type Output = LabeledFReal<T>;
            #[inline]
            fn $method(self, rhs: &LabeledFReal<T>) -> LabeledFReal<T> {
                #[cfg(debug_assertions)]
                crate::labeled::forward_tape::check_gen(self.gen_id, rhs.gen_id);
                LabeledFReal {
                    inner: &self.inner $op &rhs.inner,
                    #[cfg(debug_assertions)]
                    gen_id: self.gen_id,
                }
            }
        }
        impl<T: Scalar> ::core::ops::$trait<&LabeledFReal<T>> for LabeledFReal<T> {
            type Output = LabeledFReal<T>;
            #[inline]
            fn $method(self, rhs: &LabeledFReal<T>) -> LabeledFReal<T> {
                #[cfg(debug_assertions)]
                crate::labeled::forward_tape::check_gen(self.gen_id, rhs.gen_id);
                LabeledFReal {
                    inner: self.inner $op &rhs.inner,
                    #[cfg(debug_assertions)]
                    gen_id: self.gen_id,
                }
            }
        }
        impl<T: Scalar> ::core::ops::$trait<LabeledFReal<T>> for &LabeledFReal<T> {
            type Output = LabeledFReal<T>;
            #[inline]
            fn $method(self, rhs: LabeledFReal<T>) -> LabeledFReal<T> {
                #[cfg(debug_assertions)]
                crate::labeled::forward_tape::check_gen(self.gen_id, rhs.gen_id);
                LabeledFReal {
                    inner: &self.inner $op rhs.inner,
                    #[cfg(debug_assertions)]
                    gen_id: self.gen_id,
                }
            }
        }
        impl<T: Scalar> ::core::ops::$trait<T> for LabeledFReal<T> {
            type Output = LabeledFReal<T>;
            #[inline]
            fn $method(self, rhs: T) -> LabeledFReal<T> {
                LabeledFReal {
                    inner: self.inner $op rhs,
                    #[cfg(debug_assertions)]
                    gen_id: self.gen_id,
                }
            }
        }
        impl<T: Scalar> ::core::ops::$trait<T> for &LabeledFReal<T> {
            type Output = LabeledFReal<T>;
            #[inline]
            fn $method(self, rhs: T) -> LabeledFReal<T> {
                LabeledFReal {
                    inner: &self.inner $op rhs,
                    #[cfg(debug_assertions)]
                    gen_id: self.gen_id,
                }
            }
        }
    };
}

__lbl_freal_binop!(Add, add, +);
__lbl_freal_binop!(Sub, sub, -);
__lbl_freal_binop!(Mul, mul, *);
__lbl_freal_binop!(Div, div, /);

impl<T: Scalar> ::core::ops::Neg for LabeledFReal<T> {
    type Output = LabeledFReal<T>;
    #[inline]
    fn neg(self) -> LabeledFReal<T> {
        LabeledFReal {
            inner: -self.inner,
            #[cfg(debug_assertions)]
            gen_id: self.gen_id,
        }
    }
}
impl<T: Scalar> ::core::ops::Neg for &LabeledFReal<T> {
    type Output = LabeledFReal<T>;
    #[inline]
    fn neg(self) -> LabeledFReal<T> {
        LabeledFReal {
            inner: -&self.inner,
            #[cfg(debug_assertions)]
            gen_id: self.gen_id,
        }
    }
}

// ============ Scalar-on-LHS hand-written impls for f64 and f32 ============
// Two concrete types (f64, f32), two variants (owned, ref), four ops =
// 16 impls. Inner FReal<T> already provides `f64 op FReal<f64>` and
// `f32 op FReal<f32>` for owned + ref, so delegation is direct. Each
// result preserves the RHS's generation stamp (debug builds only).

macro_rules! __lblfreal_scalar_lhs {
    ($scalar:ty) => {
        impl ::core::ops::Add<LabeledFReal<$scalar>> for $scalar {
            type Output = LabeledFReal<$scalar>;
            #[inline]
            fn add(self, rhs: LabeledFReal<$scalar>) -> LabeledFReal<$scalar> {
                LabeledFReal {
                    inner: self + rhs.inner,
                    #[cfg(debug_assertions)]
                    gen_id: rhs.gen_id,
                }
            }
        }
        impl ::core::ops::Add<&LabeledFReal<$scalar>> for $scalar {
            type Output = LabeledFReal<$scalar>;
            #[inline]
            fn add(self, rhs: &LabeledFReal<$scalar>) -> LabeledFReal<$scalar> {
                LabeledFReal {
                    inner: self + &rhs.inner,
                    #[cfg(debug_assertions)]
                    gen_id: rhs.gen_id,
                }
            }
        }
        impl ::core::ops::Sub<LabeledFReal<$scalar>> for $scalar {
            type Output = LabeledFReal<$scalar>;
            #[inline]
            fn sub(self, rhs: LabeledFReal<$scalar>) -> LabeledFReal<$scalar> {
                LabeledFReal {
                    inner: self - rhs.inner,
                    #[cfg(debug_assertions)]
                    gen_id: rhs.gen_id,
                }
            }
        }
        impl ::core::ops::Sub<&LabeledFReal<$scalar>> for $scalar {
            type Output = LabeledFReal<$scalar>;
            #[inline]
            fn sub(self, rhs: &LabeledFReal<$scalar>) -> LabeledFReal<$scalar> {
                LabeledFReal {
                    inner: self - &rhs.inner,
                    #[cfg(debug_assertions)]
                    gen_id: rhs.gen_id,
                }
            }
        }
        impl ::core::ops::Mul<LabeledFReal<$scalar>> for $scalar {
            type Output = LabeledFReal<$scalar>;
            #[inline]
            fn mul(self, rhs: LabeledFReal<$scalar>) -> LabeledFReal<$scalar> {
                LabeledFReal {
                    inner: self * rhs.inner,
                    #[cfg(debug_assertions)]
                    gen_id: rhs.gen_id,
                }
            }
        }
        impl ::core::ops::Mul<&LabeledFReal<$scalar>> for $scalar {
            type Output = LabeledFReal<$scalar>;
            #[inline]
            fn mul(self, rhs: &LabeledFReal<$scalar>) -> LabeledFReal<$scalar> {
                LabeledFReal {
                    inner: self * &rhs.inner,
                    #[cfg(debug_assertions)]
                    gen_id: rhs.gen_id,
                }
            }
        }
        impl ::core::ops::Div<LabeledFReal<$scalar>> for $scalar {
            type Output = LabeledFReal<$scalar>;
            #[inline]
            fn div(self, rhs: LabeledFReal<$scalar>) -> LabeledFReal<$scalar> {
                LabeledFReal {
                    inner: self / rhs.inner,
                    #[cfg(debug_assertions)]
                    gen_id: rhs.gen_id,
                }
            }
        }
        impl ::core::ops::Div<&LabeledFReal<$scalar>> for $scalar {
            type Output = LabeledFReal<$scalar>;
            #[inline]
            fn div(self, rhs: &LabeledFReal<$scalar>) -> LabeledFReal<$scalar> {
                LabeledFReal {
                    inner: self / &rhs.inner,
                    #[cfg(debug_assertions)]
                    gen_id: rhs.gen_id,
                }
            }
        }
    };
}

__lblfreal_scalar_lhs!(f64);
__lblfreal_scalar_lhs!(f32);