Skip to main content

any_gpu/
device.rs

1// Unlicense — cochranblock.org
2// Contributors: GotEmCoach, KOVA, Claude Opus 4.6
3
4use anyhow::{Context, Result};
5use std::collections::HashMap;
6use std::hash::{Hash, Hasher, DefaultHasher};
7use std::sync::{Arc, Mutex};
8use wgpu::util::DeviceExt;
9
10/// GPU device handle. wgpu picks the right backend — Vulkan, Metal, DX12.
11/// One codepath, every vendor.
12pub struct GpuDevice {
13    pub(crate) device: wgpu::Device,
14    pub(crate) queue: wgpu::Queue,
15    pub adapter_name: String,
16    pub backend: String,
17    /// Compiled pipeline cache. Key = hash of WGSL source. Eliminates per-dispatch recompilation.
18    pipeline_cache: Mutex<HashMap<u64, Arc<wgpu::ComputePipeline>>>,
19}
20
21/// GPU-resident f32 buffer with element count metadata.
22pub struct GpuBuffer {
23    pub(crate) buffer: wgpu::Buffer,
24    pub(crate) size: u64,
25    pub len: usize,
26}
27
28impl GpuDevice {
29    /// Discover the best GPU and initialize it. wgpu auto-selects the backend:
30    /// Vulkan on Linux (AMD/NVIDIA/Intel), Metal on macOS, DX12 on Windows.
31    pub fn gpu() -> Result<Self> {
32        pollster::block_on(Self::init_async())
33    }
34
35    async fn init_async() -> Result<Self> {
36        let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
37            backends: wgpu::Backends::all(),
38            ..Default::default()
39        });
40
41        // Skip enumerate_adapters — can crash on Linux when probing GL/other backends.
42        // Just request the best adapter directly.
43        let adapter = instance
44            .request_adapter(&wgpu::RequestAdapterOptions {
45                power_preference: wgpu::PowerPreference::HighPerformance,
46                compatible_surface: None,
47                force_fallback_adapter: false,
48            })
49            .await
50            .context("no GPU found")?;
51
52        let info = adapter.get_info();
53        eprintln!("  any-gpu: {} ({:?}, {:?})", info.name, info.device_type, info.backend);
54
55        // Use the adapter's actual limits — not Limits::default() which can
56        // request capabilities the driver doesn't support (SIGSEGV on RADV/RDNA1).
57        let (device, queue) = adapter
58            .request_device(
59                &wgpu::DeviceDescriptor {
60                    label: Some("any-gpu"),
61                    required_features: wgpu::Features::empty(),
62                    required_limits: adapter.limits(),
63                    memory_hints: wgpu::MemoryHints::Performance,
64                },
65                None,
66            )
67            .await
68            .context("failed to create GPU device")?;
69
70        Ok(Self {
71            device,
72            queue,
73            adapter_name: info.name.clone(),
74            backend: format!("{:?}", info.backend),
75            pipeline_cache: Mutex::new(HashMap::new()),
76        })
77    }
78
79    /// Upload f32 slice to GPU. Returns a storage buffer usable in compute shaders.
80    pub fn upload(&self, data: &[f32]) -> GpuBuffer {
81        let bytes = bytemuck::cast_slice(data);
82        let buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
83            label: None,
84            contents: bytes,
85            usage: wgpu::BufferUsages::STORAGE
86                | wgpu::BufferUsages::COPY_SRC
87                | wgpu::BufferUsages::COPY_DST,
88        });
89        GpuBuffer {
90            size: bytes.len() as u64,
91            len: data.len(),
92            buffer,
93        }
94    }
95
96    /// Allocate an empty GPU buffer for `n` f32 elements.
97    pub fn alloc(&self, n: usize) -> GpuBuffer {
98        let size = (n * std::mem::size_of::<f32>()) as u64;
99        let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
100            label: None,
101            size,
102            usage: wgpu::BufferUsages::STORAGE
103                | wgpu::BufferUsages::COPY_SRC
104                | wgpu::BufferUsages::COPY_DST,
105            mapped_at_creation: false,
106        });
107        GpuBuffer {
108            size,
109            len: n,
110            buffer,
111        }
112    }
113
114    /// Read GPU buffer back to CPU as f32 vec.
115    pub fn read(&self, buf: &GpuBuffer) -> Result<Vec<f32>> {
116        let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
117            label: None,
118            size: buf.size,
119            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
120            mapped_at_creation: false,
121        });
122
123        let mut encoder = self
124            .device
125            .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
126        encoder.copy_buffer_to_buffer(&buf.buffer, 0, &staging, 0, buf.size);
127        self.queue.submit(Some(encoder.finish()));
128
129        let slice = staging.slice(..);
130        let (tx, rx) = std::sync::mpsc::channel();
131        slice.map_async(wgpu::MapMode::Read, move |result| {
132            let _ = tx.send(result);
133        });
134        self.device.poll(wgpu::Maintain::Wait);
135        rx.recv()
136            .context("channel closed")?
137            .context("buffer map failed")?;
138
139        let data = slice.get_mapped_range();
140        let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
141        drop(data);
142        staging.unmap();
143
144        Ok(result)
145    }
146
147    /// Create a small uniform buffer from a bytemuck-able struct.
148    pub(crate) fn upload_uniform<T: bytemuck::Pod>(&self, data: &T) -> wgpu::Buffer {
149        self.device
150            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
151                label: None,
152                contents: bytemuck::bytes_of(data),
153                usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
154            })
155    }
156
157    /// Get or create a compiled compute pipeline for the given WGSL source.
158    /// First call compiles; subsequent calls return the cached Arc. Thread-safe.
159    pub(crate) fn pipeline(&self, shader_src: &str, label: Option<&str>) -> Arc<wgpu::ComputePipeline> {
160        let mut h = DefaultHasher::new();
161        shader_src.hash(&mut h);
162        let key = h.finish();
163
164        let mut cache = self.pipeline_cache.lock().unwrap();
165        if let Some(p) = cache.get(&key) {
166            return Arc::clone(p);
167        }
168
169        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
170            label,
171            source: wgpu::ShaderSource::Wgsl(shader_src.into()),
172        });
173        let pipeline = Arc::new(self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
174            label,
175            layout: None,
176            module: &shader,
177            entry_point: Some("main"),
178            compilation_options: Default::default(),
179            cache: None,
180        }));
181        cache.insert(key, Arc::clone(&pipeline));
182        pipeline
183    }
184
185    /// Number of pipelines currently in the cache. For testing only.
186    #[cfg(test)]
187    pub(crate) fn pipeline_cache_len(&self) -> usize {
188        self.pipeline_cache.lock().unwrap().len()
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
197
198    #[test]
199    fn test_gpu_init() {
200        let d = dev();
201        assert!(!d.adapter_name.is_empty(), "adapter_name should be populated");
202        assert!(!d.backend.is_empty(), "backend should be populated");
203    }
204
205    #[test]
206    fn test_upload_read_roundtrip() {
207        let data = vec![1.0f32, 2.5, -3.7, 0.0, f32::MIN_POSITIVE, 999.999];
208        let buf = dev().upload(&data);
209        assert_eq!(buf.len, data.len());
210        let result = dev().read(&buf).unwrap();
211        assert_eq!(result, data);
212    }
213
214    #[test]
215    fn test_upload_odd_length() {
216        // 13 elements — not aligned to any power of 2
217        let data: Vec<f32> = (0..13).map(|i| i as f32 * 0.1).collect();
218        let buf = dev().upload(&data);
219        assert_eq!(buf.len, 13);
220        let result = dev().read(&buf).unwrap();
221        assert_eq!(result, data);
222    }
223
224    #[test]
225    fn test_upload_single_element() {
226        let buf = dev().upload(&[42.0]);
227        assert_eq!(dev().read(&buf).unwrap(), vec![42.0]);
228    }
229
230    #[test]
231    fn test_alloc_size() {
232        let buf = dev().alloc(100);
233        assert_eq!(buf.len, 100);
234        assert_eq!(buf.size, 400); // 100 * 4 bytes
235    }
236
237    #[test]
238    fn test_alloc_buffers_independent() {
239        // Two allocations should not share data
240        let a = dev().upload(&[1.0, 2.0, 3.0]);
241        let b = dev().upload(&[10.0, 20.0, 30.0]);
242        assert_eq!(dev().read(&a).unwrap(), vec![1.0, 2.0, 3.0]);
243        assert_eq!(dev().read(&b).unwrap(), vec![10.0, 20.0, 30.0]);
244    }
245
246    #[test]
247    fn test_pipeline_cache_same_shader_returns_same_arc() {
248        // Two calls with identical shader source must return the same compiled pipeline.
249        const SRC: &str = "
250struct P { n: u32, _p0: u32, _p1: u32, _p2: u32, }
251@group(0) @binding(0) var<uniform> p: P;
252@group(0) @binding(1) var<storage, read> a: array<f32>;
253@group(0) @binding(2) var<storage, read_write> out: array<f32>;
254@compute @workgroup_size(256)
255fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
256    if gid.x >= p.n { return; }
257    out[gid.x] = a[gid.x];
258}";
259        let p1 = dev().pipeline(SRC, None);
260        let p2 = dev().pipeline(SRC, None);
261        assert!(Arc::ptr_eq(&p1, &p2), "same shader src must return the same Arc");
262    }
263
264    #[test]
265    fn test_pipeline_cache_different_shaders_different_arcs() {
266        const SRC_A: &str = "
267struct P { n: u32, _p0: u32, _p1: u32, _p2: u32, }
268@group(0) @binding(0) var<uniform> p: P;
269@group(0) @binding(1) var<storage, read> a: array<f32>;
270@group(0) @binding(2) var<storage, read_write> out: array<f32>;
271@compute @workgroup_size(256)
272fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
273    if gid.x >= p.n { return; }
274    out[gid.x] = a[gid.x] + 1.0;
275}";
276        const SRC_B: &str = "
277struct P { n: u32, _p0: u32, _p1: u32, _p2: u32, }
278@group(0) @binding(0) var<uniform> p: P;
279@group(0) @binding(1) var<storage, read> a: array<f32>;
280@group(0) @binding(2) var<storage, read_write> out: array<f32>;
281@compute @workgroup_size(256)
282fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
283    if gid.x >= p.n { return; }
284    out[gid.x] = a[gid.x] + 2.0;
285}";
286        let pa = dev().pipeline(SRC_A, None);
287        let pb = dev().pipeline(SRC_B, None);
288        assert!(!Arc::ptr_eq(&pa, &pb), "different shaders must produce different pipeline entries");
289    }
290
291    #[test]
292    fn test_pipeline_cache_grows_then_stabilizes() {
293        // Cache starts with some entries from prior tests. Calling the same shader
294        // N times must not grow the cache past the first insertion.
295        const SRC: &str = "
296struct P { n: u32, _p0: u32, _p1: u32, _p2: u32, }
297@group(0) @binding(0) var<uniform> p: P;
298@group(0) @binding(1) var<storage, read> a: array<f32>;
299@group(0) @binding(2) var<storage, read_write> out: array<f32>;
300@compute @workgroup_size(256)
301fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
302    if gid.x >= p.n { return; }
303    out[gid.x] = a[gid.x] * 3.0;
304}";
305        // Warm the cache for this shader.
306        dev().pipeline(SRC, None);
307        let len_after_first = dev().pipeline_cache_len();
308        // Call 9 more times — cache must not grow.
309        for _ in 0..9 {
310            dev().pipeline(SRC, None);
311        }
312        assert_eq!(dev().pipeline_cache_len(), len_after_first,
313            "repeated calls with same shader must not grow the cache");
314    }
315
316    #[test]
317    fn test_pipeline_cache_correctness_after_caching() {
318        // Verify that an op produces correct results on the 2nd+ call (uses cached pipeline).
319        let a = dev().upload(&[1.0, 2.0, 3.0, 4.0]);
320        let b = dev().upload(&[10.0, 20.0, 30.0, 40.0]);
321        // Run add twice — second call hits pipeline cache.
322        let r1 = dev().add(&a, &b).unwrap();
323        let r2 = dev().add(&a, &b).unwrap();
324        let v1 = dev().read(&r1).unwrap();
325        let v2 = dev().read(&r2).unwrap();
326        assert_eq!(v1, v2);
327        assert_eq!(v1, vec![11.0, 22.0, 33.0, 44.0]);
328    }
329
330    #[test]
331    fn test_read_preserves_precision() {
332        let data: Vec<f32> = (0..100).map(|i| (i as f32) * 0.001 + 0.0001).collect();
333        let buf = dev().upload(&data);
334        let result = dev().read(&buf).unwrap();
335        for (i, (g, e)) in result.iter().zip(data.iter()).enumerate() {
336            assert!((g - e).abs() < 1e-7, "index {i}: got {g}, expected {e}");
337        }
338    }
339}