use super::*;
use num::traits::{Float, Num, NumCast, One, ToPrimitive, Zero};
pub type ASingle = AFloat<f32>;
pub type ADouble = AFloat<f64>;
#[derive(Clone, Copy, Debug)]
pub struct AFloat<S: Float> {
v: S,
dv: S,
ctx: Option<(usize, usize)>,
}
impl<S: Float> AFloat<S> {
pub fn new(v: S, dv: S) -> Self {
Self { v, dv, ctx: None }
}
pub fn value(&self) -> S {
self.v
}
pub fn value_mut(&mut self) -> &mut S {
&mut self.v
}
pub fn dvalue(&self) -> S {
self.dv
}
pub fn dvalue_mut(&mut self) -> &mut S {
&mut self.dv
}
fn from_op(opcode: OpCode, mut arg1: Self, mut arg2: Option<Self>) -> Self {
let v = zero_order_value(opcode, arg1.v, arg2.map(|x| x.v));
let dv = first_order_value(
opcode,
arg1.v,
arg2.map(|x| x.v),
arg1.dv,
arg2.map(|x| x.dv),
);
let mut this = Self::new(v, dv);
let mut cid = None;
if let Some((cid1, _)) = arg1.context() {
cid = Some(cid1);
}
if let Some(arg2) = arg2 {
if let Some((cid2, _)) = arg2.context() {
if let Some(cid) = cid {
assert_eq!(cid2, cid);
} else {
cid = Some(cid2);
}
}
}
if let Some(cid) = cid {
if let Some(mut ctx) = AContext::from_cid(cid) {
if arg1.context().is_none() {
let vid = ctx.record(OpCode::Const, arg1.value(), None, None);
arg1.ctx = Some((cid, vid));
}
if let Some(ref mut arg2) = &mut arg2 {
if arg2.context().is_none() {
let vid = ctx.record(OpCode::Const, arg2.value(), None, None);
arg2.ctx = Some((cid, vid));
}
}
let arg1_vid = Some(arg1.context().unwrap().1);
let arg2_vid = arg2.map(|arg2| arg2.context().unwrap().1);
let vid = ctx.record(opcode, v, arg1_vid, arg2_vid);
this.ctx = Some((cid, vid));
}
}
this
}
pub fn cast<T: Float>(x: AFloat<T>) -> Self {
Self {
v: S::from(x.v).unwrap(),
dv: S::from(x.dv).unwrap(),
ctx: x.ctx,
}
}
pub(crate) fn set_context(&mut self, ctx_id: usize, val_id: usize) {
self.ctx = Some((ctx_id, val_id));
}
pub(crate) fn context(&self) -> Option<(usize, usize)> {
self.ctx
}
}
impl<S: Float> std::cmp::PartialEq<AFloat<S>> for AFloat<S> {
fn eq(&self, other: &Self) -> bool {
self.v.eq(&other.v)
}
}
impl<S: Float> std::cmp::PartialOrd<AFloat<S>> for AFloat<S> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.v.partial_cmp(&other.v)
}
}
impl<S: Float> From<S> for AFloat<S> {
fn from(scalar: S) -> Self {
AFloat::new(scalar, S::zero())
}
}
impl<S: Float> std::ops::Neg for AFloat<S> {
type Output = AFloat<S>;
fn neg(self) -> Self {
(Self::zero() - Self::one()) * self
}
}
macro_rules! binary_op {
($op:ident, $method:ident) => {
impl<S: Float> std::ops::$op<AFloat<S>> for AFloat<S> {
type Output = AFloat<S>;
fn $method(self, rhs: Self) -> Self {
Self::from_op(OpCode::$op, self, Some(rhs))
}
}
impl<S: Float> std::ops::$op<f64> for AFloat<S> {
type Output = AFloat<S>;
fn $method(self, rhs: f64) -> Self {
Self::from_op(OpCode::$op, self, Some(NumCast::from(rhs).unwrap()))
}
}
impl<S: Float> std::ops::$op<AFloat<S>> for f64 {
type Output = AFloat<S>;
fn $method(self, rhs: AFloat<S>) -> AFloat<S> {
AFloat::<S>::from_op(OpCode::$op, NumCast::from(self).unwrap(), Some(rhs))
}
}
};
}
binary_op!(Add, add);
binary_op!(Sub, sub);
binary_op!(Mul, mul);
binary_op!(Div, div);
macro_rules! assign_op {
($op:ident, $method:ident, $optoken:tt) => {
impl<S: Float> std::ops::$op<AFloat<S>> for AFloat<S> {
fn $method(&mut self, rhs: AFloat<S>) {
let result = *self $optoken rhs;
*self = result;
}
}
impl<S: Float> std::ops::$op<f64> for AFloat<S> {
fn $method(&mut self, rhs: f64) {
let result = *self $optoken rhs;
*self = result;
}
}
}
}
assign_op!(AddAssign, add_assign, +);
assign_op!(SubAssign, sub_assign, -);
assign_op!(MulAssign, mul_assign, *);
assign_op!(DivAssign, div_assign, /);
impl<S: Float> std::ops::Rem<AFloat<S>> for AFloat<S> {
type Output = AFloat<S>;
fn rem(self, _rhs: Self) -> Self {
panic!("Operation '%' unsupported on AFloat");
}
}
impl<S: Float> ToPrimitive for AFloat<S> {
fn to_i64(&self) -> Option<i64> {
self.v.to_i64()
}
fn to_u64(&self) -> Option<u64> {
self.v.to_u64()
}
}
impl<S: Float> NumCast for AFloat<S> {
fn from<T>(n: T) -> Option<Self>
where
T: ToPrimitive,
{
S::from(n).map(|n| Self::new(n, S::zero()))
}
}
impl<S: Float> Zero for AFloat<S> {
fn zero() -> Self {
Self::new(S::zero(), S::zero())
}
fn is_zero(&self) -> bool {
self.v.is_zero()
}
}
impl<S: Float> One for AFloat<S> {
fn one() -> Self {
Self::new(S::one(), S::zero())
}
fn is_one(&self) -> bool {
self.v.is_one()
}
}
impl<S: Float> Num for AFloat<S> {
type FromStrRadixErr = S::FromStrRadixErr;
fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
Ok(Self::new(S::from_str_radix(str, radix)?, S::zero()))
}
}
macro_rules! float_constant {
($method:ident) => {
fn $method() -> Self {
Self::new(S::$method(), S::zero())
}
};
}
macro_rules! float_passthrough {
($type:ty, $method:ident) => {
fn $method(self) -> $type {
self.v.$method()
}
}
}
macro_rules! float_unsupported {
($type:ty, $method:ident $(, $arg_type:ty)*) => {
fn $method(self, $(_: $arg_type)*) -> $type {
panic!(concat!("Operation '", stringify!($method), "' unsupported on AFloat"));
}
}
}
macro_rules! float_elemental {
($method:ident, $opcode:ident) => {
fn $method(self) -> Self {
Self::from_op(OpCode::$opcode, self, None)
}
}
}
macro_rules! float_elemental2 {
($method:ident, $opcode:ident) => {
fn $method(self, other: Self) -> Self {
Self::from_op(OpCode::$opcode, self, Some(other))
}
}
}
impl<S: Float> Float for AFloat<S> {
float_constant!(nan);
float_constant!(infinity);
float_constant!(neg_infinity);
float_constant!(neg_zero);
float_constant!(min_value);
float_constant!(min_positive_value);
float_constant!(max_value);
float_passthrough!(bool, is_nan);
float_passthrough!(bool, is_infinite);
float_passthrough!(bool, is_finite);
float_passthrough!(bool, is_normal);
float_passthrough!(bool, is_sign_positive);
float_passthrough!(bool, is_sign_negative);
float_passthrough!(std::num::FpCategory, classify);
float_unsupported!(Self, floor);
float_unsupported!(Self, ceil);
float_unsupported!(Self, round);
float_unsupported!(Self, trunc);
float_unsupported!(Self, fract);
float_unsupported!(Self, signum);
float_unsupported!(Self, exp_m1);
float_unsupported!(Self, ln_1p);
float_unsupported!(Self, sinh);
float_unsupported!(Self, cosh);
float_unsupported!(Self, tanh);
float_unsupported!(Self, asinh);
float_unsupported!(Self, acosh);
float_unsupported!(Self, atanh);
float_unsupported!(Self, atan2, Self);
float_elemental!(abs, Abs);
float_elemental!(exp, Exp);
float_elemental!(ln, Ln);
float_elemental!(sin, Sin);
float_elemental!(cos, Cos);
float_elemental!(tan, Tan);
float_elemental!(asin, Asin);
float_elemental!(acos, Acos);
float_elemental!(atan, Atan);
float_elemental2!(powf, Powf);
fn mul_add(self, a: Self, b: Self) -> Self {
(self * a) + b
}
fn recip(self) -> Self {
Self::one() / self
}
fn powi(self, n: i32) -> Self {
self.powf(<Self as NumCast>::from(n).unwrap())
}
fn sqrt(self) -> Self {
self.powf(<Self as NumCast>::from(0.5).unwrap())
}
fn exp2(self) -> Self {
<Self as NumCast>::from(2.0).unwrap().powf(self)
}
fn log(self, base: Self) -> Self {
self.ln() / base.ln()
}
fn log2(self) -> Self {
self.log(<Self as NumCast>::from(2.0).unwrap())
}
fn log10(self) -> Self {
self.log(<Self as NumCast>::from(10.0).unwrap())
}
fn max(self, other: Self) -> Self {
<Self as NumCast>::from(0.5).unwrap() * (self + other + (self - other).abs())
}
fn min(self, other: Self) -> Self {
<Self as NumCast>::from(0.5).unwrap() * (self + other - (self - other).abs())
}
fn abs_sub(self, other: Self) -> Self {
(self - other).abs()
}
fn cbrt(self) -> Self {
self.powf(Self::one() / <Self as NumCast>::from(3.0).unwrap())
}
fn hypot(self, other: Self) -> Self {
(self.powi(2) + other.powi(2)).sqrt()
}
fn sin_cos(self) -> (Self, Self) {
(self.sin(), self.cos())
}
fn integer_decode(self) -> (u64, i16, i8) {
self.v.integer_decode()
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-5;
#[allow(clippy::let_and_return)]
fn test_function<S: Float>(x: S) -> S {
let v1 = x + x - x * x / x;
let v2 = (v1 + S::from(2.0).unwrap() - S::from(2.0).unwrap()) * S::from(2.0).unwrap()
/ S::from(2.0).unwrap();
let v3 = -(S::from(2.0).unwrap() - (S::from(2.0).unwrap() + v2));
let v4 = S::from(2.0).unwrap() / (S::from(2.0).unwrap() * v3);
let v5 = S::one() / v4;
v5
}
#[test]
fn afloat_consistency() {
let x = AFloat::<f64>::new(2.0, 1.0);
let y = test_function(x);
assert!((y.value() - x.value()).abs() < std::f64::EPSILON);
assert!((y.dvalue() - 1.0).abs() < std::f64::EPSILON);
}
#[test]
fn afloat_nonlinear_functions() {
macro_rules! test_case {
($func:ident, $dy:expr) => {
let x = AFloat::<f64>::new(0.5, 1.0);
let y = x.$func();
println!("{}, {}, {}", stringify!($func), y.dvalue(), $dy(x.value()));
assert!(y.dvalue() - $dy(x.value()) < EPS)
};
}
test_case!(sin, |x: f64| x.cos());
test_case!(cos, |x: f64| -x.sin());
test_case!(tan, |x: f64| (1.0 / x.cos()).powi(2));
test_case!(exp, |x: f64| x.exp());
test_case!(ln, |x: f64| 1.0 / x);
test_case!(sqrt, |x: f64| 0.5 / x.sqrt());
test_case!(asin, |x: f64| 1.0 / (1.0 - x.powi(2)).sqrt());
test_case!(acos, |x: f64| -1.0 / (1.0 - x.powi(2)).sqrt());
test_case!(atan, |x: f64| 1.0 / (1.0 + x.powi(2)));
}
}