pub struct Tensor<'a, T: TensorDataType> { /* private fields */ }Implementations§
Source§impl<'a, T: TensorDataType> Tensor<'a, T>
impl<'a, T: TensorDataType> Tensor<'a, T>
Sourcepub fn value(&self) -> T
pub fn value(&self) -> T
Retrieves the single value contained within a tensor with a singular element.
§Panics
If the tensor contains more than one element (i.e., it is not a scalar or a tensor with a single element)
§Example
let tensor = Tensor::scalar(50.0);
let value = tensor.value();
assert_eq!(value, 50.0);§Notes
This function is only meant for arrays that are guaranteed to have exactly one element. For arrays with multiple elements, consider using appropriate methods to access individual elements or slices safely.
Sourcepub fn ndarray(&self) -> &NdArray<'a, T>
pub fn ndarray(&self) -> &NdArray<'a, T>
Returns a reference to the underlying NdArray of the tensor
Sourcepub fn get_ndarray(&self) -> Rc<NdArray<'static, T>>
pub fn get_ndarray(&self) -> Rc<NdArray<'static, T>>
Returns a reference-counted pointer to the underlying NdArray of the tensor
Sourcepub fn into_ndarray(self) -> NdArray<'static, T>
pub fn into_ndarray(self) -> NdArray<'static, T>
Converts the tensor to an NdArray
Source§impl<'a, T: TensorDataType> Tensor<'a, T>
impl<'a, T: TensorDataType> Tensor<'a, T>
Sourcepub fn is_leaf(&self) -> bool
pub fn is_leaf(&self) -> bool
Checks if the tensor is a leaf.
A tensor is considered a leaf node if requires_grad = true
and it was explicitly created by the user, or if requires_grad = false.
§Examples
let mut tensor = Tensor::new([1.0, 2.0, 3.0]);
tensor.set_requires_grad(true);
assert!(tensor.is_leaf());
let tensor2 = -tensor;
assert!(!tensor2.is_leaf());Sourcepub fn requires_grad(&self) -> bool
pub fn requires_grad(&self) -> bool
Returns whether gradients must be computed for this tensor.
A tensor is marked with the requires_grad flag if it was explicitly specified by the user
through the set_requires_grad() method or if the tensor was created using operations
on other tensors which were marked requires_grad.
§Examples
let mut tensor = Tensor::new([1.0, 2.0, 3.0]);
tensor.set_requires_grad(true);
let tensor2 = -tensor;
assert!(tensor2.requires_grad());Sourcepub fn set_requires_grad(&mut self, requires_grad: bool) -> &mut Self
pub fn set_requires_grad(&mut self, requires_grad: bool) -> &mut Self
Sets whether gradients must be computed for this tensor.
Sourcepub fn gradient(&'a self) -> Option<NdArray<'a, T>>
pub fn gradient(&'a self) -> Option<NdArray<'a, T>>
Returns the gradient of the differentiated tensor with respect to self.
This method returns a view into the gradient.
§Examples
let mut a = Tensor::scalar(2.0f32);
let b = Tensor::scalar(3.0);
a.set_requires_grad(true);
let c = &a * &b;
c.backward();
// dc/da = b
assert_eq!(a.gradient().unwrap(), b);Sourcepub fn zero_gradient(&self)
pub fn zero_gradient(&self)
Sets the gradient of this tensor to zero.
§Examples
let mut a = Tensor::scalar(2.0f32);
let b = Tensor::scalar(3.0);
a.set_requires_grad(true);
let c = &a * &b;
c.backward();
a.zero_gradient();
assert_eq!(a.gradient().unwrap(), Tensor::scalar(0.0));Sourcepub fn backward_with(&self, gradient: impl AsRef<NdArray<'a, T>>)
pub fn backward_with(&self, gradient: impl AsRef<NdArray<'a, T>>)
Computes the gradient of the self with respect to its leaf tensors.
§Parameters
gradient: the gradient of the tensor being differentiated with respect toself.
§Examples
let mut a = Tensor::full(2.0, [3]); // [2, 2, 2]
let b = Tensor::new([3.0, 1.0, -1.0]);
a.set_requires_grad(true);
let c = &a * &b;
c.backward_with(NdArray::new([2.0, 1.0, 1.0]));
// dc/da = b
assert_eq!(a.gradient().unwrap(), Tensor::new([6.0, 1.0, -1.0]));Sourcepub fn backward(&self)
pub fn backward(&self)
Computes the gradient of the self with respect to its leaf tensors.
§Examples
let mut a = Tensor::full(2.0, [3]); // [2, 2, 2]
let b = Tensor::new([3.0, 1.0, -1.0]);
a.set_requires_grad(true);
let c = &a * &b;
c.backward();
// dc/da = b
assert_eq!(a.gradient().unwrap(), Tensor::new([3.0, 1.0, -1.0]));Source§impl<'a, T: TensorDataType> Tensor<'a, T>
impl<'a, T: TensorDataType> Tensor<'a, T>
Sourcepub fn dot<'b, 'r>(&self, other: impl AsRef<Tensor<'b, T>>) -> Tensor<'r, T>
pub fn dot<'b, 'r>(&self, other: impl AsRef<Tensor<'b, T>>) -> Tensor<'r, T>
Calculates the dot product of two 1D tensors.
§Panics
- Panics if either tensor is not 1D
- Panics if the lengths of the two tensors are not equal
§Examples
let tensor1 = Tensor::new([1.0, 2.0, 3.0]);
let tensor2 = Tensor::new([4.0, 5.0, 6.0]);
let result = tensor1.dot(tensor2);
assert_eq!(result.value(), 32.0); // 1*4 + 2*5 + 3*6 = 32Sourcepub fn matmul<'r>(&self, other: impl AsRef<Tensor<'a, T>>) -> Tensor<'r, T>
pub fn matmul<'r>(&self, other: impl AsRef<Tensor<'a, T>>) -> Tensor<'r, T>
Calculates the matrix product of two tensors.
- If both tensors are 1D, then their dot product is returned.
- If both tensors are 2D, then their matrix product is returned.
- If the first tensor is 2D and the second tensor is 1D, then the matrix-vector product is returned.
§Panics
- If the dimensions/shape of the tensors are incompatible
§Example
let a = Tensor::new(vec![
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
]);
let b = Tensor::new(vec![
[7.0, 8.0],
[9.0, 10.0],
[11.0, 12.0],
]);
let result = a.matmul(&b);
assert_eq!(result, Tensor::new([
[58.0, 64.0],
[139.0, 154.0],
]));Sourcepub fn bmm<'r>(&self, other: impl AsRef<Tensor<'a, T>>) -> Tensor<'r, T>
pub fn bmm<'r>(&self, other: impl AsRef<Tensor<'a, T>>) -> Tensor<'r, T>
Performs batch matrix multiplication on 3D tensors.
The shape of the resulting ndarray will be [batch_size, self.shape()[1], other.shape()[2]],
where batch_size is the shared first dimension of both input tensors.
§Panics
- If either tensor is not 3D
- If the tensors do not have dimensions compatible for batch matrix multiplication.
§Example
let arr1 = Tensor::<f32>::rand([3, 2, 4]); // 3 batches of 2x4 matrices
let arr2 = Tensor::<f32>::rand([3, 4, 5]); // 3 batches of 4x5 matrices
let result = arr1.bmm(&arr2);
assert_eq!(result.shape(), [3, 2, 5]); // result is 3 batches of 2x5 matricesTrait Implementations§
Source§impl<T: TensorDataType> Add<T> for &Tensor<'_, T>
impl<T: TensorDataType> Add<T> for &Tensor<'_, T>
Source§impl<T: TensorDataType> Add<T> for Tensor<'_, T>
impl<T: TensorDataType> Add<T> for Tensor<'_, T>
Source§impl<'a, T: TensorDataType> Constructors<T> for Tensor<'a, T>
impl<'a, T: TensorDataType> Constructors<T> for Tensor<'a, T>
Source§unsafe fn from_contiguous_owned_buffer(shape: Vec<usize>, data: Vec<T>) -> Self
unsafe fn from_contiguous_owned_buffer(shape: Vec<usize>, data: Vec<T>) -> Self
Source§fn new<const D: usize>(data: impl Flatten<T> + Shape + Nested<D>) -> Self
fn new<const D: usize>(data: impl Flatten<T> + Shape + Nested<D>) -> Self
NdArray from input data such as a vector or array. Read moreSource§fn full(n: T, shape: impl ToVec<usize>) -> Self
fn full(n: T, shape: impl ToVec<usize>) -> Self
Source§fn zeros(shape: impl ToVec<usize>) -> Self
fn zeros(shape: impl ToVec<usize>) -> Self
Source§fn ones(shape: impl ToVec<usize>) -> Self
fn ones(shape: impl ToVec<usize>) -> Self
Source§fn scalar(n: T) -> Self
fn scalar(n: T) -> Self
Source§fn arange(start: T, stop: T) -> Selfwhere
T: NumericDataType,
fn arange(start: T, stop: T) -> Selfwhere
T: NumericDataType,
Source§fn arange_with_step(start: T, stop: T, step: T) -> Selfwhere
T: NumericDataType,
fn arange_with_step(start: T, stop: T, step: T) -> Selfwhere
T: NumericDataType,
Source§fn linspace(start: T, stop: T, num: usize) -> Selfwhere
T: FloatDataType,
fn linspace(start: T, stop: T, num: usize) -> Selfwhere
T: FloatDataType,
num evenly spaced values between start and stop
(inclusive). Read moreSource§fn linspace_exclusive(start: T, stop: T, num: usize) -> Selfwhere
T: FloatDataType,
fn linspace_exclusive(start: T, stop: T, num: usize) -> Selfwhere
T: FloatDataType,
num evenly spaced values between start and stop
(exclusive). Read moreSource§impl<T: TensorDataType> Debug for Tensor<'_, T>
impl<T: TensorDataType> Debug for Tensor<'_, T>
Source§impl<T: TensorDataType> Div<T> for &Tensor<'_, T>
impl<T: TensorDataType> Div<T> for &Tensor<'_, T>
Source§impl<T: TensorDataType> Div<T> for Tensor<'_, T>
impl<T: TensorDataType> Div<T> for Tensor<'_, T>
Source§impl<T: TensorDataType> Mul<T> for &Tensor<'_, T>
impl<T: TensorDataType> Mul<T> for &Tensor<'_, T>
Source§impl<T: TensorDataType> Mul<T> for Tensor<'_, T>
impl<T: TensorDataType> Mul<T> for Tensor<'_, T>
Source§impl<T: TensorDataType> Neg for &Tensor<'_, T>
impl<T: TensorDataType> Neg for &Tensor<'_, T>
Source§impl<T: TensorDataType> Neg for Tensor<'_, T>
impl<T: TensorDataType> Neg for Tensor<'_, T>
Source§impl<'a, T: TensorDataType> RandomConstructors<T> for Tensor<'a, T>
impl<'a, T: TensorDataType> RandomConstructors<T> for Tensor<'a, T>
Source§fn randn(shape: impl ToVec<usize>) -> Selfwhere
T: FloatDataType,
fn randn(shape: impl ToVec<usize>) -> Selfwhere
T: FloatDataType,
NdArray with the specified shape
from a standard normal distribution (0 mean, unit standard deviation). Read moreSource§fn rand(shape: impl ToVec<usize>) -> Selfwhere
T: FloatDataType,
fn rand(shape: impl ToVec<usize>) -> Selfwhere
T: FloatDataType,
NdArray with the specified shape
with values uniformly distributed in [0, 1). Read moreSource§impl<'a, T: TensorDataType> Reshape<T> for &'a Tensor<'a, T>
impl<'a, T: TensorDataType> Reshape<T> for &'a Tensor<'a, T>
Source§unsafe fn reshaped_view(
self,
shape: Vec<usize>,
stride: Vec<usize>,
) -> Self::Output
unsafe fn reshaped_view( self, shape: Vec<usize>, stride: Vec<usize>, ) -> Self::Output
Provides a non-owning view of the tensor with the specified shape and stride. The data pointed to by the view is shared with the original tensor.
§Safety
- Ensure the memory layout referenced by
shape, andstrideis valid and owned by the original tensor.
Source§fn view(self) -> Self::Output
fn view(self) -> Self::Output
Provides a non-owning view of the tensor that shares its data with the original tensor.
§Example
let tensor = Tensor::new([1.0, 2.0, 3.0, 4.0]);
let view = (&tensor).view();
assert!(view.is_view())Source§fn transpose(self, axis1: impl AxisType, axis2: impl AxisType) -> Self::Output
fn transpose(self, axis1: impl AxisType, axis2: impl AxisType) -> Self::Output
Returns a transposed version of the tensor, swapping the specified axes.
§Panics
- If
axis1oraxis2are out of bounds
§Examples
let array = Tensor::new([[2.0, 3.0, 4.0], [10.0, 20.0, 30.0]]);
let transposed = array.transpose(0, 1);
assert_eq!(transposed, Tensor::new([[2.0, 10.0], [3.0, 20.0], [4.0, 30.0]]));type Output = Tensor<'a, T>
Source§fn reshape(self, new_shape: impl ToVec<usize>) -> Self::Output
fn reshape(self, new_shape: impl ToVec<usize>) -> Self::Output
Source§fn squeeze(self) -> Self::Output
fn squeeze(self) -> Self::Output
Source§impl<T: TensorDataType> Reshape<T> for Tensor<'_, T>
impl<T: TensorDataType> Reshape<T> for Tensor<'_, T>
Source§unsafe fn reshaped_view(
self,
shape: Vec<usize>,
stride: Vec<usize>,
) -> Self::Output
unsafe fn reshaped_view( self, shape: Vec<usize>, stride: Vec<usize>, ) -> Self::Output
Provides a non-owning view of the tensor with the specified shape and stride. The data pointed to by the view is shared with the original tensor.
§Safety
- Ensure the memory layout referenced by
shape, andstrideis valid and owned by the original tensor.
Source§fn view(self) -> Self::Output
fn view(self) -> Self::Output
Provides a non-owning view of the tensor that shares its data with the original tensor.
§Example
let tensor = Tensor::new([1.0, 2.0, 3.0, 4.0]);
let view = (&tensor).view();
assert!(view.is_view())Source§fn transpose(self, axis1: impl AxisType, axis2: impl AxisType) -> Self::Output
fn transpose(self, axis1: impl AxisType, axis2: impl AxisType) -> Self::Output
Returns a transposed version of the tensor, swapping the specified axes.
§Panics
- If
axis1oraxis2are out of bounds
§Examples
let array = Tensor::new([[2.0, 3.0, 4.0], [10.0, 20.0, 30.0]]);
let transposed = array.transpose(0, 1);
assert_eq!(transposed, Tensor::new([[2.0, 10.0], [3.0, 20.0], [4.0, 30.0]]));type Output = Tensor<'static, T>
Source§fn reshape(self, new_shape: impl ToVec<usize>) -> Self::Output
fn reshape(self, new_shape: impl ToVec<usize>) -> Self::Output
Source§fn squeeze(self) -> Self::Output
fn squeeze(self) -> Self::Output
Source§impl<T: TensorDataType> StridedMemory for &Tensor<'_, T>
impl<T: TensorDataType> StridedMemory for &Tensor<'_, T>
Source§fn shape(&self) -> &[usize]
fn shape(&self) -> &[usize]
Returns the dimensions of the tensor along each axis.
let a = Tensor::new([3.0, 4.0, 5.0]);
assert_eq!(a.shape(), &[3]);
let b = Tensor::new([[3.0], [5.0]]);
assert_eq!(b.shape(), &[2, 1]);
let c = Tensor::scalar(0.0);
assert_eq!(c.shape(), &[]);Source§fn stride(&self) -> &[usize]
fn stride(&self) -> &[usize]
Returns the stride of the tensor.
The stride represents the distance in memory between elements in a tensor along each axis.
let a = Tensor::new([[3.0, 4.0], [5.0, 6.0]]);
assert_eq!(a.stride(), &[2, 1]);Source§fn flags(&self) -> NdArrayFlags
fn flags(&self) -> NdArrayFlags
Returns flags containing information about various tensor metadata.
Source§fn len(&self) -> usize
fn len(&self) -> usize
Source§fn is_contiguous(&self) -> bool
fn is_contiguous(&self) -> bool
Source§fn is_uniformly_strided(&self) -> bool
fn is_uniformly_strided(&self) -> bool
Source§impl<'a, T: TensorDataType> StridedMemory for Tensor<'a, T>
impl<'a, T: TensorDataType> StridedMemory for Tensor<'a, T>
Source§fn shape(&self) -> &[usize]
fn shape(&self) -> &[usize]
Returns the dimensions of the tensor along each axis.
let a = Tensor::new([3.0, 4.0, 5.0]);
assert_eq!(a.shape(), &[3]);
let b = Tensor::new([[3.0], [5.0]]);
assert_eq!(b.shape(), &[2, 1]);
let c = Tensor::scalar(0.0);
assert_eq!(c.shape(), &[]);Source§fn stride(&self) -> &[usize]
fn stride(&self) -> &[usize]
Returns the stride of the tensor.
The stride represents the distance in memory between elements in a tensor along each axis.
let a = Tensor::new([[3.0, 4.0], [5.0, 6.0]]);
assert_eq!(a.stride(), &[2, 1]);Source§fn flags(&self) -> NdArrayFlags
fn flags(&self) -> NdArrayFlags
Returns flags containing information about various tensor metadata.