Skip to main content

Tensor

Struct Tensor 

Source
pub struct Tensor { /* private fields */ }
Expand description

A tensor handle.

In the lazy graph model a Tensor is a lightweight reference to a node in the computation graph. Operations build up the graph; actual computation happens when eval() is called (or implicitly via to_vec_f32()).

Implementations§

Source§

impl Tensor

Source

pub fn zeros(shape: &Shape, dtype: DType, device: &Device) -> Result<Self>

Create a tensor filled with zeros.

Source

pub fn ones(shape: &Shape, dtype: DType, device: &Device) -> Result<Self>

Create a tensor filled with ones.

Source

pub fn from_f32(data: &[f32], shape: &Shape, device: &Device) -> Result<Self>

Create a tensor from f32 data.

Source

pub fn from_f32_on_stream( data: &[f32], shape: &Shape, stream: &Arc<Stream>, ) -> Result<Self>

Create a tensor from f32 data on a specific stream.

Source

pub fn from_data_with_dtype( data: Vec<f32>, shape: &Shape, dtype: DType, device: &Device, ) -> Result<Self>

Create a tensor from f32 data with a specified dtype.

The data is stored as f32 internally; dtype records the logical type (e.g. F16 weights that were converted to f32 on load).

Source

pub fn add(&self, rhs: &Tensor) -> Result<Tensor>

Element-wise addition.

Source

pub fn sub(&self, rhs: &Tensor) -> Result<Tensor>

Element-wise subtraction.

Source

pub fn mul(&self, rhs: &Tensor) -> Result<Tensor>

Element-wise multiplication.

Source

pub fn div(&self, rhs: &Tensor) -> Result<Tensor>

Element-wise division.

Source

pub fn neg(&self) -> Tensor

Element-wise negation.

Source

pub fn sum_axis(&self, axis: i32) -> Result<Tensor>

Sum along an axis.

Source

pub fn sum_all(&self) -> Result<Tensor>

Sum all elements to a scalar.

Source

pub fn matmul(&self, rhs: &Tensor) -> Result<Tensor>

Matrix multiplication (2D only for now).

Source

pub fn reshape(&self, new_shape: &Shape) -> Result<Tensor>

Reshape the tensor.

Source

pub fn transpose(&self, axes: Option<&[usize]>) -> Result<Tensor>

Transpose (reverses axes by default, or use specified permutation).

Source

pub fn softmax(&self, axis: i32) -> Result<Tensor>

Softmax along an axis.

Source

pub fn silu(&self) -> Tensor

SiLU (Sigmoid Linear Unit) activation.

Source

pub fn gelu(&self) -> Tensor

GELU (Gaussian Error Linear Unit) activation.

Source

pub fn layer_norm(&self, eps: f32) -> Tensor

Layer normalization over the last dimension.

Source

pub fn rms_norm(&self, eps: f32) -> Tensor

RMS normalization over the last dimension.

Source

pub fn rope(&self, rotary_dim: usize, pos_offset: usize, theta: f32) -> Tensor

Apply Rotary Positional Embeddings.

Source

pub fn layer_norm_vjp(&self, input: &Tensor, eps: f32) -> Result<Tensor>

LayerNorm VJP: compute grad_input given grad_output and original input.

Source

pub fn rms_norm_vjp(&self, input: &Tensor, eps: f32) -> Result<Tensor>

RmsNorm VJP: compute grad_input given grad_output and original input.

Source

pub fn softmax_vjp(&self, softmax_output: &Tensor, axis: i32) -> Result<Tensor>

Softmax VJP: compute grad_input given grad_output (self) and softmax output.

Source

pub fn silu_vjp(&self, input: &Tensor) -> Result<Tensor>

SiLU VJP: compute grad_input given grad_output (self) and original input.

Source

pub fn gelu_vjp(&self, input: &Tensor) -> Result<Tensor>

GELU VJP: compute grad_input given grad_output (self) and original input.

Source

pub fn embedding_lookup(&self, indices: &Tensor) -> Result<Tensor>

Embedding lookup: gather rows from this weight matrix [vocab, dim] using indices [seq_len]. Returns [seq_len, dim].

Source

pub fn narrow(&self, axis: i32, start: i64, length: i64) -> Result<Tensor>

Narrow (slice) along an axis: extract length elements starting at start.

Source

pub fn cat(tensors: &[&Tensor], axis: i32) -> Result<Tensor>

Concatenate tensors along an axis.

Source

pub fn attention( &self, k: &Tensor, v: &Tensor, scale: f32, causal: bool, ) -> Result<Tensor>

Single-head attention: Q @ K^T * scale → causal mask → softmax → @ V. Q: [Tq, Dh], K: [Tk, Dh], V: [Tk, Dh] → Output: [Tq, Dh]

Source

pub fn sqrt(&self) -> Tensor

Element-wise square root.

Source

pub fn eval(&self) -> Result<()>

Materialize the tensor — triggers evaluation of the computation graph.

Source

pub fn to_vec_f32(&self) -> Result<Vec<f32>>

Copy data out as Vec. Triggers evaluation if needed.

Source

pub fn shape(&self) -> &Shape

Get the tensor shape.

Source

pub fn dtype(&self) -> DType

Get the tensor dtype.

Source

pub fn device(&self) -> &Device

Get the tensor device.

Source

pub fn numel(&self) -> i64

Number of elements.

Source

pub fn node_id(&self) -> NodeId

Get the graph node ID.

Source

pub fn stream(&self) -> Arc<Stream>

Get the stream this tensor belongs to.

Source

pub fn from_node_id( node_id: NodeId, shape: Shape, dtype: DType, device: Device, stream: Arc<Stream>, ) -> Self

Reconstruct a tensor handle from a node ID and metadata.

Used by autograd to create handles for graph introspection.

Source

pub fn broadcast_to(&self, target: &Shape) -> Result<Tensor>

Broadcast this tensor to the target shape (numpy-style rules).

Trait Implementations§

Source§

impl Add for &Tensor

Source§

type Output = Result<Tensor, MlxError>

The resulting type after applying the + operator.
Source§

fn add(self, rhs: &Tensor) -> Self::Output

Performs the + operation. Read more
Source§

impl Clone for Tensor

Source§

fn clone(&self) -> Tensor

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Mul for &Tensor

Source§

type Output = Result<Tensor, MlxError>

The resulting type after applying the * operator.
Source§

fn mul(self, rhs: &Tensor) -> Self::Output

Performs the * operation. Read more
Source§

impl Neg for &Tensor

Source§

type Output = Tensor

The resulting type after applying the - operator.
Source§

fn neg(self) -> Self::Output

Performs the unary - operation. Read more
Source§

impl Sub for &Tensor

Source§

type Output = Result<Tensor, MlxError>

The resulting type after applying the - operator.
Source§

fn sub(self, rhs: &Tensor) -> Self::Output

Performs the - operation. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.