1use serde::{Serialize, Deserialize};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
5pub enum Fp8Format {
6 E4M3,
7 E5M2,
8}
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub enum AggregationType {
13 Sum,
14 Mean,
15 Max,
16 Min,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
20pub enum TensorOp {
21 Add,
23 Sub,
24 Mul,
25 Div,
26 Neg,
27 MatMul,
28 Linear,
29 Conv2D,
30 Embedding,
31
32 ReLU,
34 GeLU,
35 SiLU,
36 Sigmoid,
37 Softmax,
38 Tanh,
39 LeakyReLU,
40 ELU,
41 Mish,
42 HardSwish,
43 HardSigmoid,
44
45 LayerNorm,
47 RMSNorm,
48 BatchNorm,
49 GroupNorm,
50 InstanceNorm,
51
52 Reshape,
54 Transpose,
55 Concat,
56 Split,
57 Gather,
58 Scatter,
59 Squeeze,
60 Unsqueeze,
61 Permute,
62 Expand,
63 Slice,
64 Pad,
65 Tile,
66
67 Constant,
69 Zeros,
70 Ones,
71 Arange,
72 Full,
73
74 Attention,
76 MultiHeadAttention,
77 MultiQueryAttention,
78 GroupedQueryAttention,
79 FlashAttention,
80 SlidingWindowAttention,
81 CrossAttention,
82 PagedAttention,
83
84 MoEDispatch,
86 MoECombine,
87
88 Conv1D,
90 Conv3D,
91 ConvTranspose2D,
92 DepthwiseConv2D,
93 DilatedConv2D,
94
95 MaxPool2D,
97 AvgPool2D,
98 AdaptiveAvgPool2D,
99 GlobalAvgPool,
100
101 LSTMCell,
103 GRUCell,
104 RNNCell,
105
106 Einsum,
108 FFT,
109 IFFT,
110 SVD,
111 Eig,
112 Solve,
113 TopK,
114 Sort,
115 Cumsum,
116 Where,
117 Clamp,
118
119 SparseMatMul,
121 SparseEmbedding,
122
123 Quantize,
125 Dequantize,
126 QuantizeInt4,
127 DequantizeInt4,
128 QuantizeFp8,
129 DequantizeFp8,
130
131 UNetDownBlock,
133 UNetUpBlock,
134 TimestepEmbedding,
135
136 GNNMessagePassing,
138 GNNGlobalPooling,
139
140 Checkpoint,
142 Offload,
143 GradAccumulate,
144
145 GradMatMul,
147 GradReLU,
148 GradSoftmax,
149 GradLayerNorm,
150 GradAttention,
151 GradConv2D,
152 GradLinear,
153 GradGeLU,
154
155 ParallelSplit,
157 ParallelAllReduce,
158 PipelineSend,
159 PipelineReceive,
160
161 FusedMatMulBiasReLU,
163 FusedMatMulBias,
164 FusedLinearGeLU,
165 FusedAttentionLayerNorm,
166 FusedLinearSiLU,
167 FusedConvBatchNormReLU,
168}
169
170impl TensorOp {
171 pub fn name(&self) -> &'static str {
172 match self {
173 Self::Add => "tensor.add",
175 Self::Sub => "tensor.sub",
176 Self::Mul => "tensor.mul",
177 Self::Div => "tensor.div",
178 Self::Neg => "tensor.neg",
179 Self::MatMul => "tensor.matmul",
180 Self::Linear => "tensor.linear",
181 Self::Conv2D => "tensor.conv2d",
182 Self::Embedding => "tensor.embedding",
183 Self::ReLU => "tensor.relu",
185 Self::GeLU => "tensor.gelu",
186 Self::SiLU => "tensor.silu",
187 Self::Sigmoid => "tensor.sigmoid",
188 Self::Softmax => "tensor.softmax",
189 Self::Tanh => "tensor.tanh",
190 Self::LeakyReLU => "tensor.leaky_relu",
191 Self::ELU => "tensor.elu",
192 Self::Mish => "tensor.mish",
193 Self::HardSwish => "tensor.hard_swish",
194 Self::HardSigmoid => "tensor.hard_sigmoid",
195 Self::LayerNorm => "tensor.layernorm",
197 Self::RMSNorm => "tensor.rmsnorm",
198 Self::BatchNorm => "tensor.batchnorm",
199 Self::GroupNorm => "tensor.groupnorm",
200 Self::InstanceNorm => "tensor.instancenorm",
201 Self::Reshape => "tensor.reshape",
203 Self::Transpose => "tensor.transpose",
204 Self::Concat => "tensor.concat",
205 Self::Split => "tensor.split",
206 Self::Gather => "tensor.gather",
207 Self::Scatter => "tensor.scatter",
208 Self::Squeeze => "tensor.squeeze",
209 Self::Unsqueeze => "tensor.unsqueeze",
210 Self::Permute => "tensor.permute",
211 Self::Expand => "tensor.expand",
212 Self::Slice => "tensor.slice",
213 Self::Pad => "tensor.pad",
214 Self::Tile => "tensor.tile",
215 Self::Constant => "tensor.constant",
217 Self::Zeros => "tensor.zeros",
218 Self::Ones => "tensor.ones",
219 Self::Arange => "tensor.arange",
220 Self::Full => "tensor.full",
221 Self::Attention => "tensor.attention",
223 Self::MultiHeadAttention => "tensor.multi_head_attention",
224 Self::MultiQueryAttention => "tensor.multi_query_attention",
225 Self::GroupedQueryAttention => "tensor.grouped_query_attention",
226 Self::FlashAttention => "tensor.flash_attention",
227 Self::SlidingWindowAttention => "tensor.sliding_window_attention",
228 Self::CrossAttention => "tensor.cross_attention",
229 Self::PagedAttention => "tensor.paged_attention",
230 Self::MoEDispatch => "tensor.moe_dispatch",
232 Self::MoECombine => "tensor.moe_combine",
233 Self::Conv1D => "tensor.conv1d",
235 Self::Conv3D => "tensor.conv3d",
236 Self::ConvTranspose2D => "tensor.conv_transpose2d",
237 Self::DepthwiseConv2D => "tensor.depthwise_conv2d",
238 Self::DilatedConv2D => "tensor.dilated_conv2d",
239 Self::MaxPool2D => "tensor.maxpool2d",
241 Self::AvgPool2D => "tensor.avgpool2d",
242 Self::AdaptiveAvgPool2D => "tensor.adaptive_avgpool2d",
243 Self::GlobalAvgPool => "tensor.global_avgpool",
244 Self::LSTMCell => "tensor.lstm_cell",
246 Self::GRUCell => "tensor.gru_cell",
247 Self::RNNCell => "tensor.rnn_cell",
248 Self::Einsum => "tensor.einsum",
250 Self::FFT => "tensor.fft",
251 Self::IFFT => "tensor.ifft",
252 Self::SVD => "tensor.svd",
253 Self::Eig => "tensor.eig",
254 Self::Solve => "tensor.solve",
255 Self::TopK => "tensor.topk",
256 Self::Sort => "tensor.sort",
257 Self::Cumsum => "tensor.cumsum",
258 Self::Where => "tensor.where",
259 Self::Clamp => "tensor.clamp",
260 Self::SparseMatMul => "tensor.sparse_matmul",
262 Self::SparseEmbedding => "tensor.sparse_embedding",
263 Self::Quantize => "tensor.quantize",
265 Self::Dequantize => "tensor.dequantize",
266 Self::QuantizeInt4 => "tensor.quantize_int4",
267 Self::DequantizeInt4 => "tensor.dequantize_int4",
268 Self::QuantizeFp8 => "tensor.quantize_fp8",
269 Self::DequantizeFp8 => "tensor.dequantize_fp8",
270 Self::UNetDownBlock => "tensor.unet_down_block",
272 Self::UNetUpBlock => "tensor.unet_up_block",
273 Self::TimestepEmbedding => "tensor.timestep_embedding",
274 Self::GNNMessagePassing => "tensor.gnn_message_passing",
276 Self::GNNGlobalPooling => "tensor.gnn_global_pooling",
277 Self::Checkpoint => "tensor.checkpoint",
279 Self::Offload => "tensor.offload",
280 Self::GradAccumulate => "tensor.grad_accumulate",
281 Self::GradMatMul => "tensor.grad_matmul",
283 Self::GradReLU => "tensor.grad_relu",
284 Self::GradSoftmax => "tensor.grad_softmax",
285 Self::GradLayerNorm => "tensor.grad_layernorm",
286 Self::GradAttention => "tensor.grad_attention",
287 Self::GradConv2D => "tensor.grad_conv2d",
288 Self::GradLinear => "tensor.grad_linear",
289 Self::GradGeLU => "tensor.grad_gelu",
290 Self::ParallelSplit => "tensor.parallel_split",
292 Self::ParallelAllReduce => "tensor.parallel_allreduce",
293 Self::PipelineSend => "tensor.pipeline_send",
294 Self::PipelineReceive => "tensor.pipeline_receive",
295 Self::FusedMatMulBiasReLU => "tensor.fused_matmul_bias_relu",
297 Self::FusedMatMulBias => "tensor.fused_matmul_bias",
298 Self::FusedLinearGeLU => "tensor.fused_linear_gelu",
299 Self::FusedAttentionLayerNorm => "tensor.fused_attention_layernorm",
300 Self::FusedLinearSiLU => "tensor.fused_linear_silu",
301 Self::FusedConvBatchNormReLU => "tensor.fused_conv_batchnorm_relu",
302 }
303 }
304
305 pub fn from_name(name: &str) -> Option<Self> {
306 match name {
307 "tensor.add" => Some(Self::Add),
308 "tensor.sub" => Some(Self::Sub),
309 "tensor.mul" => Some(Self::Mul),
310 "tensor.div" => Some(Self::Div),
311 "tensor.neg" => Some(Self::Neg),
312 "tensor.matmul" => Some(Self::MatMul),
313 "tensor.linear" => Some(Self::Linear),
314 "tensor.conv2d" => Some(Self::Conv2D),
315 "tensor.embedding" => Some(Self::Embedding),
316 "tensor.relu" => Some(Self::ReLU),
317 "tensor.gelu" => Some(Self::GeLU),
318 "tensor.silu" => Some(Self::SiLU),
319 "tensor.sigmoid" => Some(Self::Sigmoid),
320 "tensor.softmax" => Some(Self::Softmax),
321 "tensor.tanh" => Some(Self::Tanh),
322 "tensor.leaky_relu" => Some(Self::LeakyReLU),
323 "tensor.elu" => Some(Self::ELU),
324 "tensor.mish" => Some(Self::Mish),
325 "tensor.hard_swish" => Some(Self::HardSwish),
326 "tensor.hard_sigmoid" => Some(Self::HardSigmoid),
327 "tensor.layernorm" => Some(Self::LayerNorm),
328 "tensor.rmsnorm" => Some(Self::RMSNorm),
329 "tensor.batchnorm" => Some(Self::BatchNorm),
330 "tensor.groupnorm" => Some(Self::GroupNorm),
331 "tensor.instancenorm" => Some(Self::InstanceNorm),
332 "tensor.reshape" => Some(Self::Reshape),
333 "tensor.transpose" => Some(Self::Transpose),
334 "tensor.concat" => Some(Self::Concat),
335 "tensor.split" => Some(Self::Split),
336 "tensor.gather" => Some(Self::Gather),
337 "tensor.scatter" => Some(Self::Scatter),
338 "tensor.squeeze" => Some(Self::Squeeze),
339 "tensor.unsqueeze" => Some(Self::Unsqueeze),
340 "tensor.permute" => Some(Self::Permute),
341 "tensor.expand" => Some(Self::Expand),
342 "tensor.slice" => Some(Self::Slice),
343 "tensor.pad" => Some(Self::Pad),
344 "tensor.tile" => Some(Self::Tile),
345 "tensor.constant" => Some(Self::Constant),
346 "tensor.zeros" => Some(Self::Zeros),
347 "tensor.ones" => Some(Self::Ones),
348 "tensor.arange" => Some(Self::Arange),
349 "tensor.full" => Some(Self::Full),
350 "tensor.attention" => Some(Self::Attention),
351 "tensor.multi_head_attention" => Some(Self::MultiHeadAttention),
352 "tensor.multi_query_attention" => Some(Self::MultiQueryAttention),
353 "tensor.grouped_query_attention" => Some(Self::GroupedQueryAttention),
354 "tensor.flash_attention" => Some(Self::FlashAttention),
355 "tensor.sliding_window_attention" => Some(Self::SlidingWindowAttention),
356 "tensor.cross_attention" => Some(Self::CrossAttention),
357 "tensor.paged_attention" => Some(Self::PagedAttention),
358 "tensor.moe_dispatch" => Some(Self::MoEDispatch),
359 "tensor.moe_combine" => Some(Self::MoECombine),
360 "tensor.conv1d" => Some(Self::Conv1D),
361 "tensor.conv3d" => Some(Self::Conv3D),
362 "tensor.conv_transpose2d" => Some(Self::ConvTranspose2D),
363 "tensor.depthwise_conv2d" => Some(Self::DepthwiseConv2D),
364 "tensor.dilated_conv2d" => Some(Self::DilatedConv2D),
365 "tensor.maxpool2d" => Some(Self::MaxPool2D),
366 "tensor.avgpool2d" => Some(Self::AvgPool2D),
367 "tensor.adaptive_avgpool2d" => Some(Self::AdaptiveAvgPool2D),
368 "tensor.global_avgpool" => Some(Self::GlobalAvgPool),
369 "tensor.lstm_cell" => Some(Self::LSTMCell),
370 "tensor.gru_cell" => Some(Self::GRUCell),
371 "tensor.rnn_cell" => Some(Self::RNNCell),
372 "tensor.einsum" => Some(Self::Einsum),
373 "tensor.fft" => Some(Self::FFT),
374 "tensor.ifft" => Some(Self::IFFT),
375 "tensor.svd" => Some(Self::SVD),
376 "tensor.eig" => Some(Self::Eig),
377 "tensor.solve" => Some(Self::Solve),
378 "tensor.topk" => Some(Self::TopK),
379 "tensor.sort" => Some(Self::Sort),
380 "tensor.cumsum" => Some(Self::Cumsum),
381 "tensor.where" => Some(Self::Where),
382 "tensor.clamp" => Some(Self::Clamp),
383 "tensor.sparse_matmul" => Some(Self::SparseMatMul),
384 "tensor.sparse_embedding" => Some(Self::SparseEmbedding),
385 "tensor.quantize" => Some(Self::Quantize),
386 "tensor.dequantize" => Some(Self::Dequantize),
387 "tensor.quantize_int4" => Some(Self::QuantizeInt4),
388 "tensor.dequantize_int4" => Some(Self::DequantizeInt4),
389 "tensor.quantize_fp8" => Some(Self::QuantizeFp8),
390 "tensor.dequantize_fp8" => Some(Self::DequantizeFp8),
391 "tensor.unet_down_block" => Some(Self::UNetDownBlock),
392 "tensor.unet_up_block" => Some(Self::UNetUpBlock),
393 "tensor.timestep_embedding" => Some(Self::TimestepEmbedding),
394 "tensor.gnn_message_passing" => Some(Self::GNNMessagePassing),
395 "tensor.gnn_global_pooling" => Some(Self::GNNGlobalPooling),
396 "tensor.checkpoint" => Some(Self::Checkpoint),
397 "tensor.offload" => Some(Self::Offload),
398 "tensor.grad_accumulate" => Some(Self::GradAccumulate),
399 "tensor.grad_matmul" => Some(Self::GradMatMul),
400 "tensor.grad_relu" => Some(Self::GradReLU),
401 "tensor.grad_softmax" => Some(Self::GradSoftmax),
402 "tensor.grad_layernorm" => Some(Self::GradLayerNorm),
403 "tensor.grad_attention" => Some(Self::GradAttention),
404 "tensor.grad_conv2d" => Some(Self::GradConv2D),
405 "tensor.grad_linear" => Some(Self::GradLinear),
406 "tensor.grad_gelu" => Some(Self::GradGeLU),
407 "tensor.parallel_split" => Some(Self::ParallelSplit),
408 "tensor.parallel_allreduce" => Some(Self::ParallelAllReduce),
409 "tensor.pipeline_send" => Some(Self::PipelineSend),
410 "tensor.pipeline_receive" => Some(Self::PipelineReceive),
411 "tensor.fused_matmul_bias_relu" => Some(Self::FusedMatMulBiasReLU),
412 "tensor.fused_matmul_bias" => Some(Self::FusedMatMulBias),
413 "tensor.fused_linear_gelu" => Some(Self::FusedLinearGeLU),
414 "tensor.fused_attention_layernorm" => Some(Self::FusedAttentionLayerNorm),
415 "tensor.fused_linear_silu" => Some(Self::FusedLinearSiLU),
416 "tensor.fused_conv_batchnorm_relu" => Some(Self::FusedConvBatchNormReLU),
417 _ => None,
418 }
419 }
420
421 pub fn num_inputs(&self) -> (usize, usize) {
422 match self {
423 Self::Neg | Self::ReLU | Self::GeLU | Self::SiLU |
425 Self::Sigmoid | Self::Tanh | Self::LeakyReLU | Self::ELU |
426 Self::Mish | Self::HardSwish | Self::HardSigmoid |
427 Self::Reshape | Self::Transpose | Self::Squeeze | Self::Unsqueeze |
428 Self::Permute | Self::Expand | Self::Slice | Self::Pad | Self::Tile |
429 Self::Quantize | Self::Dequantize |
430 Self::QuantizeInt4 | Self::DequantizeInt4 |
431 Self::QuantizeFp8 | Self::DequantizeFp8 |
432 Self::Offload | Self::Checkpoint |
433 Self::GradReLU | Self::GradGeLU |
434 Self::Softmax | Self::Cumsum | Self::Sort | Self::TopK |
435 Self::FFT | Self::IFFT | Self::SVD | Self::Eig |
436 Self::GlobalAvgPool | Self::AdaptiveAvgPool2D |
437 Self::GNNGlobalPooling => (1, 1),
438
439 Self::Add | Self::Sub | Self::Mul | Self::Div |
441 Self::MatMul | Self::SparseMatMul |
442 Self::GradMatMul | Self::Embedding | Self::SparseEmbedding |
443 Self::Conv2D | Self::Conv1D | Self::Conv3D |
444 Self::ConvTranspose2D | Self::DepthwiseConv2D | Self::DilatedConv2D |
445 Self::MaxPool2D | Self::AvgPool2D |
446 Self::Solve | Self::GradConv2D | Self::Concat => (2, 2),
447
448 Self::Linear | Self::FusedMatMulBias | Self::FusedLinearGeLU |
450 Self::FusedMatMulBiasReLU | Self::FusedLinearSiLU |
451 Self::Where | Self::Clamp |
452 Self::GradLinear => (3, 3),
453
454 Self::Attention | Self::MultiHeadAttention |
456 Self::MultiQueryAttention | Self::GroupedQueryAttention |
457 Self::FlashAttention | Self::SlidingWindowAttention |
458 Self::CrossAttention | Self::GradAttention => (3, 4),
459 Self::PagedAttention => (3, 5),
460 Self::FusedAttentionLayerNorm => (3, 5),
461
462 Self::LayerNorm | Self::RMSNorm | Self::GroupNorm |
464 Self::InstanceNorm | Self::GradLayerNorm => (2, 3),
465 Self::BatchNorm | Self::FusedConvBatchNormReLU => (3, 5),
466
467 Self::LSTMCell | Self::GRUCell | Self::RNNCell => (2, 2),
469
470 Self::GNNMessagePassing => (2, 3),
472
473 Self::UNetDownBlock | Self::UNetUpBlock => (2, 3),
475 Self::TimestepEmbedding => (1, 1),
476
477 Self::MoEDispatch => (2, 3),
479 Self::MoECombine => (2, 3),
480
481 Self::Constant | Self::Zeros | Self::Ones |
483 Self::Arange | Self::Full => (0, 0),
484
485 Self::Einsum => (1, usize::MAX),
487
488 Self::GradAccumulate | Self::GradSoftmax |
490 Self::ParallelSplit | Self::ParallelAllReduce |
491 Self::PipelineSend | Self::PipelineReceive |
492 Self::Gather | Self::Scatter | Self::Split => (1, usize::MAX),
493 }
494 }
495
496 pub fn flops_formula(&self) -> &'static str {
498 match self {
499 Self::MatMul | Self::SparseMatMul => "2*M*N*K",
500 Self::Linear => "2*M*N*K + N (bias)",
501 Self::Add | Self::Sub | Self::Mul | Self::Div => "N (element count)",
502 Self::ReLU | Self::Sigmoid | Self::Tanh |
503 Self::LeakyReLU | Self::ELU | Self::HardSigmoid => "N",
504 Self::GeLU | Self::SiLU | Self::Mish | Self::HardSwish => "~8*N",
505 Self::Softmax => "5*N (exp + sum + div)",
506 Self::LayerNorm | Self::RMSNorm |
507 Self::GroupNorm | Self::InstanceNorm => "7*N",
508 Self::BatchNorm => "5*N",
509 Self::Conv2D | Self::DepthwiseConv2D | Self::DilatedConv2D => "2*Cout*Cin*Kh*Kw*Oh*Ow",
510 Self::Conv1D => "2*Cout*Cin*K*Oout",
511 Self::Conv3D => "2*Cout*Cin*Kd*Kh*Kw*Od*Oh*Ow",
512 Self::Attention | Self::MultiHeadAttention |
513 Self::GroupedQueryAttention | Self::MultiQueryAttention |
514 Self::FlashAttention | Self::SlidingWindowAttention |
515 Self::CrossAttention => "2*B*H*(S^2*D + S*D^2)",
516 Self::LSTMCell => "4*(input_size+hidden)*hidden*2",
517 Self::GRUCell => "3*(input_size+hidden)*hidden*2",
518 Self::RNNCell => "(input_size+hidden)*hidden*2",
519 Self::FFT | Self::IFFT => "5*N*log2(N)",
520 Self::Einsum => "depends on equation",
521 Self::MaxPool2D | Self::AvgPool2D | Self::AdaptiveAvgPool2D |
522 Self::GlobalAvgPool => "N (comparisons or additions)",
523 Self::Reshape | Self::Transpose | Self::Squeeze | Self::Unsqueeze |
524 Self::Permute | Self::Expand | Self::Slice | Self::Pad | Self::Tile |
525 Self::Concat | Self::Split | Self::Gather | Self::Scatter => "0 (no compute)",
526 _ => "varies",
527 }
528 }
529
530 #[inline]
532 pub fn is_zero_flop(&self) -> bool {
533 matches!(self,
534 Self::Reshape | Self::Transpose | Self::Squeeze | Self::Unsqueeze |
535 Self::Permute | Self::Expand | Self::Slice | Self::Pad | Self::Tile |
536 Self::Concat | Self::Split | Self::Gather | Self::Scatter |
537 Self::Constant | Self::Zeros | Self::Ones | Self::Arange | Self::Full |
538 Self::Checkpoint | Self::Offload |
539 Self::PipelineSend | Self::PipelineReceive |
540 Self::ParallelSplit | Self::ParallelAllReduce
541 )
542 }
543
544 #[inline]
546 pub fn is_activation(&self) -> bool {
547 matches!(self,
548 Self::ReLU | Self::GeLU | Self::SiLU | Self::Sigmoid | Self::Tanh |
549 Self::LeakyReLU | Self::ELU | Self::Mish |
550 Self::HardSwish | Self::HardSigmoid
551 )
552 }
553
554 #[inline]
556 pub fn is_attention(&self) -> bool {
557 matches!(self,
558 Self::Attention | Self::MultiHeadAttention | Self::MultiQueryAttention |
559 Self::GroupedQueryAttention | Self::FlashAttention |
560 Self::SlidingWindowAttention | Self::CrossAttention | Self::PagedAttention
561 )
562 }
563
564 #[inline]
566 pub fn is_convolution(&self) -> bool {
567 matches!(self,
568 Self::Conv1D | Self::Conv2D | Self::Conv3D |
569 Self::ConvTranspose2D | Self::DepthwiseConv2D | Self::DilatedConv2D
570 )
571 }
572
573 #[inline]
575 pub fn is_normalisation(&self) -> bool {
576 matches!(self,
577 Self::LayerNorm | Self::RMSNorm | Self::BatchNorm |
578 Self::GroupNorm | Self::InstanceNorm
579 )
580 }
581
582 #[inline]
584 pub fn is_fused(&self) -> bool {
585 matches!(self,
586 Self::FusedMatMulBiasReLU | Self::FusedMatMulBias |
587 Self::FusedLinearGeLU | Self::FusedAttentionLayerNorm |
588 Self::FusedLinearSiLU | Self::FusedConvBatchNormReLU
589 )
590 }
591
592 #[inline]
594 pub fn is_gradient(&self) -> bool {
595 matches!(self,
596 Self::GradMatMul | Self::GradReLU | Self::GradSoftmax |
597 Self::GradLayerNorm | Self::GradAttention |
598 Self::GradConv2D | Self::GradLinear | Self::GradGeLU
599 )
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606
607 #[test]
608 fn test_op_name_roundtrip() {
609 for op in &[TensorOp::MatMul, TensorOp::ReLU, TensorOp::Attention, TensorOp::Softmax] {
610 let name = op.name();
611 let recovered = TensorOp::from_name(name).unwrap();
612 assert_eq!(op, &recovered);
613 }
614 }
615
616 #[test]
617 fn test_all_ops_have_names() {
618 let ops = vec![
619 TensorOp::Add, TensorOp::Sub, TensorOp::Mul, TensorOp::Div,
620 TensorOp::MatMul, TensorOp::Linear, TensorOp::ReLU, TensorOp::GeLU,
621 TensorOp::Softmax, TensorOp::LayerNorm, TensorOp::Attention,
622 ];
623 for op in ops {
624 assert!(!op.name().is_empty());
625 assert!(TensorOp::from_name(op.name()).is_some());
626 }
627 }
628}