1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
//! Process-wide wgpu context: instance, adapter, device, queue.
use crate::error::{Result, RullamaError};
/// Holds the wgpu device and queue for the lifetime of a [`crate::api::Model`].
///
/// All inner handles are Arc-internal in wgpu, so `clone()` is cheap and lets us
/// hand the same ctx to both `Forward` (text) and `VisionForward` (multimodal)
/// without juggling Arc/Rc wrappers.
#[derive(Clone)]
pub struct WgpuCtx {
pub instance: wgpu::Instance,
pub adapter: wgpu::Adapter,
pub device: wgpu::Device,
pub queue: wgpu::Queue,
/// True iff the device was created with `Features::SUBGROUP`. Kernels that
/// require `enable subgroups;` only get registered/dispatched when this is set.
pub has_subgroups: bool,
/// True iff `Features::SHADER_F16` was granted. Kernels that declare
/// `enable f16;` only get registered when this is set.
pub has_f16: bool,
}
impl WgpuCtx {
/// Initialize wgpu against the best available adapter.
///
/// On wasm32 this binds to `navigator.gpu`; on native it picks Metal/Vulkan/DX12 via
/// wgpu's default backend selection. Test helper, used during M0–M2 bring-up.
pub async fn new() -> Result<Self> {
let instance =
wgpu::Instance::new(wgpu::InstanceDescriptor::new_without_display_handle_from_env());
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.map_err(|_| RullamaError::NoAdapter)?;
// Opportunistically opt into perf-relevant features. Each test runs:
// * SUBGROUP: subgroup intrinsics (subgroupAdd/Max) collapse the
// barrier-tree reductions in vision attention; on AMD GCN a
// 64-thread WG == 1 subgroup so a whole-WG reduction becomes a
// single op.
// * SUBGROUP_BARRIER: cross-subgroup ordering required by the kernel
// when WGs span >1 subgroup.
// * SHADER_F16: lets us keep tile data in workgroup memory as f16,
// halving LDS bandwidth and (with v3-style register tiles) doubling
// the in-flight tile size for the same LDS budget.
// If an adapter lacks any of these, fall back to the f32-only path; the
// f32 kernels stay as the correctness oracle either way.
let adapter_feats = adapter.features();
let adapter_info = adapter.get_info();
let mut requested = wgpu::Features::empty();
// SUBGROUP feature alone isn't enough — our kernels declare
// `@workgroup_size(64)` and reduce over the whole WG via `subgroupMax`
// / `subgroupAdd`. Those intrinsics only reduce **within a subgroup**,
// so we need a runtime guarantee that subgroups can hold all 64 lanes
// of the WG. `AdapterInfo::subgroup_max_size` is the ceiling the
// adapter may produce; if it's below 64, the WG gets split across
// multiple subgroups and the reduction is incorrect.
//
// Typical values (wgpu/Metal-reported):
// AMD GCN / Vega / Qualcomm Adreno: max ≥ 64 — kernels correct.
// AMD RDNA+: max 64 — correct.
// Apple Silicon: max 32 — split, would silently produce wrong output → skip.
// NVIDIA: max 32 — same.
// Intel: max 16-32 — same.
let subgroup_fits = adapter_info.subgroup_max_size >= 64;
let has_subgroups = adapter_feats.contains(wgpu::Features::SUBGROUP)
&& adapter_feats.contains(wgpu::Features::SUBGROUP_BARRIER)
&& subgroup_fits;
if has_subgroups {
requested |= wgpu::Features::SUBGROUP;
requested |= wgpu::Features::SUBGROUP_BARRIER;
}
let has_f16 = adapter_feats.contains(wgpu::Features::SHADER_F16);
if has_f16 {
requested |= wgpu::Features::SHADER_F16;
}
let adapter_limits = adapter.limits();
let (device, queue) = adapter
.request_device(&wgpu::DeviceDescriptor {
label: Some("rullama device"),
required_features: requested,
required_limits: {
// WebGPU spec mandates max_storage_buffers_per_shader_stage >= 8;
// downlevel_defaults caps it at 4 (legacy OpenGL ES targets).
// The Conformer block-local attention kernel needs 5 storage
// buffers (Q, K, V, pos_proj, out) so we bump just that field.
let mut l = wgpu::Limits::downlevel_defaults();
l.max_storage_buffers_per_shader_stage = 8;
// Raise LDS to whatever the adapter actually supports (Pro 555
// exposes 32 KB vs the WebGPU minimum 16 KB). Kernels that need
// >16 KB are gated; everyone else just gets more headroom.
l.max_compute_workgroup_storage_size = adapter_limits
.max_compute_workgroup_storage_size
.max(l.max_compute_workgroup_storage_size);
l.max_compute_invocations_per_workgroup = adapter_limits
.max_compute_invocations_per_workgroup
.max(l.max_compute_invocations_per_workgroup);
l.max_compute_workgroup_size_x = adapter_limits
.max_compute_workgroup_size_x
.max(l.max_compute_workgroup_size_x);
// Take whatever the adapter actually advertises for
// max_buffer_size / max_storage_buffer_binding_size. The
// downlevel defaults are 256 MiB / 128 MiB; iPad Pro
// reportedly advertises ~993 MB and Apple A18 likely
// beats the 256 MiB floor too. Raising the request
// costs nothing when the adapter caps it; the floor
// remains the spec minimum if the adapter says less.
l.max_buffer_size = adapter_limits.max_buffer_size.max(l.max_buffer_size);
l.max_storage_buffer_binding_size = adapter_limits
.max_storage_buffer_binding_size
.max(l.max_storage_buffer_binding_size);
l
},
experimental_features: wgpu::ExperimentalFeatures::default(),
memory_hints: wgpu::MemoryHints::Performance,
trace: wgpu::Trace::Off,
})
.await
.map_err(|e| RullamaError::DeviceRequest(format!("{e}")))?;
// Surface what the adapter actually granted, so iPhone runs tell us
// whether `max_buffer_size` is the 256 MiB downlevel floor or the
// (hopefully) larger A18-class number. Useful for diagnosing OOMs
// on tight-RAM phones.
let granted = device.limits();
#[cfg(target_arch = "wasm32")]
{
use wasm_bindgen::JsValue;
web_sys::console::log_1(&JsValue::from_str(&format!(
"[rullama wgpu limits] max_buffer_size={} MiB \
max_storage_buffer_binding_size={} MiB \
max_storage_buffers_per_shader_stage={}",
granted.max_buffer_size / (1024 * 1024),
granted.max_storage_buffer_binding_size / (1024 * 1024),
granted.max_storage_buffers_per_shader_stage,
)));
}
#[cfg(not(target_arch = "wasm32"))]
{
eprintln!(
"[rullama wgpu limits] max_buffer_size={} MiB \
max_storage_buffer_binding_size={} MiB",
granted.max_buffer_size / (1024 * 1024),
granted.max_storage_buffer_binding_size / (1024 * 1024),
);
}
Ok(Self {
instance,
adapter,
device,
queue,
has_subgroups,
has_f16,
})
}
}