Struct burn::prelude::Tensor

source ·
pub struct Tensor<B, const D: usize, K = Float>
where B: Backend, K: TensorKind<B>,
{ /* private fields */ }
Expand description

A tensor with a given backend, shape and data type.

Implementations§

source§

impl<const D: usize, B> Tensor<B, D>
where B: AutodiffBackend,

source

pub fn backward(&self) -> <B as AutodiffBackend>::Gradients

Backward pass of the tensor.

source

pub fn grad( &self, grads: &<B as AutodiffBackend>::Gradients ) -> Option<Tensor<<B as AutodiffBackend>::InnerBackend, D>>

Get the gradients of a tensor if it exist.

Returns a new reference to the same tensor. Therefore the same grad tensor can be accessed multiple times. If you only need to get the gradients one time, consider using grad_remove for better performance.

source

pub fn grad_remove( &self, grads: &mut <B as AutodiffBackend>::Gradients ) -> Option<Tensor<<B as AutodiffBackend>::InnerBackend, D>>

Remove the grad tensor from the grads struct returning the result.

source

pub fn grad_replace( &self, grads: &mut <B as AutodiffBackend>::Gradients, grad: Tensor<<B as AutodiffBackend>::InnerBackend, D> )

Replace the grad tensor from the grads struct with the provided gradient.

source§

impl<const D: usize, B, K> Tensor<B, D, K>

source

pub fn inner( self ) -> Tensor<<B as AutodiffBackend>::InnerBackend, D, <K as BasicAutodiffOps<B>>::InnerKind>

Returns the inner tensor without the autodiff information.

source

pub fn from_inner( inner: Tensor<<B as AutodiffBackend>::InnerBackend, D, <K as BasicAutodiffOps<B>>::InnerKind> ) -> Tensor<B, D, K>

Convert a tensor to the autodiff backend.

§Arguments
  • inner - The tensor to convert.
§Returns

The tensor converted to the autodiff backend.

source§

impl<B, const D: usize, K> Tensor<B, D, K>
where B: Backend, K: TensorKind<B>,

source

pub fn new(primitive: <K as TensorKind<B>>::Primitive<D>) -> Tensor<B, D, K>

Constructs a new Tensor.

source§

impl<B, const D: usize, K> Tensor<B, D, K>
where B: Backend, K: BasicOps<B>,

source

pub fn into_primitive(self) -> <K as TensorKind<B>>::Primitive<D>

Converts the tensor into a primitive tensor.

source

pub fn from_primitive( tensor: <K as TensorKind<B>>::Primitive<D> ) -> Tensor<B, D, K>

Converts from a primitive tensor into a tensor.

source

pub fn empty<S>(shape: S, device: &<B as Backend>::Device) -> Tensor<B, D, K>
where S: Into<Shape<D>>,

Create an empty tensor of the given shape.

source

pub fn dims(&self) -> [usize; D]

Returns the dimensions of the current tensor.

Equivalent to tensor.shape().dims.

source

pub fn shape(&self) -> Shape<D>

Returns the shape of the current tensor.

source

pub fn reshape<const D2: usize, S>(self, shape: S) -> Tensor<B, D2, K>
where S: ReshapeArgs<D2>,

Reshape the tensor to have the given shape.

A -1 in the shape is used to infer the remaining dimensions, e.g.: [2, -1] will reshape the tensor with [2, 3, 4] dimensions to [2, 12].

A 0 in the shape instructs to keep the current dimension from the original tensor, e.g.: [2, 0, 4] will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4]. This is useful when reshaping tensors with unknown dimensions and combining with -1 to infer the remaining dimensions, e.g. [0, -1] will reshape the tensor with [1, 3, 4] dimensions to [1, 12].

§Arguments
  • shape: The new shape of the tensor.
§Panics
  • If the tensor contains more than one -1 in the shape.
  • If the tensor contains values that are not positive (other than -1).
  • If the shape does not match the number of elements of the original shape.
§Example
use burn_tensor::backend::Backend;
use burn_tensor::Tensor;

fn example<B: Backend>() {
   let device = Default::default();
   let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
   // Given a 3D tensor with dimensions (2, 3, 4), reshape it to (2, 12)
   let reshaped_tensor: Tensor::<B, 2> = tensor.reshape([2, -1]);
   // The resulting tensor will have dimensions (2, 12).
   println!("{:?}", reshaped_tensor.shape());
}
source

pub fn transpose(self) -> Tensor<B, D, K>

Transpose the tensor.

§Arguments
  • tensor - The tensor to transpose.
§Returns

The transposed tensor.

source

pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor<B, D, K>

Swaps two dimensions of a tensor.

§Arguments
  • tensor - The tensor to swap the dimensions of.
  • dim1 - The first dimension to swap.
  • dim2 - The second dimension to swap.
§Returns

The tensor with the dimensions swapped.

source

pub fn permute(self, axes: [isize; D]) -> Tensor<B, D, K>

Permute the dimensions of the tensor.

§Arguments
  • axes - The new order of the dimensions. The length of the axes must be equal to the number of dimensions of the tensor. The values must be unique and in the range of the number of dimensions. The values can be negative, in which case they are used as an offset from the end.
§Returns

The tensor with the dimensions permuted.

source

pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K>

Reverse the order of elements in the tensor along the given dimensions.

§Arguments
  • axes - The dimensions to reverse. The values must be unique and in the range of the number of dimensions. The values can be negative, in which case they are used as an offset from the end.
§Returns

The tensor with the axes flipped.

source

pub fn flatten<const D2: usize>( self, start_dim: usize, end_dim: usize ) -> Tensor<B, D2, K>

Flatten the tensor along a given range of dimensions.

This function collapses the specified range of dimensions into a single dimension, effectively flattening the tensor in that range.

§Arguments
  • start_dim: The starting dimension of the range to be flattened.
  • end_dim: The ending dimension of the range to be flattened (inclusive).
§Type Parameters
  • D2: The resulting number of dimensions in the flattened tensor.
§Returns

A new Tensor<B, D2, K> instance with the specified range of dimensions flattened.

§Example

use burn_tensor::backend::Backend;
use burn_tensor::{Tensor, Shape};

fn example<B: Backend>() {
    let device = Default::default();
    let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);

    // Given a 3D tensor with dimensions (2, 3, 4), flatten the dimensions between indices 1 and 2:
    let flattened_tensor: Tensor::<B, 2> = tensor.flatten(1, 2);

    // The resulting tensor will have dimensions (2, 12).
   println!("{:?}", flattened_tensor.shape());
}
source

pub fn squeeze<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K>

Squeeze the tensor along the given dimension, removing the specified dimension of size one, and effectively reducing the rank of the tensor by one.

§Arguments
  • dim: The dimension to be squeezed.
§Type Parameters
  • ‘D2’: The resulting number of dimensions in the squeezed tensor.
§Returns

A new Tensor<B, D2, K> instance with the specified dimenension removed.

§Example

