hive-gpu 0.2.0

High-performance GPU acceleration for vector operations with Device Info API (Metal, CUDA, ROCm)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
#include <metal_stdlib>
using namespace metal;

// HNSW Node structure - represents a node in hierarchical graph
struct HnswNode {
    uint id;
    uint level;                // Maximum layer this node appears in
    uint base_connections_offset; // Offset to connections in base layer (layer 0)
    uint vector_offset;        // Offset into vectors buffer
};

// Layer-specific node data for GPU processing
struct HnswLayerNode {
    uint node_id;
    uint connections_offset;   // Offset to connections in this layer
    uint connection_count;     // Number of connections in this layer
};

// Search candidate with distance for priority queue
struct SearchCandidate {
    uint node_id;
    float distance;
};

// Search result structure
struct SearchResult {
    uint node_id;
    float distance;
};

// HNSW search state for layer navigation
struct HnswSearchState {
    uint current_layer;
    uint entry_point;
    uint ef;                   // Size of dynamic candidate list
    uint visited_count;        // Number of visited nodes
};

// Layer navigation result
struct LayerNavigationResult {
    uint best_node_id;
    float best_distance;
    bool found_better;
};

// Full GPU search structures
struct GpuSearchQuery {
    float data[512];           // Query vector (max 512 dimensions)
    uint dimension;            // Actual dimension
};

struct GpuSearchResult {
    uint vector_id;            // ID of the matched vector
    float distance;            // Distance to query
    uint vector_index;         // Index in vectors buffer
};

// Vector metadata for GPU processing
struct GpuVectorMetadata {
    uint vector_id;            // Original vector ID
    uint is_active;            // 1 if vector is active, 0 if removed
};

// Distance calculation function with robust numerical handling
float calculate_cosine_distance(device const float* vector_a, device const float* vector_b, uint dimension) {
    float dot_product = 0.0;
    float norm_a = 0.0;
    float norm_b = 0.0;
    
    for (uint i = 0; i < dimension; i++) {
        float a_val = vector_a[i];
        float b_val = vector_b[i];
        dot_product += a_val * b_val;
        norm_a += a_val * a_val;
        norm_b += b_val * b_val;
    }
    
    norm_a = sqrt(norm_a);
    norm_b = sqrt(norm_b);
    
    // Avoid division by zero and ensure numerical stability
    const float epsilon = 1e-8;
    float denom = max(norm_a * norm_b, epsilon);
    float similarity = dot_product / denom;
    
    // Clamp similarity to valid range [-1, 1]
    similarity = clamp(similarity, -1.0, 1.0);
    
    return 1.0 - similarity;
}

