redstone-ml 0.0.0

High-performance Machine Learning, Auto-Differentiation and Tensor Algebra crate for Rust
Documentation
use std::rc::Rc;
use crate::gradient_function::GradientFunction;
use crate::ndarray::flags::NdArrayFlags;
use crate::ndarray::NdArray;
use crate::none_backwards::NoneBackwards;
use crate::{Constructors, Tensor, TensorDataType};


impl<'a, T: TensorDataType> Constructors<T> for Tensor<'a, T> {
    unsafe fn from_contiguous_owned_buffer(shape: Vec<usize>, data: Vec<T>) -> Self {
        let array = NdArray::from_contiguous_owned_buffer(shape, data);
        Self::from_array_and_flags(array, false, true)
    }
}

impl<'a, T: TensorDataType> Tensor<'a, T> {
    /// Constructs a new tensor from the given array, gradient function, and metadata
    ///
    /// # Parameters
    /// - `requires_grad`: If gradients need to be computed for this tensor
    /// - `grad_fn`: The gradient function used on the backwards pass
    pub(crate) unsafe fn from_raw_parts(array: NdArray<'static, T>,
                                        requires_grad: bool,
                                        grad_fn: GradientFunction<T>) -> Self {
        let mut flags = NdArrayFlags::empty();

        if requires_grad {
            flags |= NdArrayFlags::RequiresGrad;
        }

        Self {
            array: Rc::new(array),
            flags,
            grad_fn,
            
            _marker: Default::default(),
        }
    }

    /// Constructs a new tensor from the given array
    ///
    /// # Parameters
    /// - `requires_grad`: If gradients need to be computed for this tensor
    ///
    /// # Safety
    /// - `user_created` must be set only if the Tensor was generated by the user outside this crate
    pub(crate) unsafe fn from_array_and_flags(array: NdArray<'static, T>,
                                              requires_grad: bool,
                                              user_created: bool) -> Self {
        let mut flags = NdArrayFlags::empty();

        if requires_grad {
            flags |= NdArrayFlags::RequiresGrad;
        }

        if user_created {
            flags |= NdArrayFlags::UserCreated;
        }

        Self {
            array: Rc::new(array),
            flags,
            grad_fn: NoneBackwards::new(),

            _marker: Default::default(),
        }
    }
}