use burn_tensor::backend::Backend;
use burn_tensor::{Tensor, Shape};

fn example<B: Backend>() {
    let device = Default::default();
    let tensor = Tensor::<B, 3>::ones(Shape::new([2, 1, 4]), &device);

    // Given a 3D tensor with dimensions (2, 1, 4), squeeze the dimension 1
    let squeezed_tensor: Tensor::<B, 2> = tensor.squeeze(1);

    // Resulting tensor will have dimensions (2, 4)
    println!("{:?}", squeezed_tensor.shape());
}
source

pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K>

Unsqueeze the current tensor. Create new dimensions to fit the given size.

If the output size is higher than the current tensor.

§Example
use burn_tensor::backend::Backend;
use burn_tensor::{Tensor, Shape};

fn example<B: Backend>() {
    let device = Default::default();
    let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
    let tensor = tensor.unsqueeze::<4>();
    println!("{:?}", tensor.shape());
    // Shape { dims: [1, 1, 3, 3] }
}
source

pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K>

Creates a new tensor with a dimension of size one inserted at the specified position.

§Example
use burn_tensor::backend::Backend;
use burn_tensor::{Tensor, Shape};

fn example<B: Backend>() {
    let device = Default::default();
    let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
    let tensor: Tensor<B, 3> = tensor.unsqueeze_dim(1);
    println!("{:?}", tensor.shape());
    // Shape { dims: [3, 1, 3] }
}
source

pub fn unsqueeze_dims<const D2: usize>(self, axes: &[isize]) -> Tensor<B, D2, K>

Creates a new tensor with added dimensions of size one inserted at the specified indices. The indices can be negative, in which case they are counted from the last to the first dimension. the axes can contain duplicates, in which case the number of dimensions inserted at the index is the number of duplicates.

§Example
use burn_tensor::backend::Backend;
use burn_tensor::{Tensor, Shape};

fn example<B: Backend>() {
    let device = Default::default();
    let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
    let tensor: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
    println!("{:?}", tensor.shape());
    // Shape { dims: [1, 3, 4, 5, 1, 1] }
}
source

pub fn slice<const D2: usize>( self, ranges: [Range<usize>; D2] ) -> Tensor<B, D, K>

Returns a tensor containing the elements selected from the given ranges.

§Panics

If a range exceeds the number of elements on a dimension.

§Example
use burn_tensor::backend::Backend;
use burn_tensor::{Tensor, Shape};

fn example<B: Backend>() {
    let device = B::Device::default();
    // Create a tensor with a single dimension of ints between 0 and 11
    let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..12, &device);
    // Select elements 0, 1, 2, 3 from the first dimension
    let tensor_slices = tensor.clone().slice([0..4]);
    println!("\nexpecting [0,1,2,3] : {:?}", tensor);
    println!("expecting [4] : {:?}", tensor.dims());

    // Create a Tensor with 3 dimensions
    let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
    // This slice will select the element 0 on the first dimension,
    // elements 0,1,2 of the second dimension and element 1 of third dimension
    let tensor_slices = tensor.slice([0..1, 0..3, 1..2]);
    println!("expecting [1, 3, 1] : {:?}", tensor_slices.dims());

    // Create a tensor of ints from 0 to 11 and reshape it into three dimensions
    let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..12, &device);
    let tensor = tensor.reshape([1, 3, 4]);
    println!("\nexpecting [[[0,1,2,3],[4,5,6,7],[8,9,10,11]]] : {:?}", tensor);
    println!("expecting [1, 3, 4] : {:?}", tensor.dims());
    // Select element 0 of first dimension, elements 1,2 of second dimension
    // and element 1 of third dimension
    //
    // This is the equivalent of this pseudo code
    // let mut v = vec![[[]]];
    // v[0][0][0] = tensor[0][1][1];
    // v[0][1][0] = tensor[0][2][1];
    let tensor_slices = tensor.slice([0..1, 1..3, 1..2]);
    println!("\nexpecting [1, 2, 1] : {:?}", tensor_slices.dims());
    println!("expecting [[[5],[9]]] : {:?}", tensor_slices);
}
source

pub fn slice_assign<const D2: usize>( self, ranges: [Range<usize>; D2], values: Tensor<B, D, K> ) -> Tensor<B, D, K>

Returns a copy of the current tensor with the selected elements changed to the new ones at the selected indices.

§Panics
  • If a range exceeds the number of elements on a dimension.
  • If the given values don’t match the given ranges.
§Example
use burn_tensor::backend::Backend;
use burn_tensor::Tensor;

fn example<B: Backend>() {
    let device = B::Device::default();
    let tensor = Tensor::<B, 3>::ones([2, 3, 3], &device);
    let values = Tensor::<B, 3>::zeros([1, 1, 1], &device);
    let tensor_sliced = tensor.slice_assign([0..1, 0..1, 0..1], values);
    println!("{:?}", tensor_sliced.dims()); // [2, 3, 3]
}
source

pub fn device(&self) -> <B as Backend>::Device

Returns the device of the current tensor.

source

pub fn to_device(self, device: &<B as Backend>::Device) -> Tensor<B, D, K>

Returns a new tensor on the given device.

source

pub fn into_data(self) -> Data<<K as BasicOps<B>>::Elem, D>

Returns the data of the current tensor.

source

pub fn to_data(&self) -> Data<<K as BasicOps<B>>::Elem, D>

Returns the data of the current tensor without taking ownership.

source

pub fn from_data<T>(data: T, device: &<B as Backend>::Device) -> Tensor<B, D, K>
where T: Into<Data<<K as BasicOps<B>>::Elem, D>>,

Create a tensor from the given data on the given device.

source

pub fn repeat(self, dim: usize, times: usize) -> Tensor<B, D, K>

Repeat the tensor along the given dimension.

§Panics

If the selected dimension more than one item.

source

pub fn equal(self, other: Tensor<B, D, K>) -> Tensor<B, D, Bool>

Applies element-wise equal comparison and returns a boolean tensor.

§Panics

If the two tensors don’t have the same shape.

source

pub fn not_equal(self, other: Tensor<B, D, K>) -> Tensor<B, D, Bool>

Applies element-wise non-equality comparison and returns a boolean tensor.

§Panics

If the two tensors don’t have the same shape.

source

