1use bytemuck::{Pod, Zeroable};
12use num_bigint::BigUint;
13use wgpu::util::DeviceExt;
14
15use crate::backend::ArithmeticBackend;
16use crate::batch::RnsBatch;
17use crate::rns::garner_crt;
18
19#[derive(Debug, thiserror::Error)]
21pub enum GpuError {
22 #[error("no compatible GPU adapter found")]
23 NoAdapter,
24 #[error("failed to acquire GPU device: {0}")]
25 Device(#[from] wgpu::RequestDeviceError),
26}
27
28#[repr(C)]
29#[derive(Clone, Copy, Pod, Zeroable)]
30struct Params {
31 batch_size: u32,
32 n_channels: u32,
33 _pad: [u32; 2], }
35
36pub struct GpuBackend {
38 device: wgpu::Device,
39 queue: wgpu::Queue,
40 bind_group_layout: wgpu::BindGroupLayout,
41 add_pipeline: wgpu::ComputePipeline,
42 mul_pipeline: wgpu::ComputePipeline,
43 adapter_info: wgpu::AdapterInfo,
44}
45
46impl GpuBackend {
47 pub fn try_init() -> Result<Self, GpuError> {
49 pollster::block_on(Self::try_init_async())
50 }
51
52 async fn try_init_async() -> Result<Self, GpuError> {
53 let instance = wgpu::Instance::default();
54 let adapter = instance
55 .request_adapter(&wgpu::RequestAdapterOptions {
56 power_preference: wgpu::PowerPreference::HighPerformance,
57 force_fallback_adapter: false,
58 compatible_surface: None,
59 })
60 .await
61 .ok_or(GpuError::NoAdapter)?;
62
63 let adapter_info = adapter.get_info();
64
65 let (device, queue) = adapter
66 .request_device(
67 &wgpu::DeviceDescriptor {
68 label: Some("adele-ring-device"),
69 required_features: wgpu::Features::empty(),
70 required_limits: wgpu::Limits::downlevel_defaults(),
71 memory_hints: wgpu::MemoryHints::Performance,
72 },
73 None,
74 )
75 .await?;
76
77 let add_shader =
78 device.create_shader_module(wgpu::include_wgsl!("../shaders/rns_add.wgsl"));
79 let mul_shader =
80 device.create_shader_module(wgpu::include_wgsl!("../shaders/rns_mul.wgsl"));
81
82 let bind_group_layout = Self::make_bind_group_layout(&device);
83 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
84 label: Some("adele-ring-pipeline-layout"),
85 bind_group_layouts: &[&bind_group_layout],
86 push_constant_ranges: &[],
87 });
88
89 let add_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
90 label: Some("rns-add"),
91 layout: Some(&pipeline_layout),
92 module: &add_shader,
93 entry_point: "main",
94 compilation_options: Default::default(),
95 cache: None,
96 });
97 let mul_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
98 label: Some("rns-mul"),
99 layout: Some(&pipeline_layout),
100 module: &mul_shader,
101 entry_point: "main",
102 compilation_options: Default::default(),
103 cache: None,
104 });
105
106 Ok(Self {
107 device,
108 queue,
109 bind_group_layout,
110 add_pipeline,
111 mul_pipeline,
112 adapter_info,
113 })
114 }
115
116 pub fn adapter_name(&self) -> &str {
118 &self.adapter_info.name
119 }
120
121 fn make_bind_group_layout(device: &wgpu::Device) -> wgpu::BindGroupLayout {
122 let storage = |read_only: bool| wgpu::BindingType::Buffer {
123 ty: wgpu::BufferBindingType::Storage { read_only },
124 has_dynamic_offset: false,
125 min_binding_size: None,
126 };
127 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
128 label: Some("adele-ring-bgl"),
129 entries: &[
130 wgpu::BindGroupLayoutEntry {
131 binding: 0,
132 visibility: wgpu::ShaderStages::COMPUTE,
133 ty: wgpu::BindingType::Buffer {
134 ty: wgpu::BufferBindingType::Uniform,
135 has_dynamic_offset: false,
136 min_binding_size: None,
137 },
138 count: None,
139 },
140 wgpu::BindGroupLayoutEntry {
141 binding: 1,
142 visibility: wgpu::ShaderStages::COMPUTE,
143 ty: storage(true),
144 count: None,
145 },
146 wgpu::BindGroupLayoutEntry {
147 binding: 2,
148 visibility: wgpu::ShaderStages::COMPUTE,
149 ty: storage(true),
150 count: None,
151 },
152 wgpu::BindGroupLayoutEntry {
153 binding: 3,
154 visibility: wgpu::ShaderStages::COMPUTE,
155 ty: storage(true),
156 count: None,
157 },
158 wgpu::BindGroupLayoutEntry {
159 binding: 4,
160 visibility: wgpu::ShaderStages::COMPUTE,
161 ty: storage(false),
162 count: None,
163 },
164 ],
165 })
166 }
167
168 fn run_pipeline(
169 &self,
170 pipeline: &wgpu::ComputePipeline,
171 a: &RnsBatch,
172 b: &RnsBatch,
173 ) -> RnsBatch {
174 let k = a.channels.len();
175 let b_size = a.batch_size;
176 let n_elems = b_size * k;
177 let byte_len = (n_elems * std::mem::size_of::<u32>()) as u64;
178
179 let params = Params {
180 batch_size: b_size as u32,
181 n_channels: k as u32,
182 _pad: [0, 0],
183 };
184 let params_buf = self
185 .device
186 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
187 label: Some("params"),
188 contents: bytemuck::bytes_of(¶ms),
189 usage: wgpu::BufferUsages::UNIFORM,
190 });
191
192 let moduli_u32: Vec<u32> = a.channels.moduli().iter().map(|&m| m as u32).collect();
193 let moduli_buf = self
194 .device
195 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
196 label: Some("moduli"),
197 contents: bytemuck::cast_slice(&moduli_u32),
198 usage: wgpu::BufferUsages::STORAGE,
199 });
200
201 let a_buf = self
202 .device
203 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
204 label: Some("a"),
205 contents: &a.as_u32_bytes(),
206 usage: wgpu::BufferUsages::STORAGE,
207 });
208 let b_buf = self
209 .device
210 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
211 label: Some("b"),
212 contents: &b.as_u32_bytes(),
213 usage: wgpu::BufferUsages::STORAGE,
214 });
215
216 let out_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
217 label: Some("out"),
218 size: byte_len,
219 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
220 mapped_at_creation: false,
221 });
222 let staging_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
223 label: Some("staging"),
224 size: byte_len,
225 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
226 mapped_at_creation: false,
227 });
228
229 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
230 label: Some("adele-ring-bg"),
231 layout: &self.bind_group_layout,
232 entries: &[
233 wgpu::BindGroupEntry {
234 binding: 0,
235 resource: params_buf.as_entire_binding(),
236 },
237 wgpu::BindGroupEntry {
238 binding: 1,
239 resource: moduli_buf.as_entire_binding(),
240 },
241 wgpu::BindGroupEntry {
242 binding: 2,
243 resource: a_buf.as_entire_binding(),
244 },
245 wgpu::BindGroupEntry {
246 binding: 3,
247 resource: b_buf.as_entire_binding(),
248 },
249 wgpu::BindGroupEntry {
250 binding: 4,
251 resource: out_buf.as_entire_binding(),
252 },
253 ],
254 });
255
256 let mut encoder = self
257 .device
258 .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
259 {
260 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
261 label: Some("rns-pass"),
262 timestamp_writes: None,
263 });
264 pass.set_pipeline(pipeline);
265 pass.set_bind_group(0, &bind_group, &[]);
266 pass.dispatch_workgroups(
267 (b_size as u32).div_ceil(16),
268 (k as u32).div_ceil(16),
269 1,
270 );
271 }
272 encoder.copy_buffer_to_buffer(&out_buf, 0, &staging_buf, 0, byte_len);
273 self.queue.submit([encoder.finish()]);
274
275 let slice = staging_buf.slice(..);
276 let (tx, rx) = std::sync::mpsc::channel();
277 slice.map_async(wgpu::MapMode::Read, move |res| {
278 let _ = tx.send(res);
279 });
280 self.device.poll(wgpu::Maintain::Wait);
281 rx.recv()
282 .expect("map_async channel closed")
283 .expect("buffer map failed");
284
285 let data = slice.get_mapped_range();
286 let values: &[u32] = bytemuck::cast_slice(&data);
287 let result = RnsBatch::from_u32(values, b_size, a.channels.clone());
288 drop(data);
289 staging_buf.unmap();
290 result
291 }
292}
293
294impl ArithmeticBackend for GpuBackend {
295 fn batch_rns_add(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch {
296 self.run_pipeline(&self.add_pipeline, a, b)
297 }
298
299 fn batch_rns_mul(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch {
300 self.run_pipeline(&self.mul_pipeline, a, b)
301 }
302
303 fn batch_crt(&self, batch: &RnsBatch) -> Vec<BigUint> {
304 let k = batch.channels.len();
306 let moduli = batch.channels.moduli();
307 (0..batch.batch_size)
308 .map(|b| garner_crt(&batch.data[b * k..(b + 1) * k], moduli))
309 .collect()
310 }
311
312 fn name(&self) -> &'static str {
313 "gpu-wgpu"
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use crate::rns::{Channels, RnsInt};
321
322 #[test]
323 fn gpu_matches_cpu_when_available() {
324 let gpu = match GpuBackend::try_init() {
325 Ok(g) => g,
326 Err(_) => return, };
328 let ch = Channels::standard(32);
329 let a = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(123, ch.clone()); 256]);
330 let b = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(456, ch.clone()); 256]);
331
332 let cpu = crate::cpu::CpuBackend::new();
333 assert_eq!(
334 cpu.batch_rns_add(&a, &b).data,
335 gpu.batch_rns_add(&a, &b).data
336 );
337 assert_eq!(
338 cpu.batch_rns_mul(&a, &b).data,
339 gpu.batch_rns_mul(&a, &b).data
340 );
341 }
342}