Struct candle_core::Var

source ·
pub struct Var(_);
Expand description

A variable is a wrapper around a tensor, however variables can have their content modified whereas tensors are immutable.

Implementations§

source§

impl Var

source

pub fn zeros<S: Into<Shape>>( shape: S, dtype: DType, device: &Device ) -> Result<Self>

source

pub fn ones<S: Into<Shape>>( shape: S, dtype: DType, device: &Device ) -> Result<Self>

source

pub fn from_tensor(t: &Tensor) -> Result<Self>

source

pub fn rand_f64<S: Into<Shape>>( lo: f64, up: f64, s: S, dtype: DType, device: &Device ) -> Result<Self>

source

pub fn randn_f64<S: Into<Shape>>( mean: f64, std: f64, s: S, dtype: DType, device: &Device ) -> Result<Self>

source

pub fn rand<S: Into<Shape>, T: FloatDType>( lo: T, up: T, s: S, device: &Device ) -> Result<Self>

source

pub fn randn<S: Into<Shape>, T: FloatDType>( mean: T, std: T, s: S, device: &Device ) -> Result<Self>

source

pub fn new<A: NdArray>(array: A, device: &Device) -> Result<Self>

Creates a new tensor on the specified device using the content and shape of the input. This is similar to new but the resulting tensor is a variable.

source

pub fn from_vec<S: Into<Shape>, D: WithDType>( data: Vec<D>, shape: S, device: &Device ) -> Result<Self>

source

pub fn from_slice<S: Into<Shape>, D: WithDType>( array: &[D], shape: S, device: &Device ) -> Result<Self>

source

pub fn as_tensor(&self) -> &Tensor

source

pub fn into_inner(self) -> Tensor

Consumes this Var and return the underlying tensor.

source

pub fn set(&self, src: &Tensor) -> Result<()>

Sets the content of the inner tensor, this does not require a mutable reference as inner mutability is used.

Methods from Deref<Target = Tensor>§

source

pub fn backward(&self) -> Result<GradStore>

source

pub fn conv1d( &self, kernel: &Self, padding: usize, stride: usize, dilation: usize, groups: usize ) -> Result<Self>

Applies a 1D convolution over the input tensor.

source

pub fn conv2d( &self, kernel: &Self, padding: usize, stride: usize, dilation: usize, groups: usize ) -> Result<Self>

Applies a 2D convolution over the input tensor.

source

pub fn conv_transpose2d( &self, kernel: &Self, padding: usize, output_padding: usize, stride: usize, dilation: usize ) -> Result<Self>

Applies a 2D transposed convolution over the input tensor.

source

pub fn write_bytes<W: Write>(&self, f: &mut W) -> Result<()>

source

pub fn write_npy<T: AsRef<Path>>(&self, path: T) -> Result<()>

Writes a multi-dimensional array in the npy format.

source

pub fn save_safetensors<P: AsRef<Path>>( &self, name: &str, filename: P ) -> Result<()>

source

pub fn dims0(&self) -> Result<()>

source

pub fn dims1(&self) -> Result<usize>

source

pub fn dims2(&self) -> Result<(usize, usize)>

source

pub fn dims3(&self) -> Result<(usize, usize, usize)>

source

pub fn dims4(&self) -> Result<(usize, usize, usize, usize)>

source

pub fn dims5(&self) -> Result<(usize, usize, usize, usize, usize)>

source

pub fn ones_like(&self) -> Result<Self>

Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.

use candle_core::{Tensor, DType, Device};
let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
let b = a.ones_like()?;
// b == a + 1
source

pub fn zeros_like(&self) -> Result<Self>

Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.

use candle_core::{Tensor, DType, Device};
let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
let b = a.zeros_like()?;
// b is on CPU f32.
source

pub fn rand_like(&self, lo: f64, up: f64) -> Result<Self>

source

pub fn randn_like(&self, mean: f64, stdev: f64) -> Result<Self>

source

pub fn add(&self, rhs: &Self) -> Result<Self>

source

pub fn mul(&self, rhs: &Self) -> Result<Self>

source

pub fn sub(&self, rhs: &Self) -> Result<Self>

source

pub fn div(&self, rhs: &Self) -> Result<Self>

source

pub fn maximum(&self, rhs: &Self) -> Result<Self>

source