pub fn cat(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D, K>

Concatenates all tensors into a new one along the given dimension.

§Panics

If all tensors don’t have the same shape.

source

pub fn stack<const D2: usize>( tensors: Vec<Tensor<B, D, K>>, dim: usize ) -> Tensor<B, D2, K>

Concatenates all tensors into a new one along a new dimension.

§Panics

If all tensors don’t have the same shape. Given dimension is not with range of 0..D2

source

pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K>

Iterate over slices of tensors alongside a given dimension.

§Panics

Given dimension is less than tensor rank.

§Returns

A tensor iterator.

source

pub fn narrow(self, dim: usize, start: usize, length: usize) -> Tensor<B, D, K>

Returns a new tensor with the given dimension narrowed to the given range.

§Panics
  • If the dimension is greater than the number of dimensions of the tensor.
  • If the given range exceeds the number of elements on the given dimension.
§Returns

A new tensor with the given dimension narrowed to the given range.

source

pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Tensor<B, D, K>>

Attempts to split the tensor along the given dimension into chunks. May return less chunks than requested if the tensor size is not divisible by the number of chunks.

When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size. Otherwise all chunks will be of equal size except for the last one.

§Panics

If the dimension is greater than the number of dimensions of the tensor.

§Returns

A vector of tensors.

source

pub fn any(self) -> Tensor<B, 1, Bool>

Tests if any element in the tensor evaluates to True.

§Arguments
  • tensor - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
§Returns

A boolean tensor Tensor<B, 1, Bool> containing a single element, True if any element in the input tensor evaluates to True, False otherwise.

source

pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool>

Tests if any element in the tensor evaluates to True along a given dimension dim.

§Arguments
  • tensor - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
  • dim - The axis along which to test.
§Returns

A boolean tensor Tensor<B, D, Bool> with the same size as input tensor, except in the dim axis where the size is 1. The elem in the dim axis is True if any element along this dim in the input evaluates to True, False otherwise.

source

pub fn all(self) -> Tensor<B, 1, Bool>

Tests if all elements in the tensor evaluate to True.

§Arguments
  • tensor - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
§Returns

A boolean tensor Tensor<B, 1, Bool> with a single element, True if all elements in the input tensor evaluate to True, False otherwise.

source

pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool>

Tests if all elements in the tensor evaluate to True along a given dimension dim.

§Arguments
  • tensor - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
  • dim - The axis along which to test.
§Returns

A boolean tensor Tensor<B, D, Bool> with the same size as input tensor, except in the dim axis where the size is 1. The elem in the dim axis is True if all elements along this dim in the input evaluates to True, False otherwise.

source

pub fn into_scalar(self) -> <K as BasicOps<B>>::Elem

Convert the tensor into a scalar.

§Panics

If the tensor doesn’t have one element.

source

pub fn expand<const D2: usize, S>(self, shape: S) -> Tensor<B, D2, K>
where S: BroadcastArgs<D, D2>,

Broadcast the tensor to the given shape.

§Arguments
  • shape - The shape to broadcast the tensor to. Can contain -1 for dimensions that should be inferred. The number of elements in the shape must be greater or equal as the number of dimensions of the tensor.
§Panics

If the tensor cannot be broadcasted to the given shape.

§Returns

A new tensor with the given shape.

source§

impl<B, const D: usize> Tensor<B, D, Bool>
where B: Backend,

source

pub fn from_bool( data: Data<bool, D>, device: &<B as Backend>::Device ) -> Tensor<B, D, Bool>

Create a boolean tensor from data on the given device.

source

pub fn int(self) -> Tensor<B, D, Int>

Convert the bool tensor into an int tensor.

source

pub fn float(self) -> Tensor<B, D>

Convert the bool tensor into an float tensor.

source

pub fn bool_not(self) -> Tensor<B, D, Bool>

Inverses boolean values.

source

pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>>

Compute the indices of the elements that are non-zero.

§Returns

A vector of tensors, one for each dimension of the given tensor, containing the indices of the non-zero elements in that dimension.

source

pub fn argwhere(self) -> Tensor<B, 2, Int>

Compute the indices of the elements that are true, grouped by element.

§Returns

A tensor containing the indices of all non-zero elements of the given tensor. Each row in the result contains the indices of a non-zero element.

source

pub fn triu_mask<S>( shape: S, offset: i64, device: &<B as Backend>::Device ) -> Tensor<B, D, Bool>
where S: Into<Shape<D>>,

Creates a mask for the upper triangle of a matrix, which can be used to fill the specified area with a value.

This function generates a boolean tensor representing the mask of the upper triangle of a matrix.

§Arguments
  • shape: The shape of the matrix.
  • offset: The offset from the diagonal, where 0 means the diagonal, and positive values shift towards the upper triangle.
  • device: The device on which the tensor will be allocated.
§Returns

Returns a boolean tensor where true indicates the elements of the matrix that are part of the upper triangle taking into account the specified offset.

source

pub fn tril_mask<S>( shape: S, offset: i64, device: &<B as Backend>::Device ) -> Tensor<B, D, Bool>
where S: Into<Shape<D>>,

Creates a mask for the lower triangle of a matrix, which can be used to fill the specified area with a value.

This function generates a boolean tensor representing the mask of the lower triangle of a matrix.

§Arguments
  • shape: The shape of the matrix.
  • offset: The offset from the diagonal, where 0 means the diagonal, and negative values shift towards the lower triangle.
  • device: The device on which the tensor will be allocated.
§Returns

Returns a boolean tensor where true indicates the elements of the matrix that are part of the lower triangle taking into account the specified offset.

source

pub fn diag_mask<S>( shape: S, offset: i64, device: &<B as Backend>::Device ) -> Tensor<B, D, Bool>
where S: Into<Shape<D>>,

Creates a mask for the diagonal of a matrix, which can be used to fill the specified area with a value.

This function generates a boolean tensor representing the mask of the diagonal of a matrix.

§Arguments
  • shape: The shape of the matrix.
  • device: The device on which the tensor will be allocated.
§Returns

Returns a boolean tensor where true indicates the elements of the matrix that are part of the diagonal.

source§

impl<const D: usize, B> Tensor<B, D>
where B: Backend,

source

pub fn inplace<F>(&mut self, func: F)
where F: FnOnce(Tensor<B, D>) -> Tensor<B, D>,

Executes an operation on the tensor and modifies its value.

§Notes

This won’t necessary reuse the same tensor data/buffer, but it should if there is no other reference pointing to the same tensor.

Wrapping operations with inplace is not an optimization, it’s mainly there if you want to mutate a tensor by using owned operations. A plausible usage would be to update the weights of a mutable model reference.

source

pub fn exp(self) -> Tensor<B, D>

Applies element wise exponential operation.

y = e^x

source

pub fn log(self) -> Tensor<B, D>

Applies element wise natural log operation ln.

y = log(x)

source

pub fn log1p(self) -> Tensor<B, D>

Applies the natural logarithm of one plus the input tensor, element-wise.

y = log(x+1)

source

pub fn erf(self) -> Tensor<B, D>

Applies the error function element wise.

y = erf(x)

source

pub fn recip(self) -> Tensor<B, D>

Applies element wise reciprocal operation.

source

pub fn sqrt(self) -> Tensor<B, D>

Applies element wise root square operation.

source

pub fn cos(self) -> Tensor<B, D>

Applies element wise cosine operation.

source

pub fn sin(self) -> Tensor<B, D>

Applies element wise sine operation.

source

pub fn tanh(self) -> Tensor<B, D>

Applies element wise hyperbolic tangent operation.

source

pub fn from_floats<A>( floats: A, device: &<B as Backend>::Device ) -> Tensor<B, D>
where A: Into<Data<f32, D>>,

Create a tensor from floats (f32) on a given device.

§Example
use burn_tensor::backend::Backend;
use burn_tensor::Tensor;

fn example<B: Backend>() {
    let device = B::Device::default();
    let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
    let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
}
source

pub fn int(self) -> Tensor<B, D, Int>

Returns a new tensor with the same shape and device as the current tensor and the data casted to Integer.

§Example
use burn_tensor::backend::Backend;
use burn_tensor::Tensor;

fn example<B: Backend>() {
    let device = Default::default();
    let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);
    let int_tensor = float_tensor.int();
}
source

