jaime 0.1.0

Jaime's Artificial Inteligence and Machine learning Engine is an ergonomic all purpose backpropagation engine
Documentation
use crate::simd_arr::SimdArr;

use super::Dual;

pub trait ExtendedArithmetic {
    fn sqrt(self) -> Self;

    fn neg(self) -> Self;

    fn exp(self) -> Self;

    fn pow2(self) -> Self;

    fn abs(self) -> Self;

    fn sigmoid(self) -> Self;

    fn relu(self) -> Self;

    fn sqrt_on_mut(&mut self);
    fn exp_on_mut(&mut self);
    fn neg_on_mut(&mut self);
    fn pow2_on_mut(&mut self);
    fn abs_on_mut(&mut self);
    fn sigmoid_on_mut(&mut self);
    fn relu_on_mut(&mut self);

    fn accumulate(&mut self, x: &Self);
}

impl<const P: usize, S: SimdArr<P>> ExtendedArithmetic for Dual<P, S> {
    fn sqrt(mut self) -> Self {
        self.sqrt_on_mut();
        self
    }

    fn neg(mut self) -> Self {
        self.neg_on_mut();
        self
    }

    fn exp(mut self) -> Self {
        self.exp_on_mut();
        self
    }

    fn pow2(mut self) -> Self {
        self.pow2_on_mut();
        self
    }

    fn abs(mut self) -> Self {
        self.abs_on_mut();
        self
    }

    fn relu(mut self) -> Self {
        self.relu_on_mut();
        self
    }

    fn sigmoid(mut self) -> Self {
        self.sigmoid_on_mut();
        self
    }

    fn sqrt_on_mut(&mut self) {
        self.real = self.real.sqrt();
        self.sigma.multiply(1. / (2. * self.real.sqrt()));
        self.check_nan();
    }

    fn exp_on_mut(&mut self) {
        self.real = self.real.exp();
        self.sigma.multiply(self.real);
    }

    fn neg_on_mut(&mut self) {
        self.real = -self.real;
        self.sigma.multiply(-1.);
        self.check_nan();
    }

    fn pow2_on_mut(&mut self) {
        self.real *= self.real;
        self.sigma.multiply(self.real * 2.);
        self.check_nan();
    }

    fn abs_on_mut(&mut self) {
        if self.real < 0. {
            self.real = -self.real;
            self.sigma.neg();
            self.check_nan();
        }
    }

    fn sigmoid_on_mut(&mut self) {
        self.real = self.real.sigmoid();

        self.sigma.multiply(self.real * (1. - self.real));
        self.check_nan();
    }

    fn relu_on_mut(&mut self) {
        if self.real < 0. {
            self.real = 0.;
            self.sigma = S::zero();
            self.check_nan();
        }
    }

    fn accumulate(&mut self, x: &Dual<P, S>) {
        self.real += x.real;
        self.sigma.acumulate(&x.sigma);
        self.check_nan();
    }
}

impl ExtendedArithmetic for f32 {
    fn sqrt(self) -> Self {
        self.sqrt()
    }

    fn neg(self) -> Self {
        -self
    }

    fn exp(self) -> Self {
        self.exp()
    }

    fn pow2(self) -> Self {
        self * self
    }

    fn abs(self) -> Self {
        self.abs()
    }

    fn relu(self) -> Self {
        self.max(0.)
    }

    fn sigmoid(mut self) -> Self {
        self.sigmoid_on_mut();
        self
    }

    fn sqrt_on_mut(&mut self) {
        *self = self.sqrt()
    }

    fn neg_on_mut(&mut self) {
        *self = -*self;
    }

    fn exp_on_mut(&mut self) {
        *self = self.exp()
    }

    fn pow2_on_mut(&mut self) {
        *self = *self * *self;
    }

    fn abs_on_mut(&mut self) {
        *self = self.abs();
    }

    fn sigmoid_on_mut(&mut self) {
        *self = 1. / (1. + (-*self).exp());
    }

    fn relu_on_mut(&mut self) {
        *self = self.max(0.);
    }

    fn accumulate(&mut self, x: &f32) {
        *self += x;
    }
}