Skip to main content

adele_ring/
gpu.rs

1//! `GpuBackend` — wgpu compute-shader implementation of [`ArithmeticBackend`].
2//!
3//! wgpu is always compiled in (not feature-gated). At startup [`GpuBackend::try_init`]
4//! probes for a compatible adapter; if none is found it returns `Err` and the
5//! [`crate::backend::Executor`] transparently falls back to the CPU backend.
6//!
7//! Each shader thread handles one `(batch_item × channel)` pair. The buffer
8//! layout is identical to [`crate::batch::RnsBatch`], so there is no reformatting
9//! on upload beyond the `u64 -> u32` narrowing (safe because all moduli `< 2^16`).
10
11use bytemuck::{Pod, Zeroable};
12use num_bigint::BigUint;
13use wgpu::util::DeviceExt;
14
15use crate::backend::ArithmeticBackend;
16use crate::batch::RnsBatch;
17use crate::rns::garner_crt;
18
19/// Errors that can occur while bringing up the GPU backend.
20#[derive(Debug, thiserror::Error)]
21pub enum GpuError {
22    #[error("no compatible GPU adapter found")]
23    NoAdapter,
24    #[error("failed to acquire GPU device: {0}")]
25    Device(#[from] wgpu::RequestDeviceError),
26}
27
28#[repr(C)]
29#[derive(Clone, Copy, Pod, Zeroable)]
30struct Params {
31    batch_size: u32,
32    n_channels: u32,
33    _pad: [u32; 2], // pad to 16 bytes for std140 uniform layout
34}
35
36/// GPU backend holding a device, queue, and the pre-built compute pipelines.
37pub struct GpuBackend {
38    device: wgpu::Device,
39    queue: wgpu::Queue,
40    bind_group_layout: wgpu::BindGroupLayout,
41    add_pipeline: wgpu::ComputePipeline,
42    mul_pipeline: wgpu::ComputePipeline,
43    adapter_info: wgpu::AdapterInfo,
44}
45
46impl GpuBackend {
47    /// Probe for a GPU and build the pipelines. Blocks on async init.
48    pub fn try_init() -> Result<Self, GpuError> {
49        pollster::block_on(Self::try_init_async())
50    }
51
52    async fn try_init_async() -> Result<Self, GpuError> {
53        let instance = wgpu::Instance::default();
54        let adapter = instance
55            .request_adapter(&wgpu::RequestAdapterOptions {
56                power_preference: wgpu::PowerPreference::HighPerformance,
57                force_fallback_adapter: false,
58                compatible_surface: None,
59            })
60            .await
61            .ok_or(GpuError::NoAdapter)?;
62
63        let adapter_info = adapter.get_info();
64
65        let (device, queue) = adapter
66            .request_device(
67                &wgpu::DeviceDescriptor {
68                    label: Some("adele-ring-device"),
69                    required_features: wgpu::Features::empty(),
70                    required_limits: wgpu::Limits::downlevel_defaults(),
71                    memory_hints: wgpu::MemoryHints::Performance,
72                },
73                None,
74            )
75            .await?;
76
77        let add_shader =
78            device.create_shader_module(wgpu::include_wgsl!("../shaders/rns_add.wgsl"));
79        let mul_shader =
80            device.create_shader_module(wgpu::include_wgsl!("../shaders/rns_mul.wgsl"));
81
82        let bind_group_layout = Self::make_bind_group_layout(&device);
83        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
84            label: Some("adele-ring-pipeline-layout"),
85            bind_group_layouts: &[&bind_group_layout],
86            push_constant_ranges: &[],
87        });
88
89        let add_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
90            label: Some("rns-add"),
91            layout: Some(&pipeline_layout),
92            module: &add_shader,
93            entry_point: "main",
94            compilation_options: Default::default(),
95            cache: None,
96        });
97        let mul_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
98            label: Some("rns-mul"),
99            layout: Some(&pipeline_layout),
100            module: &mul_shader,
101            entry_point: "main",
102            compilation_options: Default::default(),
103            cache: None,
104        });
105
106        Ok(Self {
107            device,
108            queue,
109            bind_group_layout,
110            add_pipeline,
111            mul_pipeline,
112            adapter_info,
113        })
114    }
115
116    /// Human-readable adapter name (e.g. "NVIDIA GeForce RTX 4080").
117    pub fn adapter_name(&self) -> &str {
118        &self.adapter_info.name
119    }
120
121    fn make_bind_group_layout(device: &wgpu::Device) -> wgpu::BindGroupLayout {
122        let storage = |read_only: bool| wgpu::BindingType::Buffer {
123            ty: wgpu::BufferBindingType::Storage { read_only },
124            has_dynamic_offset: false,
125            min_binding_size: None,
126        };
127        device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
128            label: Some("adele-ring-bgl"),
129            entries: &[
130                wgpu::BindGroupLayoutEntry {
131                    binding: 0,
132                    visibility: wgpu::ShaderStages::COMPUTE,
133                    ty: wgpu::BindingType::Buffer {
134                        ty: wgpu::BufferBindingType::Uniform,
135                        has_dynamic_offset: false,
136                        min_binding_size: None,
137                    },
138                    count: None,
139                },
140                wgpu::BindGroupLayoutEntry {
141                    binding: 1,
142                    visibility: wgpu::ShaderStages::COMPUTE,
143                    ty: storage(true),
144                    count: None,
145                },
146                wgpu::BindGroupLayoutEntry {
147                    binding: 2,
148                    visibility: wgpu::ShaderStages::COMPUTE,
149                    ty: storage(true),
150                    count: None,
151                },
152                wgpu::BindGroupLayoutEntry {
153                    binding: 3,
154                    visibility: wgpu::ShaderStages::COMPUTE,
155                    ty: storage(true),
156                    count: None,
157                },
158                wgpu::BindGroupLayoutEntry {
159                    binding: 4,
160                    visibility: wgpu::ShaderStages::COMPUTE,
161                    ty: storage(false),
162                    count: None,
163                },
164            ],
165        })
166    }
167
168    fn run_pipeline(
169        &self,
170        pipeline: &wgpu::ComputePipeline,
171        a: &RnsBatch,
172        b: &RnsBatch,
173    ) -> RnsBatch {
174        let k = a.channels.len();
175        let b_size = a.batch_size;
176        let n_elems = b_size * k;
177        let byte_len = (n_elems * std::mem::size_of::<u32>()) as u64;
178
179        let params = Params {
180            batch_size: b_size as u32,
181            n_channels: k as u32,
182            _pad: [0, 0],
183        };
184        let params_buf = self
185            .device
186            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
187                label: Some("params"),
188                contents: bytemuck::bytes_of(&params),
189                usage: wgpu::BufferUsages::UNIFORM,
190            });
191
192        let moduli_u32: Vec<u32> = a.channels.moduli().iter().map(|&m| m as u32).collect();
193        let moduli_buf = self
194            .device
195            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
196                label: Some("moduli"),
197                contents: bytemuck::cast_slice(&moduli_u32),
198                usage: wgpu::BufferUsages::STORAGE,
199            });
200
201        let a_buf = self
202            .device
203            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
204                label: Some("a"),
205                contents: &a.as_u32_bytes(),
206                usage: wgpu::BufferUsages::STORAGE,
207            });
208        let b_buf = self
209            .device
210            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
211                label: Some("b"),
212                contents: &b.as_u32_bytes(),
213                usage: wgpu::BufferUsages::STORAGE,
214            });
215
216        let out_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
217            label: Some("out"),
218            size: byte_len,
219            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
220            mapped_at_creation: false,
221        });
222        let staging_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
223            label: Some("staging"),
224            size: byte_len,
225            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
226            mapped_at_creation: false,
227        });
228
229        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
230            label: Some("adele-ring-bg"),
231            layout: &self.bind_group_layout,
232            entries: &[
233                wgpu::BindGroupEntry {
234                    binding: 0,
235                    resource: params_buf.as_entire_binding(),
236                },
237                wgpu::BindGroupEntry {
238                    binding: 1,
239                    resource: moduli_buf.as_entire_binding(),
240                },
241                wgpu::BindGroupEntry {
242                    binding: 2,
243                    resource: a_buf.as_entire_binding(),
244                },
245                wgpu::BindGroupEntry {
246                    binding: 3,
247                    resource: b_buf.as_entire_binding(),
248                },
249                wgpu::BindGroupEntry {
250                    binding: 4,
251                    resource: out_buf.as_entire_binding(),
252                },
253            ],
254        });
255
256        let mut encoder = self
257            .device
258            .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
259        {
260            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
261                label: Some("rns-pass"),
262                timestamp_writes: None,
263            });
264            pass.set_pipeline(pipeline);
265            pass.set_bind_group(0, &bind_group, &[]);
266            pass.dispatch_workgroups(
267                (b_size as u32).div_ceil(16),
268                (k as u32).div_ceil(16),
269                1,
270            );
271        }
272        encoder.copy_buffer_to_buffer(&out_buf, 0, &staging_buf, 0, byte_len);
273        self.queue.submit([encoder.finish()]);
274
275        let slice = staging_buf.slice(..);
276        let (tx, rx) = std::sync::mpsc::channel();
277        slice.map_async(wgpu::MapMode::Read, move |res| {
278            let _ = tx.send(res);
279        });
280        self.device.poll(wgpu::Maintain::Wait);
281        rx.recv()
282            .expect("map_async channel closed")
283            .expect("buffer map failed");
284
285        let data = slice.get_mapped_range();
286        let values: &[u32] = bytemuck::cast_slice(&data);
287        let result = RnsBatch::from_u32(values, b_size, a.channels.clone());
288        drop(data);
289        staging_buf.unmap();
290        result
291    }
292}
293
294impl ArithmeticBackend for GpuBackend {
295    fn batch_rns_add(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch {
296        self.run_pipeline(&self.add_pipeline, a, b)
297    }
298
299    fn batch_rns_mul(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch {
300        self.run_pipeline(&self.mul_pipeline, a, b)
301    }
302
303    fn batch_crt(&self, batch: &RnsBatch) -> Vec<BigUint> {
304        // CRT (Garner) is sequential; do it on the CPU regardless of backend.
305        let k = batch.channels.len();
306        let moduli = batch.channels.moduli();
307        (0..batch.batch_size)
308            .map(|b| garner_crt(&batch.data[b * k..(b + 1) * k], moduli))
309            .collect()
310    }
311
312    fn name(&self) -> &'static str {
313        "gpu-wgpu"
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::rns::{Channels, RnsInt};
321
322    #[test]
323    fn gpu_matches_cpu_when_available() {
324        let gpu = match GpuBackend::try_init() {
325            Ok(g) => g,
326            Err(_) => return, // no GPU on this machine; skip
327        };
328        let ch = Channels::standard(32);
329        let a = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(123, ch.clone()); 256]);
330        let b = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(456, ch.clone()); 256]);
331
332        let cpu = crate::cpu::CpuBackend::new();
333        assert_eq!(
334            cpu.batch_rns_add(&a, &b).data,
335            gpu.batch_rns_add(&a, &b).data
336        );
337        assert_eq!(
338            cpu.batch_rns_mul(&a, &b).data,
339            gpu.batch_rns_mul(&a, &b).data
340        );
341    }
342}