redstone_ml/tensor/
constructors.rs1use std::rc::Rc;
2use crate::gradient_function::GradientFunction;
3use crate::ndarray::flags::NdArrayFlags;
4use crate::ndarray::NdArray;
5use crate::none_backwards::NoneBackwards;
6use crate::{Constructors, Tensor, TensorDataType};
7
8
9impl<'a, T: TensorDataType> Constructors<T> for Tensor<'a, T> {
10 unsafe fn from_contiguous_owned_buffer(shape: Vec<usize>, data: Vec<T>) -> Self {
11 let array = NdArray::from_contiguous_owned_buffer(shape, data);
12 Self::from_array_and_flags(array, false, true)
13 }
14}
15
16impl<'a, T: TensorDataType> Tensor<'a, T> {
17 pub(crate) unsafe fn from_raw_parts(array: NdArray<'static, T>,
23 requires_grad: bool,
24 grad_fn: GradientFunction<T>) -> Self {
25 let mut flags = NdArrayFlags::empty();
26
27 if requires_grad {
28 flags |= NdArrayFlags::RequiresGrad;
29 }
30
31 Self {
32 array: Rc::new(array),
33 flags,
34 grad_fn,
35
36 _marker: Default::default(),
37 }
38 }
39
40 pub(crate) unsafe fn from_array_and_flags(array: NdArray<'static, T>,
48 requires_grad: bool,
49 user_created: bool) -> Self {
50 let mut flags = NdArrayFlags::empty();
51
52 if requires_grad {
53 flags |= NdArrayFlags::RequiresGrad;
54 }
55
56 if user_created {
57 flags |= NdArrayFlags::UserCreated;
58 }
59
60 Self {
61 array: Rc::new(array),
62 flags,
63 grad_fn: NoneBackwards::new(),
64
65 _marker: Default::default(),
66 }
67 }
68}