// HNSW Layer Navigation Kernel - Navigate within a single layer using greedy search
kernel void hnsw_navigate_layer(
    device const float* vectors [[buffer(0)]],
    device const HnswNode* nodes [[buffer(1)]],
    device const HnswLayerNode* layer_nodes [[buffer(2)]],
    device const uint* layer_connections [[buffer(3)]],
    device const float* query_vector [[buffer(4)]],
    device SearchCandidate* candidates [[buffer(5)]],
    device uint* visited_nodes [[buffer(6)]],
    constant uint& vector_dim [[buffer(7)]],
    constant uint& layer_node_count [[buffer(8)]],
    constant uint& max_candidates [[buffer(9)]],
    constant uint& entry_point [[buffer(10)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= layer_node_count) return;

    // This kernel processes one node from the candidates list
    // tid represents the index in the candidates array to explore

    // Get the candidate to explore
    SearchCandidate current_candidate = candidates[tid];
    uint current_node_id = current_candidate.node_id;

    // Find the layer node data for this node
    HnswLayerNode layer_node;
    bool found_node = false;

    for (uint i = 0; i < layer_node_count; i++) {
        if (layer_nodes[i].node_id == current_node_id) {
            layer_node = layer_nodes[i];
            found_node = true;
            break;
        }
    }

    if (!found_node || layer_node.connection_count == 0) {
        // No neighbors to explore from this node
        return;
    }

    // Explore all neighbors of this node
    for (uint i = 0; i < layer_node.connection_count; i++) {
        uint neighbor_id = layer_connections[layer_node.connections_offset + i];

        // Check if already visited (simple linear search for now)
        bool already_visited = false;
        for (uint j = 0; j < max_candidates; j++) {
            if (visited_nodes[j] == neighbor_id) {
                already_visited = true;
                break;
            }
        }

        if (already_visited) continue;

        // Calculate distance to neighbor
        HnswNode neighbor_node;
        bool found_neighbor = false;

        for (uint j = 0; j < layer_node_count; j++) {
            if (nodes[j].id == neighbor_id) {
                neighbor_node = nodes[j];
                found_neighbor = true;
                break;
            }
        }

        if (!found_neighbor) continue;

        device const float* neighbor_vector = &vectors[neighbor_node.vector_offset * vector_dim];
        float distance = calculate_cosine_distance(query_vector, neighbor_vector, vector_dim);

        // Try to add to candidates if better than current worst
        // This is a simplified version - in practice, we'd use a priority queue
        for (uint j = 0; j < max_candidates; j++) {
            if (candidates[j].distance > distance) {
                // Shift elements to make room
                for (uint k = max_candidates - 1; k > j; k--) {
                    candidates[k] = candidates[k - 1];
                }
                candidates[j] = SearchCandidate{.node_id = neighbor_id, .distance = distance};
                visited_nodes[j] = neighbor_id;
                break;
            }
        }
    }
}

// HNSW Complete Search Kernel - Performs full hierarchical search
kernel void hnsw_search_complete(
    device const float* vectors [[buffer(0)]],
    device const HnswNode* nodes [[buffer(1)]],
    device const HnswLayerNode* layer_nodes_base [[buffer(2)]],
    device const HnswLayerNode* layer_nodes_higher [[buffer(3)]],
    device const uint* layer_connections_base [[buffer(4)]],
    device const uint* layer_connections_higher [[buffer(5)]],
    device const float* query_vector [[buffer(6)]],
    device SearchResult* final_results [[buffer(7)]],
    constant uint& vector_dim [[buffer(8)]],
    constant uint& max_level [[buffer(9)]],
    constant uint& entry_point [[buffer(10)]],
    constant uint& ef_search [[buffer(11)]],
    constant uint& k [[buffer(12)]],
    uint search_id [[thread_position_in_grid]]
) {
    // This kernel performs a complete HNSW search for one query
    // Implements real hierarchical navigation with greedy search

    if (search_id >= 1) return; // Only handle one search at a time for now

    // Shared memory for candidates (max 256 candidates)
    threadgroup SearchCandidate candidates[256];
    threadgroup uint visited[256];
    threadgroup uint candidate_count[1];
    threadgroup uint visited_count[1];

    // Initialize
    if (search_id == 0) {
        candidate_count[0] = 0;
        visited_count[0] = 0;

        // Add entry point as first candidate
        if (entry_point < 10000) { // Safety check
            HnswNode entry_node = nodes[entry_point];
            device const float* entry_vector = &vectors[entry_node.vector_offset * vector_dim];
            float entry_distance = calculate_cosine_distance(query_vector, entry_vector, vector_dim);

            candidates[0] = SearchCandidate{.node_id = entry_point, .distance = entry_distance};
            visited[0] = entry_point;
            candidate_count[0] = 1;
            visited_count[0] = 1;
        }
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Perform hierarchical search from top layer down to layer 0 using beam search
    uint current_entry = entry_point;

    for (uint current_level = max_level; current_level > 0; current_level--) {
        // Beam search in current layer - explore multiple candidates simultaneously
        uint beam_width = min(ef_search / 4 + 1, 8u); // Adaptive beam width
        bool found_improvement = true;
        uint iterations = 0;

        while (found_improvement && iterations < 10 && candidate_count[0] > 0) { // Limit iterations
            found_improvement = false;
            iterations++;

            // Sort current candidates by distance (beam selection)
            for (uint i = 0; i < candidate_count[0] - 1 && i < 256; i++) {
                for (uint j = i + 1; j < candidate_count[0] && j < 256; j++) {
                    if (candidates[j].distance < candidates[i].distance) {
                        SearchCandidate temp = candidates[i];
                        candidates[i] = candidates[j];
                        candidates[j] = temp;
                    }
                }
            }

            // Keep only top beam_width candidates
            if (candidate_count[0] > beam_width) {
                candidate_count[0] = beam_width;
            }

            // Explore neighbors of all beam candidates
            uint new_candidates_start = candidate_count[0];

            for (uint beam_idx = 0; beam_idx < beam_width && beam_idx < candidate_count[0]; beam_idx++) {
                uint current_node_id = candidates[beam_idx].node_id;

                // Explore neighbors of this beam candidate
                device const HnswLayerNode* layer_nodes = (current_level == 0) ?
                    layer_nodes_base : layer_nodes_higher;
                device const uint* layer_connections = (current_level == 0) ?
                    layer_connections_base : layer_connections_higher;

                // Find the layer node for current_node_id
                HnswLayerNode layer_node;
                bool found_layer_node = false;

                for (uint i = 0; i < 1000 && !found_layer_node; i++) { // Safety limit
                    if (layer_nodes[i].node_id == current_node_id) {
                        layer_node = layer_nodes[i];
                        found_layer_node = true;
                    }
                }

                if (found_layer_node && layer_node.connection_count > 0) {
                    for (uint i = 0; i < layer_node.connection_count && i < 32; i++) { // Max 32 neighbors
                        uint neighbor_id = layer_connections[layer_node.connections_offset + i];

                        // Check if already visited
                        bool already_visited = false;
                        for (uint j = 0; j < visited_count[0] && j < 256; j++) {
                            if (visited[j] == neighbor_id) {
                                already_visited = true;
                                break;
                            }
                        }

                        if (!already_visited && candidate_count[0] < 256) {
                            // Calculate distance to neighbor
                            HnswNode neighbor_node = nodes[neighbor_id];
                            device const float* neighbor_vector = &vectors[neighbor_node.vector_offset * vector_dim];
                            float distance = calculate_cosine_distance(query_vector, neighbor_vector, vector_dim);

                            // Add to candidates
                            candidates[candidate_count[0]] = SearchCandidate{
                                .node_id = neighbor_id,
                                .distance = distance
                            };
                            candidate_count[0]++;

                            // Add to visited
                            if (visited_count[0] < 256) {
                                visited[visited_count[0]] = neighbor_id;
                                visited_count[0]++;
                                found_improvement = true;
                            }
                        }
                    }
                }
            }

            // Remove duplicates and sort again
            if (candidate_count[0] > beam_width) {
                // Simple deduplication - keep unique nodes with best distances
                for (uint i = new_candidates_start; i < candidate_count[0] && i < 256; i++) {
                    for (uint j = 0; j < new_candidates_start && j < 256; j++) {
                        if (candidates[i].node_id == candidates[j].node_id) {
                            // Keep the better distance
                            if (candidates[i].distance < candidates[j].distance) {
                                candidates[j] = candidates[i];
                            }
                            // Mark for removal
                            candidates[i].distance = INFINITY;
                            break;
                        }
                    }
                }

                // Compact array (remove INFINITY entries)
                uint write_idx = 0;
                for (uint i = 0; i < candidate_count[0] && i < 256; i++) {
                    if (candidates[i].distance < INFINITY) {
                        if (write_idx != i) {
                            candidates[write_idx] = candidates[i];
                        }
                        write_idx++;
                    }
                }
                candidate_count[0] = write_idx;

                // Keep only top beam_width
                if (candidate_count[0] > beam_width) {
                    candidate_count[0] = beam_width;
                }
            }
        }

        // Prepare candidates for next level (keep only the best one)
        if (candidate_count[0] > 0) {
            // Find the best candidate to carry to next level
            uint best_idx = 0;
            float best_dist = INFINITY;

            for (uint i = 0; i < candidate_count[0] && i < 256; i++) {
                if (candidates[i].distance < best_dist) {
                    best_dist = candidates[i].distance;
                    best_idx = i;
                }
            }

            // Reset candidates for next level
            candidate_count[0] = 1;
            candidates[0] = candidates[best_idx];
            current_entry = candidates[0].node_id;
        }
    }

    // Now perform beam search in base layer (level 0) with ef_search candidates
    candidate_count[0] = 1;
    candidates[0] = SearchCandidate{.node_id = current_entry, .distance = INFINITY};

    // Calculate distance for entry point
    HnswNode entry_node = nodes[current_entry];
    device const float* entry_vector = &vectors[entry_node.vector_offset * vector_dim];
    candidates[0].distance = calculate_cosine_distance(query_vector, entry_vector, vector_dim);

    visited_count[0] = 1;
    visited[0] = current_entry;

    // Beam search in base layer with ef_search beam width
    uint base_beam_width = min(ef_search, 64u); // Limit beam width for base layer
    bool found_improvement = true;
    uint base_iterations = 0;

    while (found_improvement && base_iterations < 15 && candidate_count[0] > 0) { // More iterations for base layer
        found_improvement = false;
        base_iterations++;

        // Sort current candidates by distance and keep top beam_width
        for (uint i = 0; i < candidate_count[0] - 1 && i < 256; i++) {
            for (uint j = i + 1; j < candidate_count[0] && j < 256; j++) {
                if (candidates[j].distance < candidates[i].distance) {
                    SearchCandidate temp = candidates[i];
                    candidates[i] = candidates[j];
                    candidates[j] = temp;
                }
            }
        }

        // Keep only top base_beam_width candidates
        if (candidate_count[0] > base_beam_width) {
            candidate_count[0] = base_beam_width;
        }

        // Explore neighbors of all beam candidates in base layer
        uint new_candidates_start = candidate_count[0];

        for (uint beam_idx = 0; beam_idx < base_beam_width && beam_idx < candidate_count[0]; beam_idx++) {
            uint current_node_id = candidates[beam_idx].node_id;

            // Find the layer node for current_node_id in base layer
            HnswLayerNode layer_node;
            bool found_layer_node = false;

            for (uint i = 0; i < 1000 && !found_layer_node; i++) { // Safety limit
                if (layer_nodes_base[i].node_id == current_node_id) {
                    layer_node = layer_nodes_base[i];
                    found_layer_node = true;
                }
            }

            if (found_layer_node && layer_node.connection_count > 0) {
                for (uint i = 0; i < layer_node.connection_count && i < 32; i++) {
                    uint neighbor_id = layer_connections_base[layer_node.connections_offset + i];

                    // Check if already visited
                    bool already_visited = false;
                    for (uint j = 0; j < visited_count[0] && j < 256; j++) {
                        if (visited[j] == neighbor_id) {
                            already_visited = true;
                            break;
                        }
                    }

                    if (!already_visited && candidate_count[0] < ef_search && candidate_count[0] < 256) {
                        // Calculate distance to neighbor
                        HnswNode neighbor_node = nodes[neighbor_id];
                        device const float* neighbor_vector = &vectors[neighbor_node.vector_offset * vector_dim];
                        float distance = calculate_cosine_distance(query_vector, neighbor_vector, vector_dim);

                        // Add to candidates
                        candidates[candidate_count[0]] = SearchCandidate{
                            .node_id = neighbor_id,
                            .distance = distance
                        };
                        candidate_count[0]++;

                        // Add to visited
                        if (visited_count[0] < 256) {
                            visited[visited_count[0]] = neighbor_id;
                            visited_count[0]++;
                            found_improvement = true;
                        }
                    }
                }
            }
        }

        // Maintain ef_search candidates - replace worst if we have too many
        if (candidate_count[0] > ef_search) {
            // Sort all candidates and keep best ef_search
            for (uint i = 0; i < candidate_count[0] - 1 && i < 256; i++) {
                for (uint j = i + 1; j < candidate_count[0] && j < 256; j++) {
                    if (candidates[j].distance < candidates[i].distance) {
                        SearchCandidate temp = candidates[i];
                        candidates[i] = candidates[j];
                        candidates[j] = temp;
                    }
                }
            }
            candidate_count[0] = ef_search;
        }
    }

    // Select top k results from candidates
    for (uint i = 0; i < k && i < candidate_count[0] && i < 256; i++) {
        // Simple selection - find i-th best
        uint best_idx = 0;
        float best_dist = INFINITY;

        for (uint j = 0; j < candidate_count[0] && j < 256; j++) {
            bool already_selected = false;
            for (uint m = 0; m < i; m++) {
                if (final_results[m].node_id == candidates[j].node_id) {
                    already_selected = true;
                    break;
                }
            }

            if (!already_selected && candidates[j].distance < best_dist) {
                best_dist = candidates[j].distance;
                best_idx = j;
            }
        }

        if (best_dist < INFINITY) {
            final_results[i] = SearchResult{
                .node_id = candidates[best_idx].node_id,
                .distance = candidates[best_idx].distance
            };
        }
    }
}

// Top-K selection kernel with correct implementation
kernel void select_top_k(
    device const SearchCandidate* candidates [[buffer(0)]],
    device SearchResult* top_results [[buffer(1)]],
    constant uint& candidate_count [[buffer(2)]],
    constant uint& k [[buffer(3)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= k) return;
    
    // Sort candidates by distance (simplified selection sort)
    threadgroup SearchCandidate sorted_candidates[256]; // Shared memory for sorting

    // Copy candidates to threadgroup memory
    for (uint i = 0; i < candidate_count && i < 256; i++) {
        sorted_candidates[i] = candidates[i];
    }

    // Simple selection sort (not efficient but correct)
    for (uint i = 0; i < candidate_count - 1 && i < k; i++) {
        uint min_idx = i;
        for (uint j = i + 1; j < candidate_count && j < 256; j++) {
            if (sorted_candidates[j].distance < sorted_candidates[min_idx].distance) {
                min_idx = j;
            }
        }

        // Swap
        SearchCandidate temp = sorted_candidates[i];
        sorted_candidates[i] = sorted_candidates[min_idx];
        sorted_candidates[min_idx] = temp;
    }

    // Store the tid-th best result
    if (tid < candidate_count && tid < 256) {
        SearchCandidate result = sorted_candidates[tid];
        top_results[tid] = SearchResult{
            .node_id = result.node_id,
            .distance = result.distance
        };
    }
}

// HNSW Search Initialization Kernel - Sets up initial candidates for search
kernel void hnsw_search_init(
    device const float* vectors [[buffer(0)]],
    device const HnswNode* nodes [[buffer(1)]],
    device const float* query_vector [[buffer(2)]],
    device SearchCandidate* candidates [[buffer(3)]],
    device uint* visited_nodes [[buffer(4)]],
    constant uint& vector_dim [[buffer(5)]],
    constant uint& entry_point [[buffer(6)]],
    constant uint& max_candidates [[buffer(7)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= 1) return; // Only one thread initializes

    // Initialize with entry point
    HnswNode entry_node = nodes[entry_point];

    device const float* entry_vector = &vectors[entry_node.vector_offset * vector_dim];
    float entry_distance = calculate_cosine_distance(query_vector, entry_vector, vector_dim);

    // Set first candidate as entry point
    candidates[0] = SearchCandidate{
        .node_id = entry_point,
        .distance = entry_distance
    };

    // Initialize visited nodes
    visited_nodes[0] = entry_point;

    // Fill remaining candidates with invalid data (distance = INFINITY)
    for (uint i = 1; i < max_candidates; i++) {
        candidates[i] = SearchCandidate{
            .node_id = UINT_MAX,
            .distance = INFINITY
        };
        visited_nodes[i] = UINT_MAX;
    }
}

// HNSW Connection Building Kernel - Builds connections during index construction
kernel void hnsw_build_connections(
    device const float* vectors [[buffer(0)]],
    device const HnswNode* nodes [[buffer(1)]],
    device HnswLayerNode* layer_nodes [[buffer(2)]],
    device uint* layer_connections [[buffer(3)]],
    device const float* new_vector [[buffer(4)]],
    constant uint& vector_dim [[buffer(5)]],
    constant uint& layer [[buffer(6)]],
    constant uint& max_connections [[buffer(7)]],
    constant uint& new_node_id [[buffer(8)]],
    constant uint& layer_node_count [[buffer(9)]],
    uint tid [[thread_position_in_grid]]
) {
    // This kernel helps build connections for a new node in a specific layer
    // tid represents the index of existing nodes to consider for connection

    if (tid >= layer_node_count) return;

    HnswLayerNode existing_node = layer_nodes[tid];

    // Skip if this is the new node itself
    if (existing_node.node_id == new_node_id) return;

    // Calculate distance between new vector and existing node
    HnswNode existing_node_data;
    bool found = false;

    // Find the node data (simplified linear search)
    for (uint i = 0; i < 10000; i++) { // Safety limit
        if (nodes[i].id == existing_node.node_id) {
            existing_node_data = nodes[i];
            found = true;
            break;
        }
    }

    if (!found) return;

    device const float* existing_vector = &vectors[existing_node_data.vector_offset * vector_dim];
    // Calculate distance (stored for potential future use)
    // float distance = calculate_cosine_distance(new_vector, existing_vector, vector_dim);

    // This is a simplified version - in practice, we'd need atomic operations
    // and a more sophisticated selection algorithm
    // For now, this kernel is a placeholder for future connection building
}

// ============================================================================
// FULL GPU VECTOR SEARCH - Maintains all data in VRAM
// ============================================================================

// Full GPU search kernel - searches all vectors entirely on GPU
kernel void gpu_full_vector_search(
    device const float* vectors [[buffer(0)]],                    // All vector data (flattened)
    device const GpuVectorMetadata* metadata [[buffer(1)]],       // Vector metadata
    device const GpuSearchQuery* query [[buffer(2)]],             // Search query
    device GpuSearchResult* results [[buffer(3)]],                // Search results buffer
    constant uint& vector_count [[buffer(4)]],                   // Total number of vectors
    constant uint& k [[buffer(5)]],                              // Number of results to return
    constant uint& dimension [[buffer(6)]],                      // Vector dimension
    uint tid [[thread_position_in_grid]]                         // Thread ID
) {
    // Each thread processes one vector
    if (tid >= vector_count) return;

    // Check if vector is active (not removed)
    if (metadata[tid].is_active == 0) return;

    // Get vector data pointer
    device const float* vector_data = &vectors[tid * dimension];

    // Get query data
    device const float* query_data = query->data;

    // Calculate cosine distance
    float distance = calculate_cosine_distance(query_data, vector_data, dimension);

    // Store result for this vector
    results[tid].vector_id = metadata[tid].vector_id;
    results[tid].distance = distance;
    results[tid].vector_index = tid;
}

// Parallel reduction kernel to find top-k results
kernel void gpu_find_top_k_results(
    device const GpuSearchResult* all_results [[buffer(0)]],     // All search results
    device GpuSearchResult* final_results [[buffer(1)]],         // Final top-k results
    constant uint& total_vectors [[buffer(2)]],                  // Total number of vectors
    constant uint& k [[buffer(3)]],                              // Number of results to return
    uint tid [[thread_position_in_grid]]                         // Thread ID
) {
    // This is a simplified implementation
    // In practice, we'd use parallel reduction or prefix sum algorithms
    // For now, we'll do a simple bubble sort approach (not optimal)

    if (tid >= k) return;

    // Initialize with worst possible result
    GpuSearchResult best = { UINT_MAX, FLT_MAX, UINT_MAX };

    // Find the tid-th best result
    for (uint i = 0; i < total_vectors; i++) {
        GpuSearchResult current = all_results[i];

        // Skip invalid results
        if (current.vector_id == UINT_MAX) continue;

        // Check if this result is better than current best for this position
        bool better = false;
        if (current.distance < best.distance) {
            better = true;
        } else if (current.distance == best.distance && current.vector_id < best.vector_id) {
            better = true;
        }

        if (better) {
            // Count how many results are better than this one
            uint better_count = 0;
            for (uint j = 0; j < total_vectors; j++) {
                if (i == j) continue;
                GpuSearchResult other = all_results[j];
                if (other.vector_id == UINT_MAX) continue;

                if (other.distance < current.distance ||
                   (other.distance == current.distance && other.vector_id < current.vector_id)) {
                    better_count++;
                }
            }

            // If this is the right position for us, use it
            if (better_count == tid) {
                best = current;
            }
        }
    }

    final_results[tid] = best;
}

// =====================================================================
// Brute-force SGEMV / SGEMM-style kernels (added for phase4a / phase4c)
// =====================================================================
//
// These kernels provide cuBLAS-SGEMV-equivalent primitives for the Metal
// backend without pulling in Metal Performance Shaders. They are used by
// MetalNativeVectorStorage::search (brute-force) and MetalIvfIndex (coarse
// cluster selection, refined per-cluster search, and k-means assignment).

/// SGEMV: scores[i] = sum_d matrix[i, d] * query[d].
/// matrix is row-major of shape (n_vectors, dimension).
/// One thread per output row — each thread reads an entire vector and
/// produces one dot product.
kernel void sgemv_dot(
    const device float* matrix     [[buffer(0)]],
    const device float* query      [[buffer(1)]],
    device float*       scores     [[buffer(2)]],
    constant uint&      dimension  [[buffer(3)]],
    constant uint&      n_vectors  [[buffer(4)]],
    uint                tid        [[thread_position_in_grid]])
{
    if (tid >= n_vectors) {
        return;
    }
    const device float* row = matrix + (uint)tid * dimension;
    float sum = 0.0f;
    for (uint d = 0; d < dimension; ++d) {
        sum += row[d] * query[d];
    }
    scores[tid] = sum;
}

/// SGEMM-lite: out[i, j] = sum_d samples[i, d] * centroids[j, d].
/// samples is row-major (n_samples, dimension); centroids is row-major
/// (n_list, dimension). Output is row-major (n_samples, n_list).
/// Grid = (n_samples, n_list); one thread per output element.
/// Used by MetalIvfIndex for k-means assignment and batch add.
kernel void sgemm_dot(
    const device float* samples    [[buffer(0)]],
    const device float* centroids  [[buffer(1)]],
    device float*       out        [[buffer(2)]],
    constant uint&      dimension  [[buffer(3)]],
    constant uint&      n_list     [[buffer(4)]],
    constant uint&      n_samples  [[buffer(5)]],
    uint2               gid        [[thread_position_in_grid]])
{
    uint i = gid.x; // sample
    uint j = gid.y; // centroid
    if (i >= n_samples || j >= n_list) {
        return;
    }
    const device float* s = samples   + i * dimension;
    const device float* c = centroids + j * dimension;
    float sum = 0.0f;
    for (uint d = 0; d < dimension; ++d) {
        sum += s[d] * c[d];
    }
    out[i * n_list + j] = sum;
}