redstone_ml/tensor/
constructors.rs

1use std::rc::Rc;
2use crate::gradient_function::GradientFunction;
3use crate::ndarray::flags::NdArrayFlags;
4use crate::ndarray::NdArray;
5use crate::none_backwards::NoneBackwards;
6use crate::{Constructors, Tensor, TensorDataType};
7
8
9impl<'a, T: TensorDataType> Constructors<T> for Tensor<'a, T> {
10    unsafe fn from_contiguous_owned_buffer(shape: Vec<usize>, data: Vec<T>) -> Self {
11        let array = NdArray::from_contiguous_owned_buffer(shape, data);
12        Self::from_array_and_flags(array, false, true)
13    }
14}
15
16impl<'a, T: TensorDataType> Tensor<'a, T> {
17    /// Constructs a new tensor from the given array, gradient function, and metadata
18    ///
19    /// # Parameters
20    /// - `requires_grad`: If gradients need to be computed for this tensor
21    /// - `grad_fn`: The gradient function used on the backwards pass
22    pub(crate) unsafe fn from_raw_parts(array: NdArray<'static, T>,
23                                        requires_grad: bool,
24                                        grad_fn: GradientFunction<T>) -> Self {
25        let mut flags = NdArrayFlags::empty();
26
27        if requires_grad {
28            flags |= NdArrayFlags::RequiresGrad;
29        }
30
31        Self {
32            array: Rc::new(array),
33            flags,
34            grad_fn,
35            
36            _marker: Default::default(),
37        }
38    }
39
40    /// Constructs a new tensor from the given array
41    ///
42    /// # Parameters
43    /// - `requires_grad`: If gradients need to be computed for this tensor
44    ///
45    /// # Safety
46    /// - `user_created` must be set only if the Tensor was generated by the user outside this crate
47    pub(crate) unsafe fn from_array_and_flags(array: NdArray<'static, T>,
48                                              requires_grad: bool,
49                                              user_created: bool) -> Self {
50        let mut flags = NdArrayFlags::empty();
51
52        if requires_grad {
53            flags |= NdArrayFlags::RequiresGrad;
54        }
55
56        if user_created {
57            flags |= NdArrayFlags::UserCreated;
58        }
59
60        Self {
61            array: Rc::new(array),
62            flags,
63            grad_fn: NoneBackwards::new(),
64
65            _marker: Default::default(),
66        }
67    }
68}