1use std::collections::HashMap;
5
6#[derive(Debug)]
8pub struct KernelManager {
9 kernels: HashMap<String, String>,
10}
11
12impl KernelManager {
13 pub fn new() -> Self {
14 let mut manager = Self {
15 kernels: HashMap::new(),
16 };
17 manager.initialize_kernels();
18 manager
19 }
20
21 fn initialize_kernels(&mut self) {
22 let kernels = vec![
23 (
25 "cosine_similarity".to_string(),
26 self.get_cosine_similarity_kernel(),
27 ),
28 ("dot_product".to_string(), self.get_dot_product_kernel()),
29 (
30 "pearson_correlation".to_string(),
31 self.get_pearson_correlation_kernel(),
32 ),
33 (
34 "jaccard_similarity".to_string(),
35 self.get_jaccard_similarity_kernel(),
36 ),
37 (
38 "dice_coefficient".to_string(),
39 self.get_dice_coefficient_kernel(),
40 ),
41 (
42 "angular_similarity".to_string(),
43 self.get_angular_similarity_kernel(),
44 ),
45 (
47 "euclidean_distance".to_string(),
48 self.get_euclidean_distance_kernel(),
49 ),
50 (
51 "manhattan_distance".to_string(),
52 self.get_manhattan_distance_kernel(),
53 ),
54 (
55 "minkowski_distance".to_string(),
56 self.get_minkowski_distance_kernel(),
57 ),
58 (
59 "hamming_distance".to_string(),
60 self.get_hamming_distance_kernel(),
61 ),
62 (
63 "canberra_distance".to_string(),
64 self.get_canberra_distance_kernel(),
65 ),
66 (
67 "chebyshev_distance".to_string(),
68 self.get_chebyshev_distance_kernel(),
69 ),
70 (
72 "vector_addition".to_string(),
73 self.get_vector_addition_kernel(),
74 ),
75 (
76 "vector_normalization".to_string(),
77 self.get_vector_normalization_kernel(),
78 ),
79 ("hnsw_search".to_string(), self.get_hnsw_search_kernel()),
80 (
81 "batch_distance_computation".to_string(),
82 self.get_batch_distance_kernel(),
83 ),
84 (
86 "cosine_similarity_fp16".to_string(),
87 self.get_cosine_similarity_fp16_kernel(),
88 ),
89 (
90 "euclidean_distance_fp16".to_string(),
91 self.get_euclidean_distance_fp16_kernel(),
92 ),
93 (
95 "matmul_tensor_core".to_string(),
96 self.get_matmul_tensor_core_kernel(),
97 ),
98 ];
99
100 for (name, kernel) in kernels {
101 self.kernels.insert(name, kernel);
102 }
103 }
104
105 pub fn get_kernel(&self, name: &str) -> Option<&String> {
106 self.kernels.get(name)
107 }
108
109 fn get_cosine_similarity_kernel(&self) -> String {
110 r#"
111 extern "C" __global__ void cosine_similarity_kernel(
112 const float* __restrict__ queries,
113 const float* __restrict__ database,
114 float* __restrict__ results,
115 const int query_count,
116 const int db_count,
117 const int dim
118 ) {
119 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
120 const int query_idx = tid / db_count;
121 const int db_idx = tid % db_count;
122
123 if (query_idx >= query_count || db_idx >= db_count) return;
124
125 float dot = 0.0f, norm_q = 0.0f, norm_db = 0.0f;
126
127 const int vec_dim = (dim + 3) / 4;
128 const float4* q_vec = (const float4*)(queries + query_idx * dim);
129 const float4* db_vec = (const float4*)(database + db_idx * dim);
130
131 for (int i = 0; i < vec_dim; i++) {
132 float4 q_vals = q_vec[i];
133 float4 db_vals = db_vec[i];
134
135 dot += q_vals.x * db_vals.x + q_vals.y * db_vals.y +
136 q_vals.z * db_vals.z + q_vals.w * db_vals.w;
137 norm_q += q_vals.x * q_vals.x + q_vals.y * q_vals.y +
138 q_vals.z * q_vals.z + q_vals.w * q_vals.w;
139 norm_db += db_vals.x * db_vals.x + db_vals.y * db_vals.y +
140 db_vals.z * db_vals.z + db_vals.w * db_vals.w;
141 }
142
143 const float norm_product = sqrtf(norm_q) * sqrtf(norm_db);
144 const float similarity = (norm_product > 1e-8f) ? dot / norm_product : 0.0f;
145
146 results[query_idx * db_count + db_idx] = similarity;
147 }
148 "#
149 .to_string()
150 }
151
152 fn get_euclidean_distance_kernel(&self) -> String {
153 r#"
154 extern "C" __global__ void euclidean_distance_kernel(
155 const float* __restrict__ queries,
156 const float* __restrict__ database,
157 float* __restrict__ results,
158 const int query_count,
159 const int db_count,
160 const int dim
161 ) {
162 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
163 const int query_idx = tid / db_count;
164 const int db_idx = tid % db_count;
165
166 if (query_idx >= query_count || db_idx >= db_count) return;
167
168 float sum_sq_diff = 0.0f;
169
170 const int vec_dim = (dim + 3) / 4;
171 const float4* q_vec = (const float4*)(queries + query_idx * dim);
172 const float4* db_vec = (const float4*)(database + db_idx * dim);
173
174 for (int i = 0; i < vec_dim; i++) {
175 float4 q_vals = q_vec[i];
176 float4 db_vals = db_vec[i];
177 float4 diff = make_float4(
178 q_vals.x - db_vals.x,
179 q_vals.y - db_vals.y,
180 q_vals.z - db_vals.z,
181 q_vals.w - db_vals.w
182 );
183 sum_sq_diff += diff.x * diff.x + diff.y * diff.y +
184 diff.z * diff.z + diff.w * diff.w;
185 }
186
187 results[query_idx * db_count + db_idx] = sqrtf(sum_sq_diff);
188 }
189 "#
190 .to_string()
191 }
192
193 fn get_dot_product_kernel(&self) -> String {
194 r#"
195 extern "C" __global__ void dot_product_kernel(
196 const float* __restrict__ a,
197 const float* __restrict__ b,
198 float* __restrict__ result,
199 const int n
200 ) {
201 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
202 const int stride = blockDim.x * gridDim.x;
203
204 float sum = 0.0f;
205 for (int i = tid; i < n; i += stride) {
206 sum += a[i] * b[i];
207 }
208
209 __shared__ float shared_sum[256];
210 shared_sum[threadIdx.x] = sum;
211 __syncthreads();
212
213 for (int s = blockDim.x / 2; s > 0; s >>= 1) {
214 if (threadIdx.x < s) {
215 shared_sum[threadIdx.x] += shared_sum[threadIdx.x + s];
216 }
217 __syncthreads();
218 }
219
220 if (threadIdx.x == 0) {
221 atomicAdd(result, shared_sum[0]);
222 }
223 }
224 "#
225 .to_string()
226 }
227
228 fn get_vector_addition_kernel(&self) -> String {
229 r#"
230 extern "C" __global__ void vector_addition_kernel(
231 const float* __restrict__ a,
232 const float* __restrict__ b,
233 float* __restrict__ result,
234 const int n
235 ) {
236 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
237 if (tid < n) {
238 result[tid] = a[tid] + b[tid];
239 }
240 }
241 "#
242 .to_string()
243 }
244
245 fn get_vector_normalization_kernel(&self) -> String {
246 r#"
247 extern "C" __global__ void vector_normalization_kernel(
248 float* __restrict__ vectors,
249 const int count,
250 const int dim
251 ) {
252 const int vector_idx = blockIdx.x;
253 const int tid = threadIdx.x;
254
255 if (vector_idx >= count) return;
256
257 float* vector = vectors + vector_idx * dim;
258
259 __shared__ float shared_norm;
260 if (tid == 0) shared_norm = 0.0f;
261 __syncthreads();
262
263 float local_sum = 0.0f;
264 for (int i = tid; i < dim; i += blockDim.x) {
265 local_sum += vector[i] * vector[i];
266 }
267
268 atomicAdd(&shared_norm, local_sum);
269 __syncthreads();
270
271 if (tid == 0) {
272 shared_norm = sqrtf(shared_norm);
273 if (shared_norm > 1e-8f) shared_norm = 1.0f / shared_norm;
274 }
275 __syncthreads();
276
277 for (int i = tid; i < dim; i += blockDim.x) {
278 vector[i] *= shared_norm;
279 }
280 }
281 "#
282 .to_string()
283 }
284
285 fn get_hnsw_search_kernel(&self) -> String {
286 r#"
287 extern "C" __global__ void hnsw_search_kernel(
288 const float* __restrict__ query,
289 const float* __restrict__ vectors,
290 const int* __restrict__ adjacency_list,
291 const int* __restrict__ adjacency_offsets,
292 int* __restrict__ candidate_queue,
293 float* __restrict__ candidate_distances,
294 int* __restrict__ queue_size,
295 const int dim,
296 const int entry_point
297 ) {
298 const int tid = threadIdx.x;
299
300 extern __shared__ float shared_data[];
301 float* shared_query = shared_data;
302 int* shared_queue = (int*)(shared_data + dim);
303 float* shared_queue_dist = (float*)(shared_queue + 128);
304
305 if (tid < dim) {
306 shared_query[tid] = query[tid];
307 }
308 __syncthreads();
309
310 int queue_head = 0;
311 int queue_tail = 0;
312
313 if (tid == 0) {
314 shared_queue[0] = entry_point;
315 shared_queue_dist[0] = 0.0f;
316 queue_tail = 1;
317 }
318 __syncthreads();
319
320 while (queue_head < queue_tail && queue_tail < 128) {
321 __syncthreads();
322
323 if (tid == 0 && queue_head < queue_tail) {
324 int current_node = shared_queue[queue_head];
325 queue_head++;
326
327 int neighbor_start = adjacency_offsets[current_node];
328 int neighbor_end = adjacency_offsets[current_node + 1];
329
330 for (int i = neighbor_start; i < neighbor_end && queue_tail < 128; i++) {
331 int neighbor = adjacency_list[i];
332
333 const float* neighbor_vector = vectors + neighbor * dim;
334 float neighbor_dist = 0.0f;
335 for (int d = 0; d < dim; d++) {
336 float diff = shared_query[d] - neighbor_vector[d];
337 neighbor_dist += diff * diff;
338 }
339 neighbor_dist = sqrtf(neighbor_dist);
340
341 shared_queue[queue_tail] = neighbor;
342 shared_queue_dist[queue_tail] = neighbor_dist;
343 queue_tail++;
344 }
345 }
346 }
347
348 if (tid < queue_tail) {
349 candidate_queue[tid] = shared_queue[tid];
350 candidate_distances[tid] = shared_queue_dist[tid];
351 }
352
353 if (tid == 0) {
354 *queue_size = queue_tail;
355 }
356 }
357 "#
358 .to_string()
359 }
360
361 fn get_batch_distance_kernel(&self) -> String {
362 r#"
363 extern "C" __global__ void batch_distance_kernel(
364 const float* __restrict__ batch_a,
365 const float* __restrict__ batch_b,
366 float* __restrict__ distances,
367 const int batch_size_a,
368 const int batch_size_b,
369 const int dim,
370 const int metric_type
371 ) {
372 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
373 const int i = tid / batch_size_b;
374 const int j = tid % batch_size_b;
375
376 if (i >= batch_size_a || j >= batch_size_b) return;
377
378 const float* vec_a = batch_a + i * dim;
379 const float* vec_b = batch_b + j * dim;
380
381 float distance = 0.0f;
382
383 if (metric_type == 0) { // Euclidean
384 for (int d = 0; d < dim; d++) {
385 float diff = vec_a[d] - vec_b[d];
386 distance += diff * diff;
387 }
388 distance = sqrtf(distance);
389 } else if (metric_type == 1) { // Cosine
390 float dot = 0.0f, norm_a = 0.0f, norm_b = 0.0f;
391 for (int d = 0; d < dim; d++) {
392 dot += vec_a[d] * vec_b[d];
393 norm_a += vec_a[d] * vec_a[d];
394 norm_b += vec_b[d] * vec_b[d];
395 }
396 float norm_product = sqrtf(norm_a) * sqrtf(norm_b);
397 distance = (norm_product > 1e-8f) ? 1.0f - (dot / norm_product) : 1.0f;
398 }
399
400 distances[i * batch_size_b + j] = distance;
401 }
402 "#
403 .to_string()
404 }
405
406 fn get_manhattan_distance_kernel(&self) -> String {
409 r#"
410 extern "C" __global__ void manhattan_distance_kernel(
411 const float* __restrict__ queries,
412 const float* __restrict__ database,
413 float* __restrict__ results,
414 const int query_count,
415 const int db_count,
416 const int dim
417 ) {
418 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
419 const int query_idx = tid / db_count;
420 const int db_idx = tid % db_count;
421
422 if (query_idx >= query_count || db_idx >= db_count) return;
423
424 float sum_abs_diff = 0.0f;
425 const float* q_vec = queries + query_idx * dim;
426 const float* db_vec = database + db_idx * dim;
427
428 for (int i = 0; i < dim; i++) {
429 sum_abs_diff += fabsf(q_vec[i] - db_vec[i]);
430 }
431
432 results[query_idx * db_count + db_idx] = sum_abs_diff;
433 }
434 "#
435 .to_string()
436 }
437
438 fn get_minkowski_distance_kernel(&self) -> String {
439 r#"
440 extern "C" __global__ void minkowski_distance_kernel(
441 const float* __restrict__ queries,
442 const float* __restrict__ database,
443 float* __restrict__ results,
444 const int query_count,
445 const int db_count,
446 const int dim,
447 const float p
448 ) {
449 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
450 const int query_idx = tid / db_count;
451 const int db_idx = tid % db_count;
452
453 if (query_idx >= query_count || db_idx >= db_count) return;
454
455 float sum_pow_diff = 0.0f;
456 const float* q_vec = queries + query_idx * dim;
457 const float* db_vec = database + db_idx * dim;
458
459 for (int i = 0; i < dim; i++) {
460 float diff = fabsf(q_vec[i] - db_vec[i]);
461 sum_pow_diff += powf(diff, p);
462 }
463
464 results[query_idx * db_count + db_idx] = powf(sum_pow_diff, 1.0f / p);
465 }
466 "#
467 .to_string()
468 }
469
470 fn get_pearson_correlation_kernel(&self) -> String {
471 r#"
472 extern "C" __global__ void pearson_correlation_kernel(
473 const float* __restrict__ queries,
474 const float* __restrict__ database,
475 float* __restrict__ results,
476 const int query_count,
477 const int db_count,
478 const int dim
479 ) {
480 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
481 const int query_idx = tid / db_count;
482 const int db_idx = tid % db_count;
483
484 if (query_idx >= query_count || db_idx >= db_count) return;
485
486 const float* q_vec = queries + query_idx * dim;
487 const float* db_vec = database + db_idx * dim;
488
489 // Calculate means
490 float mean_q = 0.0f, mean_db = 0.0f;
491 for (int i = 0; i < dim; i++) {
492 mean_q += q_vec[i];
493 mean_db += db_vec[i];
494 }
495 mean_q /= dim;
496 mean_db /= dim;
497
498 // Calculate correlation
499 float numerator = 0.0f, var_q = 0.0f, var_db = 0.0f;
500 for (int i = 0; i < dim; i++) {
501 float q_centered = q_vec[i] - mean_q;
502 float db_centered = db_vec[i] - mean_db;
503 numerator += q_centered * db_centered;
504 var_q += q_centered * q_centered;
505 var_db += db_centered * db_centered;
506 }
507
508 float denominator = sqrtf(var_q) * sqrtf(var_db);
509 float correlation = (denominator > 1e-8f) ? numerator / denominator : 0.0f;
510
511 results[query_idx * db_count + db_idx] = (correlation + 1.0f) / 2.0f; // Normalize to [0, 1]
512 }
513 "#
514 .to_string()
515 }
516
517 fn get_jaccard_similarity_kernel(&self) -> String {
518 r#"
519 extern "C" __global__ void jaccard_similarity_kernel(
520 const float* __restrict__ queries,
521 const float* __restrict__ database,
522 float* __restrict__ results,
523 const int query_count,
524 const int db_count,
525 const int dim
526 ) {
527 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
528 const int query_idx = tid / db_count;
529 const int db_idx = tid % db_count;
530
531 if (query_idx >= query_count || db_idx >= db_count) return;
532
533 const float* q_vec = queries + query_idx * dim;
534 const float* db_vec = database + db_idx * dim;
535
536 float intersection = 0.0f, union_val = 0.0f;
537 for (int i = 0; i < dim; i++) {
538 float min_val = fminf(q_vec[i], db_vec[i]);
539 float max_val = fmaxf(q_vec[i], db_vec[i]);
540 intersection += min_val;
541 union_val += max_val;
542 }
543
544 float similarity = (union_val > 1e-8f) ? intersection / union_val : 0.0f;
545 results[query_idx * db_count + db_idx] = similarity;
546 }
547 "#
548 .to_string()
549 }
550
551 fn get_dice_coefficient_kernel(&self) -> String {
552 r#"
553 extern "C" __global__ void dice_coefficient_kernel(
554 const float* __restrict__ queries,
555 const float* __restrict__ database,
556 float* __restrict__ results,
557 const int query_count,
558 const int db_count,
559 const int dim
560 ) {
561 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
562 const int query_idx = tid / db_count;
563 const int db_idx = tid % db_count;
564
565 if (query_idx >= query_count || db_idx >= db_count) return;
566
567 const float* q_vec = queries + query_idx * dim;
568 const float* db_vec = database + db_idx * dim;
569
570 float intersection = 0.0f, sum_q = 0.0f, sum_db = 0.0f;
571 for (int i = 0; i < dim; i++) {
572 intersection += fminf(q_vec[i], db_vec[i]);
573 sum_q += q_vec[i];
574 sum_db += db_vec[i];
575 }
576
577 float denominator = sum_q + sum_db;
578 float dice = (denominator > 1e-8f) ? (2.0f * intersection) / denominator : 0.0f;
579 results[query_idx * db_count + db_idx] = dice;
580 }
581 "#
582 .to_string()
583 }
584
585 fn get_hamming_distance_kernel(&self) -> String {
586 r#"
587 extern "C" __global__ void hamming_distance_kernel(
588 const float* __restrict__ queries,
589 const float* __restrict__ database,
590 float* __restrict__ results,
591 const int query_count,
592 const int db_count,
593 const int dim
594 ) {
595 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
596 const int query_idx = tid / db_count;
597 const int db_idx = tid % db_count;
598
599 if (query_idx >= query_count || db_idx >= db_count) return;
600
601 const float* q_vec = queries + query_idx * dim;
602 const float* db_vec = database + db_idx * dim;
603
604 int hamming_dist = 0;
605 for (int i = 0; i < dim; i++) {
606 if (fabsf(q_vec[i] - db_vec[i]) > 1e-6f) {
607 hamming_dist++;
608 }
609 }
610
611 results[query_idx * db_count + db_idx] = (float)hamming_dist / (float)dim;
612 }
613 "#
614 .to_string()
615 }
616
617 fn get_canberra_distance_kernel(&self) -> String {
618 r#"
619 extern "C" __global__ void canberra_distance_kernel(
620 const float* __restrict__ queries,
621 const float* __restrict__ database,
622 float* __restrict__ results,
623 const int query_count,
624 const int db_count,
625 const int dim
626 ) {
627 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
628 const int query_idx = tid / db_count;
629 const int db_idx = tid % db_count;
630
631 if (query_idx >= query_count || db_idx >= db_count) return;
632
633 const float* q_vec = queries + query_idx * dim;
634 const float* db_vec = database + db_idx * dim;
635
636 float distance = 0.0f;
637 for (int i = 0; i < dim; i++) {
638 float numerator = fabsf(q_vec[i] - db_vec[i]);
639 float denominator = fabsf(q_vec[i]) + fabsf(db_vec[i]);
640 if (denominator > 1e-8f) {
641 distance += numerator / denominator;
642 }
643 }
644
645 results[query_idx * db_count + db_idx] = distance;
646 }
647 "#
648 .to_string()
649 }
650
651 fn get_chebyshev_distance_kernel(&self) -> String {
652 r#"
653 extern "C" __global__ void chebyshev_distance_kernel(
654 const float* __restrict__ queries,
655 const float* __restrict__ database,
656 float* __restrict__ results,
657 const int query_count,
658 const int db_count,
659 const int dim
660 ) {
661 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
662 const int query_idx = tid / db_count;
663 const int db_idx = tid % db_count;
664
665 if (query_idx >= query_count || db_idx >= db_count) return;
666
667 const float* q_vec = queries + query_idx * dim;
668 const float* db_vec = database + db_idx * dim;
669
670 float max_diff = 0.0f;
671 for (int i = 0; i < dim; i++) {
672 float diff = fabsf(q_vec[i] - db_vec[i]);
673 max_diff = fmaxf(max_diff, diff);
674 }
675
676 results[query_idx * db_count + db_idx] = max_diff;
677 }
678 "#
679 .to_string()
680 }
681
682 fn get_angular_similarity_kernel(&self) -> String {
683 r#"
684 extern "C" __global__ void angular_similarity_kernel(
685 const float* __restrict__ queries,
686 const float* __restrict__ database,
687 float* __restrict__ results,
688 const int query_count,
689 const int db_count,
690 const int dim
691 ) {
692 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
693 const int query_idx = tid / db_count;
694 const int db_idx = tid % db_count;
695
696 if (query_idx >= query_count || db_idx >= db_count) return;
697
698 float dot = 0.0f, norm_q = 0.0f, norm_db = 0.0f;
699 const float* q_vec = queries + query_idx * dim;
700 const float* db_vec = database + db_idx * dim;
701
702 for (int i = 0; i < dim; i++) {
703 dot += q_vec[i] * db_vec[i];
704 norm_q += q_vec[i] * q_vec[i];
705 norm_db += db_vec[i] * db_vec[i];
706 }
707
708 float norm_product = sqrtf(norm_q) * sqrtf(norm_db);
709 float cosine = (norm_product > 1e-8f) ? dot / norm_product : 0.0f;
710 cosine = fminf(1.0f, fmaxf(-1.0f, cosine)); // Clamp to [-1, 1]
711
712 // Angular distance in radians, normalized to [0, 1]
713 float angular_dist = acosf(cosine) / 3.14159265359f;
714 float similarity = 1.0f - angular_dist;
715
716 results[query_idx * db_count + db_idx] = similarity;
717 }
718 "#
719 .to_string()
720 }
721
722 fn get_cosine_similarity_fp16_kernel(&self) -> String {
725 r#"
726 #include <cuda_fp16.h>
727
728 extern "C" __global__ void cosine_similarity_fp16_kernel(
729 const half* __restrict__ queries,
730 const half* __restrict__ database,
731 float* __restrict__ results,
732 const int query_count,
733 const int db_count,
734 const int dim
735 ) {
736 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
737 const int query_idx = tid / db_count;
738 const int db_idx = tid % db_count;
739
740 if (query_idx >= query_count || db_idx >= db_count) return;
741
742 float dot = 0.0f, norm_q = 0.0f, norm_db = 0.0f;
743 const half* q_vec = queries + query_idx * dim;
744 const half* db_vec = database + db_idx * dim;
745
746 // Process in chunks of 2 for half2 vectorization
747 const int vec_dim = dim / 2;
748 const half2* q_vec2 = (const half2*)q_vec;
749 const half2* db_vec2 = (const half2*)db_vec;
750
751 for (int i = 0; i < vec_dim; i++) {
752 half2 q_vals = q_vec2[i];
753 half2 db_vals = db_vec2[i];
754
755 float2 q_f = __half22float2(q_vals);
756 float2 db_f = __half22float2(db_vals);
757
758 dot += q_f.x * db_f.x + q_f.y * db_f.y;
759 norm_q += q_f.x * q_f.x + q_f.y * q_f.y;
760 norm_db += db_f.x * db_f.x + db_f.y * db_f.y;
761 }
762
763 // Handle odd dimension
764 if (dim % 2 == 1) {
765 float q_last = __half2float(q_vec[dim - 1]);
766 float db_last = __half2float(db_vec[dim - 1]);
767 dot += q_last * db_last;
768 norm_q += q_last * q_last;
769 norm_db += db_last * db_last;
770 }
771
772 const float norm_product = sqrtf(norm_q) * sqrtf(norm_db);
773 const float similarity = (norm_product > 1e-8f) ? dot / norm_product : 0.0f;
774
775 results[query_idx * db_count + db_idx] = similarity;
776 }
777 "#
778 .to_string()
779 }
780
781 fn get_euclidean_distance_fp16_kernel(&self) -> String {
782 r#"
783 #include <cuda_fp16.h>
784
785 extern "C" __global__ void euclidean_distance_fp16_kernel(
786 const half* __restrict__ queries,
787 const half* __restrict__ database,
788 float* __restrict__ results,
789 const int query_count,
790 const int db_count,
791 const int dim
792 ) {
793 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
794 const int query_idx = tid / db_count;
795 const int db_idx = tid % db_count;
796
797 if (query_idx >= query_count || db_idx >= db_count) return;
798
799 float sum_sq_diff = 0.0f;
800 const half* q_vec = queries + query_idx * dim;
801 const half* db_vec = database + db_idx * dim;
802
803 const int vec_dim = dim / 2;
804 const half2* q_vec2 = (const half2*)q_vec;
805 const half2* db_vec2 = (const half2*)db_vec;
806
807 for (int i = 0; i < vec_dim; i++) {
808 float2 q_f = __half22float2(q_vec2[i]);
809 float2 db_f = __half22float2(db_vec2[i]);
810
811 float diff_x = q_f.x - db_f.x;
812 float diff_y = q_f.y - db_f.y;
813 sum_sq_diff += diff_x * diff_x + diff_y * diff_y;
814 }
815
816 if (dim % 2 == 1) {
817 float diff = __half2float(q_vec[dim - 1]) - __half2float(db_vec[dim - 1]);
818 sum_sq_diff += diff * diff;
819 }
820
821 results[query_idx * db_count + db_idx] = sqrtf(sum_sq_diff);
822 }
823 "#
824 .to_string()
825 }
826
827 fn get_matmul_tensor_core_kernel(&self) -> String {
830 r#"
831 #include <mma.h>
832 using namespace nvcuda;
833
834 extern "C" __global__ void matmul_tensor_core_kernel(
835 const half* __restrict__ a,
836 const half* __restrict__ b,
837 float* __restrict__ c,
838 const int m,
839 const int n,
840 const int k
841 ) {
842 // Warp and lane identifiers
843 const int warp_id = threadIdx.x / 32;
844 const int lane_id = threadIdx.x % 32;
845
846 // WMMA fragment declarations (16x16x16)
847 wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
848 wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> b_frag;
849 wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag;
850
851 // Initialize accumulator
852 wmma::fill_fragment(c_frag, 0.0f);
853
854 // Tile indices
855 const int tile_m = blockIdx.y * 16;
856 const int tile_n = blockIdx.x * 16;
857
858 // Compute matrix multiplication using Tensor Cores
859 for (int i = 0; i < k; i += 16) {
860 wmma::load_matrix_sync(a_frag, a + tile_m * k + i, k);
861 wmma::load_matrix_sync(b_frag, b + i * n + tile_n, n);
862 wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
863 }
864
865 // Store result
866 wmma::store_matrix_sync(c + tile_m * n + tile_n, c_frag, n, wmma::mem_row_major);
867 }
868 "#
869 .to_string()
870 }
871}
872
873impl Default for KernelManager {
874 fn default() -> Self {
875 Self::new()
876 }
877}