Skip to main content

lift_tensor/
ops.rs

1use serde::{Serialize, Deserialize};
2
3/// FP8 quantisation format variants.
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
5pub enum Fp8Format {
6    E4M3,
7    E5M2,
8}
9
10/// Aggregation type for GNN message passing and pooling.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub enum AggregationType {
13    Sum,
14    Mean,
15    Max,
16    Min,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
20pub enum TensorOp {
21    // ── Arithmetic ──
22    Add,
23    Sub,
24    Mul,
25    Div,
26    Neg,
27    MatMul,
28    Linear,
29    Conv2D,
30    Embedding,
31
32    // ── Activations ──
33    ReLU,
34    GeLU,
35    SiLU,
36    Sigmoid,
37    Softmax,
38    Tanh,
39    LeakyReLU,
40    ELU,
41    Mish,
42    HardSwish,
43    HardSigmoid,
44
45    // ── Normalisation ──
46    LayerNorm,
47    RMSNorm,
48    BatchNorm,
49    GroupNorm,
50    InstanceNorm,
51
52    // ── Shape operations ──
53    Reshape,
54    Transpose,
55    Concat,
56    Split,
57    Gather,
58    Scatter,
59    Squeeze,
60    Unsqueeze,
61    Permute,
62    Expand,
63    Slice,
64    Pad,
65    Tile,
66
67    // ── Constants ──
68    Constant,
69    Zeros,
70    Ones,
71    Arange,
72    Full,
73
74    // ── Attention variants ──
75    Attention,
76    MultiHeadAttention,
77    MultiQueryAttention,
78    GroupedQueryAttention,
79    FlashAttention,
80    SlidingWindowAttention,
81    CrossAttention,
82    PagedAttention,
83
84    // ── MoE (Mixture of Experts) ──
85    MoEDispatch,
86    MoECombine,
87
88    // ── Convolution variants ──
89    Conv1D,
90    Conv3D,
91    ConvTranspose2D,
92    DepthwiseConv2D,
93    DilatedConv2D,
94
95    // ── Pooling ──
96    MaxPool2D,
97    AvgPool2D,
98    AdaptiveAvgPool2D,
99    GlobalAvgPool,
100
101    // ── Recurrent ──
102    LSTMCell,
103    GRUCell,
104    RNNCell,
105
106    // ── Advanced math ──
107    Einsum,
108    FFT,
109    IFFT,
110    SVD,
111    Eig,
112    Solve,
113    TopK,
114    Sort,
115    Cumsum,
116    Where,
117    Clamp,
118
119    // ── Sparse ──
120    SparseMatMul,
121    SparseEmbedding,
122
123    // ── Quantisation ──
124    Quantize,
125    Dequantize,
126    QuantizeInt4,
127    DequantizeInt4,
128    QuantizeFp8,
129    DequantizeFp8,
130
131    // ── Diffusion / Generative ──
132    UNetDownBlock,
133    UNetUpBlock,
134    TimestepEmbedding,
135
136    // ── GNN (Graph Neural Networks) ──
137    GNNMessagePassing,
138    GNNGlobalPooling,
139
140    // ── Memory management ──
141    Checkpoint,
142    Offload,
143    GradAccumulate,
144
145    // ── Gradient operations ──
146    GradMatMul,
147    GradReLU,
148    GradSoftmax,
149    GradLayerNorm,
150    GradAttention,
151    GradConv2D,
152    GradLinear,
153    GradGeLU,
154
155    // ── Parallelism ──
156    ParallelSplit,
157    ParallelAllReduce,
158    PipelineSend,
159    PipelineReceive,
160
161    // ── Fused operations ──
162    FusedMatMulBiasReLU,
163    FusedMatMulBias,
164    FusedLinearGeLU,
165    FusedAttentionLayerNorm,
166    FusedLinearSiLU,
167    FusedConvBatchNormReLU,
168}
169
170impl TensorOp {
171    pub fn name(&self) -> &'static str {
172        match self {
173            // Arithmetic
174            Self::Add => "tensor.add",
175            Self::Sub => "tensor.sub",
176            Self::Mul => "tensor.mul",
177            Self::Div => "tensor.div",
178            Self::Neg => "tensor.neg",
179            Self::MatMul => "tensor.matmul",
180            Self::Linear => "tensor.linear",
181            Self::Conv2D => "tensor.conv2d",
182            Self::Embedding => "tensor.embedding",
183            // Activations
184            Self::ReLU => "tensor.relu",
185            Self::GeLU => "tensor.gelu",
186            Self::SiLU => "tensor.silu",
187            Self::Sigmoid => "tensor.sigmoid",
188            Self::Softmax => "tensor.softmax",
189            Self::Tanh => "tensor.tanh",
190            Self::LeakyReLU => "tensor.leaky_relu",
191            Self::ELU => "tensor.elu",
192            Self::Mish => "tensor.mish",
193            Self::HardSwish => "tensor.hard_swish",
194            Self::HardSigmoid => "tensor.hard_sigmoid",
195            // Normalisation
196            Self::LayerNorm => "tensor.layernorm",
197            Self::RMSNorm => "tensor.rmsnorm",
198            Self::BatchNorm => "tensor.batchnorm",
199            Self::GroupNorm => "tensor.groupnorm",
200            Self::InstanceNorm => "tensor.instancenorm",
201            // Shape
202            Self::Reshape => "tensor.reshape",
203            Self::Transpose => "tensor.transpose",
204            Self::Concat => "tensor.concat",
205            Self::Split => "tensor.split",
206            Self::Gather => "tensor.gather",
207            Self::Scatter => "tensor.scatter",
208            Self::Squeeze => "tensor.squeeze",
209            Self::Unsqueeze => "tensor.unsqueeze",
210            Self::Permute => "tensor.permute",
211            Self::Expand => "tensor.expand",
212            Self::Slice => "tensor.slice",
213            Self::Pad => "tensor.pad",
214            Self::Tile => "tensor.tile",
215            // Constants
216            Self::Constant => "tensor.constant",
217            Self::Zeros => "tensor.zeros",
218            Self::Ones => "tensor.ones",
219            Self::Arange => "tensor.arange",
220            Self::Full => "tensor.full",
221            // Attention variants
222            Self::Attention => "tensor.attention",
223            Self::MultiHeadAttention => "tensor.multi_head_attention",
224            Self::MultiQueryAttention => "tensor.multi_query_attention",
225            Self::GroupedQueryAttention => "tensor.grouped_query_attention",
226            Self::FlashAttention => "tensor.flash_attention",
227            Self::SlidingWindowAttention => "tensor.sliding_window_attention",
228            Self::CrossAttention => "tensor.cross_attention",
229            Self::PagedAttention => "tensor.paged_attention",
230            // MoE
231            Self::MoEDispatch => "tensor.moe_dispatch",
232            Self::MoECombine => "tensor.moe_combine",
233            // Conv variants
234            Self::Conv1D => "tensor.conv1d",
235            Self::Conv3D => "tensor.conv3d",
236            Self::ConvTranspose2D => "tensor.conv_transpose2d",
237            Self::DepthwiseConv2D => "tensor.depthwise_conv2d",
238            Self::DilatedConv2D => "tensor.dilated_conv2d",
239            // Pooling
240            Self::MaxPool2D => "tensor.maxpool2d",
241            Self::AvgPool2D => "tensor.avgpool2d",
242            Self::AdaptiveAvgPool2D => "tensor.adaptive_avgpool2d",
243            Self::GlobalAvgPool => "tensor.global_avgpool",
244            // Recurrent
245            Self::LSTMCell => "tensor.lstm_cell",
246            Self::GRUCell => "tensor.gru_cell",
247            Self::RNNCell => "tensor.rnn_cell",
248            // Advanced math
249            Self::Einsum => "tensor.einsum",
250            Self::FFT => "tensor.fft",
251            Self::IFFT => "tensor.ifft",
252            Self::SVD => "tensor.svd",
253            Self::Eig => "tensor.eig",
254            Self::Solve => "tensor.solve",
255            Self::TopK => "tensor.topk",
256            Self::Sort => "tensor.sort",
257            Self::Cumsum => "tensor.cumsum",
258            Self::Where => "tensor.where",
259            Self::Clamp => "tensor.clamp",
260            // Sparse
261            Self::SparseMatMul => "tensor.sparse_matmul",
262            Self::SparseEmbedding => "tensor.sparse_embedding",
263            // Quantisation
264            Self::Quantize => "tensor.quantize",
265            Self::Dequantize => "tensor.dequantize",
266            Self::QuantizeInt4 => "tensor.quantize_int4",
267            Self::DequantizeInt4 => "tensor.dequantize_int4",
268            Self::QuantizeFp8 => "tensor.quantize_fp8",
269            Self::DequantizeFp8 => "tensor.dequantize_fp8",
270            // Diffusion / Generative
271            Self::UNetDownBlock => "tensor.unet_down_block",
272            Self::UNetUpBlock => "tensor.unet_up_block",
273            Self::TimestepEmbedding => "tensor.timestep_embedding",
274            // GNN
275            Self::GNNMessagePassing => "tensor.gnn_message_passing",
276            Self::GNNGlobalPooling => "tensor.gnn_global_pooling",
277            // Memory management
278            Self::Checkpoint => "tensor.checkpoint",
279            Self::Offload => "tensor.offload",
280            Self::GradAccumulate => "tensor.grad_accumulate",
281            // Gradient operations
282            Self::GradMatMul => "tensor.grad_matmul",
283            Self::GradReLU => "tensor.grad_relu",
284            Self::GradSoftmax => "tensor.grad_softmax",
285            Self::GradLayerNorm => "tensor.grad_layernorm",
286            Self::GradAttention => "tensor.grad_attention",
287            Self::GradConv2D => "tensor.grad_conv2d",
288            Self::GradLinear => "tensor.grad_linear",
289            Self::GradGeLU => "tensor.grad_gelu",
290            // Parallelism
291            Self::ParallelSplit => "tensor.parallel_split",
292            Self::ParallelAllReduce => "tensor.parallel_allreduce",
293            Self::PipelineSend => "tensor.pipeline_send",
294            Self::PipelineReceive => "tensor.pipeline_receive",
295            // Fused operations
296            Self::FusedMatMulBiasReLU => "tensor.fused_matmul_bias_relu",
297            Self::FusedMatMulBias => "tensor.fused_matmul_bias",
298            Self::FusedLinearGeLU => "tensor.fused_linear_gelu",
299            Self::FusedAttentionLayerNorm => "tensor.fused_attention_layernorm",
300            Self::FusedLinearSiLU => "tensor.fused_linear_silu",
301            Self::FusedConvBatchNormReLU => "tensor.fused_conv_batchnorm_relu",
302        }
303    }
304
305    pub fn from_name(name: &str) -> Option<Self> {
306        match name {
307            "tensor.add" => Some(Self::Add),
308            "tensor.sub" => Some(Self::Sub),
309            "tensor.mul" => Some(Self::Mul),
310            "tensor.div" => Some(Self::Div),
311            "tensor.neg" => Some(Self::Neg),
312            "tensor.matmul" => Some(Self::MatMul),
313            "tensor.linear" => Some(Self::Linear),
314            "tensor.conv2d" => Some(Self::Conv2D),
315            "tensor.embedding" => Some(Self::Embedding),
316            "tensor.relu" => Some(Self::ReLU),
317            "tensor.gelu" => Some(Self::GeLU),
318            "tensor.silu" => Some(Self::SiLU),
319            "tensor.sigmoid" => Some(Self::Sigmoid),
320            "tensor.softmax" => Some(Self::Softmax),
321            "tensor.tanh" => Some(Self::Tanh),
322            "tensor.leaky_relu" => Some(Self::LeakyReLU),
323            "tensor.elu" => Some(Self::ELU),
324            "tensor.mish" => Some(Self::Mish),
325            "tensor.hard_swish" => Some(Self::HardSwish),
326            "tensor.hard_sigmoid" => Some(Self::HardSigmoid),
327            "tensor.layernorm" => Some(Self::LayerNorm),
328            "tensor.rmsnorm" => Some(Self::RMSNorm),
329            "tensor.batchnorm" => Some(Self::BatchNorm),
330            "tensor.groupnorm" => Some(Self::GroupNorm),
331            "tensor.instancenorm" => Some(Self::InstanceNorm),
332            "tensor.reshape" => Some(Self::Reshape),
333            "tensor.transpose" => Some(Self::Transpose),
334            "tensor.concat" => Some(Self::Concat),
335            "tensor.split" => Some(Self::Split),
336            "tensor.gather" => Some(Self::Gather),
337            "tensor.scatter" => Some(Self::Scatter),
338            "tensor.squeeze" => Some(Self::Squeeze),
339            "tensor.unsqueeze" => Some(Self::Unsqueeze),
340            "tensor.permute" => Some(Self::Permute),
341            "tensor.expand" => Some(Self::Expand),
342            "tensor.slice" => Some(Self::Slice),
343            "tensor.pad" => Some(Self::Pad),
344            "tensor.tile" => Some(Self::Tile),
345            "tensor.constant" => Some(Self::Constant),
346            "tensor.zeros" => Some(Self::Zeros),
347            "tensor.ones" => Some(Self::Ones),
348            "tensor.arange" => Some(Self::Arange),
349            "tensor.full" => Some(Self::Full),
350            "tensor.attention" => Some(Self::Attention),
351            "tensor.multi_head_attention" => Some(Self::MultiHeadAttention),
352            "tensor.multi_query_attention" => Some(Self::MultiQueryAttention),
353            "tensor.grouped_query_attention" => Some(Self::GroupedQueryAttention),
354            "tensor.flash_attention" => Some(Self::FlashAttention),
355            "tensor.sliding_window_attention" => Some(Self::SlidingWindowAttention),
356            "tensor.cross_attention" => Some(Self::CrossAttention),
357            "tensor.paged_attention" => Some(Self::PagedAttention),
358            "tensor.moe_dispatch" => Some(Self::MoEDispatch),
359            "tensor.moe_combine" => Some(Self::MoECombine),
360            "tensor.conv1d" => Some(Self::Conv1D),
361            "tensor.conv3d" => Some(Self::Conv3D),
362            "tensor.conv_transpose2d" => Some(Self::ConvTranspose2D),
363            "tensor.depthwise_conv2d" => Some(Self::DepthwiseConv2D),
364            "tensor.dilated_conv2d" => Some(Self::DilatedConv2D),
365            "tensor.maxpool2d" => Some(Self::MaxPool2D),
366            "tensor.avgpool2d" => Some(Self::AvgPool2D),
367            "tensor.adaptive_avgpool2d" => Some(Self::AdaptiveAvgPool2D),
368            "tensor.global_avgpool" => Some(Self::GlobalAvgPool),
369            "tensor.lstm_cell" => Some(Self::LSTMCell),
370            "tensor.gru_cell" => Some(Self::GRUCell),
371            "tensor.rnn_cell" => Some(Self::RNNCell),
372            "tensor.einsum" => Some(Self::Einsum),
373            "tensor.fft" => Some(Self::FFT),
374            "tensor.ifft" => Some(Self::IFFT),
375            "tensor.svd" => Some(Self::SVD),
376            "tensor.eig" => Some(Self::Eig),
377            "tensor.solve" => Some(Self::Solve),
378            "tensor.topk" => Some(Self::TopK),
379            "tensor.sort" => Some(Self::Sort),
380            "tensor.cumsum" => Some(Self::Cumsum),
381            "tensor.where" => Some(Self::Where),
382            "tensor.clamp" => Some(Self::Clamp),
383            "tensor.sparse_matmul" => Some(Self::SparseMatMul),
384            "tensor.sparse_embedding" => Some(Self::SparseEmbedding),
385            "tensor.quantize" => Some(Self::Quantize),
386            "tensor.dequantize" => Some(Self::Dequantize),
387            "tensor.quantize_int4" => Some(Self::QuantizeInt4),
388            "tensor.dequantize_int4" => Some(Self::DequantizeInt4),
389            "tensor.quantize_fp8" => Some(Self::QuantizeFp8),
390            "tensor.dequantize_fp8" => Some(Self::DequantizeFp8),
391            "tensor.unet_down_block" => Some(Self::UNetDownBlock),
392            "tensor.unet_up_block" => Some(Self::UNetUpBlock),
393            "tensor.timestep_embedding" => Some(Self::TimestepEmbedding),
394            "tensor.gnn_message_passing" => Some(Self::GNNMessagePassing),
395            "tensor.gnn_global_pooling" => Some(Self::GNNGlobalPooling),
396            "tensor.checkpoint" => Some(Self::Checkpoint),
397            "tensor.offload" => Some(Self::Offload),
398            "tensor.grad_accumulate" => Some(Self::GradAccumulate),
399            "tensor.grad_matmul" => Some(Self::GradMatMul),
400            "tensor.grad_relu" => Some(Self::GradReLU),
401            "tensor.grad_softmax" => Some(Self::GradSoftmax),
402            "tensor.grad_layernorm" => Some(Self::GradLayerNorm),
403            "tensor.grad_attention" => Some(Self::GradAttention),
404            "tensor.grad_conv2d" => Some(Self::GradConv2D),
405            "tensor.grad_linear" => Some(Self::GradLinear),
406            "tensor.grad_gelu" => Some(Self::GradGeLU),
407            "tensor.parallel_split" => Some(Self::ParallelSplit),
408            "tensor.parallel_allreduce" => Some(Self::ParallelAllReduce),
409            "tensor.pipeline_send" => Some(Self::PipelineSend),
410            "tensor.pipeline_receive" => Some(Self::PipelineReceive),
411            "tensor.fused_matmul_bias_relu" => Some(Self::FusedMatMulBiasReLU),
412            "tensor.fused_matmul_bias" => Some(Self::FusedMatMulBias),
413            "tensor.fused_linear_gelu" => Some(Self::FusedLinearGeLU),
414            "tensor.fused_attention_layernorm" => Some(Self::FusedAttentionLayerNorm),
415            "tensor.fused_linear_silu" => Some(Self::FusedLinearSiLU),
416            "tensor.fused_conv_batchnorm_relu" => Some(Self::FusedConvBatchNormReLU),
417            _ => None,
418        }
419    }
420
421    pub fn num_inputs(&self) -> (usize, usize) {
422        match self {
423            // Unary (1 input)
424            Self::Neg | Self::ReLU | Self::GeLU | Self::SiLU |
425            Self::Sigmoid | Self::Tanh | Self::LeakyReLU | Self::ELU |
426            Self::Mish | Self::HardSwish | Self::HardSigmoid |
427            Self::Reshape | Self::Transpose | Self::Squeeze | Self::Unsqueeze |
428            Self::Permute | Self::Expand | Self::Slice | Self::Pad | Self::Tile |
429            Self::Quantize | Self::Dequantize |
430            Self::QuantizeInt4 | Self::DequantizeInt4 |
431            Self::QuantizeFp8 | Self::DequantizeFp8 |
432            Self::Offload | Self::Checkpoint |
433            Self::GradReLU | Self::GradGeLU |
434            Self::Softmax | Self::Cumsum | Self::Sort | Self::TopK |
435            Self::FFT | Self::IFFT | Self::SVD | Self::Eig |
436            Self::GlobalAvgPool | Self::AdaptiveAvgPool2D |
437            Self::GNNGlobalPooling => (1, 1),
438
439            // Binary (2 inputs)
440            Self::Add | Self::Sub | Self::Mul | Self::Div |
441            Self::MatMul | Self::SparseMatMul |
442            Self::GradMatMul | Self::Embedding | Self::SparseEmbedding |
443            Self::Conv2D | Self::Conv1D | Self::Conv3D |
444            Self::ConvTranspose2D | Self::DepthwiseConv2D | Self::DilatedConv2D |
445            Self::MaxPool2D | Self::AvgPool2D |
446            Self::Solve | Self::GradConv2D | Self::Concat => (2, 2),
447
448            // Ternary (3 inputs)
449            Self::Linear | Self::FusedMatMulBias | Self::FusedLinearGeLU |
450            Self::FusedMatMulBiasReLU | Self::FusedLinearSiLU |
451            Self::Where | Self::Clamp |
452            Self::GradLinear => (3, 3),
453
454            // Attention (3-4 inputs: Q, K, V, optional mask)
455            Self::Attention | Self::MultiHeadAttention |
456            Self::MultiQueryAttention | Self::GroupedQueryAttention |
457            Self::FlashAttention | Self::SlidingWindowAttention |
458            Self::CrossAttention | Self::GradAttention => (3, 4),
459            Self::PagedAttention => (3, 5),
460            Self::FusedAttentionLayerNorm => (3, 5),
461
462            // Normalisation (variable: input + scale + bias)
463            Self::LayerNorm | Self::RMSNorm | Self::GroupNorm |
464            Self::InstanceNorm | Self::GradLayerNorm => (2, 3),
465            Self::BatchNorm | Self::FusedConvBatchNormReLU => (3, 5),
466
467            // Recurrent (2 inputs: input, hidden state)
468            Self::LSTMCell | Self::GRUCell | Self::RNNCell => (2, 2),
469
470            // GNN (2 inputs: node features, edge index)
471            Self::GNNMessagePassing => (2, 3),
472
473            // Diffusion blocks (2-3 inputs)
474            Self::UNetDownBlock | Self::UNetUpBlock => (2, 3),
475            Self::TimestepEmbedding => (1, 1),
476
477            // MoE
478            Self::MoEDispatch => (2, 3),
479            Self::MoECombine => (2, 3),
480
481            // Constants (0 inputs)
482            Self::Constant | Self::Zeros | Self::Ones |
483            Self::Arange | Self::Full => (0, 0),
484
485            // Einsum (variable)
486            Self::Einsum => (1, usize::MAX),
487
488            // Parallelism / memory
489            Self::GradAccumulate | Self::GradSoftmax |
490            Self::ParallelSplit | Self::ParallelAllReduce |
491            Self::PipelineSend | Self::PipelineReceive |
492            Self::Gather | Self::Scatter | Self::Split => (1, usize::MAX),
493        }
494    }
495
496    /// Returns the asymptotic FLOPs formula as a human-readable string.
497    pub fn flops_formula(&self) -> &'static str {
498        match self {
499            Self::MatMul | Self::SparseMatMul => "2*M*N*K",
500            Self::Linear => "2*M*N*K + N (bias)",
501            Self::Add | Self::Sub | Self::Mul | Self::Div => "N (element count)",
502            Self::ReLU | Self::Sigmoid | Self::Tanh |
503            Self::LeakyReLU | Self::ELU | Self::HardSigmoid => "N",
504            Self::GeLU | Self::SiLU | Self::Mish | Self::HardSwish => "~8*N",
505            Self::Softmax => "5*N (exp + sum + div)",
506            Self::LayerNorm | Self::RMSNorm |
507            Self::GroupNorm | Self::InstanceNorm => "7*N",
508            Self::BatchNorm => "5*N",
509            Self::Conv2D | Self::DepthwiseConv2D | Self::DilatedConv2D => "2*Cout*Cin*Kh*Kw*Oh*Ow",
510            Self::Conv1D => "2*Cout*Cin*K*Oout",
511            Self::Conv3D => "2*Cout*Cin*Kd*Kh*Kw*Od*Oh*Ow",
512            Self::Attention | Self::MultiHeadAttention |
513            Self::GroupedQueryAttention | Self::MultiQueryAttention |
514            Self::FlashAttention | Self::SlidingWindowAttention |
515            Self::CrossAttention => "2*B*H*(S^2*D + S*D^2)",
516            Self::LSTMCell => "4*(input_size+hidden)*hidden*2",
517            Self::GRUCell => "3*(input_size+hidden)*hidden*2",
518            Self::RNNCell => "(input_size+hidden)*hidden*2",
519            Self::FFT | Self::IFFT => "5*N*log2(N)",
520            Self::Einsum => "depends on equation",
521            Self::MaxPool2D | Self::AvgPool2D | Self::AdaptiveAvgPool2D |
522            Self::GlobalAvgPool => "N (comparisons or additions)",
523            Self::Reshape | Self::Transpose | Self::Squeeze | Self::Unsqueeze |
524            Self::Permute | Self::Expand | Self::Slice | Self::Pad | Self::Tile |
525            Self::Concat | Self::Split | Self::Gather | Self::Scatter => "0 (no compute)",
526            _ => "varies",
527        }
528    }
529
530    /// Returns `true` if this op performs no arithmetic (zero FLOPs).
531    #[inline]
532    pub fn is_zero_flop(&self) -> bool {
533        matches!(self,
534            Self::Reshape | Self::Transpose | Self::Squeeze | Self::Unsqueeze |
535            Self::Permute | Self::Expand | Self::Slice | Self::Pad | Self::Tile |
536            Self::Concat | Self::Split | Self::Gather | Self::Scatter |
537            Self::Constant | Self::Zeros | Self::Ones | Self::Arange | Self::Full |
538            Self::Checkpoint | Self::Offload |
539            Self::PipelineSend | Self::PipelineReceive |
540            Self::ParallelSplit | Self::ParallelAllReduce
541        )
542    }
543
544    /// Returns `true` if this is an element-wise (unary or binary) activation.
545    #[inline]
546    pub fn is_activation(&self) -> bool {
547        matches!(self,
548            Self::ReLU | Self::GeLU | Self::SiLU | Self::Sigmoid | Self::Tanh |
549            Self::LeakyReLU | Self::ELU | Self::Mish |
550            Self::HardSwish | Self::HardSigmoid
551        )
552    }
553
554    /// Returns `true` if this is an attention variant.
555    #[inline]
556    pub fn is_attention(&self) -> bool {
557        matches!(self,
558            Self::Attention | Self::MultiHeadAttention | Self::MultiQueryAttention |
559            Self::GroupedQueryAttention | Self::FlashAttention |
560            Self::SlidingWindowAttention | Self::CrossAttention | Self::PagedAttention
561        )
562    }
563
564    /// Returns `true` if this is a convolution variant.
565    #[inline]
566    pub fn is_convolution(&self) -> bool {
567        matches!(self,
568            Self::Conv1D | Self::Conv2D | Self::Conv3D |
569            Self::ConvTranspose2D | Self::DepthwiseConv2D | Self::DilatedConv2D
570        )
571    }
572
573    /// Returns `true` if this is a normalisation op.
574    #[inline]
575    pub fn is_normalisation(&self) -> bool {
576        matches!(self,
577            Self::LayerNorm | Self::RMSNorm | Self::BatchNorm |
578            Self::GroupNorm | Self::InstanceNorm
579        )
580    }
581
582    /// Returns `true` if this is a fused operation.
583    #[inline]
584    pub fn is_fused(&self) -> bool {
585        matches!(self,
586            Self::FusedMatMulBiasReLU | Self::FusedMatMulBias |
587            Self::FusedLinearGeLU | Self::FusedAttentionLayerNorm |
588            Self::FusedLinearSiLU | Self::FusedConvBatchNormReLU
589        )
590    }
591
592    /// Returns `true` if this is a gradient (backward) operation.
593    #[inline]
594    pub fn is_gradient(&self) -> bool {
595        matches!(self,
596            Self::GradMatMul | Self::GradReLU | Self::GradSoftmax |
597            Self::GradLayerNorm | Self::GradAttention |
598            Self::GradConv2D | Self::GradLinear | Self::GradGeLU
599        )
600    }
601}
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606
607    #[test]
608    fn test_op_name_roundtrip() {
609        for op in &[TensorOp::MatMul, TensorOp::ReLU, TensorOp::Attention, TensorOp::Softmax] {
610            let name = op.name();
611            let recovered = TensorOp::from_name(name).unwrap();
612            assert_eq!(op, &recovered);
613        }
614    }
615
616    #[test]
617    fn test_all_ops_have_names() {
618        let ops = vec![
619            TensorOp::Add, TensorOp::Sub, TensorOp::Mul, TensorOp::Div,
620            TensorOp::MatMul, TensorOp::Linear, TensorOp::ReLU, TensorOp::GeLU,
621            TensorOp::Softmax, TensorOp::LayerNorm, TensorOp::Attention,
622        ];
623        for op in ops {
624            assert!(!op.name().is_empty());
625            assert!(TensorOp::from_name(op.name()).is_some());
626        }
627    }
628}