pub fn minimum(&self, rhs: &Self) -> Result<Self>

source

pub fn broadcast_add(&self, rhs: &Self) -> Result<Self>

source

pub fn broadcast_mul(&self, rhs: &Self) -> Result<Self>

source

pub fn broadcast_sub(&self, rhs: &Self) -> Result<Self>

source

pub fn broadcast_div(&self, rhs: &Self) -> Result<Self>

source

pub fn broadcast_maximum(&self, rhs: &Self) -> Result<Self>

source

pub fn broadcast_minimum(&self, rhs: &Self) -> Result<Self>

source

pub fn recip(&self) -> Result<Self>

source

pub fn neg(&self) -> Result<Self>

source

pub fn exp(&self) -> Result<Self>

source

pub fn log(&self) -> Result<Self>

source

pub fn sin(&self) -> Result<Self>

source

pub fn cos(&self) -> Result<Self>

source

pub fn abs(&self) -> Result<Self>

source

pub fn sqr(&self) -> Result<Self>

source

pub fn sqrt(&self) -> Result<Self>

source

pub fn gelu(&self) -> Result<Self>

source

pub fn relu(&self) -> Result<Self>

source

pub fn to_scalar<S: WithDType>(&self) -> Result<S>

Retrieves the single scalar value hold in the tensor. If the tensor contains multiple dimensions, an error is returned instead.

source

pub fn to_vec0<S: WithDType>(&self) -> Result<S>

An alias for to_scalar.

source

pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor>

Repeat this tensor along the specified dimensions.

source

pub fn affine(&self, mul: f64, add: f64) -> Result<Self>

This operation multiplies the input tensor by mul then adds add and return the result. The input values mul and add are casted to the appropriate type so some rounding might be performed.

use candle_core::{Tensor, Device};
let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
let a = a.affine(4., -2.)?;
assert_eq!(a.to_vec2::<f32>()?, &[[-2.0, 2.0], [6.0, 10.0]]);
source

pub fn elu(&self, alpha: f64) -> Result<Self>

Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.

source

pub fn powf(&self, e: f64) -> Result<Self>

Raise the tensor to some float exponent e.

source

pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>>

Split a tensor into the specified number of chunks, this may return less chunks than specificed.

source

pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self>

Returns a new tensor that is a narrowed version of the input, the dimension dim ranges from start to start + len.

source

pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self>

Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions.

The resulting tensor has a shape that is similar to the shape of the input tensor, except that the number of elements for each dimension index in sum_dims is 1.

use candle_core::{Tensor, Device};
let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
let s = a.sum_keepdim(0)?;
assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
let s = a.sum_keepdim(1)?;
assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
let s = a.sum_keepdim((0, 1))?;
assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
source

pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self>

Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions and compared to sum_keepdim these dimensions are squeezed rather than kept.

source

pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self>

Returns the mean of all elements in the input tensor. The mean is performed over all the input dimensions.

The resulting tensor has a shape that is similar to the shape of the input tensor, except that the number of elements for each dimension index in mean_dims is 1.

use candle_core::{Tensor, Device};
let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
let s = a.mean_keepdim(0)?;
assert_eq!(s.to_vec2::<f32>()?, &[[1., 2.]]);
let s = a.mean_keepdim(1)?;
assert_eq!(s.to_vec2::<f32>()?, &[[0.5], [2.5]]);
let s = a.mean_keepdim((0, 1))?;
assert_eq!(s.to_vec2::<f32>()?, &[[1.5]]);
source

pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self>

Returns the mean of all elements in the input tensor. The mean is performed over all the input dimensions and compared to mean_keepdim these dimensions are squeezed rather than kept.

source

pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self>

source

pub fn max<D: Dim>(&self, dim: D) -> Result<Self>

source

pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self>

source

pub fn min<D: Dim>(&self, dim: D) -> Result<Self>

source

pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self>

source

pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self>

source

pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self>

source

pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self>

source

pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self>

source

pub fn eq(&self, rhs: &Self) -> Result<Self>

source

pub fn ne(&self, rhs: &Self) -> Result<Self>

source

pub fn lt(&self, rhs: &Self) -> Result<Self>

source

pub fn gt(&self, rhs: &Self) -> Result<Self>

source

pub fn ge(&self, rhs: &Self) -> Result<Self>

source

