1use bytemuck::Pod;
22use std::collections::HashMap;
23use std::marker::PhantomData;
24
25pub fn storage_buffer_init(device: &wgpu::Device, label: &str, data: &[u8]) -> wgpu::Buffer {
43 assert!(
44 !data.is_empty(),
45 "storage_buffer_init: data must be non-empty"
46 );
47 use wgpu::util::DeviceExt as _;
48 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
49 label: non_empty_label(label),
50 contents: data,
51 usage: wgpu::BufferUsages::STORAGE
52 | wgpu::BufferUsages::COPY_DST
53 | wgpu::BufferUsages::COPY_SRC,
54 })
55}
56
57pub fn uniform_buffer(device: &wgpu::Device, label: &str, data: &[u8]) -> wgpu::Buffer {
66 assert!(!data.is_empty(), "uniform_buffer: data must be non-empty");
67 use wgpu::util::DeviceExt as _;
68 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
69 label: non_empty_label(label),
70 contents: data,
71 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
72 })
73}
74
75pub fn staging_buffer(device: &wgpu::Device, label: &str, size: u64) -> wgpu::Buffer {
84 assert!(size > 0, "staging_buffer: size must be > 0");
85 device.create_buffer(&wgpu::BufferDescriptor {
86 label: non_empty_label(label),
87 size,
88 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
89 mapped_at_creation: false,
90 })
91}
92
93#[cfg_attr(
112 feature = "tracing",
113 tracing::instrument(level = "debug", skip(device, queue, buf))
114)]
115pub fn read_back<T: Pod>(
116 device: &wgpu::Device,
117 queue: &wgpu::Queue,
118 buf: &wgpu::Buffer,
119 len: usize,
120) -> Vec<T> {
121 let byte_size = (std::mem::size_of::<T>() * len) as u64;
122 assert!(byte_size > 0, "read_back: requested size must be > 0");
123
124 let staging = staging_buffer(device, "oxiui-compute-wgpu readback staging", byte_size);
126
127 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
129 label: Some("oxiui-compute-wgpu readback encoder"),
130 });
131 encoder.copy_buffer_to_buffer(buf, 0, &staging, 0, byte_size);
132 queue.submit(std::iter::once(encoder.finish()));
133
134 let slice = staging.slice(..);
136 let (tx, rx) = std::sync::mpsc::channel();
137 slice.map_async(wgpu::MapMode::Read, move |result| {
138 let _ = tx.send(result);
139 });
140
141 device
145 .poll(wgpu::PollType::wait_indefinitely())
146 .expect("read_back: device poll failed");
147
148 rx.recv()
149 .expect("read_back: channel closed before map callback")
150 .expect("read_back: GPU mapping failed");
151
152 let mapped = slice.get_mapped_range();
154 let result: Vec<T> = bytemuck::cast_slice::<u8, T>(&mapped).to_vec();
155
156 drop(mapped);
158 staging.unmap();
159
160 result
161}
162
163pub fn mapped_storage_init(device: &wgpu::Device, label: &str, data: &[u8]) -> wgpu::Buffer {
176 assert!(
177 !data.is_empty(),
178 "mapped_storage_init: data must be non-empty"
179 );
180 let buffer = device.create_buffer(&wgpu::BufferDescriptor {
181 label: non_empty_label(label),
182 size: data.len() as u64,
183 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
184 mapped_at_creation: true,
185 });
186 buffer
187 .slice(..)
188 .get_mapped_range_mut()
189 .copy_from_slice(data);
190 buffer.unmap();
191 buffer
192}
193
194pub fn read_back_range<T: bytemuck::Pod>(
206 device: &wgpu::Device,
207 queue: &wgpu::Queue,
208 src: &wgpu::Buffer,
209 byte_offset: u64,
210 len: usize,
211) -> Vec<T> {
212 let byte_size = (len * std::mem::size_of::<T>()) as u64;
213 assert!(byte_size > 0, "read_back_range: requested size must be > 0");
214 let staging = staging_buffer(device, "", byte_size);
215
216 let mut encoder =
217 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
218 encoder.copy_buffer_to_buffer(src, byte_offset, &staging, 0, byte_size);
219 queue.submit(std::iter::once(encoder.finish()));
220
221 let slice = staging.slice(..);
222 let (tx, rx) = std::sync::mpsc::channel();
223 slice.map_async(wgpu::MapMode::Read, move |result| {
224 let _ = tx.send(result);
225 });
226 device
227 .poll(wgpu::PollType::wait_indefinitely())
228 .expect("read_back_range: device poll failed");
229 rx.recv()
230 .expect("read_back_range: channel closed before map callback")
231 .expect("read_back_range: GPU mapping failed");
232
233 let mapped = slice.get_mapped_range();
234 let result = bytemuck::cast_slice::<u8, T>(&mapped).to_vec();
235 drop(mapped);
236 staging.unmap();
237 result
238}
239
240pub async fn read_back_async<T: bytemuck::Pod>(
255 device: &wgpu::Device,
256 queue: &wgpu::Queue,
257 src: &wgpu::Buffer,
258 len: usize,
259) -> Result<Vec<T>, crate::ComputeError> {
260 let byte_size = (len * std::mem::size_of::<T>()) as u64;
261 assert!(byte_size > 0, "read_back_async: requested size must be > 0");
262 let staging = staging_buffer(device, "read-back-async", byte_size);
263
264 let mut encoder = device.create_command_encoder(&Default::default());
265 encoder.copy_buffer_to_buffer(src, 0, &staging, 0, byte_size);
266 queue.submit(std::iter::once(encoder.finish()));
267
268 let (tx, rx) = std::sync::mpsc::channel::<Result<(), wgpu::BufferAsyncError>>();
270 staging.slice(..).map_async(wgpu::MapMode::Read, move |r| {
271 let _ = tx.send(r);
272 });
273
274 std::future::ready(()).await;
277 device
278 .poll(wgpu::PollType::wait_indefinitely())
279 .map_err(|e| crate::ComputeError::Operation {
280 op: "read_back_async",
281 detail: e.to_string(),
282 })?;
283
284 rx.recv()
285 .map_err(|_| crate::ComputeError::Operation {
286 op: "read_back_async",
287 detail: "channel closed before map callback fired".into(),
288 })?
289 .map_err(|e| crate::ComputeError::Operation {
290 op: "read_back_async",
291 detail: e.to_string(),
292 })?;
293
294 let mapped = staging.slice(..).get_mapped_range();
295 let data = bytemuck::cast_slice::<u8, T>(&mapped).to_vec();
296 drop(mapped);
297 staging.unmap();
298 Ok(data)
299}
300
301pub struct TypedBuffer<T: bytemuck::Pod> {
316 buffer: wgpu::Buffer,
317 len: usize,
318 _phantom: PhantomData<T>,
319}
320
321impl<T: bytemuck::Pod> TypedBuffer<T> {
322 pub fn new(device: &wgpu::Device, label: &str, usage: wgpu::BufferUsages, len: usize) -> Self {
325 let size = (len * std::mem::size_of::<T>()) as u64;
326 let buffer = device.create_buffer(&wgpu::BufferDescriptor {
327 label: non_empty_label(label),
328 size,
329 usage,
330 mapped_at_creation: false,
331 });
332 TypedBuffer {
333 buffer,
334 len,
335 _phantom: PhantomData,
336 }
337 }
338
339 pub fn from_data(device: &wgpu::Device, label: &str, data: &[T]) -> Self {
341 let bytes = bytemuck::cast_slice(data);
342 let buffer = storage_buffer_init(device, label, bytes);
343 TypedBuffer {
344 buffer,
345 len: data.len(),
346 _phantom: PhantomData,
347 }
348 }
349
350 pub fn len(&self) -> usize {
352 self.len
353 }
354
355 pub fn is_empty(&self) -> bool {
357 self.len == 0
358 }
359
360 pub fn byte_len(&self) -> u64 {
362 (self.len * std::mem::size_of::<T>()) as u64
363 }
364
365 pub fn as_entire_binding(&self) -> wgpu::BindingResource<'_> {
368 self.buffer.as_entire_binding()
369 }
370
371 pub fn inner(&self) -> &wgpu::Buffer {
373 &self.buffer
374 }
375
376 pub fn upload(&self, queue: &wgpu::Queue, data: &[T]) {
381 assert_eq!(data.len(), self.len, "upload length mismatch");
382 queue.write_buffer(&self.buffer, 0, bytemuck::cast_slice(data));
383 }
384
385 pub fn download(&self, device: &wgpu::Device, queue: &wgpu::Queue) -> Vec<T> {
387 read_back(device, queue, &self.buffer, self.len)
388 }
389}
390
391pub struct BufferPool {
405 buckets: HashMap<(u64, wgpu::BufferUsages), Vec<wgpu::Buffer>>,
406}
407
408impl BufferPool {
409 pub fn new() -> Self {
411 BufferPool {
412 buckets: HashMap::new(),
413 }
414 }
415
416 pub fn acquire(
422 &mut self,
423 device: &wgpu::Device,
424 size: u64,
425 usage: wgpu::BufferUsages,
426 ) -> wgpu::Buffer {
427 let rounded = size.next_power_of_two().max(256);
428 let bucket = self.buckets.entry((rounded, usage)).or_default();
429 if let Some(buf) = bucket.pop() {
430 return buf;
431 }
432 device.create_buffer(&wgpu::BufferDescriptor {
433 label: Some("pool-buffer"),
434 size: rounded,
435 usage,
436 mapped_at_creation: false,
437 })
438 }
439
440 pub fn release(&mut self, size: u64, usage: wgpu::BufferUsages, buffer: wgpu::Buffer) {
447 let rounded = size.next_power_of_two().max(256);
448 self.buckets
449 .entry((rounded, usage))
450 .or_default()
451 .push(buffer);
452 }
453
454 pub fn available_count(&self, size: u64, usage: wgpu::BufferUsages) -> usize {
456 let rounded = size.next_power_of_two().max(256);
457 self.buckets.get(&(rounded, usage)).map_or(0, |v| v.len())
458 }
459}
460
461impl Default for BufferPool {
462 fn default() -> Self {
463 Self::new()
464 }
465}
466
467#[derive(Debug, Clone, Copy, PartialEq, Eq)]
471pub struct SubRegion {
472 pub offset: u64,
474 pub size: u64,
476}
477
478pub struct SubAllocator {
489 buffer: wgpu::Buffer,
490 capacity: u64,
491 cursor: u64,
492 alignment: u64,
493}
494
495impl SubAllocator {
496 pub fn new(buffer: wgpu::Buffer, capacity: u64, alignment: u64) -> Self {
499 SubAllocator {
500 buffer,
501 capacity,
502 cursor: 0,
503 alignment: alignment.max(1),
504 }
505 }
506
507 pub fn alloc(&mut self, size: u64) -> Option<SubRegion> {
512 let aligned_cursor = align_up(self.cursor, self.alignment);
513 let end = aligned_cursor.checked_add(size)?;
514 if end > self.capacity {
515 return None;
516 }
517 self.cursor = end;
518 Some(SubRegion {
519 offset: aligned_cursor,
520 size,
521 })
522 }
523
524 pub fn reset(&mut self) {
527 self.cursor = 0;
528 }
529
530 pub fn inner(&self) -> &wgpu::Buffer {
532 &self.buffer
533 }
534
535 pub fn used(&self) -> u64 {
538 self.cursor
539 }
540
541 pub fn remaining(&self) -> u64 {
543 self.capacity.saturating_sub(self.cursor)
544 }
545}
546
547#[inline]
551fn non_empty_label(label: &str) -> Option<&str> {
552 if label.is_empty() {
553 None
554 } else {
555 Some(label)
556 }
557}
558
559#[inline]
563fn align_up(value: u64, alignment: u64) -> u64 {
564 if alignment == 0 {
565 return value;
566 }
567 value.div_ceil(alignment) * alignment
568}
569
570#[cfg(test)]
573mod tests {
574 use super::*;
575 use crate::context::ComputeContext;
576
577 macro_rules! require_gpu {
579 ($ctx:ident) => {
580 let Some($ctx) = ComputeContext::try_new() else {
581 return; };
583 };
584 }
585
586 #[test]
589 fn storage_buffer_init_roundtrip() {
590 require_gpu!(ctx);
591 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
592 let bytes = bytemuck::cast_slice::<f32, u8>(&data);
593 let buf = storage_buffer_init(&ctx.device, "test-storage", bytes);
594 let back: Vec<f32> = read_back(&ctx.device, &ctx.queue, &buf, data.len());
595 assert_eq!(back, data);
596 }
597
598 #[test]
599 fn uniform_buffer_created() {
600 require_gpu!(ctx);
601 let data: [f32; 4] = [0.1, 0.2, 0.3, 0.4];
602 let bytes = bytemuck::cast_slice::<f32, u8>(&data);
603 let _buf = uniform_buffer(&ctx.device, "test-uniform", bytes);
605 }
606
607 #[test]
608 fn staging_buffer_created() {
609 require_gpu!(ctx);
610 let _buf = staging_buffer(&ctx.device, "test-staging", 256);
611 }
612
613 #[test]
614 fn non_empty_label_behaviour() {
615 assert_eq!(non_empty_label("foo"), Some("foo"));
616 assert_eq!(non_empty_label(""), None);
617 }
618
619 #[test]
623 fn typed_buffer_len_math() {
624 assert_eq!(std::mem::size_of::<f32>(), 4);
625 let len: usize = 8;
626 assert_eq!(len * std::mem::size_of::<f32>(), 32);
627 assert_eq!((len * std::mem::size_of::<f32>()) as u64, 32u64);
629 }
630
631 #[test]
634 fn suballocator_offsets_aligned() {
635 let first_aligned = align_up(0, 256);
639 assert_eq!(first_aligned, 0);
640 let after_first = first_aligned + 100; let second_aligned = align_up(after_first, 256);
642 assert!(
643 second_aligned >= 256,
644 "second offset {second_aligned} should be >= 256"
645 );
646 assert_eq!(second_aligned % 256, 0);
647 }
648
649 #[test]
651 fn suballocator_reset_rewinds() {
652 require_gpu!(ctx);
655
656 let backing = ctx.device.create_buffer(&wgpu::BufferDescriptor {
658 label: Some("sub-alloc-test"),
659 size: 1024,
660 usage: wgpu::BufferUsages::STORAGE,
661 mapped_at_creation: false,
662 });
663 let mut sa = SubAllocator::new(backing, 1024, 256);
664
665 let r1 = sa.alloc(100).expect("first alloc should succeed");
666 assert_eq!(r1.offset, 0);
667 sa.reset();
668 let r2 = sa.alloc(100).expect("post-reset alloc should succeed");
669 assert_eq!(r2.offset, 0, "after reset, offset must restart at 0");
670 }
671
672 #[test]
674 fn buffer_pool_size_rounds_up() {
675 assert_eq!(256u64.next_power_of_two(), 256);
676 assert_eq!(300u64.next_power_of_two(), 512);
677 assert_eq!(1u64.next_power_of_two().max(256), 256);
678 assert_eq!(255u64.next_power_of_two().max(256), 256);
679 }
680
681 #[test]
683 fn bytemuck_pod_roundtrip() {
684 let original: [f32; 3] = [1.0, 2.0, 3.0];
685 let bytes: &[u8] = bytemuck::cast_slice(&original);
686 assert_eq!(bytes.len(), 12);
687 let back: &[f32] = bytemuck::cast_slice(bytes);
688 assert_eq!(back, &original);
689 }
690
691 #[test]
696 fn pool_acquire_reuses_buffer() {
697 require_gpu!(ctx);
698 let mut pool = BufferPool::new();
699 let usage = wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC;
700 let size: u64 = 256;
701
702 assert_eq!(pool.available_count(size, usage), 0);
704
705 let buf = pool.acquire(&ctx.device, size, usage);
707
708 pool.release(size, usage, buf);
710 assert_eq!(pool.available_count(size, usage), 1);
711
712 let _buf2 = pool.acquire(&ctx.device, size, usage);
714 assert_eq!(pool.available_count(size, usage), 0);
715 }
716
717 #[test]
720 fn mapped_init_roundtrip() {
721 require_gpu!(ctx);
722 let data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
723 let bytes = bytemuck::cast_slice::<f32, u8>(&data);
724
725 let src = mapped_storage_init(&ctx.device, "mapped-init-test", bytes);
726
727 let staging = staging_buffer(&ctx.device, "mapped-init-staging", bytes.len() as u64);
730 let mut encoder = ctx
731 .device
732 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
733 label: Some("mapped-init-readback"),
734 });
735 encoder.copy_buffer_to_buffer(&src, 0, &staging, 0, bytes.len() as u64);
736 ctx.queue.submit(std::iter::once(encoder.finish()));
737
738 let slice = staging.slice(..);
739 let (tx, rx) = std::sync::mpsc::channel();
740 slice.map_async(wgpu::MapMode::Read, move |r| {
741 let _ = tx.send(r);
742 });
743 ctx.device
744 .poll(wgpu::PollType::wait_indefinitely())
745 .expect("poll failed");
746 rx.recv()
747 .expect("channel closed")
748 .expect("map_async failed");
749
750 let mapped = slice.get_mapped_range();
751 let back: Vec<f32> = bytemuck::cast_slice::<u8, f32>(&mapped).to_vec();
752 drop(mapped);
753 staging.unmap();
754
755 assert_eq!(back, data);
756 }
757
758 #[test]
761 fn read_back_range_returns_subslice() {
762 require_gpu!(ctx);
763 let data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
764 let bytes = bytemuck::cast_slice::<f32, u8>(&data);
765 let buf = storage_buffer_init(&ctx.device, "range-test", bytes);
766
767 let sub: Vec<f32> = read_back_range(&ctx.device, &ctx.queue, &buf, 4, 2);
769 assert_eq!(sub, vec![20.0f32, 30.0]);
770 }
771
772 #[test]
775 fn async_readback_matches_sync() {
776 require_gpu!(ctx);
777 let data: Vec<f32> = vec![5.0, 6.0, 7.0, 8.0];
778 let bytes = bytemuck::cast_slice::<f32, u8>(&data);
779 let buf = storage_buffer_init(&ctx.device, "async-readback-test", bytes);
780
781 let sync_result: Vec<f32> = read_back(&ctx.device, &ctx.queue, &buf, data.len());
782 let async_result: Vec<f32> =
783 pollster::block_on(read_back_async(&ctx.device, &ctx.queue, &buf, data.len()))
784 .expect("async readback failed");
785
786 assert_eq!(sync_result, async_result);
787 assert_eq!(async_result, data);
788 }
789}