cuda_rust_wasm/neural_integration/
cuda_kernels.rs

1//! Pre-optimized CUDA Kernels for Neural Operations
2//!
3//! This module contains hand-optimized CUDA kernels for common neural network
4//! operations, designed for maximum performance and efficiency.
5
6use super::{ActivationFunction, NeuralResult, NeuralIntegrationError};
7
8/// Collection of optimized CUDA kernels for neural operations
9pub struct OptimizedKernels;
10
11impl OptimizedKernels {
12    /// Get optimized matrix multiplication kernel
13    pub fn matrix_multiply_kernel(rows_a: usize, cols_a: usize, cols_b: usize) -> &'static str {
14        r#"
15extern "C" __global__ void optimized_matrix_multiply(
16    const float* __restrict__ A,
17    const float* __restrict__ B,
18    float* __restrict__ C,
19    int M, int N, int K
20) {
21    // Shared memory for tiling
22    __shared__ float As[32][32];
23    __shared__ float Bs[32][32];
24    
25    int bx = blockIdx.x, by = blockIdx.y;
26    int tx = threadIdx.x, ty = threadIdx.y;
27    
28    // Calculate global position
29    int row = by * 32 + ty;
30    int col = bx * 32 + tx;
31    
32    float sum = 0.0f;
33    
34    // Loop over tiles
35    for (int t = 0; t < (K + 31) / 32; ++t) {
36        // Load tiles into shared memory
37        int a_col = t * 32 + tx;
38        int b_row = t * 32 + ty;
39        
40        if (row < M && a_col < K) {
41            As[ty][tx] = A[row * K + a_col];
42        } else {
43            As[ty][tx] = 0.0f;
44        }
45        
46        if (b_row < K && col < N) {
47            Bs[ty][tx] = B[b_row * N + col];
48        } else {
49            Bs[ty][tx] = 0.0f;
50        }
51        
52        __syncthreads();
53        
54        // Compute partial sum using shared memory
55        #pragma unroll 32
56        for (int k = 0; k < 32; ++k) {
57            sum += As[ty][k] * Bs[k][tx];
58        }
59        
60        __syncthreads();
61    }
62    
63    // Write result
64    if (row < M && col < N) {
65        C[row * N + col] = sum;
66    }
67}
68"#
69    }
70    
71    /// Get optimized vector operations kernel
72    pub fn vector_operations_kernel() -> &'static str {
73        r#"
74extern "C" __global__ void optimized_vector_add(
75    const float* __restrict__ a,
76    const float* __restrict__ b,
77    float* __restrict__ result,
78    int size
79) {
80    int idx = blockIdx.x * blockDim.x + threadIdx.x;
81    int stride = blockDim.x * gridDim.x;
82    
83    // Grid-stride loop for better memory coalescing
84    for (int i = idx; i < size; i += stride) {
85        result[i] = a[i] + b[i];
86    }
87}
88
89extern "C" __global__ void optimized_vector_multiply(
90    const float* __restrict__ a,
91    const float* __restrict__ b,
92    float* __restrict__ result,
93    int size
94) {
95    int idx = blockIdx.x * blockDim.x + threadIdx.x;
96    int stride = blockDim.x * gridDim.x;
97    
98    for (int i = idx; i < size; i += stride) {
99        result[i] = a[i] * b[i];
100    }
101}
102
103extern "C" __global__ void optimized_vector_scale(
104    const float* __restrict__ input,
105    float scale,
106    float* __restrict__ result,
107    int size
108) {
109    int idx = blockIdx.x * blockDim.x + threadIdx.x;
110    int stride = blockDim.x * gridDim.x;
111    
112    for (int i = idx; i < size; i += stride) {
113        result[i] = input[i] * scale;
114    }
115}
116"#
117    }
118    
119    /// Get optimized activation functions kernel
120    pub fn activation_functions_kernel() -> &'static str {
121        r#"
122// Fast approximation functions
123__device__ __forceinline__ float fast_sigmoid(float x) {
124    return 1.0f / (1.0f + expf(-x));
125}
126
127__device__ __forceinline__ float fast_tanh(float x) {
128    // Fast tanh approximation using polynomial
129    float x2 = x * x;
130    float a = x * (135135.0f + x2 * (17325.0f + x2 * (378.0f + x2)));
131    float b = 135135.0f + x2 * (62370.0f + x2 * (3150.0f + x2 * 28.0f));
132    return a / b;
133}
134
135__device__ __forceinline__ float fast_gelu(float x) {
136    // Fast GELU approximation
137    return 0.5f * x * (1.0f + fast_tanh(0.7978845608f * (x + 0.044715f * x * x * x)));
138}
139
140extern "C" __global__ void optimized_activation_functions(
141    const float* __restrict__ input,
142    float* __restrict__ output,
143    int size,
144    int activation_type
145) {
146    int idx = blockIdx.x * blockDim.x + threadIdx.x;
147    int stride = blockDim.x * gridDim.x;
148    
149    for (int i = idx; i < size; i += stride) {
150        float x = input[i];
151        float result;
152        
153        switch (activation_type) {
154            case 0: // Sigmoid
155                result = fast_sigmoid(x);
156                break;
157            case 1: // ReLU
158                result = fmaxf(0.0f, x);
159                break;
160            case 2: // Tanh
161                result = fast_tanh(x);
162                break;
163            case 3: // Leaky ReLU
164                result = x > 0.0f ? x : 0.01f * x;
165                break;
166            case 4: // Swish
167                result = x * fast_sigmoid(x);
168                break;
169            case 5: // GELU
170                result = fast_gelu(x);
171                break;
172            default:
173                result = x; // Linear
174        }
175        
176        output[i] = result;
177    }
178}
179
180extern "C" __global__ void optimized_activation_derivatives(
181    const float* __restrict__ input,
182    float* __restrict__ output,
183    int size,
184    int activation_type
185) {
186    int idx = blockIdx.x * blockDim.x + threadIdx.x;
187    int stride = blockDim.x * gridDim.x;
188    
189    for (int i = idx; i < size; i += stride) {
190        float x = input[i];
191        float result;
192        
193        switch (activation_type) {
194            case 0: { // Sigmoid derivative
195                float s = fast_sigmoid(x);
196                result = s * (1.0f - s);
197                break;
198            }
199            case 1: // ReLU derivative
200                result = x > 0.0f ? 1.0f : 0.0f;
201                break;
202            case 2: { // Tanh derivative
203                float t = fast_tanh(x);
204                result = 1.0f - t * t;
205                break;
206            }
207            case 3: // Leaky ReLU derivative
208                result = x > 0.0f ? 1.0f : 0.01f;
209                break;
210            case 4: { // Swish derivative
211                float s = fast_sigmoid(x);
212                result = s + x * s * (1.0f - s);
213                break;
214            }
215            case 5: { // GELU derivative (approximation)
216                float tanh_arg = 0.7978845608f * (x + 0.044715f * x * x * x);
217                float tanh_val = fast_tanh(tanh_arg);
218                result = 0.5f * (1.0f + tanh_val) + x * 0.5f * (1.0f - tanh_val * tanh_val) * 
219                        0.7978845608f * (1.0f + 3.0f * 0.044715f * x * x);
220                break;
221            }
222            default:
223                result = 1.0f; // Linear derivative
224        }
225        
226        output[i] = result;
227    }
228}
229"#
230    }
231    
232    /// Get optimized convolution kernel
233    pub fn convolution_kernel() -> &'static str {
234        r#"
235extern "C" __global__ void optimized_conv2d(
236    const float* __restrict__ input,
237    const float* __restrict__ kernel,
238    float* __restrict__ output,
239    int batch_size,
240    int in_channels,
241    int out_channels,
242    int in_height,
243    int in_width,
244    int kernel_size,
245    int stride,
246    int padding
247) {
248    // Shared memory for input tile
249    extern __shared__ float shared_input[];
250    
251    int batch = blockIdx.z;
252    int out_channel = blockIdx.y;
253    int out_x = blockIdx.x * blockDim.x + threadIdx.x;
254    int out_y = threadIdx.y;
255    
256    int out_height = (in_height + 2 * padding - kernel_size) / stride + 1;
257    int out_width = (in_width + 2 * padding - kernel_size) / stride + 1;
258    
259    if (out_x >= out_width || out_y >= out_height) return;
260    
261    float sum = 0.0f;
262    
263    for (int in_channel = 0; in_channel < in_channels; ++in_channel) {
264        for (int ky = 0; ky < kernel_size; ++ky) {
265            for (int kx = 0; kx < kernel_size; ++kx) {
266                int in_y = out_y * stride - padding + ky;
267                int in_x = out_x * stride - padding + kx;
268                
269                if (in_y >= 0 && in_y < in_height && in_x >= 0 && in_x < in_width) {
270                    int input_idx = batch * in_channels * in_height * in_width +
271                                   in_channel * in_height * in_width +
272                                   in_y * in_width + in_x;
273                    
274                    int kernel_idx = out_channel * in_channels * kernel_size * kernel_size +
275                                    in_channel * kernel_size * kernel_size +
276                                    ky * kernel_size + kx;
277                    
278                    sum += input[input_idx] * kernel[kernel_idx];
279                }
280            }
281        }
282    }
283    
284    int output_idx = batch * out_channels * out_height * out_width +
285                     out_channel * out_height * out_width +
286                     out_y * out_width + out_x;
287    
288    output[output_idx] = sum;
289}
290"#
291    }
292    
293    /// Get optimized forward propagation kernel
294    pub fn forward_propagation_kernel() -> &'static str {
295        r#"
296extern "C" __global__ void optimized_forward_propagation(
297    const float* __restrict__ input,
298    const float* __restrict__ weights,
299    const float* __restrict__ biases,
300    float* __restrict__ output,
301    int batch_size,
302    int input_size,
303    int output_size,
304    int activation_type
305) {
306    // Shared memory for weight tiles
307    __shared__ float weight_tile[32][32];
308    __shared__ float input_tile[32];
309    
310    int batch = blockIdx.z;
311    int output_neuron = blockIdx.y * blockDim.y + threadIdx.y;
312    int tx = threadIdx.x;
313    int ty = threadIdx.y;
314    
315    if (output_neuron >= output_size) return;
316    
317    float sum = 0.0f;
318    
319    // Process input in tiles
320    for (int tile = 0; tile < (input_size + 31) / 32; ++tile) {
321        int input_idx = tile * 32 + tx;
322        
323        // Load input tile
324        if (input_idx < input_size && ty == 0) {
325            input_tile[tx] = input[batch * input_size + input_idx];
326        } else if (ty == 0) {
327            input_tile[tx] = 0.0f;
328        }
329        
330        // Load weight tile
331        if (input_idx < input_size) {
332            weight_tile[ty][tx] = weights[output_neuron * input_size + input_idx];
333        } else {
334            weight_tile[ty][tx] = 0.0f;
335        }
336        
337        __syncthreads();
338        
339        // Compute partial sum
340        #pragma unroll 32
341        for (int k = 0; k < 32; ++k) {
342            sum += weight_tile[ty][k] * input_tile[k];
343        }
344        
345        __syncthreads();
346    }
347    
348    // Add bias
349    sum += biases[output_neuron];
350    
351    // Apply activation function
352    float result;
353    switch (activation_type) {
354        case 0: // Sigmoid
355            result = 1.0f / (1.0f + expf(-sum));
356            break;
357        case 1: // ReLU
358            result = fmaxf(0.0f, sum);
359            break;
360        case 2: // Tanh
361            result = tanhf(sum);
362            break;
363        default:
364            result = sum; // Linear
365    }
366    
367    output[batch * output_size + output_neuron] = result;
368}
369"#
370    }
371    
372    /// Get optimized backward propagation kernel
373    pub fn backward_propagation_kernel() -> &'static str {
374        r#"
375extern "C" __global__ void optimized_backward_propagation(
376    const float* __restrict__ delta_output,
377    const float* __restrict__ weights,
378    const float* __restrict__ activations,
379    float* __restrict__ delta_input,
380    float* __restrict__ weight_gradients,
381    float* __restrict__ bias_gradients,
382    int batch_size,
383    int input_size,
384    int output_size,
385    int activation_type
386) {
387    __shared__ float delta_shared[32];
388    __shared__ float activation_shared[32];
389    
390    int batch = blockIdx.z;
391    int input_neuron = blockIdx.x * blockDim.x + threadIdx.x;
392    int tx = threadIdx.x;
393    
394    if (input_neuron >= input_size) return;
395    
396    float delta_sum = 0.0f;
397    
398    // Process output deltas in tiles
399    for (int tile = 0; tile < (output_size + 31) / 32; ++tile) {
400        int output_idx = tile * 32 + tx;
401        
402        // Load delta tile
403        if (output_idx < output_size) {
404            delta_shared[tx] = delta_output[batch * output_size + output_idx];
405        } else {
406            delta_shared[tx] = 0.0f;
407        }
408        
409        __syncthreads();
410        
411        // Compute delta contribution
412        #pragma unroll 32
413        for (int k = 0; k < 32; ++k) {
414            int out_neuron = tile * 32 + k;
415            if (out_neuron < output_size) {
416                delta_sum += delta_shared[k] * weights[out_neuron * input_size + input_neuron];
417            }
418        }
419        
420        __syncthreads();
421    }
422    
423    // Apply activation derivative
424    float activation = activations[batch * input_size + input_neuron];
425    float derivative;
426    
427    switch (activation_type) {
428        case 0: // Sigmoid derivative
429            derivative = activation * (1.0f - activation);
430            break;
431        case 1: // ReLU derivative
432            derivative = activation > 0.0f ? 1.0f : 0.0f;
433            break;
434        case 2: // Tanh derivative
435            derivative = 1.0f - activation * activation;
436            break;
437        default:
438            derivative = 1.0f; // Linear derivative
439    }
440    
441    delta_input[batch * input_size + input_neuron] = delta_sum * derivative;
442}
443
444extern "C" __global__ void optimized_compute_gradients(
445    const float* __restrict__ delta_output,
446    const float* __restrict__ input_activations,
447    float* __restrict__ weight_gradients,
448    float* __restrict__ bias_gradients,
449    int batch_size,
450    int input_size,
451    int output_size,
452    float learning_rate
453) {
454    int input_neuron = blockIdx.x * blockDim.x + threadIdx.x;
455    int output_neuron = blockIdx.y * blockDim.y + threadIdx.y;
456    
457    if (input_neuron >= input_size || output_neuron >= output_size) return;
458    
459    float gradient_sum = 0.0f;
460    
461    // Accumulate gradients across batch
462    for (int batch = 0; batch < batch_size; ++batch) {
463        float delta = delta_output[batch * output_size + output_neuron];
464        float activation = input_activations[batch * input_size + input_neuron];
465        gradient_sum += delta * activation;
466    }
467    
468    // Update weight gradient
469    int weight_idx = output_neuron * input_size + input_neuron;
470    weight_gradients[weight_idx] = gradient_sum / batch_size;
471    
472    // Update bias gradient (only for first input neuron to avoid race conditions)
473    if (input_neuron == 0) {
474        float bias_gradient = 0.0f;
475        for (int batch = 0; batch < batch_size; ++batch) {
476            bias_gradient += delta_output[batch * output_size + output_neuron];
477        }
478        bias_gradients[output_neuron] = bias_gradient / batch_size;
479    }
480}
481"#
482    }
483    
484    /// Get optimized reduction operations kernel
485    pub fn reduction_operations_kernel() -> &'static str {
486        r#"
487extern "C" __global__ void optimized_reduce_sum(
488    const float* __restrict__ input,
489    float* __restrict__ output,
490    int size
491) {
492    __shared__ float sdata[256];
493    
494    int tid = threadIdx.x;
495    int i = blockIdx.x * blockDim.x + threadIdx.x;
496    
497    // Load data into shared memory
498    sdata[tid] = (i < size) ? input[i] : 0.0f;
499    __syncthreads();
500    
501    // Reduction in shared memory
502    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
503        if (tid < s) {
504            sdata[tid] += sdata[tid + s];
505        }
506        __syncthreads();
507    }
508    
509    // Write result for this block
510    if (tid == 0) {
511        output[blockIdx.x] = sdata[0];
512    }
513}
514
515extern "C" __global__ void optimized_reduce_max(
516    const float* __restrict__ input,
517    float* __restrict__ output,
518    int size
519) {
520    __shared__ float sdata[256];
521    
522    int tid = threadIdx.x;
523    int i = blockIdx.x * blockDim.x + threadIdx.x;
524    
525    // Load data into shared memory
526    sdata[tid] = (i < size) ? input[i] : -INFINITY;
527    __syncthreads();
528    
529    // Reduction in shared memory
530    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
531        if (tid < s) {
532            sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
533        }
534        __syncthreads();
535    }
536    
537    // Write result for this block
538    if (tid == 0) {
539        output[blockIdx.x] = sdata[0];
540    }
541}
542
543extern "C" __global__ void optimized_softmax(
544    const float* __restrict__ input,
545    float* __restrict__ output,
546    int batch_size,
547    int size
548) {
549    __shared__ float max_val;
550    __shared__ float sum_exp;
551    
552    int batch = blockIdx.x;
553    int tid = threadIdx.x;
554    
555    if (batch >= batch_size) return;
556    
557    const float* batch_input = input + batch * size;
558    float* batch_output = output + batch * size;
559    
560    // Find maximum value
561    float local_max = -INFINITY;
562    for (int i = tid; i < size; i += blockDim.x) {
563        local_max = fmaxf(local_max, batch_input[i]);
564    }
565    
566    // Reduce maximum across threads
567    local_max = blockReduceMax(local_max);
568    if (tid == 0) max_val = local_max;
569    __syncthreads();
570    
571    // Compute sum of exponentials
572    float local_sum = 0.0f;
573    for (int i = tid; i < size; i += blockDim.x) {
574        local_sum += expf(batch_input[i] - max_val);
575    }
576    
577    // Reduce sum across threads
578    local_sum = blockReduceSum(local_sum);
579    if (tid == 0) sum_exp = local_sum;
580    __syncthreads();
581    
582    // Compute softmax
583    for (int i = tid; i < size; i += blockDim.x) {
584        batch_output[i] = expf(batch_input[i] - max_val) / sum_exp;
585    }
586}
587
588__device__ float blockReduceSum(float val) {
589    static __shared__ float shared[32];
590    int lane = threadIdx.x % 32;
591    int wid = threadIdx.x / 32;
592
593    val = warpReduceSum(val);
594    if (lane == 0) shared[wid] = val;
595    __syncthreads();
596
597    val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : 0;
598    if (wid == 0) val = warpReduceSum(val);
599    return val;
600}
601
602__device__ float blockReduceMax(float val) {
603    static __shared__ float shared[32];
604    int lane = threadIdx.x % 32;
605    int wid = threadIdx.x / 32;
606
607    val = warpReduceMax(val);
608    if (lane == 0) shared[wid] = val;
609    __syncthreads();
610
611    val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : -INFINITY;
612    if (wid == 0) val = warpReduceMax(val);
613    return val;
614}
615
616__device__ float warpReduceSum(float val) {
617    for (int offset = 16; offset > 0; offset /= 2) {
618        val += __shfl_down_sync(0xFFFFFFFF, val, offset);
619    }
620    return val;
621}
622
623__device__ float warpReduceMax(float val) {
624    for (int offset = 16; offset > 0; offset /= 2) {
625        val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
626    }
627    return val;
628}
629"#
630    }
631    
632    /// Get kernel for specific activation function
633    pub fn get_activation_kernel(function: ActivationFunction) -> NeuralResult<String> {
634        let activation_code = match function {
635            ActivationFunction::Sigmoid => "fast_sigmoid(x)",
636            ActivationFunction::ReLU => "fmaxf(0.0f, x)",
637            ActivationFunction::Tanh => "fast_tanh(x)",
638            ActivationFunction::LeakyReLU => "x > 0.0f ? x : 0.01f * x",
639            ActivationFunction::Swish => "x * fast_sigmoid(x)",
640            ActivationFunction::GELU => "fast_gelu(x)",
641        };
642        
643        let kernel = format!(r#"
644{}
645
646extern "C" __global__ void specialized_activation(
647    const float* __restrict__ input,
648    float* __restrict__ output,
649    int size
650) {{
651    int idx = blockIdx.x * blockDim.x + threadIdx.x;
652    int stride = blockDim.x * gridDim.x;
653    
654    for (int i = idx; i < size; i += stride) {{
655        float x = input[i];
656        output[i] = {};
657    }}
658}}
659"#, Self::activation_functions_kernel(), activation_code);
660        
661        Ok(kernel)
662    }
663    
664    /// Get all kernels as a single compilation unit
665    pub fn get_combined_kernels() -> &'static str {
666        r#"
667// Combined optimized kernels for neural operations
668#include <cuda_runtime.h>
669#include <device_launch_parameters.h>
670#include <cmath>
671
672// Fast math approximations
673__device__ __forceinline__ float fast_sigmoid(float x) {
674    return 1.0f / (1.0f + expf(-x));
675}
676
677__device__ __forceinline__ float fast_tanh(float x) {
678    float x2 = x * x;
679    float a = x * (135135.0f + x2 * (17325.0f + x2 * (378.0f + x2)));
680    float b = 135135.0f + x2 * (62370.0f + x2 * (3150.0f + x2 * 28.0f));
681    return a / b;
682}
683
684__device__ __forceinline__ float fast_gelu(float x) {
685    return 0.5f * x * (1.0f + fast_tanh(0.7978845608f * (x + 0.044715f * x * x * x)));
686}
687
688// Warp-level reductions
689__device__ float warpReduceSum(float val) {
690    for (int offset = 16; offset > 0; offset /= 2) {
691        val += __shfl_down_sync(0xFFFFFFFF, val, offset);
692    }
693    return val;
694}
695
696__device__ float warpReduceMax(float val) {
697    for (int offset = 16; offset > 0; offset /= 2) {
698        val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
699    }
700    return val;
701}
702
703// Block-level reductions
704__device__ float blockReduceSum(float val) {
705    static __shared__ float shared[32];
706    int lane = threadIdx.x % 32;
707    int wid = threadIdx.x / 32;
708
709    val = warpReduceSum(val);
710    if (lane == 0) shared[wid] = val;
711    __syncthreads();
712
713    val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : 0;
714    if (wid == 0) val = warpReduceSum(val);
715    return val;
716}
717
718__device__ float blockReduceMax(float val) {
719    static __shared__ float shared[32];
720    int lane = threadIdx.x % 32;
721    int wid = threadIdx.x / 32;
722
723    val = warpReduceMax(val);
724    if (lane == 0) shared[wid] = val;
725    __syncthreads();
726
727    val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : -INFINITY;
728    if (wid == 0) val = warpReduceMax(val);
729    return val;
730}
731
732// Include all optimized kernels
733// Matrix operations
734extern "C" __global__ void optimized_matrix_multiply(...);
735
736// Vector operations  
737extern "C" __global__ void optimized_vector_add(...);
738extern "C" __global__ void optimized_vector_multiply(...);
739extern "C" __global__ void optimized_vector_scale(...);
740
741// Activation functions
742extern "C" __global__ void optimized_activation_functions(...);
743extern "C" __global__ void optimized_activation_derivatives(...);
744
745// Neural network layers
746extern "C" __global__ void optimized_forward_propagation(...);
747extern "C" __global__ void optimized_backward_propagation(...);
748extern "C" __global__ void optimized_compute_gradients(...);
749
750// Convolution
751extern "C" __global__ void optimized_conv2d(...);
752
753// Reductions
754extern "C" __global__ void optimized_reduce_sum(...);
755extern "C" __global__ void optimized_reduce_max(...);
756extern "C" __global__ void optimized_softmax(...);
757"#
758    }
759}
760
761/// Kernel configuration parameters
762#[derive(Debug, Clone)]
763pub struct KernelConfig {
764    pub block_size: (u32, u32, u32),
765    pub grid_size: (u32, u32, u32),
766    pub shared_memory_size: u32,
767}
768
769impl KernelConfig {
770    /// Create optimal configuration for matrix multiplication
771    pub fn for_matrix_multiply(rows: usize, cols: usize) -> Self {
772        let block_x = 32u32;
773        let block_y = 32u32;
774        let grid_x = (cols as u32).div_ceil(block_x);
775        let grid_y = (rows as u32).div_ceil(block_y);
776        
777        Self {
778            block_size: (block_x, block_y, 1),
779            grid_size: (grid_x, grid_y, 1),
780            shared_memory_size: block_x * block_y * 4 * 2, // Two tiles of f32
781        }
782    }
783    
784    /// Create optimal configuration for vector operations
785    pub fn for_vector_operation(size: usize) -> Self {
786        let block_size = 256u32;
787        let grid_size = (size as u32).div_ceil(block_size);
788        
789        Self {
790            block_size: (block_size, 1, 1),
791            grid_size: (grid_size, 1, 1),
792            shared_memory_size: 0,
793        }
794    }
795    
796    /// Create optimal configuration for convolution
797    pub fn for_convolution(batch_size: usize, out_channels: usize, out_height: usize, out_width: usize) -> Self {
798        let block_x = 16u32;
799        let block_y = 16u32;
800        let grid_x = (out_width as u32).div_ceil(block_x);
801        let grid_y = (out_channels as u32).div_ceil(block_y);
802        let grid_z = batch_size as u32;
803        
804        Self {
805            block_size: (block_x, block_y, 1),
806            grid_size: (grid_x, grid_y, grid_z),
807            shared_memory_size: block_x * block_y * 4, // Shared input tile
808        }
809    }
810}
811
812/// Kernel launch parameters
813#[derive(Debug, Clone)]
814pub struct LaunchParams {
815    pub config: KernelConfig,
816    pub stream: Option<u64>, // CUDA stream handle
817}
818
819impl LaunchParams {
820    pub fn new(config: KernelConfig) -> Self {
821        Self {
822            config,
823            stream: None,
824        }
825    }
826    
827    pub fn with_stream(mut self, stream: u64) -> Self {
828        self.stream = Some(stream);
829        self
830    }
831}
832
833#[cfg(test)]
834mod tests {
835    use super::*;
836    
837    #[test]
838    fn test_kernel_availability() {
839        let matrix_kernel = OptimizedKernels::matrix_multiply_kernel(4, 4, 4);
840        assert!(matrix_kernel.contains("optimized_matrix_multiply"));
841        assert!(matrix_kernel.contains("__shared__"));
842    }
843    
844    #[test]
845    fn test_vector_kernels() {
846        let vector_kernel = OptimizedKernels::vector_operations_kernel();
847        assert!(vector_kernel.contains("optimized_vector_add"));
848        assert!(vector_kernel.contains("grid-stride loop"));
849    }
850    
851    #[test]
852    fn test_activation_kernels() {
853        let activation_kernel = OptimizedKernels::activation_functions_kernel();
854        assert!(activation_kernel.contains("fast_sigmoid"));
855        assert!(activation_kernel.contains("fast_gelu"));
856    }
857    
858    #[test]
859    fn test_kernel_config() {
860        let config = KernelConfig::for_matrix_multiply(128, 128);
861        assert_eq!(config.block_size, (32, 32, 1));
862        assert_eq!(config.grid_size, (4, 4, 1));
863    }
864    
865    #[test]
866    fn test_activation_kernel_generation() {
867        let kernel = OptimizedKernels::get_activation_kernel(ActivationFunction::ReLU).unwrap();
868        assert!(kernel.contains("fmaxf"));
869        assert!(kernel.contains("specialized_activation"));
870    }
871}