pub fn le(&self, rhs: &Self) -> Result<Self>

source

pub fn upsample_nearest2d( &self, target_h: usize, target_w: usize ) -> Result<Self>

source

pub fn avg_pool2d<T: ToUsize2>(&self, sz: T) -> Result<Self>

source

pub fn avg_pool2d_with_stride<T: ToUsize2>( &self, kernel_size: T, stride: T ) -> Result<Self>

source

pub fn max_pool2d<T: ToUsize2>(&self, sz: T) -> Result<Self>

source

pub fn max_pool2d_with_stride<T: ToUsize2>( &self, kernel_size: T, stride: T ) -> Result<Self>

source

pub fn matmul(&self, rhs: &Self) -> Result<Self>

Returns the matrix-multiplication of the input tensor with the other provided tensor.

Arguments
  • self - A tensor with dimensions b1, b2, ..., bi, m, k.
  • rhs - A tensor with dimensions b1, b2, ..., bi, k, n.

The resulting tensor has dimensions b1, b2, ..., bi, m, n.

source

pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self>

Matrix-multiplication with broadcasting support.

Compared to matmul the two matrixes are allowed to have different dimensions as long as they are compatible for broadcast. E.g. if self has shape (j, 1, n, k) and rhs has shape (l, k, m), the output will have shape (j, l, n, m).

source

pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self>

Returns a tensor with the same shape as the input tensor, the values are taken from on_true if the input tensor value is not zero, and on_false at the positions where the input tensor is equal to zero.

source

pub fn embedding(&self, ids: &Self) -> Result<Self>

Returns a tensor with the values from the self tensor at the index corresponding to the values hold in the ids tensor.

Arguments
  • self - A tensor with dimensions v, h.
  • ids - A tensor with dimensions s and with integer values between 0 and v (exclusive).

The resulting tensor has dimensions s, h. s is called the sequence length, v the vocabulary size, and h the hidden size.

use candle_core::{Tensor, Device};
let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?;
let emb = values.embedding(&ids)?;
assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]);
source

pub fn scatter_add<D: Dim>( &self, indexes: &Self, source: &Self, dim: D ) -> Result<Self>

source

pub fn index_add<D: Dim>( &self, indexes: &Self, source: &Self, dim: D ) -> Result<Self>

source

pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self>

source

pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self>

source

pub fn strided_index(&self) -> StridedIndex<'_>

Returns an iterator over position of the elements in the storage when ranging over the index tuples in lexicographic order.

source

pub fn strided_blocks(&self) -> StridedBlocks<'_>

Similar to strided_index but returns the position of the start of each contiguous block as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator will only return the start offset and the size would be the number of elements in the tensor.

source

pub fn to_vec1<S: WithDType>(&self) -> Result<Vec<S>>

Returns the data contained in a 1D tensor as a vector of scalar values.

source

pub fn to_vec2<S: WithDType>(&self) -> Result<Vec<Vec<S>>>

Returns the data contained in a 2D tensor as a vector of vector of scalar values.

source

pub fn to_vec3<S: WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>>

Returns the data contained in a 3D tensor.

source

pub fn dtype(&self) -> DType

The dtype for the elements stored in the input tensor.

source

pub fn device(&self) -> &Device

The device on which the input tensor is located.

source

pub fn shape(&self) -> &Shape

The tensor shape, i.e. dimension sizes on each axis.

source

pub fn dims(&self) -> &[usize]

The dimension size for this tensor on each axis.

source

pub fn dim<D: Dim>(&self, dim: D) -> Result<usize>

The dimension size for a specified dimension index.

source

pub fn layout(&self) -> &Layout

The layout of the input tensor, this stores both the shape of the tensor as well as the strides and the start offset to apply to the underlying storage.

source

pub fn stride(&self) -> &[usize]

source

pub fn rank(&self) -> usize

The number of dimensions for this tensor, 0 for a scalar tensor, 1 for a 1D tensor, etc.

source

pub fn elem_count(&self) -> usize

The number of elements stored in this tensor.

source

pub fn id(&self) -> TensorId

The unique identifier for this tensor.

source

pub fn is_variable(&self) -> bool

Whether this tensor is a variable or not. A variable is a tensor for which gradient is tracked and on which backpropagation can be performed.

source

pub fn sum_all(&self) -> Result<Tensor>

