1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
//! Scalar operation CUDA kernel launchers
//!
//! Provides launchers for element-wise tensor-scalar operations
//! (add_scalar, mul_scalar, etc.).
use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::loader::{
BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name,
kernel_names, launch_config,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
/// Macro to generate scalar operation launcher functions.
///
/// This eliminates code duplication across dtype-specific launchers.
/// Each launcher follows the same pattern:
/// 1. Load the scalar kernel module
/// 2. Get the kernel function for the specific dtype
/// 3. Configure launch parameters
/// 4. Launch with (input_ptr, scalar, output_ptr, count) arguments
macro_rules! define_scalar_launcher {
(
$(#[$meta:meta])*
$vis:vis fn $name:ident, $scalar_ty:ty, $dtype:expr, $dtype_name:literal
) => {
$(#[$meta])*
$vis unsafe fn $name(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
op: &str,
a_ptr: u64,
scalar: $scalar_ty,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::SCALAR_MODULE)?;
let func_name = kernel_name(op, $dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&scalar);
builder.arg(&out_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA scalar kernel '{}' ({}) launch failed: {:?}",
op, $dtype_name, e
))
})?;
Ok(())
}
}
};
}
define_scalar_launcher!(
/// Launch a scalar operation kernel for f32.
///
/// Performs element-wise operation: `output[i] = op(input[i], scalar)`
///
/// # Supported Operations
///
/// - `add_scalar`: Add scalar to each element
/// - `sub_scalar`: Subtract scalar from each element
/// - `mul_scalar`: Multiply each element by scalar
/// - `div_scalar`: Divide each element by scalar
/// - `pow_scalar`: Raise each element to scalar power
///
/// # Safety
///
/// - All pointers must be valid device memory
/// - Tensors must have at least `numel` elements
pub fn launch_scalar_op_f32, f32, DType::F32, "f32"
);
define_scalar_launcher!(
/// Launch a scalar operation kernel for f64.
///
/// See [`launch_scalar_op_f32`] for documentation.
///
/// # Safety
///
/// Same requirements as `launch_scalar_op_f32`.
pub fn launch_scalar_op_f64, f64, DType::F64, "f64"
);
define_scalar_launcher!(
/// Launch a scalar operation kernel for i32.
///
/// # Safety
///
/// Same requirements as `launch_scalar_op_f32`.
pub fn launch_scalar_op_i32, i32, DType::I32, "i32"
);
define_scalar_launcher!(
/// Launch a scalar operation kernel for i64.
///
/// # Safety
///
/// Same requirements as `launch_scalar_op_f32`.
pub fn launch_scalar_op_i64, i64, DType::I64, "i64"
);
/// Launch a scalar operation kernel for f16/bf16/fp8 (uses f32 scalar value).
///
/// This launcher handles multiple half-precision types that all use f32 scalars:
/// - F16 (IEEE 754 half precision, requires "f16" feature)
/// - BF16 (bfloat16, requires "f16" feature)
/// - FP8E4M3 / FP8E5M2 (8-bit floating point, always available)
///
/// # Safety
///
/// Same requirements as `launch_scalar_op_f32`.
pub unsafe fn launch_scalar_op_half(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
op: &str,
dtype: DType,
a_ptr: u64,
scalar: f32,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::SCALAR_MODULE)?;
let func_name = kernel_name(op, dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&scalar);
builder.arg(&out_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA scalar kernel '{}' ({:?}) launch failed: {:?}",
op, dtype, e
))
})?;
Ok(())
}
}
define_scalar_launcher!(
/// Launch a scalar operation kernel for Complex64 (uses f32 scalar value).
///
/// Scalar operations on complex numbers:
/// - `add_scalar`: (a+bi) + s = (a+s) + bi
/// - `sub_scalar`: (a+bi) - s = (a-s) + bi
/// - `mul_scalar`: s(a+bi) = sa + sbi
/// - `div_scalar`: (a+bi)/s = (a/s) + (b/s)i
/// - `pow_scalar`: z^s using polar form
///
/// # Safety
///
/// Same requirements as `launch_scalar_op_f32`.
pub fn launch_scalar_op_c64, f32, DType::Complex64, "c64"
);
define_scalar_launcher!(
/// Launch a scalar operation kernel for Complex128 (uses f64 scalar value).
///
/// See [`launch_scalar_op_c64`] for operation semantics.
///
/// # Safety
///
/// Same requirements as `launch_scalar_op_f64`.
pub fn launch_scalar_op_c128, f64, DType::Complex128, "c128"
);