pub trait GraphExt {
Show 36 methods
// Required methods
fn mm(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
fn add(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
fn sub(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
fn mul(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
fn div(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
fn gelu(&mut self, x: NodeId) -> NodeId;
fn gelu_approx(&mut self, x: NodeId) -> NodeId;
fn silu(&mut self, x: NodeId) -> NodeId;
fn relu(&mut self, x: NodeId) -> NodeId;
fn exp(&mut self, x: NodeId) -> NodeId;
fn sqrt(&mut self, x: NodeId) -> NodeId;
fn neg(&mut self, x: NodeId) -> NodeId;
fn tanh(&mut self, x: NodeId) -> NodeId;
fn ln(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId;
fn layer_norm2d(
&mut self,
x: NodeId,
gamma: NodeId,
beta: NodeId,
eps: f32,
) -> NodeId;
fn group_norm(
&mut self,
x: NodeId,
gamma: NodeId,
beta: NodeId,
num_groups: usize,
eps: f32,
) -> NodeId;
fn rms_norm(
&mut self,
x: NodeId,
gamma: NodeId,
beta: NodeId,
eps: f32,
) -> NodeId;
fn conv2d(
&mut self,
input: NodeId,
weight: NodeId,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
groups: usize,
) -> NodeId;
fn conv_transpose2d(
&mut self,
input: NodeId,
weight: NodeId,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
output_padding: [usize; 2],
groups: usize,
) -> NodeId;
fn sum(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId;
fn mean(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId;
fn sm(&mut self, x: NodeId, axis: i32) -> NodeId;
fn reshape_(&mut self, x: NodeId, new_shape: Vec<i64>) -> NodeId;
fn transpose_(&mut self, x: NodeId, perm: Vec<usize>) -> NodeId;
fn narrow_(
&mut self,
x: NodeId,
axis: usize,
start: usize,
len: usize,
) -> NodeId;
fn concat_(&mut self, inputs: Vec<NodeId>, axis: usize) -> NodeId;
fn gather_(&mut self, table: NodeId, indices: NodeId, axis: usize) -> NodeId;
fn eq(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
fn lt(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
fn attention_(
&mut self,
q: NodeId,
k: NodeId,
v: NodeId,
mask: NodeId,
num_heads: usize,
head_dim: usize,
) -> NodeId;
fn rope(
&mut self,
x: NodeId,
cos: NodeId,
sin: NodeId,
head_dim: usize,
) -> NodeId;
fn rope_n(
&mut self,
x: NodeId,
cos: NodeId,
sin: NodeId,
head_dim: usize,
n_rot: usize,
) -> NodeId;
fn cast(&mut self, x: NodeId, to: DType) -> NodeId;
fn constant(&mut self, value: f64, dtype: DType) -> NodeId;
fn try_constant(
&mut self,
value: f64,
dtype: DType,
) -> Result<NodeId, String>;
fn stop_gradient(&mut self, x: NodeId) -> NodeId;
}Expand description
Extension trait for shape-inferred graph building.
Required Methods§
fn mm(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId
fn add(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId
fn sub(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId
fn mul(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId
fn div(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId
fn gelu(&mut self, x: NodeId) -> NodeId
Sourcefn gelu_approx(&mut self, x: NodeId) -> NodeId
fn gelu_approx(&mut self, x: NodeId) -> NodeId
Tanh-approximation GELU (PyTorch’s default gelu formula,
also candle’s Tensor::gelu). Use this when porting models
whose reference implementations use the tanh form for
numerical parity (e.g. DINOv2, many ViTs).
fn silu(&mut self, x: NodeId) -> NodeId
fn relu(&mut self, x: NodeId) -> NodeId
fn exp(&mut self, x: NodeId) -> NodeId
fn sqrt(&mut self, x: NodeId) -> NodeId
fn neg(&mut self, x: NodeId) -> NodeId
fn tanh(&mut self, x: NodeId) -> NodeId
fn ln(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId
fn layer_norm2d( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32, ) -> NodeId
fn group_norm( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, num_groups: usize, eps: f32, ) -> NodeId
fn rms_norm( &mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32, ) -> NodeId
fn conv2d( &mut self, input: NodeId, weight: NodeId, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], groups: usize, ) -> NodeId
fn conv_transpose2d( &mut self, input: NodeId, weight: NodeId, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], output_padding: [usize; 2], groups: usize, ) -> NodeId
fn sum(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId
fn mean(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId
fn sm(&mut self, x: NodeId, axis: i32) -> NodeId
fn reshape_(&mut self, x: NodeId, new_shape: Vec<i64>) -> NodeId
fn transpose_(&mut self, x: NodeId, perm: Vec<usize>) -> NodeId
fn narrow_( &mut self, x: NodeId, axis: usize, start: usize, len: usize, ) -> NodeId
fn concat_(&mut self, inputs: Vec<NodeId>, axis: usize) -> NodeId
fn gather_(&mut self, table: NodeId, indices: NodeId, axis: usize) -> NodeId
fn eq(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId
fn lt(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId
fn attention_( &mut self, q: NodeId, k: NodeId, v: NodeId, mask: NodeId, num_heads: usize, head_dim: usize, ) -> NodeId
fn rope( &mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize, ) -> NodeId
Sourcefn rope_n(
&mut self,
x: NodeId,
cos: NodeId,
sin: NodeId,
head_dim: usize,
n_rot: usize,
) -> NodeId
fn rope_n( &mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize, n_rot: usize, ) -> NodeId
Partial RoPE: rotate the first n_rot dims (NeoX offset n_rot/2).
fn cast(&mut self, x: NodeId, to: DType) -> NodeId
Sourcefn constant(&mut self, value: f64, dtype: DType) -> NodeId
fn constant(&mut self, value: f64, dtype: DType) -> NodeId
Rank-0 broadcastable scalar (Op::Constant). f16 / bf16
are lowered as f32 constant + cast.
Sourcefn try_constant(&mut self, value: f64, dtype: DType) -> Result<NodeId, String>
fn try_constant(&mut self, value: f64, dtype: DType) -> Result<NodeId, String>
Fallible variant of GraphExt::constant. Returns an error when
value is out of range for dtype or when dtype cannot be encoded
directly (callers may lower f16 / bf16 via try_constant on
F32 plus cast).
Sourcefn stop_gradient(&mut self, x: NodeId) -> NodeId
fn stop_gradient(&mut self, x: NodeId) -> NodeId
Identity forward, zero backward. The reverse-mode autodiff rule
for Op::StopGradient returns no gradient contribution to the
input. Equivalent to PyTorch’s tensor.detach() /
jax.lax.stop_gradient / TF’s tf.stop_gradient.
Dyn Compatibility§
This trait is dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety".