Struct candle_core::Tensor
source · pub struct Tensor(_);Expand description
The core struct for manipulating tensors.
use candle_core::{Tensor, DType, Device};
let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
let c = a.matmul(&b)?;Tensors are reference counted with Arc so cloning them is cheap.
Implementations§
source§impl Tensor
impl Tensor
sourcepub fn conv1d(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize
) -> Result<Self>
pub fn conv1d( &self, kernel: &Self, padding: usize, stride: usize, dilation: usize, groups: usize ) -> Result<Self>
Applies a 1D convolution over the input tensor.
source§impl Tensor
impl Tensor
sourcepub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Self>
pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Self>
Reads a npy file and return the stored multi-dimensional array as a tensor.
sourcepub fn read_npz<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Self)>>
pub fn read_npz<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Self)>>
Reads a npz file and returns the stored multi-dimensional arrays together with their names.
sourcepub fn read_npz_by_name<T: AsRef<Path>>(
path: T,
names: &[&str]
) -> Result<Vec<Self>>
pub fn read_npz_by_name<T: AsRef<Path>>( path: T, names: &[&str] ) -> Result<Vec<Self>>
Reads a npz file and returns the stored multi-dimensional arrays for some specified names.
source§impl Tensor
impl Tensor
sourcepub fn ones<S: Into<Shape>>(
shape: S,
dtype: DType,
device: &Device
) -> Result<Self>
pub fn ones<S: Into<Shape>>( shape: S, dtype: DType, device: &Device ) -> Result<Self>
Creates a new tensor filled with ones.
use candle_core::{Tensor, DType, Device};
let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?;
let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?;
// a == bsourcepub fn ones_like(&self) -> Result<Self>
pub fn ones_like(&self) -> Result<Self>
Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
use candle_core::{Tensor, DType, Device};
let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
let b = a.ones_like()?;
// b == a + 1sourcepub fn zeros<S: Into<Shape>>(
shape: S,
dtype: DType,
device: &Device
) -> Result<Self>
pub fn zeros<S: Into<Shape>>( shape: S, dtype: DType, device: &Device ) -> Result<Self>
Creates a new tensor filled with zeros.
use candle_core::{Tensor, DType, Device};
let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
// a == bsourcepub fn zeros_like(&self) -> Result<Self>
pub fn zeros_like(&self) -> Result<Self>
Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
use candle_core::{Tensor, DType, Device};
let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
let b = a.zeros_like()?;
// b is on CPU f32.sourcepub fn rand<S: Into<Shape>, T: FloatDType>(
lo: T,
up: T,
s: S,
device: &Device
) -> Result<Self>
pub fn rand<S: Into<Shape>, T: FloatDType>( lo: T, up: T, s: S, device: &Device ) -> Result<Self>
Creates a new tensor initialized with values sampled uniformly between lo and up.
pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self>
pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self>
sourcepub fn randn<S: Into<Shape>, T: FloatDType>(
mean: T,
std: T,
s: S,
device: &Device
) -> Result<Self>
pub fn randn<S: Into<Shape>, T: FloatDType>( mean: T, std: T, s: S, device: &Device ) -> Result<Self>
Creates a new tensor initialized with values sampled from a normal distribution with the
specified mean and standard deviation std.
sourcepub fn new<A: NdArray>(array: A, device: &Device) -> Result<Self>
pub fn new<A: NdArray>(array: A, device: &Device) -> Result<Self>
Creates a new tensor on the specified device using the content and shape of the input.
sourcepub fn from_iter<D: WithDType>(
iter: impl IntoIterator<Item = D>,
device: &Device
) -> Result<Self>
pub fn from_iter<D: WithDType>( iter: impl IntoIterator<Item = D>, device: &Device ) -> Result<Self>
Creates a new 1D tensor from an iterator.
sourcepub fn arange<D: WithDType>(start: D, end: D, device: &Device) -> Result<Self>
pub fn arange<D: WithDType>(start: D, end: D, device: &Device) -> Result<Self>
Creates a new 1D tensor with values from the interval [start, end) taken with a common
difference 1 from start.
sourcepub fn arange_step<D: WithDType>(
start: D,
end: D,
step: D,
device: &Device
) -> Result<Self>
pub fn arange_step<D: WithDType>( start: D, end: D, step: D, device: &Device ) -> Result<Self>
Creates a new 1D tensor with values from the interval [start, end) taken with a common
difference step from start.
sourcepub fn from_vec<S: Into<Shape>, D: WithDType>(
data: Vec<D>,
shape: S,
device: &Device
) -> Result<Self>
pub fn from_vec<S: Into<Shape>, D: WithDType>( data: Vec<D>, shape: S, device: &Device ) -> Result<Self>
Creates a new tensor initialized with values from the input vector. The number of elements in this vector must be the same as the number of elements defined by the shape. If the device is cpu, no data copy is made.
sourcepub fn from_slice<S: Into<Shape>, D: WithDType>(
array: &[D],
shape: S,
device: &Device
) -> Result<Self>
pub fn from_slice<S: Into<Shape>, D: WithDType>( array: &[D], shape: S, device: &Device ) -> Result<Self>
Creates a new tensor initialized with values from the input slice. The number of elements in this vector must be the same as the number of elements defined by the shape.
pub fn add(&self, rhs: &Self) -> Result<Self>
pub fn mul(&self, rhs: &Self) -> Result<Self>
pub fn sub(&self, rhs: &Self) -> Result<Self>
pub fn div(&self, rhs: &Self) -> Result<Self>
pub fn maximum(&self, rhs: &Self) -> Result<Self>
pub fn minimum(&self, rhs: &Self) -> Result<Self>
pub fn broadcast_add(&self, rhs: &Self) -> Result<Self>
pub fn broadcast_mul(&self, rhs: &Self) -> Result<Self>
pub fn broadcast_sub(&self, rhs: &Self) -> Result<Self>
pub fn broadcast_div(&self, rhs: &Self) -> Result<Self>
pub fn broadcast_maximum(&self, rhs: &Self) -> Result<Self>
pub fn broadcast_minimum(&self, rhs: &Self) -> Result<Self>
pub fn recip(&self) -> Result<Self>
pub fn neg(&self) -> Result<Self>
pub fn exp(&self) -> Result<Self>
pub fn log(&self) -> Result<Self>
pub fn sin(&self) -> Result<Self>
pub fn cos(&self) -> Result<Self>
pub fn abs(&self) -> Result<Self>
pub fn sqr(&self) -> Result<Self>
pub fn sqrt(&self) -> Result<Self>
pub fn gelu(&self) -> Result<Self>
pub fn relu(&self) -> Result<Self>
sourcepub fn to_scalar<S: WithDType>(&self) -> Result<S>
pub fn to_scalar<S: WithDType>(&self) -> Result<S>
Retrieves the single scalar value hold in the tensor. If the tensor contains multiple dimensions, an error is returned instead.
sourcepub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor>
pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor>
Repeat this tensor along the specified dimensions.
sourcepub fn affine(&self, mul: f64, add: f64) -> Result<Self>
pub fn affine(&self, mul: f64, add: f64) -> Result<Self>
This operation multiplies the input tensor by mul then adds add and return the result.
The input values mul and add are casted to the appropriate type so some rounding might
be performed.
use candle_core::{Tensor, Device};
let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
let a = a.affine(4., -2.)?;
assert_eq!(a.to_vec2::<f32>()?, &[[-2.0, 2.0], [6.0, 10.0]]);sourcepub fn elu(&self, alpha: f64) -> Result<Self>
pub fn elu(&self, alpha: f64) -> Result<Self>
Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
sourcepub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>>
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>>
Split a tensor into the specified number of chunks, this may return less chunks than specificed.
sourcepub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self>
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self>
Returns a new tensor that is a narrowed version of the input, the dimension dim
ranges from start to start + len.
sourcepub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self>
pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self>
Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions.
The resulting tensor has a shape that is similar to the shape of the input tensor, except
that the number of elements for each dimension index in sum_dims is 1.
use candle_core::{Tensor, Device};
let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
let s = a.sum_keepdim(0)?;
assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
let s = a.sum_keepdim(1)?;
assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
let s = a.sum_keepdim((0, 1))?;
assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);sourcepub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self>
pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self>
Returns the sum of all elements in the input tensor. The sum is performed over all the
input dimensions and compared to sum_keepdim these dimensions are squeezed rather than
kept.
sourcepub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self>
pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self>
Returns the mean of all elements in the input tensor. The mean is performed over all the input dimensions.
The resulting tensor has a shape that is similar to the shape of the input tensor, except
that the number of elements for each dimension index in mean_dims is 1.
use candle_core::{Tensor, Device};
let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
let s = a.mean_keepdim(0)?;
assert_eq!(s.to_vec2::<f32>()?, &[[1., 2.]]);
let s = a.mean_keepdim(1)?;
assert_eq!(s.to_vec2::<f32>()?, &[[0.5], [2.5]]);
let s = a.mean_keepdim((0, 1))?;
assert_eq!(s.to_vec2::<f32>()?, &[[1.5]]);sourcepub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self>
pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self>
Returns the mean of all elements in the input tensor. The mean is performed over all the
input dimensions and compared to mean_keepdim these dimensions are squeezed rather than
kept.
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self>
pub fn max<D: Dim>(&self, dim: D) -> Result<Self>
pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self>
pub fn min<D: Dim>(&self, dim: D) -> Result<Self>
pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self>
pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self>
pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self>
pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self>
pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self>
pub fn eq(&self, rhs: &Self) -> Result<Self>
pub fn ne(&self, rhs: &Self) -> Result<Self>
pub fn lt(&self, rhs: &Self) -> Result<Self>
pub fn gt(&self, rhs: &Self) -> Result<Self>
pub fn ge(&self, rhs: &Self) -> Result<Self>
pub fn le(&self, rhs: &Self) -> Result<Self>
pub fn upsample_nearest2d( &self, target_h: usize, target_w: usize ) -> Result<Self>
pub fn avg_pool2d<T: ToUsize2>(&self, sz: T) -> Result<Self>
pub fn avg_pool2d_with_stride<T: ToUsize2>( &self, kernel_size: T, stride: T ) -> Result<Self>
pub fn max_pool2d<T: ToUsize2>(&self, sz: T) -> Result<Self>
pub fn max_pool2d_with_stride<T: ToUsize2>( &self, kernel_size: T, stride: T ) -> Result<Self>
sourcepub fn matmul(&self, rhs: &Self) -> Result<Self>
pub fn matmul(&self, rhs: &Self) -> Result<Self>
Returns the matrix-multiplication of the input tensor with the other provided tensor.
Arguments
self- A tensor with dimensionsb1, b2, ..., bi, m, k.rhs- A tensor with dimensionsb1, b2, ..., bi, k, n.
The resulting tensor has dimensions b1, b2, ..., bi, m, n.
sourcepub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self>
pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self>
Matrix-multiplication with broadcasting support.
Compared to matmul the two matrixes are allowed to have different dimensions as long as
they are compatible for broadcast. E.g. if self has shape (j, 1, n, k) and rhs has
shape (l, k, m), the output will have shape (j, l, n, m).
sourcepub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self>
pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self>
Returns a tensor with the same shape as the input tensor, the values are taken from
on_true if the input tensor value is not zero, and on_false at the positions where the
input tensor is equal to zero.
sourcepub fn embedding(&self, ids: &Self) -> Result<Self>
pub fn embedding(&self, ids: &Self) -> Result<Self>
Returns a tensor with the values from the self tensor at the index corresponding to the
values hold in the ids tensor.
Arguments
self- A tensor with dimensionsv, h.ids- A tensor with dimensionssand with integer values between 0 and v (exclusive).
The resulting tensor has dimensions s, h. s is called the sequence length, v the
vocabulary size, and h the hidden size.
use candle_core::{Tensor, Device};
let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?;
let emb = values.embedding(&ids)?;
assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]);pub fn scatter_add<D: Dim>( &self, indexes: &Self, source: &Self, dim: D ) -> Result<Self>
pub fn index_add<D: Dim>( &self, indexes: &Self, source: &Self, dim: D ) -> Result<Self>
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self>
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self>
sourcepub fn strided_index(&self) -> StridedIndex<'_> ⓘ
pub fn strided_index(&self) -> StridedIndex<'_> ⓘ
Returns an iterator over position of the elements in the storage when ranging over the index tuples in lexicographic order.
sourcepub fn strided_blocks(&self) -> StridedBlocks<'_>
pub fn strided_blocks(&self) -> StridedBlocks<'_>
Similar to strided_index but returns the position of the start of each contiguous block
as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator
will only return the start offset and the size would be the number of elements in the
tensor.
sourcepub fn to_vec1<S: WithDType>(&self) -> Result<Vec<S>>
pub fn to_vec1<S: WithDType>(&self) -> Result<Vec<S>>
Returns the data contained in a 1D tensor as a vector of scalar values.
sourcepub fn to_vec2<S: WithDType>(&self) -> Result<Vec<Vec<S>>>
pub fn to_vec2<S: WithDType>(&self) -> Result<Vec<Vec<S>>>
Returns the data contained in a 2D tensor as a vector of vector of scalar values.
sourcepub fn to_vec3<S: WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>>
pub fn to_vec3<S: WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>>
Returns the data contained in a 3D tensor.
sourcepub fn dim<D: Dim>(&self, dim: D) -> Result<usize>
pub fn dim<D: Dim>(&self, dim: D) -> Result<usize>
The dimension size for a specified dimension index.
sourcepub fn layout(&self) -> &Layout
pub fn layout(&self) -> &Layout
The layout of the input tensor, this stores both the shape of the tensor as well as the strides and the start offset to apply to the underlying storage.
pub fn stride(&self) -> &[usize]
sourcepub fn rank(&self) -> usize
pub fn rank(&self) -> usize
The number of dimensions for this tensor, 0 for a scalar tensor, 1 for a 1D tensor, etc.
sourcepub fn elem_count(&self) -> usize
pub fn elem_count(&self) -> usize
The number of elements stored in this tensor.
sourcepub fn is_variable(&self) -> bool
pub fn is_variable(&self) -> bool
Whether this tensor is a variable or not. A variable is a tensor for which gradient is tracked and on which backpropagation can be performed.
sourcepub fn sum_all(&self) -> Result<Tensor>
pub fn sum_all(&self) -> Result<Tensor>
Computes the sum of all the elements in this tensor and returns a tensor holding this scalar with zero dimensions.
use candle_core::{Tensor, Device};
let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
let tensor = tensor.sum_all()?;
assert_eq!(tensor.to_scalar::<f32>()?, 15.);sourcepub fn flatten<D1: Dim, D2: Dim>(
&self,
start_dim: D1,
end_dim: D2
) -> Result<Tensor>
pub fn flatten<D1: Dim, D2: Dim>( &self, start_dim: D1, end_dim: D2 ) -> Result<Tensor>
Flattens the input tensor on the dimension indexes from start_dim to end_dim (both
inclusive).
sourcepub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor>
pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor>
Flattens the input tensor on the dimension indexes from 0 to end_dim (inclusive).
sourcepub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor>
pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor>
Flattens the input tensor on the dimension indexes from start_dim (inclusive) to the last
dimension.
sourcepub fn flatten_all(&self) -> Result<Tensor>
pub fn flatten_all(&self) -> Result<Tensor>
Flattens the input tensor by reshaping it into a one dimension tensor.
use candle_core::{Tensor, Device};
let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
let tensor = tensor.flatten_all()?;
assert_eq!(tensor.to_vec1::<f32>()?, &[0., 1., 2., 3., 4., 5.]);sourcepub fn get(&self, i: usize) -> Result<Tensor>
pub fn get(&self, i: usize) -> Result<Tensor>
Returns the sub-tensor fixing the index at i on the first dimension.
use candle_core::{Tensor, Device};
let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
let t = tensor.get(0)?;
assert_eq!(t.to_vec1::<f32>()?, &[0., 1.]);
let t = tensor.get(1)?;
assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);sourcepub fn t(&self) -> Result<Tensor>
pub fn t(&self) -> Result<Tensor>
Returns a tensor that is a transposed version of the input, the two last dimensions of the input are swapped.
use candle_core::{Tensor, Device};
let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
let tensor = tensor.t()?;
assert_eq!(tensor.to_vec2::<f32>()?, &[[0.0, 2.0, 4.0], [1.0, 3.0, 5.0]]);sourcepub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor>
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor>
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
sourcepub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor>
pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor>
Returns a tensor with the same data as the input where the dimensions have been permuted. dims must be a permutation, i.e. include each dimension index exactly once.
use candle_core::{Tensor, Device};
let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?;
assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
let tensor = tensor.permute((2, 3, 1, 0))?;
assert_eq!(tensor.dims(), &[4, 5, 3, 2]);sourcepub fn is_contiguous(&self) -> bool
pub fn is_contiguous(&self) -> bool
Returns true if the data is stored in a C contiguous (aka row major) way.
sourcepub fn is_fortran_contiguous(&self) -> bool
pub fn is_fortran_contiguous(&self) -> bool
Returns true if the data is stored in a Fortran contiguous (aka column major) way.
sourcepub fn copy(&self) -> Result<Tensor>
pub fn copy(&self) -> Result<Tensor>
Compared to clone, this copies the actual storage but may fail because of running out of memory.
sourcepub fn detach(&self) -> Result<Tensor>
pub fn detach(&self) -> Result<Tensor>
Returns a new tensor detached from the current graph, gradient are not propagated through this new node. The storage of this tensor is shared with the initial tensor.
sourcepub fn to_device(&self, device: &Device) -> Result<Tensor>
pub fn to_device(&self, device: &Device) -> Result<Tensor>
If the target device is the same as the tensor device, only a shallow copy is performed.
sourcepub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self>
pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self>
Returns a new tensor duplicating data from the original tensor. New dimensions are inserted on the left.
sourcepub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self>
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self>
Broadcast the input tensor to the target shape. This returns an error if the input shape is not compatible with the target shape.
If the input shape is i_1, i_2, ... i_k, the target shape has to have k dimensions or
more and shape j_1, ..., j_l, t_1, t_2, ..., t_k. The dimensions j_1 to j_l can have
any value, the dimension t_a must be equal to i_a if i_a is different from 1. If
i_a is equal to 1, any value can be used.
sourcepub fn to_dtype(&self, dtype: DType) -> Result<Self>
pub fn to_dtype(&self, dtype: DType) -> Result<Self>
Casts the input tensor to the target dtype.
use candle_core::{Tensor, Device};
let tensor = Tensor::new(3.14159265358979f64, &Device::Cpu)?;
assert_eq!(tensor.to_scalar::<f64>()?, 3.14159265358979);
let tensor = tensor.to_dtype(candle_core::DType::F32)?;
assert_eq!(tensor.to_scalar::<f32>()?, 3.1415927);sourcepub fn contiguous(&self) -> Result<Tensor>
pub fn contiguous(&self) -> Result<Tensor>
Returns a tensor that is in row major order. This is the same as the original tensor if it was already contiguous, otherwise a copy is triggered.
sourcepub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor>
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor>
Reshape returns a tensor with the target shape provided that the number of elements of the original tensor is the same. If the input tensor is contiguous, this is a view on the original data. Otherwise this uses a new storage and copies the data over, the returned tensor is always contiguous.
let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
let c = a.reshape((1, 6))?;
assert_eq!(c.shape().dims(), &[1, 6]);
let c = a.reshape((3, 2))?;
assert_eq!(c.shape().dims(), &[3, 2]);sourcepub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self>
pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self>
Creates a new tensor with the specified dimension removed if its size was one.
let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?;
let c = a.squeeze(2)?;
assert_eq!(c.shape().dims(), &[2, 3]);
let c = a.squeeze(D::Minus1)?;
assert_eq!(c.shape().dims(), &[2, 3]);sourcepub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self>
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self>
Creates a new tensor with a dimension of size one inserted at the specified position.
let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
let c = a.unsqueeze(0)?;
assert_eq!(c.shape().dims(), &[1, 2, 3]);
let c = a.unsqueeze(D::Minus1)?;
assert_eq!(c.shape().dims(), &[2, 3, 1]);sourcepub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self>
pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self>
Stacks two or more tensors along a particular dimension.
All tensors must have the same rank, and the output has one additional rank
let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
let c = Tensor::stack(&[&a, &b], 0)?;
assert_eq!(c.shape().dims(), &[2, 2, 3]);
let c = Tensor::stack(&[&a, &b], 2)?;
assert_eq!(c.shape().dims(), &[2, 3, 2]);sourcepub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self>
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self>
Concatenates two or more tensors along a particular dimension.
All tensors must of the same rank, and the output will have the same rank
let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
let c = Tensor::cat(&[&a, &b], 0)?;
assert_eq!(c.shape().dims(), &[4, 3]);
let c = Tensor::cat(&[&a, &b], 1)?;
assert_eq!(c.shape().dims(), &[2, 6]);pub fn pad_with_zeros<D: Dim>( &self, dim: D, left: usize, right: usize ) -> Result<Self>
pub fn apply<M: Module>(&self, m: &M) -> Result<Self>
sourcepub fn storage_and_layout(&self) -> (RwLockReadGuard<'_, Storage>, &Layout)
pub fn storage_and_layout(&self) -> (RwLockReadGuard<'_, Storage>, &Layout)
The storage used by this tensor, together with the layout to use to access it safely.
sourcepub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self>
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self>
Applies a unary custom op without backward support
sourcepub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self>
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self>
Applies a binary custom op without backward support
sourcepub fn apply_op3_no_bwd<C: CustomOp3>(
&self,
t2: &Self,
t3: &Self,
c: &C
) -> Result<Self>
pub fn apply_op3_no_bwd<C: CustomOp3>( &self, t2: &Self, t3: &Self, c: &C ) -> Result<Self>
Applies a ternary custom op without backward support
sourcepub fn apply_op1_arc(
&self,
c: Arc<Box<dyn CustomOp1 + Send + Sync>>
) -> Result<Self>
pub fn apply_op1_arc( &self, c: Arc<Box<dyn CustomOp1 + Send + Sync>> ) -> Result<Self>
Applies a unary custom op.
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>( &self, c: C ) -> Result<Self>
sourcepub fn apply_op2_arc(
&self,
rhs: &Self,
c: Arc<Box<dyn CustomOp2 + Send + Sync>>
) -> Result<Self>
pub fn apply_op2_arc( &self, rhs: &Self, c: Arc<Box<dyn CustomOp2 + Send + Sync>> ) -> Result<Self>
Applies a binary custom op.
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>( &self, r: &Self, c: C ) -> Result<Self>
sourcepub fn apply_op3_arc(
&self,
t2: &Self,
t3: &Self,
c: Arc<Box<dyn CustomOp3 + Send + Sync>>
) -> Result<Self>
pub fn apply_op3_arc( &self, t2: &Self, t3: &Self, c: Arc<Box<dyn CustomOp3 + Send + Sync>> ) -> Result<Self>
Applies a ternary custom op.