oxirs_vec/gpu/
kernels.rs

1//! CUDA kernel implementations for various vector operations
2
3// Note: anyhow imports removed as they were unused
4use std::collections::HashMap;
5
6/// CUDA kernel manager
7#[derive(Debug)]
8pub struct KernelManager {
9    kernels: HashMap<String, String>,
10}
11
12impl KernelManager {
13    pub fn new() -> Self {
14        let mut manager = Self {
15            kernels: HashMap::new(),
16        };
17        manager.initialize_kernels();
18        manager
19    }
20
21    fn initialize_kernels(&mut self) {
22        let kernels = vec![
23            // Similarity metrics
24            (
25                "cosine_similarity".to_string(),
26                self.get_cosine_similarity_kernel(),
27            ),
28            ("dot_product".to_string(), self.get_dot_product_kernel()),
29            (
30                "pearson_correlation".to_string(),
31                self.get_pearson_correlation_kernel(),
32            ),
33            (
34                "jaccard_similarity".to_string(),
35                self.get_jaccard_similarity_kernel(),
36            ),
37            (
38                "dice_coefficient".to_string(),
39                self.get_dice_coefficient_kernel(),
40            ),
41            (
42                "angular_similarity".to_string(),
43                self.get_angular_similarity_kernel(),
44            ),
45            // Distance metrics
46            (
47                "euclidean_distance".to_string(),
48                self.get_euclidean_distance_kernel(),
49            ),
50            (
51                "manhattan_distance".to_string(),
52                self.get_manhattan_distance_kernel(),
53            ),
54            (
55                "minkowski_distance".to_string(),
56                self.get_minkowski_distance_kernel(),
57            ),
58            (
59                "hamming_distance".to_string(),
60                self.get_hamming_distance_kernel(),
61            ),
62            (
63                "canberra_distance".to_string(),
64                self.get_canberra_distance_kernel(),
65            ),
66            (
67                "chebyshev_distance".to_string(),
68                self.get_chebyshev_distance_kernel(),
69            ),
70            // Utility kernels
71            (
72                "vector_addition".to_string(),
73                self.get_vector_addition_kernel(),
74            ),
75            (
76                "vector_normalization".to_string(),
77                self.get_vector_normalization_kernel(),
78            ),
79            ("hnsw_search".to_string(), self.get_hnsw_search_kernel()),
80            (
81                "batch_distance_computation".to_string(),
82                self.get_batch_distance_kernel(),
83            ),
84            // Mixed-precision kernels (FP16/BF16)
85            (
86                "cosine_similarity_fp16".to_string(),
87                self.get_cosine_similarity_fp16_kernel(),
88            ),
89            (
90                "euclidean_distance_fp16".to_string(),
91                self.get_euclidean_distance_fp16_kernel(),
92            ),
93            // Tensor Core kernels
94            (
95                "matmul_tensor_core".to_string(),
96                self.get_matmul_tensor_core_kernel(),
97            ),
98        ];
99
100        for (name, kernel) in kernels {
101            self.kernels.insert(name, kernel);
102        }
103    }
104
105    pub fn get_kernel(&self, name: &str) -> Option<&String> {
106        self.kernels.get(name)
107    }
108
109    fn get_cosine_similarity_kernel(&self) -> String {
110        r#"
111        extern "C" __global__ void cosine_similarity_kernel(
112            const float* __restrict__ queries,
113            const float* __restrict__ database,
114            float* __restrict__ results,
115            const int query_count,
116            const int db_count,
117            const int dim
118        ) {
119            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
120            const int query_idx = tid / db_count;
121            const int db_idx = tid % db_count;
122
123            if (query_idx >= query_count || db_idx >= db_count) return;
124
125            float dot = 0.0f, norm_q = 0.0f, norm_db = 0.0f;
126
127            const int vec_dim = (dim + 3) / 4;
128            const float4* q_vec = (const float4*)(queries + query_idx * dim);
129            const float4* db_vec = (const float4*)(database + db_idx * dim);
130
131            for (int i = 0; i < vec_dim; i++) {
132                float4 q_vals = q_vec[i];
133                float4 db_vals = db_vec[i];
134
135                dot += q_vals.x * db_vals.x + q_vals.y * db_vals.y +
136                       q_vals.z * db_vals.z + q_vals.w * db_vals.w;
137                norm_q += q_vals.x * q_vals.x + q_vals.y * q_vals.y +
138                          q_vals.z * q_vals.z + q_vals.w * q_vals.w;
139                norm_db += db_vals.x * db_vals.x + db_vals.y * db_vals.y +
140                           db_vals.z * db_vals.z + db_vals.w * db_vals.w;
141            }
142
143            const float norm_product = sqrtf(norm_q) * sqrtf(norm_db);
144            const float similarity = (norm_product > 1e-8f) ? dot / norm_product : 0.0f;
145
146            results[query_idx * db_count + db_idx] = similarity;
147        }
148        "#
149        .to_string()
150    }
151
152    fn get_euclidean_distance_kernel(&self) -> String {
153        r#"
154        extern "C" __global__ void euclidean_distance_kernel(
155            const float* __restrict__ queries,
156            const float* __restrict__ database,
157            float* __restrict__ results,
158            const int query_count,
159            const int db_count,
160            const int dim
161        ) {
162            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
163            const int query_idx = tid / db_count;
164            const int db_idx = tid % db_count;
165
166            if (query_idx >= query_count || db_idx >= db_count) return;
167
168            float sum_sq_diff = 0.0f;
169
170            const int vec_dim = (dim + 3) / 4;
171            const float4* q_vec = (const float4*)(queries + query_idx * dim);
172            const float4* db_vec = (const float4*)(database + db_idx * dim);
173
174            for (int i = 0; i < vec_dim; i++) {
175                float4 q_vals = q_vec[i];
176                float4 db_vals = db_vec[i];
177                float4 diff = make_float4(
178                    q_vals.x - db_vals.x,
179                    q_vals.y - db_vals.y,
180                    q_vals.z - db_vals.z,
181                    q_vals.w - db_vals.w
182                );
183                sum_sq_diff += diff.x * diff.x + diff.y * diff.y +
184                               diff.z * diff.z + diff.w * diff.w;
185            }
186
187            results[query_idx * db_count + db_idx] = sqrtf(sum_sq_diff);
188        }
189        "#
190        .to_string()
191    }
192
193    fn get_dot_product_kernel(&self) -> String {
194        r#"
195        extern "C" __global__ void dot_product_kernel(
196            const float* __restrict__ a,
197            const float* __restrict__ b,
198            float* __restrict__ result,
199            const int n
200        ) {
201            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
202            const int stride = blockDim.x * gridDim.x;
203
204            float sum = 0.0f;
205            for (int i = tid; i < n; i += stride) {
206                sum += a[i] * b[i];
207            }
208
209            __shared__ float shared_sum[256];
210            shared_sum[threadIdx.x] = sum;
211            __syncthreads();
212
213            for (int s = blockDim.x / 2; s > 0; s >>= 1) {
214                if (threadIdx.x < s) {
215                    shared_sum[threadIdx.x] += shared_sum[threadIdx.x + s];
216                }
217                __syncthreads();
218            }
219
220            if (threadIdx.x == 0) {
221                atomicAdd(result, shared_sum[0]);
222            }
223        }
224        "#
225        .to_string()
226    }
227
228    fn get_vector_addition_kernel(&self) -> String {
229        r#"
230        extern "C" __global__ void vector_addition_kernel(
231            const float* __restrict__ a,
232            const float* __restrict__ b,
233            float* __restrict__ result,
234            const int n
235        ) {
236            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
237            if (tid < n) {
238                result[tid] = a[tid] + b[tid];
239            }
240        }
241        "#
242        .to_string()
243    }
244
245    fn get_vector_normalization_kernel(&self) -> String {
246        r#"
247        extern "C" __global__ void vector_normalization_kernel(
248            float* __restrict__ vectors,
249            const int count,
250            const int dim
251        ) {
252            const int vector_idx = blockIdx.x;
253            const int tid = threadIdx.x;
254
255            if (vector_idx >= count) return;
256
257            float* vector = vectors + vector_idx * dim;
258
259            __shared__ float shared_norm;
260            if (tid == 0) shared_norm = 0.0f;
261            __syncthreads();
262
263            float local_sum = 0.0f;
264            for (int i = tid; i < dim; i += blockDim.x) {
265                local_sum += vector[i] * vector[i];
266            }
267
268            atomicAdd(&shared_norm, local_sum);
269            __syncthreads();
270
271            if (tid == 0) {
272                shared_norm = sqrtf(shared_norm);
273                if (shared_norm > 1e-8f) shared_norm = 1.0f / shared_norm;
274            }
275            __syncthreads();
276
277            for (int i = tid; i < dim; i += blockDim.x) {
278                vector[i] *= shared_norm;
279            }
280        }
281        "#
282        .to_string()
283    }
284
285    fn get_hnsw_search_kernel(&self) -> String {
286        r#"
287        extern "C" __global__ void hnsw_search_kernel(
288            const float* __restrict__ query,
289            const float* __restrict__ vectors,
290            const int* __restrict__ adjacency_list,
291            const int* __restrict__ adjacency_offsets,
292            int* __restrict__ candidate_queue,
293            float* __restrict__ candidate_distances,
294            int* __restrict__ queue_size,
295            const int dim,
296            const int entry_point
297        ) {
298            const int tid = threadIdx.x;
299
300            extern __shared__ float shared_data[];
301            float* shared_query = shared_data;
302            int* shared_queue = (int*)(shared_data + dim);
303            float* shared_queue_dist = (float*)(shared_queue + 128);
304
305            if (tid < dim) {
306                shared_query[tid] = query[tid];
307            }
308            __syncthreads();
309
310            int queue_head = 0;
311            int queue_tail = 0;
312
313            if (tid == 0) {
314                shared_queue[0] = entry_point;
315                shared_queue_dist[0] = 0.0f;
316                queue_tail = 1;
317            }
318            __syncthreads();
319
320            while (queue_head < queue_tail && queue_tail < 128) {
321                __syncthreads();
322
323                if (tid == 0 && queue_head < queue_tail) {
324                    int current_node = shared_queue[queue_head];
325                    queue_head++;
326
327                    int neighbor_start = adjacency_offsets[current_node];
328                    int neighbor_end = adjacency_offsets[current_node + 1];
329
330                    for (int i = neighbor_start; i < neighbor_end && queue_tail < 128; i++) {
331                        int neighbor = adjacency_list[i];
332
333                        const float* neighbor_vector = vectors + neighbor * dim;
334                        float neighbor_dist = 0.0f;
335                        for (int d = 0; d < dim; d++) {
336                            float diff = shared_query[d] - neighbor_vector[d];
337                            neighbor_dist += diff * diff;
338                        }
339                        neighbor_dist = sqrtf(neighbor_dist);
340
341                        shared_queue[queue_tail] = neighbor;
342                        shared_queue_dist[queue_tail] = neighbor_dist;
343                        queue_tail++;
344                    }
345                }
346            }
347
348            if (tid < queue_tail) {
349                candidate_queue[tid] = shared_queue[tid];
350                candidate_distances[tid] = shared_queue_dist[tid];
351            }
352
353            if (tid == 0) {
354                *queue_size = queue_tail;
355            }
356        }
357        "#
358        .to_string()
359    }
360
361    fn get_batch_distance_kernel(&self) -> String {
362        r#"
363        extern "C" __global__ void batch_distance_kernel(
364            const float* __restrict__ batch_a,
365            const float* __restrict__ batch_b,
366            float* __restrict__ distances,
367            const int batch_size_a,
368            const int batch_size_b,
369            const int dim,
370            const int metric_type
371        ) {
372            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
373            const int i = tid / batch_size_b;
374            const int j = tid % batch_size_b;
375
376            if (i >= batch_size_a || j >= batch_size_b) return;
377
378            const float* vec_a = batch_a + i * dim;
379            const float* vec_b = batch_b + j * dim;
380
381            float distance = 0.0f;
382
383            if (metric_type == 0) { // Euclidean
384                for (int d = 0; d < dim; d++) {
385                    float diff = vec_a[d] - vec_b[d];
386                    distance += diff * diff;
387                }
388                distance = sqrtf(distance);
389            } else if (metric_type == 1) { // Cosine
390                float dot = 0.0f, norm_a = 0.0f, norm_b = 0.0f;
391                for (int d = 0; d < dim; d++) {
392                    dot += vec_a[d] * vec_b[d];
393                    norm_a += vec_a[d] * vec_a[d];
394                    norm_b += vec_b[d] * vec_b[d];
395                }
396                float norm_product = sqrtf(norm_a) * sqrtf(norm_b);
397                distance = (norm_product > 1e-8f) ? 1.0f - (dot / norm_product) : 1.0f;
398            }
399
400            distances[i * batch_size_b + j] = distance;
401        }
402        "#
403        .to_string()
404    }
405
406    // Additional distance metric kernels
407
408    fn get_manhattan_distance_kernel(&self) -> String {
409        r#"
410        extern "C" __global__ void manhattan_distance_kernel(
411            const float* __restrict__ queries,
412            const float* __restrict__ database,
413            float* __restrict__ results,
414            const int query_count,
415            const int db_count,
416            const int dim
417        ) {
418            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
419            const int query_idx = tid / db_count;
420            const int db_idx = tid % db_count;
421
422            if (query_idx >= query_count || db_idx >= db_count) return;
423
424            float sum_abs_diff = 0.0f;
425            const float* q_vec = queries + query_idx * dim;
426            const float* db_vec = database + db_idx * dim;
427
428            for (int i = 0; i < dim; i++) {
429                sum_abs_diff += fabsf(q_vec[i] - db_vec[i]);
430            }
431
432            results[query_idx * db_count + db_idx] = sum_abs_diff;
433        }
434        "#
435        .to_string()
436    }
437
438    fn get_minkowski_distance_kernel(&self) -> String {
439        r#"
440        extern "C" __global__ void minkowski_distance_kernel(
441            const float* __restrict__ queries,
442            const float* __restrict__ database,
443            float* __restrict__ results,
444            const int query_count,
445            const int db_count,
446            const int dim,
447            const float p
448        ) {
449            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
450            const int query_idx = tid / db_count;
451            const int db_idx = tid % db_count;
452
453            if (query_idx >= query_count || db_idx >= db_count) return;
454
455            float sum_pow_diff = 0.0f;
456            const float* q_vec = queries + query_idx * dim;
457            const float* db_vec = database + db_idx * dim;
458
459            for (int i = 0; i < dim; i++) {
460                float diff = fabsf(q_vec[i] - db_vec[i]);
461                sum_pow_diff += powf(diff, p);
462            }
463
464            results[query_idx * db_count + db_idx] = powf(sum_pow_diff, 1.0f / p);
465        }
466        "#
467        .to_string()
468    }
469
470    fn get_pearson_correlation_kernel(&self) -> String {
471        r#"
472        extern "C" __global__ void pearson_correlation_kernel(
473            const float* __restrict__ queries,
474            const float* __restrict__ database,
475            float* __restrict__ results,
476            const int query_count,
477            const int db_count,
478            const int dim
479        ) {
480            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
481            const int query_idx = tid / db_count;
482            const int db_idx = tid % db_count;
483
484            if (query_idx >= query_count || db_idx >= db_count) return;
485
486            const float* q_vec = queries + query_idx * dim;
487            const float* db_vec = database + db_idx * dim;
488
489            // Calculate means
490            float mean_q = 0.0f, mean_db = 0.0f;
491            for (int i = 0; i < dim; i++) {
492                mean_q += q_vec[i];
493                mean_db += db_vec[i];
494            }
495            mean_q /= dim;
496            mean_db /= dim;
497
498            // Calculate correlation
499            float numerator = 0.0f, var_q = 0.0f, var_db = 0.0f;
500            for (int i = 0; i < dim; i++) {
501                float q_centered = q_vec[i] - mean_q;
502                float db_centered = db_vec[i] - mean_db;
503                numerator += q_centered * db_centered;
504                var_q += q_centered * q_centered;
505                var_db += db_centered * db_centered;
506            }
507
508            float denominator = sqrtf(var_q) * sqrtf(var_db);
509            float correlation = (denominator > 1e-8f) ? numerator / denominator : 0.0f;
510
511            results[query_idx * db_count + db_idx] = (correlation + 1.0f) / 2.0f;  // Normalize to [0, 1]
512        }
513        "#
514        .to_string()
515    }
516
517    fn get_jaccard_similarity_kernel(&self) -> String {
518        r#"
519        extern "C" __global__ void jaccard_similarity_kernel(
520            const float* __restrict__ queries,
521            const float* __restrict__ database,
522            float* __restrict__ results,
523            const int query_count,
524            const int db_count,
525            const int dim
526        ) {
527            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
528            const int query_idx = tid / db_count;
529            const int db_idx = tid % db_count;
530
531            if (query_idx >= query_count || db_idx >= db_count) return;
532
533            const float* q_vec = queries + query_idx * dim;
534            const float* db_vec = database + db_idx * dim;
535
536            float intersection = 0.0f, union_val = 0.0f;
537            for (int i = 0; i < dim; i++) {
538                float min_val = fminf(q_vec[i], db_vec[i]);
539                float max_val = fmaxf(q_vec[i], db_vec[i]);
540                intersection += min_val;
541                union_val += max_val;
542            }
543
544            float similarity = (union_val > 1e-8f) ? intersection / union_val : 0.0f;
545            results[query_idx * db_count + db_idx] = similarity;
546        }
547        "#
548        .to_string()
549    }
550
551    fn get_dice_coefficient_kernel(&self) -> String {
552        r#"
553        extern "C" __global__ void dice_coefficient_kernel(
554            const float* __restrict__ queries,
555            const float* __restrict__ database,
556            float* __restrict__ results,
557            const int query_count,
558            const int db_count,
559            const int dim
560        ) {
561            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
562            const int query_idx = tid / db_count;
563            const int db_idx = tid % db_count;
564
565            if (query_idx >= query_count || db_idx >= db_count) return;
566
567            const float* q_vec = queries + query_idx * dim;
568            const float* db_vec = database + db_idx * dim;
569
570            float intersection = 0.0f, sum_q = 0.0f, sum_db = 0.0f;
571            for (int i = 0; i < dim; i++) {
572                intersection += fminf(q_vec[i], db_vec[i]);
573                sum_q += q_vec[i];
574                sum_db += db_vec[i];
575            }
576
577            float denominator = sum_q + sum_db;
578            float dice = (denominator > 1e-8f) ? (2.0f * intersection) / denominator : 0.0f;
579            results[query_idx * db_count + db_idx] = dice;
580        }
581        "#
582        .to_string()
583    }
584
585    fn get_hamming_distance_kernel(&self) -> String {
586        r#"
587        extern "C" __global__ void hamming_distance_kernel(
588            const float* __restrict__ queries,
589            const float* __restrict__ database,
590            float* __restrict__ results,
591            const int query_count,
592            const int db_count,
593            const int dim
594        ) {
595            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
596            const int query_idx = tid / db_count;
597            const int db_idx = tid % db_count;
598
599            if (query_idx >= query_count || db_idx >= db_count) return;
600
601            const float* q_vec = queries + query_idx * dim;
602            const float* db_vec = database + db_idx * dim;
603
604            int hamming_dist = 0;
605            for (int i = 0; i < dim; i++) {
606                if (fabsf(q_vec[i] - db_vec[i]) > 1e-6f) {
607                    hamming_dist++;
608                }
609            }
610
611            results[query_idx * db_count + db_idx] = (float)hamming_dist / (float)dim;
612        }
613        "#
614        .to_string()
615    }
616
617    fn get_canberra_distance_kernel(&self) -> String {
618        r#"
619        extern "C" __global__ void canberra_distance_kernel(
620            const float* __restrict__ queries,
621            const float* __restrict__ database,
622            float* __restrict__ results,
623            const int query_count,
624            const int db_count,
625            const int dim
626        ) {
627            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
628            const int query_idx = tid / db_count;
629            const int db_idx = tid % db_count;
630
631            if (query_idx >= query_count || db_idx >= db_count) return;
632
633            const float* q_vec = queries + query_idx * dim;
634            const float* db_vec = database + db_idx * dim;
635
636            float distance = 0.0f;
637            for (int i = 0; i < dim; i++) {
638                float numerator = fabsf(q_vec[i] - db_vec[i]);
639                float denominator = fabsf(q_vec[i]) + fabsf(db_vec[i]);
640                if (denominator > 1e-8f) {
641                    distance += numerator / denominator;
642                }
643            }
644
645            results[query_idx * db_count + db_idx] = distance;
646        }
647        "#
648        .to_string()
649    }
650
651    fn get_chebyshev_distance_kernel(&self) -> String {
652        r#"
653        extern "C" __global__ void chebyshev_distance_kernel(
654            const float* __restrict__ queries,
655            const float* __restrict__ database,
656            float* __restrict__ results,
657            const int query_count,
658            const int db_count,
659            const int dim
660        ) {
661            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
662            const int query_idx = tid / db_count;
663            const int db_idx = tid % db_count;
664
665            if (query_idx >= query_count || db_idx >= db_count) return;
666
667            const float* q_vec = queries + query_idx * dim;
668            const float* db_vec = database + db_idx * dim;
669
670            float max_diff = 0.0f;
671            for (int i = 0; i < dim; i++) {
672                float diff = fabsf(q_vec[i] - db_vec[i]);
673                max_diff = fmaxf(max_diff, diff);
674            }
675
676            results[query_idx * db_count + db_idx] = max_diff;
677        }
678        "#
679        .to_string()
680    }
681
682    fn get_angular_similarity_kernel(&self) -> String {
683        r#"
684        extern "C" __global__ void angular_similarity_kernel(
685            const float* __restrict__ queries,
686            const float* __restrict__ database,
687            float* __restrict__ results,
688            const int query_count,
689            const int db_count,
690            const int dim
691        ) {
692            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
693            const int query_idx = tid / db_count;
694            const int db_idx = tid % db_count;
695
696            if (query_idx >= query_count || db_idx >= db_count) return;
697
698            float dot = 0.0f, norm_q = 0.0f, norm_db = 0.0f;
699            const float* q_vec = queries + query_idx * dim;
700            const float* db_vec = database + db_idx * dim;
701
702            for (int i = 0; i < dim; i++) {
703                dot += q_vec[i] * db_vec[i];
704                norm_q += q_vec[i] * q_vec[i];
705                norm_db += db_vec[i] * db_vec[i];
706            }
707
708            float norm_product = sqrtf(norm_q) * sqrtf(norm_db);
709            float cosine = (norm_product > 1e-8f) ? dot / norm_product : 0.0f;
710            cosine = fminf(1.0f, fmaxf(-1.0f, cosine));  // Clamp to [-1, 1]
711
712            // Angular distance in radians, normalized to [0, 1]
713            float angular_dist = acosf(cosine) / 3.14159265359f;
714            float similarity = 1.0f - angular_dist;
715
716            results[query_idx * db_count + db_idx] = similarity;
717        }
718        "#
719        .to_string()
720    }
721
722    // Mixed-precision kernels (FP16)
723
724    fn get_cosine_similarity_fp16_kernel(&self) -> String {
725        r#"
726        #include <cuda_fp16.h>
727
728        extern "C" __global__ void cosine_similarity_fp16_kernel(
729            const half* __restrict__ queries,
730            const half* __restrict__ database,
731            float* __restrict__ results,
732            const int query_count,
733            const int db_count,
734            const int dim
735        ) {
736            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
737            const int query_idx = tid / db_count;
738            const int db_idx = tid % db_count;
739
740            if (query_idx >= query_count || db_idx >= db_count) return;
741
742            float dot = 0.0f, norm_q = 0.0f, norm_db = 0.0f;
743            const half* q_vec = queries + query_idx * dim;
744            const half* db_vec = database + db_idx * dim;
745
746            // Process in chunks of 2 for half2 vectorization
747            const int vec_dim = dim / 2;
748            const half2* q_vec2 = (const half2*)q_vec;
749            const half2* db_vec2 = (const half2*)db_vec;
750
751            for (int i = 0; i < vec_dim; i++) {
752                half2 q_vals = q_vec2[i];
753                half2 db_vals = db_vec2[i];
754
755                float2 q_f = __half22float2(q_vals);
756                float2 db_f = __half22float2(db_vals);
757
758                dot += q_f.x * db_f.x + q_f.y * db_f.y;
759                norm_q += q_f.x * q_f.x + q_f.y * q_f.y;
760                norm_db += db_f.x * db_f.x + db_f.y * db_f.y;
761            }
762
763            // Handle odd dimension
764            if (dim % 2 == 1) {
765                float q_last = __half2float(q_vec[dim - 1]);
766                float db_last = __half2float(db_vec[dim - 1]);
767                dot += q_last * db_last;
768                norm_q += q_last * q_last;
769                norm_db += db_last * db_last;
770            }
771
772            const float norm_product = sqrtf(norm_q) * sqrtf(norm_db);
773            const float similarity = (norm_product > 1e-8f) ? dot / norm_product : 0.0f;
774
775            results[query_idx * db_count + db_idx] = similarity;
776        }
777        "#
778        .to_string()
779    }
780
781    fn get_euclidean_distance_fp16_kernel(&self) -> String {
782        r#"
783        #include <cuda_fp16.h>
784
785        extern "C" __global__ void euclidean_distance_fp16_kernel(
786            const half* __restrict__ queries,
787            const half* __restrict__ database,
788            float* __restrict__ results,
789            const int query_count,
790            const int db_count,
791            const int dim
792        ) {
793            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
794            const int query_idx = tid / db_count;
795            const int db_idx = tid % db_count;
796
797            if (query_idx >= query_count || db_idx >= db_count) return;
798
799            float sum_sq_diff = 0.0f;
800            const half* q_vec = queries + query_idx * dim;
801            const half* db_vec = database + db_idx * dim;
802
803            const int vec_dim = dim / 2;
804            const half2* q_vec2 = (const half2*)q_vec;
805            const half2* db_vec2 = (const half2*)db_vec;
806
807            for (int i = 0; i < vec_dim; i++) {
808                float2 q_f = __half22float2(q_vec2[i]);
809                float2 db_f = __half22float2(db_vec2[i]);
810
811                float diff_x = q_f.x - db_f.x;
812                float diff_y = q_f.y - db_f.y;
813                sum_sq_diff += diff_x * diff_x + diff_y * diff_y;
814            }
815
816            if (dim % 2 == 1) {
817                float diff = __half2float(q_vec[dim - 1]) - __half2float(db_vec[dim - 1]);
818                sum_sq_diff += diff * diff;
819            }
820
821            results[query_idx * db_count + db_idx] = sqrtf(sum_sq_diff);
822        }
823        "#
824        .to_string()
825    }
826
827    // Tensor Core kernel for matrix multiplication
828
829    fn get_matmul_tensor_core_kernel(&self) -> String {
830        r#"
831        #include <mma.h>
832        using namespace nvcuda;
833
834        extern "C" __global__ void matmul_tensor_core_kernel(
835            const half* __restrict__ a,
836            const half* __restrict__ b,
837            float* __restrict__ c,
838            const int m,
839            const int n,
840            const int k
841        ) {
842            // Warp and lane identifiers
843            const int warp_id = threadIdx.x / 32;
844            const int lane_id = threadIdx.x % 32;
845
846            // WMMA fragment declarations (16x16x16)
847            wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
848            wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> b_frag;
849            wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag;
850
851            // Initialize accumulator
852            wmma::fill_fragment(c_frag, 0.0f);
853
854            // Tile indices
855            const int tile_m = blockIdx.y * 16;
856            const int tile_n = blockIdx.x * 16;
857
858            // Compute matrix multiplication using Tensor Cores
859            for (int i = 0; i < k; i += 16) {
860                wmma::load_matrix_sync(a_frag, a + tile_m * k + i, k);
861                wmma::load_matrix_sync(b_frag, b + i * n + tile_n, n);
862                wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
863            }
864
865            // Store result
866            wmma::store_matrix_sync(c + tile_m * n + tile_n, c_frag, n, wmma::mem_row_major);
867        }
868        "#
869        .to_string()
870    }
871}
872
873impl Default for KernelManager {
874    fn default() -> Self {
875        Self::new()
876    }
877}