pub fn zeros_like(&self) -> Tensor<B, D>

Returns a new tensor with the same shape and device as the current tensor filled with zeros.

source

pub fn ones_like(&self) -> Tensor<B, D>

Returns a new tensor with the same shape and device as the current tensor filled with ones.

source

pub fn random_like(&self, distribution: Distribution) -> Tensor<B, D>

Returns a new tensor with the same shape and device as the current tensor filled random values sampled from the given distribution.

source

pub fn one_hot( index: usize, num_classes: usize, device: &<B as Backend>::Device ) -> Tensor<B, D>

Create a one hot tensor.

§Example
use burn_tensor::backend::Backend;
use burn_tensor::Tensor;

fn example<B: Backend>() {
    let device = Default::default();
    let one_hot = Tensor::<B, 1>::one_hot(2, 10, &device);
    println!("{}", one_hot.to_data());
    // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
}
source

pub fn matmul(self, other: Tensor<B, D>) -> Tensor<B, D>

Applies the matrix multiplication operation.

C = AB

§Panics

If the two tensors dont’ have a compatible shape.

source

pub fn var(self, dim: usize) -> Tensor<B, D>

Calculate the variance along the given dimension.

source

pub fn var_bias(self, dim: usize) -> Tensor<B, D>

Calculate the variance along the given dimension without applying the Bessel’s correction.

source

pub fn var_mean(self, dim: usize) -> (Tensor<B, D>, Tensor<B, D>)

Calculate the variance along the given dimension and also returns the mean.

source

pub fn var_mean_bias(self, dim: usize) -> (Tensor<B, D>, Tensor<B, D>)

Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.

source

pub fn into_full_precision( self ) -> Tensor<<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target, D>

Returns a tensor with full precision based on the selected backend.

source

pub fn from_full_precision( tensor: Tensor<<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target, D> ) -> Tensor<B, D>

Returns a tensor on the selected backend from a full precision tensor.

source

pub fn detach(self) -> Tensor<B, D>

Detach the current tensor from the autodiff graph.

This function does nothing when autodiff is not enabled. This can be used in batchers or elsewhere to ensure that previous operations are not considered in the autodiff graph.

source

pub fn require_grad(self) -> Tensor<B, D>

Mark the tensor to keep gradients during the backward pass.

This function does nothing when autodiff is not enabled.

source

pub fn is_require_grad(&self) -> bool

Returns true if the tensor requires gradients during the backward pass.

source

pub fn set_require_grad(self, require_grad: bool) -> Tensor<B, D>

Mark the tensor as tracked or untracked depending on the require grad argument. When tracked, the gradients will be available after the backward pass.

This function does nothing when autodiff is not enabled.

source

pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D>

Calculate covaraince matrix between different entries alongside a given dimension.

§Arguments
  • size - The size of the square matrix.
  • correction_factor - Is usually 1 for samples and 0 for population.
source§

impl<B> Tensor<B, 1, Int>
where B: Backend,

source

pub fn arange( range: Range<i64>, device: &<B as Backend>::Device ) -> Tensor<B, 1, Int>

Returns a new integer tensor on the specified device.

§Arguments
  • range - The range of values to generate.
  • device - The device to create the tensor on.
source

pub fn arange_step( range: Range<i64>, step: usize, device: &<B as Backend>::Device ) -> Tensor<B, 1, Int>

Returns a new integer tensor on the specified device.

§Arguments
  • range - The range of values to generate.
  • step - The step between each value.
source§

impl<const D: usize, B> Tensor<B, D, Int>
where B: Backend,

source

pub fn from_ints<A>( ints: A, device: &<B as Backend>::Device ) -> Tensor<B, D, Int>
where A: Into<Data<i32, D>>,

Create a tensor from integers (i32), placing it on a given device.

§Example
use burn_tensor::backend::Backend;
use burn_tensor::{Tensor, Int};

fn example<B: Backend>() {
    let device = B::Device::default();
    let _x: Tensor<B, 1, Int> = Tensor::from_ints([1, 2], &device);
    let _y: Tensor<B, 2, Int> = Tensor::from_ints([[1, 2], [3, 4]], &device);
}
source

pub fn float(self) -> Tensor<B, D>

Returns a new tensor with the same shape and device as the current tensor and the data casted to Float.

§Example
use burn_tensor::backend::Backend;
use burn_tensor::{Int, Tensor};

fn example<B: Backend>() {
    let device = Default::default();
    let int_tensor = Tensor::<B, 1, Int>::arange(0..5, &device);
    let float_tensor = int_tensor.float();
}
source§

impl<B, const D: usize, K> Tensor<B, D, K>
where B: Backend, K: Numeric<B>, <K as BasicOps<B>>::Elem: Element,

source

pub fn add(self, other: Tensor<B, D, K>) -> Tensor<B, D, K>

Applies element wise addition operation.

y = x2 + x1

source

pub fn add_scalar<E>(self, other: E) -> Tensor<B, D, K>

Applies element wise addition operation with a scalar.

y = x + s

source

pub fn sub(self, other: Tensor<B, D, K>) -> Tensor<B, D, K>

Applies element wise subtraction operation.

y = x2 - x1

source

pub fn sub_scalar<E>(self, other: E) -> Tensor<B, D, K>

Applies element wise subtraction operation with a scalar.

y = x - s

source

pub fn div(self, other: Tensor<B, D, K>) -> Tensor<B, D, K>

Applies element wise division operation.

y = x2 / x1

source

pub fn div_scalar<E>(self, other: E) -> Tensor<B, D, K>

Applies element wise division operation with a scalar.

y = x / s

source

pub fn mul(self, other: Tensor<B, D, K>) -> Tensor<B, D, K>

Applies element wise multiplication operation.

y = x2 * x1

source

pub fn mul_scalar<E>(self, other: E) -> Tensor<B, D, K>

Applies element wise multiplication operation with a scalar.

y = x * s

source

pub fn neg(self) -> Tensor<B, D, K>

Switch sign of each element in the tensor.

y = -x

source

pub fn sign(self) -> Tensor<B, D, K>

Returns the signs of the elements of the input tensor.

source

pub fn zeros<S>(shape: S, device: &<B as Backend>::Device) -> Tensor<B, D, K>
where S: Into<Shape<D>>,

Create a tensor of the given shape where each element is zero.

source

pub fn ones<S>(shape: S, device: &<B as Backend>::Device) -> Tensor<B, D, K>
where S: Into<Shape<D>>,

