Skip to main content

rlx_wgpu/
device.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! wgpu device discovery + capabilities.
17//!
18//! [`wgpu_device`] returns a process-global singleton. [`select_vulkan_backend`]
19//! routes subsequent calls to a Vulkan-only instance (for [`Device::Vulkan`]).
20
21use std::sync::OnceLock;
22use std::sync::atomic::{AtomicU8, Ordering};
23
24const PREF_DEFAULT: u8 = 0;
25const PREF_VULKAN: u8 = 1;
26
27static BACKEND_PREF: AtomicU8 = AtomicU8::new(PREF_DEFAULT);
28
29/// Detected wgpu adapter + device. We hold them together because
30/// every command submission needs both the device (for encoding) and
31/// the queue (for committing).
32pub struct WgpuDevice {
33    pub instance: wgpu::Instance,
34    pub adapter: wgpu::Adapter,
35    pub device: wgpu::Device,
36    pub queue: wgpu::Queue,
37    pub name: String,
38    pub backend: wgpu::Backend,
39}
40
41impl WgpuDevice {
42    fn new_with_backends(backends: wgpu::Backends) -> Option<Self> {
43        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
44            backends,
45            flags: wgpu::InstanceFlags::default(),
46            backend_options: wgpu::BackendOptions::default(),
47            memory_budget_thresholds: wgpu::MemoryBudgetThresholds::default(),
48            display: None,
49        });
50        let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
51            power_preference: wgpu::PowerPreference::HighPerformance,
52            compatible_surface: None,
53            force_fallback_adapter: false,
54        }))
55        .ok()?;
56
57        let info = adapter.get_info();
58        let limits = adapter.limits();
59        let adapter_feats = adapter.features();
60        let mut required_features = wgpu::Features::empty();
61        if adapter_feats.contains(wgpu::Features::SHADER_F16) {
62            required_features |= wgpu::Features::SHADER_F16;
63        }
64        if adapter_feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX) {
65            required_features |= wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX;
66        }
67        if adapter_feats.contains(wgpu::Features::SUBGROUP) {
68            required_features |= wgpu::Features::SUBGROUP;
69        }
70
71        let (device, queue) =
72            match pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
73                label: Some("rlx-wgpu device"),
74                required_features,
75                required_limits: limits,
76                memory_hints: wgpu::MemoryHints::Performance,
77                experimental_features: unsafe { wgpu::ExperimentalFeatures::enabled() },
78                trace: wgpu::Trace::Off,
79            })) {
80                Ok(p) => p,
81                Err(e) => {
82                    eprintln!("rlx-wgpu request_device failed: {e}");
83                    return None;
84                }
85            };
86
87        Some(Self {
88            instance,
89            adapter,
90            device,
91            queue,
92            name: info.name,
93            backend: info.backend,
94        })
95    }
96
97    fn new_default() -> Option<Self> {
98        Self::new_with_backends(default_backends())
99    }
100}
101
102fn default_backends() -> wgpu::Backends {
103    if let Some(b) = wgpu::Backends::from_env() {
104        return b;
105    }
106    #[cfg(target_os = "windows")]
107    {
108        // Native DX12 first on Windows MSVC; Vulkan remains as fallback.
109        wgpu::Backends::DX12 | wgpu::Backends::VULKAN
110    }
111    #[cfg(target_os = "linux")]
112    {
113        // WSL Ubuntu + native Linux: prefer Vulkan (NVIDIA passthrough).
114        wgpu::Backends::VULKAN
115    }
116    #[cfg(target_os = "macos")]
117    {
118        wgpu::Backends::METAL | wgpu::Backends::VULKAN
119    }
120    #[cfg(not(any(target_os = "windows", target_os = "linux", target_os = "macos")))]
121    {
122        wgpu::Backends::all()
123    }
124}
125
126// SAFETY: wgpu's Device + Queue are documented thread-safe.
127unsafe impl Send for WgpuDevice {}
128unsafe impl Sync for WgpuDevice {}
129
130fn default_device() -> Option<&'static WgpuDevice> {
131    static DEVICE: OnceLock<Option<WgpuDevice>> = OnceLock::new();
132    DEVICE.get_or_init(WgpuDevice::new_default).as_ref()
133}
134
135fn vulkan_device() -> Option<&'static WgpuDevice> {
136    static DEVICE: OnceLock<Option<WgpuDevice>> = OnceLock::new();
137    DEVICE
138        .get_or_init(|| WgpuDevice::new_with_backends(wgpu::Backends::VULKAN))
139        .as_ref()
140}
141
142/// Prefer the Vulkan-only wgpu instance for [`Device::Vulkan`] sessions.
143/// Call before the first [`wgpu_device`] use in that process (or use
144/// `Device::Vulkan` via the runtime registry, which calls this).
145pub fn select_vulkan_backend() {
146    BACKEND_PREF.store(PREF_VULKAN, Ordering::SeqCst);
147}
148
149/// True when a Vulkan adapter is reachable (MoltenVK on macOS, native on Linux/Windows).
150pub fn is_vulkan_available() -> bool {
151    vulkan_device().is_some()
152}
153
154/// Get or initialize the global wgpu device singleton. Returns None
155/// on systems with no compatible adapter.
156pub fn wgpu_device() -> Option<&'static WgpuDevice> {
157    if BACKEND_PREF.load(Ordering::SeqCst) == PREF_VULKAN {
158        vulkan_device()
159    } else {
160        default_device()
161    }
162}
163
164/// Adapter name for calibration cache keys.
165pub fn adapter_name() -> Option<String> {
166    wgpu_device().map(|d| d.name.clone())
167}
168
169/// True on Vulkan or DX12 — discrete GPU cooperative-matrix backends.
170pub fn coop_discrete_backend() -> bool {
171    wgpu_device()
172        .map(|d| matches!(d.backend, wgpu::Backend::Vulkan | wgpu::Backend::Dx12))
173        .unwrap_or(false)
174}
175
176/// True when the adapter reports 8×8×8 f32 cooperative-matrix support
177/// (required for `matmul_coop_f32_portable` on Vulkan/DX12).
178pub fn coop_f32_8x8_supported() -> bool {
179    let dev = match wgpu_device() {
180        Some(d) => d,
181        None => return false,
182    };
183    dev.adapter.cooperative_matrix_properties().iter().any(|p| {
184        p.m_size == 8
185            && p.n_size == 8
186            && p.k_size == 8
187            && p.ab_type == wgpu::CooperativeScalarType::F32
188            && p.cr_type == wgpu::CooperativeScalarType::F32
189    })
190}
191
192/// True when the adapter reports 16×16×16 f16 cooperative-matrix support
193/// (NVIDIA / AMD tensor-core path on Vulkan/DX12).
194pub fn coop_f16_16x16_supported() -> bool {
195    let dev = match wgpu_device() {
196        Some(d) => d,
197        None => return false,
198    };
199    dev.adapter.cooperative_matrix_properties().iter().any(|p| {
200        p.m_size == 16
201            && p.n_size == 16
202            && p.k_size == 16
203            && p.ab_type == wgpu::CooperativeScalarType::F16
204            && (p.cr_type == wgpu::CooperativeScalarType::F16
205                || p.cr_type == wgpu::CooperativeScalarType::F32)
206    })
207}
208
209/// True when f16 operands with f32 accumulator are available (preferred on RTX).
210pub fn coop_f16_16x16_f32_acc_supported() -> bool {
211    let dev = match wgpu_device() {
212        Some(d) => d,
213        None => return false,
214    };
215    dev.adapter.cooperative_matrix_properties().iter().any(|p| {
216        p.m_size == 16
217            && p.n_size == 16
218            && p.k_size == 16
219            && p.ab_type == wgpu::CooperativeScalarType::F16
220            && p.cr_type == wgpu::CooperativeScalarType::F32
221    })
222}