1use crate::arch::SmVersion;
32use crate::builder::KernelBuilder;
33use crate::error::PtxGenError;
34use crate::ir::PtxType;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42pub enum ElementwiseOp {
43 Add,
45 Sub,
47 Mul,
49 Div,
51 Relu,
53 Gelu,
56 Sigmoid,
58 Silu,
60 Tanh,
62 Neg,
64 Abs,
66 Sqrt,
68 Rsqrt,
70 Exp,
72 Log,
74 Scale,
76 AddScalar,
78 Ceil,
80 Floor,
82 HardSigmoid,
84 HardSwish,
86 Softplus,
88 LeakyRelu,
90 FusedAddRelu,
92 FusedScaleAdd,
94}
95
96impl ElementwiseOp {
97 #[must_use]
99 pub const fn as_str(self) -> &'static str {
100 match self {
101 Self::Add => "add",
102 Self::Sub => "sub",
103 Self::Mul => "mul",
104 Self::Div => "div",
105 Self::Relu => "relu",
106 Self::Gelu => "gelu",
107 Self::Sigmoid => "sigmoid",
108 Self::Silu => "silu",
109 Self::Tanh => "tanh",
110 Self::Neg => "neg",
111 Self::Abs => "abs",
112 Self::Sqrt => "sqrt",
113 Self::Rsqrt => "rsqrt",
114 Self::Exp => "exp",
115 Self::Log => "log",
116 Self::Ceil => "ceil",
117 Self::Floor => "floor",
118 Self::HardSigmoid => "hard_sigmoid",
119 Self::HardSwish => "hard_swish",
120 Self::Softplus => "softplus",
121 Self::LeakyRelu => "leaky_relu",
122 Self::Scale => "scale",
123 Self::AddScalar => "add_scalar",
124 Self::FusedAddRelu => "fused_add_relu",
125 Self::FusedScaleAdd => "fused_scale_add",
126 }
127 }
128
129 #[must_use]
131 pub const fn is_binary(self) -> bool {
132 matches!(
133 self,
134 Self::Add
135 | Self::Sub
136 | Self::Mul
137 | Self::Div
138 | Self::FusedAddRelu
139 | Self::FusedScaleAdd
140 )
141 }
142
143 #[must_use]
145 pub const fn needs_scalar(self) -> bool {
146 matches!(self, Self::Scale | Self::AddScalar | Self::FusedScaleAdd)
147 }
148}
149
150pub struct ElementwiseTemplate {
159 pub op: ElementwiseOp,
161 pub precision: PtxType,
163 pub target: SmVersion,
165}
166
167impl ElementwiseTemplate {
168 #[must_use]
170 pub const fn new(op: ElementwiseOp, precision: PtxType, target: SmVersion) -> Self {
171 Self {
172 op,
173 precision,
174 target,
175 }
176 }
177
178 #[must_use]
183 pub fn kernel_name(&self) -> String {
184 let type_str = self.precision.as_ptx_str().trim_start_matches('.');
185 format!("elementwise_{}_{}", self.op.as_str(), type_str)
186 }
187
188 pub fn generate(&self) -> Result<String, PtxGenError> {
195 self.validate_precision()?;
196
197 match self.op {
198 ElementwiseOp::Add => self.generate_binary_arith("add"),
199 ElementwiseOp::Sub => self.generate_binary_arith("sub"),
200 ElementwiseOp::Mul => self.generate_binary_arith("mul"),
201 ElementwiseOp::Div => self.generate_div(),
202 ElementwiseOp::Relu => self.generate_relu(),
203 ElementwiseOp::Gelu => self.generate_gelu(),
204 ElementwiseOp::Sigmoid => self.generate_sigmoid(),
205 ElementwiseOp::Silu => self.generate_silu(),
206 ElementwiseOp::Tanh => self.generate_tanh(),
207 ElementwiseOp::Neg => self.generate_unary("neg"),
208 ElementwiseOp::Abs => self.generate_unary("abs"),
209 ElementwiseOp::Sqrt => self.generate_sqrt(),
210 ElementwiseOp::Rsqrt => self.generate_rsqrt(),
211 ElementwiseOp::Exp => self.generate_exp(),
212 ElementwiseOp::Log => self.generate_log(),
213 ElementwiseOp::Ceil => self.generate_ceil(),
214 ElementwiseOp::Floor => self.generate_floor(),
215 ElementwiseOp::HardSigmoid => self.generate_hard_sigmoid(),
216 ElementwiseOp::HardSwish => self.generate_hard_swish(),
217 ElementwiseOp::Softplus => self.generate_softplus(),
218 ElementwiseOp::LeakyRelu => self.generate_leaky_relu(),
219 ElementwiseOp::Scale => self.generate_scale(),
220 ElementwiseOp::AddScalar => self.generate_add_scalar(),
221 ElementwiseOp::FusedAddRelu => self.generate_fused_add_relu(),
222 ElementwiseOp::FusedScaleAdd => self.generate_fused_scale_add(),
223 }
224 }
225
226 fn validate_precision(&self) -> Result<(), PtxGenError> {
228 if !matches!(
229 self.precision,
230 PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64
231 ) {
232 return Err(PtxGenError::InvalidType(format!(
233 "elementwise operations require F16, BF16, F32, or F64, got {}",
234 self.precision.as_ptx_str()
235 )));
236 }
237 Ok(())
238 }
239
240 const fn ty_str(&self) -> &'static str {
242 self.precision.as_ptx_str()
243 }
244
245 fn generate_binary_arith(&self, op_name: &str) -> Result<String, PtxGenError> {
249 let kernel_name = self.kernel_name();
250 let ty = self.ty_str();
251 let byte_size = self.precision.size_bytes();
252 let op_name = op_name.to_string();
253
254 KernelBuilder::new(&kernel_name)
255 .target(self.target)
256 .param("a_ptr", PtxType::U64)
257 .param("b_ptr", PtxType::U64)
258 .param("c_ptr", PtxType::U64)
259 .param("n", PtxType::U32)
260 .max_threads_per_block(256)
261 .body(move |b| {
262 let tid = b.global_thread_id_x();
263 let tid_name = tid.to_string();
264 let n_reg = b.load_param_u32("n");
265 b.if_lt_u32(tid, n_reg, move |b| {
266 let a_ptr = b.load_param_u64("a_ptr");
267 let b_ptr = b.load_param_u64("b_ptr");
268 let c_ptr = b.load_param_u64("c_ptr");
269
270 b.raw_ptx(&format!(
272 "cvt.u64.u32 %rd_off, {tid_name};\n \
273 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
274 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
275 add.u64 %rd_b, {b_ptr}, %rd_off;\n \
276 add.u64 %rd_c, {c_ptr}, %rd_off;"
277 ));
278
279 b.raw_ptx(&format!(
281 "ld.global{ty} %f_a, [%rd_a];\n \
282 ld.global{ty} %f_b, [%rd_b];\n \
283 {op_name}{ty} %f_c, %f_a, %f_b;\n \
284 st.global{ty} [%rd_c], %f_c;"
285 ));
286 });
287 b.ret();
288 })
289 .build()
290 }
291
292 fn generate_div(&self) -> Result<String, PtxGenError> {
294 let kernel_name = self.kernel_name();
295 let ty = self.ty_str();
296 let byte_size = self.precision.size_bytes();
297
298 KernelBuilder::new(&kernel_name)
299 .target(self.target)
300 .param("a_ptr", PtxType::U64)
301 .param("b_ptr", PtxType::U64)
302 .param("c_ptr", PtxType::U64)
303 .param("n", PtxType::U32)
304 .max_threads_per_block(256)
305 .body(move |b| {
306 let tid = b.global_thread_id_x();
307 let tid_name = tid.to_string();
308 let n_reg = b.load_param_u32("n");
309 b.if_lt_u32(tid, n_reg, move |b| {
310 let a_ptr = b.load_param_u64("a_ptr");
311 let b_ptr = b.load_param_u64("b_ptr");
312 let c_ptr = b.load_param_u64("c_ptr");
313
314 b.raw_ptx(&format!(
315 "cvt.u64.u32 %rd_off, {tid_name};\n \
316 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
317 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
318 add.u64 %rd_b, {b_ptr}, %rd_off;\n \
319 add.u64 %rd_c, {c_ptr}, %rd_off;"
320 ));
321
322 b.raw_ptx(&format!(
323 "ld.global{ty} %f_a, [%rd_a];\n \
324 ld.global{ty} %f_b, [%rd_b];\n \
325 div.rn{ty} %f_c, %f_a, %f_b;\n \
326 st.global{ty} [%rd_c], %f_c;"
327 ));
328 });
329 b.ret();
330 })
331 .build()
332 }
333
334 fn generate_relu(&self) -> Result<String, PtxGenError> {
336 let kernel_name = self.kernel_name();
337 let ty = self.ty_str();
338 let byte_size = self.precision.size_bytes();
339 let zero_lit = float_zero_literal(self.precision);
341
342 KernelBuilder::new(&kernel_name)
343 .target(self.target)
344 .param("a_ptr", PtxType::U64)
345 .param("b_ptr", PtxType::U64)
346 .param("n", PtxType::U32)
347 .max_threads_per_block(256)
348 .body(move |b| {
349 let tid = b.global_thread_id_x();
350 let tid_name = tid.to_string();
351 let n_reg = b.load_param_u32("n");
352 b.if_lt_u32(tid, n_reg, move |b| {
353 let a_ptr = b.load_param_u64("a_ptr");
354 let b_ptr = b.load_param_u64("b_ptr");
355
356 b.raw_ptx(&format!(
357 "cvt.u64.u32 %rd_off, {tid_name};\n \
358 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
359 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
360 add.u64 %rd_b, {b_ptr}, %rd_off;"
361 ));
362
363 b.raw_ptx(&format!(
364 "ld.global{ty} %f_x, [%rd_a];\n \
365 max{ty} %f_y, %f_x, {zero_lit};\n \
366 st.global{ty} [%rd_b], %f_y;"
367 ));
368 });
369 b.ret();
370 })
371 .build()
372 }
373
374 fn generate_sigmoid(&self) -> Result<String, PtxGenError> {
379 let kernel_name = self.kernel_name();
380 let ty = self.ty_str();
381 let byte_size = self.precision.size_bytes();
382
383 KernelBuilder::new(&kernel_name)
384 .target(self.target)
385 .param("a_ptr", PtxType::U64)
386 .param("b_ptr", PtxType::U64)
387 .param("n", PtxType::U32)
388 .max_threads_per_block(256)
389 .body(move |b| {
390 let tid = b.global_thread_id_x();
391 let tid_name = tid.to_string();
392 let n_reg = b.load_param_u32("n");
393 b.if_lt_u32(tid, n_reg, move |b| {
394 let a_ptr = b.load_param_u64("a_ptr");
395 let b_ptr = b.load_param_u64("b_ptr");
396
397 b.raw_ptx(&format!(
398 "cvt.u64.u32 %rd_off, {tid_name};\n \
399 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
400 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
401 add.u64 %rd_b, {b_ptr}, %rd_off;"
402 ));
403
404 b.raw_ptx(&format!(
408 "ld.global{ty} %f_x, [%rd_a];\n \
409 neg{ty} %f_neg, %f_x;\n \
410 mul{ty} %f_neg, %f_neg, 0f3FB8AA3B;\n \
411 ex2.approx{ty} %f_exp, %f_neg;\n \
412 add{ty} %f_denom, %f_exp, 0f3F800000;\n \
413 rcp.approx{ty} %f_y, %f_denom;\n \
414 st.global{ty} [%rd_b], %f_y;"
415 ));
416 });
417 b.ret();
418 })
419 .build()
420 }
421
422 fn generate_gelu(&self) -> Result<String, PtxGenError> {
429 let kernel_name = self.kernel_name();
430 let ty = self.ty_str();
431 let byte_size = self.precision.size_bytes();
432
433 KernelBuilder::new(&kernel_name)
434 .target(self.target)
435 .param("a_ptr", PtxType::U64)
436 .param("b_ptr", PtxType::U64)
437 .param("n", PtxType::U32)
438 .max_threads_per_block(256)
439 .body(move |b| {
440 let tid = b.global_thread_id_x();
441 let tid_name = tid.to_string();
442 let n_reg = b.load_param_u32("n");
443 b.if_lt_u32(tid, n_reg, move |b| {
444 let a_ptr = b.load_param_u64("a_ptr");
445 let b_ptr = b.load_param_u64("b_ptr");
446
447 b.raw_ptx(&format!(
448 "cvt.u64.u32 %rd_off, {tid_name};\n \
449 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
450 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
451 add.u64 %rd_b, {b_ptr}, %rd_off;"
452 ));
453
454 b.raw_ptx(&format!(
462 "ld.global{ty} %f_x, [%rd_a];\n \
463 mul{ty} %f_x3, %f_x, %f_x;\n \
464 mul{ty} %f_x3, %f_x3, %f_x;\n \
465 mul{ty} %f_x3, %f_x3, 0f3D372713;\n \
466 add{ty} %f_inner, %f_x, %f_x3;\n \
467 mul{ty} %f_inner, %f_inner, 0f3F4C422A;\n \
468 mul{ty} %f_2a, %f_inner, 0f40000000;\n \
469 neg{ty} %f_neg2a, %f_2a;\n \
470 mul{ty} %f_neg2a, %f_neg2a, 0f3FB8AA3B;\n \
471 ex2.approx{ty} %f_exp, %f_neg2a;\n \
472 add{ty} %f_denom, %f_exp, 0f3F800000;\n \
473 rcp.approx{ty} %f_sig, %f_denom;\n \
474 mul{ty} %f_sig, %f_sig, 0f40000000;\n \
475 sub{ty} %f_tanh, %f_sig, 0f3F800000;\n \
476 add{ty} %f_tanh, %f_tanh, 0f3F800000;\n \
477 mul{ty} %f_y, 0f3F000000, %f_x;\n \
478 mul{ty} %f_y, %f_y, %f_tanh;\n \
479 st.global{ty} [%rd_b], %f_y;"
480 ));
481 });
482 b.ret();
483 })
484 .build()
485 }
486
487 fn generate_silu(&self) -> Result<String, PtxGenError> {
489 let kernel_name = self.kernel_name();
490 let ty = self.ty_str();
491 let byte_size = self.precision.size_bytes();
492
493 KernelBuilder::new(&kernel_name)
494 .target(self.target)
495 .param("a_ptr", PtxType::U64)
496 .param("b_ptr", PtxType::U64)
497 .param("n", PtxType::U32)
498 .max_threads_per_block(256)
499 .body(move |b| {
500 let tid = b.global_thread_id_x();
501 let tid_name = tid.to_string();
502 let n_reg = b.load_param_u32("n");
503 b.if_lt_u32(tid, n_reg, move |b| {
504 let a_ptr = b.load_param_u64("a_ptr");
505 let b_ptr = b.load_param_u64("b_ptr");
506
507 b.raw_ptx(&format!(
508 "cvt.u64.u32 %rd_off, {tid_name};\n \
509 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
510 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
511 add.u64 %rd_b, {b_ptr}, %rd_off;"
512 ));
513
514 b.raw_ptx(&format!(
516 "ld.global{ty} %f_x, [%rd_a];\n \
517 neg{ty} %f_neg, %f_x;\n \
518 mul{ty} %f_neg, %f_neg, 0f3FB8AA3B;\n \
519 ex2.approx{ty} %f_exp, %f_neg;\n \
520 add{ty} %f_denom, %f_exp, 0f3F800000;\n \
521 rcp.approx{ty} %f_sig, %f_denom;\n \
522 mul{ty} %f_y, %f_x, %f_sig;\n \
523 st.global{ty} [%rd_b], %f_y;"
524 ));
525 });
526 b.ret();
527 })
528 .build()
529 }
530
531 fn generate_tanh(&self) -> Result<String, PtxGenError> {
533 let kernel_name = self.kernel_name();
534 let ty = self.ty_str();
535 let byte_size = self.precision.size_bytes();
536
537 KernelBuilder::new(&kernel_name)
538 .target(self.target)
539 .param("a_ptr", PtxType::U64)
540 .param("b_ptr", PtxType::U64)
541 .param("n", PtxType::U32)
542 .max_threads_per_block(256)
543 .body(move |b| {
544 let tid = b.global_thread_id_x();
545 let tid_name = tid.to_string();
546 let n_reg = b.load_param_u32("n");
547 b.if_lt_u32(tid, n_reg, move |b| {
548 let a_ptr = b.load_param_u64("a_ptr");
549 let b_ptr = b.load_param_u64("b_ptr");
550
551 b.raw_ptx(&format!(
552 "cvt.u64.u32 %rd_off, {tid_name};\n \
553 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
554 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
555 add.u64 %rd_b, {b_ptr}, %rd_off;"
556 ));
557
558 b.raw_ptx(&format!(
560 "ld.global{ty} %f_x, [%rd_a];\n \
561 mul{ty} %f_2x, %f_x, 0f40000000;\n \
562 neg{ty} %f_neg, %f_2x;\n \
563 mul{ty} %f_neg, %f_neg, 0f3FB8AA3B;\n \
564 ex2.approx{ty} %f_exp, %f_neg;\n \
565 add{ty} %f_denom, %f_exp, 0f3F800000;\n \
566 rcp.approx{ty} %f_sig, %f_denom;\n \
567 mul{ty} %f_y, %f_sig, 0f40000000;\n \
568 sub{ty} %f_y, %f_y, 0f3F800000;\n \
569 st.global{ty} [%rd_b], %f_y;"
570 ));
571 });
572 b.ret();
573 })
574 .build()
575 }
576
577 fn generate_unary(&self, op_name: &str) -> Result<String, PtxGenError> {
579 let kernel_name = self.kernel_name();
580 let ty = self.ty_str();
581 let byte_size = self.precision.size_bytes();
582 let op_name = op_name.to_string();
583
584 KernelBuilder::new(&kernel_name)
585 .target(self.target)
586 .param("a_ptr", PtxType::U64)
587 .param("b_ptr", PtxType::U64)
588 .param("n", PtxType::U32)
589 .max_threads_per_block(256)
590 .body(move |b| {
591 let tid = b.global_thread_id_x();
592 let tid_name = tid.to_string();
593 let n_reg = b.load_param_u32("n");
594 b.if_lt_u32(tid, n_reg, move |b| {
595 let a_ptr = b.load_param_u64("a_ptr");
596 let b_ptr = b.load_param_u64("b_ptr");
597
598 b.raw_ptx(&format!(
599 "cvt.u64.u32 %rd_off, {tid_name};\n \
600 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
601 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
602 add.u64 %rd_b, {b_ptr}, %rd_off;"
603 ));
604
605 b.raw_ptx(&format!(
606 "ld.global{ty} %f_x, [%rd_a];\n \
607 {op_name}{ty} %f_y, %f_x;\n \
608 st.global{ty} [%rd_b], %f_y;"
609 ));
610 });
611 b.ret();
612 })
613 .build()
614 }
615
616 fn generate_sqrt(&self) -> Result<String, PtxGenError> {
618 let kernel_name = self.kernel_name();
619 let ty = self.ty_str();
620 let byte_size = self.precision.size_bytes();
621
622 KernelBuilder::new(&kernel_name)
623 .target(self.target)
624 .param("a_ptr", PtxType::U64)
625 .param("b_ptr", PtxType::U64)
626 .param("n", PtxType::U32)
627 .max_threads_per_block(256)
628 .body(move |b| {
629 let tid = b.global_thread_id_x();
630 let tid_name = tid.to_string();
631 let n_reg = b.load_param_u32("n");
632 b.if_lt_u32(tid, n_reg, move |b| {
633 let a_ptr = b.load_param_u64("a_ptr");
634 let b_ptr = b.load_param_u64("b_ptr");
635
636 b.raw_ptx(&format!(
637 "cvt.u64.u32 %rd_off, {tid_name};\n \
638 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
639 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
640 add.u64 %rd_b, {b_ptr}, %rd_off;"
641 ));
642
643 b.raw_ptx(&format!(
644 "ld.global{ty} %f_x, [%rd_a];\n \
645 sqrt.rn{ty} %f_y, %f_x;\n \
646 st.global{ty} [%rd_b], %f_y;"
647 ));
648 });
649 b.ret();
650 })
651 .build()
652 }
653
654 fn generate_rsqrt(&self) -> Result<String, PtxGenError> {
656 let kernel_name = self.kernel_name();
657 let ty = self.ty_str();
658 let byte_size = self.precision.size_bytes();
659
660 KernelBuilder::new(&kernel_name)
661 .target(self.target)
662 .param("a_ptr", PtxType::U64)
663 .param("b_ptr", PtxType::U64)
664 .param("n", PtxType::U32)
665 .max_threads_per_block(256)
666 .body(move |b| {
667 let tid = b.global_thread_id_x();
668 let tid_name = tid.to_string();
669 let n_reg = b.load_param_u32("n");
670 b.if_lt_u32(tid, n_reg, move |b| {
671 let a_ptr = b.load_param_u64("a_ptr");
672 let b_ptr = b.load_param_u64("b_ptr");
673
674 b.raw_ptx(&format!(
675 "cvt.u64.u32 %rd_off, {tid_name};\n \
676 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
677 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
678 add.u64 %rd_b, {b_ptr}, %rd_off;"
679 ));
680
681 b.raw_ptx(&format!(
682 "ld.global{ty} %f_x, [%rd_a];\n \
683 rsqrt.approx{ty} %f_y, %f_x;\n \
684 st.global{ty} [%rd_b], %f_y;"
685 ));
686 });
687 b.ret();
688 })
689 .build()
690 }
691
692 fn generate_exp(&self) -> Result<String, PtxGenError> {
694 let kernel_name = self.kernel_name();
695 let ty = self.ty_str();
696 let byte_size = self.precision.size_bytes();
697
698 KernelBuilder::new(&kernel_name)
699 .target(self.target)
700 .param("a_ptr", PtxType::U64)
701 .param("b_ptr", PtxType::U64)
702 .param("n", PtxType::U32)
703 .max_threads_per_block(256)
704 .body(move |b| {
705 let tid = b.global_thread_id_x();
706 let tid_name = tid.to_string();
707 let n_reg = b.load_param_u32("n");
708 b.if_lt_u32(tid, n_reg, move |b| {
709 let a_ptr = b.load_param_u64("a_ptr");
710 let b_ptr = b.load_param_u64("b_ptr");
711
712 b.raw_ptx(&format!(
713 "cvt.u64.u32 %rd_off, {tid_name};\n \
714 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
715 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
716 add.u64 %rd_b, {b_ptr}, %rd_off;"
717 ));
718
719 b.raw_ptx(&format!(
721 "ld.global{ty} %f_x, [%rd_a];\n \
722 mul{ty} %f_x2, %f_x, 0f3FB8AA3B;\n \
723 ex2.approx{ty} %f_y, %f_x2;\n \
724 st.global{ty} [%rd_b], %f_y;"
725 ));
726 });
727 b.ret();
728 })
729 .build()
730 }
731
732 fn generate_log(&self) -> Result<String, PtxGenError> {
734 let kernel_name = self.kernel_name();
735 let ty = self.ty_str();
736 let byte_size = self.precision.size_bytes();
737
738 KernelBuilder::new(&kernel_name)
739 .target(self.target)
740 .param("a_ptr", PtxType::U64)
741 .param("b_ptr", PtxType::U64)
742 .param("n", PtxType::U32)
743 .max_threads_per_block(256)
744 .body(move |b| {
745 let tid = b.global_thread_id_x();
746 let tid_name = tid.to_string();
747 let n_reg = b.load_param_u32("n");
748 b.if_lt_u32(tid, n_reg, move |b| {
749 let a_ptr = b.load_param_u64("a_ptr");
750 let b_ptr = b.load_param_u64("b_ptr");
751
752 b.raw_ptx(&format!(
753 "cvt.u64.u32 %rd_off, {tid_name};\n \
754 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
755 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
756 add.u64 %rd_b, {b_ptr}, %rd_off;"
757 ));
758
759 b.raw_ptx(&format!(
761 "ld.global{ty} %f_x, [%rd_a];\n \
762 lg2.approx{ty} %f_lg, %f_x;\n \
763 mul{ty} %f_y, %f_lg, 0f3F317218;\n \
764 st.global{ty} [%rd_b], %f_y;"
765 ));
766 });
767 b.ret();
768 })
769 .build()
770 }
771
772 fn generate_ceil(&self) -> Result<String, PtxGenError> {
776 let kernel_name = self.kernel_name();
777 let ty = self.ty_str();
778 let byte_size = self.precision.size_bytes();
779
780 KernelBuilder::new(&kernel_name)
781 .target(self.target)
782 .param("a_ptr", PtxType::U64)
783 .param("b_ptr", PtxType::U64)
784 .param("n", PtxType::U32)
785 .max_threads_per_block(256)
786 .body(move |b| {
787 let tid = b.global_thread_id_x();
788 let tid_name = tid.to_string();
789 let n_reg = b.load_param_u32("n");
790 b.if_lt_u32(tid, n_reg, move |b| {
791 let a_ptr = b.load_param_u64("a_ptr");
792 let b_ptr = b.load_param_u64("b_ptr");
793
794 b.raw_ptx(&format!(
795 "cvt.u64.u32 %rd_off, {tid_name};\n \
796 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
797 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
798 add.u64 %rd_b, {b_ptr}, %rd_off;"
799 ));
800
801 b.raw_ptx(&format!(
802 "ld.global{ty} %f_x, [%rd_a];\n \
803 cvt.rpi{ty}{ty} %f_y, %f_x;\n \
804 st.global{ty} [%rd_b], %f_y;"
805 ));
806 });
807 b.ret();
808 })
809 .build()
810 }
811
812 fn generate_floor(&self) -> Result<String, PtxGenError> {
816 let kernel_name = self.kernel_name();
817 let ty = self.ty_str();
818 let byte_size = self.precision.size_bytes();
819
820 KernelBuilder::new(&kernel_name)
821 .target(self.target)
822 .param("a_ptr", PtxType::U64)
823 .param("b_ptr", PtxType::U64)
824 .param("n", PtxType::U32)
825 .max_threads_per_block(256)
826 .body(move |b| {
827 let tid = b.global_thread_id_x();
828 let tid_name = tid.to_string();
829 let n_reg = b.load_param_u32("n");
830 b.if_lt_u32(tid, n_reg, move |b| {
831 let a_ptr = b.load_param_u64("a_ptr");
832 let b_ptr = b.load_param_u64("b_ptr");
833
834 b.raw_ptx(&format!(
835 "cvt.u64.u32 %rd_off, {tid_name};\n \
836 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
837 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
838 add.u64 %rd_b, {b_ptr}, %rd_off;"
839 ));
840
841 b.raw_ptx(&format!(
842 "ld.global{ty} %f_x, [%rd_a];\n \
843 cvt.rmi{ty}{ty} %f_y, %f_x;\n \
844 st.global{ty} [%rd_b], %f_y;"
845 ));
846 });
847 b.ret();
848 })
849 .build()
850 }
851
852 fn generate_hard_sigmoid(&self) -> Result<String, PtxGenError> {
856 let kernel_name = self.kernel_name();
857 let ty = self.ty_str();
858 let byte_size = self.precision.size_bytes();
859 let zero_lit = float_zero_literal(self.precision);
860
861 KernelBuilder::new(&kernel_name)
862 .target(self.target)
863 .param("a_ptr", PtxType::U64)
864 .param("b_ptr", PtxType::U64)
865 .param("n", PtxType::U32)
866 .max_threads_per_block(256)
867 .body(move |b| {
868 let tid = b.global_thread_id_x();
869 let tid_name = tid.to_string();
870 let n_reg = b.load_param_u32("n");
871 b.if_lt_u32(tid, n_reg, move |b| {
872 let a_ptr = b.load_param_u64("a_ptr");
873 let b_ptr = b.load_param_u64("b_ptr");
874
875 b.raw_ptx(&format!(
876 "cvt.u64.u32 %rd_off, {tid_name};\n \
877 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
878 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
879 add.u64 %rd_b, {b_ptr}, %rd_off;"
880 ));
881
882 b.raw_ptx(&format!(
885 "ld.global{ty} %f_x, [%rd_a];\n \
886 mul{ty} %f_ax, %f_x, 0f3E4CCCCD;\n \
887 add{ty} %f_lin, %f_ax, 0f3F000000;\n \
888 min{ty} %f_clip, %f_lin, 0f3F800000;\n \
889 max{ty} %f_y, %f_clip, {zero_lit};\n \
890 st.global{ty} [%rd_b], %f_y;"
891 ));
892 });
893 b.ret();
894 })
895 .build()
896 }
897
898 fn generate_hard_swish(&self) -> Result<String, PtxGenError> {
902 let kernel_name = self.kernel_name();
903 let ty = self.ty_str();
904 let byte_size = self.precision.size_bytes();
905 let zero_lit = float_zero_literal(self.precision);
906
907 KernelBuilder::new(&kernel_name)
908 .target(self.target)
909 .param("a_ptr", PtxType::U64)
910 .param("b_ptr", PtxType::U64)
911 .param("n", PtxType::U32)
912 .max_threads_per_block(256)
913 .body(move |b| {
914 let tid = b.global_thread_id_x();
915 let tid_name = tid.to_string();
916 let n_reg = b.load_param_u32("n");
917 b.if_lt_u32(tid, n_reg, move |b| {
918 let a_ptr = b.load_param_u64("a_ptr");
919 let b_ptr = b.load_param_u64("b_ptr");
920
921 b.raw_ptx(&format!(
922 "cvt.u64.u32 %rd_off, {tid_name};\n \
923 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
924 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
925 add.u64 %rd_b, {b_ptr}, %rd_off;"
926 ));
927
928 b.raw_ptx(&format!(
931 "ld.global{ty} %f_x, [%rd_a];\n \
932 add{ty} %f_xp3, %f_x, 0f40400000;\n \
933 min{ty} %f_clip, %f_xp3, 0f40C00000;\n \
934 max{ty} %f_clip, %f_clip, {zero_lit};\n \
935 mul{ty} %f_div, %f_clip, 0f3E2AAAAB;\n \
936 mul{ty} %f_y, %f_x, %f_div;\n \
937 st.global{ty} [%rd_b], %f_y;"
938 ));
939 });
940 b.ret();
941 })
942 .build()
943 }
944
945 fn generate_softplus(&self) -> Result<String, PtxGenError> {
949 let kernel_name = self.kernel_name();
950 let ty = self.ty_str();
951 let byte_size = self.precision.size_bytes();
952
953 KernelBuilder::new(&kernel_name)
954 .target(self.target)
955 .param("a_ptr", PtxType::U64)
956 .param("b_ptr", PtxType::U64)
957 .param("n", PtxType::U32)
958 .max_threads_per_block(256)
959 .body(move |b| {
960 let tid = b.global_thread_id_x();
961 let tid_name = tid.to_string();
962 let n_reg = b.load_param_u32("n");
963 b.if_lt_u32(tid, n_reg, move |b| {
964 let a_ptr = b.load_param_u64("a_ptr");
965 let b_ptr = b.load_param_u64("b_ptr");
966
967 b.raw_ptx(&format!(
968 "cvt.u64.u32 %rd_off, {tid_name};\n \
969 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
970 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
971 add.u64 %rd_b, {b_ptr}, %rd_off;"
972 ));
973
974 b.raw_ptx(&format!(
979 "ld.global{ty} %f_x, [%rd_a];\n \
980 mul{ty} %f_xe, %f_x, 0f3FB8AA3B;\n \
981 ex2.approx{ty} %f_exp, %f_xe;\n \
982 add{ty} %f_sum, %f_exp, 0f3F800000;\n \
983 lg2.approx{ty} %f_lg, %f_sum;\n \
984 mul{ty} %f_y, %f_lg, 0f3F317218;\n \
985 st.global{ty} [%rd_b], %f_y;"
986 ));
987 });
988 b.ret();
989 })
990 .build()
991 }
992
993 fn generate_leaky_relu(&self) -> Result<String, PtxGenError> {
997 let kernel_name = self.kernel_name();
998 let ty = self.ty_str();
999 let byte_size = self.precision.size_bytes();
1000 let zero_lit = float_zero_literal(self.precision);
1001
1002 KernelBuilder::new(&kernel_name)
1003 .target(self.target)
1004 .param("a_ptr", PtxType::U64)
1005 .param("b_ptr", PtxType::U64)
1006 .param("n", PtxType::U32)
1007 .max_threads_per_block(256)
1008 .body(move |b| {
1009 let tid = b.global_thread_id_x();
1010 let tid_name = tid.to_string();
1011 let n_reg = b.load_param_u32("n");
1012 b.if_lt_u32(tid, n_reg, move |b| {
1013 let a_ptr = b.load_param_u64("a_ptr");
1014 let b_ptr = b.load_param_u64("b_ptr");
1015
1016 b.raw_ptx(&format!(
1017 "cvt.u64.u32 %rd_off, {tid_name};\n \
1018 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
1019 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
1020 add.u64 %rd_b, {b_ptr}, %rd_off;"
1021 ));
1022
1023 b.raw_ptx(&format!(
1027 "ld.global{ty} %f_x, [%rd_a];\n \
1028 mul{ty} %f_leak, %f_x, 0f3C23D70A;\n \
1029 setp.ge{ty} %p_ge, %f_x, {zero_lit};\n \
1030 selp{ty} %f_y, %f_x, %f_leak, %p_ge;\n \
1031 st.global{ty} [%rd_b], %f_y;"
1032 ));
1033 });
1034 b.ret();
1035 })
1036 .build()
1037 }
1038
1039 fn generate_scale(&self) -> Result<String, PtxGenError> {
1041 let kernel_name = self.kernel_name();
1042 let ty = self.ty_str();
1043 let byte_size = self.precision.size_bytes();
1044 let scalar_ty = scalar_param_type(self.precision);
1045
1046 KernelBuilder::new(&kernel_name)
1047 .target(self.target)
1048 .param("a_ptr", PtxType::U64)
1049 .param("b_ptr", PtxType::U64)
1050 .param("alpha", scalar_ty)
1051 .param("n", PtxType::U32)
1052 .max_threads_per_block(256)
1053 .body(move |b| {
1054 let tid = b.global_thread_id_x();
1055 let tid_name = tid.to_string();
1056 let n_reg = b.load_param_u32("n");
1057 b.if_lt_u32(tid, n_reg, move |b| {
1058 let a_ptr = b.load_param_u64("a_ptr");
1059 let b_ptr = b.load_param_u64("b_ptr");
1060
1061 b.raw_ptx(&format!(
1062 "cvt.u64.u32 %rd_off, {tid_name};\n \
1063 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
1064 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
1065 add.u64 %rd_b, {b_ptr}, %rd_off;"
1066 ));
1067
1068 b.raw_ptx(&format!(
1069 "ld.param{ty} %f_alpha, [%param_alpha];\n \
1070 ld.global{ty} %f_x, [%rd_a];\n \
1071 mul{ty} %f_y, %f_alpha, %f_x;\n \
1072 st.global{ty} [%rd_b], %f_y;"
1073 ));
1074 });
1075 b.ret();
1076 })
1077 .build()
1078 }
1079
1080 fn generate_add_scalar(&self) -> Result<String, PtxGenError> {
1082 let kernel_name = self.kernel_name();
1083 let ty = self.ty_str();
1084 let byte_size = self.precision.size_bytes();
1085 let scalar_ty = scalar_param_type(self.precision);
1086
1087 KernelBuilder::new(&kernel_name)
1088 .target(self.target)
1089 .param("a_ptr", PtxType::U64)
1090 .param("b_ptr", PtxType::U64)
1091 .param("scalar", scalar_ty)
1092 .param("n", PtxType::U32)
1093 .max_threads_per_block(256)
1094 .body(move |b| {
1095 let tid = b.global_thread_id_x();
1096 let tid_name = tid.to_string();
1097 let n_reg = b.load_param_u32("n");
1098 b.if_lt_u32(tid, n_reg, move |b| {
1099 let a_ptr = b.load_param_u64("a_ptr");
1100 let b_ptr = b.load_param_u64("b_ptr");
1101
1102 b.raw_ptx(&format!(
1103 "cvt.u64.u32 %rd_off, {tid_name};\n \
1104 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
1105 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
1106 add.u64 %rd_b, {b_ptr}, %rd_off;"
1107 ));
1108
1109 b.raw_ptx(&format!(
1110 "ld.param{ty} %f_s, [%param_scalar];\n \
1111 ld.global{ty} %f_x, [%rd_a];\n \
1112 add{ty} %f_y, %f_x, %f_s;\n \
1113 st.global{ty} [%rd_b], %f_y;"
1114 ));
1115 });
1116 b.ret();
1117 })
1118 .build()
1119 }
1120
1121 fn generate_fused_add_relu(&self) -> Result<String, PtxGenError> {
1123 let kernel_name = self.kernel_name();
1124 let ty = self.ty_str();
1125 let byte_size = self.precision.size_bytes();
1126 let zero_lit = float_zero_literal(self.precision);
1127
1128 KernelBuilder::new(&kernel_name)
1129 .target(self.target)
1130 .param("a_ptr", PtxType::U64)
1131 .param("b_ptr", PtxType::U64)
1132 .param("c_ptr", PtxType::U64)
1133 .param("n", PtxType::U32)
1134 .max_threads_per_block(256)
1135 .body(move |b| {
1136 let tid = b.global_thread_id_x();
1137 let tid_name = tid.to_string();
1138 let n_reg = b.load_param_u32("n");
1139 b.if_lt_u32(tid, n_reg, move |b| {
1140 let a_ptr = b.load_param_u64("a_ptr");
1141 let b_ptr = b.load_param_u64("b_ptr");
1142 let c_ptr = b.load_param_u64("c_ptr");
1143
1144 b.raw_ptx(&format!(
1145 "cvt.u64.u32 %rd_off, {tid_name};\n \
1146 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
1147 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
1148 add.u64 %rd_b, {b_ptr}, %rd_off;\n \
1149 add.u64 %rd_c, {c_ptr}, %rd_off;"
1150 ));
1151
1152 b.raw_ptx(&format!(
1153 "ld.global{ty} %f_a, [%rd_a];\n \
1154 ld.global{ty} %f_b, [%rd_b];\n \
1155 add{ty} %f_sum, %f_a, %f_b;\n \
1156 max{ty} %f_y, %f_sum, {zero_lit};\n \
1157 st.global{ty} [%rd_c], %f_y;"
1158 ));
1159 });
1160 b.ret();
1161 })
1162 .build()
1163 }
1164
1165 fn generate_fused_scale_add(&self) -> Result<String, PtxGenError> {
1167 let kernel_name = self.kernel_name();
1168 let ty = self.ty_str();
1169 let byte_size = self.precision.size_bytes();
1170 let scalar_ty = scalar_param_type(self.precision);
1171
1172 KernelBuilder::new(&kernel_name)
1173 .target(self.target)
1174 .param("a_ptr", PtxType::U64)
1175 .param("b_ptr", PtxType::U64)
1176 .param("c_ptr", PtxType::U64)
1177 .param("alpha", scalar_ty)
1178 .param("beta", scalar_ty)
1179 .param("n", PtxType::U32)
1180 .max_threads_per_block(256)
1181 .body(move |b| {
1182 let tid = b.global_thread_id_x();
1183 let tid_name = tid.to_string();
1184 let n_reg = b.load_param_u32("n");
1185 b.if_lt_u32(tid, n_reg, move |b| {
1186 let a_ptr = b.load_param_u64("a_ptr");
1187 let b_ptr = b.load_param_u64("b_ptr");
1188 let c_ptr = b.load_param_u64("c_ptr");
1189
1190 b.raw_ptx(&format!(
1191 "cvt.u64.u32 %rd_off, {tid_name};\n \
1192 mul.lo.u64 %rd_off, %rd_off, {byte_size};\n \
1193 add.u64 %rd_a, {a_ptr}, %rd_off;\n \
1194 add.u64 %rd_b, {b_ptr}, %rd_off;\n \
1195 add.u64 %rd_c, {c_ptr}, %rd_off;"
1196 ));
1197
1198 b.raw_ptx(&format!(
1199 "ld.param{ty} %f_alpha, [%param_alpha];\n \
1200 ld.param{ty} %f_beta, [%param_beta];\n \
1201 ld.global{ty} %f_a, [%rd_a];\n \
1202 ld.global{ty} %f_b, [%rd_b];\n \
1203 mul{ty} %f_aa, %f_alpha, %f_a;\n \
1204 mul{ty} %f_bb, %f_beta, %f_b;\n \
1205 add{ty} %f_y, %f_aa, %f_bb;\n \
1206 st.global{ty} [%rd_c], %f_y;"
1207 ));
1208 });
1209 b.ret();
1210 })
1211 .build()
1212 }
1213}
1214
1215const fn float_zero_literal(ty: PtxType) -> &'static str {
1217 match ty {
1218 PtxType::F64 => "0d0000000000000000",
1219 _ => "0f00000000",
1220 }
1221}
1222
1223const fn scalar_param_type(ty: PtxType) -> PtxType {
1227 match ty {
1228 PtxType::F16 | PtxType::BF16 => PtxType::F32,
1229 other => other,
1230 }
1231}
1232
1233#[cfg(test)]
1234mod tests {
1235 use super::*;
1236 use crate::arch::SmVersion;
1237
1238 #[test]
1239 fn elementwise_op_names() {
1240 assert_eq!(ElementwiseOp::Add.as_str(), "add");
1241 assert_eq!(ElementwiseOp::Relu.as_str(), "relu");
1242 assert_eq!(ElementwiseOp::FusedScaleAdd.as_str(), "fused_scale_add");
1243 }
1244
1245 #[test]
1246 fn elementwise_op_classification() {
1247 assert!(ElementwiseOp::Add.is_binary());
1248 assert!(ElementwiseOp::Sub.is_binary());
1249 assert!(!ElementwiseOp::Relu.is_binary());
1250 assert!(!ElementwiseOp::Sigmoid.is_binary());
1251
1252 assert!(ElementwiseOp::Scale.needs_scalar());
1253 assert!(ElementwiseOp::FusedScaleAdd.needs_scalar());
1254 assert!(!ElementwiseOp::Add.needs_scalar());
1255 }
1256
1257 #[test]
1258 fn kernel_name_format() {
1259 let t = ElementwiseTemplate::new(ElementwiseOp::Add, PtxType::F32, SmVersion::Sm80);
1260 assert_eq!(t.kernel_name(), "elementwise_add_f32");
1261
1262 let t2 = ElementwiseTemplate::new(ElementwiseOp::Relu, PtxType::F16, SmVersion::Sm90);
1263 assert_eq!(t2.kernel_name(), "elementwise_relu_f16");
1264 }
1265
1266 #[test]
1267 fn invalid_precision_rejected() {
1268 let t = ElementwiseTemplate::new(ElementwiseOp::Add, PtxType::U32, SmVersion::Sm80);
1269 let result = t.generate();
1270 assert!(result.is_err());
1271 }
1272
1273 #[test]
1274 fn generate_add_f32() {
1275 let t = ElementwiseTemplate::new(ElementwiseOp::Add, PtxType::F32, SmVersion::Sm80);
1276 let ptx = t.generate().expect("should generate add kernel");
1277 assert!(ptx.contains(".entry elementwise_add_f32"));
1278 assert!(ptx.contains(".target sm_80"));
1279 assert!(ptx.contains("add.f32"));
1280 }
1281
1282 #[test]
1283 fn generate_relu_f32() {
1284 let t = ElementwiseTemplate::new(ElementwiseOp::Relu, PtxType::F32, SmVersion::Sm80);
1285 let ptx = t.generate().expect("should generate relu kernel");
1286 assert!(ptx.contains(".entry elementwise_relu_f32"));
1287 assert!(ptx.contains("max.f32"));
1288 }
1289
1290 #[test]
1291 fn generate_sigmoid_f32() {
1292 let t = ElementwiseTemplate::new(ElementwiseOp::Sigmoid, PtxType::F32, SmVersion::Sm80);
1293 let ptx = t.generate().expect("should generate sigmoid kernel");
1294 assert!(ptx.contains("ex2.approx.f32"));
1295 assert!(ptx.contains("rcp.approx.f32"));
1296 }
1297
1298 #[test]
1299 fn generate_gelu_f32() {
1300 let t = ElementwiseTemplate::new(ElementwiseOp::Gelu, PtxType::F32, SmVersion::Sm80);
1301 let ptx = t.generate().expect("should generate gelu kernel");
1302 assert!(ptx.contains("ex2.approx.f32"));
1303 assert!(ptx.contains(".entry elementwise_gelu_f32"));
1304 }
1305
1306 #[test]
1313 fn test_relu_ptx_correct_arithmetic() {
1314 let t = ElementwiseTemplate::new(ElementwiseOp::Relu, PtxType::F32, SmVersion::Sm80);
1315 let ptx = t.generate().expect("relu PTX generation failed");
1316 assert!(ptx.contains("max.f32"), "relu must emit max.f32");
1318 assert!(ptx.contains("0f00000000"), "relu must compare against 0.0");
1320 assert!(!ptx.contains("sin.approx"), "relu must not emit sin");
1322 assert!(!ptx.contains("cos.approx"), "relu must not emit cos");
1323 assert!(!ptx.contains("ex2.approx"), "relu must not use exp");
1324 assert!(!ptx.contains("rcp.approx"), "relu must not use rcp");
1325 }
1326
1327 #[test]
1331 fn test_sigmoid_ptx_contains_exp_and_rcp() {
1332 let t = ElementwiseTemplate::new(ElementwiseOp::Sigmoid, PtxType::F32, SmVersion::Sm80);
1333 let ptx = t.generate().expect("sigmoid PTX generation failed");
1334 assert!(ptx.contains("neg.f32"), "sigmoid must negate input");
1336 assert!(
1338 ptx.contains("ex2.approx.f32"),
1339 "sigmoid must use ex2.approx for exp"
1340 );
1341 assert!(ptx.contains("0f3FB8AA3B"), "sigmoid must scale by log2(e)");
1343 assert!(
1345 ptx.contains("rcp.approx.f32"),
1346 "sigmoid must use rcp.approx for 1/denom"
1347 );
1348 assert!(
1350 ptx.contains("0f3F800000"),
1351 "sigmoid must add 1.0 to denominator"
1352 );
1353 assert!(!ptx.contains("sin.approx"), "sigmoid must not emit sin");
1355 assert!(
1356 !ptx.contains("max.f32"),
1357 "sigmoid must not use max (relu op)"
1358 );
1359 }
1360
1361 #[test]
1364 fn test_gelu_ptx_contains_tanh_approximation() {
1365 let t = ElementwiseTemplate::new(ElementwiseOp::Gelu, PtxType::F32, SmVersion::Sm80);
1366 let ptx = t.generate().expect("gelu PTX generation failed");
1367 assert!(
1369 ptx.contains("0f3D372713"),
1370 "gelu must use 0.044715 constant"
1371 );
1372 assert!(
1374 ptx.contains("0f3F4C422A"),
1375 "gelu must use sqrt(2/pi) constant"
1376 );
1377 assert!(
1379 ptx.contains("ex2.approx.f32"),
1380 "gelu must use ex2.approx for tanh approximation"
1381 );
1382 assert!(
1383 ptx.contains("rcp.approx.f32"),
1384 "gelu must use rcp.approx inside tanh"
1385 );
1386 assert!(!ptx.contains("sin.approx"), "gelu must not emit sin");
1388 }
1389
1390 #[test]
1393 fn test_tanh_ptx_contains_exp_instructions() {
1394 let t = ElementwiseTemplate::new(ElementwiseOp::Tanh, PtxType::F32, SmVersion::Sm80);
1395 let ptx = t.generate().expect("tanh PTX generation failed");
1396 assert!(
1398 ptx.contains("ex2.approx.f32"),
1399 "tanh must use ex2.approx for exp"
1400 );
1401 assert!(
1403 ptx.contains("rcp.approx.f32"),
1404 "tanh must use rcp.approx in sigmoid step"
1405 );
1406 assert!(ptx.contains("0f40000000"), "tanh must scale by 2.0");
1408 assert!(
1410 ptx.contains("sub.f32"),
1411 "tanh must subtract 1.0 for tanh formula"
1412 );
1413 assert!(!ptx.contains("sin.approx"), "tanh must not emit sin");
1415 }
1416
1417 #[test]
1420 fn test_silu_ptx_contains_mul_and_sigmoid() {
1421 let t = ElementwiseTemplate::new(ElementwiseOp::Silu, PtxType::F32, SmVersion::Sm80);
1422 let ptx = t.generate().expect("silu PTX generation failed");
1423 assert!(
1425 ptx.contains("ex2.approx.f32"),
1426 "silu must use ex2.approx for sigmoid"
1427 );
1428 assert!(
1429 ptx.contains("rcp.approx.f32"),
1430 "silu must use rcp.approx for sigmoid"
1431 );
1432 assert!(
1434 ptx.contains("mul.f32"),
1435 "silu must multiply x by sigmoid(x)"
1436 );
1437 assert!(!ptx.contains("sin.approx"), "silu must not emit sin");
1439 assert!(!ptx.contains("max.f32"), "silu must not use relu max");
1440 }
1441
1442 #[test]
1445 fn test_elementwise_ptx_has_valid_headers() {
1446 let ops_and_types = [
1447 (ElementwiseOp::Add, PtxType::F32),
1448 (ElementwiseOp::Relu, PtxType::F32),
1449 (ElementwiseOp::Sigmoid, PtxType::F32),
1450 (ElementwiseOp::Gelu, PtxType::F32),
1451 (ElementwiseOp::Tanh, PtxType::F32),
1452 (ElementwiseOp::Silu, PtxType::F32),
1453 (ElementwiseOp::Neg, PtxType::F32),
1454 (ElementwiseOp::Exp, PtxType::F32),
1455 (ElementwiseOp::Log, PtxType::F32),
1456 ];
1457
1458 for (op, ty) in ops_and_types {
1459 let t = ElementwiseTemplate::new(op, ty, SmVersion::Sm80);
1460 let ptx = t
1461 .generate()
1462 .unwrap_or_else(|e| panic!("PTX generation failed for {op:?}: {e}"));
1463 assert!(
1464 ptx.contains(".version"),
1465 "PTX for {op:?} must have .version header"
1466 );
1467 assert!(
1468 ptx.contains(".target"),
1469 "PTX for {op:?} must have .target header"
1470 );
1471 assert!(
1472 ptx.contains(".entry"),
1473 "PTX for {op:?} must have .entry directive"
1474 );
1475 }
1476 }
1477
1478 fn cpu_relu_f32(x: f32) -> f32 {
1487 x.max(0.0)
1488 }
1489
1490 fn cpu_sigmoid_f32(x: f32) -> f32 {
1492 1.0 / (1.0 + (-x).exp())
1493 }
1494
1495 fn cpu_gelu_f32(x: f32) -> f32 {
1498 let k0: f32 = 0.797_884_6; let k1: f32 = 0.044_715;
1500 let inner = k0 * k1.mul_add(x * x * x, x);
1501 0.5 * x * (1.0 + inner.tanh())
1502 }
1503
1504 fn cpu_tanh_f32(x: f32) -> f32 {
1506 x.tanh()
1507 }
1508
1509 fn cpu_silu_f32(x: f32) -> f32 {
1511 x * cpu_sigmoid_f32(x)
1512 }
1513
1514 #[test]
1517 fn relu_precision_known_values() {
1518 assert!((cpu_relu_f32(0.0) - 0.0_f32).abs() < f32::EPSILON);
1519 assert!((cpu_relu_f32(-1.0) - 0.0_f32).abs() < f32::EPSILON);
1520 assert!((cpu_relu_f32(1.0) - 1.0_f32).abs() < f32::EPSILON);
1521 assert!((cpu_relu_f32(-0.001) - 0.0_f32).abs() < f32::EPSILON);
1522 assert!((cpu_relu_f32(100.0) - 100.0_f32).abs() < f32::EPSILON);
1523 }
1524
1525 #[test]
1526 fn relu_precision_negative_zero() {
1527 assert!(cpu_relu_f32(-0.0) >= 0.0);
1529 }
1530
1531 #[test]
1534 fn sigmoid_precision_known_values() {
1535 assert!((cpu_sigmoid_f32(0.0) - 0.5).abs() < 1e-7_f32);
1537 assert!((cpu_sigmoid_f32(100.0) - 1.0).abs() < 1e-6_f32);
1539 assert!(cpu_sigmoid_f32(-100.0).abs() < 1e-6_f32);
1541 let expected_sig1: f32 = 0.731_058_6;
1543 assert!(
1544 (cpu_sigmoid_f32(1.0) - expected_sig1).abs() < 1e-5_f32,
1545 "sigmoid(1.0) expected ~{expected_sig1}, got {}",
1546 cpu_sigmoid_f32(1.0)
1547 );
1548 }
1549
1550 #[test]
1551 fn sigmoid_output_in_unit_interval() {
1552 let inputs: &[f32] = &[-10.0, -1.0, 0.0, 1.0, 10.0];
1554 for &x in inputs {
1555 let s = cpu_sigmoid_f32(x);
1556 assert!(s > 0.0 && s < 1.0, "sigmoid({x}) = {s} not in (0,1)");
1557 }
1558 assert!(cpu_sigmoid_f32(-100.0) >= 0.0);
1560 assert!(cpu_sigmoid_f32(100.0) <= 1.0);
1561 }
1562
1563 #[test]
1566 fn gelu_precision_known_values() {
1567 assert!(cpu_gelu_f32(0.0).abs() < 1e-7_f32);
1569 assert!(
1571 (cpu_gelu_f32(1.0) - 0.8413_f32).abs() < 0.001_f32,
1572 "gelu(1) should be ~0.8413, got {}",
1573 cpu_gelu_f32(1.0)
1574 );
1575 assert!(
1577 (cpu_gelu_f32(-1.0) + 0.1587_f32).abs() < 0.001_f32,
1578 "gelu(-1) should be ~-0.1587, got {}",
1579 cpu_gelu_f32(-1.0)
1580 );
1581 assert!(
1583 (cpu_gelu_f32(5.0) - 5.0_f32).abs() < 0.001_f32,
1584 "gelu(5) should be ~5.0, got {}",
1585 cpu_gelu_f32(5.0)
1586 );
1587 }
1588
1589 #[test]
1590 fn gelu_sign_preservation() {
1591 assert!(cpu_gelu_f32(0.5) > 0.0);
1593 assert!(cpu_gelu_f32(2.0) > 0.0);
1594 assert!(cpu_gelu_f32(-2.0) < 0.0);
1596 }
1597
1598 #[test]
1601 fn tanh_precision_known_values() {
1602 assert!(cpu_tanh_f32(0.0).abs() < 1e-7_f32);
1603 let expected_tanh1: f32 = 0.761_594_2;
1604 assert!(
1605 (cpu_tanh_f32(1.0) - expected_tanh1).abs() < 1e-5_f32,
1606 "tanh(1.0) expected ~{expected_tanh1}, got {}",
1607 cpu_tanh_f32(1.0)
1608 );
1609 assert!(
1610 (cpu_tanh_f32(-1.0) + expected_tanh1).abs() < 1e-5_f32,
1611 "tanh(-1.0) expected ~-{expected_tanh1}, got {}",
1612 cpu_tanh_f32(-1.0)
1613 );
1614 assert!(
1616 (cpu_tanh_f32(10.0) - 1.0).abs() < 1e-5_f32,
1617 "tanh(10) should be ~1.0"
1618 );
1619 assert!(
1620 (cpu_tanh_f32(-10.0) + 1.0).abs() < 1e-5_f32,
1621 "tanh(-10) should be ~-1.0"
1622 );
1623 }
1624
1625 #[test]
1626 fn tanh_output_in_bounded_range() {
1627 let inputs: &[f32] = &[-5.0, -1.0, 0.0, 1.0, 5.0];
1629 for &x in inputs {
1630 let t = cpu_tanh_f32(x);
1631 assert!(t > -1.0 && t < 1.0, "tanh({x}) = {t} not in (-1,1)");
1632 }
1633 assert!(cpu_tanh_f32(-100.0) >= -1.0);
1635 assert!(cpu_tanh_f32(100.0) <= 1.0);
1636 }
1637
1638 #[test]
1641 fn silu_precision_known_values() {
1642 assert!(cpu_silu_f32(0.0).abs() < 1e-7_f32);
1644 let expected_sig1: f32 = 0.731_058_6;
1646 assert!(
1647 (cpu_silu_f32(1.0) - expected_sig1).abs() < 1e-5_f32,
1648 "silu(1.0) expected ~{expected_sig1}, got {}",
1649 cpu_silu_f32(1.0)
1650 );
1651 assert!(
1653 (cpu_silu_f32(-1.0) + 0.2689_f32).abs() < 0.001_f32,
1654 "silu(-1) should be ~-0.2689, got {}",
1655 cpu_silu_f32(-1.0)
1656 );
1657 }
1658
1659 #[test]
1660 fn silu_sign_matches_input() {
1661 for &x in &[0.1_f32, 0.5, 1.0, 2.0, 5.0] {
1663 assert!(
1664 cpu_silu_f32(x) > 0.0,
1665 "silu({x}) should be positive, got {}",
1666 cpu_silu_f32(x)
1667 );
1668 }
1669 for &x in &[-0.1_f32, -0.5, -2.0] {
1670 assert!(
1671 cpu_silu_f32(x) < 0.0,
1672 "silu({x}) should be negative, got {}",
1673 cpu_silu_f32(x)
1674 );
1675 }
1676 }
1677
1678 #[test]
1681 fn elementwise_ptx_generates_fused_add_relu() {
1682 let tmpl =
1683 ElementwiseTemplate::new(ElementwiseOp::FusedAddRelu, PtxType::F32, SmVersion::Sm80);
1684 let ptx = tmpl
1685 .generate()
1686 .expect("FusedAddRelu should generate successfully");
1687 assert!(
1688 ptx.contains("add"),
1689 "fused kernel should contain add instruction"
1690 );
1691 assert!(
1692 ptx.contains("max"),
1693 "fused kernel should contain max for relu"
1694 );
1695 }
1696
1697 #[test]
1700 fn elementwise_ops_precision_sweep() {
1701 let test_inputs: &[f32] = &[-5.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 5.0, 10.0];
1704
1705 for &x in test_inputs {
1706 assert!(
1708 cpu_relu_f32(x) >= 0.0,
1709 "relu({x}) = {} should be non-negative",
1710 cpu_relu_f32(x)
1711 );
1712
1713 let s = cpu_sigmoid_f32(x);
1715 assert!(s > 0.0 && s < 1.0, "sigmoid({x}) = {s} should be in (0,1)");
1716
1717 let t = cpu_tanh_f32(x);
1719 assert!(
1720 (-1.0_f32..=1.0).contains(&t),
1721 "tanh({x}) = {t} should be in [-1,1]"
1722 );
1723
1724 if x > 0.1 {
1726 assert!(
1727 cpu_silu_f32(x) > 0.0,
1728 "silu({x}) should be positive for positive input"
1729 );
1730 }
1731 }
1732 }
1733
1734 #[test]
1737 fn all_activation_ops_generate_ptx_for_f32() {
1738 let activation_ops = [
1739 ElementwiseOp::Relu,
1740 ElementwiseOp::Gelu,
1741 ElementwiseOp::Sigmoid,
1742 ElementwiseOp::Silu,
1743 ElementwiseOp::Tanh,
1744 ];
1745 for op in activation_ops {
1746 let t = ElementwiseTemplate::new(op, PtxType::F32, SmVersion::Sm80);
1747 let result = t.generate();
1748 assert!(
1749 result.is_ok(),
1750 "PTX generation failed for op {:?}: {:?}",
1751 op,
1752 result.err()
1753 );
1754 let ptx = result.expect("already checked is_ok");
1755 let name = op.as_str();
1756 assert!(
1757 ptx.contains(&format!(".entry elementwise_{name}_f32")),
1758 "PTX for {name} missing expected entry point"
1759 );
1760 }
1761 }
1762
1763 #[test]
1764 fn relu_ptx_uses_max_instruction() {
1765 let t = ElementwiseTemplate::new(ElementwiseOp::Relu, PtxType::F32, SmVersion::Sm80);
1767 let ptx = t.generate().expect("relu PTX generation should succeed");
1768 assert!(
1769 ptx.contains("max.f32"),
1770 "relu PTX must use max.f32 instruction"
1771 );
1772 }
1773
1774 #[test]
1775 fn tanh_ptx_uses_tanh_or_approx_sequence() {
1776 let t = ElementwiseTemplate::new(ElementwiseOp::Tanh, PtxType::F32, SmVersion::Sm80);
1778 let ptx = t.generate().expect("tanh PTX generation should succeed");
1779 let has_approx = ptx.contains("ex2.approx") || ptx.contains("tanh.approx");
1780 assert!(
1781 has_approx,
1782 "tanh PTX should use ex2.approx or tanh.approx, got:\n{ptx}"
1783 );
1784 }
1785}