cuda_rust_wasm/neural_integration/
cuda_kernels.rs1use super::{ActivationFunction, NeuralResult, NeuralIntegrationError};
7
8pub struct OptimizedKernels;
10
11impl OptimizedKernels {
12 pub fn matrix_multiply_kernel(rows_a: usize, cols_a: usize, cols_b: usize) -> &'static str {
14 r#"
15extern "C" __global__ void optimized_matrix_multiply(
16 const float* __restrict__ A,
17 const float* __restrict__ B,
18 float* __restrict__ C,
19 int M, int N, int K
20) {
21 // Shared memory for tiling
22 __shared__ float As[32][32];
23 __shared__ float Bs[32][32];
24
25 int bx = blockIdx.x, by = blockIdx.y;
26 int tx = threadIdx.x, ty = threadIdx.y;
27
28 // Calculate global position
29 int row = by * 32 + ty;
30 int col = bx * 32 + tx;
31
32 float sum = 0.0f;
33
34 // Loop over tiles
35 for (int t = 0; t < (K + 31) / 32; ++t) {
36 // Load tiles into shared memory
37 int a_col = t * 32 + tx;
38 int b_row = t * 32 + ty;
39
40 if (row < M && a_col < K) {
41 As[ty][tx] = A[row * K + a_col];
42 } else {
43 As[ty][tx] = 0.0f;
44 }
45
46 if (b_row < K && col < N) {
47 Bs[ty][tx] = B[b_row * N + col];
48 } else {
49 Bs[ty][tx] = 0.0f;
50 }
51
52 __syncthreads();
53
54 // Compute partial sum using shared memory
55 #pragma unroll 32
56 for (int k = 0; k < 32; ++k) {
57 sum += As[ty][k] * Bs[k][tx];
58 }
59
60 __syncthreads();
61 }
62
63 // Write result
64 if (row < M && col < N) {
65 C[row * N + col] = sum;
66 }
67}
68"#
69 }
70
71 pub fn vector_operations_kernel() -> &'static str {
73 r#"
74extern "C" __global__ void optimized_vector_add(
75 const float* __restrict__ a,
76 const float* __restrict__ b,
77 float* __restrict__ result,
78 int size
79) {
80 int idx = blockIdx.x * blockDim.x + threadIdx.x;
81 int stride = blockDim.x * gridDim.x;
82
83 // Grid-stride loop for better memory coalescing
84 for (int i = idx; i < size; i += stride) {
85 result[i] = a[i] + b[i];
86 }
87}
88
89extern "C" __global__ void optimized_vector_multiply(
90 const float* __restrict__ a,
91 const float* __restrict__ b,
92 float* __restrict__ result,
93 int size
94) {
95 int idx = blockIdx.x * blockDim.x + threadIdx.x;
96 int stride = blockDim.x * gridDim.x;
97
98 for (int i = idx; i < size; i += stride) {
99 result[i] = a[i] * b[i];
100 }
101}
102
103extern "C" __global__ void optimized_vector_scale(
104 const float* __restrict__ input,
105 float scale,
106 float* __restrict__ result,
107 int size
108) {
109 int idx = blockIdx.x * blockDim.x + threadIdx.x;
110 int stride = blockDim.x * gridDim.x;
111
112 for (int i = idx; i < size; i += stride) {
113 result[i] = input[i] * scale;
114 }
115}
116"#
117 }
118
119 pub fn activation_functions_kernel() -> &'static str {
121 r#"
122// Fast approximation functions
123__device__ __forceinline__ float fast_sigmoid(float x) {
124 return 1.0f / (1.0f + expf(-x));
125}
126
127__device__ __forceinline__ float fast_tanh(float x) {
128 // Fast tanh approximation using polynomial
129 float x2 = x * x;
130 float a = x * (135135.0f + x2 * (17325.0f + x2 * (378.0f + x2)));
131 float b = 135135.0f + x2 * (62370.0f + x2 * (3150.0f + x2 * 28.0f));
132 return a / b;
133}
134
135__device__ __forceinline__ float fast_gelu(float x) {
136 // Fast GELU approximation
137 return 0.5f * x * (1.0f + fast_tanh(0.7978845608f * (x + 0.044715f * x * x * x)));
138}
139
140extern "C" __global__ void optimized_activation_functions(
141 const float* __restrict__ input,
142 float* __restrict__ output,
143 int size,
144 int activation_type
145) {
146 int idx = blockIdx.x * blockDim.x + threadIdx.x;
147 int stride = blockDim.x * gridDim.x;
148
149 for (int i = idx; i < size; i += stride) {
150 float x = input[i];
151 float result;
152
153 switch (activation_type) {
154 case 0: // Sigmoid
155 result = fast_sigmoid(x);
156 break;
157 case 1: // ReLU
158 result = fmaxf(0.0f, x);
159 break;
160 case 2: // Tanh
161 result = fast_tanh(x);
162 break;
163 case 3: // Leaky ReLU
164 result = x > 0.0f ? x : 0.01f * x;
165 break;
166 case 4: // Swish
167 result = x * fast_sigmoid(x);
168 break;
169 case 5: // GELU
170 result = fast_gelu(x);
171 break;
172 default:
173 result = x; // Linear
174 }
175
176 output[i] = result;
177 }
178}
179
180extern "C" __global__ void optimized_activation_derivatives(
181 const float* __restrict__ input,
182 float* __restrict__ output,
183 int size,
184 int activation_type
185) {
186 int idx = blockIdx.x * blockDim.x + threadIdx.x;
187 int stride = blockDim.x * gridDim.x;
188
189 for (int i = idx; i < size; i += stride) {
190 float x = input[i];
191 float result;
192
193 switch (activation_type) {
194 case 0: { // Sigmoid derivative
195 float s = fast_sigmoid(x);
196 result = s * (1.0f - s);
197 break;
198 }
199 case 1: // ReLU derivative
200 result = x > 0.0f ? 1.0f : 0.0f;
201 break;
202 case 2: { // Tanh derivative
203 float t = fast_tanh(x);
204 result = 1.0f - t * t;
205 break;
206 }
207 case 3: // Leaky ReLU derivative
208 result = x > 0.0f ? 1.0f : 0.01f;
209 break;
210 case 4: { // Swish derivative
211 float s = fast_sigmoid(x);
212 result = s + x * s * (1.0f - s);
213 break;
214 }
215 case 5: { // GELU derivative (approximation)
216 float tanh_arg = 0.7978845608f * (x + 0.044715f * x * x * x);
217 float tanh_val = fast_tanh(tanh_arg);
218 result = 0.5f * (1.0f + tanh_val) + x * 0.5f * (1.0f - tanh_val * tanh_val) *
219 0.7978845608f * (1.0f + 3.0f * 0.044715f * x * x);
220 break;
221 }
222 default:
223 result = 1.0f; // Linear derivative
224 }
225
226 output[i] = result;
227 }
228}
229"#
230 }
231
232 pub fn convolution_kernel() -> &'static str {
234 r#"
235extern "C" __global__ void optimized_conv2d(
236 const float* __restrict__ input,
237 const float* __restrict__ kernel,
238 float* __restrict__ output,
239 int batch_size,
240 int in_channels,
241 int out_channels,
242 int in_height,
243 int in_width,
244 int kernel_size,
245 int stride,
246 int padding
247) {
248 // Shared memory for input tile
249 extern __shared__ float shared_input[];
250
251 int batch = blockIdx.z;
252 int out_channel = blockIdx.y;
253 int out_x = blockIdx.x * blockDim.x + threadIdx.x;
254 int out_y = threadIdx.y;
255
256 int out_height = (in_height + 2 * padding - kernel_size) / stride + 1;
257 int out_width = (in_width + 2 * padding - kernel_size) / stride + 1;
258
259 if (out_x >= out_width || out_y >= out_height) return;
260
261 float sum = 0.0f;
262
263 for (int in_channel = 0; in_channel < in_channels; ++in_channel) {
264 for (int ky = 0; ky < kernel_size; ++ky) {
265 for (int kx = 0; kx < kernel_size; ++kx) {
266 int in_y = out_y * stride - padding + ky;
267 int in_x = out_x * stride - padding + kx;
268
269 if (in_y >= 0 && in_y < in_height && in_x >= 0 && in_x < in_width) {
270 int input_idx = batch * in_channels * in_height * in_width +
271 in_channel * in_height * in_width +
272 in_y * in_width + in_x;
273
274 int kernel_idx = out_channel * in_channels * kernel_size * kernel_size +
275 in_channel * kernel_size * kernel_size +
276 ky * kernel_size + kx;
277
278 sum += input[input_idx] * kernel[kernel_idx];
279 }
280 }
281 }
282 }
283
284 int output_idx = batch * out_channels * out_height * out_width +
285 out_channel * out_height * out_width +
286 out_y * out_width + out_x;
287
288 output[output_idx] = sum;
289}
290"#
291 }
292
293 pub fn forward_propagation_kernel() -> &'static str {
295 r#"
296extern "C" __global__ void optimized_forward_propagation(
297 const float* __restrict__ input,
298 const float* __restrict__ weights,
299 const float* __restrict__ biases,
300 float* __restrict__ output,
301 int batch_size,
302 int input_size,
303 int output_size,
304 int activation_type
305) {
306 // Shared memory for weight tiles
307 __shared__ float weight_tile[32][32];
308 __shared__ float input_tile[32];
309
310 int batch = blockIdx.z;
311 int output_neuron = blockIdx.y * blockDim.y + threadIdx.y;
312 int tx = threadIdx.x;
313 int ty = threadIdx.y;
314
315 if (output_neuron >= output_size) return;
316
317 float sum = 0.0f;
318
319 // Process input in tiles
320 for (int tile = 0; tile < (input_size + 31) / 32; ++tile) {
321 int input_idx = tile * 32 + tx;
322
323 // Load input tile
324 if (input_idx < input_size && ty == 0) {
325 input_tile[tx] = input[batch * input_size + input_idx];
326 } else if (ty == 0) {
327 input_tile[tx] = 0.0f;
328 }
329
330 // Load weight tile
331 if (input_idx < input_size) {
332 weight_tile[ty][tx] = weights[output_neuron * input_size + input_idx];
333 } else {
334 weight_tile[ty][tx] = 0.0f;
335 }
336
337 __syncthreads();
338
339 // Compute partial sum
340 #pragma unroll 32
341 for (int k = 0; k < 32; ++k) {
342 sum += weight_tile[ty][k] * input_tile[k];
343 }
344
345 __syncthreads();
346 }
347
348 // Add bias
349 sum += biases[output_neuron];
350
351 // Apply activation function
352 float result;
353 switch (activation_type) {
354 case 0: // Sigmoid
355 result = 1.0f / (1.0f + expf(-sum));
356 break;
357 case 1: // ReLU
358 result = fmaxf(0.0f, sum);
359 break;
360 case 2: // Tanh
361 result = tanhf(sum);
362 break;
363 default:
364 result = sum; // Linear
365 }
366
367 output[batch * output_size + output_neuron] = result;
368}
369"#
370 }
371
372 pub fn backward_propagation_kernel() -> &'static str {
374 r#"
375extern "C" __global__ void optimized_backward_propagation(
376 const float* __restrict__ delta_output,
377 const float* __restrict__ weights,
378 const float* __restrict__ activations,
379 float* __restrict__ delta_input,
380 float* __restrict__ weight_gradients,
381 float* __restrict__ bias_gradients,
382 int batch_size,
383 int input_size,
384 int output_size,
385 int activation_type
386) {
387 __shared__ float delta_shared[32];
388 __shared__ float activation_shared[32];
389
390 int batch = blockIdx.z;
391 int input_neuron = blockIdx.x * blockDim.x + threadIdx.x;
392 int tx = threadIdx.x;
393
394 if (input_neuron >= input_size) return;
395
396 float delta_sum = 0.0f;
397
398 // Process output deltas in tiles
399 for (int tile = 0; tile < (output_size + 31) / 32; ++tile) {
400 int output_idx = tile * 32 + tx;
401
402 // Load delta tile
403 if (output_idx < output_size) {
404 delta_shared[tx] = delta_output[batch * output_size + output_idx];
405 } else {
406 delta_shared[tx] = 0.0f;
407 }
408
409 __syncthreads();
410
411 // Compute delta contribution
412 #pragma unroll 32
413 for (int k = 0; k < 32; ++k) {
414 int out_neuron = tile * 32 + k;
415 if (out_neuron < output_size) {
416 delta_sum += delta_shared[k] * weights[out_neuron * input_size + input_neuron];
417 }
418 }
419
420 __syncthreads();
421 }
422
423 // Apply activation derivative
424 float activation = activations[batch * input_size + input_neuron];
425 float derivative;
426
427 switch (activation_type) {
428 case 0: // Sigmoid derivative
429 derivative = activation * (1.0f - activation);
430 break;
431 case 1: // ReLU derivative
432 derivative = activation > 0.0f ? 1.0f : 0.0f;
433 break;
434 case 2: // Tanh derivative
435 derivative = 1.0f - activation * activation;
436 break;
437 default:
438 derivative = 1.0f; // Linear derivative
439 }
440
441 delta_input[batch * input_size + input_neuron] = delta_sum * derivative;
442}
443
444extern "C" __global__ void optimized_compute_gradients(
445 const float* __restrict__ delta_output,
446 const float* __restrict__ input_activations,
447 float* __restrict__ weight_gradients,
448 float* __restrict__ bias_gradients,
449 int batch_size,
450 int input_size,
451 int output_size,
452 float learning_rate
453) {
454 int input_neuron = blockIdx.x * blockDim.x + threadIdx.x;
455 int output_neuron = blockIdx.y * blockDim.y + threadIdx.y;
456
457 if (input_neuron >= input_size || output_neuron >= output_size) return;
458
459 float gradient_sum = 0.0f;
460
461 // Accumulate gradients across batch
462 for (int batch = 0; batch < batch_size; ++batch) {
463 float delta = delta_output[batch * output_size + output_neuron];
464 float activation = input_activations[batch * input_size + input_neuron];
465 gradient_sum += delta * activation;
466 }
467
468 // Update weight gradient
469 int weight_idx = output_neuron * input_size + input_neuron;
470 weight_gradients[weight_idx] = gradient_sum / batch_size;
471
472 // Update bias gradient (only for first input neuron to avoid race conditions)
473 if (input_neuron == 0) {
474 float bias_gradient = 0.0f;
475 for (int batch = 0; batch < batch_size; ++batch) {
476 bias_gradient += delta_output[batch * output_size + output_neuron];
477 }
478 bias_gradients[output_neuron] = bias_gradient / batch_size;
479 }
480}
481"#
482 }
483
484 pub fn reduction_operations_kernel() -> &'static str {
486 r#"
487extern "C" __global__ void optimized_reduce_sum(
488 const float* __restrict__ input,
489 float* __restrict__ output,
490 int size
491) {
492 __shared__ float sdata[256];
493
494 int tid = threadIdx.x;
495 int i = blockIdx.x * blockDim.x + threadIdx.x;
496
497 // Load data into shared memory
498 sdata[tid] = (i < size) ? input[i] : 0.0f;
499 __syncthreads();
500
501 // Reduction in shared memory
502 for (int s = blockDim.x / 2; s > 0; s >>= 1) {
503 if (tid < s) {
504 sdata[tid] += sdata[tid + s];
505 }
506 __syncthreads();
507 }
508
509 // Write result for this block
510 if (tid == 0) {
511 output[blockIdx.x] = sdata[0];
512 }
513}
514
515extern "C" __global__ void optimized_reduce_max(
516 const float* __restrict__ input,
517 float* __restrict__ output,
518 int size
519) {
520 __shared__ float sdata[256];
521
522 int tid = threadIdx.x;
523 int i = blockIdx.x * blockDim.x + threadIdx.x;
524
525 // Load data into shared memory
526 sdata[tid] = (i < size) ? input[i] : -INFINITY;
527 __syncthreads();
528
529 // Reduction in shared memory
530 for (int s = blockDim.x / 2; s > 0; s >>= 1) {
531 if (tid < s) {
532 sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
533 }
534 __syncthreads();
535 }
536
537 // Write result for this block
538 if (tid == 0) {
539 output[blockIdx.x] = sdata[0];
540 }
541}
542
543extern "C" __global__ void optimized_softmax(
544 const float* __restrict__ input,
545 float* __restrict__ output,
546 int batch_size,
547 int size
548) {
549 __shared__ float max_val;
550 __shared__ float sum_exp;
551
552 int batch = blockIdx.x;
553 int tid = threadIdx.x;
554
555 if (batch >= batch_size) return;
556
557 const float* batch_input = input + batch * size;
558 float* batch_output = output + batch * size;
559
560 // Find maximum value
561 float local_max = -INFINITY;
562 for (int i = tid; i < size; i += blockDim.x) {
563 local_max = fmaxf(local_max, batch_input[i]);
564 }
565
566 // Reduce maximum across threads
567 local_max = blockReduceMax(local_max);
568 if (tid == 0) max_val = local_max;
569 __syncthreads();
570
571 // Compute sum of exponentials
572 float local_sum = 0.0f;
573 for (int i = tid; i < size; i += blockDim.x) {
574 local_sum += expf(batch_input[i] - max_val);
575 }
576
577 // Reduce sum across threads
578 local_sum = blockReduceSum(local_sum);
579 if (tid == 0) sum_exp = local_sum;
580 __syncthreads();
581
582 // Compute softmax
583 for (int i = tid; i < size; i += blockDim.x) {
584 batch_output[i] = expf(batch_input[i] - max_val) / sum_exp;
585 }
586}
587
588__device__ float blockReduceSum(float val) {
589 static __shared__ float shared[32];
590 int lane = threadIdx.x % 32;
591 int wid = threadIdx.x / 32;
592
593 val = warpReduceSum(val);
594 if (lane == 0) shared[wid] = val;
595 __syncthreads();
596
597 val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : 0;
598 if (wid == 0) val = warpReduceSum(val);
599 return val;
600}
601
602__device__ float blockReduceMax(float val) {
603 static __shared__ float shared[32];
604 int lane = threadIdx.x % 32;
605 int wid = threadIdx.x / 32;
606
607 val = warpReduceMax(val);
608 if (lane == 0) shared[wid] = val;
609 __syncthreads();
610
611 val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : -INFINITY;
612 if (wid == 0) val = warpReduceMax(val);
613 return val;
614}
615
616__device__ float warpReduceSum(float val) {
617 for (int offset = 16; offset > 0; offset /= 2) {
618 val += __shfl_down_sync(0xFFFFFFFF, val, offset);
619 }
620 return val;
621}
622
623__device__ float warpReduceMax(float val) {
624 for (int offset = 16; offset > 0; offset /= 2) {
625 val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
626 }
627 return val;
628}
629"#
630 }
631
632 pub fn get_activation_kernel(function: ActivationFunction) -> NeuralResult<String> {
634 let activation_code = match function {
635 ActivationFunction::Sigmoid => "fast_sigmoid(x)",
636 ActivationFunction::ReLU => "fmaxf(0.0f, x)",
637 ActivationFunction::Tanh => "fast_tanh(x)",
638 ActivationFunction::LeakyReLU => "x > 0.0f ? x : 0.01f * x",
639 ActivationFunction::Swish => "x * fast_sigmoid(x)",
640 ActivationFunction::GELU => "fast_gelu(x)",
641 };
642
643 let kernel = format!(r#"
644{}
645
646extern "C" __global__ void specialized_activation(
647 const float* __restrict__ input,
648 float* __restrict__ output,
649 int size
650) {{
651 int idx = blockIdx.x * blockDim.x + threadIdx.x;
652 int stride = blockDim.x * gridDim.x;
653
654 for (int i = idx; i < size; i += stride) {{
655 float x = input[i];
656 output[i] = {};
657 }}
658}}
659"#, Self::activation_functions_kernel(), activation_code);
660
661 Ok(kernel)
662 }
663
664 pub fn get_combined_kernels() -> &'static str {
666 r#"
667// Combined optimized kernels for neural operations
668#include <cuda_runtime.h>
669#include <device_launch_parameters.h>
670#include <cmath>
671
672// Fast math approximations
673__device__ __forceinline__ float fast_sigmoid(float x) {
674 return 1.0f / (1.0f + expf(-x));
675}
676
677__device__ __forceinline__ float fast_tanh(float x) {
678 float x2 = x * x;
679 float a = x * (135135.0f + x2 * (17325.0f + x2 * (378.0f + x2)));
680 float b = 135135.0f + x2 * (62370.0f + x2 * (3150.0f + x2 * 28.0f));
681 return a / b;
682}
683
684__device__ __forceinline__ float fast_gelu(float x) {
685 return 0.5f * x * (1.0f + fast_tanh(0.7978845608f * (x + 0.044715f * x * x * x)));
686}
687
688// Warp-level reductions
689__device__ float warpReduceSum(float val) {
690 for (int offset = 16; offset > 0; offset /= 2) {
691 val += __shfl_down_sync(0xFFFFFFFF, val, offset);
692 }
693 return val;
694}
695
696__device__ float warpReduceMax(float val) {
697 for (int offset = 16; offset > 0; offset /= 2) {
698 val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
699 }
700 return val;
701}
702
703// Block-level reductions
704__device__ float blockReduceSum(float val) {
705 static __shared__ float shared[32];
706 int lane = threadIdx.x % 32;
707 int wid = threadIdx.x / 32;
708
709 val = warpReduceSum(val);
710 if (lane == 0) shared[wid] = val;
711 __syncthreads();
712
713 val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : 0;
714 if (wid == 0) val = warpReduceSum(val);
715 return val;
716}
717
718__device__ float blockReduceMax(float val) {
719 static __shared__ float shared[32];
720 int lane = threadIdx.x % 32;
721 int wid = threadIdx.x / 32;
722
723 val = warpReduceMax(val);
724 if (lane == 0) shared[wid] = val;
725 __syncthreads();
726
727 val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : -INFINITY;
728 if (wid == 0) val = warpReduceMax(val);
729 return val;
730}
731
732// Include all optimized kernels
733// Matrix operations
734extern "C" __global__ void optimized_matrix_multiply(...);
735
736// Vector operations
737extern "C" __global__ void optimized_vector_add(...);
738extern "C" __global__ void optimized_vector_multiply(...);
739extern "C" __global__ void optimized_vector_scale(...);
740
741// Activation functions
742extern "C" __global__ void optimized_activation_functions(...);
743extern "C" __global__ void optimized_activation_derivatives(...);
744
745// Neural network layers
746extern "C" __global__ void optimized_forward_propagation(...);
747extern "C" __global__ void optimized_backward_propagation(...);
748extern "C" __global__ void optimized_compute_gradients(...);
749
750// Convolution
751extern "C" __global__ void optimized_conv2d(...);
752
753// Reductions
754extern "C" __global__ void optimized_reduce_sum(...);
755extern "C" __global__ void optimized_reduce_max(...);
756extern "C" __global__ void optimized_softmax(...);
757"#
758 }
759}
760
761#[derive(Debug, Clone)]
763pub struct KernelConfig {
764 pub block_size: (u32, u32, u32),
765 pub grid_size: (u32, u32, u32),
766 pub shared_memory_size: u32,
767}
768
769impl KernelConfig {
770 pub fn for_matrix_multiply(rows: usize, cols: usize) -> Self {
772 let block_x = 32u32;
773 let block_y = 32u32;
774 let grid_x = (cols as u32).div_ceil(block_x);
775 let grid_y = (rows as u32).div_ceil(block_y);
776
777 Self {
778 block_size: (block_x, block_y, 1),
779 grid_size: (grid_x, grid_y, 1),
780 shared_memory_size: block_x * block_y * 4 * 2, }
782 }
783
784 pub fn for_vector_operation(size: usize) -> Self {
786 let block_size = 256u32;
787 let grid_size = (size as u32).div_ceil(block_size);
788
789 Self {
790 block_size: (block_size, 1, 1),
791 grid_size: (grid_size, 1, 1),
792 shared_memory_size: 0,
793 }
794 }
795
796 pub fn for_convolution(batch_size: usize, out_channels: usize, out_height: usize, out_width: usize) -> Self {
798 let block_x = 16u32;
799 let block_y = 16u32;
800 let grid_x = (out_width as u32).div_ceil(block_x);
801 let grid_y = (out_channels as u32).div_ceil(block_y);
802 let grid_z = batch_size as u32;
803
804 Self {
805 block_size: (block_x, block_y, 1),
806 grid_size: (grid_x, grid_y, grid_z),
807 shared_memory_size: block_x * block_y * 4, }
809 }
810}
811
812#[derive(Debug, Clone)]
814pub struct LaunchParams {
815 pub config: KernelConfig,
816 pub stream: Option<u64>, }
818
819impl LaunchParams {
820 pub fn new(config: KernelConfig) -> Self {
821 Self {
822 config,
823 stream: None,
824 }
825 }
826
827 pub fn with_stream(mut self, stream: u64) -> Self {
828 self.stream = Some(stream);
829 self
830 }
831}
832
833#[cfg(test)]
834mod tests {
835 use super::*;
836
837 #[test]
838 fn test_kernel_availability() {
839 let matrix_kernel = OptimizedKernels::matrix_multiply_kernel(4, 4, 4);
840 assert!(matrix_kernel.contains("optimized_matrix_multiply"));
841 assert!(matrix_kernel.contains("__shared__"));
842 }
843
844 #[test]
845 fn test_vector_kernels() {
846 let vector_kernel = OptimizedKernels::vector_operations_kernel();
847 assert!(vector_kernel.contains("optimized_vector_add"));
848 assert!(vector_kernel.contains("grid-stride loop"));
849 }
850
851 #[test]
852 fn test_activation_kernels() {
853 let activation_kernel = OptimizedKernels::activation_functions_kernel();
854 assert!(activation_kernel.contains("fast_sigmoid"));
855 assert!(activation_kernel.contains("fast_gelu"));
856 }
857
858 #[test]
859 fn test_kernel_config() {
860 let config = KernelConfig::for_matrix_multiply(128, 128);
861 assert_eq!(config.block_size, (32, 32, 1));
862 assert_eq!(config.grid_size, (4, 4, 1));
863 }
864
865 #[test]
866 fn test_activation_kernel_generation() {
867 let kernel = OptimizedKernels::get_activation_kernel(ActivationFunction::ReLU).unwrap();
868 assert!(kernel.contains("fmaxf"));
869 assert!(kernel.contains("specialized_activation"));
870 }
871}