use std::rc::Rc;
use crate::gradient_function::GradientFunction;
use crate::ndarray::flags::NdArrayFlags;
use crate::ndarray::NdArray;
use crate::none_backwards::NoneBackwards;
use crate::{Constructors, Tensor, TensorDataType};
impl<'a, T: TensorDataType> Constructors<T> for Tensor<'a, T> {
unsafe fn from_contiguous_owned_buffer(shape: Vec<usize>, data: Vec<T>) -> Self {
let array = NdArray::from_contiguous_owned_buffer(shape, data);
Self::from_array_and_flags(array, false, true)
}
}
impl<'a, T: TensorDataType> Tensor<'a, T> {
pub(crate) unsafe fn from_raw_parts(array: NdArray<'static, T>,
requires_grad: bool,
grad_fn: GradientFunction<T>) -> Self {
let mut flags = NdArrayFlags::empty();
if requires_grad {
flags |= NdArrayFlags::RequiresGrad;
}
Self {
array: Rc::new(array),
flags,
grad_fn,
_marker: Default::default(),
}
}
pub(crate) unsafe fn from_array_and_flags(array: NdArray<'static, T>,
requires_grad: bool,
user_created: bool) -> Self {
let mut flags = NdArrayFlags::empty();
if requires_grad {
flags |= NdArrayFlags::RequiresGrad;
}
if user_created {
flags |= NdArrayFlags::UserCreated;
}
Self {
array: Rc::new(array),
flags,
grad_fn: NoneBackwards::new(),
_marker: Default::default(),
}
}
}