Skip to main content

GraphExt

Trait GraphExt 

Source
pub trait GraphExt {
Show 36 methods // Required methods fn mm(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId; fn add(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId; fn sub(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId; fn mul(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId; fn div(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId; fn gelu(&mut self, x: NodeId) -> NodeId; fn gelu_approx(&mut self, x: NodeId) -> NodeId; fn silu(&mut self, x: NodeId) -> NodeId; fn relu(&mut self, x: NodeId) -> NodeId; fn exp(&mut self, x: NodeId) -> NodeId; fn sqrt(&mut self, x: NodeId) -> NodeId; fn neg(&mut self, x: NodeId) -> NodeId; fn tanh(&mut self, x: NodeId) -> NodeId; fn ln(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId; fn layer_norm2d( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32, ) -> NodeId; fn group_norm( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, num_groups: usize, eps: f32, ) -> NodeId; fn rms_norm( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32, ) -> NodeId; fn conv2d( &mut self, input: NodeId, weight: NodeId, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], groups: usize, ) -> NodeId; fn conv_transpose2d( &mut self, input: NodeId, weight: NodeId, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], output_padding: [usize; 2], groups: usize, ) -> NodeId; fn sum(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId; fn mean(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId; fn sm(&mut self, x: NodeId, axis: i32) -> NodeId; fn reshape_(&mut self, x: NodeId, new_shape: Vec<i64>) -> NodeId; fn transpose_(&mut self, x: NodeId, perm: Vec<usize>) -> NodeId; fn narrow_( &mut self, x: NodeId, axis: usize, start: usize, len: usize, ) -> NodeId; fn concat_(&mut self, inputs: Vec<NodeId>, axis: usize) -> NodeId; fn gather_(&mut self, table: NodeId, indices: NodeId, axis: usize) -> NodeId; fn eq(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId; fn lt(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId; fn attention_( &mut self, q: NodeId, k: NodeId, v: NodeId, mask: NodeId, num_heads: usize, head_dim: usize, ) -> NodeId; fn rope( &mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize, ) -> NodeId; fn rope_n( &mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize, n_rot: usize, ) -> NodeId; fn cast(&mut self, x: NodeId, to: DType) -> NodeId; fn constant(&mut self, value: f64, dtype: DType) -> NodeId; fn try_constant( &mut self, value: f64, dtype: DType, ) -> Result<NodeId, String>; fn stop_gradient(&mut self, x: NodeId) -> NodeId;
}
Expand description

Extension trait for shape-inferred graph building.

Required Methods§

Source

fn mm(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId

Source

fn add(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId

Source

fn sub(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId

Source

fn mul(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId

Source

fn div(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId

Source

fn gelu(&mut self, x: NodeId) -> NodeId

Source

fn gelu_approx(&mut self, x: NodeId) -> NodeId

Tanh-approximation GELU (PyTorch’s default gelu formula, also candle’s Tensor::gelu). Use this when porting models whose reference implementations use the tanh form for numerical parity (e.g. DINOv2, many ViTs).

Source

fn silu(&mut self, x: NodeId) -> NodeId

Source

fn relu(&mut self, x: NodeId) -> NodeId

Source

fn exp(&mut self, x: NodeId) -> NodeId

Source

fn sqrt(&mut self, x: NodeId) -> NodeId

Source

fn neg(&mut self, x: NodeId) -> NodeId

Source

fn tanh(&mut self, x: NodeId) -> NodeId

Source

fn ln(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId

Source

fn layer_norm2d( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32, ) -> NodeId

Source

fn group_norm( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, num_groups: usize, eps: f32, ) -> NodeId

Source

fn rms_norm( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32, ) -> NodeId

Source

fn conv2d( &mut self, input: NodeId, weight: NodeId, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], groups: usize, ) -> NodeId

Source

fn conv_transpose2d( &mut self, input: NodeId, weight: NodeId, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], output_padding: [usize; 2], groups: usize, ) -> NodeId

Source

fn sum(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId

Source

fn mean(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId

Source

fn sm(&mut self, x: NodeId, axis: i32) -> NodeId

Source

fn reshape_(&mut self, x: NodeId, new_shape: Vec<i64>) -> NodeId

Source

fn transpose_(&mut self, x: NodeId, perm: Vec<usize>) -> NodeId

Source

fn narrow_( &mut self, x: NodeId, axis: usize, start: usize, len: usize, ) -> NodeId

Source

fn concat_(&mut self, inputs: Vec<NodeId>, axis: usize) -> NodeId

Source

fn gather_(&mut self, table: NodeId, indices: NodeId, axis: usize) -> NodeId

Source

fn eq(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId

Source

fn lt(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId

Source

fn attention_( &mut self, q: NodeId, k: NodeId, v: NodeId, mask: NodeId, num_heads: usize, head_dim: usize, ) -> NodeId

Source

fn rope( &mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize, ) -> NodeId

Source

fn rope_n( &mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize, n_rot: usize, ) -> NodeId

Partial RoPE: rotate the first n_rot dims (NeoX offset n_rot/2).

Source

fn cast(&mut self, x: NodeId, to: DType) -> NodeId

Source

fn constant(&mut self, value: f64, dtype: DType) -> NodeId

Rank-0 broadcastable scalar (Op::Constant). f16 / bf16 are lowered as f32 constant + cast.

Source

fn try_constant(&mut self, value: f64, dtype: DType) -> Result<NodeId, String>

Fallible variant of GraphExt::constant. Returns an error when value is out of range for dtype or when dtype cannot be encoded directly (callers may lower f16 / bf16 via try_constant on F32 plus cast).

Source

fn stop_gradient(&mut self, x: NodeId) -> NodeId

Identity forward, zero backward. The reverse-mode autodiff rule for Op::StopGradient returns no gradient contribution to the input. Equivalent to PyTorch’s tensor.detach() / jax.lax.stop_gradient / TF’s tf.stop_gradient.

Dyn Compatibility§

This trait is dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety".

Implementors§