Create a tensor of the given shape where each element is one.

source

pub fn full<S, E>( shape: S, fill_value: E, device: &<B as Backend>::Device ) -> Tensor<B, D, K>
where S: Into<Shape<D>>, E: ElementConversion,

Create a tensor of the given shape where each element is equal to the provided value.

source

pub fn mean(self) -> Tensor<B, 1, K>

Aggregate all elements in the tensor with the mean operation.

source

pub fn sum(self) -> Tensor<B, 1, K>

Aggregate all elements in the tensor with the sum operation.

source

pub fn mean_dim(self, dim: usize) -> Tensor<B, D, K>

Aggregate all elements along the given dimension or axis in the tensor with the mean operation.

source

pub fn sum_dim(self, dim: usize) -> Tensor<B, D, K>

Aggregate all elements along the given dimension or axis in the tensor with the sum operation.

source

pub fn prod(self) -> Tensor<B, 1, K>

Aggregate all elements along the given dimension or axis in the tensor with the product operation.

source

pub fn prod_dim(self, dim: usize) -> Tensor<B, D, K>

Aggregate all elements along the given dimension or axis in the tensor with the product operation.

source

pub fn equal_elem<E>(self, other: E) -> Tensor<B, D, Bool>
where E: Element,

Applies element wise equal comparison and returns a boolean tensor.

source

pub fn not_equal_elem<E>(self, other: E) -> Tensor<B, D, Bool>
where E: Element,

Applies element wise non-equality comparison and returns a boolean tensor.

source

pub fn greater(self, other: Tensor<B, D, K>) -> Tensor<B, D, Bool>

Applies element wise greater comparison and returns a boolean tensor.

§Panics

If the two tensors don’t have the same shape.

source

pub fn greater_equal(self, other: Tensor<B, D, K>) -> Tensor<B, D, Bool>

Applies element wise greater-equal comparison and returns a boolean tensor.

§Panics

If the two tensors don’t have the same shape.

source

pub fn lower(self, other: Tensor<B, D, K>) -> Tensor<B, D, Bool>

Applies element wise lower comparison and returns a boolean tensor.

§Panics

If the two tensors don’t have the same shape.

source

pub fn lower_equal(self, other: Tensor<B, D, K>) -> Tensor<B, D, Bool>

Applies element wise lower-equal comparison and returns a boolean tensor.

§Panics

If the two tensors don’t have the same shape.

source

pub fn greater_elem<E>(self, other: E) -> Tensor<B, D, Bool>

Applies element wise greater comparison and returns a boolean tensor.

source

pub fn greater_equal_elem<E>(self, other: E) -> Tensor<B, D, Bool>

Applies element wise greater-equal comparison and returns a boolean tensor.

source

pub fn lower_elem<E>(self, other: E) -> Tensor<B, D, Bool>

Applies element wise lower comparison and returns a boolean tensor.

source

pub fn lower_equal_elem<E>(self, other: E) -> Tensor<B, D, Bool>

Applies element wise lower-equal comparison and returns a boolean tensor.

source

pub fn mask_where( self, mask: Tensor<B, D, Bool>, value: Tensor<B, D, K> ) -> Tensor<B, D, K>

Update the given tensor with the value tensor where the mask is true.

This is similar to mask_fill, however the value is a tensor instead of a scalar.

source

pub fn mask_fill<E>(self, mask: Tensor<B, D, Bool>, value: E) -> Tensor<B, D, K>

Update the given tensor with the value where the mask is true.

This is similar to mask_where, however the value is a scalar instead of a tensor.

source

pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Tensor<B, D, K>

Gather tensor elements corresponding to the given indices from the specified dim.

Example using a 3D tensor:

output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0 output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1 output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2

§Notes

The index tensor should have the same shape as the original tensor except for the dim specified.

source

pub fn scatter( self, dim: usize, indices: Tensor<B, D, Int>, values: Tensor<B, D, K> ) -> Tensor<B, D, K>

Assign the gathered elements corresponding to the given indices along the specified dimension from the value tensor to the original tensor using sum reduction.

Example using a 3D tensor:

input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0 input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1 input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2

§Notes

The index tensor should have the same shape as the original tensor except for the specified dimension. The value and index tensors should have the same shape.

Other references to the input tensor will not be modified by this operation.

source

pub fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Tensor<B, D, K>

Select the tensor elements along the given dimension corresponding to the given indices.

Example using a 3D tensor:

output[i, j, k] = input[indices[i], j, k]; // dim = 0 output[i, j, k] = input[i, indices[j], k]; // dim = 1 output[i, j, k] = input[i, j, indices[k]]; // dim = 2

source

pub fn select_assign( self, dim: usize, indices: Tensor<B, 1, Int>, values: Tensor<B, D, K> ) -> Tensor<B, D, K>

Assign the selected elements along the given dimension corresponding to the given indices from the value tensor to the original tensor using sum reduction.

Example using a 3D tensor:

input[indices[i], j, k] += values[i, j, k]; // dim = 0 input[i, indices[j], k] += values[i, j, k]; // dim = 1 input[i, j, indices[k]] += values[i, j, k]; // dim = 2

source

pub fn argmax(self, dim: usize) -> Tensor<B, D, Int>

Applies the argmax function along the given dimension and returns an integer tensor.

§Example
use burn_tensor::backend::Backend;
use burn_tensor::{Tensor, Shape};

fn example<B: Backend>() {
    let device = B::Device::default();
    let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
    let tensor = tensor.argmax(1);
    println!("{:?}", tensor.shape());
    // Shape { dims: [2, 1, 3] }
}
source

pub fn max(self) -> Tensor<B, 1, K>

Find the maximum value.

source

pub fn max_dim(self, dim: usize) -> Tensor<B, D, K>

Find the maximum value along the given dimension.

source

pub fn max_dim_with_indices( self, dim: usize ) -> (Tensor<B, D, K>, Tensor<B, D, Int>)

Find the maximum value along the given dimension.

Also returns the indices.

source

pub fn max_pair(self, other: Tensor<B, D, K>) -> Tensor<B, D, K>

Finds the maximum pair wise values with another Tensor

§Arguments
  • other - Other tensor to find maximum elements with
§Returns

A tensor with the same shape as the input tensors containing the maximum value found in the input tensors.

source

pub fn argmin(self, dim: usize) -> Tensor<B, D, Int>

Applies the argmin function along the given dimension and returns an integer tensor.

§Example
use burn_tensor::backend::Backend;
use burn_tensor::{Tensor, Shape};

fn example<B: Backend>() {
    let device = Default::default();
    let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
    let tensor = tensor.argmin(1);
    println!("{:?}", tensor.shape());
    // Shape { dims: [2, 1, 3] }
}
source

pub fn min(self) -> Tensor<B, 1, K>

Find the minimum value.

source

pub fn min_dim(self, dim: usize) -> Tensor<B, D, K>

Find the minimum value along the given dimension.

