1use std::sync::OnceLock;
7use tracing::{debug, info, warn};
8use wgpu::{Device, DeviceDescriptor, Instance, Queue, RequestAdapterOptions};
9
10use crate::error::{GpuError, GpuResult};
11
12static GPU_CONTEXT: OnceLock<Option<GpuContext>> = OnceLock::new();
14
15#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
17pub enum GpuDevicePreference {
18 #[default]
20 Auto,
21 HighPerformance,
23 LowPower,
25}
26
27#[derive(Debug, Clone)]
29pub struct GpuAdapterInfo {
30 pub name: String,
32 pub vendor: String,
34 pub device_type: String,
36 pub backend: String,
38}
39
40impl From<wgpu::AdapterInfo> for GpuAdapterInfo {
41 fn from(info: wgpu::AdapterInfo) -> Self {
42 Self {
43 name: info.name,
44 vendor: format!("{}", info.vendor),
45 device_type: format!("{:?}", info.device_type),
46 backend: format!("{:?}", info.backend),
47 }
48 }
49}
50
51pub struct GpuContext {
57 pub device: Device,
59 pub queue: Queue,
61 pub adapter_info: GpuAdapterInfo,
63 pub limits: wgpu::Limits,
65}
66
67impl GpuContext {
68 pub fn get() -> Option<&'static GpuContext> {
84 GPU_CONTEXT
85 .get_or_init(
86 || match pollster::block_on(Self::try_init(GpuDevicePreference::Auto)) {
87 Ok(ctx) => {
88 info!(
89 adapter = %ctx.adapter_info.name,
90 backend = %ctx.adapter_info.backend,
91 "GPU context initialized"
92 );
93 Some(ctx)
94 }
95 Err(e) => {
96 warn!("GPU initialization failed: {}", e);
97 None
98 }
99 },
100 )
101 .as_ref()
102 }
103
104 pub fn try_get() -> GpuResult<&'static GpuContext> {
109 Self::get().ok_or(GpuError::NotAvailable)
110 }
111
112 pub fn is_available() -> bool {
116 Self::get().is_some()
117 }
118
119 async fn try_init(preference: GpuDevicePreference) -> GpuResult<GpuContext> {
121 debug!("Initializing GPU context with preference: {:?}", preference);
122
123 let instance = Instance::new(&wgpu::InstanceDescriptor {
125 backends: wgpu::Backends::all(),
126 ..Default::default()
127 });
128
129 let power_preference = match preference {
131 GpuDevicePreference::Auto | GpuDevicePreference::HighPerformance => {
132 wgpu::PowerPreference::HighPerformance
133 }
134 GpuDevicePreference::LowPower => wgpu::PowerPreference::LowPower,
135 };
136
137 let adapter = instance
138 .request_adapter(&RequestAdapterOptions {
139 power_preference,
140 force_fallback_adapter: false,
141 compatible_surface: None,
142 })
143 .await
144 .ok_or(GpuError::NotAvailable)?;
145
146 let adapter_info = adapter.get_info();
147 debug!(
148 name = %adapter_info.name,
149 vendor = adapter_info.vendor,
150 device_type = ?adapter_info.device_type,
151 backend = ?adapter_info.backend,
152 "GPU adapter found"
153 );
154
155 let (device, queue) = adapter
157 .request_device(
158 &DeviceDescriptor {
159 label: Some("mesh-gpu"),
160 required_features: wgpu::Features::empty(),
161 required_limits: wgpu::Limits::default(),
162 ..Default::default()
163 },
164 None,
165 )
166 .await
167 .map_err(|e| GpuError::Execution(format!("device request failed: {}", e)))?;
168
169 let limits = device.limits();
170
171 Ok(GpuContext {
172 device,
173 queue,
174 adapter_info: adapter_info.into(),
175 limits,
176 })
177 }
178
179 pub fn max_buffer_size(&self) -> u64 {
181 self.limits.max_buffer_size
182 }
183
184 pub fn max_storage_buffer_size(&self) -> u32 {
186 self.limits.max_storage_buffer_binding_size
187 }
188
189 pub fn max_workgroup_size(&self) -> [u32; 3] {
191 [
192 self.limits.max_compute_workgroup_size_x,
193 self.limits.max_compute_workgroup_size_y,
194 self.limits.max_compute_workgroup_size_z,
195 ]
196 }
197
198 pub fn max_invocations_per_workgroup(&self) -> u32 {
200 self.limits.max_compute_invocations_per_workgroup
201 }
202
203 pub fn estimate_available_memory(&self) -> u64 {
208 self.limits.max_buffer_size
211 }
212}
213
214impl std::fmt::Debug for GpuContext {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 f.debug_struct("GpuContext")
217 .field("adapter_info", &self.adapter_info)
218 .field("max_buffer_size", &self.limits.max_buffer_size)
219 .finish_non_exhaustive()
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn test_gpu_availability_check() {
229 let _available = GpuContext::is_available();
231 }
232
233 #[test]
234 fn test_gpu_context_get() {
235 let first = GpuContext::get();
237 let second = GpuContext::get();
238
239 assert_eq!(first.is_some(), second.is_some());
241
242 if let Some(ctx) = first {
243 assert!(!ctx.adapter_info.name.is_empty());
244 }
245 }
246
247 #[test]
248 fn test_gpu_device_preference_default() {
249 assert_eq!(GpuDevicePreference::default(), GpuDevicePreference::Auto);
250 }
251}