Skip to main content

oxicuda_ptx/templates/
elementwise.rs

1//! Elementwise GPU operation templates.
2//!
3//! This module generates complete PTX kernels for unary and binary elementwise
4//! operations over device arrays. It supports basic arithmetic (`add`, `sub`, `mul`, `div`),
5//! activation functions (`ReLU`, GELU, sigmoid, `SiLU`, tanh), unary math (neg, abs,
6//! sqrt, rsqrt, exp, log), and fused operations (fused-add-relu, fused-scale-add).
7//!
8//! Each template produces a kernel that:
9//! 1. Computes a global thread index
10//! 2. Performs a bounds check against the array length
11//! 3. Loads input element(s)
12//! 4. Applies the operation
13//! 5. Stores the result
14//!
15//! # Example
16//!
17//! ```
18//! use oxicuda_ptx::templates::elementwise::{ElementwiseTemplate, ElementwiseOp};
19//! use oxicuda_ptx::ir::PtxType;
20//! use oxicuda_ptx::arch::SmVersion;
21//!
22//! let template = ElementwiseTemplate::new(
23//!     ElementwiseOp::Add,
24//!     PtxType::F32,
25//!     SmVersion::Sm80,
26//! );
27//! let ptx = template.generate().expect("PTX generation failed");
28//! assert!(ptx.contains("add.f32"));
29//! ```
30
31use crate::arch::SmVersion;
32use crate::builder::KernelBuilder;
33use crate::error::PtxGenError;
34use crate::ir::PtxType;
35
36/// Elementwise operation type.
37///
38/// Covers binary arithmetic, unary activations, unary math, and fused operations.
39/// Each variant determines the kernel signature (number of input/output pointers)
40/// and the PTX instruction sequence emitted in the kernel body.
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42pub enum ElementwiseOp {
43    /// Element-wise addition: `c[i] = a[i] + b[i]`.
44    Add,
45    /// Element-wise subtraction: `c[i] = a[i] - b[i]`.
46    Sub,
47    /// Element-wise multiplication: `c[i] = a[i] * b[i]`.
48    Mul,
49    /// Element-wise division: `c[i] = a[i] / b[i]`.
50    Div,
51    /// Rectified linear unit: `b[i] = max(0, a[i])`.
52    Relu,
53    /// Gaussian error linear unit (tanh approximation):
54    /// `b[i] = 0.5 * a[i] * (1 + tanh(sqrt(2/pi) * (a[i] + 0.044715 * a[i]^3)))`.
55    Gelu,
56    /// Sigmoid activation: `b[i] = 1 / (1 + exp(-a[i]))`.
57    Sigmoid,
58    /// Sigmoid linear unit: `b[i] = a[i] * sigmoid(a[i])`.
59    Silu,
60    /// Hyperbolic tangent: `b[i] = tanh(a[i])`.
61    Tanh,
62    /// Arithmetic negation: `b[i] = -a[i]`.
63    Neg,
64    /// Absolute value: `b[i] = |a[i]|`.
65    Abs,
66    /// Square root: `b[i] = sqrt(a[i])`.
67    Sqrt,
68    /// Reciprocal square root: `b[i] = 1 / sqrt(a[i])`.
69    Rsqrt,
70    /// Exponential: `b[i] = exp(a[i])`.
71    Exp,
72    /// Natural logarithm: `b[i] = ln(a[i])`.
73    Log,
74    /// Scalar scaling: `b[i] = alpha * a[i]`.
75    Scale,
76    /// Add scalar: `b[i] = a[i] + scalar`.
77    AddScalar,
78    /// Ceiling (round toward +inf): `b[i] = ceil(a[i])`.
79    Ceil,
80    /// Floor (round toward -inf): `b[i] = floor(a[i])`.
81    Floor,
82    /// Hard sigmoid: `b[i] = max(0, min(1, 0.2*a[i] + 0.5))`.
83    HardSigmoid,
84    /// Hard swish: `b[i] = a[i] * max(0, min(6, a[i]+3)) / 6`.
85    HardSwish,
86    /// Softplus: `b[i] = ln(1 + exp(a[i]))`.
87    Softplus,
88    /// Leaky relu: `b[i] = a[i] >= 0 ? a[i] : 0.01 * a[i]`.
89    LeakyRelu,
90    /// Fused add-relu: `c[i] = relu(a[i] + b[i])`.
91    FusedAddRelu,
92    /// Fused scale-add: `c[i] = alpha * a[i] + beta * b[i]`.
93    FusedScaleAdd,
94}
95
96impl ElementwiseOp {
97    /// Returns a short lowercase name suitable for kernel naming.
98    #[must_use]
99    pub const fn as_str(self) -> &'static str {
100        match self {
101            Self::Add => "add",
102            Self::Sub => "sub",
103            Self::Mul => "mul",
104            Self::Div => "div",
105            Self::Relu => "relu",
106            Self::Gelu => "gelu",
107            Self::Sigmoid => "sigmoid",
108            Self::Silu => "silu",
109            Self::Tanh => "tanh",
110            Self::Neg => "neg",
111            Self::Abs => "abs",
112            Self::Sqrt => "sqrt",
113            Self::Rsqrt => "rsqrt",
114            Self::Exp => "exp",
115            Self::Log => "log",
116            Self::Ceil => "ceil",
117            Self::Floor => "floor",
118            Self::HardSigmoid => "hard_sigmoid",
119            Self::HardSwish => "hard_swish",
120            Self::Softplus => "softplus",
121            Self::LeakyRelu => "leaky_relu",
122            Self::Scale => "scale",
123            Self::AddScalar => "add_scalar",
124            Self::FusedAddRelu => "fused_add_relu",
125            Self::FusedScaleAdd => "fused_scale_add",
126        }
127    }
128
129    /// Returns `true` if this is a binary operation requiring two input arrays.
130    #[must_use]
131    pub const fn is_binary(self) -> bool {
132        matches!(
133            self,
134            Self::Add
135                | Self::Sub
136                | Self::Mul
137                | Self::Div
138                | Self::FusedAddRelu
139                | Self::FusedScaleAdd
140        )
141    }
142
143    /// Returns `true` if this operation requires scalar parameter(s).
144    #[must_use]
145    pub const fn needs_scalar(self) -> bool {
146        matches!(self, Self::Scale | Self::AddScalar | Self::FusedScaleAdd)
147    }
148}
149
150/// Template for generating elementwise PTX kernels.
151///
152/// Combines an [`ElementwiseOp`], a precision ([`PtxType`]), and a target
153/// architecture ([`SmVersion`]) to produce a complete PTX module string.
154///
155/// The generated kernel handles global thread indexing and bounds checking.
156/// For complex activations (GELU, sigmoid, `SiLU`), the template emits
157/// approximate PTX instruction sequences using `ex2.approx` and `rcp.approx`.
158pub struct ElementwiseTemplate {
159    /// The elementwise operation to generate.
160    pub op: ElementwiseOp,
161    /// The data precision for computation (e.g., `PtxType::F32`).
162    pub precision: PtxType,
163    /// The target GPU architecture.
164    pub target: SmVersion,
165}
166
167impl ElementwiseTemplate {
168    /// Creates a new elementwise template with the given parameters.
169    #[must_use]
170    pub const fn new(op: ElementwiseOp, precision: PtxType, target: SmVersion) -> Self {
171        Self {
172            op,
173            precision,
174            target,
175        }
176    }
177
178    /// Returns the kernel function name derived from the operation and precision.
179    ///
180    /// The name follows the pattern `elementwise_{op}_{type}`, for example
181    /// `elementwise_add_f32` or `elementwise_relu_f16`.
182    #[must_use]
183    pub fn kernel_name(&self) -> String {
184        let type_str = self.precision.as_ptx_str().trim_start_matches('.');
185        format!("elementwise_{}_{}", self.op.as_str(), type_str)
186    }
187
188    /// Generates the complete PTX module text for this elementwise operation.
189    ///
190    /// # Errors
191    ///
192    /// Returns [`PtxGenError`] if the precision type is unsupported for the
193    /// requested operation or if PTX text generation fails.
194    pub fn generate(&self) -> Result<String, PtxGenError> {
195        self.validate_precision()?;
196
197        match self.op {
198            ElementwiseOp::Add => self.generate_binary_arith("add"),
199            ElementwiseOp::Sub => self.generate_binary_arith("sub"),
200            ElementwiseOp::Mul => self.generate_binary_arith("mul"),
201            ElementwiseOp::Div => self.generate_div(),
202            ElementwiseOp::Relu => self.generate_relu(),
203            ElementwiseOp::Gelu => self.generate_gelu(),
204            ElementwiseOp::Sigmoid => self.generate_sigmoid(),
205            ElementwiseOp::Silu => self.generate_silu(),
206            ElementwiseOp::Tanh => self.generate_tanh(),
207            ElementwiseOp::Neg => self.generate_unary("neg"),
208            ElementwiseOp::Abs => self.generate_unary("abs"),
209            ElementwiseOp::Sqrt => self.generate_sqrt(),
210            ElementwiseOp::Rsqrt => self.generate_rsqrt(),
211            ElementwiseOp::Exp => self.generate_exp(),
212            ElementwiseOp::Log => self.generate_log(),
213            ElementwiseOp::Ceil => self.generate_ceil(),
214            ElementwiseOp::Floor => self.generate_floor(),
215            ElementwiseOp::HardSigmoid => self.generate_hard_sigmoid(),
216            ElementwiseOp::HardSwish => self.generate_hard_swish(),
217            ElementwiseOp::Softplus => self.generate_softplus(),
218            ElementwiseOp::LeakyRelu => self.generate_leaky_relu(),
219            ElementwiseOp::Scale => self.generate_scale(),
220            ElementwiseOp::AddScalar => self.generate_add_scalar(),
221            ElementwiseOp::FusedAddRelu => self.generate_fused_add_relu(),
222            ElementwiseOp::FusedScaleAdd => self.generate_fused_scale_add(),
223        }
224    }
225
226    /// Validates that the precision type is a supported floating-point type.
227    fn validate_precision(&self) -> Result<(), PtxGenError> {
228        if !matches!(
229            self.precision,
230            PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64
231        ) {
232            return Err(PtxGenError::InvalidType(format!(
233                "elementwise operations require F16, BF16, F32, or F64, got {}",
234                self.precision.as_ptx_str()
235            )));
236        }
237        Ok(())
238    }
239
240    /// Returns the PTX type suffix string (e.g., `.f32`).
241    const fn ty_str(&self) -> &'static str {
242        self.precision.as_ptx_str()
243    }
244
245    /// Generates a binary arithmetic kernel (add, sub, mul).
246    ///
247    /// Kernel signature: `(a_ptr: u64, b_ptr: u64, c_ptr: u64, n: u32)`
248    fn generate_binary_arith(&self, op_name: &str) -> Result<String, PtxGenError> {
249        let kernel_name = self.kernel_name();
250        let ty = self.ty_str();
251        let byte_size = self.precision.size_bytes();
252        let op_name = op_name.to_string();
253
254        KernelBuilder::new(&kernel_name)
255            .target(self.target)
256            .param("a_ptr", PtxType::U64)
257            .param("b_ptr", PtxType::U64)
258            .param("c_ptr", PtxType::U64)
259            .param("n", PtxType::U32)
260            .max_threads_per_block(256)
261            .body(move |b| {
262                let tid = b.global_thread_id_x();
263                let tid_name = tid.to_string();
264                let n_reg = b.load_param_u32("n");
265                b.if_lt_u32(tid, n_reg, move |b| {
266                    let a_ptr = b.load_param_u64("a_ptr");
267                    let b_ptr = b.load_param_u64("b_ptr");
268                    let c_ptr = b.load_param_u64("c_ptr");
269
270                    // Compute byte offset: tid * sizeof(element)
271                    b.raw_ptx(&format!(
272                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
273                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
274                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
275                         add.u64 %rd_b, {b_ptr}, %rd_off;\n    \
276                         add.u64 %rd_c, {c_ptr}, %rd_off;"
277                    ));
278
279                    // Load, compute, store via raw PTX for reliable type handling
280                    b.raw_ptx(&format!(
281                        "ld.global{ty} %f_a, [%rd_a];\n    \
282                         ld.global{ty} %f_b, [%rd_b];\n    \
283                         {op_name}{ty} %f_c, %f_a, %f_b;\n    \
284                         st.global{ty} [%rd_c], %f_c;"
285                    ));
286                });
287                b.ret();
288            })
289            .build()
290    }
291
292    /// Generates a division kernel with appropriate rounding.
293    fn generate_div(&self) -> Result<String, PtxGenError> {
294        let kernel_name = self.kernel_name();
295        let ty = self.ty_str();
296        let byte_size = self.precision.size_bytes();
297
298        KernelBuilder::new(&kernel_name)
299            .target(self.target)
300            .param("a_ptr", PtxType::U64)
301            .param("b_ptr", PtxType::U64)
302            .param("c_ptr", PtxType::U64)
303            .param("n", PtxType::U32)
304            .max_threads_per_block(256)
305            .body(move |b| {
306                let tid = b.global_thread_id_x();
307                let tid_name = tid.to_string();
308                let n_reg = b.load_param_u32("n");
309                b.if_lt_u32(tid, n_reg, move |b| {
310                    let a_ptr = b.load_param_u64("a_ptr");
311                    let b_ptr = b.load_param_u64("b_ptr");
312                    let c_ptr = b.load_param_u64("c_ptr");
313
314                    b.raw_ptx(&format!(
315                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
316                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
317                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
318                         add.u64 %rd_b, {b_ptr}, %rd_off;\n    \
319                         add.u64 %rd_c, {c_ptr}, %rd_off;"
320                    ));
321
322                    b.raw_ptx(&format!(
323                        "ld.global{ty} %f_a, [%rd_a];\n    \
324                         ld.global{ty} %f_b, [%rd_b];\n    \
325                         div.rn{ty} %f_c, %f_a, %f_b;\n    \
326                         st.global{ty} [%rd_c], %f_c;"
327                    ));
328                });
329                b.ret();
330            })
331            .build()
332    }
333
334    /// Generates a `ReLU` kernel: `max(0, x)`.
335    fn generate_relu(&self) -> Result<String, PtxGenError> {
336        let kernel_name = self.kernel_name();
337        let ty = self.ty_str();
338        let byte_size = self.precision.size_bytes();
339        // IEEE 754 zero in hex for PTX immediate
340        let zero_lit = float_zero_literal(self.precision);
341
342        KernelBuilder::new(&kernel_name)
343            .target(self.target)
344            .param("a_ptr", PtxType::U64)
345            .param("b_ptr", PtxType::U64)
346            .param("n", PtxType::U32)
347            .max_threads_per_block(256)
348            .body(move |b| {
349                let tid = b.global_thread_id_x();
350                let tid_name = tid.to_string();
351                let n_reg = b.load_param_u32("n");
352                b.if_lt_u32(tid, n_reg, move |b| {
353                    let a_ptr = b.load_param_u64("a_ptr");
354                    let b_ptr = b.load_param_u64("b_ptr");
355
356                    b.raw_ptx(&format!(
357                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
358                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
359                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
360                         add.u64 %rd_b, {b_ptr}, %rd_off;"
361                    ));
362
363                    b.raw_ptx(&format!(
364                        "ld.global{ty} %f_x, [%rd_a];\n    \
365                         max{ty} %f_y, %f_x, {zero_lit};\n    \
366                         st.global{ty} [%rd_b], %f_y;"
367                    ));
368                });
369                b.ret();
370            })
371            .build()
372    }
373
374    /// Generates a sigmoid kernel: `1 / (1 + exp(-x))`.
375    ///
376    /// Uses `ex2.approx.f32` with a log2(e) scaling factor for the exponential,
377    /// then `rcp.approx.f32` for the reciprocal.
378    fn generate_sigmoid(&self) -> Result<String, PtxGenError> {
379        let kernel_name = self.kernel_name();
380        let ty = self.ty_str();
381        let byte_size = self.precision.size_bytes();
382
383        KernelBuilder::new(&kernel_name)
384            .target(self.target)
385            .param("a_ptr", PtxType::U64)
386            .param("b_ptr", PtxType::U64)
387            .param("n", PtxType::U32)
388            .max_threads_per_block(256)
389            .body(move |b| {
390                let tid = b.global_thread_id_x();
391                let tid_name = tid.to_string();
392                let n_reg = b.load_param_u32("n");
393                b.if_lt_u32(tid, n_reg, move |b| {
394                    let a_ptr = b.load_param_u64("a_ptr");
395                    let b_ptr = b.load_param_u64("b_ptr");
396
397                    b.raw_ptx(&format!(
398                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
399                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
400                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
401                         add.u64 %rd_b, {b_ptr}, %rd_off;"
402                    ));
403
404                    // sigmoid(x) = 1 / (1 + exp(-x))
405                    // exp(-x) = 2^(-x * log2(e))
406                    // log2(e) ~= 1.4426950408889634 = 0f3FB8AA3B in float hex
407                    b.raw_ptx(&format!(
408                        "ld.global{ty} %f_x, [%rd_a];\n    \
409                         neg{ty} %f_neg, %f_x;\n    \
410                         mul{ty} %f_neg, %f_neg, 0f3FB8AA3B;\n    \
411                         ex2.approx{ty} %f_exp, %f_neg;\n    \
412                         add{ty} %f_denom, %f_exp, 0f3F800000;\n    \
413                         rcp.approx{ty} %f_y, %f_denom;\n    \
414                         st.global{ty} [%rd_b], %f_y;"
415                    ));
416                });
417                b.ret();
418            })
419            .build()
420    }
421
422    /// Generates a GELU kernel using the tanh approximation.
423    ///
424    /// GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
425    ///
426    /// Since PTX does not have a native tanh, this uses the identity:
427    /// tanh(a) = 2 * sigmoid(2a) - 1 = (2 / (1 + exp(-2a))) - 1
428    fn generate_gelu(&self) -> Result<String, PtxGenError> {
429        let kernel_name = self.kernel_name();
430        let ty = self.ty_str();
431        let byte_size = self.precision.size_bytes();
432
433        KernelBuilder::new(&kernel_name)
434            .target(self.target)
435            .param("a_ptr", PtxType::U64)
436            .param("b_ptr", PtxType::U64)
437            .param("n", PtxType::U32)
438            .max_threads_per_block(256)
439            .body(move |b| {
440                let tid = b.global_thread_id_x();
441                let tid_name = tid.to_string();
442                let n_reg = b.load_param_u32("n");
443                b.if_lt_u32(tid, n_reg, move |b| {
444                    let a_ptr = b.load_param_u64("a_ptr");
445                    let b_ptr = b.load_param_u64("b_ptr");
446
447                    b.raw_ptx(&format!(
448                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
449                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
450                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
451                         add.u64 %rd_b, {b_ptr}, %rd_off;"
452                    ));
453
454                    // Constants (IEEE 754 hex):
455                    //   0.5          = 0f3F000000
456                    //   0.044715     = 0f3D372713
457                    //   sqrt(2/pi)   = 0f3F4C422A  (~0.7978845608)
458                    //   2.0          = 0f40000000
459                    //   1.0          = 0f3F800000
460                    //   log2(e)      = 0f3FB8AA3B
461                    b.raw_ptx(&format!(
462                        "ld.global{ty} %f_x, [%rd_a];\n    \
463                         mul{ty} %f_x3, %f_x, %f_x;\n    \
464                         mul{ty} %f_x3, %f_x3, %f_x;\n    \
465                         mul{ty} %f_x3, %f_x3, 0f3D372713;\n    \
466                         add{ty} %f_inner, %f_x, %f_x3;\n    \
467                         mul{ty} %f_inner, %f_inner, 0f3F4C422A;\n    \
468                         mul{ty} %f_2a, %f_inner, 0f40000000;\n    \
469                         neg{ty} %f_neg2a, %f_2a;\n    \
470                         mul{ty} %f_neg2a, %f_neg2a, 0f3FB8AA3B;\n    \
471                         ex2.approx{ty} %f_exp, %f_neg2a;\n    \
472                         add{ty} %f_denom, %f_exp, 0f3F800000;\n    \
473                         rcp.approx{ty} %f_sig, %f_denom;\n    \
474                         mul{ty} %f_sig, %f_sig, 0f40000000;\n    \
475                         sub{ty} %f_tanh, %f_sig, 0f3F800000;\n    \
476                         add{ty} %f_tanh, %f_tanh, 0f3F800000;\n    \
477                         mul{ty} %f_y, 0f3F000000, %f_x;\n    \
478                         mul{ty} %f_y, %f_y, %f_tanh;\n    \
479                         st.global{ty} [%rd_b], %f_y;"
480                    ));
481                });
482                b.ret();
483            })
484            .build()
485    }
486
487    /// Generates a `SiLU` kernel: `x * sigmoid(x)`.
488    fn generate_silu(&self) -> Result<String, PtxGenError> {
489        let kernel_name = self.kernel_name();
490        let ty = self.ty_str();
491        let byte_size = self.precision.size_bytes();
492
493        KernelBuilder::new(&kernel_name)
494            .target(self.target)
495            .param("a_ptr", PtxType::U64)
496            .param("b_ptr", PtxType::U64)
497            .param("n", PtxType::U32)
498            .max_threads_per_block(256)
499            .body(move |b| {
500                let tid = b.global_thread_id_x();
501                let tid_name = tid.to_string();
502                let n_reg = b.load_param_u32("n");
503                b.if_lt_u32(tid, n_reg, move |b| {
504                    let a_ptr = b.load_param_u64("a_ptr");
505                    let b_ptr = b.load_param_u64("b_ptr");
506
507                    b.raw_ptx(&format!(
508                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
509                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
510                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
511                         add.u64 %rd_b, {b_ptr}, %rd_off;"
512                    ));
513
514                    // silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
515                    b.raw_ptx(&format!(
516                        "ld.global{ty} %f_x, [%rd_a];\n    \
517                         neg{ty} %f_neg, %f_x;\n    \
518                         mul{ty} %f_neg, %f_neg, 0f3FB8AA3B;\n    \
519                         ex2.approx{ty} %f_exp, %f_neg;\n    \
520                         add{ty} %f_denom, %f_exp, 0f3F800000;\n    \
521                         rcp.approx{ty} %f_sig, %f_denom;\n    \
522                         mul{ty} %f_y, %f_x, %f_sig;\n    \
523                         st.global{ty} [%rd_b], %f_y;"
524                    ));
525                });
526                b.ret();
527            })
528            .build()
529    }
530
531    /// Generates a tanh kernel using `tanh(x) = 2 * sigmoid(2x) - 1`.
532    fn generate_tanh(&self) -> Result<String, PtxGenError> {
533        let kernel_name = self.kernel_name();
534        let ty = self.ty_str();
535        let byte_size = self.precision.size_bytes();
536
537        KernelBuilder::new(&kernel_name)
538            .target(self.target)
539            .param("a_ptr", PtxType::U64)
540            .param("b_ptr", PtxType::U64)
541            .param("n", PtxType::U32)
542            .max_threads_per_block(256)
543            .body(move |b| {
544                let tid = b.global_thread_id_x();
545                let tid_name = tid.to_string();
546                let n_reg = b.load_param_u32("n");
547                b.if_lt_u32(tid, n_reg, move |b| {
548                    let a_ptr = b.load_param_u64("a_ptr");
549                    let b_ptr = b.load_param_u64("b_ptr");
550
551                    b.raw_ptx(&format!(
552                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
553                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
554                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
555                         add.u64 %rd_b, {b_ptr}, %rd_off;"
556                    ));
557
558                    // tanh(x) = 2*sigmoid(2x) - 1
559                    b.raw_ptx(&format!(
560                        "ld.global{ty} %f_x, [%rd_a];\n    \
561                         mul{ty} %f_2x, %f_x, 0f40000000;\n    \
562                         neg{ty} %f_neg, %f_2x;\n    \
563                         mul{ty} %f_neg, %f_neg, 0f3FB8AA3B;\n    \
564                         ex2.approx{ty} %f_exp, %f_neg;\n    \
565                         add{ty} %f_denom, %f_exp, 0f3F800000;\n    \
566                         rcp.approx{ty} %f_sig, %f_denom;\n    \
567                         mul{ty} %f_y, %f_sig, 0f40000000;\n    \
568                         sub{ty} %f_y, %f_y, 0f3F800000;\n    \
569                         st.global{ty} [%rd_b], %f_y;"
570                    ));
571                });
572                b.ret();
573            })
574            .build()
575    }
576
577    /// Generates a unary operation kernel (neg, abs).
578    fn generate_unary(&self, op_name: &str) -> Result<String, PtxGenError> {
579        let kernel_name = self.kernel_name();
580        let ty = self.ty_str();
581        let byte_size = self.precision.size_bytes();
582        let op_name = op_name.to_string();
583
584        KernelBuilder::new(&kernel_name)
585            .target(self.target)
586            .param("a_ptr", PtxType::U64)
587            .param("b_ptr", PtxType::U64)
588            .param("n", PtxType::U32)
589            .max_threads_per_block(256)
590            .body(move |b| {
591                let tid = b.global_thread_id_x();
592                let tid_name = tid.to_string();
593                let n_reg = b.load_param_u32("n");
594                b.if_lt_u32(tid, n_reg, move |b| {
595                    let a_ptr = b.load_param_u64("a_ptr");
596                    let b_ptr = b.load_param_u64("b_ptr");
597
598                    b.raw_ptx(&format!(
599                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
600                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
601                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
602                         add.u64 %rd_b, {b_ptr}, %rd_off;"
603                    ));
604
605                    b.raw_ptx(&format!(
606                        "ld.global{ty} %f_x, [%rd_a];\n    \
607                         {op_name}{ty} %f_y, %f_x;\n    \
608                         st.global{ty} [%rd_b], %f_y;"
609                    ));
610                });
611                b.ret();
612            })
613            .build()
614    }
615
616    /// Generates a sqrt kernel with rounding.
617    fn generate_sqrt(&self) -> Result<String, PtxGenError> {
618        let kernel_name = self.kernel_name();
619        let ty = self.ty_str();
620        let byte_size = self.precision.size_bytes();
621
622        KernelBuilder::new(&kernel_name)
623            .target(self.target)
624            .param("a_ptr", PtxType::U64)
625            .param("b_ptr", PtxType::U64)
626            .param("n", PtxType::U32)
627            .max_threads_per_block(256)
628            .body(move |b| {
629                let tid = b.global_thread_id_x();
630                let tid_name = tid.to_string();
631                let n_reg = b.load_param_u32("n");
632                b.if_lt_u32(tid, n_reg, move |b| {
633                    let a_ptr = b.load_param_u64("a_ptr");
634                    let b_ptr = b.load_param_u64("b_ptr");
635
636                    b.raw_ptx(&format!(
637                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
638                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
639                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
640                         add.u64 %rd_b, {b_ptr}, %rd_off;"
641                    ));
642
643                    b.raw_ptx(&format!(
644                        "ld.global{ty} %f_x, [%rd_a];\n    \
645                         sqrt.rn{ty} %f_y, %f_x;\n    \
646                         st.global{ty} [%rd_b], %f_y;"
647                    ));
648                });
649                b.ret();
650            })
651            .build()
652    }
653
654    /// Generates an rsqrt (reciprocal square root) kernel.
655    fn generate_rsqrt(&self) -> Result<String, PtxGenError> {
656        let kernel_name = self.kernel_name();
657        let ty = self.ty_str();
658        let byte_size = self.precision.size_bytes();
659
660        KernelBuilder::new(&kernel_name)
661            .target(self.target)
662            .param("a_ptr", PtxType::U64)
663            .param("b_ptr", PtxType::U64)
664            .param("n", PtxType::U32)
665            .max_threads_per_block(256)
666            .body(move |b| {
667                let tid = b.global_thread_id_x();
668                let tid_name = tid.to_string();
669                let n_reg = b.load_param_u32("n");
670                b.if_lt_u32(tid, n_reg, move |b| {
671                    let a_ptr = b.load_param_u64("a_ptr");
672                    let b_ptr = b.load_param_u64("b_ptr");
673
674                    b.raw_ptx(&format!(
675                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
676                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
677                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
678                         add.u64 %rd_b, {b_ptr}, %rd_off;"
679                    ));
680
681                    b.raw_ptx(&format!(
682                        "ld.global{ty} %f_x, [%rd_a];\n    \
683                         rsqrt.approx{ty} %f_y, %f_x;\n    \
684                         st.global{ty} [%rd_b], %f_y;"
685                    ));
686                });
687                b.ret();
688            })
689            .build()
690    }
691
692    /// Generates an exp kernel using base-2 exponentiation: `exp(x) = 2^(x * log2(e))`.
693    fn generate_exp(&self) -> Result<String, PtxGenError> {
694        let kernel_name = self.kernel_name();
695        let ty = self.ty_str();
696        let byte_size = self.precision.size_bytes();
697
698        KernelBuilder::new(&kernel_name)
699            .target(self.target)
700            .param("a_ptr", PtxType::U64)
701            .param("b_ptr", PtxType::U64)
702            .param("n", PtxType::U32)
703            .max_threads_per_block(256)
704            .body(move |b| {
705                let tid = b.global_thread_id_x();
706                let tid_name = tid.to_string();
707                let n_reg = b.load_param_u32("n");
708                b.if_lt_u32(tid, n_reg, move |b| {
709                    let a_ptr = b.load_param_u64("a_ptr");
710                    let b_ptr = b.load_param_u64("b_ptr");
711
712                    b.raw_ptx(&format!(
713                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
714                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
715                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
716                         add.u64 %rd_b, {b_ptr}, %rd_off;"
717                    ));
718
719                    // log2(e) = 0f3FB8AA3B
720                    b.raw_ptx(&format!(
721                        "ld.global{ty} %f_x, [%rd_a];\n    \
722                         mul{ty} %f_x2, %f_x, 0f3FB8AA3B;\n    \
723                         ex2.approx{ty} %f_y, %f_x2;\n    \
724                         st.global{ty} [%rd_b], %f_y;"
725                    ));
726                });
727                b.ret();
728            })
729            .build()
730    }
731
732    /// Generates a natural log kernel using base-2 logarithm: `ln(x) = lg2(x) / lg2(e)`.
733    fn generate_log(&self) -> Result<String, PtxGenError> {
734        let kernel_name = self.kernel_name();
735        let ty = self.ty_str();
736        let byte_size = self.precision.size_bytes();
737
738        KernelBuilder::new(&kernel_name)
739            .target(self.target)
740            .param("a_ptr", PtxType::U64)
741            .param("b_ptr", PtxType::U64)
742            .param("n", PtxType::U32)
743            .max_threads_per_block(256)
744            .body(move |b| {
745                let tid = b.global_thread_id_x();
746                let tid_name = tid.to_string();
747                let n_reg = b.load_param_u32("n");
748                b.if_lt_u32(tid, n_reg, move |b| {
749                    let a_ptr = b.load_param_u64("a_ptr");
750                    let b_ptr = b.load_param_u64("b_ptr");
751
752                    b.raw_ptx(&format!(
753                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
754                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
755                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
756                         add.u64 %rd_b, {b_ptr}, %rd_off;"
757                    ));
758
759                    // 1/log2(e) = ln(2) ~= 0.6931471805599453 = 0f3F317218
760                    b.raw_ptx(&format!(
761                        "ld.global{ty} %f_x, [%rd_a];\n    \
762                         lg2.approx{ty} %f_lg, %f_x;\n    \
763                         mul{ty} %f_y, %f_lg, 0f3F317218;\n    \
764                         st.global{ty} [%rd_b], %f_y;"
765                    ));
766                });
767                b.ret();
768            })
769            .build()
770    }
771
772    /// Generates a ceil kernel: `b[i] = ceil(a[i])`.
773    ///
774    /// Uses `cvt.rpi` (round-to-positive-infinity) for ceiling.
775    fn generate_ceil(&self) -> Result<String, PtxGenError> {
776        let kernel_name = self.kernel_name();
777        let ty = self.ty_str();
778        let byte_size = self.precision.size_bytes();
779
780        KernelBuilder::new(&kernel_name)
781            .target(self.target)
782            .param("a_ptr", PtxType::U64)
783            .param("b_ptr", PtxType::U64)
784            .param("n", PtxType::U32)
785            .max_threads_per_block(256)
786            .body(move |b| {
787                let tid = b.global_thread_id_x();
788                let tid_name = tid.to_string();
789                let n_reg = b.load_param_u32("n");
790                b.if_lt_u32(tid, n_reg, move |b| {
791                    let a_ptr = b.load_param_u64("a_ptr");
792                    let b_ptr = b.load_param_u64("b_ptr");
793
794                    b.raw_ptx(&format!(
795                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
796                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
797                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
798                         add.u64 %rd_b, {b_ptr}, %rd_off;"
799                    ));
800
801                    b.raw_ptx(&format!(
802                        "ld.global{ty} %f_x, [%rd_a];\n    \
803                         cvt.rpi{ty}{ty} %f_y, %f_x;\n    \
804                         st.global{ty} [%rd_b], %f_y;"
805                    ));
806                });
807                b.ret();
808            })
809            .build()
810    }
811
812    /// Generates a floor kernel: `b[i] = floor(a[i])`.
813    ///
814    /// Uses `cvt.rmi` (round-to-minus-infinity) for floor.
815    fn generate_floor(&self) -> Result<String, PtxGenError> {
816        let kernel_name = self.kernel_name();
817        let ty = self.ty_str();
818        let byte_size = self.precision.size_bytes();
819
820        KernelBuilder::new(&kernel_name)
821            .target(self.target)
822            .param("a_ptr", PtxType::U64)
823            .param("b_ptr", PtxType::U64)
824            .param("n", PtxType::U32)
825            .max_threads_per_block(256)
826            .body(move |b| {
827                let tid = b.global_thread_id_x();
828                let tid_name = tid.to_string();
829                let n_reg = b.load_param_u32("n");
830                b.if_lt_u32(tid, n_reg, move |b| {
831                    let a_ptr = b.load_param_u64("a_ptr");
832                    let b_ptr = b.load_param_u64("b_ptr");
833
834                    b.raw_ptx(&format!(
835                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
836                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
837                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
838                         add.u64 %rd_b, {b_ptr}, %rd_off;"
839                    ));
840
841                    b.raw_ptx(&format!(
842                        "ld.global{ty} %f_x, [%rd_a];\n    \
843                         cvt.rmi{ty}{ty} %f_y, %f_x;\n    \
844                         st.global{ty} [%rd_b], %f_y;"
845                    ));
846                });
847                b.ret();
848            })
849            .build()
850    }
851
852    /// Generates a hard-sigmoid kernel: `max(0, min(1, alpha*x + beta))`.
853    ///
854    /// Uses ONNX default constants: alpha=0.2 (0f3E4CCCCD), beta=0.5 (0f3F000000).
855    fn generate_hard_sigmoid(&self) -> Result<String, PtxGenError> {
856        let kernel_name = self.kernel_name();
857        let ty = self.ty_str();
858        let byte_size = self.precision.size_bytes();
859        let zero_lit = float_zero_literal(self.precision);
860
861        KernelBuilder::new(&kernel_name)
862            .target(self.target)
863            .param("a_ptr", PtxType::U64)
864            .param("b_ptr", PtxType::U64)
865            .param("n", PtxType::U32)
866            .max_threads_per_block(256)
867            .body(move |b| {
868                let tid = b.global_thread_id_x();
869                let tid_name = tid.to_string();
870                let n_reg = b.load_param_u32("n");
871                b.if_lt_u32(tid, n_reg, move |b| {
872                    let a_ptr = b.load_param_u64("a_ptr");
873                    let b_ptr = b.load_param_u64("b_ptr");
874
875                    b.raw_ptx(&format!(
876                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
877                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
878                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
879                         add.u64 %rd_b, {b_ptr}, %rd_off;"
880                    ));
881
882                    // HardSigmoid: max(0, min(1, 0.2*x + 0.5))
883                    // 0.2 = 0f3E4CCCCD, 0.5 = 0f3F000000, 1.0 = 0f3F800000
884                    b.raw_ptx(&format!(
885                        "ld.global{ty} %f_x, [%rd_a];\n    \
886                         mul{ty} %f_ax, %f_x, 0f3E4CCCCD;\n    \
887                         add{ty} %f_lin, %f_ax, 0f3F000000;\n    \
888                         min{ty} %f_clip, %f_lin, 0f3F800000;\n    \
889                         max{ty} %f_y, %f_clip, {zero_lit};\n    \
890                         st.global{ty} [%rd_b], %f_y;"
891                    ));
892                });
893                b.ret();
894            })
895            .build()
896    }
897
898    /// Generates a hard-swish kernel: `x * max(0, min(6, x+3)) / 6`.
899    ///
900    /// IEEE 754 hex: 3.0=0f40400000, 6.0=0f40C00000, 1/6=0f3E2AAAAB.
901    fn generate_hard_swish(&self) -> Result<String, PtxGenError> {
902        let kernel_name = self.kernel_name();
903        let ty = self.ty_str();
904        let byte_size = self.precision.size_bytes();
905        let zero_lit = float_zero_literal(self.precision);
906
907        KernelBuilder::new(&kernel_name)
908            .target(self.target)
909            .param("a_ptr", PtxType::U64)
910            .param("b_ptr", PtxType::U64)
911            .param("n", PtxType::U32)
912            .max_threads_per_block(256)
913            .body(move |b| {
914                let tid = b.global_thread_id_x();
915                let tid_name = tid.to_string();
916                let n_reg = b.load_param_u32("n");
917                b.if_lt_u32(tid, n_reg, move |b| {
918                    let a_ptr = b.load_param_u64("a_ptr");
919                    let b_ptr = b.load_param_u64("b_ptr");
920
921                    b.raw_ptx(&format!(
922                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
923                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
924                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
925                         add.u64 %rd_b, {b_ptr}, %rd_off;"
926                    ));
927
928                    // HardSwish: x * max(0, min(6, x+3)) / 6
929                    // 3.0 = 0f40400000, 6.0 = 0f40C00000, 1/6 = 0f3E2AAAAB
930                    b.raw_ptx(&format!(
931                        "ld.global{ty} %f_x, [%rd_a];\n    \
932                         add{ty} %f_xp3, %f_x, 0f40400000;\n    \
933                         min{ty} %f_clip, %f_xp3, 0f40C00000;\n    \
934                         max{ty} %f_clip, %f_clip, {zero_lit};\n    \
935                         mul{ty} %f_div, %f_clip, 0f3E2AAAAB;\n    \
936                         mul{ty} %f_y, %f_x, %f_div;\n    \
937                         st.global{ty} [%rd_b], %f_y;"
938                    ));
939                });
940                b.ret();
941            })
942            .build()
943    }
944
945    /// Generates a softplus kernel: `ln(1 + exp(x))`.
946    ///
947    /// Uses exp(x) = 2^(x * log2(e)) and ln(z) = lg2(z) * ln(2).
948    fn generate_softplus(&self) -> Result<String, PtxGenError> {
949        let kernel_name = self.kernel_name();
950        let ty = self.ty_str();
951        let byte_size = self.precision.size_bytes();
952
953        KernelBuilder::new(&kernel_name)
954            .target(self.target)
955            .param("a_ptr", PtxType::U64)
956            .param("b_ptr", PtxType::U64)
957            .param("n", PtxType::U32)
958            .max_threads_per_block(256)
959            .body(move |b| {
960                let tid = b.global_thread_id_x();
961                let tid_name = tid.to_string();
962                let n_reg = b.load_param_u32("n");
963                b.if_lt_u32(tid, n_reg, move |b| {
964                    let a_ptr = b.load_param_u64("a_ptr");
965                    let b_ptr = b.load_param_u64("b_ptr");
966
967                    b.raw_ptx(&format!(
968                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
969                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
970                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
971                         add.u64 %rd_b, {b_ptr}, %rd_off;"
972                    ));
973
974                    // softplus(x) = ln(1 + exp(x))
975                    // exp(x) = 2^(x * log2(e)),  log2(e) = 0f3FB8AA3B
976                    // ln(z) = lg2(z) * ln(2),    ln(2) = 0f3F317218
977                    // 1.0 = 0f3F800000
978                    b.raw_ptx(&format!(
979                        "ld.global{ty} %f_x, [%rd_a];\n    \
980                         mul{ty} %f_xe, %f_x, 0f3FB8AA3B;\n    \
981                         ex2.approx{ty} %f_exp, %f_xe;\n    \
982                         add{ty} %f_sum, %f_exp, 0f3F800000;\n    \
983                         lg2.approx{ty} %f_lg, %f_sum;\n    \
984                         mul{ty} %f_y, %f_lg, 0f3F317218;\n    \
985                         st.global{ty} [%rd_b], %f_y;"
986                    ));
987                });
988                b.ret();
989            })
990            .build()
991    }
992
993    /// Generates a leaky-relu kernel: `x >= 0 ? x : alpha*x` (alpha=0.01).
994    ///
995    /// IEEE 754 hex: 0.01 = 0f3C23D70A.
996    fn generate_leaky_relu(&self) -> Result<String, PtxGenError> {
997        let kernel_name = self.kernel_name();
998        let ty = self.ty_str();
999        let byte_size = self.precision.size_bytes();
1000        let zero_lit = float_zero_literal(self.precision);
1001
1002        KernelBuilder::new(&kernel_name)
1003            .target(self.target)
1004            .param("a_ptr", PtxType::U64)
1005            .param("b_ptr", PtxType::U64)
1006            .param("n", PtxType::U32)
1007            .max_threads_per_block(256)
1008            .body(move |b| {
1009                let tid = b.global_thread_id_x();
1010                let tid_name = tid.to_string();
1011                let n_reg = b.load_param_u32("n");
1012                b.if_lt_u32(tid, n_reg, move |b| {
1013                    let a_ptr = b.load_param_u64("a_ptr");
1014                    let b_ptr = b.load_param_u64("b_ptr");
1015
1016                    b.raw_ptx(&format!(
1017                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
1018                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
1019                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
1020                         add.u64 %rd_b, {b_ptr}, %rd_off;"
1021                    ));
1022
1023                    // LeakyRelu: x >= 0 ? x : 0.01*x
1024                    // Compute both paths then select via setp + selp
1025                    // 0.01 = 0f3C23D70A
1026                    b.raw_ptx(&format!(
1027                        "ld.global{ty} %f_x, [%rd_a];\n    \
1028                         mul{ty} %f_leak, %f_x, 0f3C23D70A;\n    \
1029                         setp.ge{ty} %p_ge, %f_x, {zero_lit};\n    \
1030                         selp{ty} %f_y, %f_x, %f_leak, %p_ge;\n    \
1031                         st.global{ty} [%rd_b], %f_y;"
1032                    ));
1033                });
1034                b.ret();
1035            })
1036            .build()
1037    }
1038
1039    /// Generates a scale kernel: `b[i] = alpha * a[i]`.
1040    fn generate_scale(&self) -> Result<String, PtxGenError> {
1041        let kernel_name = self.kernel_name();
1042        let ty = self.ty_str();
1043        let byte_size = self.precision.size_bytes();
1044        let scalar_ty = scalar_param_type(self.precision);
1045
1046        KernelBuilder::new(&kernel_name)
1047            .target(self.target)
1048            .param("a_ptr", PtxType::U64)
1049            .param("b_ptr", PtxType::U64)
1050            .param("alpha", scalar_ty)
1051            .param("n", PtxType::U32)
1052            .max_threads_per_block(256)
1053            .body(move |b| {
1054                let tid = b.global_thread_id_x();
1055                let tid_name = tid.to_string();
1056                let n_reg = b.load_param_u32("n");
1057                b.if_lt_u32(tid, n_reg, move |b| {
1058                    let a_ptr = b.load_param_u64("a_ptr");
1059                    let b_ptr = b.load_param_u64("b_ptr");
1060
1061                    b.raw_ptx(&format!(
1062                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
1063                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
1064                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
1065                         add.u64 %rd_b, {b_ptr}, %rd_off;"
1066                    ));
1067
1068                    b.raw_ptx(&format!(
1069                        "ld.param{ty} %f_alpha, [%param_alpha];\n    \
1070                         ld.global{ty} %f_x, [%rd_a];\n    \
1071                         mul{ty} %f_y, %f_alpha, %f_x;\n    \
1072                         st.global{ty} [%rd_b], %f_y;"
1073                    ));
1074                });
1075                b.ret();
1076            })
1077            .build()
1078    }
1079
1080    /// Generates an add-scalar kernel: `b[i] = a[i] + scalar`.
1081    fn generate_add_scalar(&self) -> Result<String, PtxGenError> {
1082        let kernel_name = self.kernel_name();
1083        let ty = self.ty_str();
1084        let byte_size = self.precision.size_bytes();
1085        let scalar_ty = scalar_param_type(self.precision);
1086
1087        KernelBuilder::new(&kernel_name)
1088            .target(self.target)
1089            .param("a_ptr", PtxType::U64)
1090            .param("b_ptr", PtxType::U64)
1091            .param("scalar", scalar_ty)
1092            .param("n", PtxType::U32)
1093            .max_threads_per_block(256)
1094            .body(move |b| {
1095                let tid = b.global_thread_id_x();
1096                let tid_name = tid.to_string();
1097                let n_reg = b.load_param_u32("n");
1098                b.if_lt_u32(tid, n_reg, move |b| {
1099                    let a_ptr = b.load_param_u64("a_ptr");
1100                    let b_ptr = b.load_param_u64("b_ptr");
1101
1102                    b.raw_ptx(&format!(
1103                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
1104                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
1105                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
1106                         add.u64 %rd_b, {b_ptr}, %rd_off;"
1107                    ));
1108
1109                    b.raw_ptx(&format!(
1110                        "ld.param{ty} %f_s, [%param_scalar];\n    \
1111                         ld.global{ty} %f_x, [%rd_a];\n    \
1112                         add{ty} %f_y, %f_x, %f_s;\n    \
1113                         st.global{ty} [%rd_b], %f_y;"
1114                    ));
1115                });
1116                b.ret();
1117            })
1118            .build()
1119    }
1120
1121    /// Generates a fused add-relu kernel: `c[i] = max(0, a[i] + b[i])`.
1122    fn generate_fused_add_relu(&self) -> Result<String, PtxGenError> {
1123        let kernel_name = self.kernel_name();
1124        let ty = self.ty_str();
1125        let byte_size = self.precision.size_bytes();
1126        let zero_lit = float_zero_literal(self.precision);
1127
1128        KernelBuilder::new(&kernel_name)
1129            .target(self.target)
1130            .param("a_ptr", PtxType::U64)
1131            .param("b_ptr", PtxType::U64)
1132            .param("c_ptr", PtxType::U64)
1133            .param("n", PtxType::U32)
1134            .max_threads_per_block(256)
1135            .body(move |b| {
1136                let tid = b.global_thread_id_x();
1137                let tid_name = tid.to_string();
1138                let n_reg = b.load_param_u32("n");
1139                b.if_lt_u32(tid, n_reg, move |b| {
1140                    let a_ptr = b.load_param_u64("a_ptr");
1141                    let b_ptr = b.load_param_u64("b_ptr");
1142                    let c_ptr = b.load_param_u64("c_ptr");
1143
1144                    b.raw_ptx(&format!(
1145                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
1146                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
1147                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
1148                         add.u64 %rd_b, {b_ptr}, %rd_off;\n    \
1149                         add.u64 %rd_c, {c_ptr}, %rd_off;"
1150                    ));
1151
1152                    b.raw_ptx(&format!(
1153                        "ld.global{ty} %f_a, [%rd_a];\n    \
1154                         ld.global{ty} %f_b, [%rd_b];\n    \
1155                         add{ty} %f_sum, %f_a, %f_b;\n    \
1156                         max{ty} %f_y, %f_sum, {zero_lit};\n    \
1157                         st.global{ty} [%rd_c], %f_y;"
1158                    ));
1159                });
1160                b.ret();
1161            })
1162            .build()
1163    }
1164
1165    /// Generates a fused scale-add kernel: `c[i] = alpha * a[i] + beta * b[i]`.
1166    fn generate_fused_scale_add(&self) -> Result<String, PtxGenError> {
1167        let kernel_name = self.kernel_name();
1168        let ty = self.ty_str();
1169        let byte_size = self.precision.size_bytes();
1170        let scalar_ty = scalar_param_type(self.precision);
1171
1172        KernelBuilder::new(&kernel_name)
1173            .target(self.target)
1174            .param("a_ptr", PtxType::U64)
1175            .param("b_ptr", PtxType::U64)
1176            .param("c_ptr", PtxType::U64)
1177            .param("alpha", scalar_ty)
1178            .param("beta", scalar_ty)
1179            .param("n", PtxType::U32)
1180            .max_threads_per_block(256)
1181            .body(move |b| {
1182                let tid = b.global_thread_id_x();
1183                let tid_name = tid.to_string();
1184                let n_reg = b.load_param_u32("n");
1185                b.if_lt_u32(tid, n_reg, move |b| {
1186                    let a_ptr = b.load_param_u64("a_ptr");
1187                    let b_ptr = b.load_param_u64("b_ptr");
1188                    let c_ptr = b.load_param_u64("c_ptr");
1189
1190                    b.raw_ptx(&format!(
1191                        "cvt.u64.u32 %rd_off, {tid_name};\n    \
1192                         mul.lo.u64 %rd_off, %rd_off, {byte_size};\n    \
1193                         add.u64 %rd_a, {a_ptr}, %rd_off;\n    \
1194                         add.u64 %rd_b, {b_ptr}, %rd_off;\n    \
1195                         add.u64 %rd_c, {c_ptr}, %rd_off;"
1196                    ));
1197
1198                    b.raw_ptx(&format!(
1199                        "ld.param{ty} %f_alpha, [%param_alpha];\n    \
1200                         ld.param{ty} %f_beta, [%param_beta];\n    \
1201                         ld.global{ty} %f_a, [%rd_a];\n    \
1202                         ld.global{ty} %f_b, [%rd_b];\n    \
1203                         mul{ty} %f_aa, %f_alpha, %f_a;\n    \
1204                         mul{ty} %f_bb, %f_beta, %f_b;\n    \
1205                         add{ty} %f_y, %f_aa, %f_bb;\n    \
1206                         st.global{ty} [%rd_c], %f_y;"
1207                    ));
1208                });
1209                b.ret();
1210            })
1211            .build()
1212    }
1213}
1214
1215/// Returns the IEEE 754 hex literal for 0.0 in the given precision.
1216const fn float_zero_literal(ty: PtxType) -> &'static str {
1217    match ty {
1218        PtxType::F64 => "0d0000000000000000",
1219        _ => "0f00000000",
1220    }
1221}
1222
1223/// Returns the scalar parameter type matching the given float precision.
1224///
1225/// For F16 and BF16, scalar parameters are passed as F32 (promoted).
1226const fn scalar_param_type(ty: PtxType) -> PtxType {
1227    match ty {
1228        PtxType::F16 | PtxType::BF16 => PtxType::F32,
1229        other => other,
1230    }
1231}
1232
1233#[cfg(test)]
1234mod tests {
1235    use super::*;
1236    use crate::arch::SmVersion;
1237
1238    #[test]
1239    fn elementwise_op_names() {
1240        assert_eq!(ElementwiseOp::Add.as_str(), "add");
1241        assert_eq!(ElementwiseOp::Relu.as_str(), "relu");
1242        assert_eq!(ElementwiseOp::FusedScaleAdd.as_str(), "fused_scale_add");
1243    }
1244
1245    #[test]
1246    fn elementwise_op_classification() {
1247        assert!(ElementwiseOp::Add.is_binary());
1248        assert!(ElementwiseOp::Sub.is_binary());
1249        assert!(!ElementwiseOp::Relu.is_binary());
1250        assert!(!ElementwiseOp::Sigmoid.is_binary());
1251
1252        assert!(ElementwiseOp::Scale.needs_scalar());
1253        assert!(ElementwiseOp::FusedScaleAdd.needs_scalar());
1254        assert!(!ElementwiseOp::Add.needs_scalar());
1255    }
1256
1257    #[test]
1258    fn kernel_name_format() {
1259        let t = ElementwiseTemplate::new(ElementwiseOp::Add, PtxType::F32, SmVersion::Sm80);
1260        assert_eq!(t.kernel_name(), "elementwise_add_f32");
1261
1262        let t2 = ElementwiseTemplate::new(ElementwiseOp::Relu, PtxType::F16, SmVersion::Sm90);
1263        assert_eq!(t2.kernel_name(), "elementwise_relu_f16");
1264    }
1265
1266    #[test]
1267    fn invalid_precision_rejected() {
1268        let t = ElementwiseTemplate::new(ElementwiseOp::Add, PtxType::U32, SmVersion::Sm80);
1269        let result = t.generate();
1270        assert!(result.is_err());
1271    }
1272
1273    #[test]
1274    fn generate_add_f32() {
1275        let t = ElementwiseTemplate::new(ElementwiseOp::Add, PtxType::F32, SmVersion::Sm80);
1276        let ptx = t.generate().expect("should generate add kernel");
1277        assert!(ptx.contains(".entry elementwise_add_f32"));
1278        assert!(ptx.contains(".target sm_80"));
1279        assert!(ptx.contains("add.f32"));
1280    }
1281
1282    #[test]
1283    fn generate_relu_f32() {
1284        let t = ElementwiseTemplate::new(ElementwiseOp::Relu, PtxType::F32, SmVersion::Sm80);
1285        let ptx = t.generate().expect("should generate relu kernel");
1286        assert!(ptx.contains(".entry elementwise_relu_f32"));
1287        assert!(ptx.contains("max.f32"));
1288    }
1289
1290    #[test]
1291    fn generate_sigmoid_f32() {
1292        let t = ElementwiseTemplate::new(ElementwiseOp::Sigmoid, PtxType::F32, SmVersion::Sm80);
1293        let ptx = t.generate().expect("should generate sigmoid kernel");
1294        assert!(ptx.contains("ex2.approx.f32"));
1295        assert!(ptx.contains("rcp.approx.f32"));
1296    }
1297
1298    #[test]
1299    fn generate_gelu_f32() {
1300        let t = ElementwiseTemplate::new(ElementwiseOp::Gelu, PtxType::F32, SmVersion::Sm80);
1301        let ptx = t.generate().expect("should generate gelu kernel");
1302        assert!(ptx.contains("ex2.approx.f32"));
1303        assert!(ptx.contains(".entry elementwise_gelu_f32"));
1304    }
1305
1306    // -------------------------------------------------------------------------
1307    // P4: Precision tests – verify arithmetic correctness of generated PTX
1308    // -------------------------------------------------------------------------
1309
1310    /// `ReLU` must use `max` (not `setp`/`selp`, not `sin` or other wrong ops).
1311    /// The implementation emits `max.f32 %f_y, %f_x, 0f00000000`.
1312    #[test]
1313    fn test_relu_ptx_correct_arithmetic() {
1314        let t = ElementwiseTemplate::new(ElementwiseOp::Relu, PtxType::F32, SmVersion::Sm80);
1315        let ptx = t.generate().expect("relu PTX generation failed");
1316        // Must contain max (implements max(x, 0))
1317        assert!(ptx.contains("max.f32"), "relu must emit max.f32");
1318        // Must contain the IEEE 754 zero literal
1319        assert!(ptx.contains("0f00000000"), "relu must compare against 0.0");
1320        // Must NOT contain wrong operations
1321        assert!(!ptx.contains("sin.approx"), "relu must not emit sin");
1322        assert!(!ptx.contains("cos.approx"), "relu must not emit cos");
1323        assert!(!ptx.contains("ex2.approx"), "relu must not use exp");
1324        assert!(!ptx.contains("rcp.approx"), "relu must not use rcp");
1325    }
1326
1327    /// Sigmoid must contain: neg (for -x), ex2.approx (exp via base-2),
1328    /// rcp.approx (reciprocal for 1/(1+exp(-x))), and add (for +1.0).
1329    /// Must NOT contain wrong operations.
1330    #[test]
1331    fn test_sigmoid_ptx_contains_exp_and_rcp() {
1332        let t = ElementwiseTemplate::new(ElementwiseOp::Sigmoid, PtxType::F32, SmVersion::Sm80);
1333        let ptx = t.generate().expect("sigmoid PTX generation failed");
1334        // neg for computing -x
1335        assert!(ptx.contains("neg.f32"), "sigmoid must negate input");
1336        // exp approximation via base-2 exponentiation
1337        assert!(
1338            ptx.contains("ex2.approx.f32"),
1339            "sigmoid must use ex2.approx for exp"
1340        );
1341        // log2(e) scaling constant
1342        assert!(ptx.contains("0f3FB8AA3B"), "sigmoid must scale by log2(e)");
1343        // reciprocal for the division step
1344        assert!(
1345            ptx.contains("rcp.approx.f32"),
1346            "sigmoid must use rcp.approx for 1/denom"
1347        );
1348        // +1.0 in denominator
1349        assert!(
1350            ptx.contains("0f3F800000"),
1351            "sigmoid must add 1.0 to denominator"
1352        );
1353        // Must NOT contain wrong operations
1354        assert!(!ptx.contains("sin.approx"), "sigmoid must not emit sin");
1355        assert!(
1356            !ptx.contains("max.f32"),
1357            "sigmoid must not use max (relu op)"
1358        );
1359    }
1360
1361    /// GELU uses tanh approximation: check for the three key constants
1362    /// (0.044715, sqrt(2/pi), 2.0) and the ex2+rcp pattern for tanh-via-sigmoid.
1363    #[test]
1364    fn test_gelu_ptx_contains_tanh_approximation() {
1365        let t = ElementwiseTemplate::new(ElementwiseOp::Gelu, PtxType::F32, SmVersion::Sm80);
1366        let ptx = t.generate().expect("gelu PTX generation failed");
1367        // 0.044715 constant (IEEE 754: 0f3D372713)
1368        assert!(
1369            ptx.contains("0f3D372713"),
1370            "gelu must use 0.044715 constant"
1371        );
1372        // sqrt(2/pi) constant (IEEE 754: 0f3F4C422A)
1373        assert!(
1374            ptx.contains("0f3F4C422A"),
1375            "gelu must use sqrt(2/pi) constant"
1376        );
1377        // tanh implemented via 2*sigmoid(2a)-1 using ex2
1378        assert!(
1379            ptx.contains("ex2.approx.f32"),
1380            "gelu must use ex2.approx for tanh approximation"
1381        );
1382        assert!(
1383            ptx.contains("rcp.approx.f32"),
1384            "gelu must use rcp.approx inside tanh"
1385        );
1386        // Must NOT emit a raw sine (wrong operation for gelu)
1387        assert!(!ptx.contains("sin.approx"), "gelu must not emit sin");
1388    }
1389
1390    /// Tanh (implemented as `2*sigmoid(2x)-1`) must contain ex2.approx,
1391    /// rcp.approx, the 2.0 constant, and the subtract of 1.0.
1392    #[test]
1393    fn test_tanh_ptx_contains_exp_instructions() {
1394        let t = ElementwiseTemplate::new(ElementwiseOp::Tanh, PtxType::F32, SmVersion::Sm80);
1395        let ptx = t.generate().expect("tanh PTX generation failed");
1396        // exp via base-2
1397        assert!(
1398            ptx.contains("ex2.approx.f32"),
1399            "tanh must use ex2.approx for exp"
1400        );
1401        // rcp for sigmoid step
1402        assert!(
1403            ptx.contains("rcp.approx.f32"),
1404            "tanh must use rcp.approx in sigmoid step"
1405        );
1406        // 2.0 constant (0f40000000)
1407        assert!(ptx.contains("0f40000000"), "tanh must scale by 2.0");
1408        // sub 1.0 to complete tanh = 2*sigmoid(2x) - 1
1409        assert!(
1410            ptx.contains("sub.f32"),
1411            "tanh must subtract 1.0 for tanh formula"
1412        );
1413        // Must NOT emit wrong operation
1414        assert!(!ptx.contains("sin.approx"), "tanh must not emit sin");
1415    }
1416
1417    /// `SiLU` (`x * sigmoid(x)`) must contain both multiplication and the
1418    /// sigmoid sub-pattern (ex2.approx + rcp.approx).
1419    #[test]
1420    fn test_silu_ptx_contains_mul_and_sigmoid() {
1421        let t = ElementwiseTemplate::new(ElementwiseOp::Silu, PtxType::F32, SmVersion::Sm80);
1422        let ptx = t.generate().expect("silu PTX generation failed");
1423        // sigmoid sub-pattern
1424        assert!(
1425            ptx.contains("ex2.approx.f32"),
1426            "silu must use ex2.approx for sigmoid"
1427        );
1428        assert!(
1429            ptx.contains("rcp.approx.f32"),
1430            "silu must use rcp.approx for sigmoid"
1431        );
1432        // outer multiplication x * sigmoid(x)
1433        assert!(
1434            ptx.contains("mul.f32"),
1435            "silu must multiply x by sigmoid(x)"
1436        );
1437        // Must NOT emit wrong operation
1438        assert!(!ptx.contains("sin.approx"), "silu must not emit sin");
1439        assert!(!ptx.contains("max.f32"), "silu must not use relu max");
1440    }
1441
1442    /// Every generated elementwise kernel must have valid PTX structural headers:
1443    /// `.version`, `.target`, and `.entry`.
1444    #[test]
1445    fn test_elementwise_ptx_has_valid_headers() {
1446        let ops_and_types = [
1447            (ElementwiseOp::Add, PtxType::F32),
1448            (ElementwiseOp::Relu, PtxType::F32),
1449            (ElementwiseOp::Sigmoid, PtxType::F32),
1450            (ElementwiseOp::Gelu, PtxType::F32),
1451            (ElementwiseOp::Tanh, PtxType::F32),
1452            (ElementwiseOp::Silu, PtxType::F32),
1453            (ElementwiseOp::Neg, PtxType::F32),
1454            (ElementwiseOp::Exp, PtxType::F32),
1455            (ElementwiseOp::Log, PtxType::F32),
1456        ];
1457
1458        for (op, ty) in ops_and_types {
1459            let t = ElementwiseTemplate::new(op, ty, SmVersion::Sm80);
1460            let ptx = t
1461                .generate()
1462                .unwrap_or_else(|e| panic!("PTX generation failed for {op:?}: {e}"));
1463            assert!(
1464                ptx.contains(".version"),
1465                "PTX for {op:?} must have .version header"
1466            );
1467            assert!(
1468                ptx.contains(".target"),
1469                "PTX for {op:?} must have .target header"
1470            );
1471            assert!(
1472                ptx.contains(".entry"),
1473                "PTX for {op:?} must have .entry directive"
1474            );
1475        }
1476    }
1477
1478    // -----------------------------------------------------------------------
1479    // CPU reference implementations mirroring the PTX kernel arithmetic.
1480    // These validate numerical precision of the elementwise operations,
1481    // verifying that the same arithmetic would produce correct results
1482    // when executed in a PTX kernel on device.
1483    // -----------------------------------------------------------------------
1484
1485    /// CPU reference for `ReLU`: `max(0, x)`.
1486    fn cpu_relu_f32(x: f32) -> f32 {
1487        x.max(0.0)
1488    }
1489
1490    /// CPU reference for sigmoid: 1 / (1 + exp(-x)).
1491    fn cpu_sigmoid_f32(x: f32) -> f32 {
1492        1.0 / (1.0 + (-x).exp())
1493    }
1494
1495    /// CPU reference for GELU (tanh approximation matching PTX):
1496    /// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))).
1497    fn cpu_gelu_f32(x: f32) -> f32 {
1498        let k0: f32 = 0.797_884_6; // sqrt(2/pi)
1499        let k1: f32 = 0.044_715;
1500        let inner = k0 * k1.mul_add(x * x * x, x);
1501        0.5 * x * (1.0 + inner.tanh())
1502    }
1503
1504    /// CPU reference for tanh: `x.tanh()`.
1505    fn cpu_tanh_f32(x: f32) -> f32 {
1506        x.tanh()
1507    }
1508
1509    /// CPU reference for `SiLU`: `x * sigmoid(x)`.
1510    fn cpu_silu_f32(x: f32) -> f32 {
1511        x * cpu_sigmoid_f32(x)
1512    }
1513
1514    // -- relu precision tests ------------------------------------------------
1515
1516    #[test]
1517    fn relu_precision_known_values() {
1518        assert!((cpu_relu_f32(0.0) - 0.0_f32).abs() < f32::EPSILON);
1519        assert!((cpu_relu_f32(-1.0) - 0.0_f32).abs() < f32::EPSILON);
1520        assert!((cpu_relu_f32(1.0) - 1.0_f32).abs() < f32::EPSILON);
1521        assert!((cpu_relu_f32(-0.001) - 0.0_f32).abs() < f32::EPSILON);
1522        assert!((cpu_relu_f32(100.0) - 100.0_f32).abs() < f32::EPSILON);
1523    }
1524
1525    #[test]
1526    fn relu_precision_negative_zero() {
1527        // -0.0 is <= 0, so relu should return 0.0 (non-negative zero)
1528        assert!(cpu_relu_f32(-0.0) >= 0.0);
1529    }
1530
1531    // -- sigmoid precision tests ---------------------------------------------
1532
1533    #[test]
1534    fn sigmoid_precision_known_values() {
1535        // sigmoid(0) = 0.5 exactly
1536        assert!((cpu_sigmoid_f32(0.0) - 0.5).abs() < 1e-7_f32);
1537        // sigmoid(large) -> 1.0
1538        assert!((cpu_sigmoid_f32(100.0) - 1.0).abs() < 1e-6_f32);
1539        // sigmoid(-large) -> 0.0
1540        assert!(cpu_sigmoid_f32(-100.0).abs() < 1e-6_f32);
1541        // sigmoid(1.0) ~= 0.73105858
1542        let expected_sig1: f32 = 0.731_058_6;
1543        assert!(
1544            (cpu_sigmoid_f32(1.0) - expected_sig1).abs() < 1e-5_f32,
1545            "sigmoid(1.0) expected ~{expected_sig1}, got {}",
1546            cpu_sigmoid_f32(1.0)
1547        );
1548    }
1549
1550    #[test]
1551    fn sigmoid_output_in_unit_interval() {
1552        // For moderate inputs, sigmoid is strictly in (0, 1).
1553        let inputs: &[f32] = &[-10.0, -1.0, 0.0, 1.0, 10.0];
1554        for &x in inputs {
1555            let s = cpu_sigmoid_f32(x);
1556            assert!(s > 0.0 && s < 1.0, "sigmoid({x}) = {s} not in (0,1)");
1557        }
1558        // For extreme inputs, sigmoid saturates to 0.0 or 1.0 in f32 precision.
1559        assert!(cpu_sigmoid_f32(-100.0) >= 0.0);
1560        assert!(cpu_sigmoid_f32(100.0) <= 1.0);
1561    }
1562
1563    // -- gelu precision tests ------------------------------------------------
1564
1565    #[test]
1566    fn gelu_precision_known_values() {
1567        // gelu(0) = 0
1568        assert!(cpu_gelu_f32(0.0).abs() < 1e-7_f32);
1569        // gelu(1) ~= 0.8413 (well-known reference)
1570        assert!(
1571            (cpu_gelu_f32(1.0) - 0.8413_f32).abs() < 0.001_f32,
1572            "gelu(1) should be ~0.8413, got {}",
1573            cpu_gelu_f32(1.0)
1574        );
1575        // gelu(-1) ~= -0.1587
1576        assert!(
1577            (cpu_gelu_f32(-1.0) + 0.1587_f32).abs() < 0.001_f32,
1578            "gelu(-1) should be ~-0.1587, got {}",
1579            cpu_gelu_f32(-1.0)
1580        );
1581        // gelu(large positive) ~= x (saturation)
1582        assert!(
1583            (cpu_gelu_f32(5.0) - 5.0_f32).abs() < 0.001_f32,
1584            "gelu(5) should be ~5.0, got {}",
1585            cpu_gelu_f32(5.0)
1586        );
1587    }
1588
1589    #[test]
1590    fn gelu_sign_preservation() {
1591        // GELU should be positive for positive inputs (at least for x > 0)
1592        assert!(cpu_gelu_f32(0.5) > 0.0);
1593        assert!(cpu_gelu_f32(2.0) > 0.0);
1594        // GELU negative at large negative inputs
1595        assert!(cpu_gelu_f32(-2.0) < 0.0);
1596    }
1597
1598    // -- tanh precision tests ------------------------------------------------
1599
1600    #[test]
1601    fn tanh_precision_known_values() {
1602        assert!(cpu_tanh_f32(0.0).abs() < 1e-7_f32);
1603        let expected_tanh1: f32 = 0.761_594_2;
1604        assert!(
1605            (cpu_tanh_f32(1.0) - expected_tanh1).abs() < 1e-5_f32,
1606            "tanh(1.0) expected ~{expected_tanh1}, got {}",
1607            cpu_tanh_f32(1.0)
1608        );
1609        assert!(
1610            (cpu_tanh_f32(-1.0) + expected_tanh1).abs() < 1e-5_f32,
1611            "tanh(-1.0) expected ~-{expected_tanh1}, got {}",
1612            cpu_tanh_f32(-1.0)
1613        );
1614        // tanh saturates at ±1
1615        assert!(
1616            (cpu_tanh_f32(10.0) - 1.0).abs() < 1e-5_f32,
1617            "tanh(10) should be ~1.0"
1618        );
1619        assert!(
1620            (cpu_tanh_f32(-10.0) + 1.0).abs() < 1e-5_f32,
1621            "tanh(-10) should be ~-1.0"
1622        );
1623    }
1624
1625    #[test]
1626    fn tanh_output_in_bounded_range() {
1627        // For moderate inputs, tanh is strictly in (-1, 1).
1628        let inputs: &[f32] = &[-5.0, -1.0, 0.0, 1.0, 5.0];
1629        for &x in inputs {
1630            let t = cpu_tanh_f32(x);
1631            assert!(t > -1.0 && t < 1.0, "tanh({x}) = {t} not in (-1,1)");
1632        }
1633        // For extreme inputs, tanh saturates to ±1 in f32 precision.
1634        assert!(cpu_tanh_f32(-100.0) >= -1.0);
1635        assert!(cpu_tanh_f32(100.0) <= 1.0);
1636    }
1637
1638    // -- silu precision tests ------------------------------------------------
1639
1640    #[test]
1641    fn silu_precision_known_values() {
1642        // silu(0) = 0
1643        assert!(cpu_silu_f32(0.0).abs() < 1e-7_f32);
1644        // silu(1) = 1 * sigmoid(1) ~= 0.73106
1645        let expected_sig1: f32 = 0.731_058_6;
1646        assert!(
1647            (cpu_silu_f32(1.0) - expected_sig1).abs() < 1e-5_f32,
1648            "silu(1.0) expected ~{expected_sig1}, got {}",
1649            cpu_silu_f32(1.0)
1650        );
1651        // silu(-1) ~= -0.2689
1652        assert!(
1653            (cpu_silu_f32(-1.0) + 0.2689_f32).abs() < 0.001_f32,
1654            "silu(-1) should be ~-0.2689, got {}",
1655            cpu_silu_f32(-1.0)
1656        );
1657    }
1658
1659    #[test]
1660    fn silu_sign_matches_input() {
1661        // silu has same sign as its input for non-zero values
1662        for &x in &[0.1_f32, 0.5, 1.0, 2.0, 5.0] {
1663            assert!(
1664                cpu_silu_f32(x) > 0.0,
1665                "silu({x}) should be positive, got {}",
1666                cpu_silu_f32(x)
1667            );
1668        }
1669        for &x in &[-0.1_f32, -0.5, -2.0] {
1670            assert!(
1671                cpu_silu_f32(x) < 0.0,
1672                "silu({x}) should be negative, got {}",
1673                cpu_silu_f32(x)
1674            );
1675        }
1676    }
1677
1678    // -- PTX generation test for fused add+relu ------------------------------
1679
1680    #[test]
1681    fn elementwise_ptx_generates_fused_add_relu() {
1682        let tmpl =
1683            ElementwiseTemplate::new(ElementwiseOp::FusedAddRelu, PtxType::F32, SmVersion::Sm80);
1684        let ptx = tmpl
1685            .generate()
1686            .expect("FusedAddRelu should generate successfully");
1687        assert!(
1688            ptx.contains("add"),
1689            "fused kernel should contain add instruction"
1690        );
1691        assert!(
1692            ptx.contains("max"),
1693            "fused kernel should contain max for relu"
1694        );
1695    }
1696
1697    // -- grid sweep precision test -------------------------------------------
1698
1699    #[test]
1700    fn elementwise_ops_precision_sweep() {
1701        // Validate mathematical invariants of all reference functions across a
1702        // 10-point grid spanning negative, zero, and positive inputs.
1703        let test_inputs: &[f32] = &[-5.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 5.0, 10.0];
1704
1705        for &x in test_inputs {
1706            // relu: output must be >= 0
1707            assert!(
1708                cpu_relu_f32(x) >= 0.0,
1709                "relu({x}) = {} should be non-negative",
1710                cpu_relu_f32(x)
1711            );
1712
1713            // sigmoid: output must be strictly in (0, 1)
1714            let s = cpu_sigmoid_f32(x);
1715            assert!(s > 0.0 && s < 1.0, "sigmoid({x}) = {s} should be in (0,1)");
1716
1717            // tanh: output must be in [-1, 1] (saturates at ±1 in f32 for large |x|)
1718            let t = cpu_tanh_f32(x);
1719            assert!(
1720                (-1.0_f32..=1.0).contains(&t),
1721                "tanh({x}) = {t} should be in [-1,1]"
1722            );
1723
1724            // silu: should have same sign as input for |x| > small threshold
1725            if x > 0.1 {
1726                assert!(
1727                    cpu_silu_f32(x) > 0.0,
1728                    "silu({x}) should be positive for positive input"
1729                );
1730            }
1731        }
1732    }
1733
1734    // -- PTX generation consistency tests ------------------------------------
1735
1736    #[test]
1737    fn all_activation_ops_generate_ptx_for_f32() {
1738        let activation_ops = [
1739            ElementwiseOp::Relu,
1740            ElementwiseOp::Gelu,
1741            ElementwiseOp::Sigmoid,
1742            ElementwiseOp::Silu,
1743            ElementwiseOp::Tanh,
1744        ];
1745        for op in activation_ops {
1746            let t = ElementwiseTemplate::new(op, PtxType::F32, SmVersion::Sm80);
1747            let result = t.generate();
1748            assert!(
1749                result.is_ok(),
1750                "PTX generation failed for op {:?}: {:?}",
1751                op,
1752                result.err()
1753            );
1754            let ptx = result.expect("already checked is_ok");
1755            let name = op.as_str();
1756            assert!(
1757                ptx.contains(&format!(".entry elementwise_{name}_f32")),
1758                "PTX for {name} missing expected entry point"
1759            );
1760        }
1761    }
1762
1763    #[test]
1764    fn relu_ptx_uses_max_instruction() {
1765        // The PTX relu must use max.f32 to implement max(0, x)
1766        let t = ElementwiseTemplate::new(ElementwiseOp::Relu, PtxType::F32, SmVersion::Sm80);
1767        let ptx = t.generate().expect("relu PTX generation should succeed");
1768        assert!(
1769            ptx.contains("max.f32"),
1770            "relu PTX must use max.f32 instruction"
1771        );
1772    }
1773
1774    #[test]
1775    fn tanh_ptx_uses_tanh_or_approx_sequence() {
1776        // Tanh PTX must use some form of approximation (ex2.approx or tanh.approx)
1777        let t = ElementwiseTemplate::new(ElementwiseOp::Tanh, PtxType::F32, SmVersion::Sm80);
1778        let ptx = t.generate().expect("tanh PTX generation should succeed");
1779        let has_approx = ptx.contains("ex2.approx") || ptx.contains("tanh.approx");
1780        assert!(
1781            has_approx,
1782            "tanh PTX should use ex2.approx or tanh.approx, got:\n{ptx}"
1783        );
1784    }
1785}