source

pub fn min_dim_with_indices( self, dim: usize ) -> (Tensor<B, D, K>, Tensor<B, D, Int>)

Find the minimum value along the given dimension.

Also returns the indices.

source

pub fn min_pair(self, other: Tensor<B, D, K>) -> Tensor<B, D, K>

Finds the minimum pair wise values with another Tensor

§Arguments
  • other - Other tensor to find minimum elements with
§Returns

A tensor with the same shape as the input tensors containing the minimum value found between each element of the two source tensors.

source

pub fn clamp<E>(self, min: E, max: E) -> Tensor<B, D, K>

Clamp the tensor between the given min and max values.

§Arguments
  • min - The minimum value.
  • max - The maximum value.
§Returns

A new tensor with the values clamped between the given min and max values.

source

pub fn clamp_min<E>(self, min: E) -> Tensor<B, D, K>

Clamps a tensor under a minimum value.

§Arguments
  • tensor - The tensor to clamp.
  • min - The minimum value.
§Returns

A new tensor with the values clamped under the given min value.

source

pub fn clamp_max<E>(self, max: E) -> Tensor<B, D, K>

Clamps a tensor over a maximum value.

§Arguments
  • tensor - The tensor to clamp.
  • max - The maximum value.
§Returns

A new tensor with the values clamped over the given max value.

source

pub fn abs(self) -> Tensor<B, D, K>

Apply element wise absolute value operation

source

pub fn triu(self, diagonal: i64) -> Tensor<B, D, K>

Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.

§Example
use burn_tensor::backend::Backend;
use burn_tensor::{Int, Tensor};

fn example<B: Backend>() {
   let device = Default::default();
   let tensor = Tensor::<B, 2, Int>::from_ints(
       [
         [1, 2, 3],
         [4, 5, 6],
         [7, 8, 9]
       ],
       &device
   );
   let tensor = tensor.triu(1);
   println!("{}", tensor);
   // Tensor { data: [
   //   [0, 2, 3],
   //   [0, 0, 6],
   //   [0, 0, 0]
   // ], ... }
}
source

pub fn tril(self, diagonal: i64) -> Tensor<B, D, K>

Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.

§Example
use burn_tensor::backend::Backend;
use burn_tensor::{Int, Tensor};

fn example<B: Backend>() {
   let device = Default::default();
   let tensor = Tensor::<B, 2, Int>::from_ints(
       [
         [1, 2, 3],
         [4, 5, 6],
         [7, 8, 9]
       ],
       &device
   );

   let tensor = tensor.tril(-1);
   println!("{}", tensor);
   // Tensor { data: [
   //   [0, 0, 0],
   //   [4, 0, 0],
   //   [7, 8, 0]
   // ], ... }
}
source

pub fn powf(self, other: Tensor<B, D, K>) -> Tensor<B, D, K>

Applies element wise power operation with a float Tensor

source

pub fn powf_scalar<E>(self, other: E) -> Tensor<B, D, K>

Applies element wise power operation with a float scalar

source

pub fn powi(self, other: Tensor<B, D, K>) -> Tensor<B, D, K>

Applies element wise power operation with a integer Tensor

source

pub fn powi_scalar<E>(self, other: E) -> Tensor<B, D, K>

Applies element wise power operation with a integer scalar

source

pub fn is_close( self, other: Tensor<B, D, K>, rtol: Option<f64>, atol: Option<f64> ) -> Tensor<B, D, Bool>

Checks element wise if the tensor is close to another tensor.

The tolerance is defined by the following equation:

abs(a - b) <= (atol + rtol * abs(b))

where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
and `atol` is the absolute tolerance.
§Arguments
  • other - The tensor to compare with.
  • rtol - Optional relative tolerance. Default is 1e-5.
  • atol - Optional absolute tolerance. Default is 1e-8.
§Returns

A boolean tensor with the same shape as the input tensors.

source

pub fn all_close( self, other: Tensor<B, D, K>, rtol: Option<f64>, atol: Option<f64> ) -> bool

Checks if all elements are close to another tensor.

The tolerance is defined by the following equation:


abs(a - b) <= (atol + rtol * abs(b))

where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
and `atol` is the absolute tolerance.

§Arguments
  • other - The tensor to compare with.
  • rtol - Optional relative tolerance. Default is 1e-5.
  • atol - Optional absolute tolerance. Default is 1e-8.
§Returns

A boolean scalar.

§Remarks

This method is only available for non-wasm targets or when the wasm-sync feature is enabled.

source

pub fn bool(self) -> Tensor<B, D, Bool>

Converts the tensor to a boolean tensor by checking if the elements are non-zero.

§Returns

A boolean tensor with the same shape as the input tensor.

source

pub fn random<S>( shape: S, distribution: Distribution, device: &<B as Backend>::Device ) -> Tensor<B, D, K>
where S: Into<Shape<D>>,

Create a random tensor of the given shape on the given device where each element is sampled from the given distribution.

source

pub fn sort(self, dim: usize) -> Tensor<B, D, K>

Sort the elements by value in ascending order along a given dimension.

This sort is unstable (i.e., may reorder equal elements).

source

pub fn sort_descending(self, dim: usize) -> Tensor<B, D, K>

Sort the elements by value in descending order along a given dimension.

This sort is unstable (i.e., may reorder equal elements).

source

pub fn sort_with_indices( self, dim: usize ) -> (Tensor<B, D, K>, Tensor<B, D, Int>)

Sort the elements by value in ascending order along a given dimension. Also returns the indices.

This sort is unstable (i.e., may reorder equal elements).

source

pub fn sort_descending_with_indices( self, dim: usize ) -> (Tensor<B, D, K>, Tensor<B, D, Int>)

Sort the elements by value in descending order along a given dimension. Also returns the indices.

This sort is unstable (i.e., may reorder equal elements).

source

pub fn argsort(self, dim: usize) -> Tensor<B, D, Int>

Returns the indices that sort the elements by value in ascending order along a given dimension.

This sort is unstable (i.e., may reorder equal elements).

source

pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int>

Returns the indices that sort the elements by value in descending order along a given dimension.

This sort is unstable (i.e., may reorder equal elements).

source

pub fn topk(self, k: usize, dim: usize) -> Tensor<B, D, K>

Returns the k largest elements of the given input tensor along a given dimension.

source

pub fn topk_with_indices( self, k: usize, dim: usize ) -> (Tensor<B, D, K>, Tensor<B, D, Int>)

Returns the k largest elements of the given input tensor along a given dimension. Also returns the indices.

source

pub fn pad( self, padding: (usize, usize, usize, usize), value: <K as BasicOps<B>>::Elem ) -> Tensor<B, D, K>

Pad the tensor with the given value on the last two dimensions.

§Arguments
  • padding - A tuple of four integers representing the padding on the left, right, top, and bottom.
  • value - The value to pad the tensor with.
§Returns

A new tensor with the given padding.

source§

