#![allow(clippy::double_parens)]
mod container_record;
mod functions;
pub mod operations;
pub mod record_operations;
pub mod trace_operations;
pub mod usage;
pub use container_record::*;
use crate::numeric::{Numeric, NumericRef};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
pub trait Primitive {}
#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Trace<T: Primitive> {
pub number: T,
pub derivative: T,
}
impl<T: Numeric + Primitive> Trace<T> {
pub fn constant(c: T) -> Trace<T> {
Trace {
number: c,
derivative: T::zero(),
}
}
pub fn variable(x: T) -> Trace<T> {
Trace {
number: x,
derivative: T::one(),
}
}
pub fn derivative(function: impl FnOnce(Trace<T>) -> Trace<T>, x: T) -> T {
(function(Trace::variable(x))).derivative
}
}
impl<T: Numeric + Primitive> Trace<T>
where
for<'a> &'a T: NumericRef<T>,
{
#[inline]
pub fn unary(&self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Trace<T> {
Trace {
number: fx(self.number.clone()),
derivative: self.derivative.clone() * dfx_dx(self.number.clone()),
}
}
#[inline]
pub fn binary(
&self,
rhs: &Trace<T>,
fxy: impl Fn(T, T) -> T,
dfxy_dx: impl Fn(T, T) -> T,
dfxy_dy: impl Fn(T, T) -> T,
) -> Trace<T> {
Trace {
number: fxy(self.number.clone(), rhs.number.clone()),
#[rustfmt::skip]
derivative: (
((self.derivative.clone() * dfxy_dx(self.number.clone(), rhs.number.clone()))
+ (rhs.derivative.clone() * dfxy_dy(self.number.clone(), rhs.number.clone())))
),
}
}
}
use std::cell::RefCell;
pub type Index = usize;
#[derive(Debug)]
pub struct WengertList<T> {
operations: RefCell<Vec<Operation<T>>>,
}
struct BorrowedWengertList<'a, T> {
operations: &'a mut Vec<Operation<T>>,
}
#[derive(Debug)]
struct Operation<T> {
left_parent: Index,
right_parent: Index,
left_derivative: T,
right_derivative: T,
}
#[derive(Debug)]
pub struct Derivatives<T> {
derivatives: Vec<T>,
}
impl<T: Clone> Clone for Derivatives<T> {
fn clone(&self) -> Self {
Derivatives {
derivatives: self.derivatives.clone(),
}
}
}
impl<T: Clone + Primitive> Derivatives<T> {
pub fn at(&self, input: &Record<T>) -> T {
self.derivatives[input.index].clone()
}
}
impl<'a, T: Primitive> std::ops::Index<&Record<'a, T>> for Derivatives<T> {
type Output = T;
fn index(&self, input: &Record<'a, T>) -> &Self::Output {
&self.derivatives[input.index]
}
}
impl<T> std::convert::From<Derivatives<T>> for Vec<T> {
fn from(derivatives: Derivatives<T>) -> Self {
derivatives.derivatives
}
}
impl<T: Clone + Primitive> Clone for Operation<T> {
fn clone(&self) -> Self {
Operation {
left_parent: self.left_parent,
right_parent: self.right_parent,
left_derivative: self.left_derivative.clone(),
right_derivative: self.right_derivative.clone(),
}
}
}
#[derive(Debug)]
pub struct Record<'a, T: Primitive> {
pub number: T,
history: Option<&'a WengertList<T>>,
pub index: Index,
}
impl<'a, T: Numeric + Primitive> Record<'a, T> {
pub fn constant(c: T) -> Record<'a, T> {
Record {
number: c,
history: None,
index: 0,
}
}
pub fn variable(x: T, history: &'a WengertList<T>) -> Record<'a, T> {
Record {
number: x,
history: Some(history),
index: history.append_nullary(),
}
}
pub fn from_existing(number: (T, Index), history: Option<&'a WengertList<T>>) -> Record<'a, T> {
Record {
number: number.0,
history,
index: number.1,
}
}
pub fn reset(&mut self) {
match self.history {
None => (), Some(history) => self.index = history.append_nullary(),
};
}
pub fn do_reset(mut x: Record<T>) -> Record<T> {
x.reset();
x
}
pub fn history(&self) -> Option<&'a WengertList<T>> {
self.history
}
}
impl<'a, T: Numeric + Primitive> Record<'a, T>
where
for<'t> &'t T: NumericRef<T>,
{
#[track_caller]
pub fn derivatives(&self) -> Derivatives<T> {
match self.try_derivatives() {
None => panic!("Record has no WengertList to find derivatives from"),
Some(d) => d,
}
}
pub fn try_derivatives(&self) -> Option<Derivatives<T>> {
let history = self.history?;
let operations = history.operations.borrow();
let mut derivatives = vec![T::zero(); operations.len()];
derivatives[self.index] = T::one();
for i in (0..operations.len()).rev() {
let operation = operations[i].clone();
let derivative = derivatives[i].clone();
derivatives[operation.left_parent] = derivatives[operation.left_parent].clone()
+ derivative.clone() * operation.left_derivative;
derivatives[operation.right_parent] = derivatives[operation.right_parent].clone()
+ derivative * operation.right_derivative;
}
Some(Derivatives { derivatives })
}
}
impl<T: Primitive> WengertList<T> {
pub fn new() -> WengertList<T> {
WengertList {
operations: RefCell::new(Vec::new()),
}
}
}
impl<T: Primitive> Default for WengertList<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> WengertList<T> {
pub fn clear(&self) {
self.operations.borrow_mut().clear();
}
}
impl<T: Numeric + Primitive> WengertList<T> {
pub fn variable(&self, x: T) -> Record<'_, T> {
Record {
number: x,
history: Some(self),
index: self.append_nullary(),
}
}
fn append_nullary(&self) -> Index {
use std::ops::DerefMut;
let mut borrow = self.operations.borrow_mut();
BorrowedWengertList::new(borrow.deref_mut()).append_nullary()
}
fn append_nullary_repeating(&self, values: usize) -> Index {
let mut operations = self.operations.borrow_mut();
let starting_index = operations.len();
for i in 0..values {
let index = starting_index + i;
operations.push(Operation {
left_parent: index,
right_parent: index,
left_derivative: T::zero(),
right_derivative: T::zero(),
});
}
starting_index
}
fn append_unary(&self, parent: Index, derivative: T) -> Index {
use std::ops::DerefMut;
let mut borrow = self.operations.borrow_mut();
BorrowedWengertList::new(borrow.deref_mut()).append_unary(parent, derivative)
}
fn append_binary(
&self,
left_parent: Index,
left_derivative: T,
right_parent: Index,
right_derivative: T,
) -> Index {
use std::ops::DerefMut;
let mut borrow = self.operations.borrow_mut();
BorrowedWengertList::new(borrow.deref_mut()).append_binary(
left_parent,
left_derivative,
right_parent,
right_derivative,
)
}
fn borrow<F>(&self, op: F)
where
F: FnOnce(&mut BorrowedWengertList<T>),
{
use std::ops::DerefMut;
let mut borrow = self.operations.borrow_mut();
op(&mut BorrowedWengertList::new(borrow.deref_mut()));
}
}
impl<T: Clone + Primitive> Clone for WengertList<T> {
fn clone(&self) -> Self {
WengertList {
operations: RefCell::new(self.operations.borrow().clone()),
}
}
}
impl<'a, T: Numeric + Primitive> BorrowedWengertList<'a, T> {
fn new(operations: &mut Vec<Operation<T>>) -> BorrowedWengertList<'_, T> {
BorrowedWengertList { operations }
}
fn append_nullary(&mut self) -> Index {
let index = self.operations.len();
self.operations.push(Operation {
left_parent: index,
right_parent: index,
left_derivative: T::zero(),
right_derivative: T::zero(),
});
index
}
fn append_unary(&mut self, parent: Index, derivative: T) -> Index {
let index = self.operations.len();
self.operations.push(Operation {
left_parent: parent,
right_parent: index,
left_derivative: derivative,
right_derivative: T::zero(),
});
index
}
fn append_binary(
&mut self,
left_parent: Index,
left_derivative: T,
right_parent: Index,
right_derivative: T,
) -> Index {
let index = self.operations.len();
self.operations.push(Operation {
left_parent,
right_parent,
left_derivative,
right_derivative,
});
index
}
}
impl<'a, T: Numeric + Primitive> Record<'a, T>
where
for<'t> &'t T: NumericRef<T>,
{
#[inline]
pub fn unary(&self, fx: impl Fn(T) -> T, dfx_dx: impl Fn(T) -> T) -> Record<'_, T> {
match self.history {
None => Record {
number: fx(self.number.clone()),
history: None,
index: 0,
},
Some(history) => Record {
number: fx(self.number.clone()),
history: Some(history),
index: history.append_unary(self.index, dfx_dx(self.number.clone())),
},
}
}
#[inline]
#[track_caller]
pub fn binary(
&self,
rhs: &Record<'a, T>,
fxy: impl Fn(T, T) -> T,
dfxy_dx: impl Fn(T, T) -> T,
dfxy_dy: impl Fn(T, T) -> T,
) -> Record<'_, T> {
assert!(
record_operations::same_list(self, rhs),
"Records must be using the same WengertList"
);
match (self.history, rhs.history) {
(None, None) => Record {
number: fxy(self.number.clone(), rhs.number.clone()),
history: None,
index: 0,
},
(Some(history), None) => Record {
number: fxy(self.number.clone(), rhs.number.clone()),
history: Some(history),
index: history.append_unary(
self.index,
dfxy_dx(self.number.clone(), rhs.number.clone()),
),
},
(None, Some(history)) => Record {
number: fxy(self.number.clone(), rhs.number.clone()),
history: Some(history),
index: history.append_unary(
rhs.index,
dfxy_dy(self.number.clone(), rhs.number.clone()),
),
},
(Some(history), Some(_)) => Record {
number: fxy(self.number.clone(), rhs.number.clone()),
history: Some(history),
index: history.append_binary(
self.index,
dfxy_dx(self.number.clone(), rhs.number.clone()),
rhs.index,
dfxy_dy(self.number.clone(), rhs.number.clone()),
),
},
}
}
}
#[cfg(test)]
#[should_panic]
#[test]
fn test_record_derivatives_when_no_history() {
let record = Record::constant(1.0);
record.derivatives();
}
#[test]
fn test_sync() {
fn assert_sync<T: Sync>() {}
assert_sync::<Trace<f64>>();
}
#[test]
fn test_send() {
fn assert_send<T: Send>() {}
assert_send::<Trace<f64>>();
}