oxirs_vec/gpu/
kernels.rs

1//! CUDA kernel implementations for various vector operations
2
3// Note: anyhow imports removed as they were unused
4use std::collections::HashMap;
5
6/// CUDA kernel manager
7#[derive(Debug)]
8pub struct KernelManager {
9    kernels: HashMap<String, String>,
10}
11
12impl KernelManager {
13    pub fn new() -> Self {
14        let mut manager = Self {
15            kernels: HashMap::new(),
16        };
17        manager.initialize_kernels();
18        manager
19    }
20
21    fn initialize_kernels(&mut self) {
22        let kernels = vec![
23            (
24                "cosine_similarity".to_string(),
25                self.get_cosine_similarity_kernel(),
26            ),
27            (
28                "euclidean_distance".to_string(),
29                self.get_euclidean_distance_kernel(),
30            ),
31            ("dot_product".to_string(), self.get_dot_product_kernel()),
32            (
33                "vector_addition".to_string(),
34                self.get_vector_addition_kernel(),
35            ),
36            (
37                "vector_normalization".to_string(),
38                self.get_vector_normalization_kernel(),
39            ),
40            ("hnsw_search".to_string(), self.get_hnsw_search_kernel()),
41            (
42                "batch_distance_computation".to_string(),
43                self.get_batch_distance_kernel(),
44            ),
45        ];
46
47        for (name, kernel) in kernels {
48            self.kernels.insert(name, kernel);
49        }
50    }
51
52    pub fn get_kernel(&self, name: &str) -> Option<&String> {
53        self.kernels.get(name)
54    }
55
56    fn get_cosine_similarity_kernel(&self) -> String {
57        r#"
58        extern "C" __global__ void cosine_similarity_kernel(
59            const float* __restrict__ queries,
60            const float* __restrict__ database,
61            float* __restrict__ results,
62            const int query_count,
63            const int db_count,
64            const int dim
65        ) {
66            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
67            const int query_idx = tid / db_count;
68            const int db_idx = tid % db_count;
69
70            if (query_idx >= query_count || db_idx >= db_count) return;
71
72            float dot = 0.0f, norm_q = 0.0f, norm_db = 0.0f;
73
74            const int vec_dim = (dim + 3) / 4;
75            const float4* q_vec = (const float4*)(queries + query_idx * dim);
76            const float4* db_vec = (const float4*)(database + db_idx * dim);
77
78            for (int i = 0; i < vec_dim; i++) {
79                float4 q_vals = q_vec[i];
80                float4 db_vals = db_vec[i];
81
82                dot += q_vals.x * db_vals.x + q_vals.y * db_vals.y +
83                       q_vals.z * db_vals.z + q_vals.w * db_vals.w;
84                norm_q += q_vals.x * q_vals.x + q_vals.y * q_vals.y +
85                          q_vals.z * q_vals.z + q_vals.w * q_vals.w;
86                norm_db += db_vals.x * db_vals.x + db_vals.y * db_vals.y +
87                           db_vals.z * db_vals.z + db_vals.w * db_vals.w;
88            }
89
90            const float norm_product = sqrtf(norm_q) * sqrtf(norm_db);
91            const float similarity = (norm_product > 1e-8f) ? dot / norm_product : 0.0f;
92
93            results[query_idx * db_count + db_idx] = similarity;
94        }
95        "#
96        .to_string()
97    }
98
99    fn get_euclidean_distance_kernel(&self) -> String {
100        r#"
101        extern "C" __global__ void euclidean_distance_kernel(
102            const float* __restrict__ queries,
103            const float* __restrict__ database,
104            float* __restrict__ results,
105            const int query_count,
106            const int db_count,
107            const int dim
108        ) {
109            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
110            const int query_idx = tid / db_count;
111            const int db_idx = tid % db_count;
112
113            if (query_idx >= query_count || db_idx >= db_count) return;
114
115            float sum_sq_diff = 0.0f;
116
117            const int vec_dim = (dim + 3) / 4;
118            const float4* q_vec = (const float4*)(queries + query_idx * dim);
119            const float4* db_vec = (const float4*)(database + db_idx * dim);
120
121            for (int i = 0; i < vec_dim; i++) {
122                float4 q_vals = q_vec[i];
123                float4 db_vals = db_vec[i];
124                float4 diff = make_float4(
125                    q_vals.x - db_vals.x,
126                    q_vals.y - db_vals.y,
127                    q_vals.z - db_vals.z,
128                    q_vals.w - db_vals.w
129                );
130                sum_sq_diff += diff.x * diff.x + diff.y * diff.y +
131                               diff.z * diff.z + diff.w * diff.w;
132            }
133
134            results[query_idx * db_count + db_idx] = sqrtf(sum_sq_diff);
135        }
136        "#
137        .to_string()
138    }
139
140    fn get_dot_product_kernel(&self) -> String {
141        r#"
142        extern "C" __global__ void dot_product_kernel(
143            const float* __restrict__ a,
144            const float* __restrict__ b,
145            float* __restrict__ result,
146            const int n
147        ) {
148            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
149            const int stride = blockDim.x * gridDim.x;
150
151            float sum = 0.0f;
152            for (int i = tid; i < n; i += stride) {
153                sum += a[i] * b[i];
154            }
155
156            __shared__ float shared_sum[256];
157            shared_sum[threadIdx.x] = sum;
158            __syncthreads();
159
160            for (int s = blockDim.x / 2; s > 0; s >>= 1) {
161                if (threadIdx.x < s) {
162                    shared_sum[threadIdx.x] += shared_sum[threadIdx.x + s];
163                }
164                __syncthreads();
165            }
166
167            if (threadIdx.x == 0) {
168                atomicAdd(result, shared_sum[0]);
169            }
170        }
171        "#
172        .to_string()
173    }
174
175    fn get_vector_addition_kernel(&self) -> String {
176        r#"
177        extern "C" __global__ void vector_addition_kernel(
178            const float* __restrict__ a,
179            const float* __restrict__ b,
180            float* __restrict__ result,
181            const int n
182        ) {
183            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
184            if (tid < n) {
185                result[tid] = a[tid] + b[tid];
186            }
187        }
188        "#
189        .to_string()
190    }
191
192    fn get_vector_normalization_kernel(&self) -> String {
193        r#"
194        extern "C" __global__ void vector_normalization_kernel(
195            float* __restrict__ vectors,
196            const int count,
197            const int dim
198        ) {
199            const int vector_idx = blockIdx.x;
200            const int tid = threadIdx.x;
201
202            if (vector_idx >= count) return;
203
204            float* vector = vectors + vector_idx * dim;
205
206            __shared__ float shared_norm;
207            if (tid == 0) shared_norm = 0.0f;
208            __syncthreads();
209
210            float local_sum = 0.0f;
211            for (int i = tid; i < dim; i += blockDim.x) {
212                local_sum += vector[i] * vector[i];
213            }
214
215            atomicAdd(&shared_norm, local_sum);
216            __syncthreads();
217
218            if (tid == 0) {
219                shared_norm = sqrtf(shared_norm);
220                if (shared_norm > 1e-8f) shared_norm = 1.0f / shared_norm;
221            }
222            __syncthreads();
223
224            for (int i = tid; i < dim; i += blockDim.x) {
225                vector[i] *= shared_norm;
226            }
227        }
228        "#
229        .to_string()
230    }
231
232    fn get_hnsw_search_kernel(&self) -> String {
233        r#"
234        extern "C" __global__ void hnsw_search_kernel(
235            const float* __restrict__ query,
236            const float* __restrict__ vectors,
237            const int* __restrict__ adjacency_list,
238            const int* __restrict__ adjacency_offsets,
239            int* __restrict__ candidate_queue,
240            float* __restrict__ candidate_distances,
241            int* __restrict__ queue_size,
242            const int dim,
243            const int entry_point
244        ) {
245            const int tid = threadIdx.x;
246
247            extern __shared__ float shared_data[];
248            float* shared_query = shared_data;
249            int* shared_queue = (int*)(shared_data + dim);
250            float* shared_queue_dist = (float*)(shared_queue + 128);
251
252            if (tid < dim) {
253                shared_query[tid] = query[tid];
254            }
255            __syncthreads();
256
257            int queue_head = 0;
258            int queue_tail = 0;
259
260            if (tid == 0) {
261                shared_queue[0] = entry_point;
262                shared_queue_dist[0] = 0.0f;
263                queue_tail = 1;
264            }
265            __syncthreads();
266
267            while (queue_head < queue_tail && queue_tail < 128) {
268                __syncthreads();
269
270                if (tid == 0 && queue_head < queue_tail) {
271                    int current_node = shared_queue[queue_head];
272                    queue_head++;
273
274                    int neighbor_start = adjacency_offsets[current_node];
275                    int neighbor_end = adjacency_offsets[current_node + 1];
276
277                    for (int i = neighbor_start; i < neighbor_end && queue_tail < 128; i++) {
278                        int neighbor = adjacency_list[i];
279
280                        const float* neighbor_vector = vectors + neighbor * dim;
281                        float neighbor_dist = 0.0f;
282                        for (int d = 0; d < dim; d++) {
283                            float diff = shared_query[d] - neighbor_vector[d];
284                            neighbor_dist += diff * diff;
285                        }
286                        neighbor_dist = sqrtf(neighbor_dist);
287
288                        shared_queue[queue_tail] = neighbor;
289                        shared_queue_dist[queue_tail] = neighbor_dist;
290                        queue_tail++;
291                    }
292                }
293            }
294
295            if (tid < queue_tail) {
296                candidate_queue[tid] = shared_queue[tid];
297                candidate_distances[tid] = shared_queue_dist[tid];
298            }
299
300            if (tid == 0) {
301                *queue_size = queue_tail;
302            }
303        }
304        "#
305        .to_string()
306    }
307
308    fn get_batch_distance_kernel(&self) -> String {
309        r#"
310        extern "C" __global__ void batch_distance_kernel(
311            const float* __restrict__ batch_a,
312            const float* __restrict__ batch_b,
313            float* __restrict__ distances,
314            const int batch_size_a,
315            const int batch_size_b,
316            const int dim,
317            const int metric_type
318        ) {
319            const int tid = blockIdx.x * blockDim.x + threadIdx.x;
320            const int i = tid / batch_size_b;
321            const int j = tid % batch_size_b;
322
323            if (i >= batch_size_a || j >= batch_size_b) return;
324
325            const float* vec_a = batch_a + i * dim;
326            const float* vec_b = batch_b + j * dim;
327
328            float distance = 0.0f;
329
330            if (metric_type == 0) { // Euclidean
331                for (int d = 0; d < dim; d++) {
332                    float diff = vec_a[d] - vec_b[d];
333                    distance += diff * diff;
334                }
335                distance = sqrtf(distance);
336            } else if (metric_type == 1) { // Cosine
337                float dot = 0.0f, norm_a = 0.0f, norm_b = 0.0f;
338                for (int d = 0; d < dim; d++) {
339                    dot += vec_a[d] * vec_b[d];
340                    norm_a += vec_a[d] * vec_a[d];
341                    norm_b += vec_b[d] * vec_b[d];
342                }
343                float norm_product = sqrtf(norm_a) * sqrtf(norm_b);
344                distance = (norm_product > 1e-8f) ? 1.0f - (dot / norm_product) : 1.0f;
345            }
346
347            distances[i * batch_size_b + j] = distance;
348        }
349        "#
350        .to_string()
351    }
352}
353
354impl Default for KernelManager {
355    fn default() -> Self {
356        Self::new()
357    }
358}