impl<B, K> Tensor<B, 2, K>
where B: Backend, K: Numeric<B>, <K as BasicOps<B>>::Elem: Element,

source

pub fn eye(size: usize, device: &<B as Backend>::Device) -> Tensor<B, 2, K>

Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.

§Arguments
  • size - The size of the square matrix.

Trait Implementations§

source§

impl<E, const D: usize, B, K> Add<E> for Tensor<B, D, K>
where E: ElementConversion, B: Backend, K: Numeric<B>, <K as BasicOps<B>>::Elem: Element,

§

type Output = Tensor<B, D, K>

The resulting type after applying the + operator.
source§

fn add(self, other: E) -> Tensor<B, D, K>

Performs the + operation. Read more
source§

impl<B, const D: usize, K> Add for Tensor<B, D, K>
where B: Backend, K: Numeric<B>, <K as BasicOps<B>>::Elem: Element,

§

type Output = Tensor<B, D, K>

The resulting type after applying the + operator.
source§

fn add(self, rhs: Tensor<B, D, K>) -> Tensor<B, D, K>

Performs the + operation. Read more
source§

impl<const D: usize, B, K> AutodiffModule<B> for Tensor<B, D, K>

§

type InnerModule = Tensor<<B as AutodiffBackend>::InnerBackend, D, <K as BasicAutodiffOps<B>>::InnerKind>

Inner module without auto-differentiation.
source§

fn valid(&self) -> <Tensor<B, D, K> as AutodiffModule<B>>::InnerModule

Get the same module, but on the inner backend without auto-differentiation.
source§

impl<B, const D: usize> BitXor<T> for Tensor<B, D>
where B: Backend,

§

type Output = Tensor<B, D>

The resulting type after applying the ^ operator.
source§

fn bitxor(self, _: T) -> <Tensor<B, D> as BitXor<T>>::Output

Performs the ^ operation. Read more
source§

impl<B, const D: usize, K> Clone for Tensor<B, D, K>
where B: Clone + Backend, K: Clone + TensorKind<B>, <K as TensorKind<B>>::Primitive<D>: Clone,

source§

fn clone(&self) -> Tensor<B, D, K>

Returns a copy of the value. Read more
1.0.0 · source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
source§

impl<B, const D: usize, K> Debug for Tensor<B, D, K>
where B: Debug + Backend, K: Debug + TensorKind<B>, <K as TensorKind<B>>::Primitive<D>: Debug,

source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>

Formats the value using the given formatter. Read more
source§

impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
where B: Backend, K: BasicOps<B>, <K as BasicOps<B>>::Elem: Debug + Copy + Deserialize<'de>,

source§

fn deserialize<De>( deserializer: De ) -> Result<Tensor<B, D, K>, <De as Deserializer<'de>>::Error>
where De: Deserializer<'de>,

Deserialize this value from the given Serde deserializer. Read more
source§

impl<B, const D: usize, K> Display for Tensor<B, D, K>
where B: Backend, <B as Backend>::IntElem: Display, K: BasicOps<B>, <K as BasicOps<B>>::Elem: Debug,

Pretty print tensors

source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>

Formats the value using the given formatter. Read more
source§

impl<E, const D: usize, B, K> Div<E> for Tensor<B, D, K>
where E: ElementConversion, B: Backend, K: Numeric<B>, <K as BasicOps<B>>::Elem: Element,

§

type Output = Tensor<B, D, K>

The resulting type after applying the / operator.
source§

fn div(self, other: E) -> Tensor<B, D, K>

Performs the / operation. Read more
source§

impl<B, const D: usize, K> Div for Tensor<B, D, K>
where B: Backend, K: Numeric<B>, <K as BasicOps<B>>::Elem: Element,

§

type Output = Tensor<B, D, K>

The resulting type after applying the / operator.
source§

fn div(self, rhs: Tensor<B, D, K>) -> Tensor<B, D, K>

Performs the / operation. Read more
source§

impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
where B: Backend, K: BasicOps<B>, T: Into<Data<<K as BasicOps<B>>::Elem, D>>,

source§

fn from(value: T) -> Tensor<B, D, K>

Converts to this type from the input type.
source§

impl<const D: usize, B, K> Module<B> for Tensor<B, D, K>
where B: Backend, K: BasicOps<B>,

§

type Record = ConstantRecord

Type to save and load the module.
source§

fn visit<V>(&self, _visitor: &mut V)
where V: ModuleVisitor<B>,

Visit each tensor parameter in the module with a visitor.
source§

fn map<M>(self, _mapper: &mut M) -> Tensor<B, D, K>
where M: ModuleMapper<B>,

Map each tensor parameter in the module with a mapper.
source§

fn into_record(self) -> <Tensor<B, D, K> as Module<B>>::Record

Convert the module into a record containing the state.
source§

fn load_record( self, _record: <Tensor<B, D, K> as Module<B>>::Record ) -> Tensor<B, D, K>

Load the module state from a record.
source§

fn to_device(self, device: &<B as Backend>::Device) -> Tensor<B, D, K>

Move the module and all of its sub-modules to the given device. Read more
source§

fn fork(self, device: &<B as Backend>::Device) -> Tensor<B, D, K>

Fork the module and all of its sub-modules to the given device. Read more
source§

fn collect_devices( &self, devices: Vec<<B as Backend>::Device> ) -> Vec<<B as Backend>::Device>

Return all the devices found in the underneath module tree added to the given vector without duplicates.
source§

fn devices(&self) -> Vec<<B as Backend>::Device>

Return all the devices found in the underneath module tree without duplicates.
source§

fn no_grad(self) -> Self

Each tensor in the module tree will not require grad. Read more
source§

fn num_params(&self) -> usize

Get the number of parameters the module has, including all of its sub-modules.
source§

fn save_file<FR, PB>( self, file_path: PB, recorder: &FR ) -> Result<(), RecorderError>
where FR: FileRecorder<B>, PB: Into<PathBuf>,

Save the module to a file using the provided file recorder. Read more
source§

fn load_file<FR, PB>( self, file_path: PB, recorder: &FR, device: &<B as Backend>::Device ) -> Result<Self, RecorderError>
where FR: FileRecorder<B>, PB: Into<PathBuf>,

Load the module from a file using the provided file recorder. Read more
source§

impl<E, const D: usize, B, K> Mul<E> for Tensor<B, D, K>
where E: ElementConversion, B: Backend, K: Numeric<B>, <K as BasicOps<B>>::Elem: Element,

§

type Output = Tensor<B, D, K>

The resulting type after applying the * operator.
source§

fn mul(self, other: E) -> Tensor<B, D, K>

Performs the * operation. Read more
source§

impl<B, const D: usize, K> Mul for Tensor<B, D, K>
where B: Backend, K: Numeric<B>, <K as BasicOps<B>>::Elem: Element,

§

type Output = Tensor<B, D, K>

The resulting type after applying the * operator.
source§

fn mul(self, rhs: Tensor<B, D, K>) -> Tensor<B, D, K>

Performs the * operation. Read more
source§

