briny_ai 0.5.0

A tiny & efficient AI inference engine
Documentation
//! Flexible tensor backends based on features and targets.

mod optim;

#[cfg(feature = "alloc")]
mod vec;

use core::marker::PhantomData;

#[cfg(not(feature = "dyntensor"))]
pub use tensor_optim::ConstTensorOps;
pub use tensor_optim::TensorOps;

use crate::nn::TensorFloat;

pub use self::optim::{Flatten, StaticShape, Tensor, TensorGrad};
#[cfg(feature = "alloc")]
pub use vec::VecTensor;

/// A trait mainly for converting `Tensor`s to `WithGrad`.
pub trait IntoWithGrad<T>: TensorGrad<T> + Sized {
    /// Wraps the tensor with a zero-initialized gradient.
    fn with_grad(self) -> WithGrad<Self, T> {
        WithGrad::new(self)
    }

    /// Defines a `self` with a gradient of `grad`.
    fn grad_of(self, grad: Self) -> WithGrad<Self, T> {
        let mut w = WithGrad::new(self);
        w.set_grad(grad);
        w
    }
}

impl<T: TensorGrad<U>, U> IntoWithGrad<U> for T {}

/// A container for tracking gradients of values (used in autograd).
///
/// Typically used as `WithGrad<f32>` or `WithGrad<f64>`.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
pub struct WithGrad<T: TensorGrad<U>, U = TensorFloat> {
    value: T,
    grad: T,
    _marker: PhantomData<U>,
}

impl<T: TensorGrad<U>, U> WithGrad<T, U> {
    /// Creates a new `WithGrad`.
    ///
    /// The gradient is zeroed, whereas the value is initialized
    /// witht the data given.
    pub fn new(value: T) -> Self {
        let grad = value.zeros_like();
        Self {
            value,
            grad,
            _marker: PhantomData,
        }
    }

    /// Overwrites the gradient.
    pub fn set_grad(&mut self, grad: T) {
        self.grad = grad;
    }

    /// Get immutable references to the items.
    ///
    /// The gradient and the value are both immutable.
    pub const fn split(&self) -> (&T, &T) {
        (&self.value, &self.grad)
    }

    /// Get mutable references to the items.
    ///
    /// The gradient and the value are both mutable.
    pub const fn split_mut(&mut self) -> (&mut T, &mut T) {
        (&mut self.value, &mut self.grad)
    }

    /// Immutably singles out the gradient.
    pub const fn get_grad(&self) -> &T {
        &self.grad
    }

    /// Immutably singles out the value.
    pub const fn get_value(&self) -> &T {
        &self.value
    }

    /// Mutably singles out the gradient.
    pub const fn get_grad_mut(&mut self) -> &mut T {
        &mut self.grad
    }

    /// Mutably singles out the value.
    pub const fn get_value_mut(&mut self) -> &mut T {
        &mut self.value
    }

    /// Discards the gradient and moves out of the value.
    pub fn into_value(self) -> T {
        self.value
    }

    /// Discards the value and moves out of the gradient.
    pub fn into_grad(self) -> T {
        self.grad
    }

    /// Maps the value of `self`.
    #[must_use]
    pub fn map_value<F>(self, f: F) -> Self
    where 
        F: Fn(T) -> T,
    {
        Self {
            value: f(self.value),
            grad: self.grad,
            _marker: PhantomData,
        }
    }

    /// Maps the gradient of `self`.
    #[must_use]
    pub fn map_grad<F>(self, f: F) -> Self
    where 
        F: Fn(T) -> T,
    {
        Self {
            value: self.value,
            grad: f(self.grad),
            _marker: PhantomData,
        }
    }
}