use crate::tape::{Tape, TapeStorage};
use std::fmt;
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
#[derive(Clone)]
pub struct AReal<T: TapeStorage> {
value: T,
pub(crate) slot: u32,
}
const INVALID_SLOT: u32 = u32::MAX;
impl<T: TapeStorage> AReal<T> {
pub fn new(value: T) -> Self {
AReal {
value,
slot: INVALID_SLOT,
}
}
#[inline]
pub fn value(&self) -> T {
self.value
}
#[inline]
pub fn set_value(&mut self, v: T) {
self.value = v;
}
#[inline]
pub fn slot(&self) -> u32 {
self.slot
}
#[inline]
pub fn should_record(&self) -> bool {
self.slot != INVALID_SLOT
}
#[inline]
pub fn adjoint(&self, tape: &Tape<T>) -> T {
if self.slot == INVALID_SLOT {
T::zero()
} else {
tape.derivative(self.slot)
}
}
#[inline]
pub fn set_adjoint(&self, tape: &mut Tape<T>, value: T) {
if self.slot != INVALID_SLOT {
tape.set_derivative(self.slot, value);
}
}
pub fn register_input(vars: &mut [AReal<T>], tape: &mut Tape<T>) {
for v in vars.iter_mut() {
if !v.should_record() {
v.slot = tape.register_variable();
}
}
}
pub fn register_output(vars: &mut [AReal<T>], tape: &mut Tape<T>) {
for v in vars.iter_mut() {
if !v.should_record() {
v.slot = tape.register_variable();
}
}
}
}
#[inline]
fn record_binary<T: TapeStorage>(
result_value: T,
lhs_slot: u32,
lhs_mul: T,
rhs_slot: u32,
rhs_mul: T,
) -> AReal<T> {
let tape_ptr = Tape::<T>::get_active();
if let Some(ptr) = tape_ptr {
let tape = unsafe { &mut *ptr };
let slot = tape.register_variable();
tape.push_binary(slot, lhs_mul, lhs_slot, rhs_mul, rhs_slot);
AReal { value: result_value, slot }
} else {
AReal::new(result_value)
}
}
#[inline]
fn record_unary<T: TapeStorage>(result_value: T, input_slot: u32, multiplier: T) -> AReal<T> {
let tape_ptr = Tape::<T>::get_active();
if let Some(ptr) = tape_ptr {
let tape = unsafe { &mut *ptr };
let slot = tape.register_variable();
tape.push_unary(slot, multiplier, input_slot);
AReal { value: result_value, slot }
} else {
AReal::new(result_value)
}
}
pub(crate) fn record_unary_op<T: TapeStorage>(
result_value: T,
input_slot: u32,
multiplier: T,
) -> AReal<T> {
record_unary(result_value, input_slot, multiplier)
}
pub(crate) fn record_binary_op<T: TapeStorage>(
result_value: T,
a_slot: u32,
a_mul: T,
b_slot: u32,
b_mul: T,
) -> AReal<T> {
record_binary(result_value, a_slot, a_mul, b_slot, b_mul)
}
impl<T: TapeStorage> From<T> for AReal<T> {
fn from(value: T) -> Self {
AReal::new(value)
}
}
impl From<i32> for AReal<f64> {
fn from(value: i32) -> Self {
AReal::new(value as f64)
}
}
impl From<i32> for AReal<f32> {
fn from(value: i32) -> Self {
AReal::new(value as f32)
}
}
impl<T: TapeStorage> Add for AReal<T> {
type Output = AReal<T>;
#[inline]
fn add(self, rhs: AReal<T>) -> AReal<T> {
record_binary(
self.value + rhs.value,
self.slot,
T::one(),
rhs.slot,
T::one(),
)
}
}
impl<T: TapeStorage> Add for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn add(self, rhs: &AReal<T>) -> AReal<T> {
record_binary(
self.value + rhs.value,
self.slot,
T::one(),
rhs.slot,
T::one(),
)
}
}
impl<T: TapeStorage> Add<&AReal<T>> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn add(self, rhs: &AReal<T>) -> AReal<T> {
record_binary(
self.value + rhs.value,
self.slot,
T::one(),
rhs.slot,
T::one(),
)
}
}
impl<T: TapeStorage> Add<AReal<T>> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn add(self, rhs: AReal<T>) -> AReal<T> {
record_binary(
self.value + rhs.value,
self.slot,
T::one(),
rhs.slot,
T::one(),
)
}
}
impl<T: TapeStorage> Add<T> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn add(self, rhs: T) -> AReal<T> {
record_unary(self.value + rhs, self.slot, T::one())
}
}
impl<T: TapeStorage> Add<T> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn add(self, rhs: T) -> AReal<T> {
record_unary(self.value + rhs, self.slot, T::one())
}
}
impl Add<AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn add(self, rhs: AReal<f64>) -> AReal<f64> {
record_unary(self + rhs.value, rhs.slot, 1.0)
}
}
impl Add<&AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn add(self, rhs: &AReal<f64>) -> AReal<f64> {
record_unary(self + rhs.value, rhs.slot, 1.0)
}
}
impl Add<AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn add(self, rhs: AReal<f32>) -> AReal<f32> {
record_unary(self + rhs.value, rhs.slot, 1.0)
}
}
impl Add<&AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn add(self, rhs: &AReal<f32>) -> AReal<f32> {
record_unary(self + rhs.value, rhs.slot, 1.0)
}
}
impl<T: TapeStorage> Sub for AReal<T> {
type Output = AReal<T>;
#[inline]
fn sub(self, rhs: AReal<T>) -> AReal<T> {
record_binary(
self.value - rhs.value,
self.slot,
T::one(),
rhs.slot,
-T::one(),
)
}
}
impl<T: TapeStorage> Sub for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn sub(self, rhs: &AReal<T>) -> AReal<T> {
record_binary(
self.value - rhs.value,
self.slot,
T::one(),
rhs.slot,
-T::one(),
)
}
}
impl<T: TapeStorage> Sub<&AReal<T>> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn sub(self, rhs: &AReal<T>) -> AReal<T> {
record_binary(
self.value - rhs.value,
self.slot,
T::one(),
rhs.slot,
-T::one(),
)
}
}
impl<T: TapeStorage> Sub<AReal<T>> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn sub(self, rhs: AReal<T>) -> AReal<T> {
record_binary(
self.value - rhs.value,
self.slot,
T::one(),
rhs.slot,
-T::one(),
)
}
}
impl<T: TapeStorage> Sub<T> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn sub(self, rhs: T) -> AReal<T> {
record_unary(self.value - rhs, self.slot, T::one())
}
}
impl<T: TapeStorage> Sub<T> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn sub(self, rhs: T) -> AReal<T> {
record_unary(self.value - rhs, self.slot, T::one())
}
}
impl Sub<AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn sub(self, rhs: AReal<f64>) -> AReal<f64> {
record_unary(self - rhs.value, rhs.slot, -1.0)
}
}
impl Sub<&AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn sub(self, rhs: &AReal<f64>) -> AReal<f64> {
record_unary(self - rhs.value, rhs.slot, -1.0)
}
}
impl Sub<AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn sub(self, rhs: AReal<f32>) -> AReal<f32> {
record_unary(self - rhs.value, rhs.slot, -1.0)
}
}
impl Sub<&AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn sub(self, rhs: &AReal<f32>) -> AReal<f32> {
record_unary(self - rhs.value, rhs.slot, -1.0)
}
}
impl<T: TapeStorage> Mul for AReal<T> {
type Output = AReal<T>;
#[inline]
fn mul(self, rhs: AReal<T>) -> AReal<T> {
record_binary(
self.value * rhs.value,
self.slot,
rhs.value,
rhs.slot,
self.value,
)
}
}
impl<T: TapeStorage> Mul for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn mul(self, rhs: &AReal<T>) -> AReal<T> {
record_binary(
self.value * rhs.value,
self.slot,
rhs.value,
rhs.slot,
self.value,
)
}
}
impl<T: TapeStorage> Mul<&AReal<T>> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn mul(self, rhs: &AReal<T>) -> AReal<T> {
record_binary(
self.value * rhs.value,
self.slot,
rhs.value,
rhs.slot,
self.value,
)
}
}
impl<T: TapeStorage> Mul<AReal<T>> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn mul(self, rhs: AReal<T>) -> AReal<T> {
record_binary(
self.value * rhs.value,
self.slot,
rhs.value,
rhs.slot,
self.value,
)
}
}
impl<T: TapeStorage> Mul<T> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn mul(self, rhs: T) -> AReal<T> {
record_unary(self.value * rhs, self.slot, rhs)
}
}
impl<T: TapeStorage> Mul<T> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn mul(self, rhs: T) -> AReal<T> {
record_unary(self.value * rhs, self.slot, rhs)
}
}
impl Mul<AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn mul(self, rhs: AReal<f64>) -> AReal<f64> {
record_unary(self * rhs.value, rhs.slot, self)
}
}
impl Mul<&AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn mul(self, rhs: &AReal<f64>) -> AReal<f64> {
record_unary(self * rhs.value, rhs.slot, self)
}
}
impl Mul<AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn mul(self, rhs: AReal<f32>) -> AReal<f32> {
record_unary(self * rhs.value, rhs.slot, self)
}
}
impl Mul<&AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn mul(self, rhs: &AReal<f32>) -> AReal<f32> {
record_unary(self * rhs.value, rhs.slot, self)
}
}
impl<T: TapeStorage> Div for AReal<T> {
type Output = AReal<T>;
#[inline]
fn div(self, rhs: AReal<T>) -> AReal<T> {
let inv_b = T::one() / rhs.value;
record_binary(
self.value * inv_b,
self.slot,
inv_b,
rhs.slot,
-self.value * inv_b * inv_b,
)
}
}
impl<T: TapeStorage> Div for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn div(self, rhs: &AReal<T>) -> AReal<T> {
let inv_b = T::one() / rhs.value;
record_binary(
self.value * inv_b,
self.slot,
inv_b,
rhs.slot,
-self.value * inv_b * inv_b,
)
}
}
impl<T: TapeStorage> Div<&AReal<T>> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn div(self, rhs: &AReal<T>) -> AReal<T> {
let inv_b = T::one() / rhs.value;
record_binary(
self.value * inv_b,
self.slot,
inv_b,
rhs.slot,
-self.value * inv_b * inv_b,
)
}
}
impl<T: TapeStorage> Div<AReal<T>> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn div(self, rhs: AReal<T>) -> AReal<T> {
let inv_b = T::one() / rhs.value;
record_binary(
self.value * inv_b,
self.slot,
inv_b,
rhs.slot,
-self.value * inv_b * inv_b,
)
}
}
impl<T: TapeStorage> Div<T> for AReal<T> {
type Output = AReal<T>;
#[inline]
fn div(self, rhs: T) -> AReal<T> {
let inv = T::one() / rhs;
record_unary(self.value * inv, self.slot, inv)
}
}
impl<T: TapeStorage> Div<T> for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn div(self, rhs: T) -> AReal<T> {
let inv = T::one() / rhs;
record_unary(self.value * inv, self.slot, inv)
}
}
impl Div<AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn div(self, rhs: AReal<f64>) -> AReal<f64> {
let inv = 1.0 / rhs.value;
record_unary(self * inv, rhs.slot, -self * inv * inv)
}
}
impl Div<&AReal<f64>> for f64 {
type Output = AReal<f64>;
#[inline]
fn div(self, rhs: &AReal<f64>) -> AReal<f64> {
let inv = 1.0 / rhs.value;
record_unary(self * inv, rhs.slot, -self * inv * inv)
}
}
impl Div<AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn div(self, rhs: AReal<f32>) -> AReal<f32> {
let inv = 1.0 / rhs.value;
record_unary(self * inv, rhs.slot, -self * inv * inv)
}
}
impl Div<&AReal<f32>> for f32 {
type Output = AReal<f32>;
#[inline]
fn div(self, rhs: &AReal<f32>) -> AReal<f32> {
let inv = 1.0 / rhs.value;
record_unary(self * inv, rhs.slot, -self * inv * inv)
}
}
impl<T: TapeStorage> Neg for AReal<T> {
type Output = AReal<T>;
#[inline]
fn neg(self) -> AReal<T> {
record_unary(-self.value, self.slot, -T::one())
}
}
impl<T: TapeStorage> Neg for &AReal<T> {
type Output = AReal<T>;
#[inline]
fn neg(self) -> AReal<T> {
record_unary(-self.value, self.slot, -T::one())
}
}
impl<T: TapeStorage> AddAssign for AReal<T> {
fn add_assign(&mut self, rhs: AReal<T>) {
*self = self.clone() + rhs;
}
}
impl<T: TapeStorage> AddAssign<&AReal<T>> for AReal<T> {
fn add_assign(&mut self, rhs: &AReal<T>) {
*self = self.clone() + rhs;
}
}
impl<T: TapeStorage> AddAssign<T> for AReal<T> {
fn add_assign(&mut self, rhs: T) {
*self = self.clone() + rhs;
}
}
impl<T: TapeStorage> SubAssign for AReal<T> {
fn sub_assign(&mut self, rhs: AReal<T>) {
*self = self.clone() - rhs;
}
}
impl<T: TapeStorage> SubAssign<&AReal<T>> for AReal<T> {
fn sub_assign(&mut self, rhs: &AReal<T>) {
*self = self.clone() - rhs;
}
}
impl<T: TapeStorage> SubAssign<T> for AReal<T> {
fn sub_assign(&mut self, rhs: T) {
*self = self.clone() - rhs;
}
}
impl<T: TapeStorage> MulAssign for AReal<T> {
fn mul_assign(&mut self, rhs: AReal<T>) {
*self = self.clone() * rhs;
}
}
impl<T: TapeStorage> MulAssign<&AReal<T>> for AReal<T> {
fn mul_assign(&mut self, rhs: &AReal<T>) {
*self = self.clone() * rhs;
}
}
impl<T: TapeStorage> MulAssign<T> for AReal<T> {
fn mul_assign(&mut self, rhs: T) {
*self = self.clone() * rhs;
}
}
impl<T: TapeStorage> DivAssign for AReal<T> {
fn div_assign(&mut self, rhs: AReal<T>) {
*self = self.clone() / rhs;
}
}
impl<T: TapeStorage> DivAssign<&AReal<T>> for AReal<T> {
fn div_assign(&mut self, rhs: &AReal<T>) {
*self = self.clone() / rhs;
}
}
impl<T: TapeStorage> DivAssign<T> for AReal<T> {
fn div_assign(&mut self, rhs: T) {
*self = self.clone() / rhs;
}
}
impl<T: TapeStorage> PartialEq for AReal<T> {
fn eq(&self, other: &Self) -> bool {
self.value == other.value
}
}
impl<T: TapeStorage> PartialEq<T> for AReal<T> {
fn eq(&self, other: &T) -> bool {
self.value == *other
}
}
impl<T: TapeStorage> PartialOrd for AReal<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.value.partial_cmp(&other.value)
}
}
impl<T: TapeStorage> PartialOrd<T> for AReal<T> {
fn partial_cmp(&self, other: &T) -> Option<std::cmp::Ordering> {
self.value.partial_cmp(other)
}
}
impl<T: TapeStorage> fmt::Display for AReal<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.value)
}
}
impl<T: TapeStorage> fmt::Debug for AReal<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AReal({}, slot={})", self.value, self.slot)
}
}
impl<T: TapeStorage> Default for AReal<T> {
fn default() -> Self {
AReal::new(T::zero())
}
}