impl<B, const D: usize, K> Neg for Tensor<B, D, K>
where B: Backend, K: Numeric<B>, <K as BasicOps<B>>::Elem: Element,

§

type Output = Tensor<B, D, K>

The resulting type after applying the - operator.
source§

fn neg(self) -> Tensor<B, D, K>

Performs the unary - operation. Read more
source§

impl<B, const D: usize> Parameter for Tensor<B, D>
where B: Backend,

§

type Device = <B as Backend>::Device

The device type to be used.
source§

fn device(&self) -> <Tensor<B, D> as Parameter>::Device

Fetch the device.
source§

fn is_require_grad(&self) -> bool

Fetch the gradient requirement.
source§

fn set_require_grad(self, require_grad: bool) -> Tensor<B, D>

Set the gradient requirement.
source§

impl<B, const D: usize> Parameter for Tensor<B, D, Bool>
where B: Backend,

§

type Device = <B as Backend>::Device

The device type to be used.
source§

fn device(&self) -> <Tensor<B, D, Bool> as Parameter>::Device

Fetch the device.
source§

fn is_require_grad(&self) -> bool

Fetch the gradient requirement.
source§

fn set_require_grad(self, _require_grad: bool) -> Tensor<B, D, Bool>

Set the gradient requirement.
source§

impl<B, const D: usize> Parameter for Tensor<B, D, Int>
where B: Backend,

§

type Device = <B as Backend>::Device

The device type to be used.
source§

fn device(&self) -> <Tensor<B, D, Int> as Parameter>::Device

Fetch the device.
source§

fn is_require_grad(&self) -> bool

Fetch the gradient requirement.
source§

fn set_require_grad(self, _require_grad: bool) -> Tensor<B, D, Int>

Set the gradient requirement.
source§

impl<B, const D: usize> Record<B> for Tensor<B, D>
where B: Backend,

§

type Item<S: PrecisionSettings> = FloatTensorSerde<S>

Type of the item that can be serialized and deserialized.
source§

fn into_item<S>(self) -> <Tensor<B, D> as Record<B>>::Item<S>

Convert the current record into the corresponding item that follows the given settings.
source§

fn from_item<S>( item: <Tensor<B, D> as Record<B>>::Item<S>, device: &<B as Backend>::Device ) -> Tensor<B, D>

Convert the given item into a record.
source§

impl<B, const D: usize> Record<B> for Tensor<B, D, Bool>
where B: Backend,

§

type Item<S: PrecisionSettings> = BoolTensorSerde

Type of the item that can be serialized and deserialized.
source§

fn into_item<S>(self) -> <Tensor<B, D, Bool> as Record<B>>::Item<S>

Convert the current record into the corresponding item that follows the given settings.
source§

fn from_item<S>( item: <Tensor<B, D, Bool> as Record<B>>::Item<S>, device: &<B as Backend>::Device ) -> Tensor<B, D, Bool>

Convert the given item into a record.
source§

impl<B, const D: usize> Record<B> for Tensor<B, D, Int>
where B: Backend,

§

type Item<S: PrecisionSettings> = IntTensorSerde<S>

Type of the item that can be serialized and deserialized.
source§

fn into_item<S>(self) -> <Tensor<B, D, Int> as Record<B>>::Item<S>

Convert the current record into the corresponding item that follows the given settings.
source§

fn from_item<S>( item: <Tensor<B, D, Int> as Record<B>>::Item<S>, device: &<B as Backend>::Device ) -> Tensor<B, D, Int>

Convert the given item into a record.
source§

impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
where B: Backend, K: BasicOps<B>, <K as BasicOps<B>>::Elem: Debug + Copy + Serialize,

source§

fn serialize<S>( &self, serializer: S ) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
where S: Serializer,

Serialize this value into the given Serde serializer. Read more
source§

impl<E, const D: usize, B, K> Sub<E> for Tensor<B, D, K>
where E: ElementConversion, B: Backend, K: Numeric<B>, <K as BasicOps<B>>::Elem: Element,

§

type Output = Tensor<B, D, K>

The resulting type after applying the - operator.
source§

fn sub(self, other: E) -> Tensor<B, D, K>

Performs the - operation. Read more
source§

impl<B, const D: usize, K> Sub for Tensor<B, D, K>
where B: Backend, K: Numeric<B>, <K as BasicOps<B>>::Elem: Element,

§

type Output = Tensor<B, D, K>

The resulting type after applying the - operator.
source§

fn sub(self, rhs: Tensor<B, D, K>) -> Tensor<B, D, K>

Performs the - operation. Read more

Auto Trait Implementations§

§

impl<B, const D: usize, K> Freeze for Tensor<B, D, K>
where <K as TensorKind<B>>::Primitive<D>: Freeze,

§

impl<B, const D: usize, K> RefUnwindSafe for Tensor<B, D, K>
where <K as TensorKind<B>>::Primitive<D>: RefUnwindSafe,

§

impl<B, const D: usize, K> Send for Tensor<B, D, K>

§

impl<B, const D: usize, K> Sync for Tensor<B, D, K>
where <K as TensorKind<B>>::Primitive<D>: Sync,

§

impl<B, const D: usize, K> Unpin for Tensor<B, D, K>
where <K as TensorKind<B>>::Primitive<D>: Unpin,

§

impl<B, const D: usize, K> UnwindSafe for Tensor<B, D, K>
where <K as TensorKind<B>>::Primitive<D>: UnwindSafe,

Blanket Implementations§

source§

impl<T> Any for T
where T: 'static + ?Sized,

source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
source§

impl<T> Borrow<T> for T
where T: ?Sized,

source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
source§

impl<T> Downcast<T> for T

source§

fn downcast(&self) -> &T

source§

impl<T> From<T> for T

source§

fn from(t: T) -> T

Returns the argument unchanged.

source§

impl<T> Instrument for T

source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
source§

impl<T, U> Into<U> for T
where U: From<T>,

source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

source§

impl<T> Pointable for T

source§

const ALIGN: usize = _

The alignment of pointer.
§

type Init = T

The type for initializers.
source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
source§

impl<T> Same for T

§

type Output = T

Should always be Self
source§

impl<T> ToOwned for T
where T: Clone,

§

type Owned = T

The resulting type after obtaining ownership.
source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
source§

impl<T> ToString for T
where T: Display + ?Sized,

source§

default fn to_string(&self) -> String

Converts the given value to a String. Read more
source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

§

type Error = Infallible

The type returned in the event of a conversion error.
source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
source§

impl<T> Upcast<T> for T

source§

fn upcast(&self) -> Option<&T>

source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

source§

fn vzip(self) -> V

source§

impl<T> WithSubscriber for T

source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more
source§

impl<T> DeserializeOwned for T
where T: for<'de> Deserialize<'de>,

source§

impl<T> ErasedDestructor for T
where T: 'static,

source§

impl<T> WasmNotSend for T
where T: Send,

source§

impl<T> WasmNotSync for T
where T: Sync,