1use std::collections::HashMap;
5
6#[derive(Debug)]
8pub struct KernelManager {
9 kernels: HashMap<String, String>,
10}
11
12impl KernelManager {
13 pub fn new() -> Self {
14 let mut manager = Self {
15 kernels: HashMap::new(),
16 };
17 manager.initialize_kernels();
18 manager
19 }
20
21 fn initialize_kernels(&mut self) {
22 let kernels = vec![
23 (
24 "cosine_similarity".to_string(),
25 self.get_cosine_similarity_kernel(),
26 ),
27 (
28 "euclidean_distance".to_string(),
29 self.get_euclidean_distance_kernel(),
30 ),
31 ("dot_product".to_string(), self.get_dot_product_kernel()),
32 (
33 "vector_addition".to_string(),
34 self.get_vector_addition_kernel(),
35 ),
36 (
37 "vector_normalization".to_string(),
38 self.get_vector_normalization_kernel(),
39 ),
40 ("hnsw_search".to_string(), self.get_hnsw_search_kernel()),
41 (
42 "batch_distance_computation".to_string(),
43 self.get_batch_distance_kernel(),
44 ),
45 ];
46
47 for (name, kernel) in kernels {
48 self.kernels.insert(name, kernel);
49 }
50 }
51
52 pub fn get_kernel(&self, name: &str) -> Option<&String> {
53 self.kernels.get(name)
54 }
55
56 fn get_cosine_similarity_kernel(&self) -> String {
57 r#"
58 extern "C" __global__ void cosine_similarity_kernel(
59 const float* __restrict__ queries,
60 const float* __restrict__ database,
61 float* __restrict__ results,
62 const int query_count,
63 const int db_count,
64 const int dim
65 ) {
66 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
67 const int query_idx = tid / db_count;
68 const int db_idx = tid % db_count;
69
70 if (query_idx >= query_count || db_idx >= db_count) return;
71
72 float dot = 0.0f, norm_q = 0.0f, norm_db = 0.0f;
73
74 const int vec_dim = (dim + 3) / 4;
75 const float4* q_vec = (const float4*)(queries + query_idx * dim);
76 const float4* db_vec = (const float4*)(database + db_idx * dim);
77
78 for (int i = 0; i < vec_dim; i++) {
79 float4 q_vals = q_vec[i];
80 float4 db_vals = db_vec[i];
81
82 dot += q_vals.x * db_vals.x + q_vals.y * db_vals.y +
83 q_vals.z * db_vals.z + q_vals.w * db_vals.w;
84 norm_q += q_vals.x * q_vals.x + q_vals.y * q_vals.y +
85 q_vals.z * q_vals.z + q_vals.w * q_vals.w;
86 norm_db += db_vals.x * db_vals.x + db_vals.y * db_vals.y +
87 db_vals.z * db_vals.z + db_vals.w * db_vals.w;
88 }
89
90 const float norm_product = sqrtf(norm_q) * sqrtf(norm_db);
91 const float similarity = (norm_product > 1e-8f) ? dot / norm_product : 0.0f;
92
93 results[query_idx * db_count + db_idx] = similarity;
94 }
95 "#
96 .to_string()
97 }
98
99 fn get_euclidean_distance_kernel(&self) -> String {
100 r#"
101 extern "C" __global__ void euclidean_distance_kernel(
102 const float* __restrict__ queries,
103 const float* __restrict__ database,
104 float* __restrict__ results,
105 const int query_count,
106 const int db_count,
107 const int dim
108 ) {
109 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
110 const int query_idx = tid / db_count;
111 const int db_idx = tid % db_count;
112
113 if (query_idx >= query_count || db_idx >= db_count) return;
114
115 float sum_sq_diff = 0.0f;
116
117 const int vec_dim = (dim + 3) / 4;
118 const float4* q_vec = (const float4*)(queries + query_idx * dim);
119 const float4* db_vec = (const float4*)(database + db_idx * dim);
120
121 for (int i = 0; i < vec_dim; i++) {
122 float4 q_vals = q_vec[i];
123 float4 db_vals = db_vec[i];
124 float4 diff = make_float4(
125 q_vals.x - db_vals.x,
126 q_vals.y - db_vals.y,
127 q_vals.z - db_vals.z,
128 q_vals.w - db_vals.w
129 );
130 sum_sq_diff += diff.x * diff.x + diff.y * diff.y +
131 diff.z * diff.z + diff.w * diff.w;
132 }
133
134 results[query_idx * db_count + db_idx] = sqrtf(sum_sq_diff);
135 }
136 "#
137 .to_string()
138 }
139
140 fn get_dot_product_kernel(&self) -> String {
141 r#"
142 extern "C" __global__ void dot_product_kernel(
143 const float* __restrict__ a,
144 const float* __restrict__ b,
145 float* __restrict__ result,
146 const int n
147 ) {
148 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
149 const int stride = blockDim.x * gridDim.x;
150
151 float sum = 0.0f;
152 for (int i = tid; i < n; i += stride) {
153 sum += a[i] * b[i];
154 }
155
156 __shared__ float shared_sum[256];
157 shared_sum[threadIdx.x] = sum;
158 __syncthreads();
159
160 for (int s = blockDim.x / 2; s > 0; s >>= 1) {
161 if (threadIdx.x < s) {
162 shared_sum[threadIdx.x] += shared_sum[threadIdx.x + s];
163 }
164 __syncthreads();
165 }
166
167 if (threadIdx.x == 0) {
168 atomicAdd(result, shared_sum[0]);
169 }
170 }
171 "#
172 .to_string()
173 }
174
175 fn get_vector_addition_kernel(&self) -> String {
176 r#"
177 extern "C" __global__ void vector_addition_kernel(
178 const float* __restrict__ a,
179 const float* __restrict__ b,
180 float* __restrict__ result,
181 const int n
182 ) {
183 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
184 if (tid < n) {
185 result[tid] = a[tid] + b[tid];
186 }
187 }
188 "#
189 .to_string()
190 }
191
192 fn get_vector_normalization_kernel(&self) -> String {
193 r#"
194 extern "C" __global__ void vector_normalization_kernel(
195 float* __restrict__ vectors,
196 const int count,
197 const int dim
198 ) {
199 const int vector_idx = blockIdx.x;
200 const int tid = threadIdx.x;
201
202 if (vector_idx >= count) return;
203
204 float* vector = vectors + vector_idx * dim;
205
206 __shared__ float shared_norm;
207 if (tid == 0) shared_norm = 0.0f;
208 __syncthreads();
209
210 float local_sum = 0.0f;
211 for (int i = tid; i < dim; i += blockDim.x) {
212 local_sum += vector[i] * vector[i];
213 }
214
215 atomicAdd(&shared_norm, local_sum);
216 __syncthreads();
217
218 if (tid == 0) {
219 shared_norm = sqrtf(shared_norm);
220 if (shared_norm > 1e-8f) shared_norm = 1.0f / shared_norm;
221 }
222 __syncthreads();
223
224 for (int i = tid; i < dim; i += blockDim.x) {
225 vector[i] *= shared_norm;
226 }
227 }
228 "#
229 .to_string()
230 }
231
232 fn get_hnsw_search_kernel(&self) -> String {
233 r#"
234 extern "C" __global__ void hnsw_search_kernel(
235 const float* __restrict__ query,
236 const float* __restrict__ vectors,
237 const int* __restrict__ adjacency_list,
238 const int* __restrict__ adjacency_offsets,
239 int* __restrict__ candidate_queue,
240 float* __restrict__ candidate_distances,
241 int* __restrict__ queue_size,
242 const int dim,
243 const int entry_point
244 ) {
245 const int tid = threadIdx.x;
246
247 extern __shared__ float shared_data[];
248 float* shared_query = shared_data;
249 int* shared_queue = (int*)(shared_data + dim);
250 float* shared_queue_dist = (float*)(shared_queue + 128);
251
252 if (tid < dim) {
253 shared_query[tid] = query[tid];
254 }
255 __syncthreads();
256
257 int queue_head = 0;
258 int queue_tail = 0;
259
260 if (tid == 0) {
261 shared_queue[0] = entry_point;
262 shared_queue_dist[0] = 0.0f;
263 queue_tail = 1;
264 }
265 __syncthreads();
266
267 while (queue_head < queue_tail && queue_tail < 128) {
268 __syncthreads();
269
270 if (tid == 0 && queue_head < queue_tail) {
271 int current_node = shared_queue[queue_head];
272 queue_head++;
273
274 int neighbor_start = adjacency_offsets[current_node];
275 int neighbor_end = adjacency_offsets[current_node + 1];
276
277 for (int i = neighbor_start; i < neighbor_end && queue_tail < 128; i++) {
278 int neighbor = adjacency_list[i];
279
280 const float* neighbor_vector = vectors + neighbor * dim;
281 float neighbor_dist = 0.0f;
282 for (int d = 0; d < dim; d++) {
283 float diff = shared_query[d] - neighbor_vector[d];
284 neighbor_dist += diff * diff;
285 }
286 neighbor_dist = sqrtf(neighbor_dist);
287
288 shared_queue[queue_tail] = neighbor;
289 shared_queue_dist[queue_tail] = neighbor_dist;
290 queue_tail++;
291 }
292 }
293 }
294
295 if (tid < queue_tail) {
296 candidate_queue[tid] = shared_queue[tid];
297 candidate_distances[tid] = shared_queue_dist[tid];
298 }
299
300 if (tid == 0) {
301 *queue_size = queue_tail;
302 }
303 }
304 "#
305 .to_string()
306 }
307
308 fn get_batch_distance_kernel(&self) -> String {
309 r#"
310 extern "C" __global__ void batch_distance_kernel(
311 const float* __restrict__ batch_a,
312 const float* __restrict__ batch_b,
313 float* __restrict__ distances,
314 const int batch_size_a,
315 const int batch_size_b,
316 const int dim,
317 const int metric_type
318 ) {
319 const int tid = blockIdx.x * blockDim.x + threadIdx.x;
320 const int i = tid / batch_size_b;
321 const int j = tid % batch_size_b;
322
323 if (i >= batch_size_a || j >= batch_size_b) return;
324
325 const float* vec_a = batch_a + i * dim;
326 const float* vec_b = batch_b + j * dim;
327
328 float distance = 0.0f;
329
330 if (metric_type == 0) { // Euclidean
331 for (int d = 0; d < dim; d++) {
332 float diff = vec_a[d] - vec_b[d];
333 distance += diff * diff;
334 }
335 distance = sqrtf(distance);
336 } else if (metric_type == 1) { // Cosine
337 float dot = 0.0f, norm_a = 0.0f, norm_b = 0.0f;
338 for (int d = 0; d < dim; d++) {
339 dot += vec_a[d] * vec_b[d];
340 norm_a += vec_a[d] * vec_a[d];
341 norm_b += vec_b[d] * vec_b[d];
342 }
343 float norm_product = sqrtf(norm_a) * sqrtf(norm_b);
344 distance = (norm_product > 1e-8f) ? 1.0f - (dot / norm_product) : 1.0f;
345 }
346
347 distances[i * batch_size_b + j] = distance;
348 }
349 "#
350 .to_string()
351 }
352}
353
354impl Default for KernelManager {
355 fn default() -> Self {
356 Self::new()
357 }
358}