Skip to main content

rustyhdf5_gpu/
accelerator.rs

1use crate::device::DeviceInfo;
2use crate::error::{GpuError, Result};
3use crate::helpers::{bgl_entry, div_ceil, top_k_cpu, Params4, Params4U, WORKGROUP_SIZE};
4use crate::shaders;
5
6use bytemuck::Pod;
7use wgpu::util::DeviceExt;
8
9/// GPU-accelerated vector search engine.
10///
11/// Upload vectors once, then run many searches against them.
12/// If GPU initialization fails, callers should fall back to CPU SIMD.
13///
14/// Vectors are automatically split into chunks when they exceed the device's
15/// `max_storage_buffer_binding_size` (typically 128 MB). Searches dispatch
16/// against each chunk and merge results transparently.
17pub struct GpuAccelerator {
18    device: wgpu::Device,
19    queue: wgpu::Queue,
20    info: DeviceInfo,
21    max_binding_size: u32,
22    vectors_bufs: Vec<wgpu::Buffer>,
23    norms_bufs: Vec<wgpu::Buffer>,
24    chunk_counts: Vec<usize>,
25    dim: usize,
26    n_vectors: usize,
27}
28
29impl GpuAccelerator {
30    /// Check if any GPU is available on this system.
31    pub fn is_available() -> bool {
32        let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
33            backends: wgpu::Backends::all(),
34            ..Default::default()
35        });
36        let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
37            power_preference: wgpu::PowerPreference::HighPerformance,
38            compatible_surface: None,
39            force_fallback_adapter: false,
40        }));
41        adapter.is_ok()
42    }
43
44    /// Initialize the best available GPU device.
45    ///
46    /// Requests the adapter's maximum buffer limits so that chunking only
47    /// kicks in when truly necessary.
48    pub fn new() -> Result<Self> {
49        let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
50            backends: wgpu::Backends::all(),
51            ..Default::default()
52        });
53
54        let adapter =
55            pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
56                power_preference: wgpu::PowerPreference::HighPerformance,
57                compatible_surface: None,
58                force_fallback_adapter: false,
59            }))
60            .map_err(|_| GpuError::NoDevice)?;
61
62        let adapter_info = adapter.get_info();
63        let adapter_limits = adapter.limits();
64
65        let info = DeviceInfo {
66            name: adapter_info.name.clone(),
67            backend: format!("{:?}", adapter_info.backend),
68            device_type: format!("{:?}", adapter_info.device_type),
69            max_buffer_size: adapter_limits.max_buffer_size,
70            max_storage_buffer_binding_size: adapter_limits.max_storage_buffer_binding_size,
71        };
72
73        let (device, queue) = pollster::block_on(adapter.request_device(
74            &wgpu::DeviceDescriptor {
75                label: Some("rustyhdf5-gpu"),
76                required_features: wgpu::Features::empty(),
77                required_limits: wgpu::Limits {
78                    max_storage_buffer_binding_size: adapter_limits
79                        .max_storage_buffer_binding_size,
80                    max_buffer_size: adapter_limits.max_buffer_size,
81                    ..wgpu::Limits::default()
82                },
83                ..Default::default()
84            },
85        ))
86        .map_err(|e: wgpu::RequestDeviceError| GpuError::DeviceRequest(e.to_string()))?;
87
88        Ok(Self {
89            device,
90            queue,
91            info,
92            max_binding_size: adapter_limits.max_storage_buffer_binding_size,
93            vectors_bufs: Vec::new(),
94            norms_bufs: Vec::new(),
95            chunk_counts: Vec::new(),
96            dim: 0,
97            n_vectors: 0,
98        })
99    }
100
101    pub fn device_info(&self) -> &DeviceInfo {
102        &self.info
103    }
104
105    pub fn max_storage_buffer_binding_size(&self) -> u32 {
106        self.max_binding_size
107    }
108
109    /// Upload a flat array of vectors to GPU memory.
110    /// Automatically splits into chunks when data exceeds the binding limit.
111    pub fn upload_vectors(&mut self, vectors: &[f32], dim: usize) -> Result<()> {
112        if vectors.is_empty() || dim == 0 {
113            return Err(GpuError::DimensionMismatch {
114                expected: 1,
115                got: 0,
116            });
117        }
118        let n = vectors.len() / dim;
119        if vectors.len() != n * dim {
120            return Err(GpuError::DimensionMismatch {
121                expected: n * dim,
122                got: vectors.len(),
123            });
124        }
125
126        let max_vecs_per_chunk = self.max_binding_size as usize / (dim * 4);
127        if max_vecs_per_chunk == 0 {
128            return Err(GpuError::OutOfMemory {
129                need_mb: (dim as u64 * 4) / (1024 * 1024),
130                avail_mb: self.max_binding_size as u64 / (1024 * 1024),
131            });
132        }
133
134        let mut bufs = Vec::new();
135        let mut counts = Vec::new();
136        let mut offset = 0;
137        while offset < n {
138            let chunk_n = (n - offset).min(max_vecs_per_chunk);
139            let start = offset * dim;
140            let end = start + chunk_n * dim;
141            bufs.push(self.make_storage_buf("vectors_chunk", &vectors[start..end]));
142            counts.push(chunk_n);
143            offset += chunk_n;
144        }
145
146        self.vectors_bufs = bufs;
147        self.chunk_counts = counts;
148        self.norms_bufs.clear();
149        self.dim = dim;
150        self.n_vectors = n;
151        Ok(())
152    }
153
154    /// Upload pre-computed L2 norms, split to match vector chunk layout.
155    pub fn upload_norms(&mut self, norms: &[f32]) -> Result<()> {
156        if norms.len() != self.n_vectors {
157            return Err(GpuError::DimensionMismatch {
158                expected: self.n_vectors,
159                got: norms.len(),
160            });
161        }
162        let mut bufs = Vec::new();
163        let mut offset = 0;
164        for &chunk_n in &self.chunk_counts {
165            bufs.push(self.make_storage_buf("norms_chunk", &norms[offset..offset + chunk_n]));
166            offset += chunk_n;
167        }
168        self.norms_bufs = bufs;
169        Ok(())
170    }
171
172    /// Cosine similarity search: returns top-k (index, score) pairs, highest first.
173    /// Dispatches against each vector chunk and merges results.
174    pub fn cosine_search(&self, query: &[f32], k: usize) -> Result<Vec<(usize, f32)>> {
175        self.check_ready(query.len())?;
176        if k > self.n_vectors {
177            return Err(GpuError::KExceedsN {
178                k,
179                n: self.n_vectors,
180            });
181        }
182        let query_norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
183        let dim = self.dim as u32;
184        let mut all_results: Vec<(usize, f32)> = Vec::new();
185        let mut offset = 0usize;
186
187        for (ci, vecs_buf) in self.vectors_bufs.iter().enumerate() {
188            let chunk_n = self.chunk_counts[ci];
189            let params = Params4 {
190                a: dim,
191                b: chunk_n as u32,
192                c: query_norm,
193                d: 0,
194            };
195            let scores = self.run_4bind_shader(
196                shaders::COSINE_SIMILARITY,
197                &params,
198                query,
199                vecs_buf,
200                Some(&self.norms_bufs[ci]),
201                chunk_n,
202            )?;
203            let chunk_topk = top_k_cpu(&scores, k.min(chunk_n), true);
204            for (idx, score) in chunk_topk {
205                all_results.push((idx + offset, score));
206            }
207            offset += chunk_n;
208        }
209
210        all_results
211            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
212        all_results.truncate(k);
213        Ok(all_results)
214    }
215
216    /// Batch cosine search: multiple queries at once.
217    pub fn batch_cosine_search(
218        &self,
219        queries: &[Vec<f32>],
220        k: usize,
221    ) -> Result<Vec<Vec<(usize, f32)>>> {
222        queries.iter().map(|q| self.cosine_search(q, k)).collect()
223    }
224
225    /// L2 distance search: returns top-k (index, distance) pairs, smallest first.
226    pub fn l2_search(&self, query: &[f32], k: usize) -> Result<Vec<(usize, f32)>> {
227        self.check_vectors(query.len())?;
228        if k > self.n_vectors {
229            return Err(GpuError::KExceedsN {
230                k,
231                n: self.n_vectors,
232            });
233        }
234        let dim = self.dim as u32;
235        let mut all_results: Vec<(usize, f32)> = Vec::new();
236        let mut offset = 0usize;
237
238        for (ci, vecs_buf) in self.vectors_bufs.iter().enumerate() {
239            let chunk_n = self.chunk_counts[ci];
240            let params = Params4U {
241                a: dim,
242                b: chunk_n as u32,
243                c: 0,
244                d: 0,
245            };
246            let scores =
247                self.run_3bind_shader(shaders::L2_DISTANCE, &params, query, vecs_buf, chunk_n)?;
248            let chunk_topk = top_k_cpu(&scores, k.min(chunk_n), false);
249            for (idx, dist) in chunk_topk {
250                all_results.push((idx + offset, dist));
251            }
252            offset += chunk_n;
253        }
254
255        all_results
256            .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
257        all_results.truncate(k);
258        Ok(all_results)
259    }
260
261    /// Compute L2 norms for all uploaded vectors on GPU.
262    pub fn compute_norms(&self) -> Result<Vec<f32>> {
263        if self.vectors_bufs.is_empty() {
264            return Err(GpuError::NoVectors);
265        }
266        let mut all_norms = Vec::with_capacity(self.n_vectors);
267        for (ci, vecs_buf) in self.vectors_bufs.iter().enumerate() {
268            let norms = self.run_norms_shader(vecs_buf, self.chunk_counts[ci], self.dim)?;
269            all_norms.extend_from_slice(&norms);
270        }
271        Ok(all_norms)
272    }
273
274    /// Compute L2 norms from raw vectors (not previously uploaded).
275    /// Handles chunking automatically for large inputs.
276    pub fn compute_norms_gpu(&self, vectors: &[f32], dim: usize) -> Result<Vec<f32>> {
277        if vectors.is_empty() || dim == 0 {
278            return Err(GpuError::DimensionMismatch {
279                expected: 1,
280                got: 0,
281            });
282        }
283        let n = vectors.len() / dim;
284        if vectors.len() != n * dim {
285            return Err(GpuError::DimensionMismatch {
286                expected: n * dim,
287                got: vectors.len(),
288            });
289        }
290        let max_vecs = self.max_binding_size as usize / (dim * 4);
291        let mut all_norms = Vec::with_capacity(n);
292        let mut offset = 0;
293        while offset < n {
294            let chunk_n = (n - offset).min(max_vecs);
295            let start = offset * dim;
296            let end = start + chunk_n * dim;
297            let vecs_buf = self.make_storage_buf("temp_vectors", &vectors[start..end]);
298            let norms = self.run_norms_shader(&vecs_buf, chunk_n, dim)?;
299            all_norms.extend_from_slice(&norms);
300            offset += chunk_n;
301        }
302        Ok(all_norms)
303    }
304
305    /// Batch dot product: queries [Q×D] × vectors [N×D] -> flat [Q×N] scores.
306    pub fn batch_dot_product(
307        &self,
308        queries_flat: &[f32],
309        num_queries: usize,
310    ) -> Result<Vec<f32>> {
311        if self.vectors_bufs.is_empty() {
312            return Err(GpuError::NoVectors);
313        }
314        if queries_flat.len() != num_queries * self.dim {
315            return Err(GpuError::DimensionMismatch {
316                expected: num_queries * self.dim,
317                got: queries_flat.len(),
318            });
319        }
320        let queries_buf = self.make_storage_buf("queries", queries_flat);
321        let mut output = vec![0.0f32; num_queries * self.n_vectors];
322        let mut col_offset = 0usize;
323
324        for (ci, vecs_buf) in self.vectors_bufs.iter().enumerate() {
325            let chunk_n = self.chunk_counts[ci];
326            let total = (num_queries * chunk_n) as u32;
327            let params = Params4U {
328                a: self.dim as u32,
329                b: chunk_n as u32,
330                c: num_queries as u32,
331                d: 0,
332            };
333            let chunk_scores = self.run_batch_shader(
334                shaders::BATCH_DOT_PRODUCT,
335                &params,
336                &queries_buf,
337                vecs_buf,
338                total,
339                num_queries * chunk_n,
340            )?;
341            for qi in 0..num_queries {
342                let src = qi * chunk_n;
343                let dst = qi * self.n_vectors + col_offset;
344                output[dst..dst + chunk_n].copy_from_slice(&chunk_scores[src..src + chunk_n]);
345            }
346            col_offset += chunk_n;
347        }
348        Ok(output)
349    }
350
351    /// Compute L2 distance matrix: queries × vectors -> Q×N distances.
352    /// Uses 16×16 workgroup tiling for cache efficiency.
353    pub fn distance_matrix(
354        &self,
355        queries: &[f32],
356        vectors: &[f32],
357        dim: usize,
358    ) -> Result<Vec<Vec<f32>>> {
359        if queries.is_empty() || vectors.is_empty() || dim == 0 {
360            return Err(GpuError::DimensionMismatch {
361                expected: 1,
362                got: 0,
363            });
364        }
365        let num_queries = queries.len() / dim;
366        let n = vectors.len() / dim;
367        if queries.len() != num_queries * dim || vectors.len() != n * dim {
368            return Err(GpuError::DimensionMismatch {
369                expected: num_queries * dim,
370                got: queries.len(),
371            });
372        }
373        let queries_buf = self.make_storage_buf("dm_queries", queries);
374
375        let max_vecs_input = self.max_binding_size as usize / (dim * 4);
376        let max_vecs_output = if num_queries > 0 {
377            self.max_binding_size as usize / (num_queries * 4)
378        } else {
379            max_vecs_input
380        };
381        let max_vecs = max_vecs_input.min(max_vecs_output).max(1);
382
383        let mut flat_output = vec![0.0f32; num_queries * n];
384        let mut col_offset = 0usize;
385        let mut vec_offset = 0usize;
386
387        while vec_offset < n {
388            let chunk_n = (n - vec_offset).min(max_vecs);
389            let start = vec_offset * dim;
390            let end = start + chunk_n * dim;
391            let vecs_buf = self.make_storage_buf("dm_vectors", &vectors[start..end]);
392            let params = Params4U {
393                a: dim as u32,
394                b: chunk_n as u32,
395                c: num_queries as u32,
396                d: 0,
397            };
398            let chunk_dists = self.run_distance_matrix_shader(
399                &params,
400                &queries_buf,
401                &vecs_buf,
402                num_queries,
403                chunk_n,
404            )?;
405            for qi in 0..num_queries {
406                let src = qi * chunk_n;
407                let dst = qi * n + col_offset;
408                flat_output[dst..dst + chunk_n]
409                    .copy_from_slice(&chunk_dists[src..src + chunk_n]);
410            }
411            col_offset += chunk_n;
412            vec_offset += chunk_n;
413        }
414
415        Ok((0..num_queries)
416            .map(|qi| flat_output[qi * n..(qi + 1) * n].to_vec())
417            .collect())
418    }
419
420    /// Convert f16 values (as raw u16 bits) to f32 on the GPU.
421    pub fn f16_to_f32_batch(&self, f16_bits: &[u16]) -> Result<Vec<f32>> {
422        let total = f16_bits.len() as u32;
423        let params = Params4U { a: total, b: 0, c: 0, d: 0 };
424        let packed: Vec<u32> = f16_bits
425            .chunks(2)
426            .map(|c| {
427                let lo = c[0] as u32;
428                let hi = if c.len() > 1 { c[1] as u32 } else { 0 };
429                lo | (hi << 16)
430            })
431            .collect();
432
433        let pair_count = f16_bits.len().div_ceil(2);
434        let (_params_buf, _input_buf, output_buf, bgl, bind_group) = self.make_3bind_group(
435            &params,
436            bytemuck::cast_slice(&packed),
437            (f16_bits.len() * 4) as u64,
438        );
439        let module = self.make_module("f16_to_f32", shaders::F16_TO_F32);
440        let pipeline = self.create_pipeline(&module, &bgl);
441        self.dispatch(&pipeline, &bind_group, div_ceil(pair_count as u32, WORKGROUP_SIZE));
442        self.read_buffer::<f32>(&output_buf, f16_bits.len())
443    }
444
445    /// Convert f32 values to f16 (as raw u16 bits) on the GPU.
446    pub fn f32_to_f16_batch(&self, values: &[f32]) -> Result<Vec<u16>> {
447        let total = values.len() as u32;
448        let params = Params4U { a: total, b: 0, c: 0, d: 0 };
449        let pair_count = values.len().div_ceil(2);
450
451        let (_params_buf, _input_buf, output_buf, bgl, bind_group) = self.make_3bind_group(
452            &params,
453            bytemuck::cast_slice(values),
454            (pair_count * 4) as u64,
455        );
456        let module = self.make_module("f32_to_f16", shaders::F32_TO_F16);
457        let pipeline = self.create_pipeline(&module, &bgl);
458        self.dispatch(&pipeline, &bind_group, div_ceil(pair_count as u32, WORKGROUP_SIZE));
459
460        let packed = self.read_buffer::<u32>(&output_buf, pair_count)?;
461        let mut result = Vec::with_capacity(values.len());
462        for (i, &word) in packed.iter().enumerate() {
463            result.push((word & 0xFFFF) as u16);
464            if i * 2 + 1 < values.len() {
465                result.push((word >> 16) as u16);
466            }
467        }
468        Ok(result)
469    }
470
471    pub fn vector_count(&self) -> usize {
472        self.n_vectors
473    }
474
475    pub fn dimension(&self) -> usize {
476        self.dim
477    }
478
479    pub fn chunk_count(&self) -> usize {
480        self.chunk_counts.len()
481    }
482
483    // ── Internal helpers ──────────────────────────────────────────
484
485    fn check_ready(&self, query_dim: usize) -> Result<()> {
486        self.check_vectors(query_dim)?;
487        if self.norms_bufs.is_empty() {
488            return Err(GpuError::NoNorms);
489        }
490        Ok(())
491    }
492
493    fn check_vectors(&self, query_dim: usize) -> Result<()> {
494        if self.vectors_bufs.is_empty() {
495            return Err(GpuError::NoVectors);
496        }
497        if query_dim != self.dim {
498            return Err(GpuError::DimensionMismatch {
499                expected: self.dim,
500                got: query_dim,
501            });
502        }
503        Ok(())
504    }
505
506    fn make_storage_buf(&self, label: &str, data: &[f32]) -> wgpu::Buffer {
507        self.device
508            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
509                label: Some(label),
510                contents: bytemuck::cast_slice(data),
511                usage: wgpu::BufferUsages::STORAGE,
512            })
513    }
514
515    fn make_module(&self, label: &str, src: &str) -> wgpu::ShaderModule {
516        self.device
517            .create_shader_module(wgpu::ShaderModuleDescriptor {
518                label: Some(label),
519                source: wgpu::ShaderSource::Wgsl(src.into()),
520            })
521    }
522
523    /// Create a 3-binding group: uniform params, storage input, storage RW output.
524    #[allow(clippy::type_complexity)]
525    fn make_3bind_group(
526        &self,
527        params: &Params4U,
528        input_data: &[u8],
529        output_size: u64,
530    ) -> (
531        wgpu::Buffer,
532        wgpu::Buffer,
533        wgpu::Buffer,
534        wgpu::BindGroupLayout,
535        wgpu::BindGroup,
536    ) {
537        let params_buf = self
538            .device
539            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
540                label: Some("params"),
541                contents: bytemuck::bytes_of(params),
542                usage: wgpu::BufferUsages::UNIFORM,
543            });
544        let input_buf = self
545            .device
546            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
547                label: Some("input"),
548                contents: input_data,
549                usage: wgpu::BufferUsages::STORAGE,
550            });
551        let output_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
552            label: Some("output"),
553            size: output_size,
554            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
555            mapped_at_creation: false,
556        });
557        let bgl = self
558            .device
559            .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
560                label: None,
561                entries: &[
562                    bgl_entry(0, wgpu::BufferBindingType::Uniform),
563                    bgl_entry(1, wgpu::BufferBindingType::Storage { read_only: true }),
564                    bgl_entry(2, wgpu::BufferBindingType::Storage { read_only: false }),
565                ],
566            });
567        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
568            label: None,
569            layout: &bgl,
570            entries: &[
571                wgpu::BindGroupEntry {
572                    binding: 0,
573                    resource: params_buf.as_entire_binding(),
574                },
575                wgpu::BindGroupEntry {
576                    binding: 1,
577                    resource: input_buf.as_entire_binding(),
578                },
579                wgpu::BindGroupEntry {
580                    binding: 2,
581                    resource: output_buf.as_entire_binding(),
582                },
583            ],
584        });
585        (params_buf, input_buf, output_buf, bgl, bind_group)
586    }
587
588    fn run_norms_shader(
589        &self,
590        vecs_buf: &wgpu::Buffer,
591        n: usize,
592        dim: usize,
593    ) -> Result<Vec<f32>> {
594        let params = Params4U {
595            a: dim as u32,
596            b: n as u32,
597            c: 0,
598            d: 0,
599        };
600        let params_buf = self
601            .device
602            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
603                label: Some("params"),
604                contents: bytemuck::bytes_of(&params),
605                usage: wgpu::BufferUsages::UNIFORM,
606            });
607        let output_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
608            label: Some("norms_out"),
609            size: (n * 4) as u64,
610            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
611            mapped_at_creation: false,
612        });
613        let module = self.make_module("batch_norms", shaders::BATCH_NORMS);
614        let bgl = self
615            .device
616            .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
617                label: None,
618                entries: &[
619                    bgl_entry(0, wgpu::BufferBindingType::Uniform),
620                    bgl_entry(1, wgpu::BufferBindingType::Storage { read_only: true }),
621                    bgl_entry(2, wgpu::BufferBindingType::Storage { read_only: false }),
622                ],
623            });
624        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
625            label: None,
626            layout: &bgl,
627            entries: &[
628                wgpu::BindGroupEntry {
629                    binding: 0,
630                    resource: params_buf.as_entire_binding(),
631                },
632                wgpu::BindGroupEntry {
633                    binding: 1,
634                    resource: vecs_buf.as_entire_binding(),
635                },
636                wgpu::BindGroupEntry {
637                    binding: 2,
638                    resource: output_buf.as_entire_binding(),
639                },
640            ],
641        });
642        let pipeline = self.create_pipeline(&module, &bgl);
643        self.dispatch(&pipeline, &bind_group, div_ceil(n as u32, WORKGROUP_SIZE));
644        self.read_buffer::<f32>(&output_buf, n)
645    }
646
647    fn run_batch_shader(
648        &self,
649        shader_src: &str,
650        params: &Params4U,
651        queries_buf: &wgpu::Buffer,
652        vecs_buf: &wgpu::Buffer,
653        total_threads: u32,
654        output_len: usize,
655    ) -> Result<Vec<f32>> {
656        let params_buf = self
657            .device
658            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
659                label: Some("params"),
660                contents: bytemuck::bytes_of(params),
661                usage: wgpu::BufferUsages::UNIFORM,
662            });
663        let output_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
664            label: Some("scores"),
665            size: (output_len * 4) as u64,
666            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
667            mapped_at_creation: false,
668        });
669        let module = self
670            .device
671            .create_shader_module(wgpu::ShaderModuleDescriptor {
672                label: None,
673                source: wgpu::ShaderSource::Wgsl(shader_src.into()),
674            });
675        let bgl = self
676            .device
677            .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
678                label: None,
679                entries: &[
680                    bgl_entry(0, wgpu::BufferBindingType::Uniform),
681                    bgl_entry(1, wgpu::BufferBindingType::Storage { read_only: true }),
682                    bgl_entry(2, wgpu::BufferBindingType::Storage { read_only: true }),
683                    bgl_entry(3, wgpu::BufferBindingType::Storage { read_only: false }),
684                ],
685            });
686        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
687            label: None,
688            layout: &bgl,
689            entries: &[
690                wgpu::BindGroupEntry {
691                    binding: 0,
692                    resource: params_buf.as_entire_binding(),
693                },
694                wgpu::BindGroupEntry {
695                    binding: 1,
696                    resource: queries_buf.as_entire_binding(),
697                },
698                wgpu::BindGroupEntry {
699                    binding: 2,
700                    resource: vecs_buf.as_entire_binding(),
701                },
702                wgpu::BindGroupEntry {
703                    binding: 3,
704                    resource: output_buf.as_entire_binding(),
705                },
706            ],
707        });
708        let pipeline = self.create_pipeline(&module, &bgl);
709        self.dispatch(&pipeline, &bind_group, div_ceil(total_threads, WORKGROUP_SIZE));
710        self.read_buffer::<f32>(&output_buf, output_len)
711    }
712
713    fn run_distance_matrix_shader(
714        &self,
715        params: &Params4U,
716        queries_buf: &wgpu::Buffer,
717        vecs_buf: &wgpu::Buffer,
718        num_queries: usize,
719        chunk_n: usize,
720    ) -> Result<Vec<f32>> {
721        let params_buf = self
722            .device
723            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
724                label: Some("params"),
725                contents: bytemuck::bytes_of(params),
726                usage: wgpu::BufferUsages::UNIFORM,
727            });
728        let output_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
729            label: Some("distances"),
730            size: (num_queries * chunk_n * 4) as u64,
731            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
732            mapped_at_creation: false,
733        });
734        let module = self.make_module("distance_matrix", shaders::DISTANCE_MATRIX);
735        let bgl = self
736            .device
737            .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
738                label: None,
739                entries: &[
740                    bgl_entry(0, wgpu::BufferBindingType::Uniform),
741                    bgl_entry(1, wgpu::BufferBindingType::Storage { read_only: true }),
742                    bgl_entry(2, wgpu::BufferBindingType::Storage { read_only: true }),
743                    bgl_entry(3, wgpu::BufferBindingType::Storage { read_only: false }),
744                ],
745            });
746        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
747            label: None,
748            layout: &bgl,
749            entries: &[
750                wgpu::BindGroupEntry {
751                    binding: 0,
752                    resource: params_buf.as_entire_binding(),
753                },
754                wgpu::BindGroupEntry {
755                    binding: 1,
756                    resource: queries_buf.as_entire_binding(),
757                },
758                wgpu::BindGroupEntry {
759                    binding: 2,
760                    resource: vecs_buf.as_entire_binding(),
761                },
762                wgpu::BindGroupEntry {
763                    binding: 3,
764                    resource: output_buf.as_entire_binding(),
765                },
766            ],
767        });
768        let pipeline = self.create_pipeline(&module, &bgl);
769        let wg_x = div_ceil(chunk_n as u32, 16);
770        let wg_y = div_ceil(num_queries as u32, 16);
771        self.dispatch_2d(&pipeline, &bind_group, wg_x, wg_y);
772        self.read_buffer::<f32>(&output_buf, num_queries * chunk_n)
773    }
774
775    fn run_4bind_shader(
776        &self,
777        shader_src: &str,
778        params: &Params4,
779        query: &[f32],
780        vectors_buf: &wgpu::Buffer,
781        extra_buf: Option<&wgpu::Buffer>,
782        output_len: usize,
783    ) -> Result<Vec<f32>> {
784        let params_buf = self
785            .device
786            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
787                label: Some("params"),
788                contents: bytemuck::bytes_of(params),
789                usage: wgpu::BufferUsages::UNIFORM,
790            });
791        let query_buf = self.make_storage_buf("query", query);
792        let output_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
793            label: Some("scores"),
794            size: (output_len * 4) as u64,
795            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
796            mapped_at_creation: false,
797        });
798        let module = self
799            .device
800            .create_shader_module(wgpu::ShaderModuleDescriptor {
801                label: None,
802                source: wgpu::ShaderSource::Wgsl(shader_src.into()),
803            });
804
805        let mut entries_desc = vec![
806            bgl_entry(0, wgpu::BufferBindingType::Uniform),
807            bgl_entry(1, wgpu::BufferBindingType::Storage { read_only: true }),
808            bgl_entry(2, wgpu::BufferBindingType::Storage { read_only: true }),
809        ];
810        let mut bind_entries = vec![
811            wgpu::BindGroupEntry {
812                binding: 0,
813                resource: params_buf.as_entire_binding(),
814            },
815            wgpu::BindGroupEntry {
816                binding: 1,
817                resource: query_buf.as_entire_binding(),
818            },
819            wgpu::BindGroupEntry {
820                binding: 2,
821                resource: vectors_buf.as_entire_binding(),
822            },
823        ];
824
825        if let Some(eb) = extra_buf {
826            entries_desc.push(bgl_entry(3, wgpu::BufferBindingType::Storage { read_only: true }));
827            entries_desc.push(bgl_entry(4, wgpu::BufferBindingType::Storage { read_only: false }));
828            bind_entries.push(wgpu::BindGroupEntry {
829                binding: 3,
830                resource: eb.as_entire_binding(),
831            });
832            bind_entries.push(wgpu::BindGroupEntry {
833                binding: 4,
834                resource: output_buf.as_entire_binding(),
835            });
836        } else {
837            entries_desc.push(bgl_entry(3, wgpu::BufferBindingType::Storage { read_only: false }));
838            bind_entries.push(wgpu::BindGroupEntry {
839                binding: 3,
840                resource: output_buf.as_entire_binding(),
841            });
842        }
843
844        let bgl = self
845            .device
846            .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
847                label: None,
848                entries: &entries_desc,
849            });
850        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
851            label: None,
852            layout: &bgl,
853            entries: &bind_entries,
854        });
855
856        let pipeline = self.create_pipeline(&module, &bgl);
857        self.dispatch(
858            &pipeline,
859            &bind_group,
860            div_ceil(output_len as u32, WORKGROUP_SIZE),
861        );
862        self.read_buffer::<f32>(&output_buf, output_len)
863    }
864
865    fn run_3bind_shader(
866        &self,
867        shader_src: &str,
868        params: &Params4U,
869        query: &[f32],
870        vectors_buf: &wgpu::Buffer,
871        output_len: usize,
872    ) -> Result<Vec<f32>> {
873        let params_buf = self
874            .device
875            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
876                label: Some("params"),
877                contents: bytemuck::bytes_of(params),
878                usage: wgpu::BufferUsages::UNIFORM,
879            });
880        let query_buf = self.make_storage_buf("query", query);
881        let output_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
882            label: Some("scores"),
883            size: (output_len * 4) as u64,
884            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
885            mapped_at_creation: false,
886        });
887        let module = self
888            .device
889            .create_shader_module(wgpu::ShaderModuleDescriptor {
890                label: None,
891                source: wgpu::ShaderSource::Wgsl(shader_src.into()),
892            });
893        let bgl = self
894            .device
895            .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
896                label: None,
897                entries: &[
898                    bgl_entry(0, wgpu::BufferBindingType::Uniform),
899                    bgl_entry(1, wgpu::BufferBindingType::Storage { read_only: true }),
900                    bgl_entry(2, wgpu::BufferBindingType::Storage { read_only: true }),
901                    bgl_entry(3, wgpu::BufferBindingType::Storage { read_only: false }),
902                ],
903            });
904        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
905            label: None,
906            layout: &bgl,
907            entries: &[
908                wgpu::BindGroupEntry {
909                    binding: 0,
910                    resource: params_buf.as_entire_binding(),
911                },
912                wgpu::BindGroupEntry {
913                    binding: 1,
914                    resource: query_buf.as_entire_binding(),
915                },
916                wgpu::BindGroupEntry {
917                    binding: 2,
918                    resource: vectors_buf.as_entire_binding(),
919                },
920                wgpu::BindGroupEntry {
921                    binding: 3,
922                    resource: output_buf.as_entire_binding(),
923                },
924            ],
925        });
926        let pipeline = self.create_pipeline(&module, &bgl);
927        self.dispatch(
928            &pipeline,
929            &bind_group,
930            div_ceil(output_len as u32, WORKGROUP_SIZE),
931        );
932        self.read_buffer::<f32>(&output_buf, output_len)
933    }
934
935    fn create_pipeline(
936        &self,
937        module: &wgpu::ShaderModule,
938        bgl: &wgpu::BindGroupLayout,
939    ) -> wgpu::ComputePipeline {
940        let layout = self
941            .device
942            .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
943                label: None,
944                bind_group_layouts: &[bgl],
945                immediate_size: 0,
946            });
947        self.device
948            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
949                label: None,
950                layout: Some(&layout),
951                module,
952                entry_point: Some("main"),
953                compilation_options: Default::default(),
954                cache: None,
955            })
956    }
957
958    fn dispatch(
959        &self,
960        pipeline: &wgpu::ComputePipeline,
961        bind_group: &wgpu::BindGroup,
962        workgroups: u32,
963    ) {
964        let mut encoder = self
965            .device
966            .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
967        {
968            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
969                label: None,
970                timestamp_writes: None,
971            });
972            pass.set_pipeline(pipeline);
973            pass.set_bind_group(0, bind_group, &[]);
974            pass.dispatch_workgroups(workgroups, 1, 1);
975        }
976        self.queue.submit(std::iter::once(encoder.finish()));
977    }
978
979    fn dispatch_2d(
980        &self,
981        pipeline: &wgpu::ComputePipeline,
982        bind_group: &wgpu::BindGroup,
983        wg_x: u32,
984        wg_y: u32,
985    ) {
986        let mut encoder = self
987            .device
988            .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
989        {
990            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
991                label: None,
992                timestamp_writes: None,
993            });
994            pass.set_pipeline(pipeline);
995            pass.set_bind_group(0, bind_group, &[]);
996            pass.dispatch_workgroups(wg_x, wg_y, 1);
997        }
998        self.queue.submit(std::iter::once(encoder.finish()));
999    }
1000
1001    fn read_buffer<T: Pod>(&self, buffer: &wgpu::Buffer, count: usize) -> Result<Vec<T>> {
1002        let elem_size = std::mem::size_of::<T>();
1003        let byte_len = (count * elem_size) as u64;
1004        let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
1005            label: Some("staging"),
1006            size: byte_len,
1007            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1008            mapped_at_creation: false,
1009        });
1010        let mut encoder = self
1011            .device
1012            .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
1013        encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, byte_len);
1014        self.queue.submit(std::iter::once(encoder.finish()));
1015
1016        let slice = staging.slice(..);
1017        let (tx, rx) = std::sync::mpsc::channel();
1018        slice.map_async(wgpu::MapMode::Read, move |result| {
1019            let _ = tx.send(result);
1020        });
1021        let _ = self.device.poll(wgpu::PollType::Wait {
1022            submission_index: None,
1023            timeout: None,
1024        });
1025        rx.recv()
1026            .map_err(|e| GpuError::BufferMap(e.to_string()))?
1027            .map_err(|e| GpuError::BufferMap(e.to_string()))?;
1028
1029        let data = slice.get_mapped_range();
1030        let result: Vec<T> = bytemuck::cast_slice(&data).to_vec();
1031        drop(data);
1032        staging.unmap();
1033        Ok(result)
1034    }
1035}