Computes the sum of all the elements in this tensor and returns a tensor holding this scalar with zero dimensions.

use candle_core::{Tensor, Device};
let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
let tensor = tensor.sum_all()?;
assert_eq!(tensor.to_scalar::<f32>()?, 15.);
source

pub fn flatten<D1: Dim, D2: Dim>( &self, start_dim: D1, end_dim: D2 ) -> Result<Tensor>

Flattens the input tensor on the dimension indexes from start_dim to end_dim (both inclusive).

source

pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor>

Flattens the input tensor on the dimension indexes from 0 to end_dim (inclusive).

source

pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor>

Flattens the input tensor on the dimension indexes from start_dim (inclusive) to the last dimension.

source

pub fn flatten_all(&self) -> Result<Tensor>

Flattens the input tensor by reshaping it into a one dimension tensor.

use candle_core::{Tensor, Device};
let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
let tensor = tensor.flatten_all()?;
assert_eq!(tensor.to_vec1::<f32>()?, &[0., 1., 2., 3., 4., 5.]);
source

pub fn get(&self, i: usize) -> Result<Tensor>

Returns the sub-tensor fixing the index at i on the first dimension.

use candle_core::{Tensor, Device};
let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
let t = tensor.get(0)?;
assert_eq!(t.to_vec1::<f32>()?, &[0., 1.]);
let t = tensor.get(1)?;
assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
source

pub fn t(&self) -> Result<Tensor>

Returns a tensor that is a transposed version of the input, the two last dimensions of the input are swapped.

use candle_core::{Tensor, Device};
let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
let tensor = tensor.t()?;
assert_eq!(tensor.to_vec2::<f32>()?, &[[0.0, 2.0, 4.0], [1.0, 3.0, 5.0]]);
source

pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor>

Returns a tensor that is a transposed version of the input, the given dimensions are swapped.

source

pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor>

Returns a tensor with the same data as the input where the dimensions have been permuted. dims must be a permutation, i.e. include each dimension index exactly once.

use candle_core::{Tensor, Device};
let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?;
assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
let tensor = tensor.permute((2, 3, 1, 0))?;
assert_eq!(tensor.dims(), &[4, 5, 3, 2]);
source

pub fn is_contiguous(&self) -> bool

Returns true if the data is stored in a C contiguous (aka row major) way.

source

pub fn is_fortran_contiguous(&self) -> bool

Returns true if the data is stored in a Fortran contiguous (aka column major) way.

source

pub fn copy(&self) -> Result<Tensor>

Compared to clone, this copies the actual storage but may fail because of running out of memory.

source

pub fn detach(&self) -> Result<Tensor>

Returns a new tensor detached from the current graph, gradient are not propagated through this new node. The storage of this tensor is shared with the initial tensor.

source

pub fn to_device(&self, device: &Device) -> Result<Tensor>

If the target device is the same as the tensor device, only a shallow copy is performed.

source

pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self>

Returns a new tensor duplicating data from the original tensor. New dimensions are inserted on the left.

source

pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self>

Broadcast the input tensor to the target shape. This returns an error if the input shape is not compatible with the target shape.

If the input shape is i_1, i_2, ... i_k, the target shape has to have k dimensions or more and shape j_1, ..., j_l, t_1, t_2, ..., t_k. The dimensions j_1 to j_l can have any value, the dimension t_a must be equal to i_a if i_a is different from 1. If i_a is equal to 1, any value can be used.

source

pub fn expand<S: Into<Shape>>(&self, shape: S) -> Result<Self>

An alias for broadcast_as.

source

pub fn to_dtype(&self, dtype: DType) -> Result<Self>

Casts the input tensor to the target dtype.

use candle_core::{Tensor, Device};
let tensor = Tensor::new(3.14159265358979f64, &Device::Cpu)?;
assert_eq!(tensor.to_scalar::<f64>()?, 3.14159265358979);
let tensor = tensor.to_dtype(candle_core::DType::F32)?;
assert_eq!(tensor.to_scalar::<f32>()?, 3.1415927);
source

pub fn contiguous(&self) -> Result<Tensor>

Returns a tensor that is in row major order. This is the same as the original tensor if it was already contiguous, otherwise a copy is triggered.

source

pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor>

