1use anyhow::{Context, Result};
5use std::collections::HashMap;
6use std::hash::{Hash, Hasher, DefaultHasher};
7use std::sync::{Arc, Mutex};
8use wgpu::util::DeviceExt;
9
10pub struct GpuDevice {
13 pub(crate) device: wgpu::Device,
14 pub(crate) queue: wgpu::Queue,
15 pub adapter_name: String,
16 pub backend: String,
17 pipeline_cache: Mutex<HashMap<u64, Arc<wgpu::ComputePipeline>>>,
19}
20
21pub struct GpuBuffer {
23 pub(crate) buffer: wgpu::Buffer,
24 pub(crate) size: u64,
25 pub len: usize,
26}
27
28impl GpuDevice {
29 pub fn gpu() -> Result<Self> {
32 pollster::block_on(Self::init_async())
33 }
34
35 async fn init_async() -> Result<Self> {
36 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
37 backends: wgpu::Backends::all(),
38 ..Default::default()
39 });
40
41 let adapter = instance
44 .request_adapter(&wgpu::RequestAdapterOptions {
45 power_preference: wgpu::PowerPreference::HighPerformance,
46 compatible_surface: None,
47 force_fallback_adapter: false,
48 })
49 .await
50 .context("no GPU found")?;
51
52 let info = adapter.get_info();
53 eprintln!(" any-gpu: {} ({:?}, {:?})", info.name, info.device_type, info.backend);
54
55 let (device, queue) = adapter
58 .request_device(
59 &wgpu::DeviceDescriptor {
60 label: Some("any-gpu"),
61 required_features: wgpu::Features::empty(),
62 required_limits: adapter.limits(),
63 memory_hints: wgpu::MemoryHints::Performance,
64 },
65 None,
66 )
67 .await
68 .context("failed to create GPU device")?;
69
70 Ok(Self {
71 device,
72 queue,
73 adapter_name: info.name.clone(),
74 backend: format!("{:?}", info.backend),
75 pipeline_cache: Mutex::new(HashMap::new()),
76 })
77 }
78
79 pub fn upload(&self, data: &[f32]) -> GpuBuffer {
81 let bytes = bytemuck::cast_slice(data);
82 let buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
83 label: None,
84 contents: bytes,
85 usage: wgpu::BufferUsages::STORAGE
86 | wgpu::BufferUsages::COPY_SRC
87 | wgpu::BufferUsages::COPY_DST,
88 });
89 GpuBuffer {
90 size: bytes.len() as u64,
91 len: data.len(),
92 buffer,
93 }
94 }
95
96 pub fn alloc(&self, n: usize) -> GpuBuffer {
98 let size = (n * std::mem::size_of::<f32>()) as u64;
99 let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
100 label: None,
101 size,
102 usage: wgpu::BufferUsages::STORAGE
103 | wgpu::BufferUsages::COPY_SRC
104 | wgpu::BufferUsages::COPY_DST,
105 mapped_at_creation: false,
106 });
107 GpuBuffer {
108 size,
109 len: n,
110 buffer,
111 }
112 }
113
114 pub fn read(&self, buf: &GpuBuffer) -> Result<Vec<f32>> {
116 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
117 label: None,
118 size: buf.size,
119 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
120 mapped_at_creation: false,
121 });
122
123 let mut encoder = self
124 .device
125 .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
126 encoder.copy_buffer_to_buffer(&buf.buffer, 0, &staging, 0, buf.size);
127 self.queue.submit(Some(encoder.finish()));
128
129 let slice = staging.slice(..);
130 let (tx, rx) = std::sync::mpsc::channel();
131 slice.map_async(wgpu::MapMode::Read, move |result| {
132 let _ = tx.send(result);
133 });
134 self.device.poll(wgpu::Maintain::Wait);
135 rx.recv()
136 .context("channel closed")?
137 .context("buffer map failed")?;
138
139 let data = slice.get_mapped_range();
140 let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
141 drop(data);
142 staging.unmap();
143
144 Ok(result)
145 }
146
147 pub(crate) fn upload_uniform<T: bytemuck::Pod>(&self, data: &T) -> wgpu::Buffer {
149 self.device
150 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
151 label: None,
152 contents: bytemuck::bytes_of(data),
153 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
154 })
155 }
156
157 pub(crate) fn pipeline(&self, shader_src: &str, label: Option<&str>) -> Arc<wgpu::ComputePipeline> {
160 let mut h = DefaultHasher::new();
161 shader_src.hash(&mut h);
162 let key = h.finish();
163
164 let mut cache = self.pipeline_cache.lock().unwrap();
165 if let Some(p) = cache.get(&key) {
166 return Arc::clone(p);
167 }
168
169 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
170 label,
171 source: wgpu::ShaderSource::Wgsl(shader_src.into()),
172 });
173 let pipeline = Arc::new(self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
174 label,
175 layout: None,
176 module: &shader,
177 entry_point: Some("main"),
178 compilation_options: Default::default(),
179 cache: None,
180 }));
181 cache.insert(key, Arc::clone(&pipeline));
182 pipeline
183 }
184
185 #[cfg(test)]
187 pub(crate) fn pipeline_cache_len(&self) -> usize {
188 self.pipeline_cache.lock().unwrap().len()
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
197
198 #[test]
199 fn test_gpu_init() {
200 let d = dev();
201 assert!(!d.adapter_name.is_empty(), "adapter_name should be populated");
202 assert!(!d.backend.is_empty(), "backend should be populated");
203 }
204
205 #[test]
206 fn test_upload_read_roundtrip() {
207 let data = vec![1.0f32, 2.5, -3.7, 0.0, f32::MIN_POSITIVE, 999.999];
208 let buf = dev().upload(&data);
209 assert_eq!(buf.len, data.len());
210 let result = dev().read(&buf).unwrap();
211 assert_eq!(result, data);
212 }
213
214 #[test]
215 fn test_upload_odd_length() {
216 let data: Vec<f32> = (0..13).map(|i| i as f32 * 0.1).collect();
218 let buf = dev().upload(&data);
219 assert_eq!(buf.len, 13);
220 let result = dev().read(&buf).unwrap();
221 assert_eq!(result, data);
222 }
223
224 #[test]
225 fn test_upload_single_element() {
226 let buf = dev().upload(&[42.0]);
227 assert_eq!(dev().read(&buf).unwrap(), vec![42.0]);
228 }
229
230 #[test]
231 fn test_alloc_size() {
232 let buf = dev().alloc(100);
233 assert_eq!(buf.len, 100);
234 assert_eq!(buf.size, 400); }
236
237 #[test]
238 fn test_alloc_buffers_independent() {
239 let a = dev().upload(&[1.0, 2.0, 3.0]);
241 let b = dev().upload(&[10.0, 20.0, 30.0]);
242 assert_eq!(dev().read(&a).unwrap(), vec![1.0, 2.0, 3.0]);
243 assert_eq!(dev().read(&b).unwrap(), vec![10.0, 20.0, 30.0]);
244 }
245
246 #[test]
247 fn test_pipeline_cache_same_shader_returns_same_arc() {
248 const SRC: &str = "
250struct P { n: u32, _p0: u32, _p1: u32, _p2: u32, }
251@group(0) @binding(0) var<uniform> p: P;
252@group(0) @binding(1) var<storage, read> a: array<f32>;
253@group(0) @binding(2) var<storage, read_write> out: array<f32>;
254@compute @workgroup_size(256)
255fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
256 if gid.x >= p.n { return; }
257 out[gid.x] = a[gid.x];
258}";
259 let p1 = dev().pipeline(SRC, None);
260 let p2 = dev().pipeline(SRC, None);
261 assert!(Arc::ptr_eq(&p1, &p2), "same shader src must return the same Arc");
262 }
263
264 #[test]
265 fn test_pipeline_cache_different_shaders_different_arcs() {
266 const SRC_A: &str = "
267struct P { n: u32, _p0: u32, _p1: u32, _p2: u32, }
268@group(0) @binding(0) var<uniform> p: P;
269@group(0) @binding(1) var<storage, read> a: array<f32>;
270@group(0) @binding(2) var<storage, read_write> out: array<f32>;
271@compute @workgroup_size(256)
272fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
273 if gid.x >= p.n { return; }
274 out[gid.x] = a[gid.x] + 1.0;
275}";
276 const SRC_B: &str = "
277struct P { n: u32, _p0: u32, _p1: u32, _p2: u32, }
278@group(0) @binding(0) var<uniform> p: P;
279@group(0) @binding(1) var<storage, read> a: array<f32>;
280@group(0) @binding(2) var<storage, read_write> out: array<f32>;
281@compute @workgroup_size(256)
282fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
283 if gid.x >= p.n { return; }
284 out[gid.x] = a[gid.x] + 2.0;
285}";
286 let pa = dev().pipeline(SRC_A, None);
287 let pb = dev().pipeline(SRC_B, None);
288 assert!(!Arc::ptr_eq(&pa, &pb), "different shaders must produce different pipeline entries");
289 }
290
291 #[test]
292 fn test_pipeline_cache_grows_then_stabilizes() {
293 const SRC: &str = "
296struct P { n: u32, _p0: u32, _p1: u32, _p2: u32, }
297@group(0) @binding(0) var<uniform> p: P;
298@group(0) @binding(1) var<storage, read> a: array<f32>;
299@group(0) @binding(2) var<storage, read_write> out: array<f32>;
300@compute @workgroup_size(256)
301fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
302 if gid.x >= p.n { return; }
303 out[gid.x] = a[gid.x] * 3.0;
304}";
305 dev().pipeline(SRC, None);
307 let len_after_first = dev().pipeline_cache_len();
308 for _ in 0..9 {
310 dev().pipeline(SRC, None);
311 }
312 assert_eq!(dev().pipeline_cache_len(), len_after_first,
313 "repeated calls with same shader must not grow the cache");
314 }
315
316 #[test]
317 fn test_pipeline_cache_correctness_after_caching() {
318 let a = dev().upload(&[1.0, 2.0, 3.0, 4.0]);
320 let b = dev().upload(&[10.0, 20.0, 30.0, 40.0]);
321 let r1 = dev().add(&a, &b).unwrap();
323 let r2 = dev().add(&a, &b).unwrap();
324 let v1 = dev().read(&r1).unwrap();
325 let v2 = dev().read(&r2).unwrap();
326 assert_eq!(v1, v2);
327 assert_eq!(v1, vec![11.0, 22.0, 33.0, 44.0]);
328 }
329
330 #[test]
331 fn test_read_preserves_precision() {
332 let data: Vec<f32> = (0..100).map(|i| (i as f32) * 0.001 + 0.0001).collect();
333 let buf = dev().upload(&data);
334 let result = dev().read(&buf).unwrap();
335 for (i, (g, e)) in result.iter().zip(data.iter()).enumerate() {
336 assert!((g - e).abs() < 1e-7, "index {i}: got {g}, expected {e}");
337 }
338 }
339}