1use std::{borrow::Cow, collections::BTreeMap, sync::Arc};
2
3use futures::Future;
4use thiserror::Error;
5use wasm_bindgen::prelude::wasm_bindgen;
6use web_rwkv_derive::{Deref, DerefMut};
7use wgpu::{
8 util::{BufferInitDescriptor, DeviceExt},
9 Adapter, BindGroup, BindGroupDescriptor, BindGroupEntry, BindGroupLayout,
10 BindGroupLayoutDescriptor, BindGroupLayoutEntry, Buffer, BufferDescriptor, BufferUsages,
11 ComputePipeline, ComputePipelineDescriptor, Device, DeviceDescriptor, ExperimentalFeatures,
12 Features, Instance, Limits, MemoryHints, PipelineLayoutDescriptor, PowerPreference, Queue,
13 RequestAdapterOptions, ShaderModuleDescriptor, Trace,
14};
15
16use crate::tensor::{
17 cache::{ResourceCache, SharedResourceCache},
18 shape::{IntoBytes, Shape},
19 ResourceKey, TensorResource, View,
20};
21
22pub trait InstanceExt {
23 fn adapter(
24 &self,
25 power_preference: PowerPreference,
26 ) -> impl Future<Output = Result<Adapter, ContextError>>;
27}
28
29impl InstanceExt for Instance {
30 async fn adapter(&self, power_preference: PowerPreference) -> Result<Adapter, ContextError> {
31 self.request_adapter(&RequestAdapterOptions {
32 power_preference,
33 force_fallback_adapter: false,
34 compatible_surface: None,
35 })
36 .await
37 .or(Err(ContextError::RequestAdapterFailed))
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42pub struct ContextId;
43
44#[cfg(not(target_arch = "wasm32"))]
45pub struct ContextEvent {
46 pub buffer: Arc<Buffer>,
47 pub sender: flume::Sender<Box<[u8]>>,
48}
49
50#[derive(Debug, Clone)]
51pub struct Context {
52 pub id: uid::Id<ContextId>,
53 pub adapter: Adapter,
54 pub device: Device,
55 pub queue: Queue,
56
57 pipelines: SharedResourceCache<PipelineKey, CachedPipeline>,
58 shapes: ResourceCache<View, Buffer>,
59 buffers: ResourceCache<BufferKey, Buffer>,
60 bindings: SharedResourceCache<BindGroupKey, BindGroup>,
61
62 #[cfg(not(target_arch = "wasm32"))]
63 event: flume::Sender<ContextEvent>,
64}
65
66#[cfg(not(target_arch = "wasm32"))]
67impl Drop for Context {
68 fn drop(&mut self) {
69 if self.event.sender_count() <= 1 {
70 self.clear_buffers();
71 self.queue.submit(None);
72 _ = self.device.poll(wgpu::PollType::Wait {
73 submission_index: None,
74 timeout: None,
75 });
76 }
77 }
78}
79
80impl PartialEq for Context {
81 fn eq(&self, other: &Self) -> bool {
82 self.id == other.id
83 }
84}
85
86pub struct ContextBuilder {
87 pub adapter: Adapter,
88 pub features: Features,
89 pub limits: Limits,
90}
91
92#[wasm_bindgen]
93#[derive(Debug, Error)]
94pub enum ContextError {
95 #[error("failed to request adaptor")]
96 RequestAdapterFailed,
97 #[error("failed to request device")]
98 RequestDeviceFailed,
99}
100
101impl ContextBuilder {
102 pub fn new(adapter: Adapter) -> Self {
103 let features = Features::empty();
104 #[cfg(feature = "subgroup-ops")]
105 let features = features | Features::SUBGROUP;
106 Self {
107 adapter,
108 features,
109 limits: Default::default(),
110 }
111 }
112
113 pub async fn build(self) -> Result<Context, ContextError> {
114 let Self {
115 adapter,
116 features,
117 limits,
118 } = self;
119
120 let (device, queue) = adapter
121 .request_device(&DeviceDescriptor {
122 label: None,
123 required_features: features,
124 required_limits: limits,
125 memory_hints: MemoryHints::Performance,
126 trace: Trace::Off,
127 experimental_features: ExperimentalFeatures::disabled(),
128 })
129 .await
130 .map_err(|_| ContextError::RequestDeviceFailed)?;
131
132 #[cfg(not(target_arch = "wasm32"))]
133 let (event, receiver) = flume::unbounded();
134
135 let context = Context {
136 id: uid::Id::new(),
137 adapter,
138 device,
139 queue,
140 pipelines: Default::default(),
141 shapes: Default::default(),
142 buffers: ResourceCache::new(4),
143 bindings: SharedResourceCache::new(64),
144 #[cfg(not(target_arch = "wasm32"))]
145 event,
146 };
147
148 #[cfg(not(target_arch = "wasm32"))]
150 {
151 let id = context.id;
152 let device = context.device.clone();
153 std::thread::spawn(move || {
154 while let Ok(ContextEvent { buffer, sender }) = receiver.recv() {
155 #[cfg(feature = "trace")]
156 let _span = tracing::trace_span!("device").entered();
157 let data = read_back_buffer(&device, &buffer);
158 let _ = sender.send(data);
159 }
160 log::info!("context dropped: {id}");
161 });
162 }
163
164 Ok(context)
165 }
166
167 pub fn limits(mut self, limits: Limits) -> Self {
168 self.limits = limits;
169 self
170 }
171
172 pub fn update_limits(mut self, f: impl FnOnce(&mut Limits)) -> Self {
173 f(&mut self.limits);
174 self
175 }
176
177 pub fn features(mut self, features: Features) -> Self {
178 self.features = features;
179 self
180 }
181
182 pub fn update_features(mut self, f: impl FnOnce(&mut Features)) -> Self {
183 f(&mut self.features);
184 self
185 }
186}
187
188#[derive(Debug, Default, Clone, Deref, DerefMut, PartialEq, Eq, Hash)]
190pub struct Macros(BTreeMap<String, String>);
191
192impl Macros {
193 pub fn new() -> Self {
194 Default::default()
195 }
196
197 pub fn compile(self) -> Vec<(String, String)> {
198 self.0.into_iter().collect()
199 }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq, Hash)]
203pub struct PipelineKey {
204 name: String,
205 entry_point: String,
206 macros: Vec<(String, String)>,
207}
208
209impl PipelineKey {
210 pub fn new(name: impl AsRef<str>, entry_point: impl AsRef<str>, macros: Macros) -> Self {
211 let name = name.as_ref().into();
212 let entry_point = entry_point.as_ref().into();
213 let macros = macros.compile();
214 Self {
215 name,
216 entry_point,
217 macros,
218 }
219 }
220}
221
222#[derive(Debug, Clone)]
223pub struct CachedPipeline {
224 pub pipeline: ComputePipeline,
225 pub layout: BindGroupLayout,
226}
227
228#[derive(Debug, Clone, PartialEq, Eq, Hash)]
229struct BufferKey {
230 size: usize,
231 usage: BufferUsages,
232}
233
234#[derive(Debug, Clone, PartialEq, Eq, Hash)]
235struct BindGroupKey {
236 pipeline: PipelineKey,
237 bindings: Vec<(u32, ResourceKey)>,
238}
239
240pub struct BindGroupBuilder<'a, 'b> {
241 context: &'b Context,
242 layout: &'b BindGroupLayout,
243 key: BindGroupKey,
244 entries: Vec<BindGroupEntry<'a>>,
245}
246
247impl<'a, 'b> BindGroupBuilder<'a, 'b> {
248 pub fn new(key: &PipelineKey, context: &'b Context, layout: &'b BindGroupLayout) -> Self {
249 Self {
250 context,
251 layout,
252 key: BindGroupKey {
253 pipeline: key.clone(),
254 bindings: vec![],
255 },
256 entries: vec![],
257 }
258 }
259
260 pub fn touch(mut self, binding: u32, tensor: &'a impl TensorResource) -> Self {
263 let key = tensor.resource_key();
264 self.key.bindings.push((binding, key));
265 self
266 }
267
268 pub fn bind(mut self, binding: u32, tensor: &'a impl TensorResource) -> Self {
270 let resource = tensor.binding();
271 self.entries.push(BindGroupEntry { binding, resource });
272 self.touch(binding, tensor)
273 }
274
275 pub fn bind_meta(mut self, binding: u32, tensor: &'a impl TensorResource) -> Self {
277 let resource = tensor.meta_binding();
278 self.entries.push(BindGroupEntry { binding, resource });
279 self
281 }
282
283 pub fn build(self) -> Arc<BindGroup> {
284 let name = self.key.pipeline.name.clone();
285 self.context.bindings.checkout(self.key, || {
286 self.context.device.create_bind_group(&BindGroupDescriptor {
287 label: Some(&name),
288 layout: self.layout,
289 entries: &self.entries,
290 })
291 })
292 }
293}
294
295impl Eq for Context {}
296
297impl Context {
298 pub fn checkout_pipeline(
299 &self,
300 key: &PipelineKey,
301 source: impl AsRef<str>,
302 entries: &[BindGroupLayoutEntry],
303 ) -> Arc<CachedPipeline> {
304 self.pipelines.checkout(key.clone(), || {
305 use gpp::{process_str, Context};
306 let mut context = Context::new();
307 context.macros = key.macros.iter().cloned().collect();
308
309 let shader = process_str(source.as_ref(), &mut context).unwrap();
310 let module = &self.device.create_shader_module(ShaderModuleDescriptor {
311 label: Some(&key.name),
312 source: wgpu::ShaderSource::Wgsl(Cow::from(shader)),
313 });
314
315 let layout = self
316 .device
317 .create_bind_group_layout(&BindGroupLayoutDescriptor {
318 label: Some(&key.name),
319 entries,
320 });
321 let pipeline_layout = self
322 .device
323 .create_pipeline_layout(&PipelineLayoutDescriptor {
324 label: Some(&key.name),
325 bind_group_layouts: &[&layout],
326 push_constant_ranges: &[],
327 });
328
329 let pipeline = self
330 .device
331 .create_compute_pipeline(&ComputePipelineDescriptor {
332 label: Some(&key.name),
333 layout: Some(&pipeline_layout),
334 module,
335 entry_point: Some(&key.entry_point),
336 compilation_options: Default::default(),
337 cache: None,
338 });
339 CachedPipeline { pipeline, layout }
340 })
341 }
342
343 pub(crate) fn checkout_shape_uniform(&self, shape: Shape) -> Arc<Buffer> {
344 let view = View {
345 shape,
346 stride: shape,
347 offset: Shape::new(0, 0, 0, 0),
348 };
349 let desc = BufferInitDescriptor {
350 label: None,
351 contents: &view.into_bytes(),
352 usage: BufferUsages::UNIFORM,
353 };
354 self.shapes
355 .checkout(view, || self.device.create_buffer_init(&desc))
356 }
357
358 pub(crate) fn checkout_view_uniform(&self, view: View) -> Arc<Buffer> {
359 let desc = BufferInitDescriptor {
360 label: None,
361 contents: &view.into_bytes(),
362 usage: BufferUsages::UNIFORM,
363 };
364 self.shapes
365 .checkout(view, || self.device.create_buffer_init(&desc))
366 }
367
368 pub(crate) fn checkout_buffer_init(&self, contents: &[u8], usage: BufferUsages) -> Arc<Buffer> {
369 let size = std::mem::size_of_val(contents);
370 let _key = BufferKey { size, usage };
371 let desc = BufferInitDescriptor {
372 label: None,
373 contents,
374 usage,
375 };
376 self.device.create_buffer_init(&desc).into()
382 }
383
384 pub(crate) fn checkout_buffer(&self, size: usize, usage: BufferUsages) -> Arc<Buffer> {
385 let key = BufferKey { size, usage };
386 let desc = BufferDescriptor {
387 label: None,
388 size: size as u64,
389 usage,
390 mapped_at_creation: false,
391 };
392 self.buffers
393 .checkout(key, || self.device.create_buffer(&desc))
394 }
395
396 #[inline]
409 pub fn maintain(&self) {
410 self.pipelines.maintain();
411 self.shapes.maintain();
412 self.buffers.maintain();
413 self.bindings.maintain();
414 }
415
416 #[inline]
418 pub fn clear_buffers(&self) {
419 self.shapes.clear();
420 self.buffers.clear();
421 }
422
423 #[cfg(not(target_arch = "wasm32"))]
424 pub(crate) fn event(&self) -> flume::Sender<ContextEvent> {
425 self.event.clone()
426 }
427
428 #[cfg(feature = "subgroup-ops")]
429 pub fn min_subgroup_size(&self) -> u32 {
430 self.adapter.limits().min_subgroup_size
431 }
432
433 #[cfg(feature = "subgroup-ops")]
434 pub fn max_subgroup_size(&self) -> u32 {
435 self.adapter.limits().max_subgroup_size
436 }
437}
438
439#[cfg(not(target_arch = "wasm32"))]
440fn read_back_buffer(device: &Device, buffer: &Buffer) -> Box<[u8]> {
441 assert!(buffer.usage().contains(BufferUsages::MAP_READ));
442
443 let (sender, receiver) = flume::bounded(1);
444 let slice = buffer.slice(..);
445 slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
446
447 _ = device.poll(wgpu::PollType::Wait {
448 submission_index: None,
449 timeout: None,
450 });
451 receiver
452 .recv()
453 .expect("failed to receive read back buffer")
454 .expect("failed to map buffer");
455
456 let data = {
457 let map = slice.get_mapped_range();
458 let len = map.len();
459 let size = std::mem::size_of::<u32>();
460 let data = vec![0u32; len.div_ceil(size)].into_boxed_slice();
461 unsafe {
462 let data = Box::leak(data);
463 let data: &mut [u8] = bytemuck::cast_slice_mut(data);
464 data.copy_from_slice(&map);
465 Box::from_raw(data)
466 }
467 };
468 buffer.unmap();
469 data
470}