Reshape returns a tensor with the target shape provided that the number of elements of the original tensor is the same. If the input tensor is contiguous, this is a view on the original data. Otherwise this uses a new storage and copies the data over, the returned tensor is always contiguous.

let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;

let c = a.reshape((1, 6))?;
assert_eq!(c.shape().dims(), &[1, 6]);

let c = a.reshape((3, 2))?;
assert_eq!(c.shape().dims(), &[3, 2]);
source

pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self>

Creates a new tensor with the specified dimension removed if its size was one.

let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?;

let c = a.squeeze(2)?;
assert_eq!(c.shape().dims(), &[2, 3]);

let c = a.squeeze(D::Minus1)?;
assert_eq!(c.shape().dims(), &[2, 3]);
source

pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self>

Creates a new tensor with a dimension of size one inserted at the specified position.

let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;

let c = a.unsqueeze(0)?;
assert_eq!(c.shape().dims(), &[1, 2, 3]);

let c = a.unsqueeze(D::Minus1)?;
assert_eq!(c.shape().dims(), &[2, 3, 1]);
source

pub fn pad_with_zeros<D: Dim>( &self, dim: D, left: usize, right: usize ) -> Result<Self>

source

pub fn apply<M: Module>(&self, m: &M) -> Result<Self>

source

pub fn storage_and_layout(&self) -> (RwLockReadGuard<'_, Storage>, &Layout)

The storage used by this tensor, together with the layout to use to access it safely.

source

pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self>

Applies a unary custom op without backward support

source

pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self>

Applies a binary custom op without backward support

source

pub fn apply_op3_no_bwd<C: CustomOp3>( &self, t2: &Self, t3: &Self, c: &C ) -> Result<Self>

Applies a ternary custom op without backward support

source

pub fn apply_op1_arc( &self, c: Arc<Box<dyn CustomOp1 + Send + Sync>> ) -> Result<Self>

Applies a unary custom op.

source

pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>( &self, c: C ) -> Result<Self>

source

pub fn apply_op2_arc( &self, rhs: &Self, c: Arc<Box<dyn CustomOp2 + Send + Sync>> ) -> Result<Self>

Applies a binary custom op.

source

pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>( &self, r: &Self, c: C ) -> Result<Self>

source

pub fn apply_op3_arc( &self, t2: &Self, t3: &Self, c: Arc<Box<dyn CustomOp3 + Send + Sync>> ) -> Result<Self>

Applies a ternary custom op.

source

pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>( &self, t2: &Self, t3: &Self, c: C ) -> Result<Self>

Trait Implementations§

source§

impl Clone for Var

source§

fn clone(&self) -> Var

Returns a copy of the value. Read more
1.0.0 · source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
source§

impl Debug for Var

source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
source§

impl Deref for Var

§

type Target = Tensor

The resulting type after dereferencing.
source§

fn deref(&self) -> &Self::Target

Dereferences the value.
source§

impl Display for Var

source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more

Auto Trait Implementations§

§

impl !RefUnwindSafe for Var

§

impl Send for Var

§

impl Sync for Var

§

impl Unpin for Var

§

impl !UnwindSafe for Var

Blanket Implementations§

source§

impl<T> Any for Twhere T: 'static + ?Sized,

source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
source§

impl<T> Borrow<T> for Twhere T: ?Sized,

source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
source§

impl<T> BorrowMut<T> for Twhere T: ?Sized,

source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
source§

impl<T> From<T> for T

source§

fn from(t: T) -> T

Returns the argument unchanged.

source§

impl<T, U> Into<U> for Twhere U: From<T>,

source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

§

impl<T> Pointable for T

§

const ALIGN: usize = mem::align_of::<T>()

The alignment of pointer.
§

type Init = T

The type for initializers.
§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
source§

impl<T> ToOwned for Twhere T: Clone,

§

type Owned = T

The resulting type after obtaining ownership.
source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
source§

impl<T> ToString for Twhere T: Display + ?Sized,

source§

default fn to_string(&self) -> String

Converts the given value to a String. Read more
source§

impl<T, U> TryFrom<U> for Twhere U: Into<T>,

§

type Error = Infallible

The type returned in the event of a conversion error.
source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
source§

impl<T, U> TryInto<U> for Twhere U: TryFrom<T>,

§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
§

impl<V, T> VZip<V> for Twhere V: MultiLane<T>,

§

fn vzip(self) -> V